question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

pytorch2onnx script failed

See original GitHub issue

Describe the Issue

`Running mmsegmentation pytorch2onnx on standard stdc2 512X1024 config failed. 

Error code: TypeError: forward() got multiple values for argument 'img_metas'
`

Based on https://github.com/open-mmlab/mmsegmentation/issues/1962, they suggest to report the error here

Reproduction

  1. What command, code, or script did you run?
python mmsegmentation/tools/pytorch2onnx.py /home/unj1szh/xspace/work_dirs/stdc2_4x4_pretrain_80k/stdc2_in1k-pre_512x1024_80k_cityscapes.py --shape 600 600
  1. Did you make any modifications on the code? Did you understand what you have modified? No I have not.

Environment

  1. Please run python -c "from mmcv.utils import collect_env; print(collect_env())" to collect necessary environment information and paste it here.
`sys.platform: linux
Python: 3.8.12 (default, Oct 12 2021, 13:49:34) [GCC 7.5.0]
CUDA available: True
GPU 0,1,2,3: NVIDIA GeForce RTX 3080
CUDA_HOME: /usr/local/cuda-11.1
NVCC: Cuda compilation tools, release 11.1, V11.1.105
GCC: gcc (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
PyTorch: 1.12.1
PyTorch compiling details: PyTorch built with:
  - GCC 9.3
  - C++ Version: 201402
  - Intel(R) oneAPI Math Kernel Library Version 2021.4-Product Build 20210904 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v2.6.0 (Git Hash 52b5f107dd9cf10910aaa19cb47f3abf9b349815)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - LAPACK is enabled (usually provided by MKL)
  - NNPACK is enabled
  - CPU capability usage: AVX2
  - CUDA Runtime 11.3
  - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_61,code=sm_61;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_37,code=compute_37
  - CuDNN 8.3.2  (built against CUDA 11.5)
  - Magma 2.5.2
  - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.3, CUDNN_VERSION=8.3.2, CXX_COMPILER=/opt/rh/devtoolset-9/root/usr/bin/c++, CXX_FLAGS= -fabi-version=11 -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -DEDGE_PROFILER_USE_KINETO -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.12.1, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=OFF, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, 

TorchVision: 0.13.1
OpenCV: 4.5.5
MMCV: 1.6.1
MMCV Compiler: GCC 9.4
MMCV CUDA Compiler: 11.1
MMSegmentation: 0.27.0+dd42fa8`
  1. You may add addition that may be helpful for locating the problem, such as
    • How you installed PyTorch [e.g., pip, conda, source] conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch

    • Other environment variables that may be related (such as $PATH, $LD_LIBRARY_PATH, $PYTHONPATH, etc.)

Error traceback If applicable, paste the error traceback here.

Traceback (most recent call last):
  File "tools/pytorch2onnx.py", line 386, in <module>
    pytorch2onnx(
  File "tools/pytorch2onnx.py", line 195, in pytorch2onnx
    torch.onnx.export(
  File "/home/unj1szh/.conda/envs/xcbase/lib/python3.8/site-packages/torch/onnx/__init__.py", line 350, in export
    return utils.export(
  File "/home/unj1szh/.conda/envs/xcbase/lib/python3.8/site-packages/torch/onnx/utils.py", line 163, in export
    _export(
  File "/home/unj1szh/.conda/envs/xcbase/lib/python3.8/site-packages/torch/onnx/utils.py", line 1074, in _export
    graph, params_dict, torch_out = _model_to_graph(
  File "/home/unj1szh/.conda/envs/xcbase/lib/python3.8/site-packages/torch/onnx/utils.py", line 727, in _model_to_graph
    graph, params, torch_out, module = _create_jit_graph(model, args)
  File "/home/unj1szh/.conda/envs/xcbase/lib/python3.8/site-packages/torch/onnx/utils.py", line 602, in _create_jit_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
  File "/home/unj1szh/.conda/envs/xcbase/lib/python3.8/site-packages/torch/onnx/utils.py", line 517, in _trace_and_get_graph_from_model
    trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
  File "/home/unj1szh/.conda/envs/xcbase/lib/python3.8/site-packages/torch/jit/_trace.py", line 1175, in _get_trace_graph
    outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
  File "/home/unj1szh/.conda/envs/xcbase/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/unj1szh/.conda/envs/xcbase/lib/python3.8/site-packages/torch/jit/_trace.py", line 127, in forward
    graph, out = torch._C._create_graph_by_tracing(
  File "/home/unj1szh/.conda/envs/xcbase/lib/python3.8/site-packages/torch/jit/_trace.py", line 118, in wrapper
    outs.append(self.inner(*trace_inputs))
  File "/home/unj1szh/.conda/envs/xcbase/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/unj1szh/.conda/envs/xcbase/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1118, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/unj1szh/xspace/4_labs/mmcv/mmcv/runner/fp16_utils.py", line 116, in new_func
    return old_func(*args, **kwargs)
TypeError: forward() got multiple values for argument 'img_metas'

Bug fix I have not identified the issue. I tried to debug, values in img_metas does not have duplicate values. The error seems ambiguous to me. Based on https://github.com/open-mmlab/mmsegmentation/issues/1962, it was suggested that the issue may be caused by mmcv

Issue Analytics

  • State:closed
  • Created a year ago
  • Comments:9

github_iconTop GitHub Comments

6reactions
HAOCHENYEcommented, Aug 29, 2022

Hi, you can try the following modified code(only worked for pytorch1.12).

# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import warnings
from functools import partial

import mmcv
import numpy as np
import onnxruntime as rt
import torch
import torch._C
import torch.serialization
from mmcv import DictAction
from mmcv.onnx import register_extra_symbolics
from mmcv.runner import load_checkpoint
from torch import nn

from mmseg.apis import show_result_pyplot
from mmseg.apis.inference import LoadImage
from mmseg.datasets.pipelines import Compose
from mmseg.models import build_segmentor
from mmseg.ops import resize

torch.manual_seed(3)


def _convert_batchnorm(module):
    module_output = module
    if isinstance(module, torch.nn.SyncBatchNorm):
        module_output = torch.nn.BatchNorm2d(module.num_features, module.eps,
                                             module.momentum, module.affine,
                                             module.track_running_stats)
        if module.affine:
            module_output.weight.data = module.weight.data.clone().detach()
            module_output.bias.data = module.bias.data.clone().detach()
            # keep requires_grad unchanged
            module_output.weight.requires_grad = module.weight.requires_grad
            module_output.bias.requires_grad = module.bias.requires_grad
        module_output.running_mean = module.running_mean
        module_output.running_var = module.running_var
        module_output.num_batches_tracked = module.num_batches_tracked
    for name, child in module.named_children():
        module_output.add_module(name, _convert_batchnorm(child))
    del module
    return module_output


def _demo_mm_inputs(input_shape, num_classes):
    """Create a superset of inputs needed to run test or train batches.

    Args:
        input_shape (tuple):
            input batch dimensions
        num_classes (int):
            number of semantic classes
    """
    (N, C, H, W) = input_shape
    rng = np.random.RandomState(0)
    imgs = rng.rand(*input_shape)
    segs = rng.randint(
        low=0, high=num_classes - 1, size=(N, 1, H, W)).astype(np.uint8)
    img_metas = [{
        'img_shape': (H, W, C),
        'ori_shape': (H, W, C),
        'pad_shape': (H, W, C),
        'filename': '<demo>.png',
        'scale_factor': 1.0,
        'flip': False,
    } for _ in range(N)]
    mm_inputs = {
        'imgs': torch.FloatTensor(imgs).requires_grad_(True),
        'img_metas': img_metas,
        'gt_semantic_seg': torch.LongTensor(segs)
    }
    return mm_inputs


def _prepare_input_img(img_path,
                       test_pipeline,
                       shape=None,
                       rescale_shape=None):
    # build the data pipeline
    if shape is not None:
        test_pipeline[1]['img_scale'] = (shape[1], shape[0])
    test_pipeline[1]['transforms'][0]['keep_ratio'] = False
    test_pipeline = [LoadImage()] + test_pipeline[1:]
    test_pipeline = Compose(test_pipeline)
    # prepare data
    data = dict(img=img_path)
    data = test_pipeline(data)
    imgs = data['img']
    img_metas = [i.data for i in data['img_metas']]

    if rescale_shape is not None:
        for img_meta in img_metas:
            img_meta['ori_shape'] = tuple(rescale_shape) + (3, )

    mm_inputs = {'imgs': imgs, 'img_metas': img_metas}

    return mm_inputs


def _update_input_img(img_list, img_meta_list, update_ori_shape=False):
    # update img and its meta list
    N, C, H, W = img_list[0].shape
    img_meta = img_meta_list[0][0]
    img_shape = (H, W, C)
    if update_ori_shape:
        ori_shape = img_shape
    else:
        ori_shape = img_meta['ori_shape']
    pad_shape = img_shape
    new_img_meta_list = [[{
        'img_shape':
        img_shape,
        'ori_shape':
        ori_shape,
        'pad_shape':
        pad_shape,
        'filename':
        img_meta['filename'],
        'scale_factor':
        (img_shape[1] / ori_shape[1], img_shape[0] / ori_shape[0]) * 2,
        'flip':
        False,
    } for _ in range(N)]]

    return img_list, new_img_meta_list


def pytorch2onnx(model,
                 mm_inputs,
                 opset_version=11,
                 show=False,
                 output_file='tmp.onnx',
                 verify=False,
                 dynamic_export=False):
    """Export Pytorch model to ONNX model and verify the outputs are same
    between Pytorch and ONNX.

    Args:
        model (nn.Module): Pytorch model we want to export.
        mm_inputs (dict): Contain the input tensors and img_metas information.
        opset_version (int): The onnx op version. Default: 11.
        show (bool): Whether print the computation graph. Default: False.
        output_file (string): The path to where we store the output ONNX model.
            Default: `tmp.onnx`.
        verify (bool): Whether compare the outputs between Pytorch and ONNX.
            Default: False.
        dynamic_export (bool): Whether to export ONNX with dynamic axis.
            Default: False.
    """
    model.cpu().eval()
    test_mode = model.test_cfg.mode

    if isinstance(model.decode_head, nn.ModuleList):
        num_classes = model.decode_head[-1].num_classes
    else:
        num_classes = model.decode_head.num_classes

    imgs = mm_inputs.pop('imgs')
    img_metas = mm_inputs.pop('img_metas')

    img_list = [img[None, :] for img in imgs]
    img_meta_list = [[img_meta] for img_meta in img_metas]
    # update img_meta
    img_list, img_meta_list = _update_input_img(img_list, img_meta_list)

    # replace original forward function
    origin_forward = model.forward
    # model.forward = partial(
    #     model.forward,
    #     img_metas=img_meta_list,
    #     return_loss=False,
    #     rescale=True)
    dynamic_axes = None
    if dynamic_export:
        if test_mode == 'slide':
            dynamic_axes = {'input': {0: 'batch'}, 'output': {1: 'batch'}}
        else:
            dynamic_axes = {
                'input': {
                    0: 'batch',
                    2: 'height',
                    3: 'width'
                },
                'output': {
                    1: 'batch',
                    2: 'height',
                    3: 'width'
                }
            }

    register_extra_symbolics(opset_version)
    with torch.no_grad():
        torch.onnx.export(
            model,
            (img_list, img_meta_list, False, dict(rescale=True)),
            output_file,
            input_names=['input'],
            output_names=['output'],
            export_params=True,
            keep_initializers_as_inputs=False,
            verbose=show,
            opset_version=opset_version,
            dynamic_axes=dynamic_axes)
        print(f'Successfully exported ONNX model: {output_file}')
    model.forward = origin_forward

    if verify:
        # check by onnx
        import onnx
        onnx_model = onnx.load(output_file)
        onnx.checker.check_model(onnx_model)

        if dynamic_export and test_mode == 'whole':
            # scale image for dynamic shape test
            img_list = [resize(_, scale_factor=1.5) for _ in img_list]
            # concate flip image for batch test
            flip_img_list = [_.flip(-1) for _ in img_list]
            img_list = [
                torch.cat((ori_img, flip_img), 0)
                for ori_img, flip_img in zip(img_list, flip_img_list)
            ]

            # update img_meta
            img_list, img_meta_list = _update_input_img(
                img_list, img_meta_list, test_mode == 'whole')

        # check the numerical value
        # get pytorch output
        with torch.no_grad():
            pytorch_result = model(img_list, img_meta_list, return_loss=False)
            pytorch_result = np.stack(pytorch_result, 0)

        # get onnx output
        input_all = [node.name for node in onnx_model.graph.input]
        input_initializer = [
            node.name for node in onnx_model.graph.initializer
        ]
        net_feed_input = list(set(input_all) - set(input_initializer))
        assert (len(net_feed_input) == 1)
        sess = rt.InferenceSession(output_file)
        onnx_result = sess.run(
            None, {net_feed_input[0]: img_list[0].detach().numpy()})[0][0]
        # show segmentation results
        if show:
            import os.path as osp

            import cv2
            img = img_meta_list[0][0]['filename']
            if not osp.exists(img):
                img = imgs[0][:3, ...].permute(1, 2, 0) * 255
                img = img.detach().numpy().astype(np.uint8)
                ori_shape = img.shape[:2]
            else:
                ori_shape = LoadImage()({'img': img})['ori_shape']

            # resize onnx_result to ori_shape
            onnx_result_ = cv2.resize(onnx_result[0].astype(np.uint8),
                                      (ori_shape[1], ori_shape[0]))
            show_result_pyplot(
                model,
                img, (onnx_result_, ),
                palette=model.PALETTE,
                block=False,
                title='ONNXRuntime',
                opacity=0.5)

            # resize pytorch_result to ori_shape
            pytorch_result_ = cv2.resize(pytorch_result[0].astype(np.uint8),
                                         (ori_shape[1], ori_shape[0]))
            show_result_pyplot(
                model,
                img, (pytorch_result_, ),
                title='PyTorch',
                palette=model.PALETTE,
                opacity=0.5)
        # compare results
        np.testing.assert_allclose(
            pytorch_result.astype(np.float32) / num_classes,
            onnx_result.astype(np.float32) / num_classes,
            rtol=1e-5,
            atol=1e-5,
            err_msg='The outputs are different between Pytorch and ONNX')
        print('The outputs are same between Pytorch and ONNX')


def parse_args():
    parser = argparse.ArgumentParser(description='Convert MMSeg to ONNX')
    parser.add_argument('config', help='test config file path')
    parser.add_argument('--checkpoint', help='checkpoint file', default=None)
    parser.add_argument(
        '--input-img', type=str, help='Images for input', default=None)
    parser.add_argument(
        '--show',
        action='store_true',
        help='show onnx graph and segmentation results')
    parser.add_argument(
        '--verify', action='store_true', help='verify the onnx model')
    parser.add_argument('--output-file', type=str, default='tmp.onnx')
    parser.add_argument('--opset-version', type=int, default=11)
    parser.add_argument(
        '--shape',
        type=int,
        nargs='+',
        default=None,
        help='input image height and width.')
    parser.add_argument(
        '--rescale_shape',
        type=int,
        nargs='+',
        default=None,
        help='output image rescale height and width, work for slide mode.')
    parser.add_argument(
        '--cfg-options',
        nargs='+',
        action=DictAction,
        help='Override some settings in the used config, the key-value pair '
        'in xxx=yyy format will be merged into config file. If the value to '
        'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
        'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
        'Note that the quotation marks are necessary and that no white space '
        'is allowed.')
    parser.add_argument(
        '--dynamic-export',
        action='store_true',
        help='Whether to export onnx with dynamic axis.')
    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = parse_args()

    cfg = mmcv.Config.fromfile(args.config)
    if args.cfg_options is not None:
        cfg.merge_from_dict(args.cfg_options)
    cfg.model.pretrained = None

    if args.shape is None:
        img_scale = cfg.test_pipeline[1]['img_scale']
        input_shape = (1, 3, img_scale[1], img_scale[0])
    elif len(args.shape) == 1:
        input_shape = (1, 3, args.shape[0], args.shape[0])
    elif len(args.shape) == 2:
        input_shape = (
            1,
            3,
        ) + tuple(args.shape)
    else:
        raise ValueError('invalid input shape')

    test_mode = cfg.model.test_cfg.mode

    # build the model and load checkpoint
    cfg.model.train_cfg = None
    segmentor = build_segmentor(
        cfg.model, train_cfg=None, test_cfg=cfg.get('test_cfg'))
    # convert SyncBN to BN
    segmentor = _convert_batchnorm(segmentor)

    if args.checkpoint:
        checkpoint = load_checkpoint(
            segmentor, args.checkpoint, map_location='cpu')
        segmentor.CLASSES = checkpoint['meta']['CLASSES']
        segmentor.PALETTE = checkpoint['meta']['PALETTE']

    # read input or create dummpy input
    if args.input_img is not None:
        preprocess_shape = (input_shape[2], input_shape[3])
        rescale_shape = None
        if args.rescale_shape is not None:
            rescale_shape = [args.rescale_shape[0], args.rescale_shape[1]]
        mm_inputs = _prepare_input_img(
            args.input_img,
            cfg.data.test.pipeline,
            shape=preprocess_shape,
            rescale_shape=rescale_shape)
    else:
        if isinstance(segmentor.decode_head, nn.ModuleList):
            num_classes = segmentor.decode_head[-1].num_classes
        else:
            num_classes = segmentor.decode_head.num_classes
        mm_inputs = _demo_mm_inputs(input_shape, num_classes)

    # convert model to onnx file
    pytorch2onnx(
        segmentor,
        mm_inputs,
        opset_version=args.opset_version,
        show=args.show,
        output_file=args.output_file,
        verify=args.verify,
        dynamic_export=args.dynamic_export)

    # Following strings of text style are from colorama package
    bright_style, reset_style = '\x1b[1m', '\x1b[0m'
    red_text, blue_text = '\x1b[31m', '\x1b[34m'
    white_background = '\x1b[107m'

    msg = white_background + bright_style + red_text
    msg += 'DeprecationWarning: This tool will be deprecated in future. '
    msg += blue_text + 'Welcome to use the unified model deployment toolbox '
    msg += 'MMDeploy: https://github.com/open-mmlab/mmdeploy'
    msg += reset_style
    warnings.warn(msg)

1reaction
jyang68shcommented, Sep 8, 2022

@HAOCHENYE Hi Sorry for the late reply. I have tried your code which works. Thanks

Read more comments on GitHub >

github_iconTop Results From Across the Web

pytorch2onnx script run failed ! · Issue #97 · open-mmlab ...
A clear and concise description of what the bug is. ... What command or script did you run? ... What dataset did you...
Read more >
Tutorial 8: Pytorch to ONNX (Experimental)
Users need to implement the post-process by themselves. We do not guarantee the correctness of the exported model. Example: python tools/deployment/pytorch2onnx ...
Read more >
scripts/pytorch2onnx.py · akhaliq/Real-ESRGAN at ... - Hugging Face
Build error. App Files Files and versions Community. 2. f61504e. Real-ESRGAN / scripts /pytorch2onnx.py. AK391. updates · 810c8ea 11 months ago.
Read more >
Pytorch2onnx (#3075) · 2f32a47a17 - MMDetection - OpenI
We provide a script to convert model to [ONNX](https://github.com/onnx/onnx) format. The converted model could be visualized by tools like ...
Read more >
RuntimeError when converting onnx model using ... - Kneron
Hi, when I run the script with my onnx model the following error occurred: ... RuntimeError when converting onnx model using pytorch2onnx.py.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found