question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Suggestions for constructing a ResNet with the revlib

See original GitHub issue

Hi

I would like to use this library to build a ResNet20 model, I’ve tried several times but I still have the mismatched dimension error. My model is shown as follows:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.modules.instancenorm import _InstanceNorm

hidden_size = [16, 32, 64]

class View(nn.Module):
    def forward(self, x):
        batch_size = x.size(0)
        return x.view(batch_size, -1)

class LambdaLayer(nn.Module):
    def __init__(self, lambd):
        super(LambdaLayer, self).__init__()
        self.lambd = lambd

    def forward(self, x):
        return self.lambd(x)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride, norm_layer, conv_layer, option='A'):
        super(BasicBlock, self).__init__()
        
        self.bn1 = norm_layer(in_planes)
        self.conv1 = conv_layer(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = norm_layer(planes)
        self.conv2 = conv_layer(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            if option == 'A':
                """
                For CIFAR10 ResNet paper uses option A.
                """
                self.shortcut = LambdaLayer(lambda x:
                                            F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))
            elif option == 'B':
                self.shortcut = conv_layer(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False)

    def forward(self, x):
        out = F.relu(self.bn1(x))
        shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
        print(1, shortcut.size())
        out = self.conv1(out)
        out = self.conv2(F.relu(self.bn2(out)))
        print(2, out.size())
        out += shortcut
        return out

class ResNet(nn.Module):

    def __init__(self, hidden_size, block, num_blocks, num_classes=10, bn_type='bn',
                 share_affine=False, track_running_stats=True):
        super(ResNet, self).__init__()
        self.in_planes = 16

        self.bn_type = bn_type
        if bn_type == 'bn':
            norm_layer = lambda n_ch: nn.BatchNorm2d(n_ch, track_running_stats=track_running_stats)
        elif bn_type == 'gn':
            norm_layer = lambda n_ch: nn.GroupNorm(4, n_ch) # 3 can be changed -- # of groups
        else:
            raise RuntimeError(f"Not support bn_type={bn_type}")
        conv_layer = nn.Conv2d
        first = conv_layer(3, hidden_size[0], kernel_size=3, stride=1, padding=1, bias=False)
        layer1 = self._make_layer(block, hidden_size[0], num_blocks[0], stride=1,
                                       norm_layer=norm_layer, conv_layer=conv_layer)
        layer2 = self._make_layer(block, hidden_size[1], num_blocks[1], stride=2,
                                       norm_layer=norm_layer, conv_layer=conv_layer)
        layer3 = self._make_layer(block, hidden_size[2], num_blocks[2], stride=2,
                                       norm_layer=norm_layer, conv_layer=conv_layer)
        
        self.rev_layers = revlib.ReversibleSequential(*[layer1, layer2, layer3])
        
        norm = norm_layer(hidden_size[2] * block.expansion)
        linear = nn.Linear(hidden_size[2] * block.expansion, num_classes)
        
        self.full_model = nn.Sequential(first, self.rev_layers, nn.ReLU(), norm, \
                                        nn.AdaptiveAvgPool2d((None, 1)), View(), linear)

    def _make_layer(self, block, planes, num_blocks, stride, norm_layer, conv_layer):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride, norm_layer, conv_layer))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.full_model(x)
        return out

def init_param(m):
    """Special init for ResNet"""
    if isinstance(m, (_BatchNorm, _InstanceNorm)):
        m.weight.data.fill_(1)
        m.bias.data.zero_()
    elif isinstance(m, nn.Linear):
        m.bias.data.zero_()
    return m

def resnet20(**kwargs):
    model = ResNet(hidden_size, BasicBlock, [3,3,3], **kwargs)
    model.apply(init_param)
    return model

I’ve tried to modify self.in_planes = 8 and hidden_size = [8, 16, 32], respectively, but it still does not work. Could you provide any hints? Is it possible to build a model in a forward way instead of wrapping reversible model with non-reversible layers like model = nn.Sequential(conv, rev_layer, conv)? I appreciate your help.

Issue Analytics

  • State:closed
  • Created a year ago
  • Comments:5 (3 by maintainers)

github_iconTop GitHub Comments

1reaction
ClashLukecommented, Aug 9, 2022

The top problem you faced is that RevNet requires all inputs and outputs to be the same size. As the second layer has more output features than the first, RevNet will have to add a tensor with 32 features to one with 16, which isn’t possible.
Think about it like in a ResNet. In ResNet, you only have the residual path within each resolution+feature size, but not across them. To get the residual stream across, you usually use downsampling (such as AvgPool2d) and add its output to the output of your “residual” block. In RevNet, the second thing doesn’t exist. Instead, you would have to use PixelShuffle and feature padding to arrive at a similar result (see #2).

The easiest way forward would be to have multiple ReversibleSequential modules, one for each _make_layer()-call, and put these into a standard nn.Sequential-container. This is how the original RevNet did it. Their method uses marginally more parameters but otherwise gives the same results: grafik grafik

Another alternative would be to avoid this multi-stage assembly and construct one large ReversibleSequential module instead. Using one large block saves memory, and i-RevNet documented how they achieved marginally worse ImageNet accuracy with this kind of architecture: grafik


Yes, you can define the reversible architecture in forward. However, I’d advise against it, as ReversibleSequential is a thin wrapper around things you have to do anyway.
If you want to do what ReversibleSequential would usually handle for you, you’d have to wrap your modules in ReversibleModules like so: https://github.com/HomebrewNLP/revlib/blob/34dad19318e2f861ea6b0ce263506625a934b568/revlib/core.py#L471-L487

and call these modules one-by-one, just like in a normal nn.Sequential module: https://github.com/HomebrewNLP/revlib/blob/34dad19318e2f861ea6b0ce263506625a934b568/revlib/core.py#L509-L511

0reactions
ClashLukecommented, Aug 10, 2022

Sorry, I’m not planning to add these, as the most common functions (pooling, pixelshuffle, upsample) are already part of PyTorch.

Read more comments on GitHub >

github_iconTop Results From Across the Web

REVLib Information - SPARK MAX - Documentation
Below you will find information on how to download and install REVLib for LabVIEW, Java, ... Creating a CANSparkMax object for a device...
Read more >
Untitled
Principal component regression sas, Kapital building leicester navratri? ... Shadab saifi song, Zupa rybna przepis forum, Gary yamamoto senko tips, ...
Read more >
Council - University of Pennsylvania Almanac
The Committee on the Library of the University Council took as its central charge the recommendation that it work with the Vice Provost...
Read more >
Yibo Lin, Author at SIGDA
In this paper, we discuss the source of effectiveness of Graph Neural Networks (GNNs) in EDA, particularly in the VLSI design automation domain....
Read more >
Simplicity is not Key: Understanding Firm-Generated Social ...
ing a variety of key marketing objectives, from creating awareness to facili- ... In addtion, we follow recommendations not to trans-.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found