From ce1c6ec99b633f4e3de74678e7bc49f971d89de5 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 11 Sep 2024 09:07:10 -0700 Subject: [PATCH 01/15] fix(caching.py): set ttl for async_increment cache fixes issue where ttl for redis client was not being set on increment_cache Fixes https://github.com/BerriAI/litellm/issues/5609 --- litellm/caching.py | 12 +++- litellm/router.py | 6 +- litellm/tests/test_tpm_rpm_routing_v2.py | 70 ++++++++++++++++++++++++ 3 files changed, 85 insertions(+), 3 deletions(-) diff --git a/litellm/caching.py b/litellm/caching.py index 5add0cd8e923..0a806dc37ee9 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -606,12 +606,22 @@ async def batch_cache_write(self, key, value, **kwargs): if len(self.redis_batch_writing_buffer) >= self.redis_flush_size: await self.flush_cache_buffer() # logging done in here - async def async_increment(self, key, value: float, **kwargs) -> float: + async def async_increment( + self, key, value: float, ttl: Optional[float] = None, **kwargs + ) -> float: _redis_client = self.init_async_client() start_time = time.time() try: async with _redis_client as redis_client: result = await redis_client.incrbyfloat(name=key, amount=value) + + if ttl is not None: + # check if key already has ttl, if not -> set ttl + current_ttl = await redis_client.ttl(key) + if current_ttl == -1: + # Key has no expiration + await redis_client.expire(key, ttl) + ## LOGGING ## end_time = time.time() _duration = end_time - start_time diff --git a/litellm/router.py b/litellm/router.py index 5a01f4f39584..a7a2fa9e22d8 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -263,7 +263,7 @@ def __init__( self.debug_level = debug_level self.enable_pre_call_checks = enable_pre_call_checks self.enable_tag_filtering = enable_tag_filtering - if self.set_verbose == True: + if self.set_verbose is True: if debug_level == "INFO": verbose_router_logger.setLevel(logging.INFO) elif debug_level == "DEBUG": @@ -3690,7 +3690,9 @@ def _set_cooldown_deployments( exception=original_exception, ) - allowed_fails = _allowed_fails if _allowed_fails is not None else self.allowed_fails + allowed_fails = ( + _allowed_fails if _allowed_fails is not None else self.allowed_fails + ) dt = get_utc_datetime() current_minute = dt.strftime("%H-%M") diff --git a/litellm/tests/test_tpm_rpm_routing_v2.py b/litellm/tests/test_tpm_rpm_routing_v2.py index 1f3de09104f3..4d50f0e896f7 100644 --- a/litellm/tests/test_tpm_rpm_routing_v2.py +++ b/litellm/tests/test_tpm_rpm_routing_v2.py @@ -17,6 +17,8 @@ sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path +from unittest.mock import AsyncMock, MagicMock, patch + import pytest import litellm @@ -459,3 +461,71 @@ async def test_router_completion_streaming(): - Unit test for sync 'pre_call_checks' - Unit test for async 'async_pre_call_checks' """ + + +@pytest.mark.asyncio +async def test_router_caching_ttl(): + """ + Confirm caching ttl's work as expected. + + Relevant issue: https://github.com/BerriAI/litellm/issues/5609 + """ + messages = [ + {"role": "user", "content": "Hello, can you generate a 500 words poem?"} + ] + model = "azure-model" + model_list = [ + { + "model_name": "azure-model", + "litellm_params": { + "model": "azure/gpt-turbo", + "api_key": "os.environ/AZURE_FRANCE_API_KEY", + "api_base": "https://openai-france-1234.openai.azure.com", + "tpm": 1440, + "mock_response": "Hello world", + }, + "model_info": {"id": 1}, + } + ] + router = Router( + model_list=model_list, + routing_strategy="usage-based-routing-v2", + set_verbose=False, + redis_host=os.getenv("REDIS_HOST"), + redis_password=os.getenv("REDIS_PASSWORD"), + redis_port=os.getenv("REDIS_PORT"), + ) + + assert router.cache.redis_cache is not None + + increment_cache_kwargs = {} + with patch.object( + router.cache.redis_cache, + "async_increment", + new=AsyncMock(), + ) as mock_client: + await router.acompletion(model=model, messages=messages) + + mock_client.assert_called_once() + print(f"mock_client.call_args.kwargs: {mock_client.call_args.kwargs}") + print(f"mock_client.call_args.args: {mock_client.call_args.args}") + + increment_cache_kwargs = { + "key": mock_client.call_args.args[0], + "value": mock_client.call_args.args[1], + "ttl": mock_client.call_args.kwargs["ttl"], + } + + assert mock_client.call_args.kwargs["ttl"] == 60 + + ## call redis async increment and check if ttl correctly set + await router.cache.redis_cache.async_increment(**increment_cache_kwargs) + + _redis_client = router.cache.redis_cache.init_async_client() + + async with _redis_client as redis_client: + current_ttl = await redis_client.ttl(increment_cache_kwargs["key"]) + + assert current_ttl >= 0 + + print(f"current_ttl: {current_ttl}") From 017283ba0bc8905b3f71c52e7afd0544e4dd2ba9 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 11 Sep 2024 09:24:07 -0700 Subject: [PATCH 02/15] fix(caching.py): fix increment cache w/ ttl for sync increment cache on redis Fixes https://github.com/BerriAI/litellm/issues/5609 --- litellm/caching.py | 35 ++++--------- litellm/tests/test_tpm_rpm_routing_v2.py | 67 ++++++++++++++++++++++++ 2 files changed, 77 insertions(+), 25 deletions(-) diff --git a/litellm/caching.py b/litellm/caching.py index 0a806dc37ee9..0a9fef4175f3 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -304,40 +304,25 @@ def set_cache(self, key, value, **kwargs): f"LiteLLM Caching: set() - Got exception from REDIS : {str(e)}" ) - def increment_cache(self, key, value: int, **kwargs) -> int: + def increment_cache( + self, key, value: int, ttl: Optional[float] = None, **kwargs + ) -> int: _redis_client = self.redis_client start_time = time.time() try: result = _redis_client.incr(name=key, amount=value) - ## LOGGING ## - end_time = time.time() - _duration = end_time - start_time - asyncio.create_task( - self.service_logger_obj.service_success_hook( - service=ServiceTypes.REDIS, - duration=_duration, - call_type="increment_cache", - start_time=start_time, - end_time=end_time, - parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs), - ) - ) + + if ttl is not None: + # check if key already has ttl, if not -> set ttl + current_ttl = _redis_client.ttl(key) + if current_ttl == -1: + # Key has no expiration + _redis_client.expire(key, ttl) return result except Exception as e: ## LOGGING ## end_time = time.time() _duration = end_time - start_time - asyncio.create_task( - self.service_logger_obj.async_service_failure_hook( - service=ServiceTypes.REDIS, - duration=_duration, - error=e, - call_type="increment_cache", - start_time=start_time, - end_time=end_time, - parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs), - ) - ) verbose_logger.error( "LiteLLM Redis Caching: increment_cache() - Got exception from REDIS %s, Writing value=%s", str(e), diff --git a/litellm/tests/test_tpm_rpm_routing_v2.py b/litellm/tests/test_tpm_rpm_routing_v2.py index 4d50f0e896f7..259bd0ee0ec5 100644 --- a/litellm/tests/test_tpm_rpm_routing_v2.py +++ b/litellm/tests/test_tpm_rpm_routing_v2.py @@ -529,3 +529,70 @@ async def test_router_caching_ttl(): assert current_ttl >= 0 print(f"current_ttl: {current_ttl}") + + +def test_router_caching_ttl_sync(): + """ + Confirm caching ttl's work as expected. + + Relevant issue: https://github.com/BerriAI/litellm/issues/5609 + """ + messages = [ + {"role": "user", "content": "Hello, can you generate a 500 words poem?"} + ] + model = "azure-model" + model_list = [ + { + "model_name": "azure-model", + "litellm_params": { + "model": "azure/gpt-turbo", + "api_key": "os.environ/AZURE_FRANCE_API_KEY", + "api_base": "https://openai-france-1234.openai.azure.com", + "tpm": 1440, + "mock_response": "Hello world", + }, + "model_info": {"id": 1}, + } + ] + router = Router( + model_list=model_list, + routing_strategy="usage-based-routing-v2", + set_verbose=False, + redis_host=os.getenv("REDIS_HOST"), + redis_password=os.getenv("REDIS_PASSWORD"), + redis_port=os.getenv("REDIS_PORT"), + ) + + assert router.cache.redis_cache is not None + + increment_cache_kwargs = {} + with patch.object( + router.cache.redis_cache, + "increment_cache", + new=MagicMock(), + ) as mock_client: + router.completion(model=model, messages=messages) + + print(mock_client.call_args_list) + mock_client.assert_called() + print(f"mock_client.call_args.kwargs: {mock_client.call_args.kwargs}") + print(f"mock_client.call_args.args: {mock_client.call_args.args}") + + increment_cache_kwargs = { + "key": mock_client.call_args.args[0], + "value": mock_client.call_args.args[1], + "ttl": mock_client.call_args.kwargs["ttl"], + } + + assert mock_client.call_args.kwargs["ttl"] == 60 + + ## call redis async increment and check if ttl correctly set + router.cache.redis_cache.increment_cache(**increment_cache_kwargs) + + _redis_client = router.cache.redis_cache.redis_client + + current_ttl = _redis_client.ttl(increment_cache_kwargs["key"]) + + assert current_ttl >= 0 + + print(f"current_ttl: {current_ttl}") From 0d83028cfd991484ee19b2e0cfc8b98a01c6bc56 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 11 Sep 2024 10:41:21 -0700 Subject: [PATCH 03/15] fix(router.py): support adding retry policy + allowed fails policy via config.yaml --- docs/my-website/docs/routing.md | 24 ++++++++++++ litellm/proxy/_experimental/out/404.html | 1 - .../proxy/_experimental/out/model_hub.html | 1 - .../proxy/_experimental/out/onboarding.html | 1 - litellm/proxy/_new_secret_config.yaml | 6 ++- litellm/router.py | 39 +++++++++++++++---- 6 files changed, 61 insertions(+), 11 deletions(-) delete mode 100644 litellm/proxy/_experimental/out/404.html delete mode 100644 litellm/proxy/_experimental/out/model_hub.html delete mode 100644 litellm/proxy/_experimental/out/onboarding.html diff --git a/docs/my-website/docs/routing.md b/docs/my-website/docs/routing.md index c7c6c3c97045..87925516aa04 100644 --- a/docs/my-website/docs/routing.md +++ b/docs/my-website/docs/routing.md @@ -1038,6 +1038,12 @@ print(f"response: {response}") - Use `RetryPolicy` if you want to set a `num_retries` based on the Exception receieved - Use `AllowedFailsPolicy` to set a custom number of `allowed_fails`/minute before cooling down a deployment +[**See All Exception Types**](https://github.com/BerriAI/litellm/blob/ccda616f2f881375d4e8586c76fe4662909a7d22/litellm/types/router.py#L436) + + + + + Example: ```python @@ -1101,6 +1107,24 @@ response = await router.acompletion( ) ``` + + + +```yaml +router_settings: + retry_policy: { + "BadRequestErrorRetries": 3, + "ContentPolicyViolationErrorRetries": 4 + } + allowed_fails_policy: { + "ContentPolicyViolationErrorAllowedFails": 1000, # Allow 1000 ContentPolicyViolationError before cooling down a deployment + "RateLimitErrorAllowedFails": 100 # Allow 100 RateLimitErrors before cooling down a deployment + } +``` + + + + ### Fallbacks diff --git a/litellm/proxy/_experimental/out/404.html b/litellm/proxy/_experimental/out/404.html deleted file mode 100644 index 34d1b613de74..000000000000 --- a/litellm/proxy/_experimental/out/404.html +++ /dev/null @@ -1 +0,0 @@ -404: This page could not be found.LiteLLM Dashboard

404

This page could not be found.

\ No newline at end of file diff --git a/litellm/proxy/_experimental/out/model_hub.html b/litellm/proxy/_experimental/out/model_hub.html deleted file mode 100644 index 07e68f30e38d..000000000000 --- a/litellm/proxy/_experimental/out/model_hub.html +++ /dev/null @@ -1 +0,0 @@ -LiteLLM Dashboard \ No newline at end of file diff --git a/litellm/proxy/_experimental/out/onboarding.html b/litellm/proxy/_experimental/out/onboarding.html deleted file mode 100644 index abb658918b3a..000000000000 --- a/litellm/proxy/_experimental/out/onboarding.html +++ /dev/null @@ -1 +0,0 @@ -LiteLLM Dashboard \ No newline at end of file diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index bf86da1e12a4..ac873ae8942f 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -6,4 +6,8 @@ model_list: api_base: os.environ/AZURE_API_BASE router_settings: - model_group_alias: {"gpt-4": "gpt-turbo"} \ No newline at end of file + model_group_alias: {"gpt-4": "gpt-turbo"} + retry_policy: { + "BadRequestErrorRetries": 3, + "ContentPolicyViolationErrorRetries": 4 + } \ No newline at end of file diff --git a/litellm/router.py b/litellm/router.py index a7a2fa9e22d8..c4510e316370 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -161,10 +161,10 @@ def __init__( enable_tag_filtering: bool = False, retry_after: int = 0, # min time to wait before retrying a failed request retry_policy: Optional[ - RetryPolicy + Union[RetryPolicy, dict] ] = None, # set custom retries for different exceptions - model_group_retry_policy: Optional[ - Dict[str, RetryPolicy] + model_group_retry_policy: Dict[ + str, RetryPolicy ] = {}, # set custom retry policies based on model group allowed_fails: Optional[ int @@ -454,11 +454,35 @@ def __init__( ) self.routing_strategy_args = routing_strategy_args - self.retry_policy: Optional[RetryPolicy] = retry_policy + self.retry_policy: Optional[RetryPolicy] = None + if retry_policy is not None: + if isinstance(retry_policy, dict): + self.retry_policy = RetryPolicy(**retry_policy) + elif isinstance(retry_policy, RetryPolicy): + self.retry_policy = retry_policy + verbose_router_logger.info( + "\033[32mRouter Custom Retry Policy Set:\n{}\033[0m".format( + self.retry_policy.model_dump(exclude_none=True) + ) + ) + self.model_group_retry_policy: Optional[Dict[str, RetryPolicy]] = ( model_group_retry_policy ) - self.allowed_fails_policy: Optional[AllowedFailsPolicy] = allowed_fails_policy + + self.allowed_fails_policy: Optional[AllowedFailsPolicy] = None + if allowed_fails_policy is not None: + if isinstance(allowed_fails_policy, dict): + self.allowed_fails_policy = AllowedFailsPolicy(**allowed_fails_policy) + elif isinstance(allowed_fails_policy, AllowedFailsPolicy): + self.allowed_fails_policy = allowed_fails_policy + + verbose_router_logger.info( + "\033[32mRouter Custom Allowed Fails Policy Set:\n{}\033[0m".format( + self.allowed_fails_policy.model_dump(exclude_none=True) + ) + ) + self.alerting_config: Optional[AlertingConfig] = alerting_config if self.alerting_config is not None: self._initialize_alerting() @@ -5530,18 +5554,19 @@ def get_num_retries_from_retry_policy( ContentPolicyViolationErrorRetries: Optional[int] = None """ # if we can find the exception then in the retry policy -> return the number of retries - retry_policy = self.retry_policy + retry_policy: Optional[RetryPolicy] = self.retry_policy if ( self.model_group_retry_policy is not None and model_group is not None and model_group in self.model_group_retry_policy ): - retry_policy = self.model_group_retry_policy.get(model_group, None) + retry_policy = self.model_group_retry_policy.get(model_group, None) # type: ignore if retry_policy is None: return None if isinstance(retry_policy, dict): retry_policy = RetryPolicy(**retry_policy) + if ( isinstance(exception, litellm.BadRequestError) and retry_policy.BadRequestErrorRetries is not None From 68707ea2a06f9fe823bde017723788b4f05ece53 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 11 Sep 2024 11:26:39 -0700 Subject: [PATCH 04/15] fix(router.py): don't cooldown single deployments No point, as there's no other deployment to loadbalance with. --- litellm/router.py | 47 ++++++++++++++++++++++++-- litellm/tests/test_router_cooldowns.py | 38 +++++++++++++++++++++ 2 files changed, 82 insertions(+), 3 deletions(-) diff --git a/litellm/router.py b/litellm/router.py index c4510e316370..bc3f86163c16 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -3189,6 +3189,7 @@ def function_with_fallbacks(self, *args, **kwargs): If it fails after num_retries, fall back to another model group """ mock_testing_fallbacks = kwargs.pop("mock_testing_fallbacks", None) + model_group = kwargs.get("model") fallbacks = kwargs.get("fallbacks", self.fallbacks) context_window_fallbacks = kwargs.get( @@ -3197,6 +3198,7 @@ def function_with_fallbacks(self, *args, **kwargs): content_policy_fallbacks = kwargs.get( "content_policy_fallbacks", self.content_policy_fallbacks ) + try: if mock_testing_fallbacks is not None and mock_testing_fallbacks == True: raise Exception( @@ -3348,6 +3350,9 @@ def function_with_retries(self, *args, **kwargs): f"Inside function with retries: args - {args}; kwargs - {kwargs}" ) original_function = kwargs.pop("original_function") + mock_testing_rate_limit_error = kwargs.pop( + "mock_testing_rate_limit_error", None + ) num_retries = kwargs.pop("num_retries") fallbacks = kwargs.pop("fallbacks", self.fallbacks) context_window_fallbacks = kwargs.pop( @@ -3356,9 +3361,22 @@ def function_with_retries(self, *args, **kwargs): content_policy_fallbacks = kwargs.pop( "content_policy_fallbacks", self.content_policy_fallbacks ) + model_group = kwargs.get("model") try: # if the function call is successful, no exception will be raised and we'll break out of the loop + if ( + mock_testing_rate_limit_error is not None + and mock_testing_rate_limit_error is True + ): + verbose_router_logger.info( + "litellm.router.py::async_function_with_retries() - mock_testing_rate_limit_error=True. Raising litellm.RateLimitError." + ) + raise litellm.RateLimitError( + model=model_group, + llm_provider="", + message=f"This is a mock exception for model={model_group}, to trigger a rate limit error.", + ) response = original_function(*args, **kwargs) return response except Exception as e: @@ -3595,17 +3613,26 @@ def _update_usage(self, deployment_id: str): ) # don't change existing ttl def _is_cooldown_required( - self, exception_status: Union[str, int], exception_str: Optional[str] = None - ): + self, + model_id: str, + exception_status: Union[str, int], + exception_str: Optional[str] = None, + ) -> bool: """ A function to determine if a cooldown is required based on the exception status. Parameters: + model_id (str) The id of the model in the model list exception_status (Union[str, int]): The status of the exception. Returns: bool: True if a cooldown is required, False otherwise. """ + ## BASE CASE - single deployment + model_group = self.get_model_group(id=model_id) + if model_group is not None and len(model_group) == 1: + return False + try: ignored_strings = ["APIConnectionError"] if ( @@ -3701,7 +3728,9 @@ def _set_cooldown_deployments( if ( self._is_cooldown_required( - exception_status=exception_status, exception_str=str(original_exception) + model_id=deployment, + exception_status=exception_status, + exception_str=str(original_exception), ) is False ): @@ -4324,6 +4353,18 @@ def get_model_info(self, id: str) -> Optional[dict]: return model return None + def get_model_group(self, id: str) -> Optional[List]: + """ + Return list of all models in the same model group as that model id + """ + + model_info = self.get_model_info(id=id) + if model_info is None: + return None + + model_name = model_info["model_name"] + return self.get_model_list(model_name=model_name) + def _set_model_group_info( self, model_group: str, user_facing_model_group_name: str ) -> Optional[ModelGroupInfo]: diff --git a/litellm/tests/test_router_cooldowns.py b/litellm/tests/test_router_cooldowns.py index 3eef6e54231c..ac92dfbf0705 100644 --- a/litellm/tests/test_router_cooldowns.py +++ b/litellm/tests/test_router_cooldowns.py @@ -21,6 +21,7 @@ import litellm from litellm import Router from litellm.integrations.custom_logger import CustomLogger +from litellm.types.router import DeploymentTypedDict, LiteLLMParamsTypedDict @pytest.mark.asyncio @@ -112,3 +113,40 @@ async def test_dynamic_cooldowns(): assert "cooldown_time" in tmp_mock.call_args[0][0]["litellm_params"] assert tmp_mock.call_args[0][0]["litellm_params"]["cooldown_time"] == 0 + + +@pytest.mark.parametrize("num_deployments", [1, 2]) +def test_single_deployment_no_cooldowns(num_deployments): + """ + Do not cooldown on single deployment. + + Cooldown on multiple deployments. + """ + model_list = [] + for i in range(num_deployments): + model = DeploymentTypedDict( + model_name="gpt-3.5-turbo", + litellm_params=LiteLLMParamsTypedDict( + model="gpt-3.5-turbo", + ), + ) + model_list.append(model) + + router = Router(model_list=model_list, allowed_fails=0, num_retries=0) + + with patch.object( + router.cooldown_cache, "add_deployment_to_cooldown", new=MagicMock() + ) as mock_client: + try: + router.completion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hey, how's it going?"}], + mock_response="litellm.RateLimitError", + ) + except litellm.RateLimitError: + pass + + if num_deployments == 1: + mock_client.assert_not_called() + else: + mock_client.assert_called_once() From f9486ed00a68a9910195c01780e49e633f90d9d7 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 11 Sep 2024 14:56:51 -0700 Subject: [PATCH 05/15] fix(user_api_key_auth.py): support setting allowed email domains on jwt tokens Closes https://github.com/BerriAI/litellm/issues/5605 --- litellm/proxy/_types.py | 4 + litellm/proxy/auth/handle_jwt.py | 43 ++++++- litellm/proxy/auth/user_api_key_auth.py | 30 ++++- litellm/tests/test_jwt.py | 150 +++++++++++++++++++++++- 4 files changed, 217 insertions(+), 10 deletions(-) diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 3559a4792f1c..e7c750051aad 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -386,6 +386,8 @@ class LiteLLM_JWTAuth(LiteLLMBase): - team_id_jwt_field: The field in the JWT token that stores the team ID. Default - `client_id`. - team_allowed_routes: list of allowed routes for proxy team roles. - user_id_jwt_field: The field in the JWT token that stores the user id (maps to `LiteLLMUserTable`). Use this for internal employees. + - user_email_jwt_field: The field in the JWT token that stores the user email (maps to `LiteLLMUserTable`). Use this for internal employees. + - user_allowed_email_subdomain: If specified, only emails from specified subdomain will be allowed to access proxy. - end_user_id_jwt_field: The field in the JWT token that stores the end-user ID (maps to `LiteLLMEndUserTable`). Turn this off by setting to `None`. Enables end-user cost tracking. Use this for external customers. - public_key_ttl: Default - 600s. TTL for caching public JWT keys. @@ -417,6 +419,8 @@ class LiteLLM_JWTAuth(LiteLLMBase): ) org_id_jwt_field: Optional[str] = None user_id_jwt_field: Optional[str] = None + user_email_jwt_field: Optional[str] = None + user_allowed_email_domain: Optional[str] = None user_id_upsert: bool = Field( default=False, description="If user doesn't exist, upsert them into the db." ) diff --git a/litellm/proxy/auth/handle_jwt.py b/litellm/proxy/auth/handle_jwt.py index f8618781ff20..b39064ae6143 100644 --- a/litellm/proxy/auth/handle_jwt.py +++ b/litellm/proxy/auth/handle_jwt.py @@ -78,6 +78,19 @@ def is_required_team_id(self) -> bool: return False return True + def is_enforced_email_domain(self) -> bool: + """ + Returns: + - True: if 'user_allowed_email_domain' is set + - False: if 'user_allowed_email_domain' is None + """ + + if self.litellm_jwtauth.user_allowed_email_domain is not None and isinstance( + self.litellm_jwtauth.user_allowed_email_domain, str + ): + return True + return False + def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]: try: if self.litellm_jwtauth.team_id_jwt_field is not None: @@ -90,12 +103,14 @@ def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str team_id = default_value return team_id - def is_upsert_user_id(self) -> bool: + def is_upsert_user_id(self, valid_user_email: Optional[bool] = None) -> bool: """ Returns: - - True: if 'user_id_upsert' is set + - True: if 'user_id_upsert' is set AND valid_user_email is not False - False: if not """ + if valid_user_email is False: + return False return self.litellm_jwtauth.user_id_upsert def get_user_id(self, token: dict, default_value: Optional[str]) -> Optional[str]: @@ -103,11 +118,23 @@ def get_user_id(self, token: dict, default_value: Optional[str]) -> Optional[str if self.litellm_jwtauth.user_id_jwt_field is not None: user_id = token[self.litellm_jwtauth.user_id_jwt_field] else: - user_id = None + user_id = default_value except KeyError: user_id = default_value return user_id + def get_user_email( + self, token: dict, default_value: Optional[str] + ) -> Optional[str]: + try: + if self.litellm_jwtauth.user_email_jwt_field is not None: + user_email = token[self.litellm_jwtauth.user_email_jwt_field] + else: + user_email = None + except KeyError: + user_email = default_value + return user_email + def get_org_id(self, token: dict, default_value: Optional[str]) -> Optional[str]: try: if self.litellm_jwtauth.org_id_jwt_field is not None: @@ -183,6 +210,16 @@ async def get_public_key(self, kid: Optional[str]) -> dict: return public_key + def is_allowed_domain(self, user_email: str) -> bool: + if self.litellm_jwtauth.user_allowed_email_domain is None: + return True + + email_domain = user_email.split("@")[-1] # Extract domain from email + if email_domain == self.litellm_jwtauth.user_allowed_email_domain: + return True + else: + return False + async def auth_jwt(self, token: str) -> dict: # Supported algos: https://pyjwt.readthedocs.io/en/stable/algorithms.html # "Warning: Make sure not to mix symmetric and asymmetric algorithms that interpret diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index deee81ffdd71..114f27d4429b 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -250,6 +250,7 @@ async def user_api_key_auth( raise Exception( f"Admin not allowed to access this route. Route={route}, Allowed Routes={actual_routes}" ) + # get team id team_id = jwt_handler.get_team_id( token=jwt_valid_token, default_value=None @@ -296,10 +297,30 @@ async def user_api_key_auth( parent_otel_span=parent_otel_span, proxy_logging_obj=proxy_logging_obj, ) + # [OPTIONAL] allowed user email domains + valid_user_email: Optional[bool] = None + user_email: Optional[str] = None + if jwt_handler.is_enforced_email_domain(): + """ + if 'allowed_email_subdomains' is set, + + - checks if token contains 'email' field + - checks if 'email' is from an allowed domain + """ + user_email = jwt_handler.get_user_email( + token=jwt_valid_token, default_value=None + ) + if user_email is None: + valid_user_email = False + else: + valid_user_email = jwt_handler.is_allowed_domain( + user_email=user_email + ) + # [OPTIONAL] track spend against an internal employee - `LiteLLM_UserTable` user_object = None user_id = jwt_handler.get_user_id( - token=jwt_valid_token, default_value=None + token=jwt_valid_token, default_value=user_email ) if user_id is not None: # get the user object @@ -307,11 +328,12 @@ async def user_api_key_auth( user_id=user_id, prisma_client=prisma_client, user_api_key_cache=user_api_key_cache, - user_id_upsert=jwt_handler.is_upsert_user_id(), + user_id_upsert=jwt_handler.is_upsert_user_id( + valid_user_email=valid_user_email + ), parent_otel_span=parent_otel_span, proxy_logging_obj=proxy_logging_obj, ) - # [OPTIONAL] track spend against an external user - `LiteLLM_EndUserTable` end_user_object = None end_user_id = jwt_handler.get_end_user_id( @@ -802,7 +824,7 @@ async def user_api_key_auth( # collect information for alerting # #################################### - user_email: Optional[str] = None + user_email = None # Check if the token has any user id information if user_obj is not None: user_email = user_obj.user_email diff --git a/litellm/tests/test_jwt.py b/litellm/tests/test_jwt.py index ddafdb933fc3..51bf55c9c202 100644 --- a/litellm/tests/test_jwt.py +++ b/litellm/tests/test_jwt.py @@ -23,8 +23,9 @@ import pytest from fastapi import Request +import litellm from litellm.caching import DualCache -from litellm.proxy._types import LiteLLM_JWTAuth, LiteLLMRoutes +from litellm.proxy._types import LiteLLM_JWTAuth, LiteLLM_UserTable, LiteLLMRoutes from litellm.proxy.auth.handle_jwt import JWTHandler from litellm.proxy.management_endpoints.team_endpoints import new_team from litellm.proxy.proxy_server import chat_completion @@ -816,8 +817,6 @@ async def test_allowed_routes_admin(prisma_client, audience): raise e -from unittest.mock import AsyncMock - import pytest @@ -844,3 +843,148 @@ async def test_team_cache_update_called(): await asyncio.sleep(3) mock_call_cache.assert_awaited_once() + + +@pytest.fixture +def public_jwt_key(): + import json + + import jwt + from cryptography.hazmat.backends import default_backend + from cryptography.hazmat.primitives import serialization + from cryptography.hazmat.primitives.asymmetric import rsa + + # Generate a private / public key pair using RSA algorithm + key = rsa.generate_private_key( + public_exponent=65537, key_size=2048, backend=default_backend() + ) + # Get private key in PEM format + private_key = key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + + # Get public key in PEM format + public_key = key.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + + public_key_obj = serialization.load_pem_public_key( + public_key, backend=default_backend() + ) + + # Convert RSA public key object to JWK (JSON Web Key) + public_jwk = json.loads(jwt.algorithms.RSAAlgorithm.to_jwk(public_key_obj)) + + return {"private_key": private_key, "public_jwk": public_jwk} + + +def mock_user_object(*args, **kwargs): + print("Args: {}".format(args)) + print("kwargs: {}".format(kwargs)) + assert kwargs["user_id_upsert"] is True + + +@pytest.mark.parametrize( + "user_email, should_work", [("ishaan@berri.ai", True), ("krrish@tassle.xyz", False)] +) +@pytest.mark.asyncio +async def test_allow_access_by_email(public_jwt_key, user_email, should_work): + """ + Allow anyone with an `@xyz.com` email make a request to the proxy. + + Relevant issue: https://github.com/BerriAI/litellm/issues/5605 + """ + import jwt + from starlette.datastructures import URL + + from litellm.proxy._types import NewTeamRequest, UserAPIKeyAuth + from litellm.proxy.proxy_server import user_api_key_auth + + public_jwk = public_jwt_key["public_jwk"] + private_key = public_jwt_key["private_key"] + + # set cache + cache = DualCache() + + await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk]) + + jwt_handler = JWTHandler() + + jwt_handler.user_api_key_cache = cache + + jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth( + user_email_jwt_field="email", + user_allowed_email_domain="berri.ai", + user_id_upsert=True, + ) + + # VALID TOKEN + ## GENERATE A TOKEN + # Assuming the current time is in UTC + expiration_time = int((datetime.utcnow() + timedelta(minutes=10)).timestamp()) + + team_id = f"team123_{uuid.uuid4()}" + payload = { + "sub": "user123", + "exp": expiration_time, # set the token to expire in 10 minutes + "scope": "litellm_team", + "client_id": team_id, + "aud": "litellm-proxy", + "email": user_email, + } + + # Generate the JWT token + # But before, you should convert bytes to string + private_key_str = private_key.decode("utf-8") + + ## team token + token = jwt.encode(payload, private_key_str, algorithm="RS256") + + ## VERIFY IT WORKS + # Expect the call to succeed + response = await jwt_handler.auth_jwt(token=token) + assert response is not None # Adjust this based on your actual response check + + ## RUN IT THROUGH USER API KEY AUTH + bearer_token = "Bearer " + token + + request = Request(scope={"type": "http"}) + + request._url = URL(url="/chat/completions") + + ## 1. INITIAL TEAM CALL - should fail + # use generated key to auth in + setattr( + litellm.proxy.proxy_server, + "general_settings", + { + "enable_jwt_auth": True, + }, + ) + setattr(litellm.proxy.proxy_server, "jwt_handler", jwt_handler) + setattr(litellm.proxy.proxy_server, "prisma_client", {}) + + # AsyncMock( + # return_value=LiteLLM_UserTable( + # spend=0, user_id=user_email, max_budget=None, user_email=user_email + # ) + # ), + with patch.object( + litellm.proxy.auth.user_api_key_auth, + "get_user_object", + side_effect=mock_user_object, + ) as mock_client: + if should_work: + # Expect the call to succeed + result = await user_api_key_auth(request=request, api_key=bearer_token) + assert result is not None # Adjust this based on your actual response check + else: + # Expect the call to fail + with pytest.raises( + Exception + ): # Replace with the actual exception raised on failure + resp = await user_api_key_auth(request=request, api_key=bearer_token) + print(resp) From 3a9a0767c1e4a2648c5b228b524c09c5ce2ae293 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 11 Sep 2024 16:00:52 -0700 Subject: [PATCH 06/15] docs(token_auth.md): add user upsert + allowed email domain to jwt auth docs --- docs/my-website/docs/proxy/token_auth.md | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/docs/my-website/docs/proxy/token_auth.md b/docs/my-website/docs/proxy/token_auth.md index 659cc6edf062..049ea0f98ec4 100644 --- a/docs/my-website/docs/proxy/token_auth.md +++ b/docs/my-website/docs/proxy/token_auth.md @@ -243,3 +243,17 @@ curl --location 'http://0.0.0.0:4000/team/unblock' \ }' ``` + +## Advanced - Upsert Users + Allowed Email Domains + +Allow users who belong to a specific email domain, automatic access to the proxy. + +```yaml +general_settings: + master_key: sk-1234 + enable_jwt_auth: True + litellm_jwtauth: + user_email_jwt_field: "email" # 👈 checks 'email' field in jwt payload + user_allowed_email_domain: "my-co.com" # allows user@my-co.com to call proxy + user_id_upsert: true # 👈 upserts the user to db, if valid email but not in db +``` \ No newline at end of file From a741fd3009de8553a28ed68efa2e1f0ab1dd294e Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 11 Sep 2024 17:47:24 -0700 Subject: [PATCH 07/15] fix(litellm_pre_call_utils.py): fix dynamic key logging when team id is set Fixes issue where key logging would not be set if team metadata was not none --- litellm/proxy/_new_secret_config.yaml | 9 +-- litellm/proxy/litellm_pre_call_utils.py | 21 +++---- litellm/tests/test_proxy_utils.py | 74 ++++++++++++++++++++++++- 3 files changed, 84 insertions(+), 20 deletions(-) diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index ac873ae8942f..985ab8710121 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -3,11 +3,4 @@ model_list: litellm_params: model: azure/chatgpt-v-2 api_key: os.environ/AZURE_API_KEY - api_base: os.environ/AZURE_API_BASE - -router_settings: - model_group_alias: {"gpt-4": "gpt-turbo"} - retry_policy: { - "BadRequestErrorRetries": 3, - "ContentPolicyViolationErrorRetries": 4 - } \ No newline at end of file + api_base: os.environ/AZURE_API_BASE \ No newline at end of file diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index 890c576c9443..4c6172a4db6e 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -107,7 +107,16 @@ def _get_dynamic_logging_metadata( user_api_key_dict: UserAPIKeyAuth, ) -> Optional[TeamCallbackMetadata]: callback_settings_obj: Optional[TeamCallbackMetadata] = None - if user_api_key_dict.team_metadata is not None: + if ( + user_api_key_dict.metadata is not None + and "logging" in user_api_key_dict.metadata + ): + for item in user_api_key_dict.metadata["logging"]: + callback_settings_obj = convert_key_logging_metadata_to_callback( + data=AddTeamCallback(**item), + team_callback_settings_obj=callback_settings_obj, + ) + elif user_api_key_dict.team_metadata is not None: team_metadata = user_api_key_dict.team_metadata if "callback_settings" in team_metadata: callback_settings = team_metadata.get("callback_settings", None) or {} @@ -124,15 +133,7 @@ def _get_dynamic_logging_metadata( } } """ - elif ( - user_api_key_dict.metadata is not None - and "logging" in user_api_key_dict.metadata - ): - for item in user_api_key_dict.metadata["logging"]: - callback_settings_obj = convert_key_logging_metadata_to_callback( - data=AddTeamCallback(**item), - team_callback_settings_obj=callback_settings_obj, - ) + return callback_settings_obj diff --git a/litellm/tests/test_proxy_utils.py b/litellm/tests/test_proxy_utils.py index 63361b09aaaf..576e86f070dd 100644 --- a/litellm/tests/test_proxy_utils.py +++ b/litellm/tests/test_proxy_utils.py @@ -11,8 +11,11 @@ sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path -from litellm.proxy._types import UserAPIKeyAuth -from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request +from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth +from litellm.proxy.litellm_pre_call_utils import ( + _get_dynamic_logging_metadata, + add_litellm_data_to_request, +) from litellm.types.utils import SupportedCacheControls @@ -204,3 +207,70 @@ async def test_add_key_or_team_level_spend_logs_metadata_to_request( # assert ( # new_data["metadata"]["spend_logs_metadata"] == metadata["spend_logs_metadata"] # ) + + +def test_dynamic_logging_metadata_key_and_team_metadata(): + user_api_key_dict = UserAPIKeyAuth( + token="6f8688eaff1d37555bb9e9a6390b6d7032b3ab2526ba0152da87128eab956432", + key_name="sk-...63Fg", + key_alias=None, + spend=0.000111, + max_budget=None, + expires=None, + models=[], + aliases={}, + config={}, + user_id=None, + team_id="ishaan-special-team_e02dd54f-f790-4755-9f93-73734f415898", + max_parallel_requests=None, + metadata={ + "logging": [ + { + "callback_name": "langfuse", + "callback_type": "success", + "callback_vars": { + "langfuse_host": "https://us.cloud.langfuse.com", + "langfuse_public_key": "pk-lf-9636b7a6-c066", + "langfuse_secret_key": "sk-lf-7cc8b620", + }, + } + ] + }, + tpm_limit=None, + rpm_limit=None, + budget_duration=None, + budget_reset_at=None, + allowed_cache_controls=[], + permissions={}, + model_spend={}, + model_max_budget={}, + soft_budget_cooldown=False, + litellm_budget_table=None, + org_id=None, + team_spend=0.000132, + team_alias=None, + team_tpm_limit=None, + team_rpm_limit=None, + team_max_budget=None, + team_models=[], + team_blocked=False, + soft_budget=None, + team_model_aliases=None, + team_member_spend=None, + team_member=None, + team_metadata={}, + end_user_id=None, + end_user_tpm_limit=None, + end_user_rpm_limit=None, + end_user_max_budget=None, + last_refreshed_at=1726101560.967527, + api_key="7c305cc48fe72272700dc0d67dc691c2d1f2807490ef5eb2ee1d3a3ca86e12b1", + user_role=LitellmUserRoles.INTERNAL_USER, + allowed_model_region=None, + parent_otel_span=None, + rpm_limit_per_model=None, + tpm_limit_per_model=None, + ) + callbacks = _get_dynamic_logging_metadata(user_api_key_dict=user_api_key_dict) + + assert callbacks is not None From 5008ba4332a6beb517b399bc0c9723613a6b4747 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 11 Sep 2024 18:11:47 -0700 Subject: [PATCH 08/15] fix(secret_managers/main.py): load environment variables correctly Fixes issue where os.environ/ was not being loaded correctly --- litellm/secret_managers/main.py | 38 ++++++++++++++++++++++--------- litellm/tests/test_proxy_utils.py | 26 ++++++++++++++++----- 2 files changed, 47 insertions(+), 17 deletions(-) diff --git a/litellm/secret_managers/main.py b/litellm/secret_managers/main.py index e136654c1c4c..5d1f72cf7fb9 100644 --- a/litellm/secret_managers/main.py +++ b/litellm/secret_managers/main.py @@ -29,6 +29,27 @@ def _is_base64(s): return False +def str_to_bool(value: str) -> Optional[bool]: + """ + Converts a string to a boolean if it's a recognized boolean string. + Returns None if the string is not a recognized boolean value. + + :param value: The string to be checked. + :return: True or False if the string is a recognized boolean, otherwise None. + """ + true_values = {"true"} + false_values = {"false"} + + value_lower = value.strip().lower() + + if value_lower in true_values: + return True + elif value_lower in false_values: + return False + else: + return None + + def get_secret( secret_name: str, default_value: Optional[Union[str, bool]] = None, @@ -257,17 +278,12 @@ def get_secret( return secret else: secret = os.environ.get(secret_name) - try: - secret_value_as_bool = ( - ast.literal_eval(secret) if secret is not None else None - ) - if isinstance(secret_value_as_bool, bool): - return secret_value_as_bool - else: - return secret - except Exception: - if default_value is not None: - return default_value + secret_value_as_bool = str_to_bool(secret) if secret is not None else None + if secret_value_as_bool is not None and isinstance( + secret_value_as_bool, bool + ): + return secret_value_as_bool + else: return secret except Exception as e: if default_value is not None: diff --git a/litellm/tests/test_proxy_utils.py b/litellm/tests/test_proxy_utils.py index 576e86f070dd..76f555cc95ca 100644 --- a/litellm/tests/test_proxy_utils.py +++ b/litellm/tests/test_proxy_utils.py @@ -209,7 +209,22 @@ async def test_add_key_or_team_level_spend_logs_metadata_to_request( # ) -def test_dynamic_logging_metadata_key_and_team_metadata(): +@pytest.mark.parametrize( + "callback_vars", + [ + { + "langfuse_host": "https://us.cloud.langfuse.com", + "langfuse_public_key": "pk-lf-9636b7a6-c066", + "langfuse_secret_key": "sk-lf-7cc8b620", + }, + { + "langfuse_host": "os.environ/LANGFUSE_HOST", + "langfuse_public_key": "os.environ/LANGFUSE_PUBLIC_KEY", + "langfuse_secret_key": "os.environ/LANGFUSE_SECRET_KEY", + }, + ], +) +def test_dynamic_logging_metadata_key_and_team_metadata(callback_vars): user_api_key_dict = UserAPIKeyAuth( token="6f8688eaff1d37555bb9e9a6390b6d7032b3ab2526ba0152da87128eab956432", key_name="sk-...63Fg", @@ -228,11 +243,7 @@ def test_dynamic_logging_metadata_key_and_team_metadata(): { "callback_name": "langfuse", "callback_type": "success", - "callback_vars": { - "langfuse_host": "https://us.cloud.langfuse.com", - "langfuse_public_key": "pk-lf-9636b7a6-c066", - "langfuse_secret_key": "sk-lf-7cc8b620", - }, + "callback_vars": callback_vars, } ] }, @@ -274,3 +285,6 @@ def test_dynamic_logging_metadata_key_and_team_metadata(): callbacks = _get_dynamic_logging_metadata(user_api_key_dict=user_api_key_dict) assert callbacks is not None + + for var in callbacks.callback_vars.values(): + assert "os.environ" not in var From f8d9d4433ac6228fe9e3f1f74c05ba493ab30cc4 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 11 Sep 2024 18:27:42 -0700 Subject: [PATCH 09/15] test(test_router.py): fix test --- litellm/tests/test_router.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index fd89130feb1b..643c34543265 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -2272,7 +2272,13 @@ async def test_aaarouter_dynamic_cooldown_message_retry_time(sync_mode): "litellm_params": { "model": "openai/text-embedding-ada-002", }, - } + }, + { + "model_name": "text-embedding-ada-002", + "litellm_params": { + "model": "openai/text-embedding-ada-002", + }, + }, ] ) From 6de8720cf17fb2e750079789213c531c19fa577b Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 11 Sep 2024 19:00:46 -0700 Subject: [PATCH 10/15] feat(spend_tracking_utils.py): support logging additional usage params - e.g. prompt caching values for deepseek --- .pre-commit-config.yaml | 12 ++++++------ litellm/proxy/_new_secret_config.yaml | 6 +++--- litellm/proxy/_types.py | 3 +++ litellm/proxy/spend_tracking/spend_tracking_utils.py | 8 ++++++++ 4 files changed, 20 insertions(+), 9 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a33473b72465..d429bc6b8c14 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,12 +1,12 @@ repos: - repo: local hooks: - - id: mypy - name: mypy - entry: python3 -m mypy --ignore-missing-imports - language: system - types: [python] - files: ^litellm/ + # - id: mypy + # name: mypy + # entry: python3 -m mypy --ignore-missing-imports + # language: system + # types: [python] + # files: ^litellm/ - id: isort name: isort entry: isort diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 985ab8710121..cb52992e2eec 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -1,6 +1,6 @@ model_list: - model_name: "gpt-turbo" litellm_params: - model: azure/chatgpt-v-2 - api_key: os.environ/AZURE_API_KEY - api_base: os.environ/AZURE_API_BASE \ No newline at end of file + model: deepseek/deepseek-chat + # api_key: os.environ/AZURE_API_KEY + # api_base: os.environ/AZURE_API_BASE \ No newline at end of file diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index e7c750051aad..c76be80394f3 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -1688,6 +1688,9 @@ class SpendLogsMetadata(TypedDict): Specific metadata k,v pairs logged to spendlogs for easier cost tracking """ + additional_usage_values: Optional[ + dict + ] # covers provider-specific usage information - e.g. prompt caching user_api_key: Optional[str] user_api_key_alias: Optional[str] user_api_key_team_id: Optional[str] diff --git a/litellm/proxy/spend_tracking/spend_tracking_utils.py b/litellm/proxy/spend_tracking/spend_tracking_utils.py index a1a0b97339d6..bdeef92cce80 100644 --- a/litellm/proxy/spend_tracking/spend_tracking_utils.py +++ b/litellm/proxy/spend_tracking/spend_tracking_utils.py @@ -84,6 +84,7 @@ def get_logging_payload( user_api_key_team_alias=None, spend_logs_metadata=None, requester_ip_address=None, + additional_usage_values=None, ) if isinstance(metadata, dict): verbose_proxy_logger.debug( @@ -100,6 +101,13 @@ def get_logging_payload( } ) + special_usage_fields = ["completion_tokens", "prompt_tokens", "total_tokens"] + additional_usage_values = {} + for k, v in usage.items(): + if k not in special_usage_fields: + additional_usage_values.update({k: v}) + clean_metadata["additional_usage_values"] = additional_usage_values + if litellm.cache is not None: cache_key = litellm.cache.get_cache_key(**kwargs) else: From cceb5bf92cfd4e2389ab9463436ee710653c1380 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 11 Sep 2024 19:57:51 -0700 Subject: [PATCH 11/15] test: fix tests --- litellm/litellm_core_utils/litellm_logging.py | 17 ++++++++--- litellm/router.py | 7 +++++ litellm/tests/test_custom_callback_router.py | 29 ++++++++++--------- litellm/tests/test_lunary.py | 26 ++++++++--------- litellm/tests/test_traceloop.py | 3 +- 5 files changed, 50 insertions(+), 32 deletions(-) diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 43273224cb41..fe6acf345be5 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -1608,15 +1608,24 @@ async def special_failure_handlers(self, exception: Exception): """ from litellm.types.router import RouterErrors + litellm_params: dict = self.model_call_details.get("litellm_params") or {} + metadata = litellm_params.get("metadata") or {} + + ## BASE CASE ## check if rate limit error for model group size 1 + is_base_case = False + if metadata.get("model_group_size") is not None: + model_group_size = metadata.get("model_group_size") + if isinstance(model_group_size, int) and model_group_size == 1: + is_base_case = True ## check if special error ## - if RouterErrors.no_deployments_available.value not in str(exception): + if ( + RouterErrors.no_deployments_available.value not in str(exception) + and is_base_case is False + ): return ## get original model group ## - litellm_params: dict = self.model_call_details.get("litellm_params") or {} - metadata = litellm_params.get("metadata") or {} - model_group = metadata.get("model_group") or None for callback in litellm._async_failure_callback: if isinstance(callback, CustomLogger): # custom logger class diff --git a/litellm/router.py b/litellm/router.py index bc3f86163c16..c187474f1ac8 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -3027,6 +3027,13 @@ async def async_function_with_retries(self, *args, **kwargs): model_group = kwargs.get("model") num_retries = kwargs.pop("num_retries") + ## ADD MODEL GROUP SIZE TO METADATA - used for model_group_rate_limit_error tracking + _metadata: dict = kwargs.get("metadata") or {} + if "model_group" in _metadata and isinstance(_metadata["model_group"], str): + model_list = self.get_model_list(model_name=_metadata["model_group"]) + if model_list is not None: + _metadata.update({"model_group_size": len(model_list)}) + verbose_router_logger.debug( f"async function w/ retries: original_function - {original_function}, num_retries - {num_retries}" ) diff --git a/litellm/tests/test_custom_callback_router.py b/litellm/tests/test_custom_callback_router.py index 80fc096e78ac..6ffa97d89ecf 100644 --- a/litellm/tests/test_custom_callback_router.py +++ b/litellm/tests/test_custom_callback_router.py @@ -38,6 +38,8 @@ ## 1. router.completion() + router.embeddings() ## 2. proxy.completions + proxy.embeddings +litellm.num_retries = 0 + class CompletionCustomHandler( CustomLogger @@ -401,7 +403,7 @@ async def test_async_chat_azure(): "rpm": 1800, }, ] - router = Router(model_list=model_list) # type: ignore + router = Router(model_list=model_list, num_retries=0) # type: ignore response = await router.acompletion( model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}], @@ -413,7 +415,7 @@ async def test_async_chat_azure(): ) # pre, post, success # streaming litellm.callbacks = [customHandler_streaming_azure_router] - router2 = Router(model_list=model_list) # type: ignore + router2 = Router(model_list=model_list, num_retries=0) # type: ignore response = await router2.acompletion( model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}], @@ -443,7 +445,7 @@ async def test_async_chat_azure(): }, ] litellm.callbacks = [customHandler_failure] - router3 = Router(model_list=model_list) # type: ignore + router3 = Router(model_list=model_list, num_retries=0) # type: ignore try: response = await router3.acompletion( model="gpt-3.5-turbo", @@ -505,7 +507,7 @@ async def test_async_embedding_azure(): }, ] litellm.callbacks = [customHandler_failure] - router3 = Router(model_list=model_list) # type: ignore + router3 = Router(model_list=model_list, num_retries=0) # type: ignore try: response = await router3.aembedding( model="azure-embedding-model", input=["hello from litellm!"] @@ -678,22 +680,21 @@ async def test_rate_limit_error_callback(): pass with patch.object( - customHandler, "log_model_group_rate_limit_error", new=MagicMock() + customHandler, "log_model_group_rate_limit_error", new=AsyncMock() ) as mock_client: print( f"customHandler.log_model_group_rate_limit_error: {customHandler.log_model_group_rate_limit_error}" ) - for _ in range(3): - try: - _ = await router.acompletion( - model="my-test-gpt", - messages=[{"role": "user", "content": "Hey, how's it going?"}], - litellm_logging_obj=litellm_logging_obj, - ) - except (litellm.RateLimitError, ValueError): - pass + try: + _ = await router.acompletion( + model="my-test-gpt", + messages=[{"role": "user", "content": "Hey, how's it going?"}], + litellm_logging_obj=litellm_logging_obj, + ) + except (litellm.RateLimitError, ValueError): + pass await asyncio.sleep(3) mock_client.assert_called_once() diff --git a/litellm/tests/test_lunary.py b/litellm/tests/test_lunary.py index cd068d9900f0..d181d24c782d 100644 --- a/litellm/tests/test_lunary.py +++ b/litellm/tests/test_lunary.py @@ -1,16 +1,17 @@ -import sys -import os import io +import os +import sys sys.path.insert(0, os.path.abspath("../..")) -from litellm import completion import litellm +from litellm import completion litellm.failure_callback = ["lunary"] litellm.success_callback = ["lunary"] litellm.set_verbose = True + def test_lunary_logging(): try: response = completion( @@ -24,6 +25,7 @@ def test_lunary_logging(): except Exception as e: print(e) + test_lunary_logging() @@ -37,8 +39,6 @@ def test_lunary_template(): except Exception as e: print(e) -test_lunary_template() - def test_lunary_logging_with_metadata(): try: @@ -50,19 +50,23 @@ def test_lunary_logging_with_metadata(): metadata={ "run_name": "litellmRUN", "project_name": "litellm-completion", - "tags": ["tag1", "tag2"] + "tags": ["tag1", "tag2"], }, ) print(response) except Exception as e: print(e) -test_lunary_logging_with_metadata() def test_lunary_with_tools(): import litellm - messages = [{"role": "user", "content": "What's the weather like in San Francisco, Tokyo, and Paris?"}] + messages = [ + { + "role": "user", + "content": "What's the weather like in San Francisco, Tokyo, and Paris?", + } + ] tools = [ { "type": "function", @@ -90,13 +94,11 @@ def test_lunary_with_tools(): tools=tools, tool_choice="auto", # auto is default, but we'll be explicit ) - + response_message = response.choices[0].message print("\nLLM Response:\n", response.choices[0].message) -test_lunary_with_tools() - def test_lunary_logging_with_streaming_and_metadata(): try: response = completion( @@ -114,5 +116,3 @@ def test_lunary_logging_with_streaming_and_metadata(): continue except Exception as e: print(e) - -test_lunary_logging_with_streaming_and_metadata() diff --git a/litellm/tests/test_traceloop.py b/litellm/tests/test_traceloop.py index bcc120323cb4..74d58228efeb 100644 --- a/litellm/tests/test_traceloop.py +++ b/litellm/tests/test_traceloop.py @@ -4,7 +4,6 @@ import pytest from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter -from traceloop.sdk import Traceloop import litellm @@ -13,6 +12,8 @@ @pytest.fixture() def exporter(): + from traceloop.sdk import Traceloop + exporter = InMemorySpanExporter() Traceloop.init( app_name="test_litellm", From 6422501a1731c87634a3474e1cafbdb63f92b2fb Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 11 Sep 2024 20:01:41 -0700 Subject: [PATCH 12/15] test: fix test --- litellm/tests/test_router.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index 643c34543265..05d9f9f769f1 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -2121,7 +2121,7 @@ def test_router_cooldown_api_connection_error(): except litellm.APIConnectionError as e: assert ( Router()._is_cooldown_required( - exception_status=e.code, exception_str=str(e) + model_id="", exception_status=e.code, exception_str=str(e) ) is False ) From 451cdb598b01cfc4ebe28368e4a97785d81e123e Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 11 Sep 2024 20:07:00 -0700 Subject: [PATCH 13/15] test: fix test --- litellm/tests/test_router_retries.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/litellm/tests/test_router_retries.py b/litellm/tests/test_router_retries.py index f0503cd3f18b..f4574212d697 100644 --- a/litellm/tests/test_router_retries.py +++ b/litellm/tests/test_router_retries.py @@ -89,6 +89,17 @@ async def test_router_retries_errors(sync_mode, error_type): "tpm": 240000, "rpm": 1800, }, + { + "model_name": "azure/gpt-3.5-turbo", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/chatgpt-functioncalling", + "api_key": _api_key, + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE"), + }, + "tpm": 240000, + "rpm": 1800, + }, ] router = Router(model_list=model_list, allowed_fails=3) From 80fe1471aed2995570ef6cdc9849b0221fa48415 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 11 Sep 2024 20:39:54 -0700 Subject: [PATCH 14/15] test: fix test --- litellm/tests/test_amazing_vertex_completion.py | 1 + 1 file changed, 1 insertion(+) diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index 4c065b8d3712..89a86559dec8 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -53,6 +53,7 @@ "gemini-pro-experimental", "gemini-flash-experimental", "gemini-pro-flash", + "gemini-1.5-flash-exp-0827", ] From ea5f52f30b6879b72691f427a47fb549f4822a86 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 11 Sep 2024 21:26:23 -0700 Subject: [PATCH 15/15] test: fix test --- litellm/tests/test_proxy_utils.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/litellm/tests/test_proxy_utils.py b/litellm/tests/test_proxy_utils.py index 76f555cc95ca..b5aac09d1261 100644 --- a/litellm/tests/test_proxy_utils.py +++ b/litellm/tests/test_proxy_utils.py @@ -218,13 +218,16 @@ async def test_add_key_or_team_level_spend_logs_metadata_to_request( "langfuse_secret_key": "sk-lf-7cc8b620", }, { - "langfuse_host": "os.environ/LANGFUSE_HOST", - "langfuse_public_key": "os.environ/LANGFUSE_PUBLIC_KEY", - "langfuse_secret_key": "os.environ/LANGFUSE_SECRET_KEY", + "langfuse_host": "os.environ/LANGFUSE_HOST_TEMP", + "langfuse_public_key": "os.environ/LANGFUSE_PUBLIC_KEY_TEMP", + "langfuse_secret_key": "os.environ/LANGFUSE_SECRET_KEY_TEMP", }, ], ) def test_dynamic_logging_metadata_key_and_team_metadata(callback_vars): + os.environ["LANGFUSE_PUBLIC_KEY_TEMP"] = "pk-lf-9636b7a6-c066" + os.environ["LANGFUSE_SECRET_KEY_TEMP"] = "sk-lf-7cc8b620" + os.environ["LANGFUSE_HOST_TEMP"] = "https://us.cloud.langfuse.com" user_api_key_dict = UserAPIKeyAuth( token="6f8688eaff1d37555bb9e9a6390b6d7032b3ab2526ba0152da87128eab956432", key_name="sk-...63Fg",