question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

inf/nan in generate (beam_sample) with small temperature values

See original GitHub issue

Environment info

  • transformers` version: transformers version: ‘4.6.0.dev0’
  • Platform: Linux
  • Python version: 3.6.9
  • PyTorch version (GPU?): ‘1.8.0’ (yes)

Information

The generate function (beam_sample) throws error when passing small temperature values.

To reproduce

from transformers import (
    AutoModelForSeq2SeqLM, 
    AutoTokenizer
  )

model_name = "sshleifer/distilbart-xsum-12-3"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

text = "New York City (NYC), often simply called New York, is the most populous city in the United States"
input_ids = tokenizer.encode(text, return_tensors='pt')

sample_outputs = model.generate(input_ids, 
                                num_beams=3,
                                do_sample=True,
                                temperature=0.2
                                )
Traceback (most recent call last):
  File "test.py", line 16, in <module>
    temperature=0.2
  File "/opt/anaconda3/envs/tensorflow2/lib/python3.6/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/opt/anaconda3/envs/tensorflow2/lib/python3.6/site-packages/transformers/generation_utils.py", line 1113, in generate
    **model_kwargs,
  File "/opt/anaconda3/envs/tensorflow2/lib/python3.6/site-packages/transformers/generation_utils.py", line 2134, in beam_sample
    next_tokens = torch.multinomial(probs, num_samples=2 * num_beams)
RuntimeError: probability tensor contains either `inf`, `nan` or element < 0

Another way to reproduce this error is using higher temperatures and more iterations (generate a longer output).

It looks like this error is caused by next_token_scores growing to -inf and probs becoming nan. Apparently, large absolute values accumulate over iterations because next_token_scores are no longer normalized after adding unnormalized beam_scores. beam_scores are calculated form the output of logits_warper(input_ids, next_token_scores) , and can grow fast with low temperatures (warper does: scores = scores / self.temperature).

Expected behavior

Is the increase of unscaled values a desired behaviour and should one just implement their own logits_warper handling float overflow?

If not, a quick fix, just for demonstration, is scaling the values of beam_scores added to next_token_scores by replacing: next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores) with: beam_scores_softmax = F.softmax(beam_scores, dim=-1) next_token_scores = next_token_scores + beam_scores_softmax[:, None].expand_as(next_token_scores) It works fine but changes absolute values of scores users may rely on.

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:5 (4 by maintainers)

github_iconTop GitHub Comments

1reaction
elsannscommented, Jan 10, 2022

Hi @patrickvonplaten,

Thank you for a detailed answer.

I noticed this behaviour testing various decoding methods, and I don’t recall seeing a significant advantage of beam_sample in any particular use case.

Since the new approach would be a breaking change, it seems a right solution to keep it the way it is for now.

Thanks again for your answer

0reactions
github-actions[bot]commented, Feb 4, 2022

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Read more comments on GitHub >

github_iconTop Results From Across the Web

No results found

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found