Skip to content

Commit

Permalink
feat(vertex_ai_and_google_ai_studio): Support Google AI Studio Embedd…
Browse files Browse the repository at this point in the history
…ings endpoint

Closes #5385
  • Loading branch information
krrishdholakia committed Aug 28, 2024
1 parent f9c93c3 commit 301b116
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 40 deletions.
61 changes: 60 additions & 1 deletion litellm/llms/vertex_ai_and_google_ai_studio/common_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Literal
from typing import Literal, Tuple

import httpx

Expand Down Expand Up @@ -37,3 +37,62 @@ def get_supports_system_message(
supports_system_message = False

return supports_system_message


from typing import Literal, Optional

all_gemini_url_modes = Literal["chat", "embedding"]


def _get_vertex_url(
mode: all_gemini_url_modes,
model: str,
stream: Optional[bool],
vertex_project: Optional[str],
vertex_location: Optional[str],
vertex_api_version: Literal["v1", "v1beta1"],
) -> Tuple[str, str]:
if mode == "chat":
### SET RUNTIME ENDPOINT ###
endpoint = "generateContent"
if stream is True:
endpoint = "streamGenerateContent"
url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}?alt=sse"
else:
url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"

# if model is only numeric chars then it's a fine tuned gemini model
# model = 4965075652664360960
# send to this url: url = f"https://{vertex_location}-aiplatform.googleapis.com/{version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}"
if model.isdigit():
# It's a fine-tuned Gemini model
url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}"
if stream is True:
url += "?alt=sse"

return url, endpoint


def _get_gemini_url(
mode: all_gemini_url_modes,
model: str,
stream: Optional[bool],
gemini_api_key: Optional[str],
) -> Tuple[str, str]:
if mode == "chat":
_gemini_model_name = "models/{}".format(model)
endpoint = "generateContent"
if stream is True:
endpoint = "streamGenerateContent"
url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}&alt=sse".format(
_gemini_model_name, endpoint, gemini_api_key
)
else:
url = (
"https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format(
_gemini_model_name, endpoint, gemini_api_key
)
)
elif mode == "embedding":
pass
return url, endpoint
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,16 @@
from litellm.types.utils import GenericStreamingChunk
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage

from ..base import BaseLLM
from .common_utils import VertexAIError, get_supports_system_message
from .context_caching.vertex_ai_context_caching import ContextCachingEndpoints
from .gemini_transformation import transform_system_message
from ...base import BaseLLM
from ..common_utils import (
VertexAIError,
_get_gemini_url,
_get_vertex_url,
all_gemini_url_modes,
get_supports_system_message,
)
from ..context_caching.vertex_ai_context_caching import ContextCachingEndpoints
from .transformation import transform_system_message

context_caching_endpoints = ContextCachingEndpoints()

Expand Down Expand Up @@ -309,6 +315,7 @@ def get_supported_openai_params(self):
"n",
"stop",
]

def _map_function(self, value: List[dict]) -> List[Tools]:
gtool_func_declarations = []
googleSearchRetrieval: Optional[dict] = None
Expand Down Expand Up @@ -1164,6 +1171,7 @@ def _get_token_and_url(
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
api_base: Optional[str],
should_use_v1beta1_features: Optional[bool] = False,
mode: all_gemini_url_modes = "chat",
) -> Tuple[Optional[str], str]:
"""
Internal function. Returns the token and url for the call.
Expand All @@ -1174,42 +1182,31 @@ def _get_token_and_url(
token, url
"""
if custom_llm_provider == "gemini":
_gemini_model_name = "models/{}".format(model)
auth_header = None
endpoint = "generateContent"
if stream is True:
endpoint = "streamGenerateContent"
url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}&alt=sse".format(
_gemini_model_name, endpoint, gemini_api_key
)
else:
url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format(
_gemini_model_name, endpoint, gemini_api_key
)
url, endpoint = _get_gemini_url(
mode=mode,
model=model,
stream=stream,
gemini_api_key=gemini_api_key,
)
else:
auth_header, vertex_project = self._ensure_access_token(
credentials=vertex_credentials, project_id=vertex_project
)
vertex_location = self.get_vertex_region(vertex_region=vertex_location)

### SET RUNTIME ENDPOINT ###
version = "v1beta1" if should_use_v1beta1_features is True else "v1"
endpoint = "generateContent"
litellm.utils.print_verbose("vertex_project - {}".format(vertex_project))
if stream is True:
endpoint = "streamGenerateContent"
url = f"https://{vertex_location}-aiplatform.googleapis.com/{version}/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}?alt=sse"
else:
url = f"https://{vertex_location}-aiplatform.googleapis.com/{version}/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"

# if model is only numeric chars then it's a fine tuned gemini model
# model = 4965075652664360960
# send to this url: url = f"https://{vertex_location}-aiplatform.googleapis.com/{version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}"
if model.isdigit():
# It's a fine-tuned Gemini model
url = f"https://{vertex_location}-aiplatform.googleapis.com/{version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}"
if stream is True:
url += "?alt=sse"
version: Literal["v1beta1", "v1"] = (
"v1beta1" if should_use_v1beta1_features is True else "v1"
)
url, endpoint = _get_vertex_url(
mode=mode,
model=model,
stream=stream,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_api_version=version,
)

if (
api_base is not None
Expand Down Expand Up @@ -1793,8 +1790,10 @@ def multimodal_embedding(
input: Union[list, str],
print_verbose,
model_response: litellm.EmbeddingResponse,
custom_llm_provider: Literal["gemini", "vertex_ai"],
optional_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
logging_obj=None,
encoding=None,
vertex_project=None,
Expand All @@ -1804,6 +1803,17 @@ def multimodal_embedding(
timeout=300,
client=None,
):
auth_header, url = self._get_token_and_url(
model=model,
gemini_api_key=api_key,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_credentials=vertex_credentials,
stream=None,
custom_llm_provider=custom_llm_provider,
api_base=api_base,
should_use_v1beta1_features=False,
)

if client is None:
_params = {}
Expand Down
7 changes: 4 additions & 3 deletions litellm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,12 @@
vertex_ai_anthropic,
vertex_ai_non_gemini,
)
from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
VertexLLM,
)
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.main import (
VertexAIPartnerModels,
)
from .llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import (
VertexLLM,
)
from .llms.watsonx import IBMWatsonXAI
from .types.llms.openai import HttpxBinaryResponseContent
from .types.utils import (
Expand Down Expand Up @@ -3568,6 +3568,7 @@ def embedding(
vertex_credentials=vertex_credentials,
aembedding=aembedding,
print_verbose=print_verbose,
custom_llm_provider="vertex_ai",
)
else:
response = vertex_ai_non_gemini.embedding(
Expand Down
8 changes: 4 additions & 4 deletions litellm/tests/test_amazing_vertex_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
completion_cost,
embedding,
)
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import (
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
_gemini_convert_messages_with_history,
)
from litellm.tests.test_streaming import streaming_format_tests
Expand Down Expand Up @@ -2065,7 +2065,7 @@ def test_prompt_factory_nested():


def test_get_token_url():
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import (
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
VertexLLM,
)

Expand All @@ -2087,7 +2087,7 @@ def test_get_token_url():
vertex_credentials=vertex_credentials,
gemini_api_key="",
custom_llm_provider="vertex_ai_beta",
should_use_v1beta1_features=should_use_v1beta1_features,
should_use_vertex_v1beta1_features=should_use_v1beta1_features,
api_base=None,
model="",
stream=False,
Expand All @@ -2107,7 +2107,7 @@ def test_get_token_url():
vertex_credentials=vertex_credentials,
gemini_api_key="",
custom_llm_provider="vertex_ai_beta",
should_use_v1beta1_features=should_use_v1beta1_features,
should_use_vertex_v1beta1_features=should_use_v1beta1_features,
api_base=None,
model="",
stream=False,
Expand Down

0 comments on commit 301b116

Please sign in to comment.