question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Efficient Attention

See original GitHub issue

Feature request

The authors of https://openaccess.thecvf.com/content/WACV2021/papers/Shen_Efficient_Attention_Attention_With_Linear_Complexities_WACV_2021_paper.pdf propose to change the attention module in transformers to achieve linear complexity in the number of processed tokens (instead of the feared quadratic complexity). I’d like to request the following feature: The possibility to choose between standard and efficient attention, for instance via a flag in the corresponding config file when building a huggingface model.

Motivation

Any model with a variable size of input tokens can benefit from this, and it is especially useful if one wants to process a lot of tokens. Say, number of words in a text or number of patches in an image. The authors of the paper show that this can significantly reduce inference time, training time and memory load. In my case, I have implemented this change for the ViTMAE model and noticed that I can easily process images of size 592x592 (37x37 = 1369 (!!!) tokens) now, whereas before my machine was capped with 384x384-sized images (24x24 = 576 tokens). Due to the linear complexity, I could even go for higher token count/image resolution, trading off with batch size. The model size is not affected (it’s magic, really), and for the ViTMAE model I have confirmed that this type of attention is working quite well.

Your contribution

The change is really easy to implement, but might take a bit of work because it most likely has to be done for every model of interest. For the ViTMAEmodel, I have done the following quick and dirty tweak to the class ViTMAESelfAttention:

# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention ViT->ViTMAE
class ViTMAESelfAttention(nn.Module):
    def __init__(self, config: ViTMAEConfig) -> None:
        super().__init__()
        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
            raise ValueError(
                f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
                f"heads {config.num_attention_heads}."
            )

        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
        self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
        self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(
        self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
    ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
        mixed_query_layer = self.query(hidden_states)

        key_layer = self.transpose_for_scores(self.key(hidden_states))
        value_layer = self.transpose_for_scores(self.value(hidden_states))
        query_layer = self.transpose_for_scores(mixed_query_layer)
        
        # OLD ATTENTION
        # Take the dot product between "query" and "key" to get the raw attention scores.
        # attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        #
        # attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        #
        # # Normalize the attention scores to probabilities.
        # attention_probs = nn.functional.softmax(attention_scores, dim=-1)
        #
        # # This is actually dropping out entire tokens to attend to, which might
        # # seem a bit unusual, but is taken from the original Transformer paper.
        # attention_probs = self.dropout(attention_probs)
        #
        # # Mask heads if we want to
        # if head_mask is not None:
        #     attention_probs = attention_probs * head_mask
        #
        # context_layer = torch.matmul(attention_probs, value_layer)

        # NEW EFFICIENT ATTENTION:
        #######################
        key_layer = nn.functional.softmax(key_layer, dim=3)
        query_layer = nn.functional.softmax(query_layer, dim=2)
        G = torch.matmul(key_layer.transpose(-1, -2), value_layer)
        context_layer = torch.matmul(query_layer, G)
        #######################

        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)

# THIS LINE HERE DID NOT MAKE SENSE WITH THIS ATTENTION
#        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

        return (context_layer,)

As you can see, I commented out 15 lines of code and added 4 new ones, and this worked out of the box. Of course, to make it more accessible, one would need to put in a switch and a flag in the config etc etc…

On a side note, many thanks to @NielsRogge for implementing the ViTMAE model, I am loving it!

Issue Analytics

  • State:closed
  • Created a year ago
  • Comments:7 (3 by maintainers)

github_iconTop GitHub Comments

1reaction
NielsRoggecommented, Jul 29, 2022

Hi,

Loved reading that 😄 thanks for your interest in ViTMAE. It would definitely be nice to have one of these more efficient attention models in the library. However, we usually only add models that have dedicated pre-trained weights.

As Transformers is not really a modular toolbox, we treat each model rather independently in the library. This means that, in case we would have a model with an efficient attention variant, it would have its own modeling files, doc page, etc.

0reactions
github-actions[bot]commented, Aug 27, 2022

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Efficient Attention: Attention with Linear Complexities - arXiv
Its resource efficiency allows more widespread and flexible integration of attention modules into a network, which leads to better accuracies.
Read more >
Efficient Attention: Attention With Linear ... - CVF Open Access
To remedy this drawback, this paper proposes a novel efficient attention mechanism equivalent to dot-product attention but with substantially less mem- ory and ......
Read more >
An implementation of the efficient attention module. - GitHub
Efficient attention is an attention mechanism that substantially optimizes the memory and computational efficiency while retaining exactly the same ...
Read more >
215 - Efficient Attention: Attention with Linear Complexities
WACV 2021. 215 - Efficient Attention : Attention with Linear Complexities. 96 views 1 year ago. ComputerVisionFoundation Videos.
Read more >
Efficient Attention: Attention with Linear Complexities
Efficient Attention : attention with Linear Complexities is a work by myself and colleagues at SenseTime. We proposed a simple but effective ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found