Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Databricks Integration: Integrate Databricks SDK as optional mechanism for fetching API base and token, if unspecified #5746

Merged
merged 17 commits into from
Sep 19, 2024
Prev Previous commit
Next Next commit
fix
Signed-off-by: dbczumar <corey.zumar@databricks.com>
  • Loading branch information
dbczumar committed Sep 17, 2024
commit 651cc9fffb4723576e5f631f7536631bd19865fd
143 changes: 134 additions & 9 deletions tests/llm_translation/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,17 @@
from unittest.mock import MagicMock, Mock, patch

import litellm
from litellm.exceptions import BadRequestError, InternalServerError
from litellm.exceptions import BadRequestError
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.utils import CustomStreamWrapper

try:
import databricks.sdk

databricks_sdk_installed = True
except ImportError:
databricks_sdk_installed = False


def mock_chat_response() -> Dict[str, Any]:
return {
Expand All @@ -33,8 +40,8 @@ def mock_chat_response() -> Dict[str, Any]:
"usage": {
"prompt_tokens": 230,
"completion_tokens": 38,
"total_tokens": 268,
"completion_tokens_details": None,
"total_tokens": 268,
},
"system_fingerprint": None,
}
Expand Down Expand Up @@ -195,20 +202,21 @@ def mock_embedding_response() -> Dict[str, Any]:


@pytest.mark.parametrize("set_base", [True, False])
def test_throws_if_only_one_of_api_base_or_api_key_set(monkeypatch, set_base):
def test_throws_if_api_base_or_api_key_not_set_without_databricks_sdk(
monkeypatch, set_base
):
# Simulate that the databricks SDK is not installed
monkeypatch.setitem(sys.modules, "databricks.sdk", None)

err_msg = "the Databricks base URL and API key are not set"

if set_base:
monkeypatch.setenv(
"DATABRICKS_API_BASE",
"https://my.workspace.cloud.databricks.com/serving-endpoints",
)
monkeypatch.delenv(
"DATABRICKS_API_KEY",
)
err_msg = "A call is being made to LLM Provider but no key is set"
else:
monkeypatch.setenv("DATABRICKS_API_KEY", "dapimykey")
monkeypatch.delenv("DATABRICKS_API_BASE")
err_msg = "A call is being made to LLM Provider but no api base is set"

with pytest.raises(BadRequestError) as exc:
litellm.completion(
Expand Down Expand Up @@ -422,6 +430,67 @@ async def gather_responses():
)


@pytest.mark.skipif(not databricks_sdk_installed, reason="Databricks SDK not installed")
def test_completions_uses_databricks_sdk_if_api_key_and_base_not_specified(monkeypatch):
from databricks.sdk import WorkspaceClient
from databricks.sdk.config import Config

sync_handler = HTTPHandler()
mock_response = Mock(spec=httpx.Response)
mock_response.status_code = 200
mock_response.json.return_value = mock_chat_response()

expected_response_json = {
**mock_chat_response(),
**{
"model": "databricks/dbrx-instruct-071224",
},
}

base_url = "https://my.workspace.cloud.databricks.com"
api_key = "dapimykey"
headers = {
"Authorization": f"Bearer {api_key}",
}
messages = [{"role": "user", "content": "How are you?"}]

mock_workspace_client: WorkspaceClient = MagicMock()
mock_config: Config = MagicMock()
# Simulate the behavior of the config property and its methods
mock_config.authenticate.side_effect = lambda: headers
mock_config.host = base_url # Assign directly as if it's a property
mock_workspace_client.config = mock_config

with patch(
"databricks.sdk.WorkspaceClient", return_value=mock_workspace_client
), patch.object(HTTPHandler, "post", return_value=mock_response) as mock_post:
response = litellm.completion(
model="databricks/dbrx-instruct-071224",
messages=messages,
client=sync_handler,
temperature=0.5,
extraparam="testpassingextraparam",
)
assert response.to_dict() == expected_response_json

mock_post.assert_called_once_with(
f"{base_url}/serving-endpoints/chat/completions",
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
},
data=json.dumps(
{
"model": "dbrx-instruct-071224",
"messages": messages,
"temperature": 0.5,
"extraparam": "testpassingextraparam",
"stream": False,
}
),
)


def test_embeddings_with_sync_http_handler(monkeypatch):
base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints"
api_key = "dapimykey"
Expand Down Expand Up @@ -500,3 +569,59 @@ def test_embeddings_with_async_http_handler(monkeypatch):
}
),
)


@pytest.mark.skipif(not databricks_sdk_installed, reason="Databricks SDK not installed")
def test_completions_uses_databricks_sdk_if_api_key_and_base_not_specified(monkeypatch):
from databricks.sdk import WorkspaceClient
from databricks.sdk.config import Config

base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints"
api_key = "dapimykey"
monkeypatch.setenv("DATABRICKS_API_BASE", base_url)
monkeypatch.setenv("DATABRICKS_API_KEY", api_key)

sync_handler = HTTPHandler()
mock_response = Mock(spec=httpx.Response)
mock_response.status_code = 200
mock_response.json.return_value = mock_embedding_response()

base_url = "https://my.workspace.cloud.databricks.com"
api_key = "dapimykey"
headers = {
"Authorization": f"Bearer {api_key}",
}
inputs = ["Hello", "World"]

mock_workspace_client: WorkspaceClient = MagicMock()
mock_config: Config = MagicMock()
# Simulate the behavior of the config property and its methods
mock_config.authenticate.side_effect = lambda: headers
mock_config.host = base_url # Assign directly as if it's a property
mock_workspace_client.config = mock_config

with patch(
"databricks.sdk.WorkspaceClient", return_value=mock_workspace_client
), patch.object(HTTPHandler, "post", return_value=mock_response) as mock_post:
response = litellm.embedding(
model="databricks/bge-large-en-v1.5",
input=inputs,
client=sync_handler,
extraparam="testpassingextraparam",
)
assert response.to_dict() == mock_embedding_response()

mock_post.assert_called_once_with(
f"{base_url}/serving-endpoints/embeddings",
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
},
data=json.dumps(
{
"model": "bge-large-en-v1.5",
"input": inputs,
"extraparam": "testpassingextraparam",
}
),
)
Loading