question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

TypeError: forward() got an unexpected keyword argument 'attention_mask'

See original GitHub issue

Environment info

  • transformers version: 4.10.0
  • Platform: Windows-10-10.0.19042-SP0
  • Python version: 3.9.6
  • PyTorch version (GPU?): 1.9.0+cpu (False)
  • Tensorflow version (GPU?): 2.6.0 (False)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: No
  • Using distributed or parallel set-up in script?: No

@patrickvonplaten @patil-suraj

Information

I am using EncoderDecoderModel (encoder=TransfoXLModel, decoder=TransfoXLLMHeadModel) to train a generative model for text summarization using the ‘multi_x_science_sum’ huggingface dataset

When the training starts below error is given and training stops TypeError: forward() got an unexpected keyword argument ‘attention_mask’

To reproduce


tokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103')
txl2txl = EncoderDecoderModel.from_encoder_decoder_pretrained('transfo-xl-wt103', 'transfo-xl-wt103')

training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    evaluation_strategy="steps",
    per_device_train_batch_size=batch_size, # 4
    per_device_eval_batch_size=batch_size,  # 4
    output_dir="output",
    logging_steps=2,
    save_steps=10,
    eval_steps=4,
    num_train_epochs=1
)

trainer = Seq2SeqTrainer(
    model=txl2txl,
    tokenizer=tokenizer,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_data_processed,
    eval_dataset=validation_data_processed
)
trainer.train()

TypeError: forward() got an unexpected keyword argument 'attention_mask'
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
C:\Users\DILEEP~1\AppData\Local\Temp/ipykernel_21416/3777690609.py in <module>
      7     eval_dataset=validation_data_processed
      8 )
----> 9 trainer.train()

~\.conda\envs\msresearch\lib\site-packages\transformers\trainer.py in train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   1282                         tr_loss += self.training_step(model, inputs)
   1283                 else:
-> 1284                     tr_loss += self.training_step(model, inputs)
   1285                 self.current_flos += float(self.floating_point_ops(inputs))
   1286 

~\.conda\envs\msresearch\lib\site-packages\transformers\trainer.py in training_step(self, model, inputs)
   1787                 loss = self.compute_loss(model, inputs)
   1788         else:
-> 1789             loss = self.compute_loss(model, inputs)
   1790 
   1791         if self.args.n_gpu > 1:

~\.conda\envs\msresearch\lib\site-packages\transformers\trainer.py in compute_loss(self, model, inputs, return_outputs)
   1819         else:
   1820             labels = None
-> 1821         outputs = model(**inputs)
   1822         # Save past state if it exists
   1823         # TODO: this needs to be fixed and made cleaner later.

~\.conda\envs\msresearch\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

~\.conda\envs\msresearch\lib\site-packages\transformers\models\encoder_decoder\modeling_encoder_decoder.py in forward(self, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, encoder_outputs, past_key_values, inputs_embeds, decoder_inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, **kwargs)
    423 
    424         if encoder_outputs is None:
--> 425             encoder_outputs = self.encoder(
    426                 input_ids=input_ids,
    427                 attention_mask=attention_mask,

~\.conda\envs\msresearch\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

TypeError: forward() got an unexpected keyword argument 'attention_mask'

As a sidenote, when I do the same task with following setting, the training starts without a problem tokenizer = BertTokenizerFast.from_pretrained(‘bert-base-uncased’) bert2bert= EncoderDecoderModel.from_encoder_decoder_pretrained(‘bert-base-uncased’, ‘bert-base-uncased’)

Please provide me assistance on how to do the training with TransformerXL to TransformerXL model

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:5 (3 by maintainers)

github_iconTop GitHub Comments

1reaction
patil-surajcommented, Dec 14, 2021

Hey @dpitawela sorry to only answer now.

  1. I’m not very familiar with TransformerXL, so not sure about the attention_mask, @patrickvonplaten do you know why?
  2. instead of removing the attention mask etc, I will suggest using a different model which can process long sequence, will explain that below.
  3. Yes, IMO TransformerXL might not be a good choice for the encoder, since the model is trained as a decoder. Also, it is trained on WikiText-103 which is not a good enough dataset for pre-training. There two other models which can process long sequences. Longformer and BigBird.

You could use the longformer as encoder and bert/gpt2 as decoder or you could use the LED model.

And BigBird can be used as both encoder and decoder. So you could use bigbird2bigbird if the target sequences are also longer. Or bigbird to bert/gpt2. 4. IMO transforxl is not a good choice for such task, so probably not.

Hope this helps 😃

1reaction
patil-surajcommented, Sep 30, 2021

We haven’t really tested TransformerXL with EncoderDecoderModel so I’m not sure if it’s will work or not since it’s a bit of a different model. One major difference is that TransformerXL does not accept attetion_mask but in EncoderDecoderModel it’s passed each time. You could try by removing attetion_mask, and see if it works.

Also, TransformerXL is a decoder-only model, so it might not give the best results as an encoder. And out of curiosity, is there any reason you want to try TransformerXL to TransformerXL model?

Read more comments on GitHub >

github_iconTop Results From Across the Web

forward() got an unexpected keyword argument 'labels' - Stack ...
As far as I know, the BertModel does not take labels in the forward() function. Check out the forward function parameters.
Read more >
forward() got an unexpected keyword argument 'labels ...
I get the error above, I don't understand the reason since BertForSequenceClassification should have labels in its arguments. Any help is really ...
Read more >
TypeError: forward() got an unexpected keyword argument ...
Anyway I'm trying to implement a Bert Classifier to discriminate between 2 sequences classes (BINARY CLASSIFICATION), with AX hyperparameters ...
Read more >
报错:TypeError forward() got an unexpected keyword argument
已解决:TypeError: forward() got an unexpected keyword argument报错:TypeError: ... token_type_ids=token_type_ids, attention_mask=attention_.
Read more >
forward() got an unexpected keyword argument 'return_dict ...
TypeError : forward() got an unexpected keyword argument 'return_dict' BERT CLASSIFICATION HUGGINFACE with ray tuning - deployment ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found