Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat-Proxy] Add Azure Assistants API - Create Assistant, Delete Assistant Support #5777

Merged
merged 11 commits into from
Sep 18, 2024
33 changes: 33 additions & 0 deletions docs/my-website/docs/assistants.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,20 @@ Covers Threads, Messages, Assistants.

LiteLLM currently covers:
- Create Assistants
- Delete Assistants
- Get Assistants
- Create Thread
- Get Thread
- Add Messages
- Get Messages
- Run Thread


## **Supported Providers**:
- [OpenAI](#quick-start)
- [Azure OpenAI](#azure-openai)
- [OpenAI-Compatible APIs](#openai-compatible-apis)

## Quick Start

Call an existing Assistant.
Expand Down Expand Up @@ -283,6 +290,32 @@ curl -X POST 'http://0.0.0.0:4000/threads/{thread_id}/runs' \

## [👉 Proxy API Reference](https://litellm-api.up.railway.app/#/assistants)


## Azure OpenAI

**config**
```yaml
assistant_settings:
custom_llm_provider: azure
litellm_params:
api_key: os.environ/AZURE_API_KEY
api_base: os.environ/AZURE_API_BASE
```

**curl**

```bash
curl -X POST "http://localhost:4000/v1/assistants" \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-1234" \
-d '{
"instructions": "You are a personal math tutor. When asked a question, write and run Python code to answer the question.",
"name": "Math Tutor",
"tools": [{"type": "code_interpreter"}],
"model": "<my-azure-deployment-name>"
}'
```

## OpenAI-Compatible APIs

To call openai-compatible Assistants API's (eg. Astra Assistants API), just add `openai/` to the model name:
Expand Down
2 changes: 1 addition & 1 deletion docs/my-website/docs/batches.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import TabItem from '@theme/TabItem';

Covers Batches, Files

Supported Providers:
## **Supported Providers**:
- Azure OpenAI
- OpenAI

Expand Down
160 changes: 125 additions & 35 deletions litellm/assistants/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import litellm
from litellm import client
from litellm.llms.AzureOpenAI import assistants
from litellm.types.router import GenericLiteLLMParams
from litellm.utils import (
exception_type,
Expand All @@ -21,7 +22,7 @@
supports_httpx_timeout,
)

from ..llms.AzureOpenAI.azure import AzureAssistantsAPI
from ..llms.AzureOpenAI.assistants import AzureAssistantsAPI
from ..llms.OpenAI.openai import OpenAIAssistantsAPI
from ..types.llms.openai import *
from ..types.router import *
Expand Down Expand Up @@ -210,8 +211,8 @@ async def acreate_assistants(
loop = asyncio.get_event_loop()
### PASS ARGS TO GET ASSISTANTS ###
kwargs["async_create_assistants"] = True
model = kwargs.pop("model", None)
try:
model = kwargs.pop("model", None)
kwargs["client"] = client
# Use a partial function to pass your keyword arguments
func = partial(create_assistants, custom_llm_provider, model, **kwargs)
Expand Down Expand Up @@ -258,7 +259,7 @@ def create_assistants(
api_base: Optional[str] = None,
api_version: Optional[str] = None,
**kwargs,
) -> Assistant:
) -> Union[Assistant, Coroutine[Any, Any, Assistant]]:
async_create_assistants: Optional[bool] = kwargs.pop(
"async_create_assistants", None
)
Expand Down Expand Up @@ -288,7 +289,20 @@ def create_assistants(
elif timeout is None:
timeout = 600.0

response: Optional[Assistant] = None
create_assistant_data = {
"model": model,
"name": name,
"description": description,
"instructions": instructions,
"tools": tools,
"tool_resources": tool_resources,
"metadata": metadata,
"temperature": temperature,
"top_p": top_p,
"response_format": response_format,
}

response: Optional[Union[Coroutine[Any, Any, Assistant], Assistant]] = None
if custom_llm_provider == "openai":
api_base = (
optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
Expand All @@ -310,19 +324,6 @@ def create_assistants(
or os.getenv("OPENAI_API_KEY")
)

create_assistant_data = {
"model": model,
"name": name,
"description": description,
"instructions": instructions,
"tools": tools,
"tool_resources": tool_resources,
"metadata": metadata,
"temperature": temperature,
"top_p": top_p,
"response_format": response_format,
}

response = openai_assistants_api.create_assistants(
api_base=api_base,
api_key=api_key,
Expand All @@ -333,6 +334,46 @@ def create_assistants(
client=client,
async_create_assistants=async_create_assistants, # type: ignore
) # type: ignore
elif custom_llm_provider == "azure":
api_base = (
optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE")
) # type: ignore

api_version = (
optional_params.api_version
or litellm.api_version
or get_secret("AZURE_API_VERSION")
) # type: ignore

api_key = (
optional_params.api_key
or litellm.api_key
or litellm.azure_key
or get_secret("AZURE_OPENAI_API_KEY")
or get_secret("AZURE_API_KEY")
) # type: ignore

extra_body = optional_params.get("extra_body", {})
azure_ad_token: Optional[str] = None
if extra_body is not None:
azure_ad_token = extra_body.pop("azure_ad_token", None)
else:
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore

if isinstance(client, OpenAI):
client = None # only pass client if it's AzureOpenAI

response = azure_assistants_api.create_assistants(
api_base=api_base,
api_key=api_key,
azure_ad_token=azure_ad_token,
api_version=api_version,
timeout=timeout,
max_retries=optional_params.max_retries,
client=client,
async_create_assistants=async_create_assistants,
create_assistant_data=create_assistant_data,
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'create_assistants'. Only 'openai' is supported.".format(
Expand Down Expand Up @@ -401,7 +442,7 @@ def delete_assistant(
api_base: Optional[str] = None,
api_version: Optional[str] = None,
**kwargs,
) -> AssistantDeleted:
) -> Union[AssistantDeleted, Coroutine[Any, Any, AssistantDeleted]]:
optional_params = GenericLiteLLMParams(
api_key=api_key, api_base=api_base, api_version=api_version, **kwargs
)
Expand Down Expand Up @@ -432,7 +473,9 @@ def delete_assistant(
elif timeout is None:
timeout = 600.0

response: Optional[AssistantDeleted] = None
response: Optional[
Union[AssistantDeleted, Coroutine[Any, Any, AssistantDeleted]]
] = None
if custom_llm_provider == "openai":
api_base = (
optional_params.api_base
Expand Down Expand Up @@ -464,6 +507,46 @@ def delete_assistant(
client=client,
async_delete_assistants=async_delete_assistants,
)
elif custom_llm_provider == "azure":
api_base = (
optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE")
) # type: ignore

api_version = (
optional_params.api_version
or litellm.api_version
or get_secret("AZURE_API_VERSION")
) # type: ignore

api_key = (
optional_params.api_key
or litellm.api_key
or litellm.azure_key
or get_secret("AZURE_OPENAI_API_KEY")
or get_secret("AZURE_API_KEY")
) # type: ignore

extra_body = optional_params.get("extra_body", {})
azure_ad_token: Optional[str] = None
if extra_body is not None:
azure_ad_token = extra_body.pop("azure_ad_token", None)
else:
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore

if isinstance(client, OpenAI):
client = None # only pass client if it's AzureOpenAI

response = azure_assistants_api.delete_assistant(
assistant_id=assistant_id,
api_base=api_base,
api_key=api_key,
azure_ad_token=azure_ad_token,
api_version=api_version,
timeout=timeout,
max_retries=optional_params.max_retries,
client=client,
async_delete_assistants=async_delete_assistants,
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'delete_assistant'. Only 'openai' is supported.".format(
Expand Down Expand Up @@ -575,6 +658,9 @@ def create_thread(
elif timeout is None:
timeout = 600.0

api_base: Optional[str] = None
api_key: Optional[str] = None

response: Optional[Thread] = None
if custom_llm_provider == "openai":
api_base = (
Expand Down Expand Up @@ -612,12 +698,6 @@ def create_thread(
optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE")
) # type: ignore

api_version = (
optional_params.api_version
or litellm.api_version
or get_secret("AZURE_API_VERSION")
) # type: ignore

api_key = (
optional_params.api_key
or litellm.api_key
Expand All @@ -626,8 +706,14 @@ def create_thread(
or get_secret("AZURE_API_KEY")
) # type: ignore

api_version: Optional[str] = (
optional_params.api_version
or litellm.api_version
or get_secret("AZURE_API_VERSION")
) # type: ignore

extra_body = optional_params.get("extra_body", {})
azure_ad_token = None
azure_ad_token: Optional[str] = None
if extra_body is not None:
azure_ad_token = extra_body.pop("azure_ad_token", None)
else:
Expand All @@ -647,7 +733,7 @@ def create_thread(
max_retries=optional_params.max_retries,
client=client,
acreate_thread=acreate_thread,
) # type :ignore
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'create_thread'. Only 'openai' is supported.".format(
Expand Down Expand Up @@ -727,7 +813,8 @@ def get_thread(
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0

api_base: Optional[str] = None
api_key: Optional[str] = None
response: Optional[Thread] = None
if custom_llm_provider == "openai":
api_base = (
Expand Down Expand Up @@ -765,7 +852,7 @@ def get_thread(
optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE")
) # type: ignore

api_version = (
api_version: Optional[str] = (
optional_params.api_version
or litellm.api_version
or get_secret("AZURE_API_VERSION")
Expand All @@ -780,7 +867,7 @@ def get_thread(
) # type: ignore

extra_body = optional_params.get("extra_body", {})
azure_ad_token = None
azure_ad_token: Optional[str] = None
if extra_body is not None:
azure_ad_token = extra_body.pop("azure_ad_token", None)
else:
Expand Down Expand Up @@ -912,7 +999,8 @@ def add_message(
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0

api_key: Optional[str] = None
api_base: Optional[str] = None
response: Optional[OpenAIMessage] = None
if custom_llm_provider == "openai":
api_base = (
Expand Down Expand Up @@ -950,7 +1038,7 @@ def add_message(
optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE")
) # type: ignore

api_version = (
api_version: Optional[str] = (
optional_params.api_version
or litellm.api_version
or get_secret("AZURE_API_VERSION")
Expand All @@ -965,7 +1053,7 @@ def add_message(
) # type: ignore

extra_body = optional_params.get("extra_body", {})
azure_ad_token = None
azure_ad_token: Optional[str] = None
if extra_body is not None:
azure_ad_token = extra_body.pop("azure_ad_token", None)
else:
Expand Down Expand Up @@ -1071,6 +1159,8 @@ def get_messages(
timeout = 600.0

response: Optional[SyncCursorPage[OpenAIMessage]] = None
api_key: Optional[str] = None
api_base: Optional[str] = None
if custom_llm_provider == "openai":
api_base = (
optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
Expand Down Expand Up @@ -1106,7 +1196,7 @@ def get_messages(
optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE")
) # type: ignore

api_version = (
api_version: Optional[str] = (
optional_params.api_version
or litellm.api_version
or get_secret("AZURE_API_VERSION")
Expand All @@ -1121,7 +1211,7 @@ def get_messages(
) # type: ignore

extra_body = optional_params.get("extra_body", {})
azure_ad_token = None
azure_ad_token: Optional[str] = None
if extra_body is not None:
azure_ad_token = extra_body.pop("azure_ad_token", None)
else:
Expand Down
Loading
Loading