Memory Efficiency w.r.t Sequence Length
See original GitHub issueI am a bit of a noob when it comes to transformers. If I want to encode a batch of N
sequences of maximum length L
, my understanding is that I do something like this:
from x_transformer import Encoder, TransformerWrapper
seqs = ['aba','cb','abcab']
N = len(seqs)
L = max(len(seq) for seq in seqs)
C = 3
padded_seqs = get_padded_seqs(seqs) # N x L long tensor
mask = get_seq_mask(seqs) # N x L boolean tensor
encoder = TransformerWrapper(num_tokens=C,max_seq_len=L,attn_layers=Encoder())
embeddings = encoder(padded_seqs,mask=mask,return_embeddings=True)
In this transformer implementation, would there be a difference in memory usage if all of the sequences were of length L
(i.e. all the mask
values were True
)?
Issue Analytics
- State:
- Created 2 years ago
- Reactions:1
- Comments:5 (2 by maintainers)
Top Results From Across the Web
Reformer: The Efficient Transformer | by Rohan Jagtap
Say we want to train a Transformer for a sequence of length as long as 64K. Here, the 0.5B parameters account for 2GB...
Read more >LINEAR COST SELF-ATTENTION VIA BERNOULLI SAMPLING
We evaluate our proposed algorithm on the GLUE benchmark with standard 512 sequence length and our method achieves comparable or even slightly better ......
Read more >Sparse Attentive Memory Network for Click-through Rate Prediction ...
SAM supports efficient training and real-time inference for user behavior sequences with lengths on the scale of thousands. In SAM, we model the...
Read more >Nyströmformer: A Nystöm-based Algorithm for Approximating ...
Formally, an input sequence of n tokens of dimensions d, ... linearly w.r.t. input sequence length in the sense of both memory and...
Read more >FastRPB: a Scalable Relative Positional Encoding for Long ...
However, these models have shown weaker performance on the long sequence tasks ... requiring O(N) memory w.r.t. input sequence length N.
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Thanks, that’s a good solution! Will check it out.
@adamoyoung yea, the transformers community went a very different direction than that of graph neural nets and how it is approached with PyG. we typically don’t do it the scatter/gather way, though I have met researchers who were interested in writing CUDA kernels to remove attention on the padding. i think batching by similar lengths is a good middle ground that i’ve seen used by others (one such implementation i came across https://github.com/jonathanking/sidechainnet/blob/4d4f57204c162ab938b8762dfacffb1d992774d0/sidechainnet/dataloaders/SimilarLengthBatchSampler.py#L9 )