-
Notifications
You must be signed in to change notification settings - Fork 26.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Have seq2seq just use gather #27025
Have seq2seq just use gather #27025
Conversation
The documentation is not available anymore as the PR was closed or merged. |
Looks like it's all passing now! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks ! Looks clean on my end
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing this!
A few comments/questions for my own understanding of the PR before I can approve:
- Could you clarify in the PR description the issue i.e. what does
gather_metrics
do differently fromgather
(what is the "magic")? - Am I right in understanding this should only be applied to cases when evaluating generations from seq2seq models and the generation config specifies
num_return_sequences > 1
? - What happens and what should happen if I call evaluate with a generation config with
num_return_sequences > 1
and then call a second time withnum_return_sequences==1
?
for num_return_sequences in range(1, 4): | ||
gen_config.num_return_sequences = num_return_sequences | ||
metrics = trainer.evaluate(eval_dataset=prepared_dataset, generation_config=gen_config) | ||
assert ( | ||
metrics["eval_samples"] == dataset_len * num_return_sequences | ||
), f"Got {metrics['eval_samples']}, expected: {dataset_len * num_return_sequences}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This works because the state of the trainer is set such that self.gather_function = self.accelerator.gather_metrics
initially and then switches to self.accelerator.gather
when num_return_sequences
. However, I don't think this would work if you did for num_return_sequences in range(3, 0, -1)
, as the trainer would never have self.gather_function = self.accelerator.gather_metrics
for when num_return_sequence=1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modified the test to use range(3,0,-1)
. It still passed beforehand, but simplified the logic to just use gather()
Correct, otherwise we will drop samples. Technically we can avoid this entirely I think by just using
Per your recommendation of the test, I tried this, and it worked as it should (because |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@muellerzr Sorry, I still don't fully understand and need some clarification.
Correct, otherwise we will drop samples.
Why does using gather_metrics
drop samples?
Technically we can avoid this entirely I think by just using gather, and the test seems to show that will indeed work fine.
Is this only true for Seq2SeqTrainer. If not, why not just use gather
everywhere?
It's some logic in Accelerate,
Yes, just Seq2Seq. Otherwise |
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for iterating!
As a general comment, this seems like something that should really be resolved on the accelerate side, however the fix seems tidy enough here.
* Have seq2seq just use gather * Change * Reset after * Make slow * Apply suggestions from code review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Clean * Simplify and just use gather * Update tests/trainer/test_trainer_seq2seq.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * gather always for seq2seq --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* Have seq2seq just use gather * Change * Reset after * Make slow * Apply suggestions from code review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Clean * Simplify and just use gather * Update tests/trainer/test_trainer_seq2seq.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * gather always for seq2seq --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* Have seq2seq just use gather * Change * Reset after * Make slow * Apply suggestions from code review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Clean * Simplify and just use gather * Update tests/trainer/test_trainer_seq2seq.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * gather always for seq2seq --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Hello @muellerzr , I think |
What does this PR do?
In the case of using
Seq2Seq
, we don't wantgather_for_metrics
to use its magic and we just want to do.gather()
(since otherwise it will drop samples in the former case as accelerate will drop "duplicates" based on the batch size, which leads to a bug).This PR sets a new
gather_function
in theTrainer
which by default isgather_for_metrics
, but if a particularTrainer
needs to modify it (such asSeq2SeqTrainer
), then it can be specified.Fixes #25231
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@ArthurZucker @younesbelkada