`Comparison exception: The values for attribute 'dtype' do not match: torch.float64 != torch.int64.` when pruning YoloV5-face
See original GitHub issueDescribe the issue: I’m trying to prune the pre-trained model yolov5n-0.5 from Yolov5-face. Here is the code I used:
from models.experimental import attempt_load
from nni.algorithms.compression.pytorch.pruning import LevelPruner
from nni.compression.pytorch import ModelSpeedup
import torch
if __name__ == '__main__':
model = attempt_load('weights/yolov5n-0.5.pt', map_location=torch.device('cpu')) # FP32
model.eval()
config_list = [{
'sparsity': 0.5,
'op_types': ['default'],
}]
pruner = LevelPruner(model, config_list)
model = pruner.compress()
pruner.export_model(model_path='pruned_yolov5n-0.5.pth', mask_path='mask_yolov5n-0.5.pth')
m_speedup = ModelSpeedup(model, dummy_input=torch.rand((1, 3, 384, 640)), masks_file='mask_yolov5n-0.5.pth')
m_speedup.speedup_model()
But it always throws this error:
...
ERROR: Tensor-valued Constant nodes differed in value across invocations. This often indicates that the tracer has encountered untraceable code.
Node:
%1693 : Tensor = prim::Constant[value={0.5}](), scope: __module.model.21 # /home/vinhtq115/PycharmProjects/yolov5-face/models/yolo.py:71:0
Source Location:
/home/vinhtq115/PycharmProjects/yolov5-face/models/yolo.py(71): forward
/home/vinhtq115/miniconda3/envs/tensorrt/lib/python3.8/site-packages/torch/nn/modules/module.py(1090): _slow_forward
/home/vinhtq115/miniconda3/envs/tensorrt/lib/python3.8/site-packages/torch/nn/modules/module.py(1102): _call_impl
/home/vinhtq115/PycharmProjects/yolov5-face/models/yolo.py(167): forward_once
/home/vinhtq115/PycharmProjects/yolov5-face/models/yolo.py(151): forward
/home/vinhtq115/miniconda3/envs/tensorrt/lib/python3.8/site-packages/torch/nn/modules/module.py(1090): _slow_forward
/home/vinhtq115/miniconda3/envs/tensorrt/lib/python3.8/site-packages/torch/nn/modules/module.py(1102): _call_impl
/home/vinhtq115/miniconda3/envs/tensorrt/lib/python3.8/site-packages/torch/jit/_trace.py(958): trace_module
/home/vinhtq115/miniconda3/envs/tensorrt/lib/python3.8/site-packages/torch/jit/_trace.py(741): trace
/home/vinhtq115/miniconda3/envs/tensorrt/lib/python3.8/site-packages/nni/common/graph_utils.py(78): _trace
/home/vinhtq115/miniconda3/envs/tensorrt/lib/python3.8/site-packages/nni/common/graph_utils.py(66): __init__
/home/vinhtq115/miniconda3/envs/tensorrt/lib/python3.8/site-packages/nni/common/graph_utils.py(252): __init__
/home/vinhtq115/miniconda3/envs/tensorrt/lib/python3.8/site-packages/nni/common/graph_utils.py(24): build_module_graph
/home/vinhtq115/miniconda3/envs/tensorrt/lib/python3.8/site-packages/nni/compression/pytorch/speedup/compressor.py(57): __init__
prune2.py(21): <module>
Comparison exception: The values for attribute 'dtype' do not match: torch.float64 != torch.int64.
Full output: out.txt
I’m not sure what is wrong here because the input and weights are torch.float32
, not torch.float64
.
Environment:
- NNI version: 2.6
- Training service (local|remote|pai|aml|etc): N/A (using pretrained weights)
- Client OS: Ubuntu 20.04.3
- Server OS (for remote mode only): N/A
- Python version: 3.8.12
- PyTorch/TensorFlow version: PyTorch 1.10.0
- Is conda/virtualenv/venv used?: Conda
- Is running in Docker?: N/A
Issue Analytics
- State:
- Created 2 years ago
- Comments:5 (2 by maintainers)
Top Results From Across the Web
No results found
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
I tried
L1FilterPruner
and still have the same issue. Seems like it’s model-related issue so I will close the issue.@vinhtq115 - could you please report back how things going with the suggestions? we are closing resolved issues and wondering whether this issue is resolved or not. thanks.