Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] InternVLChatModel.batch_chat()中缺少设置template.system_message的操作 #463

Closed
3 tasks done
Andempathy opened this issue Aug 7, 2024 · 1 comment
Closed
3 tasks done

Comments

@Andempathy
Copy link

Checklist

  • 1. I have searched related issues but cannot get the expected help.
  • 2. The bug has not been fixed in the latest version.
  • 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.

Describe the bug

在在InternVLChatModel.chat()中,可以通过model.system_message覆盖template.system_message

template = get_conv_template(self.template)
template.system_message = self.system_message

而在InternVLChatModel.batch_chat()中缺少这一部分:

template = get_conv_template(self.template)
template.append_message(template.roles[0], question)
template.append_message(template.roles[1], None)
query = template.get_prompt()

这使得对model.system_messagemodel.conv_template.system_message的修改在batch_chat的推理中实际上不生效,进而导致在包含system_message的sft数据上微调后的模型在chat和batch_chat时表现出比较明显的差异,而在batch_chat中手动添加template.system_message = self.system_message后,chat和batch_chat重新保持一致。

这一问题在https://github.com/OpenGVLab/InternVL/blob/main/internvl_chat/internvl/model/internvl_chat/modeling_internvl_chat.py 以及HF上的modeling_internvl_chat.py中都存在,望修复。

Reproduction

import torch
from transformers import AutoModel, AutoTokenizer

model = (
    AutoModel.from_pretrained(
        "my_finetuned_ckpt",
        torch_dtype=torch.bfloat16,
        low_cpu_mem_usage=True,
        trust_remote_code=True,
    )
    .eval()
    .cuda()
)

model.system_message = "my_custom_system_message"
model.conv_template.system_message = "my_custom_system_message"

tokenizer = AutoTokenizer.from_pretrained("my_finetuned_ckpt", trust_remote_code=True)

pixel_values = [load_image(image, max_num=1) for image in images]
num_patches_list = [pixel_value.size(0) for pixel_value in pixel_values]
pixel_values = torch.cat(pixel_values, dim=0)

questions = ["my_custom_question"] * len(num_patches_list)

generation_config = dict(
    num_beams=1,
    max_new_tokens=1024,
    do_sample=False,
)

responses = model.batch_chat(
    tokenizer,
    pixel_values,
    num_patches_list=num_patches_list,
    questions=questions,
    generation_config=generation_config,
)

Environment

sys.platform: linux
Python: 3.10.12 (main, Jul 29 2024, 16:56:48) [GCC 11.4.0]
CUDA available: True
MUSA available: False
numpy_random_seed: 2147483648
GPU 0,1,2,3: NVIDIA GeForce RTX 4090
CUDA_HOME: /usr/local/cuda-12.1
NVCC: Cuda compilation tools, release 12.1, V12.1.105
GCC: x86_64-linux-gnu-gcc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
PyTorch: 2.3.1+cu121
...

Error traceback

No response

@czczup
Copy link
Member

czczup commented Aug 8, 2024

感谢您的反馈,我今晚统一修复一下这个bug

czczup added a commit that referenced this issue Aug 9, 2024
@czczup czczup closed this as completed Aug 9, 2024
yqyao pushed a commit to ModelTC/InternVL that referenced this issue Aug 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants