question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

reimplement of split_attention_conv2d and why don't want to add BN2/ReLU in Bottleneck?

See original GitHub issue

hi @zhanghang1989 ,First of all, thank you very much for providing such an imaginative model

I refer to the source code implementation of ResNetSt and reproduce a new implementation of SplitAttentionConv2d. The implementation architecture may be clearer

# -*- coding: utf-8 -*-

"""
@date: 2021/1/4 上午11:32
@file: split_attention_conv2d.py
@author: zj
@description: 
"""
from abc import ABC

import torch

import torch.nn as nn

from ..init_helper import init_weights


class SplitAttentionConv2d(nn.Module, ABC):
    """
    ResNetSt的SplitAttention实现,参考:
    1. https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/backbones/resnest.py
    2. https://github.com/zhanghang1989/ResNeSt/blob/73b43ba63d1034dbf3e96b3010a8f2eb4cc3854f/resnest/torch/splat.py
    部分参考./selective_kernel_conv2d.py实现
    """

    def __init__(self,
                 # 输入通道数
                 in_channels,
                 # 输出通道数
                 out_channels,
                 # 每个group中的分离数
                 radix=2,
                 # cardinality
                 groups=1,
                 # 中间层衰减率
                 reduction_rate=4,
                 # 默认中间层最小通道数
                 default_channels: int = 32,
                 # 维度
                 dimension: int = 2
                 ):
        super(SplitAttentionConv2d, self).__init__()

        # split
        self.split = nn.Sequential(
            nn.Conv2d(in_channels, out_channels * radix, kernel_size=3, stride=1, padding=1, bias=False,
                      groups=groups * radix),
            nn.BatchNorm2d(out_channels * radix),
            nn.ReLU(inplace=True)
        )
        # fuse
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        inner_channels = max(out_channels // reduction_rate, default_channels)
        self.compact = nn.Sequential(
            nn.Conv2d(out_channels, inner_channels, kernel_size=1, stride=1, padding=0, bias=False,
                      groups=groups),
            nn.BatchNorm2d(inner_channels),
            nn.ReLU(inplace=True)
        )
        # select
        self.select = nn.Conv2d(inner_channels, out_channels * radix, kernel_size=1, stride=1, bias=False,
                                groups=groups)
        self.softmax = nn.Softmax(dim=0)
        self.dimension = dimension
        self.out_channels = out_channels
        self.radix = radix

        init_weights(self.modules())

    def forward(self, x):
        # N, C, H, W = x.shape[:4]
        # split
        out = self.split(x)
        split_out = torch.stack(torch.split(out, self.out_channels, dim=1))
        # fuse
        u = torch.sum(split_out, dim=0)
        s = self.pool(u)
        z = self.compact(s)
        # select
        c = self.select(z)
        split_c = torch.stack(torch.split(c, self.out_channels, dim=1))
        softmax_c = self.softmax(split_c)

        v = torch.sum(split_out.mul(softmax_c), dim=0)
        return v.contiguous()

and one of my question is why there is no need to add bn2/relu in Bottleneck when radix>0,Is it obtained through experiments?

Issue Analytics

  • State:open
  • Created 3 years ago
  • Comments:5 (1 by maintainers)

github_iconTop GitHub Comments

1reaction
zhanghang1989commented, Jan 17, 2021

Hi @FrancescoSaverioZuppichini , the bn+relu is applied to the first conv, because it adds non-linearity between two convs (otherwise it is equivalent to a single one). There is no bn+relu for the second conv, because the softmax is a kind of non-linearity or activation function.

1reaction
FrancescoSaverioZuppichinicommented, Jan 14, 2021

Same question, posting my implementation for completeness:

class SplitAtt(nn.Module):
    def __init__(self, in_features: int, features: int, radix: int, groups: int):
        """Implementation of Split Attention proposed in `"ResNeSt: Split-Attention Networks" <https://arxiv.org/abs/2004.08955>`_
        Grouped convolution have been proved to be impirically better (ResNetXt). The main idea is to apply an attention group-wise. 
        Einops is used to improve the readibility of this module
        Args:
            in_features (int): number of input features
            features (int): attention's features
            radix (int): number of subgroups (`radix`) in the groups
            groups (int): number of groups, each group contains `radix` subgroups
        """
        super().__init__()
        self.radix, self.groups = radix, groups
        self.att = nn.Sequential(
            # this produces U^{/hat}
            Reduce('b (r k c) h w -> b (k c) h w',
                   reduction='mean', r=radix, k=groups),
            # eq 1
            nn.AdaptiveAvgPool2d(1),
            # the two following conv layers are G in the paper
            ConvBnAct(in_features, features, kernel_size=1,
                      groups=groups, activation=ReLUInPlace, bias=True),
            nn.Conv2d(features, in_features * radix,
                      kernel_size=1, groups=groups),
            Rearrange('b (r k c) h w -> b r (k c) h w', r=radix, k=groups),
            nn.Softmax(dim=1) if radix > 1 else nn.Sigmoid(),
            Rearrange('b r (k c) h w -> b (r k c) h w', r=radix, k=groups)
        )

    def forward(self, x: Tensor) -> Tensor:
        att = self.att(x)
        # eq 2, scale using att and sum-up over the radix axis
        x *= att 
        x = reduce(x, 'b (r k c) h w -> b (k c) h w',
                   reduction='mean', r=self.radix, k=self.groups)
        return x

btw the bias in the first conv is useless but it is present in the original implementation, I guess it is an error

[Edit] After thinking about it, I think it makes sense because when radix > 1 softmax is applied (in rSoftmax) while when radix=0 sigmoid is used making it the same as SE. But there shouldn’t be a batchnorm and a ReLU

Read more comments on GitHub >

github_iconTop Results From Across the Web

No results found

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found