question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Encounter error during ONNX export

See original GitHub issue

🐛 Describe the bug

Hi team,

I tried to figure out whether pyg model can be exported to ONNX for inference. And I tested the code from the example in the pyg repo.

import os

import torch
import onnx
import onnxruntime as ort
from torch_geometric.nn import SAGEConv


class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = SAGEConv(8, 16)
        self.conv2 = SAGEConv(16, 16)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x
    

model = MyModel()
x = torch.randn(3, 8)
edge_index = torch.tensor([[0, 1, 2], [1, 0, 2]])
expected = model(x, edge_index)
assert expected.size() == (3, 16)

torch.onnx.export(model, (x, edge_index), 'model.onnx',
                  input_names=('x', 'edge_index'), opset_version=16)

model = onnx.load('model.onnx')
onnx.checker.check_model(model)

ort_session = ort.InferenceSession('model.onnx')

out = ort_session.run(None, {
    'x': x.numpy(),
    'edge_index': edge_index.numpy()
})[0]
out = torch.from_numpy(out)
assert torch.allclose(out, expected, atol=1e-6)

os.remove('model.onnx')

However, error occurs at the step of onnx export:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In [11], line 27
     24 expected = model(x, edge_index)
     25 assert expected.size() == (3, 16)
---> 27 torch.onnx.export(model, (x, edge_index), 'model.onnx',
     28                   input_names=('x', 'edge_index'), opset_version=16)
     30 model = onnx.load('model.onnx')
     31 onnx.checker.check_model(model)

File ~/Library/Python/3.9/lib/python/site-packages/torch/onnx/utils.py:504, in export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, opset_version, do_constant_folding, dynamic_axes, keep_initializers_as_inputs, custom_opsets, export_modules_as_functions)
    186 @_beartype.beartype
    187 def export(
    188     model: Union[torch.nn.Module, torch.jit.ScriptModule, torch.jit.ScriptFunction],
   (...)
    204     export_modules_as_functions: Union[bool, Collection[Type[torch.nn.Module]]] = False,
    205 ) -> None:
    206     r"""Exports a model into ONNX format.
    207 
    208     If ``model`` is not a :class:`torch.jit.ScriptModule` nor a
   (...)
    501             All errors are subclasses of :class:`errors.OnnxExporterError`.
    502     """
--> 504     _export(
    505         model,
    506         args,
    507         f,
    508         export_params,
    509         verbose,
    510         training,
    511         input_names,
    512         output_names,
    513         operator_export_type=operator_export_type,
    514         opset_version=opset_version,
    515         do_constant_folding=do_constant_folding,
    516         dynamic_axes=dynamic_axes,
    517         keep_initializers_as_inputs=keep_initializers_as_inputs,
    518         custom_opsets=custom_opsets,
    519         export_modules_as_functions=export_modules_as_functions,
    520     )

File ~/Library/Python/3.9/lib/python/site-packages/torch/onnx/utils.py:1529, in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, opset_version, do_constant_folding, dynamic_axes, keep_initializers_as_inputs, fixed_batch_size, custom_opsets, add_node_names, onnx_shape_inference, export_modules_as_functions)
   1526     dynamic_axes = {}
   1527 _validate_dynamic_axes(dynamic_axes, model, input_names, output_names)
-> 1529 graph, params_dict, torch_out = _model_to_graph(
   1530     model,
   1531     args,
   1532     verbose,
   1533     input_names,
   1534     output_names,
   1535     operator_export_type,
   1536     val_do_constant_folding,
   1537     fixed_batch_size=fixed_batch_size,
   1538     training=training,
   1539     dynamic_axes=dynamic_axes,
   1540 )
   1542 # TODO: Don't allocate a in-memory string for the protobuf
   1543 defer_weight_export = (
   1544     export_type is not _exporter_states.ExportTypes.PROTOBUF_FILE
   1545 )

File ~/Library/Python/3.9/lib/python/site-packages/torch/onnx/utils.py:1111, in _model_to_graph(model, args, verbose, input_names, output_names, operator_export_type, do_constant_folding, _disable_torch_constant_prop, fixed_batch_size, training, dynamic_axes)
   1108     args = (args,)
   1110 model = _pre_trace_quant_model(model, args)
-> 1111 graph, params, torch_out, module = _create_jit_graph(model, args)
   1112 params_dict = _get_named_param_dict(graph, params)
   1114 try:

File ~/Library/Python/3.9/lib/python/site-packages/torch/onnx/utils.py:987, in _create_jit_graph(model, args)
    982     graph = _C._propagate_and_assign_input_shapes(
    983         graph, flattened_args, param_count_list, False, False
    984     )
    985     return graph, params, torch_out, None
--> 987 graph, torch_out = _trace_and_get_graph_from_model(model, args)
    988 _C._jit_pass_onnx_lint(graph)
    989 state_dict = torch.jit._unique_state_dict(model)

File ~/Library/Python/3.9/lib/python/site-packages/torch/onnx/utils.py:891, in _trace_and_get_graph_from_model(model, args)
    889 prev_autocast_cache_enabled = torch.is_autocast_cache_enabled()
    890 torch.set_autocast_cache_enabled(False)
--> 891 trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
    892     model,
    893     args,
    894     strict=False,
    895     _force_outplace=False,
    896     _return_inputs_states=True,
    897 )
    898 torch.set_autocast_cache_enabled(prev_autocast_cache_enabled)
    900 warn_on_static_input_change(inputs_states)

File ~/Library/Python/3.9/lib/python/site-packages/torch/jit/_trace.py:1184, in _get_trace_graph(f, args, kwargs, strict, _force_outplace, return_inputs, _return_inputs_states)
   1182 if not isinstance(args, tuple):
   1183     args = (args,)
-> 1184 outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
   1185 return outs

File ~/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py:1190, in Module._call_impl(self, *input, **kwargs)
   1186 # If we don't have any hooks, we want to skip the rest of the logic in
   1187 # this function, and just call forward.
   1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1189         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190     return forward_call(*input, **kwargs)
   1191 # Do not call functions when jit is used
   1192 full_backward_hooks, non_full_backward_hooks = [], []

File ~/Library/Python/3.9/lib/python/site-packages/torch/jit/_trace.py:127, in ONNXTracedModule.forward(self, *args)
    124     else:
    125         return tuple(out_vars)
--> 127 graph, out = torch._C._create_graph_by_tracing(
    128     wrapper,
    129     in_vars + module_state,
    130     _create_interpreter_name_lookup_fn(),
    131     self.strict,
    132     self._force_outplace,
    133 )
    135 if self._return_inputs:
    136     return graph, outs[0], ret_inputs[0]

File ~/Library/Python/3.9/lib/python/site-packages/torch/jit/_trace.py:118, in ONNXTracedModule.forward.<locals>.wrapper(*args)
    116 if self._return_inputs_states:
    117     inputs_states.append(_unflatten(in_args, in_desc))
--> 118 outs.append(self.inner(*trace_inputs))
    119 if self._return_inputs_states:
    120     inputs_states[0] = (inputs_states[0], trace_inputs)

File ~/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py:1190, in Module._call_impl(self, *input, **kwargs)
   1186 # If we don't have any hooks, we want to skip the rest of the logic in
   1187 # this function, and just call forward.
   1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1189         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190     return forward_call(*input, **kwargs)
   1191 # Do not call functions when jit is used
   1192 full_backward_hooks, non_full_backward_hooks = [], []

File ~/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py:1178, in Module._slow_forward(self, *input, **kwargs)
   1176         recording_scopes = False
   1177 try:
-> 1178     result = self.forward(*input, **kwargs)
   1179 finally:
   1180     if recording_scopes:

Cell In [11], line 16, in MyModel.forward(self, x, edge_index)
     15 def forward(self, x, edge_index):
---> 16     x = self.conv1(x, edge_index).relu()
     17     x = self.conv2(x, edge_index)
     18     return x

File ~/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py:1190, in Module._call_impl(self, *input, **kwargs)
   1186 # If we don't have any hooks, we want to skip the rest of the logic in
   1187 # this function, and just call forward.
   1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1189         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190     return forward_call(*input, **kwargs)
   1191 # Do not call functions when jit is used
   1192 full_backward_hooks, non_full_backward_hooks = [], []

File ~/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py:1178, in Module._slow_forward(self, *input, **kwargs)
   1176         recording_scopes = False
   1177 try:
-> 1178     result = self.forward(*input, **kwargs)
   1179 finally:
   1180     if recording_scopes:

File ~/Library/Python/3.9/lib/python/site-packages/torch_geometric/nn/conv/sage_conv.py:132, in SAGEConv.forward(self, x, edge_index, size)
    130 # propagate_type: (x: OptPairTensor)
    131 out = self.propagate(edge_index, x=x, size=size)
--> 132 out = self.lin_l(out)
    134 x_r = x[1]
    135 if self.root_weight and x_r is not None:

File ~/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py:1190, in Module._call_impl(self, *input, **kwargs)
   1186 # If we don't have any hooks, we want to skip the rest of the logic in
   1187 # this function, and just call forward.
   1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1189         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190     return forward_call(*input, **kwargs)
   1191 # Do not call functions when jit is used
   1192 full_backward_hooks, non_full_backward_hooks = [], []

File ~/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py:1178, in Module._slow_forward(self, *input, **kwargs)
   1176         recording_scopes = False
   1177 try:
-> 1178     result = self.forward(*input, **kwargs)
   1179 finally:
   1180     if recording_scopes:

File ~/Library/Python/3.9/lib/python/site-packages/torch_geometric/nn/dense/linear.py:118, in Linear.forward(self, x)
    113 def forward(self, x: Tensor) -> Tensor:
    114     r"""
    115     Args:
    116         x (Tensor): The features.
    117     """
--> 118     return F.linear(x, self.weight, self.bias)

RuntimeError: Cannot insert a Tensor that requires grad as a constant. Consider making it a parameter or input, or detaching the gradient
Tensor:
-0.1737  0.0572 -0.2492 -0.0007 -0.2551 -0.3244 -0.3505 -0.2479
-0.2926  0.0689 -0.0733  0.1093 -0.3388  0.2230  0.3113 -0.1732
-0.2584 -0.3294 -0.1186 -0.1263 -0.2351  0.1039  0.1304 -0.1237
 0.2858  0.1967  0.2776 -0.1998 -0.2199  0.2670  0.1745  0.1999
 0.2931  0.1472 -0.3305 -0.3097  0.0389 -0.0562 -0.1534  0.3115
 0.0438 -0.1634  0.2131  0.0337  0.1632 -0.1252  0.2368 -0.3067
-0.2482  0.0340  0.2924  0.0272  0.0925 -0.1025  0.3397 -0.3534
 0.2330  0.1820 -0.2319  0.0098  0.0659  0.0823  0.1556  0.1783
-0.1823  0.3397 -0.2553  0.0930 -0.2429 -0.3190 -0.1297 -0.1721
 0.2961  0.2292  0.1183  0.0785  0.0889  0.1453 -0.1325 -0.1583
-0.2163  0.1580  0.3421 -0.3478 -0.2789  0.2223  0.2199 -0.1646
 0.0150  0.1753  0.2263  0.2297 -0.0150 -0.2408  0.2811  0.1594
-0.0436 -0.1363 -0.0851  0.1756  0.1284 -0.0845  0.2577 -0.2410
 0.2452  0.1854  0.1490 -0.1863 -0.3367  0.1585  0.2370  0.1055
 0.1854 -0.1109  0.3030  0.3058 -0.2230  0.2402 -0.2770  0.0289
-0.1733  0.2713 -0.3123  0.1801  0.3247 -0.0202 -0.3534 -0.3037
[ torch.FloatTensor{16,8} ]

Environment

  • PyG version: 2.1.0.post1
  • PyTorch version: 1.13.0
  • OS: macOS 13.0
  • Python version: 3.9.6
  • CUDA/cuDNN version: None
  • How you installed PyTorch and PyG (conda, pip, source): pip
  • Any other relevant information (e.g., version of torch-scatter):
    • torch-cluster==1.6.0
    • torch-scatter==2.0.9
    • torch-sparse==0.6.15
    • torch-spline-conv==1.2.1

Issue Analytics

  • State:open
  • Created 10 months ago
  • Comments:7 (3 by maintainers)

github_iconTop GitHub Comments

1reaction
rusty1scommented, Nov 25, 2022

Thank you. I am a bit helpless here since I cannot reproduce this on my end 😦 Is there any chance you can debug on your end what might be causing this?

1reaction
wondey-shcommented, Nov 19, 2022

Hi @rusty1s, I have the following findings after some experiments:

  1. If I insert the following code before torch.onnx.export, the error will be gone.
for para in model.parameters():
    para.requires_grad = False

However, with torch.no_grad(): doesn’t work.

  1. If I replace torch_geometric.nn.dense.linear in SAGEConv by torch.nn.Linear, the error will be gone as well.
Read more comments on GitHub >

github_iconTop Results From Across the Web

I get an error when exporting to onnx but cannot find what it ...
It is a "torch.onnx.export" of a large model, and of course "shape" is used numerous times. The error does not tell which line...
Read more >
torch.onnx — PyTorch 1.13 documentation
The torch. onnx module can export PyTorch models to ONNX. The model can then be consumed by any of the many runtimes that...
Read more >
Question - ScatterElements error - Unity Forum
It will be present in 2.3.0 It does require you exporting the model ... Unknown type ScatterElements encountered while parsing layer 878.".
Read more >
ONNX export yields Error ! - MATLAB Answers - MathWorks
Hi, I tried to use exportONNXNetwork, I ran this part of code, but i saw this error, could you help me pls?! Usage...
Read more >
Running an ONNX model ValueError: not enough values to ...
I am new to python. I would like to convert my .pth file to .onnx file but encounter an error when running the...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found