Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LiteLLM Minor Fixes & Improvements (09/23/2024) (#5842) #5858

Merged
merged 5 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/my-website/docs/providers/bedrock.md
Original file line number Diff line number Diff line change
Expand Up @@ -987,6 +987,9 @@ Here's an example of using a bedrock model with LiteLLM. For a complete list, re
| Anthropic Claude-V2.1 | `completion(model='bedrock/anthropic.claude-v2:1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
| Anthropic Claude-V2 | `completion(model='bedrock/anthropic.claude-v2', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
| Anthropic Claude-Instant V1 | `completion(model='bedrock/anthropic.claude-instant-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
| Meta llama3-1-405b | `completion(model='bedrock/meta.llama3-1-405b-instruct-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
| Meta llama3-1-70b | `completion(model='bedrock/meta.llama3-1-70b-instruct-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
| Meta llama3-1-8b | `completion(model='bedrock/meta.llama3-1-8b-instruct-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
| Meta llama3-70b | `completion(model='bedrock/meta.llama3-70b-instruct-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
| Meta llama3-8b | `completion(model='bedrock/meta.llama3-8b-instruct-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
| Amazon Titan Lite | `completion(model='bedrock/amazon.titan-text-lite-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
Expand Down
2 changes: 1 addition & 1 deletion litellm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,8 +963,8 @@ class LlmProviders(str, Enum):
MistralEmbeddingConfig,
DeepInfraConfig,
GroqConfig,
AzureAIStudioConfig,
)
from .llms.azure_ai.chat.transformation import AzureAIStudioConfig
from .llms.mistral.mistral_chat_transformation import MistralConfig
from .llms.OpenAI.chat.o1_transformation import (
OpenAIO1Config,
Expand Down
19 changes: 3 additions & 16 deletions litellm/integrations/SlackAlerting/slack_alerting.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from datetime import datetime as dt
from datetime import timedelta, timezone
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Set, TypedDict, Union
from typing import Any, Dict, List, Literal, Optional, Set, TypedDict, Union, get_args

import aiohttp
import dotenv
Expand Down Expand Up @@ -57,20 +57,7 @@ def __init__(
float
] = None, # threshold for slow / hanging llm responses (in seconds)
alerting: Optional[List] = [],
alert_types: List[AlertType] = [
"llm_exceptions",
"llm_too_slow",
"llm_requests_hanging",
"budget_alerts",
"db_exceptions",
"daily_reports",
"spend_reports",
"fallback_reports",
"cooldown_deployment",
"new_model_added",
"outage_alerts",
"failed_tracking_spend",
],
alert_types: List[AlertType] = list(get_args(AlertType)),
alert_to_webhook_url: Optional[
Dict[AlertType, Union[List[str], str]]
] = None, # if user wants to separate alerts to diff channels
Expand Down Expand Up @@ -613,7 +600,7 @@ async def failed_tracking_alert(self, error_message: str):
await self.send_alert(
message=message,
level="High",
alert_type="budget_alerts",
alert_type="failed_tracking_spend",
alerting_metadata={},
)
await _cache.async_set_cache(
Expand Down
9 changes: 6 additions & 3 deletions litellm/litellm_core_utils/litellm_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -2498,13 +2498,16 @@ def get_standard_logging_object_payload(
else:
cache_key = None

saved_cache_cost: Optional[float] = None
saved_cache_cost: float = 0.0
if cache_hit is True:

id = f"{id}_cache_hit{time.time()}" # do not duplicate the request id

saved_cache_cost = logging_obj._response_cost_calculator(
result=init_response_obj, cache_hit=False # type: ignore
saved_cache_cost = (
logging_obj._response_cost_calculator(
result=init_response_obj, cache_hit=False # type: ignore
)
or 0.0
)

## Get model cost information ##
Expand Down
19 changes: 0 additions & 19 deletions litellm/llms/OpenAI/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,25 +103,6 @@ def map_openai_params(self, non_default_params: dict, optional_params: dict):
return optional_params


class AzureAIStudioConfig:
def get_required_params(self) -> List[ProviderField]:
"""For a given provider, return it's required fields with a description"""
return [
ProviderField(
field_name="api_key",
field_type="string",
field_description="Your Azure AI Studio API Key.",
field_value="zEJ...",
),
ProviderField(
field_name="api_base",
field_type="string",
field_description="Your Azure AI Studio API Base.",
field_value="https://Mistral-serverless.",
),
]


class DeepInfraConfig:
"""
Reference: https://deepinfra.com/docs/advanced/openai_api
Expand Down
59 changes: 59 additions & 0 deletions litellm/llms/azure_ai/chat/handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from typing import Any, Callable, List, Optional, Union

from httpx._config import Timeout

from litellm.llms.bedrock.chat.invoke_handler import MockResponseIterator
from litellm.llms.OpenAI.openai import OpenAIChatCompletion
from litellm.types.utils import ModelResponse
from litellm.utils import CustomStreamWrapper

from .transformation import AzureAIStudioConfig


class AzureAIChatCompletion(OpenAIChatCompletion):
def completion(
self,
model_response: ModelResponse,
timeout: Union[float, Timeout],
optional_params: dict,
logging_obj: Any,
model: Optional[str] = None,
messages: Optional[list] = None,
print_verbose: Optional[Callable[..., Any]] = None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
acompletion: bool = False,
litellm_params=None,
logger_fn=None,
headers: Optional[dict] = None,
custom_prompt_dict: dict = {},
client=None,
organization: Optional[str] = None,
custom_llm_provider: Optional[str] = None,
drop_params: Optional[bool] = None,
):

transformed_messages = AzureAIStudioConfig()._transform_messages(
messages=messages # type: ignore
)

return super().completion(
model_response,
timeout,
optional_params,
logging_obj,
model,
transformed_messages,
print_verbose,
api_key,
api_base,
acompletion,
litellm_params,
logger_fn,
headers,
custom_prompt_dict,
client,
organization,
custom_llm_provider,
drop_params,
)
31 changes: 31 additions & 0 deletions litellm/llms/azure_ai/chat/transformation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import List

from litellm.llms.OpenAI.openai import OpenAIConfig
from litellm.llms.prompt_templates.common_utils import convert_content_list_to_str
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ProviderField


class AzureAIStudioConfig(OpenAIConfig):
def get_required_params(self) -> List[ProviderField]:
"""For a given provider, return it's required fields with a description"""
return [
ProviderField(
field_name="api_key",
field_type="string",
field_description="Your Azure AI Studio API Key.",
field_value="zEJ...",
),
ProviderField(
field_name="api_base",
field_type="string",
field_description="Your Azure AI Studio API Base.",
field_value="https://Mistral-serverless.",
),
]

def _transform_messages(self, messages: List[AllMessageValues]) -> List:
for message in messages:
message = convert_content_list_to_str(message=message)

return messages
18 changes: 12 additions & 6 deletions litellm/llms/fireworks_ai/cost_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

# Extract the number of billion parameters from the model name
# only used for together_computer LLMs
def get_model_params_and_category(model_name: str) -> str:
def get_base_model_for_pricing(model_name: str) -> str:
"""
Helper function for calculating together ai pricing.
Expand Down Expand Up @@ -43,7 +43,7 @@ def get_model_params_and_category(model_name: str) -> str:
return "fireworks-ai-16b-80b"

# If no matches, return the original model_name
return model_name
return "fireworks-ai-default"


def cost_per_token(model: str, usage: Usage) -> Tuple[float, float]:
Expand All @@ -57,10 +57,16 @@ def cost_per_token(model: str, usage: Usage) -> Tuple[float, float]:
Returns:
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
"""
base_model = get_model_params_and_category(model_name=model)

## GET MODEL INFO
model_info = get_model_info(model=base_model, custom_llm_provider="fireworks_ai")
## check if model mapped, else use default pricing
try:
model_info = get_model_info(model=model, custom_llm_provider="fireworks_ai")
except Exception:
base_model = get_base_model_for_pricing(model_name=model)

## GET MODEL INFO
model_info = get_model_info(
model=base_model, custom_llm_provider="fireworks_ai"
)

## CALCULATE INPUT COST

Expand Down
32 changes: 32 additions & 0 deletions litellm/llms/prompt_templates/common_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""
Common utility functions used for translating messages across providers
"""

from typing import List

from litellm.types.llms.openai import AllMessageValues


def convert_content_list_to_str(message: AllMessageValues) -> AllMessageValues:
"""
- handles scenario where content is list and not string
- content list is just text, and no images
- if image passed in, then just return as is (user-intended)
Motivation: mistral api + azure ai don't support content as a list
"""
texts = ""
message_content = message.get("content")
if message_content:
if message_content is not None and isinstance(message_content, list):
for c in message_content:
text_content = c.get("text")
if text_content:
texts += text_content
elif message_content is not None and isinstance(message_content, str):
texts = message_content

if texts:
message["content"] = texts

return message
Loading