question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Training loss of BART is going to nan in transformers>=4.21.0

See original GitHub issue

System Info

transformers==4.20.1 and transformers>=4.21.0 torch==1.12.1

Who can help?

@patil-suraj

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, …)
  • My own task or dataset (give details below)

Reproduction

Hi, I’m using a huge dataset, so it is hard to show how to reproduce my problem. I’m using BART pre-trained model and trying to fine-tune the model as a translation model. But, the training loss is completely differently calculated depending on transformers’s versions.

My pseudo code is like:

net = BartForConditionalGeneration.from_pretrained("gogamza/kobart-base-v1").to(rank)
net.train()

with amp.autocast(enabled=True):
    output = net(
        input_ids=input_ids,
        attention_mask=attention_mask,
        decoder_input_ids=decoder_input_ids,,
        decoder_attention_mask=decoder_attention_mask,
    )
    # draw graphs of output.loss

I drawed the graph (training loss by iterations) using wandb.

image

effortless-water-23 (green): transformers>=4.21.0 swept-tree-24 (pink): transformers==4.20.1

swept-tree-24 was slowly coverged to zero, but effortless-water-23 eventually got nan at 80k+ iterations. (The above graph didn’t show that.) I’ve searched the difference between transformers>=4.21.0 and transformers==4.20.1 especially about BART, and I’m suspicious this part.

So as I reverted the part in transformers>=4.21.0 like:

# mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))

the problem was gone. (The result is same as swept-tree-24 used transformers==4.20.1)

Anyway my problem is solved, but I’m wondering what is the real cause of the problem. Thanks in advance.

Expected behavior

I explained this part in the reproduction part.

Issue Analytics

  • State:closed
  • Created a year ago
  • Comments:10 (1 by maintainers)

github_iconTop GitHub Comments

1reaction
ydshiehcommented, Sep 5, 2022

@soocheolnoh I am very happy that you found the cause and a solution! Also appreciate a lot your effort on the detailed issue description and further investigations!

It’s better to use the same decoder start token as the one used in pretraining. Regarding using pad token, it might work in some case, but we should be very careful. I believe in the case of this issue, it might be related to the decoder attention mask. I saw you have

decoder_attention_mask = torch.ne(decoder_input_ids, tokenizer.pad_token_id)

So when you used pad token as decoder start token, you prepared a (decoder) attention mask that ignores the decoder start token (which shouldn’t be ignored).

Of course, this is just one observation that might be related.

I am going to close the issue. Don’t hesitate to reopen if you still have further questions.

0reactions
soocheolnohcommented, Sep 5, 2022

Thank you! @ydshieh

Read more comments on GitHub >

github_iconTop Results From Across the Web

Loss is "nan" when fine-tuning NLI model (both RoBERTa ...
I have the impression that the fine-tuning works (it does the training and saves the checkpoints), but trainer.train() and trainer.evaluate() ...
Read more >
BERT HuggingFace gives NaN Loss - Stack Overflow
EDIT: The model computes losses on the first epoch but it starts returning NaNs at the second epoch. What could be causing that...
Read more >
BART: Denoising Sequence-to-Sequence Pre-training for ...
We present BART, a denoising autoencoder for pretraining sequence-to-sequence models. BART is trained by (1) corrupting text with an.
Read more >
Compare Packages Between Distributions - DistroWatch.com
Comparing package versions between two distributions. Often times it is useful to be able to compare the versions of different packages between two ......
Read more >
ubuntupackages.cgi - UDD - Debian
... 0.9+dfsg-2 bart 0.8.00-3 bart-cuda 0.8.00-2 bart-view 0.2.00-1 base16384 ... 1.1-9 fonts-train 1.000-20210120-2 fonts-tuffy 20120614-2.1 fonts-ubuntu ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found