question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Trainer.evaluate does not support seq2seq models

See original GitHub issue

🐛 Bug

Information

Hi! I can’t thank you enough for Transformers. I know that the Trainer is still under development, but would like to report this just to know the current status.

Currently Trainer._prediction_loop assumes that different batches of data have the same shape. Specifically, this line

preds = torch.cat((preds, logits.detach()), dim=0)

This does not allow to use Trainer.evaluate for models with a variable output (e.g. seq2seq models). One of the possible solutions is to pad all batches to the same length, but it is pretty inefficient.

The problem arises when using:

  • the official example scripts: (give details below)
  • my own modified scripts: (give details below)

The tasks I am working on is:

  • an official GLUE/SQUaD task: (give the name)
  • my own task or dataset: (give details below)

To reproduce

Steps to reproduce the behavior:

  1. create seq2seq model
  2. pad batches in such a way that each batch is padded to the maximum length within batch
  3. create Trainer for the model, call .evaluate()
Traceback (most recent call last):
  File "/home/vlialin/miniconda3/lib/python3.7/site-packages/transformers/trainer.py", line 509, in train
    self.evaluate()
  File "/home/vlialin/miniconda3/lib/python3.7/site-packages/transformers/trainer.py", line 696, in evaluate
    output = self._prediction_loop(eval_dataloader, description="Evaluation")
  File "/home/vlialin/miniconda3/lib/python3.7/site-packages/transformers/trainer.py", line 767, in _prediction_loop
    preds = torch.cat((preds, logits.detach()), dim=0)
RuntimeError: Sizes of tensors must match except in dimension 0. Got 29 and 22 in dimension 1

Expected behavior

Trainer is able to evaluate Seq2seq

Environment info

  • transformers version: 2.11
  • Platform: Linux
  • Python version: 3.7.6
  • PyTorch version (GPU?): 1.5.0
  • Tensorflow version (GPU?): 2.2.0
  • Using GPU in script?: No
  • Using distributed or parallel set-up in script?: No

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Reactions:1
  • Comments:6 (3 by maintainers)

github_iconTop GitHub Comments

2reactions
ggaemocommented, Aug 7, 2020

Still no updates on this issue?

2reactions
patil-surajcommented, Jun 12, 2020

Hi @Guitaricet , if you only want to evaluate for loss (AFAIK this is the case for seq2seq models) then you can set prediction_loss_only to True

Read more comments on GitHub >

github_iconTop Results From Across the Web

Trainer.evaluate() with text generation - Hugging Face Forums
I'm evaluating my trained model and am trying to decide between trainer.evaluate() and model.generate(). Running the same input/model with both ...
Read more >
Dealing with infs in Seq2Seq Trainer - Stack Overflow
I am trying to fine tune a hugging face model onto a Shell Code dataset (https://huggingface.co/datasets/SoLID/shellcode_i_a32).
Read more >
Encoder-Decoder Seq2Seq Models, Clearly Explained!!
Thus we can't use what we did in the training phase as we don't have the target-sequence/Y_true. Thus when we are testing our...
Read more >
Seq2Seq Model - Simple Transformers
Class Seq2SeqModel; Training a Seq2SeqModel; Evaluating a ... MarianMT models are translation models with support for a huge variety of ...
Read more >
LightningModule - PyTorch Lightning - Read the Docs
The PyTorch code IS NOT abstracted - just organized. ... Trainer(max_epochs=1) model = LitModel() trainer.fit(model, ... class Seq2Seq(pl.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found