Split-Attention Module in PyTorch
See original GitHub issueJust in case someone want to use the Split-Attention Module. The module is provided here:
import torch
import torch.nn as nn
from torch.nn import functional as F
class rSoftMax(nn.Module):
def __init__(self, radix, cardinality):
super().__init__()
assert radix > 0
self.radix = radix
self.cardinality = cardinality
def forward(self, x):
batch = x.size(0)
if self.radix > 1:
x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2)
x = F.softmax(x, dim=1)
x = x.reshape(batch, -1)
else:
x = torch.sigmoid(x)
return x
class Splat(nn.Module):
def __init__(self, channels, radix, cardinality, reduction_factor=4):
super(Splat, self).__init__()
self.radix = radix
self.cardinality = cardinality
self.channels = channels
inter_channels = max(channels*radix//reduction_factor, 32)
self.fc1 = nn.Conv2d(channels//radix, inter_channels, 1, groups=cardinality)
self.bn1 = nn.BatchNorm2d(inter_channels)
self.relu = nn.ReLU(inplace=True)
self.fc2 = nn.Conv2d(inter_channels, channels*radix, 1, groups=cardinality)
self.rsoftmax = rSoftMax(radix, cardinality)
def forward(self, x):
batch, rchannel = x.shape[:2]
if self.radix > 1:
splited = torch.split(x, rchannel//self.radix, dim=1)
gap = sum(splited)
else:
gap = x
gap = F.adaptive_avg_pool2d(gap, 1)
gap = self.fc1(gap)
gap = self.bn1(gap)
gap = self.relu(gap)
atten = self.fc2(gap)
atten = self.rsoftmax(atten).view(batch, -1, 1, 1)
if self.radix > 1:
attens = torch.split(atten, rchannel//self.radix, dim=1)
out = sum([att*split for (att, split) in zip(attens, splited)])
else:
out = atten * x
return out.contiguous()
Issue Analytics
- State:
- Created 3 years ago
- Comments:5 (2 by maintainers)
Top Results From Across the Web
ResNeSt - PyTorch
We present a simple and modular Split-Attention block that enables attention across feature-map groups. By stacking these Split-Attention blocks ...
Read more >External Attention pytorch - Model Zoo
Pytorch implementation of "BAM: Bottleneck Attention Module---BMCV2018" ... Pytorch implementation of "EPSANet: An Efficient Pyramid Split Attention Block ...
Read more >Tutorial 6: Transformers and Multi-Head Attention
This is for instance used if we stack multiple sequences with different lengths into a batch. To still benefit from parallelization in PyTorch,...
Read more >Self Attention with torch.nn.MultiheadAttention Module
This video explains how the torch multihead attention module works in Pytorch using a numerical example and also how Pytorch takes care of ......
Read more >EPSANet: An Efficient Pyramid Squeeze Attention Block on ...
By replacing the 3x3 convolution with the PSA module in the bottleneck blocks of the ResNet, a novel representational block named Efficient ...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Yes, that should work.
Hi! some questions about the parameter channels of the layer fn1 in class Splat, as follows. The one is “what’s mean of the parameter channels”, the other one is “the input channels of the layer fn1 is channels//radix, or channels?”.