Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Have seq2seq just use gather #27025

Merged
merged 9 commits into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3131,13 +3131,13 @@ def evaluation_loop(

# Update containers on host
if loss is not None:
losses = self.accelerator.gather_for_metrics((loss.repeat(batch_size)))
losses = self.gather_function((loss.repeat(batch_size)))
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
losses_host = losses if losses_host is None else nested_concat(losses_host, losses, padding_index=-100)
if labels is not None:
labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100)
if inputs_decode is not None:
inputs_decode = self.accelerator.pad_across_processes(inputs_decode, dim=1, pad_index=-100)
inputs_decode = self.accelerator.gather_for_metrics((inputs_decode))
inputs_decode = self.gather_function((inputs_decode))
inputs_host = (
inputs_decode
if inputs_host is None
Expand All @@ -3147,11 +3147,11 @@ def evaluation_loop(
logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100)
if self.preprocess_logits_for_metrics is not None:
logits = self.preprocess_logits_for_metrics(logits, labels)
logits = self.accelerator.gather_for_metrics((logits))
logits = self.gather_function((logits))
preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)

if labels is not None:
labels = self.accelerator.gather_for_metrics((labels))
labels = self.gather_function((labels))
labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)

self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
Expand Down Expand Up @@ -3184,6 +3184,8 @@ def evaluation_loop(
# Set back to None to begin a new accumulation
losses_host, preds_host, inputs_host, labels_host = None, None, None, None

# After all calls to `.gather_function`, reset to `gather_for_metrics`:
self.gather_function = self.accelerator.gather_for_metrics
if args.past_index and hasattr(self, "_past"):
# Clean the state at the end of the evaluation loop
delattr(self, "_past")
Expand Down Expand Up @@ -3853,6 +3855,8 @@ def create_accelerator_and_postprocess(self):
deepspeed_plugin=self.args.deepspeed_plugin,
gradient_accumulation_plugin=gradient_accumulation_plugin,
)
# some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag
self.gather_function = self.accelerator.gather_for_metrics

# deepspeed and accelerate flags covering both trainer args and accelerate launcher
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/trainer_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,9 @@ def evaluate(
gen_kwargs["max_length"] = self.args.generation_max_length
if gen_kwargs.get("num_beams") is None and self.args.generation_num_beams is not None:
gen_kwargs["num_beams"] = self.args.generation_num_beams
# We don't want to drop samples in general
self.gather_function = self.accelerator.gather
self._gen_kwargs = gen_kwargs

return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)

def predict(
Expand Down Expand Up @@ -223,6 +224,7 @@ def predict(
gen_kwargs["max_length"] = self.args.generation_max_length
if gen_kwargs.get("num_beams") is None and self.args.generation_num_beams is not None:
gen_kwargs["num_beams"] = self.args.generation_num_beams
self.gather_function = self.accelerator.gather
self._gen_kwargs = gen_kwargs

return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
Expand Down
61 changes: 59 additions & 2 deletions tests/trainer/test_trainer_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from transformers import BertTokenizer, EncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments
from transformers import (
AutoModelForSeq2SeqLM,
BertTokenizer,
DataCollatorForSeq2Seq,
EncoderDecoderModel,
GenerationConfig,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
T5Tokenizer,
)
from transformers.testing_utils import TestCasePlus, require_torch, slow
from transformers.utils import is_datasets_available

Expand Down Expand Up @@ -124,3 +132,52 @@ def _compute_metrics(pred):

# start training
trainer.train()

@slow
@require_torch
def test_return_sequences(self):
# Tests that the number of generated sequences is correct when num_return_sequences > 1
# and essentially ensuring that `accelerator.gather()` is used instead of `gather_for_metrics`
INPUT_COLUMN = "question"
TARGET_COLUMN = "answer"
MAX_INPUT_LENGTH = 256
MAX_TARGET_LENGTH = 256

dataset = datasets.load_dataset("gsm8k", "main", split="train[:38]")
model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")
tokenizer = T5Tokenizer.from_pretrained("t5-small")
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, return_tensors="pt", padding="longest")
gen_config = GenerationConfig.from_pretrained(
"t5-small", max_length=None, min_length=None, max_new_tokens=256, min_new_tokens=1, num_beams=5
)

training_args = Seq2SeqTrainingArguments(".", predict_with_generate=True)

trainer = Seq2SeqTrainer(
model=model,
args=training_args,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=lambda x: {"samples": x[0].shape[0]},
)

def prepare_data(examples):
# Remove pairs where at least one record is none
inputs = examples[INPUT_COLUMN]
targets = examples[TARGET_COLUMN]

model_inputs = tokenizer(inputs, max_length=MAX_INPUT_LENGTH, truncation=True)
labels = tokenizer(text_target=targets, max_length=MAX_TARGET_LENGTH, truncation=True)
model_inputs["labels"] = labels["input_ids"]

return model_inputs

prepared_dataset = dataset.map(prepare_data, batched=True, remove_columns=[INPUT_COLUMN, TARGET_COLUMN])
dataset_len = len(prepared_dataset) # 38

for num_return_sequences in range(3, 0, -1):
gen_config.num_return_sequences = num_return_sequences
metrics = trainer.evaluate(eval_dataset=prepared_dataset, generation_config=gen_config)
assert (
metrics["eval_samples"] == dataset_len * num_return_sequences
), f"Got {metrics['eval_samples']}, expected: {dataset_len * num_return_sequences}"
Loading