Skip to content

Commit

Permalink
fix model parallel
Browse files Browse the repository at this point in the history
  • Loading branch information
Jiaxin-Wen committed Apr 1, 2022
1 parent 6f7ecf2 commit bd74192
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 23 deletions.
62 changes: 40 additions & 22 deletions src/eva_interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,29 +156,47 @@ def generate_samples(model, tokenizer: EVATokenizer, args, device):
with torch.no_grad():
full_context_list = []
while True:
input_text = input("Usr >>> ")
if input_text == "clear":
print("Clear Dialog")
# set_random_seed(args.seed) # reset rng
full_context_list = []
continue
if input_text == "seed":
seed = int(input("Seed >>> "))
print("Clear Dialog")
set_random_seed(seed)
full_context_list = []
continue
if dist.get_rank() == 0:
input_text = input("Usr >>> ")
if input_text == "clear":
print("Clear Dialog")
# set_random_seed(args.seed) # reset rng
full_context_list = []
length_tensor = torch.tensor([-1], dtype=torch.long).to(device)
continue
if input_text == "seed":
seed = int(input("Seed >>> "))
print("Clear Dialog")
set_random_seed(seed)
full_context_list = []
length_tensor = torch.tensor([-1], dtype=torch.long).to(device)
continue
else:
full_context_list.append(tokenizer.encode(input_text) + [tokenizer.sep_id])
full_context = [x for y in full_context_list for x in y]
trunc_context = []
for utt in full_context_list[:-9:-1]:
if len(trunc_context) + len(utt) + 1 <= 128:
trunc_context = utt + trunc_context
trunc_context.append(tokenizer.get_sentinel_id(0))
length_tensor = torch.tensor([len(trunc_context), len(full_context)], dtype=torch.long).to(device)
trunc_context = torch.tensor(trunc_context, dtype=torch.long).to(device)
full_context = torch.tensor(full_context, dtype=torch.long).to(device)

else:
full_context_list.append(tokenizer.encode(input_text) + [tokenizer.sep_id])
full_context = [x for y in full_context_list for x in y]
trunc_context = []
for utt in full_context_list[:-9:-1]:
if len(trunc_context) + len(utt) + 1 <= 128:
trunc_context = utt + trunc_context
trunc_context.append(tokenizer.get_sentinel_id(0))
trunc_context = torch.tensor(trunc_context, dtype=torch.long).to(device)
full_context = torch.tensor(full_context, dtype=torch.long).to(device)

length_tensor = torch.zeros(2, dtype=torch.long).to(device)

dist.barrier()
dist.broadcast(length_tensor, 0)
if length_tensor[0] < 0:
continue
if dist.get_rank() != 0:
trunc_context = torch.zeros(int(length_tensor[0]), dtype=torch.long).to(device)
full_context = torch.zeros(int(length_tensor[1]), dtype=torch.long).to(device)
dist.broadcast(trunc_context, 0)
dist.broadcast(full_context, 0)


# encoder tensor
trunc_context = trunc_context.unsqueeze(0).repeat(args.batch_size, 1) # repeat
full_context = full_context.unsqueeze(0).repeat(args.batch_size, 1)
Expand Down
12 changes: 11 additions & 1 deletion src/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import os
import torch
import mpu

import torch.nn.functional as F

Expand Down Expand Up @@ -313,7 +314,12 @@ def generate_no_beam(model_batch, full_context, model, tokenizer: EVATokenizer,
past_key_values=past_key_values,
)
past_key_values = dec_outputs['past_key_values']
lm_logits = dec_outputs["lm_logits"]
lm_logits = dec_outputs['lm_logits']

gathered_lm_logits = [torch.zeros_like(lm_logits).to(device) for _ in range(mpu.get_model_parallel_world_size())]
torch.distributed.all_gather(gathered_lm_logits, lm_logits.data, mpu.get_model_parallel_group())
lm_logits = torch.cat(gathered_lm_logits, dim=-1)

logits = lm_logits[:, -1, :] / args.temperature

prev_output_tokens = torch.cat([full_context, output_ids], dim=-1)
Expand Down Expand Up @@ -436,6 +442,10 @@ def generate_beam(model_batch, full_context, model, tokenizer: EVATokenizer, arg
past_key_values = dec_outputs['past_key_values']
lm_logits = dec_outputs["lm_logits"]

gathered_lm_logits = [torch.zeros_like(lm_logits).to(device) for _ in range(mpu.get_model_parallel_world_size())]
torch.distributed.all_gather(gathered_lm_logits, lm_logits.data, mpu.get_model_parallel_group())
lm_logits = torch.cat(gathered_lm_logits, dim=-1)

logits = lm_logits[:, -1, :] / args.temperature
scores = F.log_softmax(logits, dim=-1)

Expand Down

0 comments on commit bd74192

Please sign in to comment.