Encounter error during ONNX export
See original GitHub issue🐛 Describe the bug
Hi team,
I tried to figure out whether pyg model can be exported to ONNX for inference. And I tested the code from the example in the pyg repo.
import os
import torch
import onnx
import onnxruntime as ort
from torch_geometric.nn import SAGEConv
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = SAGEConv(8, 16)
self.conv2 = SAGEConv(16, 16)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
x = self.conv2(x, edge_index)
return x
model = MyModel()
x = torch.randn(3, 8)
edge_index = torch.tensor([[0, 1, 2], [1, 0, 2]])
expected = model(x, edge_index)
assert expected.size() == (3, 16)
torch.onnx.export(model, (x, edge_index), 'model.onnx',
input_names=('x', 'edge_index'), opset_version=16)
model = onnx.load('model.onnx')
onnx.checker.check_model(model)
ort_session = ort.InferenceSession('model.onnx')
out = ort_session.run(None, {
'x': x.numpy(),
'edge_index': edge_index.numpy()
})[0]
out = torch.from_numpy(out)
assert torch.allclose(out, expected, atol=1e-6)
os.remove('model.onnx')
However, error occurs at the step of onnx export:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In [11], line 27
24 expected = model(x, edge_index)
25 assert expected.size() == (3, 16)
---> 27 torch.onnx.export(model, (x, edge_index), 'model.onnx',
28 input_names=('x', 'edge_index'), opset_version=16)
30 model = onnx.load('model.onnx')
31 onnx.checker.check_model(model)
File ~/Library/Python/3.9/lib/python/site-packages/torch/onnx/utils.py:504, in export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, opset_version, do_constant_folding, dynamic_axes, keep_initializers_as_inputs, custom_opsets, export_modules_as_functions)
186 @_beartype.beartype
187 def export(
188 model: Union[torch.nn.Module, torch.jit.ScriptModule, torch.jit.ScriptFunction],
(...)
204 export_modules_as_functions: Union[bool, Collection[Type[torch.nn.Module]]] = False,
205 ) -> None:
206 r"""Exports a model into ONNX format.
207
208 If ``model`` is not a :class:`torch.jit.ScriptModule` nor a
(...)
501 All errors are subclasses of :class:`errors.OnnxExporterError`.
502 """
--> 504 _export(
505 model,
506 args,
507 f,
508 export_params,
509 verbose,
510 training,
511 input_names,
512 output_names,
513 operator_export_type=operator_export_type,
514 opset_version=opset_version,
515 do_constant_folding=do_constant_folding,
516 dynamic_axes=dynamic_axes,
517 keep_initializers_as_inputs=keep_initializers_as_inputs,
518 custom_opsets=custom_opsets,
519 export_modules_as_functions=export_modules_as_functions,
520 )
File ~/Library/Python/3.9/lib/python/site-packages/torch/onnx/utils.py:1529, in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, opset_version, do_constant_folding, dynamic_axes, keep_initializers_as_inputs, fixed_batch_size, custom_opsets, add_node_names, onnx_shape_inference, export_modules_as_functions)
1526 dynamic_axes = {}
1527 _validate_dynamic_axes(dynamic_axes, model, input_names, output_names)
-> 1529 graph, params_dict, torch_out = _model_to_graph(
1530 model,
1531 args,
1532 verbose,
1533 input_names,
1534 output_names,
1535 operator_export_type,
1536 val_do_constant_folding,
1537 fixed_batch_size=fixed_batch_size,
1538 training=training,
1539 dynamic_axes=dynamic_axes,
1540 )
1542 # TODO: Don't allocate a in-memory string for the protobuf
1543 defer_weight_export = (
1544 export_type is not _exporter_states.ExportTypes.PROTOBUF_FILE
1545 )
File ~/Library/Python/3.9/lib/python/site-packages/torch/onnx/utils.py:1111, in _model_to_graph(model, args, verbose, input_names, output_names, operator_export_type, do_constant_folding, _disable_torch_constant_prop, fixed_batch_size, training, dynamic_axes)
1108 args = (args,)
1110 model = _pre_trace_quant_model(model, args)
-> 1111 graph, params, torch_out, module = _create_jit_graph(model, args)
1112 params_dict = _get_named_param_dict(graph, params)
1114 try:
File ~/Library/Python/3.9/lib/python/site-packages/torch/onnx/utils.py:987, in _create_jit_graph(model, args)
982 graph = _C._propagate_and_assign_input_shapes(
983 graph, flattened_args, param_count_list, False, False
984 )
985 return graph, params, torch_out, None
--> 987 graph, torch_out = _trace_and_get_graph_from_model(model, args)
988 _C._jit_pass_onnx_lint(graph)
989 state_dict = torch.jit._unique_state_dict(model)
File ~/Library/Python/3.9/lib/python/site-packages/torch/onnx/utils.py:891, in _trace_and_get_graph_from_model(model, args)
889 prev_autocast_cache_enabled = torch.is_autocast_cache_enabled()
890 torch.set_autocast_cache_enabled(False)
--> 891 trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
892 model,
893 args,
894 strict=False,
895 _force_outplace=False,
896 _return_inputs_states=True,
897 )
898 torch.set_autocast_cache_enabled(prev_autocast_cache_enabled)
900 warn_on_static_input_change(inputs_states)
File ~/Library/Python/3.9/lib/python/site-packages/torch/jit/_trace.py:1184, in _get_trace_graph(f, args, kwargs, strict, _force_outplace, return_inputs, _return_inputs_states)
1182 if not isinstance(args, tuple):
1183 args = (args,)
-> 1184 outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
1185 return outs
File ~/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py:1190, in Module._call_impl(self, *input, **kwargs)
1186 # If we don't have any hooks, we want to skip the rest of the logic in
1187 # this function, and just call forward.
1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1189 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190 return forward_call(*input, **kwargs)
1191 # Do not call functions when jit is used
1192 full_backward_hooks, non_full_backward_hooks = [], []
File ~/Library/Python/3.9/lib/python/site-packages/torch/jit/_trace.py:127, in ONNXTracedModule.forward(self, *args)
124 else:
125 return tuple(out_vars)
--> 127 graph, out = torch._C._create_graph_by_tracing(
128 wrapper,
129 in_vars + module_state,
130 _create_interpreter_name_lookup_fn(),
131 self.strict,
132 self._force_outplace,
133 )
135 if self._return_inputs:
136 return graph, outs[0], ret_inputs[0]
File ~/Library/Python/3.9/lib/python/site-packages/torch/jit/_trace.py:118, in ONNXTracedModule.forward.<locals>.wrapper(*args)
116 if self._return_inputs_states:
117 inputs_states.append(_unflatten(in_args, in_desc))
--> 118 outs.append(self.inner(*trace_inputs))
119 if self._return_inputs_states:
120 inputs_states[0] = (inputs_states[0], trace_inputs)
File ~/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py:1190, in Module._call_impl(self, *input, **kwargs)
1186 # If we don't have any hooks, we want to skip the rest of the logic in
1187 # this function, and just call forward.
1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1189 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190 return forward_call(*input, **kwargs)
1191 # Do not call functions when jit is used
1192 full_backward_hooks, non_full_backward_hooks = [], []
File ~/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py:1178, in Module._slow_forward(self, *input, **kwargs)
1176 recording_scopes = False
1177 try:
-> 1178 result = self.forward(*input, **kwargs)
1179 finally:
1180 if recording_scopes:
Cell In [11], line 16, in MyModel.forward(self, x, edge_index)
15 def forward(self, x, edge_index):
---> 16 x = self.conv1(x, edge_index).relu()
17 x = self.conv2(x, edge_index)
18 return x
File ~/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py:1190, in Module._call_impl(self, *input, **kwargs)
1186 # If we don't have any hooks, we want to skip the rest of the logic in
1187 # this function, and just call forward.
1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1189 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190 return forward_call(*input, **kwargs)
1191 # Do not call functions when jit is used
1192 full_backward_hooks, non_full_backward_hooks = [], []
File ~/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py:1178, in Module._slow_forward(self, *input, **kwargs)
1176 recording_scopes = False
1177 try:
-> 1178 result = self.forward(*input, **kwargs)
1179 finally:
1180 if recording_scopes:
File ~/Library/Python/3.9/lib/python/site-packages/torch_geometric/nn/conv/sage_conv.py:132, in SAGEConv.forward(self, x, edge_index, size)
130 # propagate_type: (x: OptPairTensor)
131 out = self.propagate(edge_index, x=x, size=size)
--> 132 out = self.lin_l(out)
134 x_r = x[1]
135 if self.root_weight and x_r is not None:
File ~/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py:1190, in Module._call_impl(self, *input, **kwargs)
1186 # If we don't have any hooks, we want to skip the rest of the logic in
1187 # this function, and just call forward.
1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1189 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190 return forward_call(*input, **kwargs)
1191 # Do not call functions when jit is used
1192 full_backward_hooks, non_full_backward_hooks = [], []
File ~/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py:1178, in Module._slow_forward(self, *input, **kwargs)
1176 recording_scopes = False
1177 try:
-> 1178 result = self.forward(*input, **kwargs)
1179 finally:
1180 if recording_scopes:
File ~/Library/Python/3.9/lib/python/site-packages/torch_geometric/nn/dense/linear.py:118, in Linear.forward(self, x)
113 def forward(self, x: Tensor) -> Tensor:
114 r"""
115 Args:
116 x (Tensor): The features.
117 """
--> 118 return F.linear(x, self.weight, self.bias)
RuntimeError: Cannot insert a Tensor that requires grad as a constant. Consider making it a parameter or input, or detaching the gradient
Tensor:
-0.1737 0.0572 -0.2492 -0.0007 -0.2551 -0.3244 -0.3505 -0.2479
-0.2926 0.0689 -0.0733 0.1093 -0.3388 0.2230 0.3113 -0.1732
-0.2584 -0.3294 -0.1186 -0.1263 -0.2351 0.1039 0.1304 -0.1237
0.2858 0.1967 0.2776 -0.1998 -0.2199 0.2670 0.1745 0.1999
0.2931 0.1472 -0.3305 -0.3097 0.0389 -0.0562 -0.1534 0.3115
0.0438 -0.1634 0.2131 0.0337 0.1632 -0.1252 0.2368 -0.3067
-0.2482 0.0340 0.2924 0.0272 0.0925 -0.1025 0.3397 -0.3534
0.2330 0.1820 -0.2319 0.0098 0.0659 0.0823 0.1556 0.1783
-0.1823 0.3397 -0.2553 0.0930 -0.2429 -0.3190 -0.1297 -0.1721
0.2961 0.2292 0.1183 0.0785 0.0889 0.1453 -0.1325 -0.1583
-0.2163 0.1580 0.3421 -0.3478 -0.2789 0.2223 0.2199 -0.1646
0.0150 0.1753 0.2263 0.2297 -0.0150 -0.2408 0.2811 0.1594
-0.0436 -0.1363 -0.0851 0.1756 0.1284 -0.0845 0.2577 -0.2410
0.2452 0.1854 0.1490 -0.1863 -0.3367 0.1585 0.2370 0.1055
0.1854 -0.1109 0.3030 0.3058 -0.2230 0.2402 -0.2770 0.0289
-0.1733 0.2713 -0.3123 0.1801 0.3247 -0.0202 -0.3534 -0.3037
[ torch.FloatTensor{16,8} ]
Environment
- PyG version: 2.1.0.post1
- PyTorch version: 1.13.0
- OS: macOS 13.0
- Python version: 3.9.6
- CUDA/cuDNN version: None
- How you installed PyTorch and PyG (
conda
,pip
, source): pip - Any other relevant information (e.g., version of
torch-scatter
):- torch-cluster==1.6.0
- torch-scatter==2.0.9
- torch-sparse==0.6.15
- torch-spline-conv==1.2.1
Issue Analytics
- State:
- Created 10 months ago
- Comments:7 (3 by maintainers)
Top Results From Across the Web
I get an error when exporting to onnx but cannot find what it ...
It is a "torch.onnx.export" of a large model, and of course "shape" is used numerous times. The error does not tell which line...
Read more >torch.onnx — PyTorch 1.13 documentation
The torch. onnx module can export PyTorch models to ONNX. The model can then be consumed by any of the many runtimes that...
Read more >Question - ScatterElements error - Unity Forum
It will be present in 2.3.0 It does require you exporting the model ... Unknown type ScatterElements encountered while parsing layer 878.".
Read more >ONNX export yields Error ! - MATLAB Answers - MathWorks
Hi, I tried to use exportONNXNetwork, I ran this part of code, but i saw this error, could you help me pls?! Usage...
Read more >Running an ONNX model ValueError: not enough values to ...
I am new to python. I would like to convert my .pth file to .onnx file but encounter an error when running the...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Thank you. I am a bit helpless here since I cannot reproduce this on my end 😦 Is there any chance you can debug on your end what might be causing this?
Hi @rusty1s, I have the following findings after some experiments:
torch.onnx.export
, the error will be gone.However,
with torch.no_grad():
doesn’t work.torch_geometric.nn.dense.linear
inSAGEConv
bytorch.nn.Linear
, the error will be gone as well.