torch.fx not working on ViT model
See original GitHub issueSystem Info
transformer version: 4.23.0.dev0 platform: windows 11, AMD64 python version: 3.7.9
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, …) - My own task or dataset (give details below)
Reproduction
- write below code and execute
import torch
import numpy as np
from transformers import ViTFeatureExtractor, ViTModel
from datasets import load_dataset
from torch.fx import symbolic_trace
dataset = load_dataset("huggingface/cats-image")
image = dataset["test"]["image"][0]
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
img = torch.Tensor(np.random.randn(1, 3, 224, 224))
tracing_info = {"head_mask": None, "interpolate_pos_encoding": None, "bool_masked_pos": None, "output_hidden_states": None, "output_attentions": None, "return_dict": None}
traced = symbolic_trace(model, tracing_info) # bug here
with torch.no_grad():
outputs = model(img)
outputs2 = traced(img)
assert torch.allclose(dict(outputs)["last_hidden_state"], outputs2["last_hidden_state"])
- error traceback
symbolically traced variables cannot be used as inputs to control flow
File "C:\Users\NOTA2001\Desktop\abab\transformers\src\transformers\models\vit\modeling_vit.py", line 166, in forward
if num_channels != self.num_channels:
File "C:\Users\NOTA2001\Desktop\abab\transformers\src\transformers\models\vit\modeling_vit.py", line 118, in forward
embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
File "C:\Users\NOTA2001\Desktop\abab\transformers\src\transformers\models\vit\modeling_vit.py", line 558, in forward
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
File "C:\Users\NOTA2001\Desktop\abab\transformers\do_something.py", line 19, in <module>
traced = symbolic_trace(model, inputs2)
Expected behavior
when this issue was fixed then below code will work
import torch
import numpy as np
from transformers import ViTFeatureExtractor, ViTModel
from datasets import load_dataset
from torch.fx import symbolic_trace
dataset = load_dataset("huggingface/cats-image")
image = dataset["test"]["image"][0]
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
img = torch.Tensor(np.random.randn(1, 3, 224, 224))
tracing_info = {"head_mask": None, "interpolate_pos_encoding": None, "bool_masked_pos": None, "output_hidden_states": None, "output_attentions": None, "return_dict": None}
traced = symbolic_trace(model, tracing_info) # bug here
with torch.no_grad():
outputs = model(img)
outputs2 = traced(img)
assert torch.allclose(dict(outputs)["last_hidden_state"], outputs2["last_hidden_state"])
Issue Analytics
- State:
- Created a year ago
- Comments:5 (3 by maintainers)
Top Results From Across the Web
[DISCUSS][torch.fx] Support pytorch's new frontend torch.fx
fx statisfies most models I recently play with (e.g., ViT based models, quantized transformer). I agree there might be some corner issues we ......
Read more >torch.fx — PyTorch 1.13 documentation
The key is to work backwards: first, check the results of invoking the generated module to prove or disprove correctness. Then, inspect and...
Read more >Getting Started with PyTorch Image Models (timm)
Let's explore how we can use FX to extract features from timm models. First, let's import some helper methods from TorchVision: Now, we...
Read more >Efficient Training on a Single GPU - Hugging Face
Efficient Training on a Single GPU. This guide focuses on training large models efficiently on a single GPU. These approaches are still valid...
Read more >import torch.fx ModuleNotFoundError: No module named ...
The issue also happens with torch '1.10.0' . As noted in another answer from torch.fx import symbolic_trace resolved the problem for me!
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
thank you for your comment @michaelbenayoun I will close this issue 😃
Hi @dwlim-nota, We actually do support
torch.fx
symbolic tracing though our custom tracer, which is supposed to handle the ViT case and what you do in your PR.Compared to the original
torch.fx.symbolic_trace
function, you need to specify which inputs the traced model must have (because our models usually support different inputs, which is not possible in fx), so that is why I specifiedpixel_values
here.