Error with shapes in LocalAttention
See original GitHub issueHi, recently updated to the recent version of the package (0.22.3
) and caught a bug with shapes in LocalAttention
’s forward()
method (reformer_pytorch.py#L502) when both num_mem_kv
and n_local_attn_heads
are non-zero:
501 out = torch.einsum('bhij,bhje->bhie', attn, bv)
--> 502 out = out.reshape(b, t, e)
503 return out
504
RuntimeError: shape '[2, 2048, 64]' is invalid for input of size 278528
When any of num_mem_kv
/n_local_attn_heads
is zero, the model works fine.
Minimal example for reproducing the error:
import torch
from reformer_pytorch import ReformerLM
num_tokens, seq_len = 1024, 2048
model = ReformerLM(
num_tokens=num_tokens,
dim=512,
depth=6,
max_seq_len=seq_len,
num_mem_kv=128,
n_local_attn_heads=2,
).cuda()
x = torch.randint(0, num_tokens, (1, seq_len)).long().cuda()
y = model(x)
Hope that helps! Thanks for your work on the package!
Issue Analytics
- State:
- Created 3 years ago
- Reactions:1
- Comments:5 (5 by maintainers)
Top Results From Across the Web
TFLongformer Shape Error - Hugging Face Forums
Hi, when trying to finetune the TFLongformer using the TFTrainer, I got this error InvalidArgumentError: 2 root error(s) found.
Read more >keras - Cannot implement attention-LSTM-attention model
I am new to using keras and want to create a model with the structure like input>>attention>>LSTM>>attention>>output. But an error occurred ...
Read more >Keras throwing a labels shape mismatch error : r/deeplearning
Keras throwing a labels shape mismatch error. I am building a multiclass segmentation model using DeepLapv3+ and ResNet50.
Read more >Attention in Neural Networks - Towards Data Science
Once the key and value vectors are defined, the rest of the network could be any attention utilizing model.
Read more >Attention Mechanism In Deep Learning - Analytics Vidhya
A complete guide to attention models and attention mechanisms in deep learning. Learn how to implement an attention model in python using ...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
@lucidrains for now I did not experiment much with local attention, just got back to Reformer and updated to the recent version. If I find any strange or positive behaviour, I will report it.
@ilya16 please do!