question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[generation] multiple eos/pad asserts/ifs in generate search functions

See original GitHub issue

In _generate_no_beam_search eos_token_id is required: https://github.com/huggingface/transformers/blob/master/src/transformers/generation_utils.py#L731 (that code always get hit)

                    assert (
                        eos_token_id is not None and pad_token_id is not None
                    ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"

why do we assert and check eos_token_id is not None multiple times through the code, why not assert once at the top of the function and then just use it?

Moreover, all those if eos_token_id is not None can be then removed (or reduced if there are other parts to them).

Also a larger question - is there a model where eos_token_id is not defined? If there is none, then why not assert once at the top of generate and then just use it everywhere in sub-calls without testing its definition?

Oh, I also see pad_token_id is used in _generate_no_beam_search w/o testing whether it’s defined: https://github.com/huggingface/transformers/blob/master/src/transformers/generation_utils.py#L571

                tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents)

is it the same situation as eos_token_id - that is it is always needed?

I see it’s may be defined here: https://github.com/huggingface/transformers/blob/master/src/transformers/generation_utils.py#L355 but only if eos_token_id is defined.

        if pad_token_id is None and eos_token_id is not None:
            logger.warning(
                "Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_id)
            )
            pad_token_id = eos_token_id

my thinking is that if this worked until now for all models, it’s another proof that eos_token_id has to be required again.

in _generate_no_beam_search pad_token_id is required and similarly to eos_token_id can be asserted once on top and not multiple times through the code.

Thank you for reviewing my observations. It’s possible that some (all?) are incorrect if I missed something.

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:7 (7 by maintainers)

github_iconTop GitHub Comments

1reaction
patrickvonplatencommented, Sep 7, 2020

I think your PR is fine because if no eos_token_id is defined, this condition can never happen: sent_lengths[i] < max_length:. What I mean is that if no eos_token_id is defined no matter what generate() method is used, all sent_length will always be == max_length and the condition will not be hit.

1reaction
patrickvonplatencommented, Sep 7, 2020

Hey @stas00,

This is definitely part of the code that should be refactored 😄 Super hard to follow the logic there 😕

As a start, this PR is probably quite useful for context: https://github.com/huggingface/transformers/pull/2885. So there are a couple of models where EOS token is not defined and I’m quite sure that the code you linked does not always get hit. It can very well be that we apply beam search to OpenAIGPT - with a given max_length. OpenAIGPT does not have an EOS token, but beam search should work nevertheless.

It’s quite a tricky pad token / eos token / … logic that is implemented there. I think we have to be super careful to not break anything here - even if all the slow tests pass, it might not be enough (OpenAIGPT beam search is not integration tested…)

Also, I’m currently working on refactoring the generate function, will ping you guys in a couple of days with a first design proposition. My idea is to pull apart beam search + greedy / beam search + sampling / no beam search + greedy / no beam searh + greedy to make everything more readable. I’m not sure whether it’s worth diving deep into the generate() logic before we have a more readable code

Read more comments on GitHub >

github_iconTop Results From Across the Web

No results found

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found