question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

BartForConditionalGeneration breaks with label smoothing loss

See original GitHub issue

Environment info

  • transformers version: 4.3.3
  • The other parameters are irrelevant

Who can help

@patrickvonplaten @sgugger

Information

I apologize for not using the provided template for this issue.

By generating entries with PreTrainedTokenizer.prepare_seq2seq_batch, collating with DataCollatorForSeq2Seq and training a BartForConditionalGeneration with Seq2SeqTrainer, I ran into this particular error message.

Traceback (most recent call last):
  File "playground.py", line 42, in <module>
    trainer.train()
  File "/Users/mingrui.wang/miniconda3/envs/bert/lib/python3.7/site-packages/transformers/trainer.py", line 940, in train
    tr_loss += self.training_step(model, inputs)
  File "/Users/mingrui.wang/miniconda3/envs/bert/lib/python3.7/site-packages/transformers/trainer.py", line 1304, in training_step
    loss = self.compute_loss(model, inputs)
  File "/Users/mingrui.wang/miniconda3/envs/bert/lib/python3.7/site-packages/transformers/trainer.py", line 1341, in compute_loss
    loss = self.label_smoother(outputs, labels)
  File "/Users/mingrui.wang/miniconda3/envs/bert/lib/python3.7/site-packages/transformers/trainer_pt_utils.py", line 398, in __call__
    nll_loss = log_probs.gather(dim=-1, index=labels)
RuntimeError: Size does not match at dimension 1 expected index [1, 7, 1] to be smaller than src [1, 5, 50265] apart from dimension 2

Provided is a script to replicate the error

from torch.utils.data import Dataset
from transformers import (BartForConditionalGeneration, BartTokenizer,
                          BatchEncoding, DataCollatorForSeq2Seq,
                          PreTrainedTokenizer, Seq2SeqTrainer,
                          TrainingArguments)


class DummySeq2SeqDataset(Dataset):

    def __init__(self, tokenizer: PreTrainedTokenizer):
        self.tokenizer = tokenizer
        self.data = [
            ("Hello world!", "Hallo welt!"),
        ]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index: int) -> BatchEncoding:
        src_text, tgt_text = self.data[index]
        return self.tokenizer.prepare_seq2seq_batch(
            src_text, tgt_text,
            return_token_type_ids=False
        )


train_args = TrainingArguments(output_dir='tmp', label_smoothing_factor=0.1)

tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
model = BartForConditionalGeneration.from_pretrained('facebook/bart-base')

train_dataset = DummySeq2SeqDataset(tokenizer)
data_collator = DataCollatorForSeq2Seq(tokenizer)

trainer = Seq2SeqTrainer(
    model=model,
    args=train_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
    tokenizer=tokenizer,
)
trainer.train()

Source of problem

The problem lies with the interaction between BartForConditionalGeneration and how label-smoothing is implemented in Trainer.

BartForConditionalGeneration.forward is highly tied to labels, since it’s also used to generate decoder_input_ids.

In the behavior of label-smoothing as implemented in the Trainer class, the following is currently being done.

https://github.com/huggingface/transformers/blob/3c733f320870261ea948049505a30c30fd6ea23a/src/transformers/trainer.py#L1447-L1469

The whoopsie is that labels is removed from the arguments passed to BartForConditionalGeneration.forward. Computation of logits then defaulted to using input_ids as decoder_input_ids.

Possible solutions

A possible way to fix this would be to shift smooth label loss into the loss computation of each model rather than in Trainer. Doing it this way comes with its own set of pros and cons.

Pros

  • Backward compatibility can be completely maintained
  • Removes the little bit of code smell where true training loss is not reflected in model.forward when label_smoothing > 0.

Cons

  • Complicates configuration
    • label_smoothing loss defined in model config rather than training args
  • Requires changes in many places in this repository (albeit, they are the same exact set of changes)

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:10 (10 by maintainers)

github_iconTop GitHub Comments

1reaction
mingruimingruicommented, Mar 2, 2021

I see, it’s a fair point. Implementing the feature this way also ensures that DataCollator performs all required preprocessing for training data input. Closing issue.

0reactions
sguggercommented, Mar 1, 2021

We only use the method of the model ot generate decoder input IDs, not the actual model, so I think it’s completely fine in this case. Passing the method from the model would be way weirder in terms of user API.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Label Smoothing as Another Regularization Trick
In this story, we define label smoothing, implement a cross-entropy loss function that uses this technique, and evaluate its performance. If you ...
Read more >
Label Smoothing - Rick Wierenga
Label smoothing is a mathematical technique that helps machine learning models to deal with data where some labels are wrong.
Read more >
Master's thesis Source Code Generation from Descriptions in a ...
servation from a space of inputs and Y is a corresponding label, is called a ... Figure 5.4: Smoothed progress of the cross-entropy...
Read more >
Label Smoothing & Deep Learning: Google Brain explains ...
Label smoothing is a loss function modification that has been shown to be ... Our team used it for example in breaking a...
Read more >
Label smoothing with Keras, TensorFlow, and Deep Learning
Label smoothing by explicitly updating your labels list; Label smoothing using your loss function. We'll then train our own custom models using ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found