question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

TypeError: forward() missing 1 required positional argument: 'x_body'

See original GitHub issue

TypeError: forward() missing 1 required positional argument: ‘x_body’.

#But I already have ‘x_body’ in the network definition.

Question

I use the torch.onnx.export() function to convert the.pth format model to.onnx format. I converted the two resNet models to the.onNx format and it worked fine. Now I have a problem transforming the model defined in my following code. This model takes the outputs of the two RESNet models as parameter inputs.

Traceback

Traceback (most recent call last): File “E:/pythonProject/Pytorch2TFLite/PyTorch2ONNX.py”, line 61, in <module> torch.onnx.export(model_emotic, File “D:\Anaconda3\lib\site-packages\torch-1.9.0-py3.8-win-amd64.egg\torch\onnx_init_.py”, line 275, in export return utils.export(model, args, f, export_params, verbose, training, File “D:\Anaconda3\lib\site-packages\torch-1.9.0-py3.8-win-amd64.egg\torch\onnx\utils.py”, line 88, in export _export(model, args, f, export_params, verbose, training, input_names, output_names, File “D:\Anaconda3\lib\site-packages\torch-1.9.0-py3.8-win-amd64.egg\torch\onnx\utils.py”, line 689, in _export _model_to_graph(model, args, verbose, input_names, File “D:\Anaconda3\lib\site-packages\torch-1.9.0-py3.8-win-amd64.egg\torch\onnx\utils.py”, line 458, in _model_to_graph graph, params, torch_out, module = _create_jit_graph(model, args, File “D:\Anaconda3\lib\site-packages\torch-1.9.0-py3.8-win-amd64.egg\torch\onnx\utils.py”, line 422, in _create_jit_graph graph, torch_out = _trace_and_get_graph_from_model(model, args) File “D:\Anaconda3\lib\site-packages\torch-1.9.0-py3.8-win-amd64.egg\torch\onnx\utils.py”, line 373, in _trace_and_get_graph_from_model torch.jit._get_trace_graph(model, args, strict=False, _force_outplace=False, _return_inputs_states=True) File “D:\Anaconda3\lib\site-packages\torch-1.9.0-py3.8-win-amd64.egg\torch\jit_trace.py”, line 1160, in _get_trace_graph outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs) File “D:\Anaconda3\lib\site-packages\torch-1.9.0-py3.8-win-amd64.egg\torch\nn\modules\module.py”, line 1051, in _call_impl return forward_call(*input, **kwargs) File “D:\Anaconda3\lib\site-packages\torch-1.9.0-py3.8-win-amd64.egg\torch\jit_trace.py”, line 127, in forward graph, out = torch._C._create_graph_by_tracing( File “D:\Anaconda3\lib\site-packages\torch-1.9.0-py3.8-win-amd64.egg\torch\jit_trace.py”, line 118, in wrapper outs.append(self.inner(*trace_inputs)) File “D:\Anaconda3\lib\site-packages\torch-1.9.0-py3.8-win-amd64.egg\torch\nn\modules\module.py”, line 1051, in _call_impl return forward_call(*input, **kwargs) File “D:\Anaconda3\lib\site-packages\torch-1.9.0-py3.8-win-amd64.egg\torch\nn\modules\module.py”, line 1039, in _slow_forward result = self.forward(*input, **kwargs) TypeError: forward() missing 1 required positional argument: ‘x_body’

My code

import os.path as osp
import numpy as np
import onnx
import onnxruntime as ort
import torch
import torchvision
import torch.nn as nn

#加入model_emotic1.pth:初始化
class Emotic(nn.Module):
    ''' Emotic Model
        转换model_body1.pth和model_context1.pth不需要这个类。
        转换model_emotic1.pth时,需要加此类。
    '''

    def __init__(self, num_context_features, num_body_features):
        super(Emotic, self).__init__()
        self.num_context_features = num_context_features
        self.num_body_features = num_body_features
        self.fc1 = nn.Linear((self.num_context_features + num_body_features), 256)
        self.bn1 = nn.BatchNorm1d(256)
        self.d1 = nn.Dropout(p=0.5)
        self.fc_cat = nn.Linear(256, 26)
        self.fc_cont = nn.Linear(256, 3)
        self.relu = nn.ReLU()

    '''forward是自动调用的,如:model=Emotic();y=model(x,y),则会调用网络模型定义的forward()方法【而不是model.forward(x,y)】。
        即,当把定义的网络模型当作函数调用的时候,就自动调用定义的网络模型的forward()方法。
        是通过nn.Module的__call__方法调用的,就相当于调用了模型,就是直接调用它的forward函数,y=model(x),这个x就是直接传入到forward函数的x参数。
    '''
    def forward(self, x_context, x_body):   #定义前向传播
        context_features = x_context.view(-1, self.num_context_features)
        body_features = x_body.view(-1, self.num_body_features)
        fuse_features = torch.cat((context_features, body_features),
                                  1)  # torch.cat()函数输入了context_features和body_features。torch.cat是将两个张量(tensor)拼接在一起,cat是concatnate的意思,即拼接,联系在一起。使用torch.cat((A,B),dim)时,除拼接维数dim数值可不同外其余维数数值需相同,方能对齐。dim=1,横着拼接。
        fuse_out = self.fc1(fuse_features)
        fuse_out = self.bn1(fuse_out)
        fuse_out = self.relu(fuse_out)
        fuse_out = self.d1(fuse_out)
        cat_out = self.fc_cat(fuse_out)
        cont_out = self.fc_cont(fuse_out)
        return cat_out, cont_out


test_arr=np.random.randn(10,3,224,224).astype(np.float32)   #astype():数组的副本,转换为指定类型
dummy_input=torch.tensor(test_arr)
body_test_arr=np.random.randn(10,3,128,128).astype(np.float32)
body_dummy_input=torch.tensor(body_test_arr)

model_body=torch.load("./models/model_body1.pth")
model_context=torch.load("./models/model_context1.pth")
pred_context=model_context(torch.from_numpy(test_arr))
pred_body=model_body(torch.from_numpy(body_test_arr))

model_emotic=torch.load("./models/model_emotic1.pth")
model_emotic.eval()
torch_output=model_emotic(pred_context,pred_body)  #torch.from_numpy():从numpy.ndarray创建Tensor

input_names=["input"]
output_names=["cat","cont"]
torch.onnx.export(model_emotic,
                  dummy_input,
                  "./models/model_emotic1.onnx",
                  verbose=False,    #如果verbose指定了,将输出正在导出的跟踪的调试描述。默认:false
                  input_names=input_names,  #按顺序分配给图的输入节点的名称
                  output_names=output_names)    #按顺序分配给图的输入节点的名称

model=onnx.load("./models/model_emotic1.onnx")
ort_session=ort.InferenceSession("./models/model_emotic1.onnx")
onnx_outputs=ort_session.run(None,{'input':test_arr})
print('Export ONNX!')

I hope you can solve my confusion as soon as possible. Thank you very much!

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:6 (2 by maintainers)

github_iconTop GitHub Comments

1reaction
RyanShun-511commented, Jan 18, 2022

@jcwchen Thank you very much for your reply. I have solved the problem. Sorry for the trouble caused to you.

0reactions
Leonardo0325commented, Nov 17, 2022

I have the same problem,look this. TypeError:forward() takes 1 positional argument but 2 were given.

Read more comments on GitHub >

github_iconTop Results From Across the Web

TypeError: forward() missing 1 required positional argument ...
I think the error message is pretty straight forward. You have two positional arguments input_tokens and hidden for your forward() .
Read more >
forward() missing 1 required positional argument: 'target' when ...
I am trying to build a next word prediction model with pytorch in google colab. As my vocabulary size is over 1.5 million,...
Read more >
TypeError: forward() missing 1 required positional argument: 'x'
This exception may be raised by user code to indicate that an attempted operation on an object is not supported, and is not...
Read more >
TypeError: forward() missing 1 required positional argument: 'x ...
Current repo: run git fetch && git status -uno to check and git pull to update repo; Common dataset: coco.yaml or coco128.yaml; Common ......
Read more >
forward() missing 1 required positional argument - Xilinx Support
I changed resnet18_quant.py to quantize my model?the .py file like this: import osimport argparseimport randomfrom pytorch_nndct.apis import ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found