Skip to content

Commit

Permalink
fix(utils.py): support raw response headers for streaming requests
Browse files Browse the repository at this point in the history
  • Loading branch information
krrishdholakia committed Jul 23, 2024
1 parent d1ffb4d commit f64a330
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 30 deletions.
8 changes: 6 additions & 2 deletions litellm/proxy/proxy_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2909,6 +2909,7 @@ async def chat_completion(
fastest_response_batch_completion = hidden_params.get(
"fastest_response_batch_completion", None
)
additional_headers: dict = hidden_params.get("additional_headers", {}) or {}

# Post Call Processing
if llm_router is not None:
Expand All @@ -2931,6 +2932,7 @@ async def chat_completion(
response_cost=response_cost,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
fastest_response_batch_completion=fastest_response_batch_completion,
**additional_headers,
)
selected_data_generator = select_data_generator(
response=response,
Expand All @@ -2948,8 +2950,10 @@ async def chat_completion(
user_api_key_dict=user_api_key_dict, response=response
)

hidden_params = getattr(response, "_hidden_params", {}) or {}
additional_headers: dict = hidden_params.get("additional_headers", {}) or {}
hidden_params = (
getattr(response, "_hidden_params", {}) or {}
) # get any updated response headers
additional_headers = hidden_params.get("additional_headers", {}) or {}

fastapi_response.headers.update(
get_custom_headers(
Expand Down
18 changes: 18 additions & 0 deletions litellm/tests/test_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1364,6 +1364,12 @@ def test_completion_openai_response_headers():
print("response_headers=", response._response_headers)
assert response._response_headers is not None
assert "x-ratelimit-remaining-tokens" in response._response_headers
assert isinstance(
response._hidden_params["additional_headers"][
"llm_provider-x-ratelimit-remaining-requests"
],
str,
)

# /chat/completion - with streaming

Expand All @@ -1376,6 +1382,12 @@ def test_completion_openai_response_headers():
print("streaming response_headers=", response_headers)
assert response_headers is not None
assert "x-ratelimit-remaining-tokens" in response_headers
assert isinstance(
response._hidden_params["additional_headers"][
"llm_provider-x-ratelimit-remaining-requests"
],
str,
)

for chunk in streaming_response:
print("chunk=", chunk)
Expand All @@ -1390,6 +1402,12 @@ def test_completion_openai_response_headers():
print("embedding_response_headers=", embedding_response_headers)
assert embedding_response_headers is not None
assert "x-ratelimit-remaining-tokens" in embedding_response_headers
assert isinstance(
response._hidden_params["additional_headers"][
"llm_provider-x-ratelimit-remaining-requests"
],
str,
)

litellm.return_response_headers = False

Expand Down
4 changes: 1 addition & 3 deletions litellm/tests/test_completion_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,6 +881,7 @@ def test_completion_azure_ai():
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_completion_cost_hidden_params(sync_mode):
litellm.return_response_headers = True
if sync_mode:
response = litellm.completion(
model="gpt-3.5-turbo",
Expand All @@ -896,9 +897,6 @@ async def test_completion_cost_hidden_params(sync_mode):

assert "response_cost" in response._hidden_params
assert isinstance(response._hidden_params["response_cost"], float)
assert isinstance(
response._hidden_params["llm_provider-x-ratelimit-remaining-requests"], float
)


def test_vertex_ai_gemini_predict_cost():
Expand Down
43 changes: 24 additions & 19 deletions litellm/tests/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -1988,25 +1988,30 @@ async def test_hf_completion_tgi_stream():

# test on openai completion call
def test_openai_chat_completion_call():
try:
litellm.set_verbose = False
print(f"making openai chat completion call")
response = completion(model="gpt-3.5-turbo", messages=messages, stream=True)
complete_response = ""
start_time = time.time()
for idx, chunk in enumerate(response):
chunk, finished = streaming_format_tests(idx, chunk)
print(f"outside chunk: {chunk}")
if finished:
break
complete_response += chunk
# print(f'complete_chunk: {complete_response}')
if complete_response.strip() == "":
raise Exception("Empty response received")
print(f"complete response: {complete_response}")
except:
print(f"error occurred: {traceback.format_exc()}")
pass
litellm.set_verbose = False
litellm.return_response_headers = True
print(f"making openai chat completion call")
response = completion(model="gpt-3.5-turbo", messages=messages, stream=True)
assert isinstance(
response._hidden_params["additional_headers"][
"llm_provider-x-ratelimit-remaining-requests"
],
str,
)

print(f"response._hidden_params: {response._hidden_params}")
complete_response = ""
start_time = time.time()
for idx, chunk in enumerate(response):
chunk, finished = streaming_format_tests(idx, chunk)
print(f"outside chunk: {chunk}")
if finished:
break
complete_response += chunk
# print(f'complete_chunk: {complete_response}')
if complete_response.strip() == "":
raise Exception("Empty response received")
print(f"complete response: {complete_response}")


# test_openai_chat_completion_call()
Expand Down
17 changes: 11 additions & 6 deletions litellm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5679,13 +5679,13 @@ def convert_to_model_response_object(
):
received_args = locals()
if _response_headers is not None:
llm_response_headers = {
"{}-{}".format("llm_provider", k): v for k, v in _response_headers.items()
}
if hidden_params is not None:
hidden_params["additional_headers"] = {
"{}-{}".format("llm_provider", k): v
for k, v in _response_headers.items()
}
hidden_params["additional_headers"] = llm_response_headers
else:
hidden_params = {"additional_headers": _response_headers}
hidden_params = {"additional_headers": llm_response_headers}
### CHECK IF ERROR IN RESPONSE ### - openrouter returns these in the dictionary
if (
response_object is not None
Expand Down Expand Up @@ -8320,8 +8320,13 @@ def __init__(
or {}
)
self._hidden_params = {
"model_id": (_model_info.get("id", None))
"model_id": (_model_info.get("id", None)),
} # returned as x-litellm-model-id response header in proxy
if _response_headers is not None:
self._hidden_params["additional_headers"] = {
"{}-{}".format("llm_provider", k): v
for k, v in _response_headers.items()
}
self._response_headers = _response_headers
self.response_id = None
self.logging_loop = None
Expand Down

0 comments on commit f64a330

Please sign in to comment.