Skip to content

Commit

Permalink
support internvl2-76b
Browse files Browse the repository at this point in the history
  • Loading branch information
kennymckormick committed Jul 16, 2024
1 parent a11b8ca commit a32622a
Showing 1 changed file with 50 additions and 7 deletions.
57 changes: 50 additions & 7 deletions vlmeval/vlm/internvl_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,36 @@ def load_image(image_file, input_size=448, max_num=6, upscale=False):
return pixel_values


# This function is used to split InternVL2-Llama3-76B
def split_model(model_name):
import math
device_map = {}
num_gpus = torch.cuda.device_count()
rank, world_size = get_rank_and_world_size()
num_gpus = num_gpus // world_size

num_layers = {'InternVL2-8B': 32, 'InternVL2-26B': 48,
'InternVL2-40B': 60, 'InternVL2-Llama3-76B': 80}[model_name]
# Since the first GPU will be used for ViT, treat it as 0.8 GPU.
num_layers_per_gpu = math.ceil(num_layers / (num_gpus - 0.2))
num_layers_per_gpu = [num_layers_per_gpu] * num_gpus
num_layers_per_gpu[0] = math.ceil(num_layers_per_gpu[0] * 0.8)
layer_cnt = 0
for i, num_layer in enumerate(num_layers_per_gpu):
for j in range(num_layer):
device_map[f'language_model.model.layers.{layer_cnt}'] = rank + world_size * i
layer_cnt += 1
device_map['vision_model'] = 0
device_map['mlp1'] = 0
device_map['language_model.model.tok_embeddings'] = 0
device_map['language_model.model.embed_tokens'] = 0
device_map['language_model.output'] = 0
device_map['language_model.model.norm'] = 0
device_map['language_model.lm_head'] = 0
device_map[f'language_model.model.layers.{num_layers - 1}'] = 0
return device_map


class InternVLChat(BaseModel):

INSTALL_REQ = False
Expand All @@ -106,13 +136,26 @@ def __init__(self, model_path='OpenGVLab/InternVL-Chat-V1-5', load_in_8bit=False

self.model_path = model_path
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False)
device = torch.cuda.current_device()
self.device = device
self.model = AutoModel.from_pretrained(model_path, torch_dtype=torch.bfloat16,
trust_remote_code=True,
load_in_8bit=load_in_8bit).eval()
if not load_in_8bit:
self.model = self.model.to(device)

if listinstr(['InternVL2-Llama3-76B'], model_path):
device_map = split_model(model_path.split('/')[1])
self.model = AutoModel.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
load_in_8bit=load_in_8bit,
trust_remote_code=True,
low_cpu_mem_usage=True,
device_map=device_map).eval()
else:
device = torch.cuda.current_device()
self.device = device
self.model = AutoModel.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
load_in_8bit=load_in_8bit).eval()
if not load_in_8bit:
self.model = self.model.to(device)

self.image_size = self.model.config.vision_config.image_size
self.version = version
Expand Down

0 comments on commit a32622a

Please sign in to comment.