diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/common_utils.py b/litellm/llms/vertex_ai_and_google_ai_studio/common_utils.py index 8faf7a3afad3..7e2f9b29d0d7 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/common_utils.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/common_utils.py @@ -1,4 +1,4 @@ -from typing import Literal +from typing import Literal, Tuple import httpx @@ -37,3 +37,62 @@ def get_supports_system_message( supports_system_message = False return supports_system_message + + +from typing import Literal, Optional + +all_gemini_url_modes = Literal["chat", "embedding"] + + +def _get_vertex_url( + mode: all_gemini_url_modes, + model: str, + stream: Optional[bool], + vertex_project: Optional[str], + vertex_location: Optional[str], + vertex_api_version: Literal["v1", "v1beta1"], +) -> Tuple[str, str]: + if mode == "chat": + ### SET RUNTIME ENDPOINT ### + endpoint = "generateContent" + if stream is True: + endpoint = "streamGenerateContent" + url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}?alt=sse" + else: + url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}" + + # if model is only numeric chars then it's a fine tuned gemini model + # model = 4965075652664360960 + # send to this url: url = f"https://{vertex_location}-aiplatform.googleapis.com/{version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}" + if model.isdigit(): + # It's a fine-tuned Gemini model + url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}" + if stream is True: + url += "?alt=sse" + + return url, endpoint + + +def _get_gemini_url( + mode: all_gemini_url_modes, + model: str, + stream: Optional[bool], + gemini_api_key: Optional[str], +) -> Tuple[str, str]: + if mode == "chat": + _gemini_model_name = "models/{}".format(model) + endpoint = "generateContent" + if stream is True: + endpoint = "streamGenerateContent" + url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}&alt=sse".format( + _gemini_model_name, endpoint, gemini_api_key + ) + else: + url = ( + "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format( + _gemini_model_name, endpoint, gemini_api_key + ) + ) + elif mode == "embedding": + pass + return url, endpoint diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/gemini_transformation.py b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/transformation.py similarity index 100% rename from litellm/llms/vertex_ai_and_google_ai_studio/gemini_transformation.py rename to litellm/llms/vertex_ai_and_google_ai_studio/gemini/transformation.py diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_and_google_ai_studio_gemini.py b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py similarity index 97% rename from litellm/llms/vertex_ai_and_google_ai_studio/vertex_and_google_ai_studio_gemini.py rename to litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py index 5392f253f801..d897f5bfbd69 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_and_google_ai_studio_gemini.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py @@ -54,10 +54,16 @@ from litellm.types.utils import GenericStreamingChunk from litellm.utils import CustomStreamWrapper, ModelResponse, Usage -from ..base import BaseLLM -from .common_utils import VertexAIError, get_supports_system_message -from .context_caching.vertex_ai_context_caching import ContextCachingEndpoints -from .gemini_transformation import transform_system_message +from ...base import BaseLLM +from ..common_utils import ( + VertexAIError, + _get_gemini_url, + _get_vertex_url, + all_gemini_url_modes, + get_supports_system_message, +) +from ..context_caching.vertex_ai_context_caching import ContextCachingEndpoints +from .transformation import transform_system_message context_caching_endpoints = ContextCachingEndpoints() @@ -309,6 +315,7 @@ def get_supported_openai_params(self): "n", "stop", ] + def _map_function(self, value: List[dict]) -> List[Tools]: gtool_func_declarations = [] googleSearchRetrieval: Optional[dict] = None @@ -1164,6 +1171,7 @@ def _get_token_and_url( custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"], api_base: Optional[str], should_use_v1beta1_features: Optional[bool] = False, + mode: all_gemini_url_modes = "chat", ) -> Tuple[Optional[str], str]: """ Internal function. Returns the token and url for the call. @@ -1174,18 +1182,13 @@ def _get_token_and_url( token, url """ if custom_llm_provider == "gemini": - _gemini_model_name = "models/{}".format(model) auth_header = None - endpoint = "generateContent" - if stream is True: - endpoint = "streamGenerateContent" - url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}&alt=sse".format( - _gemini_model_name, endpoint, gemini_api_key - ) - else: - url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format( - _gemini_model_name, endpoint, gemini_api_key - ) + url, endpoint = _get_gemini_url( + mode=mode, + model=model, + stream=stream, + gemini_api_key=gemini_api_key, + ) else: auth_header, vertex_project = self._ensure_access_token( credentials=vertex_credentials, project_id=vertex_project @@ -1193,23 +1196,17 @@ def _get_token_and_url( vertex_location = self.get_vertex_region(vertex_region=vertex_location) ### SET RUNTIME ENDPOINT ### - version = "v1beta1" if should_use_v1beta1_features is True else "v1" - endpoint = "generateContent" - litellm.utils.print_verbose("vertex_project - {}".format(vertex_project)) - if stream is True: - endpoint = "streamGenerateContent" - url = f"https://{vertex_location}-aiplatform.googleapis.com/{version}/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}?alt=sse" - else: - url = f"https://{vertex_location}-aiplatform.googleapis.com/{version}/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}" - - # if model is only numeric chars then it's a fine tuned gemini model - # model = 4965075652664360960 - # send to this url: url = f"https://{vertex_location}-aiplatform.googleapis.com/{version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}" - if model.isdigit(): - # It's a fine-tuned Gemini model - url = f"https://{vertex_location}-aiplatform.googleapis.com/{version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}" - if stream is True: - url += "?alt=sse" + version: Literal["v1beta1", "v1"] = ( + "v1beta1" if should_use_v1beta1_features is True else "v1" + ) + url, endpoint = _get_vertex_url( + mode=mode, + model=model, + stream=stream, + vertex_project=vertex_project, + vertex_location=vertex_location, + vertex_api_version=version, + ) if ( api_base is not None @@ -1793,8 +1790,10 @@ def multimodal_embedding( input: Union[list, str], print_verbose, model_response: litellm.EmbeddingResponse, + custom_llm_provider: Literal["gemini", "vertex_ai"], optional_params: dict, api_key: Optional[str] = None, + api_base: Optional[str] = None, logging_obj=None, encoding=None, vertex_project=None, @@ -1804,6 +1803,17 @@ def multimodal_embedding( timeout=300, client=None, ): + auth_header, url = self._get_token_and_url( + model=model, + gemini_api_key=api_key, + vertex_project=vertex_project, + vertex_location=vertex_location, + vertex_credentials=vertex_credentials, + stream=None, + custom_llm_provider=custom_llm_provider, + api_base=api_base, + should_use_v1beta1_features=False, + ) if client is None: _params = {} diff --git a/litellm/main.py b/litellm/main.py index a77a03522a02..b83a583f4a82 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -126,12 +126,12 @@ vertex_ai_anthropic, vertex_ai_non_gemini, ) +from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( + VertexLLM, +) from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.main import ( VertexAIPartnerModels, ) -from .llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import ( - VertexLLM, -) from .llms.watsonx import IBMWatsonXAI from .types.llms.openai import HttpxBinaryResponseContent from .types.utils import ( @@ -3568,6 +3568,7 @@ def embedding( vertex_credentials=vertex_credentials, aembedding=aembedding, print_verbose=print_verbose, + custom_llm_provider="vertex_ai", ) else: response = vertex_ai_non_gemini.embedding( diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index 53005fac0988..f542d18d16a2 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -28,7 +28,7 @@ completion_cost, embedding, ) -from litellm.llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import ( +from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( _gemini_convert_messages_with_history, ) from litellm.tests.test_streaming import streaming_format_tests @@ -2065,7 +2065,7 @@ def test_prompt_factory_nested(): def test_get_token_url(): - from litellm.llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import ( + from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( VertexLLM, ) @@ -2087,7 +2087,7 @@ def test_get_token_url(): vertex_credentials=vertex_credentials, gemini_api_key="", custom_llm_provider="vertex_ai_beta", - should_use_v1beta1_features=should_use_v1beta1_features, + should_use_vertex_v1beta1_features=should_use_v1beta1_features, api_base=None, model="", stream=False, @@ -2107,7 +2107,7 @@ def test_get_token_url(): vertex_credentials=vertex_credentials, gemini_api_key="", custom_llm_provider="vertex_ai_beta", - should_use_v1beta1_features=should_use_v1beta1_features, + should_use_vertex_v1beta1_features=should_use_v1beta1_features, api_base=None, model="", stream=False,