einsum operation in Linear Attention Part
See original GitHub issueHi, Thanks a lot for your FLASH_pytorch, which helps a lot. I found that there are some differences from the paper in the Linear Attention Part: https://github.com/lucidrains/FLASH-pytorch/blob/main/flash_pytorch/flash_pytorch.py#L342-L343
lin_kv = einsum('b g n d, b g n e -> b d e', lin_k, v) / n
lin_out = einsum('b g n d, b d e -> b g n e', lin_q, lin_kv)
the lin_kv
is three-dim (bde
)
And the code in the paper is
lin_kv = tf.einsum('bhke,bgh→bgke', lin_kv, mask)
linear = tf.einsum('bgnk,bgke→bgne', lin_q, lin_kv)
the lin_kv
is four-dim (bgke
)
It seems that the two ways are not equivalent.
Looking forward to your reply. Best,
Issue Analytics
- State:
- Created a year ago
- Comments:5 (2 by maintainers)
Top Results From Across the Web
Understanding einsum for Deep learning: implement a ...
Learn about the einsum notation and einops by coding a custom multi-head self-attention unit and a transformer block.
Read more >Einsum is All you Need - Einstein Summation in Deep Learning
Einsum notation is an elegant way to express all of these, as well as complex operations on tensors, using essentially a domain-specific ...
Read more >python - Understanding NumPy's einsum
So here, the indexing operation on A lines up the first axes of the two arrays so that the multiplication can be broadcast....
Read more >Fast Transformer Decoding: One Write-Head is All You Need
A neural attention function takes a single query-vector q and a set ... Our code samples use einsum notation, as defined in TensorFlow...
Read more >`einsum` is ~20X slower than manually multiplying and ...
Bug A manual multiplication and summation (a * b).sum(dim = (-3, -2, -1)) is about 20X faster than the equivalent einsum.
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Hi, It indeed that there is a reduction for all groups. However, in the final page
Code 8: Pseudocode for FLASH
, there is no reduction for groups. So maybe both are OK. (In my opinion, if there is a sum reduction for all groups, the attention results would be quite larger than the quad_part?)When I read this part of the expressions and formulas, it should be that the reduction is the group dimension.