question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Split-Attention Module in PyTorch

See original GitHub issue

Just in case someone want to use the Split-Attention Module. The module is provided here:

import torch
import torch.nn as nn
from torch.nn import functional as F

class rSoftMax(nn.Module):
    def __init__(self, radix, cardinality):
        super().__init__()
        assert radix > 0
        self.radix = radix
        self.cardinality = cardinality

    def forward(self, x):
        batch = x.size(0)
        if self.radix > 1:
            x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2)
            x = F.softmax(x, dim=1)
            x = x.reshape(batch, -1)
        else:
            x = torch.sigmoid(x)
        return x

class Splat(nn.Module):
    def __init__(self, channels, radix, cardinality, reduction_factor=4):
        super(Splat, self).__init__()
        self.radix = radix
        self.cardinality = cardinality
        self.channels = channels
        inter_channels = max(channels*radix//reduction_factor, 32)
        self.fc1 = nn.Conv2d(channels//radix, inter_channels, 1, groups=cardinality)
        self.bn1 = nn.BatchNorm2d(inter_channels)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Conv2d(inter_channels, channels*radix, 1, groups=cardinality)
        self.rsoftmax = rSoftMax(radix, cardinality)

    def forward(self, x):
        batch, rchannel = x.shape[:2]
        if self.radix > 1:
            splited = torch.split(x, rchannel//self.radix, dim=1)
            gap = sum(splited) 
        else:
            gap = x
        gap = F.adaptive_avg_pool2d(gap, 1)
        gap = self.fc1(gap)

        gap = self.bn1(gap)
        gap = self.relu(gap)

        atten = self.fc2(gap)
        atten = self.rsoftmax(atten).view(batch, -1, 1, 1)

        if self.radix > 1:
            attens = torch.split(atten, rchannel//self.radix, dim=1)
            out = sum([att*split for (att, split) in zip(attens, splited)])
        else:
            out = atten * x
        return out.contiguous()

Issue Analytics

  • State:open
  • Created 3 years ago
  • Comments:5 (2 by maintainers)

github_iconTop GitHub Comments

1reaction
zhanghang1989commented, Jun 16, 2020

@zhanghang1989 thanks for sharing the split attention module implementation, can we integrate this with FPN module instead of Resnet backbone ??

Yes, that should work.

0reactions
Hi-Jingzhicommented, Jun 27, 2020

Hi! some questions about the parameter channels of the layer fn1 in class Splat, as follows. The one is “what’s mean of the parameter channels”, the other one is “the input channels of the layer fn1 is channels//radix, or channels?”.

Read more comments on GitHub >

github_iconTop Results From Across the Web

ResNeSt - PyTorch
We present a simple and modular Split-Attention block that enables attention across feature-map groups. By stacking these Split-Attention blocks ...
Read more >
External Attention pytorch - Model Zoo
Pytorch implementation of "BAM: Bottleneck Attention Module---BMCV2018" ... Pytorch implementation of "EPSANet: An Efficient Pyramid Split Attention Block ...
Read more >
Tutorial 6: Transformers and Multi-Head Attention
This is for instance used if we stack multiple sequences with different lengths into a batch. To still benefit from parallelization in PyTorch,...
Read more >
Self Attention with torch.nn.MultiheadAttention Module
This video explains how the torch multihead attention module works in Pytorch using a numerical example and also how Pytorch takes care of ......
Read more >
EPSANet: An Efficient Pyramid Squeeze Attention Block on ...
By replacing the 3x3 convolution with the PSA module in the bottleneck blocks of the ResNet, a novel representational block named Efficient ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found