question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

torch.fx not working on ViT model

See original GitHub issue

System Info

transformer version: 4.23.0.dev0 platform: windows 11, AMD64 python version: 3.7.9

Who can help?

@NielsRogge

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, …)
  • My own task or dataset (give details below)

Reproduction

  1. write below code and execute
import torch
import numpy as np
from transformers import ViTFeatureExtractor, ViTModel
from datasets import load_dataset
from torch.fx import symbolic_trace


dataset = load_dataset("huggingface/cats-image")
image = dataset["test"]["image"][0]

feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")


img = torch.Tensor(np.random.randn(1, 3, 224, 224))

tracing_info = {"head_mask": None, "interpolate_pos_encoding": None, "bool_masked_pos": None, "output_hidden_states": None, "output_attentions": None, "return_dict": None}
traced = symbolic_trace(model, tracing_info) # bug here

with torch.no_grad():
    outputs = model(img)
    outputs2 = traced(img)
assert torch.allclose(dict(outputs)["last_hidden_state"], outputs2["last_hidden_state"])
  1. error traceback
symbolically traced variables cannot be used as inputs to control flow
  File "C:\Users\NOTA2001\Desktop\abab\transformers\src\transformers\models\vit\modeling_vit.py", line 166, in forward
    if num_channels != self.num_channels:
  File "C:\Users\NOTA2001\Desktop\abab\transformers\src\transformers\models\vit\modeling_vit.py", line 118, in forward
    embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
  File "C:\Users\NOTA2001\Desktop\abab\transformers\src\transformers\models\vit\modeling_vit.py", line 558, in forward
    pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
  File "C:\Users\NOTA2001\Desktop\abab\transformers\do_something.py", line 19, in <module>
    traced = symbolic_trace(model, inputs2)

Expected behavior

when this issue was fixed then below code will work

import torch
import numpy as np
from transformers import ViTFeatureExtractor, ViTModel
from datasets import load_dataset
from torch.fx import symbolic_trace


dataset = load_dataset("huggingface/cats-image")
image = dataset["test"]["image"][0]

feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")


img = torch.Tensor(np.random.randn(1, 3, 224, 224))

tracing_info = {"head_mask": None, "interpolate_pos_encoding": None, "bool_masked_pos": None, "output_hidden_states": None, "output_attentions": None, "return_dict": None}
traced = symbolic_trace(model, tracing_info) # bug here

with torch.no_grad():
    outputs = model(img)
    outputs2 = traced(img)
assert torch.allclose(dict(outputs)["last_hidden_state"], outputs2["last_hidden_state"])

Issue Analytics

  • State:closed
  • Created a year ago
  • Comments:5 (3 by maintainers)

github_iconTop GitHub Comments

1reaction
dwlim-notacommented, Sep 29, 2022

thank you for your comment @michaelbenayoun I will close this issue 😃

0reactions
michaelbenayouncommented, Sep 28, 2022

Hi @dwlim-nota, We actually do support torch.fx symbolic tracing though our custom tracer, which is supposed to handle the ViT case and what you do in your PR.

from transformers.utils.fx import symbolic_trace
traced = symbolic_trace(vit_model, input_names=[ "pixel_values"])

Compared to the original torch.fx.symbolic_trace function, you need to specify which inputs the traced model must have (because our models usually support different inputs, which is not possible in fx), so that is why I specified pixel_values here.

Read more comments on GitHub >

github_iconTop Results From Across the Web

[DISCUSS][torch.fx] Support pytorch's new frontend torch.fx
fx statisfies most models I recently play with (e.g., ViT based models, quantized transformer). I agree there might be some corner issues we ......
Read more >
torch.fx — PyTorch 1.13 documentation
The key is to work backwards: first, check the results of invoking the generated module to prove or disprove correctness. Then, inspect and...
Read more >
Getting Started with PyTorch Image Models (timm)
Let's explore how we can use FX to extract features from timm models. First, let's import some helper methods from TorchVision: Now, we...
Read more >
Efficient Training on a Single GPU - Hugging Face
Efficient Training on a Single GPU. This guide focuses on training large models efficiently on a single GPU. These approaches are still valid...
Read more >
import torch.fx ModuleNotFoundError: No module named ...
The issue also happens with torch '1.10.0' . As noted in another answer from torch.fx import symbolic_trace resolved the problem for me!
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found