question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Supported for torch.sum

See original GitHub issue

Describe the issue:

I noticed that currently NNI does not support the torch.sum operation. But I did find the torch.sum operation in some network models, such as resnest.

I wrote my own support for torch.sum but it doesn’t seem right.

def sum_python(node, speedup):
    c_node = node.key_node
    inputs = list(c_node.inputs())
    dim_list = translate_list(inputs[1], speedup)
    keep_dim = inputs[2].toIValue()
    new_sum = partial(torch.sum, dim=tuple(dim_list), keepdim=keep_dim)
    return new_sum

Some masks of layers will be omitted. image

Environment:

  • NNI version: the latest
  • Training service (local|remote|pai|aml|etc):
  • Client OS: centos 7
  • Server OS (for remote mode only):
  • Python version: 3.8.8
  • PyTorch/TensorFlow version: 1.8.0
  • Is conda/virtualenv/venv used?: yes
  • Is running in Docker?: no

How to reproduce it?: This is the simpe code and you can download mmclassification to reproduce it. Note that the pytorch version should be higher than 1.8.0 or equal.

import torch
from argparse import ArgumentParser

from mmcls.apis import inference_model, init_model, show_result_pyplot

from nni.compression.pytorch import ModelSpeedup
from nni.compression.pytorch.utils.counter import count_flops_params
from nni.algorithms.compression.v2.pytorch.pruning.basic_pruner import SlimPruner, L1NormPruner, FPGMPruner
from nni.compression.pytorch.utils import not_safe_to_prune

device = 'cuda:0'
config = 'configs/resnest/resnest50_8xb16_cifar10.py'
checkpoint = None
img_file = 'demo/demo.JPEG'

# build the model from a config file and a checkpoint file
model = init_model(config, checkpoint, device=device)

model.forward = model.dummy_forward

pre_flops, pre_params, _ = count_flops_params(model, torch.randn([128, 3, 32, 32]).to(device))

im = torch.ones(1, 3, 128, 128).to(device)
out = model(im)

# with torch.no_grad():
#     input_name = ['input']
#     output_name  = ['output']
#     onnxname = 'resnest.onnx'
#     torch.onnx.export(model, im, onnxname, input_names = input_name, output_names = output_name,
#                     opset_version=11, training=False, verbose=False, do_constant_folding=False)
#     print(f'successful export onnx {onnxname}')
# exit()

# scores = model(return_loss=False, **data)
# scores = model(return_loss=False, **im)

# test a single image
# result = inference_model(model, img_file)

# Start to prune and speedupls
print('\n' + '=' * 50 + ' START TO PRUNE THE BEST ACCURACY PRETRAINED MODEL ' + '=' * 50)
not_safe = not_safe_to_prune(model, im)

print('\n' + '=' * 50 +  'not_safe' + '=' * 50, not_safe)
cfg_list = []
for name, module in model.named_modules():
    print(name)
    if name in not_safe:
        continue
    if isinstance(module, torch.nn.Conv2d):
        cfg_list.append({'op_types':['Conv2d'], 'sparsity':0.2, 'op_names':[name]})

print('cfg_list')
for i in cfg_list:
    print(i)

pruner = FPGMPruner(model, cfg_list)
_, masks = pruner.compress()
pruner.show_pruned_weights()
pruner._unwrap_model()
pruner.show_pruned_weights()

ModelSpeedup(model, dummy_input=im, masks_file=masks, confidence=32).speedup_model()
torch.jit.trace(model, im, strict=False)
print(model)
flops, params, results = count_flops_params(model, torch.randn([128, 3, 32, 32]).to(device))
print(f'Pretrained model FLOPs {pre_flops/1e6:.2f} M, #Params: {pre_params/1e6:.2f}M')
print(f'Finetuned model FLOPs {flops/1e6:.2f} M, #Params: {params/1e6:.2f}M')
model.forward = model.forward_
torch.save(model, 'chek/prune_model/resnest50_8xb16_cifar10_sparsity_0.2.pth')

The config file for resnest50_8xb16_cifar10.py is:

_base_ = [
    '../_base_/datasets/cifar10_bs16.py',
    '../_base_/schedules/cifar10_bs128.py', 
    '../_base_/default_runtime.py'
]

# model settings
model = dict(
    type='ImageClassifier',
    backbone=dict(
        type='ResNeSt',
        depth=50,
        num_stages=4,
        out_indices=(3, ),
        style='pytorch'),
    neck=dict(type='GlobalAveragePooling'),
    head=dict(
        type='LinearClsHead',
        num_classes=10,
        in_channels=2048,
        loss=dict(
            type='LabelSmoothLoss',
            label_smooth_val=0.1,
            num_classes=10,
            reduction='mean',
            loss_weight=1.0),
        topk=(1, 5),
        cal_acc=False))

train_cfg = dict(mixup=dict(alpha=0.2, num_classes=10))

lr_config = dict(policy='step', step=[120, 170])
runner = dict(type='EpochBasedRunner', max_epochs=200)

Issue Analytics

  • State:open
  • Created a year ago
  • Comments:29 (17 by maintainers)

github_iconTop GitHub Comments

1reaction
Louis-Jcommented, Jul 19, 2022

You can try it by downloading the file and moving it to [your_env_name]\lib\site-packages\nni\compression\pytorch\speedup\jit_translate.py.

1reaction
Louis-Jcommented, Jul 8, 2022

torch.sum has 3 overrides, and I can’t recognize which one is used. maybe another override with no dim argument is used.

We will test the auto-convert-op feature at the dev-speedup-auto-op branch. I think this issue can be solved with this feature.

Read more comments on GitHub >

github_iconTop Results From Across the Web

torch.sum — PyTorch 1.13 documentation
torch.sum ... Returns the sum of all elements in the input tensor. ... Returns the sum of each row of the input tensor...
Read more >
Torch sum a tensor along an axis - Stack Overflow
The simplest and best solution is to use torch.sum() . To sum all elements of a tensor: torch.sum(x) # gives back a scalar....
Read more >
Pytorch tensor operations - Adrian G
For a tensor to be viewed, the new view size must be compatible with its ... sum. torch.sum(input, dtype=None). Returns the sum of...
Read more >
Sum - Tensors and Neural Networks with 'GPU' Acceleration
(bool) whether the output tensor has dim retained or not. dtype. ( torch.dtype , optional) the desired data type of returned tensor.
Read more >
torch — PyTorch master documentation
The rows of input do not need to sum to one (in which case we use the values as ... If you are...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found