question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Cannot trace_module on models using model's generate function

See original GitHub issue

Environment info

  • transformers version: 3.3.1
  • Platform: Linux-5.8.14_1-x86_64-with-glibc2.10
  • Python version: 3.8.5
  • PyTorch version (GPU?): 1.6.0 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Using GPU in script?: no
  • Using distributed or parallel set-up in script?: no

Who can help

@patrickvonplaten, @sshleifer

Information

Model I am using BART

The problem arises when using:

  • [*] the official example scripts: (give details below)
  • my own modified scripts: (give details below)

The tasks I am working on is:

  • an official GLUE/SQUaD task: (give the name)
  • [*] my own task or dataset: (give details below)

To reproduce

Steps to reproduce the behavior:

  1. load any model that uses the generate function
  2. try to trace it using trace_module

Can be easily reproduced with the following snippet:

import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

model = 'sshleifer/bart-tiny-random'
tokenizer = AutoTokenizer.from_pretrained(model)
sqgen_model = AutoModelForSeq2SeqLM.from_pretrained(model, torchscript=True)
sqgen_model.eval()
dummy_input = ' '.join('dummy' for dummy in range(512))
batch = tokenizer(
    [dummy_input], return_tensors='pt', truncation=True, padding='longest',
)
with torch.no_grad():
    traced_model = torch.jit.trace_module(  # type: ignore
        sqgen_model,
        {
            'forward': (batch.input_ids, batch.attention_mask),
            'generate': (batch.input_ids, batch.attention_mask),
        },
    )

It throws an error:

  File "/home/void/.miniconda3/envs/lexml/src/transformers/src/transformers/generation_utils.py", line 288, in generate
    assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer."
AssertionError: `max_length` should be a strictly positive integer.

obviously because the generate function’s second argument is supposed to be max_length and not attention_mask

Expected behavior

Should be able to trace models that use the generate function.

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:6 (3 by maintainers)

github_iconTop GitHub Comments

1reaction
patrickvonplatencommented, Oct 14, 2020

generate currently does not support torch.jit.trace. This is sadly also not on the short-term roadmap.

0reactions
patrickvonplatencommented, Nov 30, 2021

Hey @DevBey,

Could you maybe open a new issue that states exactly what doesn’t work in your case? This issue is quite old now and it would be nice to have a reproducible code snippet with the current transformers version.

Read more comments on GitHub >

github_iconTop Results From Across the Web

TorchScript — PyTorch 1.13 documentation
For instance the beam search of a sequence to sequence model will typically be written in script but can call an encoder module...
Read more >
What to do when you get an error - Hugging Face Course
In this section we'll look at some common errors that can occur when you're trying to generate predictions from your freshly tuned Transformer...
Read more >
High Performance SoC Modeling with Verilator - Embecosm
This document describes how to use Verilator [13] to create a fast cycle accurate SystemC model of a complete System-on-Chip from its Verilog...
Read more >
Documentation - Module Resolution - TypeScript
Module resolution is the process the compiler uses to figure out what an import refers to. Consider an import statement like import {...
Read more >
trace — Trace or track Python statement execution — Python ...
The trace module allows you to trace program execution, generate annotated statement coverage listings, print caller/callee relationships and list functions ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found