Cannot trace_module on models using model's generate function
See original GitHub issueEnvironment info
transformers
version: 3.3.1- Platform: Linux-5.8.14_1-x86_64-with-glibc2.10
- Python version: 3.8.5
- PyTorch version (GPU?): 1.6.0 (True)
- Tensorflow version (GPU?): not installed (NA)
- Using GPU in script?: no
- Using distributed or parallel set-up in script?: no
Who can help
Information
Model I am using BART
The problem arises when using:
- [*] the official example scripts: (give details below)
- my own modified scripts: (give details below)
The tasks I am working on is:
- an official GLUE/SQUaD task: (give the name)
- [*] my own task or dataset: (give details below)
To reproduce
Steps to reproduce the behavior:
- load any model that uses the generate function
- try to trace it using trace_module
Can be easily reproduced with the following snippet:
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
model = 'sshleifer/bart-tiny-random'
tokenizer = AutoTokenizer.from_pretrained(model)
sqgen_model = AutoModelForSeq2SeqLM.from_pretrained(model, torchscript=True)
sqgen_model.eval()
dummy_input = ' '.join('dummy' for dummy in range(512))
batch = tokenizer(
[dummy_input], return_tensors='pt', truncation=True, padding='longest',
)
with torch.no_grad():
traced_model = torch.jit.trace_module( # type: ignore
sqgen_model,
{
'forward': (batch.input_ids, batch.attention_mask),
'generate': (batch.input_ids, batch.attention_mask),
},
)
It throws an error:
File "/home/void/.miniconda3/envs/lexml/src/transformers/src/transformers/generation_utils.py", line 288, in generate
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer."
AssertionError: `max_length` should be a strictly positive integer.
obviously because the generate function’s second argument is supposed to be max_length and not attention_mask
Expected behavior
Should be able to trace models that use the generate function.
Issue Analytics
- State:
- Created 3 years ago
- Comments:6 (3 by maintainers)
Top Results From Across the Web
TorchScript — PyTorch 1.13 documentation
For instance the beam search of a sequence to sequence model will typically be written in script but can call an encoder module...
Read more >What to do when you get an error - Hugging Face Course
In this section we'll look at some common errors that can occur when you're trying to generate predictions from your freshly tuned Transformer...
Read more >High Performance SoC Modeling with Verilator - Embecosm
This document describes how to use Verilator [13] to create a fast cycle accurate SystemC model of a complete System-on-Chip from its Verilog...
Read more >Documentation - Module Resolution - TypeScript
Module resolution is the process the compiler uses to figure out what an import refers to. Consider an import statement like import {...
Read more >trace — Trace or track Python statement execution — Python ...
The trace module allows you to trace program execution, generate annotated statement coverage listings, print caller/callee relationships and list functions ...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
generate
currently does not supporttorch.jit.trace
. This is sadly also not on the short-term roadmap.Hey @DevBey,
Could you maybe open a new issue that states exactly what doesn’t work in your case? This issue is quite old now and it would be nice to have a reproducible code snippet with the current
transformers
version.