Skip to content

Commit

Permalink
Merge pull request THUDM#432 from zRzRzRzRzRzRzR/main
Browse files Browse the repository at this point in the history
增加OpenAI API的示例
  • Loading branch information
zRzRzRzRzRzRzR committed Nov 24, 2023
2 parents e3fc190 + e31cff5 commit 29af9b1
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 16 deletions.
9 changes: 4 additions & 5 deletions openai_api_demo/openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,13 +154,12 @@ async def create_chat_completion(request: ChatCompletionRequest):
response["text"] = response["text"][1:]
response["text"] = response["text"].strip()
usage = UsageInfo()

function_call, finish_reason = None, "stop"
if request.functions:
try:
function_call = process_response(response["text"], use_tool=True)
except:
logger.warning("Failed to parse tool call")
logger.warning("Failed to parse tool call, maybe the response is not a tool call or have been answered.")

if isinstance(function_call, dict):
finish_reason = "function_call"
Expand All @@ -172,16 +171,16 @@ async def create_chat_completion(request: ChatCompletionRequest):
function_call=function_call if isinstance(function_call, FunctionCallResponse) else None,
)

logger.debug(f"==== message ====\n{message}")

choice_data = ChatCompletionResponseChoice(
index=0,
message=message,
finish_reason=finish_reason,
)

task_usage = UsageInfo.model_validate(response["usage"])
for usage_key, usage_value in task_usage.model_dump().items():
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)

return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion", usage=usage)


Expand Down Expand Up @@ -211,7 +210,7 @@ async def predict(model_id: str, params: dict):
try:
function_call = process_response(decoded_unicode, use_tool=True)
except:
print("Failed to parse tool call")
logger.warning("Failed to parse tool call, maybe the response is not a tool call or have been answered.")

if isinstance(function_call, dict):
function_call = FunctionCallResponse(**function_call)
Expand Down
76 changes: 66 additions & 10 deletions openai_api_demo/openai_api_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,16 @@

base_url = "http://127.0.0.1:8000"

def create_chat_completion(model, messages, use_stream=False):

def create_chat_completion(model, messages, functions, use_stream=False):
data = {
"model": model, # 模型名称
"messages": messages, # 会话历史
"stream": use_stream, # 是否流式响应
"max_tokens": 100, # 最多生成字数
"temperature": 0.8, # 温度
"top_p": 0.8, # 采样概率
"function": functions, # 函数定义
"model": model, # 模型名称
"messages": messages, # 会话历史
"stream": use_stream, # 是否流式响应
"max_tokens": 100, # 最多生成字数
"temperature": 0.8, # 温度
"top_p": 0.8, # 采样概率
}

response = requests.post(f"{base_url}/v1/chat/completions", json=data, stream=use_stream)
Expand All @@ -35,15 +37,64 @@ def create_chat_completion(model, messages, use_stream=False):
else:
# 处理非流式响应
decoded_line = response.json()
print(decoded_line)
content = decoded_line.get("choices", [{}])[0].get("message", "").get("content", "")
print(content)
else:
print("Error:", response.status_code)
return None


if __name__ == "__main__":
def function_chat(use_stream=True):
functions = [
{
"name": "get_current_weather",
"description": "Get the current weather in a given location.",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. Beijing",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
}
]
chat_messages = [
{
"role": "user",
"content": "波士顿天气如何?",
},
{
"role": "assistant",
"content": "get_current_weather\n ```python\ntool_call(location='Beijing', unit='celsius')\n```",
"function_call": {
"name": "get_current_weather",
"arguments": '{"location": "Beijing", "unit": "celsius"}',
},
},
{
"role": "function",
"name": "get_current_weather",
"content": '{"temperature": "12", "unit": "celsius", "description": "Sunny"}',
},
# ... 接下来这段是 assistant 的回复和用户的回复。
# {
# "role": "assistant",
# "content": "根据最新的天气预报,目前北京的天气情况是晴朗的,温度为12摄氏度。",
# },
# {
# "role": "user",
# "content": "谢谢",
# }
]
create_chat_completion("chatglm3-6b", messages=chat_messages, functions=functions, use_stream=use_stream)


def simple_chat(use_stream=True):
functions = None
chat_messages = [
{
"role": "system",
Expand All @@ -54,4 +105,9 @@ def create_chat_completion(model, messages, use_stream=False):
"content": "你好,给我讲一个故事,大概100字"
}
]
create_chat_completion("chatglm3-6b", chat_messages, use_stream=False)
create_chat_completion("chatglm3-6b", messages=chat_messages, functions=functions, use_stream=use_stream)


if __name__ == "__main__":
function_chat(use_stream=False)
# simple_chat(use_stream=True)
1 change: 0 additions & 1 deletion openai_api_demo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,6 @@ def generate_stream_chatglm3(model: PreTrainedModel, tokenizer: PreTrainedTokeni
def process_chatglm_messages(messages, functions=None):
_messages = messages
messages = []

if functions:
messages.append(
{
Expand Down

0 comments on commit 29af9b1

Please sign in to comment.