[Question!] How to Inject Rotary Positional Embeddings in Linear Transformers
See original GitHub issueHello Phil,
Do you mind how to inject the rotary positional embeddings into the linear transformers ?
import torch
from torch.nn import Module
from ..attention_registry import AttentionRegistry, Optional, Callable, Int, \
EventDispatcherInstance
from ..events import EventDispatcher
from ..feature_maps import elu_feature_map
class LinearAttention(Module):
"""Implement unmasked attention using dot product of feature maps in
O(N D^2) complexity.
Given the queries, keys and values as Q, K, V instead of computing
V' = softmax(Q.mm(K.t()), dim=-1).mm(V),
we make use of a feature map function Φ(.) and perform the following
computation
V' = normalize(Φ(Q).mm(Φ(K).t())).mm(V).
The above can be computed in O(N D^2) complexity where D is the
dimensionality of Q, K and V and N is the sequence length. Depending on the
feature map, however, the complexity of the attention might be limited.
Arguments
---------
feature_map: callable, a callable that applies the feature map to the
last dimension of a tensor (default: elu(x)+1)
eps: float, a small number to ensure the numerical stability of the
denominator (default: 1e-6)
event_dispatcher: str or EventDispatcher instance to be used by this
module for dispatching events (default: the default
global dispatcher)
"""
def __init__(self, query_dimensions, feature_map=None, eps=1e-6,
event_dispatcher=""):
super(LinearAttention, self).__init__()
self.feature_map = (
feature_map(query_dimensions) if feature_map else
elu_feature_map(query_dimensions)
)
self.eps = eps
self.event_dispatcher = EventDispatcher.get(event_dispatcher)
def forward(self, queries, keys, values, attn_mask, query_lengths,
key_lengths):
# Apply the feature map to the queries and keys
self.feature_map.new_feature_map(queries.device)
Q = self.feature_map.forward_queries(queries)
K = self.feature_map.forward_keys(keys)
# Apply the key padding mask and make sure that the attn_mask is
# all_ones
if not attn_mask.all_ones:
raise RuntimeError(("LinearAttention does not support arbitrary "
"attention masks"))
K = K * key_lengths.float_matrix[:, :, None, None]
# Compute the KV matrix, namely the dot product of keys and values so
# that we never explicitly compute the attention matrix and thus
# decrease the complexity
KV = torch.einsum("nshd,nshm->nhmd", K, values)
# Compute the normalizer
Z = 1/(torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1))+self.eps)
# Finally compute and return the new values
V = torch.einsum("nlhd,nhmd,nlh->nlhm", Q, KV, Z)
return V.contiguous()
Thanks!
Issue Analytics
- State:
- Created 2 years ago
- Comments:5 (4 by maintainers)
Top Results From Across the Web
Rotary Embeddings: A Relative Revolution - Blog - EleutherAI
Rotary Positional Embedding (RoPE) is a new type of position encoding that unifies absolute and relative approaches. We put it to the test....
Read more >Rotating The Way We View Position Embeddings
Position embeddings are used in transformers to give them an idea of positions. I'll give a brief overview of transformers and attention to ......
Read more >RoFormer - Hugging Face
We investigate various methods to encode positional information in transformer-based language models and propose a novel implementation named Rotary Position ...
Read more >Transformer Architecture: The Positional Encoding
Let's use sinusoidal functions to inject the order of words in our model.
Read more >Linear Relationships in the Transformer's Positional Encoding
use positional encoding, to inject information about a token's position within a sentence into the model. The exact definition is written down ...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
@gaceladri make sure to turn off absolute positional embeddings when you try it! it conflicts with rotary for some unknown reason - more research needed
Amazing! You are amazing! Thanks a lot, I will try it!!