Skip to content

Commit

Permalink
Merge branch 'BerriAI:main' into feature/watsonx-integration
Browse files Browse the repository at this point in the history
  • Loading branch information
simonsanvil committed Apr 21, 2024
2 parents c36cb7d + cb3c98a commit a77537d
Show file tree
Hide file tree
Showing 45 changed files with 1,027 additions and 281 deletions.
2 changes: 1 addition & 1 deletion docs/my-website/docs/routing.md
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ router_settings:
```

</TabItem>
<TabItem value="simple-shuffle" label="(Default) Weighted Pick">
<TabItem value="simple-shuffle" label="(Default) Weighted Pick (Async)">

**Default** Picks a deployment based on the provided **Requests per minute (rpm) or Tokens per minute (tpm)**

Expand Down
2 changes: 1 addition & 1 deletion litellm/integrations/prometheus.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(
**kwargs,
):
try:
verbose_logger.debug(f"in init prometheus metrics")
print(f"in init prometheus metrics")
from prometheus_client import Counter

self.litellm_llm_api_failed_requests_metric = Counter(
Expand Down
53 changes: 41 additions & 12 deletions litellm/integrations/prometheus_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,18 @@ def __init__(
) # store the prometheus histogram/counter we need to call for each field in payload

for service in self.services:
histogram = self.create_histogram(service)
counter = self.create_counter(service)
self.payload_to_prometheus_map[service] = [histogram, counter]
histogram = self.create_histogram(service, type_of_request="latency")
counter_failed_request = self.create_counter(
service, type_of_request="failed_requests"
)
counter_total_requests = self.create_counter(
service, type_of_request="total_requests"
)
self.payload_to_prometheus_map[service] = [
histogram,
counter_failed_request,
counter_total_requests,
]

self.prometheus_to_amount_map: dict = (
{}
Expand Down Expand Up @@ -74,26 +83,26 @@ def get_metric(self, metric_name):
return metric
return None

def create_histogram(self, label: str):
metric_name = "litellm_{}_latency".format(label)
def create_histogram(self, service: str, type_of_request: str):
metric_name = "litellm_{}_{}".format(service, type_of_request)
is_registered = self.is_metric_registered(metric_name)
if is_registered:
return self.get_metric(metric_name)
return self.Histogram(
metric_name,
"Latency for {} service".format(label),
labelnames=[label],
"Latency for {} service".format(service),
labelnames=[service],
)

def create_counter(self, label: str):
metric_name = "litellm_{}_failed_requests".format(label)
def create_counter(self, service: str, type_of_request: str):
metric_name = "litellm_{}_{}".format(service, type_of_request)
is_registered = self.is_metric_registered(metric_name)
if is_registered:
return self.get_metric(metric_name)
return self.Counter(
metric_name,
"Total failed requests for {} service".format(label),
labelnames=[label],
"Total {} for {} service".format(type_of_request, service),
labelnames=[service],
)

def observe_histogram(
Expand All @@ -120,6 +129,8 @@ def service_success_hook(self, payload: ServiceLoggerPayload):
if self.mock_testing:
self.mock_testing_success_calls += 1

print(f"payload call type: {payload.call_type}")

if payload.service.value in self.payload_to_prometheus_map:
prom_objects = self.payload_to_prometheus_map[payload.service.value]
for obj in prom_objects:
Expand All @@ -129,19 +140,27 @@ def service_success_hook(self, payload: ServiceLoggerPayload):
labels=payload.service.value,
amount=payload.duration,
)
elif isinstance(obj, self.Counter) and "total_requests" in obj._name:
self.increment_counter(
counter=obj,
labels=payload.service.value,
amount=1, # LOG TOTAL REQUESTS TO PROMETHEUS
)

def service_failure_hook(self, payload: ServiceLoggerPayload):
if self.mock_testing:
self.mock_testing_failure_calls += 1

print(f"payload call type: {payload.call_type}")

if payload.service.value in self.payload_to_prometheus_map:
prom_objects = self.payload_to_prometheus_map[payload.service.value]
for obj in prom_objects:
if isinstance(obj, self.Counter):
self.increment_counter(
counter=obj,
labels=payload.service.value,
amount=1, # LOG ERROR COUNT TO PROMETHEUS
amount=1, # LOG ERROR COUNT / TOTAL REQUESTS TO PROMETHEUS
)

async def async_service_success_hook(self, payload: ServiceLoggerPayload):
Expand All @@ -151,6 +170,8 @@ async def async_service_success_hook(self, payload: ServiceLoggerPayload):
if self.mock_testing:
self.mock_testing_success_calls += 1

print(f"payload call type: {payload.call_type}")

if payload.service.value in self.payload_to_prometheus_map:
prom_objects = self.payload_to_prometheus_map[payload.service.value]
for obj in prom_objects:
Expand All @@ -160,12 +181,20 @@ async def async_service_success_hook(self, payload: ServiceLoggerPayload):
labels=payload.service.value,
amount=payload.duration,
)
elif isinstance(obj, self.Counter) and "total_requests" in obj._name:
self.increment_counter(
counter=obj,
labels=payload.service.value,
amount=1, # LOG TOTAL REQUESTS TO PROMETHEUS
)

async def async_service_failure_hook(self, payload: ServiceLoggerPayload):
print(f"received error payload: {payload.error}")
if self.mock_testing:
self.mock_testing_failure_calls += 1

print(f"payload call type: {payload.call_type}")

if payload.service.value in self.payload_to_prometheus_map:
prom_objects = self.payload_to_prometheus_map[payload.service.value]
for obj in prom_objects:
Expand Down
38 changes: 24 additions & 14 deletions litellm/llms/prompt_templates/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,10 +507,11 @@ def construct_tool_use_system_prompt(
): # from https://github.com/anthropics/anthropic-cookbook/blob/main/function_calling/function_calling.ipynb
tool_str_list = []
for tool in tools:
tool_function = get_attribute_or_key(tool, "function")
tool_str = construct_format_tool_for_claude_prompt(
tool["function"]["name"],
tool["function"].get("description", ""),
tool["function"].get("parameters", {}),
get_attribute_or_key(tool_function, "name"),
get_attribute_or_key(tool_function, "description", ""),
get_attribute_or_key(tool_function, "parameters", {}),
)
tool_str_list.append(tool_str)
tool_use_system_prompt = (
Expand Down Expand Up @@ -634,7 +635,8 @@ def convert_to_anthropic_tool_result_xml(message: dict) -> str:
</function_results>
"""
name = message.get("name")
content = message.get("content")
content = message.get("content", "")
content = content.replace("<", "&lt;").replace(">", "&gt;").replace("&", "&amp;")

# We can't determine from openai message format whether it's a successful or
# error call result so default to the successful result template
Expand All @@ -655,13 +657,15 @@ def convert_to_anthropic_tool_result_xml(message: dict) -> str:
def convert_to_anthropic_tool_invoke_xml(tool_calls: list) -> str:
invokes = ""
for tool in tool_calls:
if tool["type"] != "function":
if get_attribute_or_key(tool, "type") != "function":
continue

tool_name = tool["function"]["name"]
tool_function = get_attribute_or_key(tool,"function")
tool_name = get_attribute_or_key(tool_function, "name")
tool_arguments = get_attribute_or_key(tool_function, "arguments")
parameters = "".join(
f"<{param}>{val}</{param}>\n"
for param, val in json.loads(tool["function"]["arguments"]).items()
for param, val in json.loads(tool_arguments).items()
)
invokes += (
"<invoke>\n"
Expand Down Expand Up @@ -715,7 +719,7 @@ def anthropic_messages_pt_xml(messages: list):
{
"type": "text",
"text": (
convert_to_anthropic_tool_result(messages[msg_i])
convert_to_anthropic_tool_result_xml(messages[msg_i])
if messages[msg_i]["role"] == "tool"
else messages[msg_i]["content"]
),
Expand All @@ -736,7 +740,7 @@ def anthropic_messages_pt_xml(messages: list):
if messages[msg_i].get(
"tool_calls", []
): # support assistant tool invoke convertion
assistant_text += convert_to_anthropic_tool_invoke( # type: ignore
assistant_text += convert_to_anthropic_tool_invoke_xml( # type: ignore
messages[msg_i]["tool_calls"]
)

Expand Down Expand Up @@ -848,12 +852,12 @@ def convert_to_anthropic_tool_invoke(tool_calls: list) -> list:
anthropic_tool_invoke = [
{
"type": "tool_use",
"id": tool["id"],
"name": tool["function"]["name"],
"input": json.loads(tool["function"]["arguments"]),
"id": get_attribute_or_key(tool, "id"),
"name": get_attribute_or_key(get_attribute_or_key(tool, "function"), "name"),
"input": json.loads(get_attribute_or_key(get_attribute_or_key(tool, "function"), "arguments")),
}
for tool in tool_calls
if tool["type"] == "function"
if get_attribute_or_key(tool, "type") == "function"
]

return anthropic_tool_invoke
Expand Down Expand Up @@ -1074,7 +1078,8 @@ def cohere_message_pt(messages: list):
tool_result = convert_openai_message_to_cohere_tool_result(message)
tool_results.append(tool_result)
else:
prompt += message["content"]
prompt += message["content"] + "\n\n"
prompt = prompt.rstrip()
return prompt, tool_results


Expand Down Expand Up @@ -1414,3 +1419,8 @@ def prompt_factory(
return default_pt(
messages=messages
) # default that covers Bloom, T-5, any non-chat tuned model (e.g. base Llama2)

def get_attribute_or_key(tool_or_function, attribute, default=None):
if hasattr(tool_or_function, attribute):
return getattr(tool_or_function, attribute)
return tool_or_function.get(attribute, default)
12 changes: 7 additions & 5 deletions litellm/llms/vertex_ai_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def map_openai_params(self, non_default_params: dict, optional_params: dict):


"""
- Run client init
- Run client init
- Support async completion, streaming
"""

Expand Down Expand Up @@ -236,17 +236,19 @@ def completion(
if client is None:
if vertex_credentials is not None and isinstance(vertex_credentials, str):
import google.oauth2.service_account

json_obj = json.loads(vertex_credentials)

creds = (
google.oauth2.service_account.Credentials.from_service_account_info(
json_obj,
json.loads(vertex_credentials),
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
)
### CHECK IF ACCESS
access_token = refresh_auth(credentials=creds)
else:
import google.auth
creds, _ = google.auth.default()
### CHECK IF ACCESS
access_token = refresh_auth(credentials=creds)

vertex_ai_client = AnthropicVertex(
project_id=vertex_project,
Expand Down
4 changes: 4 additions & 0 deletions litellm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,7 @@ def completion(
"client",
"rpm",
"tpm",
"max_parallel_requests",
"input_cost_per_token",
"output_cost_per_token",
"input_cost_per_second",
Expand Down Expand Up @@ -2598,6 +2599,7 @@ def embedding(
client = kwargs.pop("client", None)
rpm = kwargs.pop("rpm", None)
tpm = kwargs.pop("tpm", None)
max_parallel_requests = kwargs.pop("max_parallel_requests", None)
model_info = kwargs.get("model_info", None)
metadata = kwargs.get("metadata", None)
encoding_format = kwargs.get("encoding_format", None)
Expand Down Expand Up @@ -2655,6 +2657,7 @@ def embedding(
"client",
"rpm",
"tpm",
"max_parallel_requests",
"input_cost_per_token",
"output_cost_per_token",
"input_cost_per_second",
Expand Down Expand Up @@ -3514,6 +3517,7 @@ def image_generation(
"client",
"rpm",
"tpm",
"max_parallel_requests",
"input_cost_per_token",
"output_cost_per_token",
"hf_model_name",
Expand Down
Loading

0 comments on commit a77537d

Please sign in to comment.