Efficient Attention
See original GitHub issueFeature request
The authors of https://openaccess.thecvf.com/content/WACV2021/papers/Shen_Efficient_Attention_Attention_With_Linear_Complexities_WACV_2021_paper.pdf propose to change the attention module in transformers to achieve linear complexity in the number of processed tokens (instead of the feared quadratic complexity). I’d like to request the following feature: The possibility to choose between standard and efficient attention, for instance via a flag in the corresponding config file when building a huggingface model.
Motivation
Any model with a variable size of input tokens can benefit from this, and it is especially useful if one wants to process a lot of tokens. Say, number of words in a text or number of patches in an image. The authors of the paper show that this can significantly reduce inference time, training time and memory load. In my case, I have implemented this change for the ViTMAE model and noticed that I can easily process images of size 592x592 (37x37 = 1369 (!!!) tokens) now, whereas before my machine was capped with 384x384-sized images (24x24 = 576 tokens). Due to the linear complexity, I could even go for higher token count/image resolution, trading off with batch size. The model size is not affected (it’s magic, really), and for the ViTMAE model I have confirmed that this type of attention is working quite well.
Your contribution
The change is really easy to implement, but might take a bit of work because it most likely has to be done for every model of interest. For the ViTMAEmodel, I have done the following quick and dirty tweak to the class ViTMAESelfAttention:
# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention ViT->ViTMAE
class ViTMAESelfAttention(nn.Module):
def __init__(self, config: ViTMAEConfig) -> None:
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
f"heads {config.num_attention_heads}."
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
# OLD ATTENTION
# Take the dot product between "query" and "key" to get the raw attention scores.
# attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
#
# attention_scores = attention_scores / math.sqrt(self.attention_head_size)
#
# # Normalize the attention scores to probabilities.
# attention_probs = nn.functional.softmax(attention_scores, dim=-1)
#
# # This is actually dropping out entire tokens to attend to, which might
# # seem a bit unusual, but is taken from the original Transformer paper.
# attention_probs = self.dropout(attention_probs)
#
# # Mask heads if we want to
# if head_mask is not None:
# attention_probs = attention_probs * head_mask
#
# context_layer = torch.matmul(attention_probs, value_layer)
# NEW EFFICIENT ATTENTION:
#######################
key_layer = nn.functional.softmax(key_layer, dim=3)
query_layer = nn.functional.softmax(query_layer, dim=2)
G = torch.matmul(key_layer.transpose(-1, -2), value_layer)
context_layer = torch.matmul(query_layer, G)
#######################
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
# THIS LINE HERE DID NOT MAKE SENSE WITH THIS ATTENTION
# outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return (context_layer,)
As you can see, I commented out 15 lines of code and added 4 new ones, and this worked out of the box. Of course, to make it more accessible, one would need to put in a switch and a flag in the config etc etc…
On a side note, many thanks to @NielsRogge for implementing the ViTMAE model, I am loving it!
Issue Analytics
- State:
- Created a year ago
- Comments:7 (3 by maintainers)
Top GitHub Comments
Hi,
Loved reading that 😄 thanks for your interest in ViTMAE. It would definitely be nice to have one of these more efficient attention models in the library. However, we usually only add models that have dedicated pre-trained weights.
As Transformers is not really a modular toolbox, we treat each model rather independently in the library. This means that, in case we would have a model with an efficient attention variant, it would have its own modeling files, doc page, etc.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.