BartForConditionalGeneration breaks with label smoothing loss
See original GitHub issueEnvironment info
transformers
version: 4.3.3- The other parameters are irrelevant
Who can help
Information
I apologize for not using the provided template for this issue.
By generating entries with PreTrainedTokenizer.prepare_seq2seq_batch
, collating with DataCollatorForSeq2Seq
and
training a BartForConditionalGeneration
with Seq2SeqTrainer
, I ran into this particular error message.
Traceback (most recent call last):
File "playground.py", line 42, in <module>
trainer.train()
File "/Users/mingrui.wang/miniconda3/envs/bert/lib/python3.7/site-packages/transformers/trainer.py", line 940, in train
tr_loss += self.training_step(model, inputs)
File "/Users/mingrui.wang/miniconda3/envs/bert/lib/python3.7/site-packages/transformers/trainer.py", line 1304, in training_step
loss = self.compute_loss(model, inputs)
File "/Users/mingrui.wang/miniconda3/envs/bert/lib/python3.7/site-packages/transformers/trainer.py", line 1341, in compute_loss
loss = self.label_smoother(outputs, labels)
File "/Users/mingrui.wang/miniconda3/envs/bert/lib/python3.7/site-packages/transformers/trainer_pt_utils.py", line 398, in __call__
nll_loss = log_probs.gather(dim=-1, index=labels)
RuntimeError: Size does not match at dimension 1 expected index [1, 7, 1] to be smaller than src [1, 5, 50265] apart from dimension 2
Provided is a script to replicate the error
from torch.utils.data import Dataset
from transformers import (BartForConditionalGeneration, BartTokenizer,
BatchEncoding, DataCollatorForSeq2Seq,
PreTrainedTokenizer, Seq2SeqTrainer,
TrainingArguments)
class DummySeq2SeqDataset(Dataset):
def __init__(self, tokenizer: PreTrainedTokenizer):
self.tokenizer = tokenizer
self.data = [
("Hello world!", "Hallo welt!"),
]
def __len__(self):
return len(self.data)
def __getitem__(self, index: int) -> BatchEncoding:
src_text, tgt_text = self.data[index]
return self.tokenizer.prepare_seq2seq_batch(
src_text, tgt_text,
return_token_type_ids=False
)
train_args = TrainingArguments(output_dir='tmp', label_smoothing_factor=0.1)
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
model = BartForConditionalGeneration.from_pretrained('facebook/bart-base')
train_dataset = DummySeq2SeqDataset(tokenizer)
data_collator = DataCollatorForSeq2Seq(tokenizer)
trainer = Seq2SeqTrainer(
model=model,
args=train_args,
data_collator=data_collator,
train_dataset=train_dataset,
tokenizer=tokenizer,
)
trainer.train()
Source of problem
The problem lies with the interaction between BartForConditionalGeneration
and how label-smoothing is implemented in Trainer
.
BartForConditionalGeneration.forward
is highly tied to labels
, since it’s also used to generate decoder_input_ids
.
In the behavior of label-smoothing as implemented in the Trainer
class, the following is currently being done.
The whoopsie is that labels
is removed from the arguments passed to BartForConditionalGeneration.forward
. Computation of logits then defaulted to using input_ids
as decoder_input_ids
.
Possible solutions
A possible way to fix this would be to shift smooth label loss into the loss computation of each model rather than in Trainer
. Doing it this way comes with its own set of pros and cons.
Pros
- Backward compatibility can be completely maintained
- Removes the little bit of code smell where true training loss is not reflected in
model.forward
whenlabel_smoothing > 0
.
Cons
- Complicates configuration
- label_smoothing loss defined in model config rather than training args
- Requires changes in many places in this repository (albeit, they are the same exact set of changes)
Issue Analytics
- State:
- Created 3 years ago
- Comments:10 (10 by maintainers)
Top GitHub Comments
I see, it’s a fair point. Implementing the feature this way also ensures that DataCollator performs all required preprocessing for training data input. Closing issue.
We only use the method of the model ot generate decoder input IDs, not the actual model, so I think it’s completely fine in this case. Passing the method from the model would be way weirder in terms of user API.