[generation] multiple eos/pad asserts/ifs in generate search functions
See original GitHub issueIn _generate_no_beam_search
eos_token_id
is required: https://github.com/huggingface/transformers/blob/master/src/transformers/generation_utils.py#L731 (that code always get hit)
assert (
eos_token_id is not None and pad_token_id is not None
), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
why do we assert and check eos_token_id is not None
multiple times through the code, why not assert once at the top of the function and then just use it?
Moreover, all those if eos_token_id is not None
can be then removed (or reduced if there are other parts to them).
Also a larger question - is there a model where eos_token_id
is not defined? If there is none, then why not assert once at the top of generate
and then just use it everywhere in sub-calls without testing its definition?
Oh, I also see pad_token_id
is used in _generate_no_beam_search
w/o testing whether it’s defined: https://github.com/huggingface/transformers/blob/master/src/transformers/generation_utils.py#L571
tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents)
is it the same situation as eos_token_id
- that is it is always needed?
I see it’s may be defined here: https://github.com/huggingface/transformers/blob/master/src/transformers/generation_utils.py#L355 but only if eos_token_id
is defined.
if pad_token_id is None and eos_token_id is not None:
logger.warning(
"Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_id)
)
pad_token_id = eos_token_id
my thinking is that if this worked until now for all models, it’s another proof that eos_token_id
has to be required again.
in _generate_no_beam_search
pad_token_id
is required and similarly to eos_token_id
can be asserted once on top and not multiple times through the code.
Thank you for reviewing my observations. It’s possible that some (all?) are incorrect if I missed something.
Issue Analytics
- State:
- Created 3 years ago
- Comments:7 (7 by maintainers)
I think your PR is fine because if no
eos_token_id
is defined, this condition can never happen:sent_lengths[i] < max_length:
. What I mean is that if noeos_token_id
is defined no matter whatgenerate()
method is used, all sent_length will always be ==max_length
and the condition will not be hit.Hey @stas00,
This is definitely part of the code that should be refactored 😄 Super hard to follow the logic there 😕
As a start, this PR is probably quite useful for context: https://github.com/huggingface/transformers/pull/2885. So there are a couple of models where EOS token is not defined and I’m quite sure that the code you linked does not always get hit. It can very well be that we apply beam search to
OpenAIGPT
- with a givenmax_length
.OpenAIGPT
does not have an EOS token, but beam search should work nevertheless.It’s quite a tricky pad token / eos token / … logic that is implemented there. I think we have to be super careful to not break anything here - even if all the slow tests pass, it might not be enough (
OpenAIGPT
beam search is not integration tested…)Also, I’m currently working on refactoring the generate function, will ping you guys in a couple of days with a first design proposition. My idea is to pull apart beam search + greedy / beam search + sampling / no beam search + greedy / no beam searh + greedy to make everything more readable. I’m not sure whether it’s worth diving deep into the generate() logic before we have a more readable code