Skip to content

Commit

Permalink
fixing prefix_allowed_tokens_fn (#3276)
Browse files Browse the repository at this point in the history
Summary:
# Before submitting

- [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
- [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)?
- [x] Did you make sure to update the docs?
- [x] Did you write any new necessary tests?

## What does this PR do?
Fixes the use of `prefix_allowed_tokens_fn` in generation. It was working for `fairseq==0.9.0` (see https://github.com/facebookresearch/GENRE) but with the current version is broken.

## PR review
Anyone in the community is free to review the PR once the tests have passed.

## Did you have fun?
Make sure you had fun coding �

Pull Request resolved: #3276

Reviewed By: alexeib

Differential Revision: D26725494

Pulled By: myleott

fbshipit-source-id: ce3da725f36352687e5cb5d62a59b4c89ce0b0bc
  • Loading branch information
nicola-decao authored and facebook-github-bot committed May 27, 2021
1 parent e6eddd8 commit c8223e3
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 3 deletions.
7 changes: 6 additions & 1 deletion fairseq/hub_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def generate(
verbose: bool = False,
skip_invalid_size_inputs=False,
inference_step_args=None,
prefix_allowed_tokens_fn=None,
**kwargs
) -> List[List[Dict[str, torch.Tensor]]]:
if torch.is_tensor(tokenized_sentences) and tokenized_sentences.dim() == 1:
Expand All @@ -164,7 +165,11 @@ def generate(
gen_args.beam = beam
for k, v in kwargs.items():
setattr(gen_args, k, v)
generator = self.task.build_generator(self.models, gen_args)
generator = self.task.build_generator(
self.models,
gen_args,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
)

inference_step_args = inference_step_args or {}
results = []
Expand Down
29 changes: 27 additions & 2 deletions fairseq/tasks/fairseq_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,8 +341,32 @@ def build_criterion(self, cfg: DictConfig):
return criterions.build_criterion(cfg, self)

def build_generator(
self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None
self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None, prefix_allowed_tokens_fn=None,
):
"""
Build a :class:`~fairseq.SequenceGenerator` instance for this
task.
Args:
models (List[~fairseq.models.FairseqModel]): ensemble of models
args (fairseq.dataclass.configs.GenerationConfig):
configuration object (dataclass) for generation
extra_gen_cls_kwargs (Dict[str, Any]): extra options to pass
through to SequenceGenerator
prefix_allowed_tokens_fn (Callable[[int, torch.Tensor], List[int]]):
If provided, this function constrains the beam search to
allowed tokens only at each step. The provided function
should take 2 arguments: the batch ID (`batch_id: int`)
and a unidimensional tensor of token ids (`inputs_ids:
torch.Tensor`). It has to return a `List[int]` with the
allowed tokens for the next generation step conditioned
on the previously generated tokens (`inputs_ids`) and
the batch ID (`batch_id`). This argument is useful for
constrained generation conditioned on the prefix, as
described in "Autoregressive Entity Retrieval"
(https://arxiv.org/abs/2010.00904) and
https://github.com/facebookresearch/GENRE.
"""
if getattr(args, "score_reference", False):
from fairseq.sequence_scorer import SequenceScorer

Expand All @@ -369,7 +393,8 @@ def build_generator(
match_source_len = getattr(args, "match_source_len", False)
diversity_rate = getattr(args, "diversity_rate", -1)
constrained = getattr(args, "constraints", False)
prefix_allowed_tokens_fn = getattr(args, "prefix_allowed_tokens_fn", None)
if prefix_allowed_tokens_fn is None:
prefix_allowed_tokens_fn = getattr(args, "prefix_allowed_tokens_fn", None)
if (
sum(
int(cond)
Expand Down

0 comments on commit c8223e3

Please sign in to comment.