The Question about the mask of window attention
See original GitHub issueNice work!And i reading your code recently. But i cannot understand well about the implementation of the mask in shifted window attention.
I simply draw a picture like below. The red mean the mask, and i choose windowsize as 2, shiftsize as 1.
I think the mask should be like this but i use your code to generate mask like this:
import torch
import torch.nn as nn
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
window_size = 2
shift_size = 1
H, W = 4, 4
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (slice(0, -window_size),
slice(-window_size, -shift_size),
slice(-shift_size, None))
w_slices = (slice(0, -window_size),
slice(-window_size, -shift_size),
slice(-shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, window_size * window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
attn_mask = attn_mask.unsqueeze(1).unsqueeze(0)
"""
tensor([[[[[ 0., 0., 0., 0.],
[ 0., 0., 0., 0.],
[ 0., 0., 0., 0.],
[ 0., 0., 0., 0.]]],
[[[ 0., -100., 0., -100.],
[-100., 0., -100., 0.],
[ 0., -100., 0., -100.],
[-100., 0., -100., 0.]]],
[[[ 0., 0., -100., -100.],
[ 0., 0., -100., -100.],
[-100., -100., 0., 0.],
[-100., -100., 0., 0.]]],
[[[ 0., -100., -100., -100.],
[-100., 0., -100., -100.],
[-100., -100., 0., -100.],
[-100., -100., -100., 0.]]]]])
"""
I cannot understand it, can you give me a favor?
Issue Analytics
- State:
- Created 2 years ago
- Reactions:9
- Comments:9 (2 by maintainers)
Top Results From Across the Web
Sliding Window Attention Explained | Papers With Code
Sliding Window Attention is an attention pattern for attention-based models. It was proposed as part of the Longformer architecture.
Read more >How to properly mask MultiHeadAttention for sliding window ...
In the attention model, the mask required is of the shape (batch, queries, keys) so in order to train the entire horizon (queries)...
Read more >Mask Attention Networks: Rethinking and Strengthen ...
In the framework of MANs, we find a problem that irrelevant tokens with over- lapping neighbors incorrectly attend to each other with relatively...
Read more >MaiT: integrating spatial locality into image transformers with ...
In this work, we address this issue by introducing attention masks to ... window across layers while MaiT uses unmasked attention head within...
Read more >How to properly mask MultiHeadAttention for sliding window ...
I only apply this masking to the decoder self-attention layers. ... mask handling in MultiHeadAttention layer GitHub issue contains more ...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Hi, it seems all right. I simply visualize the
img_mask
andattn_mask
using image-size=14x14, window-size=7x7, shift=3. Below is the visualization code:emmm I think i’m wrong. I draw a new pic for the mask attention