Skip to content

Commit

Permalink
Merge pull request THUDM#334 from zRzRzRzRzRzRzR/main
Browse files Browse the repository at this point in the history
依赖更新和参数更新
  • Loading branch information
zhangch9 committed Nov 16, 2023
2 parents 2f95d81 + 685dd2e commit f823b4a
Show file tree
Hide file tree
Showing 12 changed files with 94 additions and 58 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,10 @@ cd ChatGLM3
```
pip install -r requirements.txt
```
其中 `transformers` 库版本推荐为 `4.30.2``torch` 推荐使用 2.0 及以上的版本,以获得最佳的推理性能。

+ `transformers` 库版本应该 `4.30.2` 以及以上的版本 ,`torch` 库版本应为 2.0 及以上的版本,以获得最佳的推理性能。
+ 为了保证 `torch` 的版本正确,请严格按照 [官方文档](https://pytorch.org/get-started/locally/) 的说明安装。
+ `gradio` 库版本应该为 `3.x` 的版本。

### 综合 Demo

Expand Down
5 changes: 4 additions & 1 deletion README_en.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,10 @@ Then use pip to install the dependencies:
```
pip install -r requirements.txt
```
It is recommended to use version `4.30.2` for the `transformers` library, and version 2.0 or above for `torch`, to achieve the best inference performance.
+ The `transformers` library version should be `4.30.2` and above, and `torch` library should be 2.0 and above to obtain the best inference performance.
+ In order to ensure that the version of `torch` is correct, please strictly follow the instructions of [official documentation](https://pytorch.org/get-started/locally/) for installation.
+ The `gradio` library version should be the `3.x` version.


### Integrated Demo

Expand Down
5 changes: 4 additions & 1 deletion basic_demo/cli_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@

welcome_prompt = "欢迎使用 ChatGLM3-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序"


def build_prompt(history):
prompt = welcome_prompt
for query, response in history:
prompt += f"\n\n用户:{query}"
prompt += f"\n\nChatGLM3-6B:{response}"
return prompt


def main():
past_key_values, history = None, []
global stop_stream
Expand All @@ -38,7 +40,8 @@ def main():
continue
print("\nChatGLM:", end="")
current_length = 0
for response, history, past_key_values in model.stream_chat(tokenizer, query, history=history,
for response, history, past_key_values in model.stream_chat(tokenizer, query, history=history, top_p=1,
temperature=0.01,
past_key_values=past_key_values,
return_past_key_values=True):
if stop_stream:
Expand Down
2 changes: 1 addition & 1 deletion basic_demo/web_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,4 +110,4 @@ def reset_state():

emptyBtn.click(reset_state, outputs=[chatbot, history, past_key_values], show_progress=True)

demo.queue().launch(share=False, server_port=8501, inbrowser=True) #在这里修改demo的端口号
demo.queue().launch(share=False, server_name="127.0.0.1", server_port=8501, inbrowser=True)
62 changes: 38 additions & 24 deletions composite_demo/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,27 @@
PT_PATH = os.environ.get('PT_PATH', None)
TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH)


@st.cache_resource
def get_client() -> Client:
client = HFClient(MODEL_PATH, TOKENIZER_PATH, PT_PATH)
return client


class Client(Protocol):
def generate_stream(self,
system: str | None,
tools: list[dict] | None,
history: list[Conversation],
**parameters: Any
) -> Iterable[TextGenerationStreamResponse]:
system: str | None,
tools: list[dict] | None,
history: list[Conversation],
**parameters: Any
) -> Iterable[TextGenerationStreamResponse]:
...


def stream_chat(self, tokenizer, query: str, history: list[tuple[str, str]] = None, role: str = "user",
past_key_values=None, max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8,
logits_processor=None, return_past_key_values=False, **kwargs):

past_key_values=None, max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8,
repetition_penalty=1.0, length_penalty=1.0, num_beams=1,
logits_processor=None, return_past_key_values=False, **kwargs):
from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import LogitsProcessorList

Expand All @@ -52,8 +55,18 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
logits_processor.append(InvalidScoreLogitsProcessor())
eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
tokenizer.get_command("<|observation|>")]
gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
gen_kwargs = {"max_length": max_length,
"do_sample": do_sample,
"top_p": top_p,
"temperature": temperature,
"logits_processor": logits_processor,
"repetition_penalty": repetition_penalty,
"length_penalty": length_penalty,
"num_beams": num_beams,
**kwargs
}

print(gen_kwargs)
if past_key_values is None:
inputs = tokenizer.build_chat_input(query, history=history, role=role)
else:
Expand Down Expand Up @@ -98,8 +111,9 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
else:
yield response, new_history


class HFClient(Client):
def __init__(self, model_path: str, tokenizer_path: str, pt_checkpoint: str | None = None,):
def __init__(self, model_path: str, tokenizer_path: str, pt_checkpoint: str | None = None, ):
self.model_path = model_path
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True)

Expand All @@ -123,11 +137,11 @@ def __init__(self, model_path: str, tokenizer_path: str, pt_checkpoint: str | No
).eval()

def generate_stream(self,
system: str | None,
tools: list[dict] | None,
history: list[Conversation],
**parameters: Any
) -> Iterable[TextGenerationStreamResponse]:
system: str | None,
tools: list[dict] | None,
history: list[Conversation],
**parameters: Any
) -> Iterable[TextGenerationStreamResponse]:
chat_history = [{
'role': 'system',
'content': system if not tools else TOOL_PROMPT,
Expand All @@ -141,19 +155,19 @@ def generate_stream(self,
'role': str(conversation.role).removeprefix('<|').removesuffix('|>'),
'content': conversation.content,
})

query = history[-1].content
role = str(history[-1].role).removeprefix('<|').removesuffix('|>')

text = ''

for new_text, _ in stream_chat(self.model,
self.tokenizer,
query,
chat_history,
role,
**parameters,
):
self.tokenizer,
query,
chat_history,
role,
**parameters,
):
word = new_text.removeprefix(text)
word_stripped = word.strip()
text = new_text
Expand Down
31 changes: 17 additions & 14 deletions composite_demo/demo_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,18 @@

client = get_client()


# Append a conversation into history, while show it in a new markdown block
def append_conversation(
conversation: Conversation,
history: list[Conversation],
placeholder: DeltaGenerator | None=None,
conversation: Conversation,
history: list[Conversation],
placeholder: DeltaGenerator | None = None,
) -> None:
history.append(conversation)
conversation.show(placeholder)

def main(top_p: float, temperature: float, system_prompt: str, prompt_text: str):

def main(top_p: float, temperature: float, system_prompt: str, prompt_text: str, repetition_penalty: float):
placeholder = st.empty()
with placeholder.container():
if 'chat_history' not in st.session_state:
Expand Down Expand Up @@ -48,14 +50,15 @@ def main(top_p: float, temperature: float, system_prompt: str, prompt_text: str)

output_text = ''
for response in client.generate_stream(
system_prompt,
tools=None,
history=history,
do_sample=True,
max_length=MAX_LENGTH,
temperature=temperature,
top_p=top_p,
stop_sequences=[str(Role.USER)],
system_prompt,
tools=None,
history=history,
do_sample=True,
max_length=MAX_LENGTH,
temperature=temperature,
top_p=top_p,
stop_sequences=[str(Role.USER)],
repetition_penalty=repetition_penalty,
):
token = response.token
if response.token.special:
Expand All @@ -70,8 +73,8 @@ def main(top_p: float, temperature: float, system_prompt: str, prompt_text: str)
break
output_text += response.token.text
markdown_placeholder.markdown(postprocess_text(output_text + '▌'))

append_conversation(Conversation(
Role.ASSISTANT,
postprocess_text(output_text),
), history, markdown_placeholder)
), history, markdown_placeholder)
6 changes: 4 additions & 2 deletions composite_demo/demo_ci.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def append_conversation(
history.append(conversation)
conversation.show(placeholder)

def main(top_p: float, temperature: float, prompt_text: str):
def main(top_p: float, temperature: float, prompt_text: str, repetition_penalty: float):
if 'ci_history' not in st.session_state:
st.session_state.ci_history = []

Expand Down Expand Up @@ -255,6 +255,7 @@ def main(top_p: float, temperature: float, prompt_text: str):
temperature=temperature,
top_p=top_p,
stop_sequences=[str(r) for r in (Role.USER, Role.OBSERVATION)],
repetition_penalty=repetition_penalty,
):
token = response.token
if response.token.special:
Expand Down Expand Up @@ -324,4 +325,5 @@ def main(top_p: float, temperature: float, prompt_text: str):
Role.ASSISTANT,
postprocess_text(output_text),
), history, markdown_placeholder)
return
return

4 changes: 3 additions & 1 deletion composite_demo/demo_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def append_conversation(
history.append(conversation)
conversation.show(placeholder)

def main(top_p: float, temperature: float, prompt_text: str):
def main(top_p: float, temperature: float, prompt_text: str, repetition_penalty: float):
manual_mode = st.toggle('Manual mode',
help='Define your tools in YAML format. You need to supply tool call results manually.'
)
Expand Down Expand Up @@ -117,6 +117,8 @@ def main(top_p: float, temperature: float, prompt_text: str):
temperature=temperature,
top_p=top_p,
stop_sequences=[str(r) for r in (Role.USER, Role.OBSERVATION)],
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
):
token = response.token
if response.token.special:
Expand Down
10 changes: 7 additions & 3 deletions composite_demo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.
'''.strip()


class Mode(str, Enum):
CHAT, TOOL, CI = '💬 Chat', '🛠️ Tool', '🧑‍💻 Code Interpreter'

Expand All @@ -25,6 +26,9 @@ class Mode(str, Enum):
temperature = st.slider(
'temperature', 0.0, 1.5, 0.95, step=0.01
)
repetition_penalty = st.slider(
'repetition_penalty', 0.0, 2.0, 1.2, step=0.01
)
system_prompt = st.text_area(
label="System Prompt (Only for chat mode)",
height=300,
Expand All @@ -47,10 +51,10 @@ class Mode(str, Enum):

match tab:
case Mode.CHAT:
demo_chat.main(top_p, temperature, system_prompt, prompt_text)
demo_chat.main(top_p, temperature, system_prompt, prompt_text, repetition_penalty)
case Mode.TOOL:
demo_tool.main(top_p, temperature, prompt_text)
demo_tool.main(top_p, temperature, prompt_text, repetition_penalty)
case Mode.CI:
demo_ci.main(top_p, temperature, prompt_text)
demo_ci.main(top_p, temperature, prompt_text, repetition_penalty)
case _:
st.error(f'Unexpected tab: {tab}')
1 change: 1 addition & 0 deletions openai_api_demo/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
openai>=1.3.0
10 changes: 6 additions & 4 deletions openai_api_demo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from transformers import PreTrainedModel, PreTrainedTokenizer
from transformers import AutoModel
from transformers.generation.logits_process import LogitsProcessor
from typing import Dict, Union, Optional,Tuple
from typing import Dict, Union, Optional, Tuple


def auto_configure_device_map(num_gpus: int) -> Dict[str, int]:
# transformer.word_embeddings 占用1层
Expand Down Expand Up @@ -60,9 +61,11 @@ def load_model_on_gpus(checkpoint_path: Union[str, os.PathLike], num_gpus: int =
model = dispatch_model(model, device_map=device_map)

return model


class InvalidScoreLogitsProcessor(LogitsProcessor):
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
) -> torch.FloatTensor:
if torch.isnan(scores).any() or torch.isinf(scores).any():
scores.zero_()
Expand Down Expand Up @@ -120,9 +123,8 @@ def generate_stream_chatglm3(model: PreTrainedModel, tokenizer: PreTrainedTokeni
if input_echo_len >= model.config.seq_length:
print(f"Input length larger than {model.config.seq_length}")


# TODO 废弃max_length,使用max_new_tokens
if max_new_tokens is not None and max_length is not None: # OpenAI接口的用户传入的应该是max_new_tokens才是适配OpenAI接口的。
if max_new_tokens is not None and max_length is not None: # OpenAI接口的用户传入的应该是max_new_tokens才是适配OpenAI接口的。
max_length = None

if max_new_tokens is None and max_length is None: # 什么参数都没传
Expand Down
11 changes: 5 additions & 6 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
protobuf
transformers==4.30.2
transformers>=4.30.2
cpm_kernels
torch>=2.0
gradio==3.39
gradio~=3.39
mdtex2html
sentencepiece
accelerate
sse-starlette
streamlit>=1.24.0
fastapi==0.95.1
typing_extensions==4.4.0
fastapi>=0.95.1
typing_extensions
uvicorn
sse_starlette
loguru
openai>=1.0.0
loguru

0 comments on commit f823b4a

Please sign in to comment.