Some questions
See original GitHub issueHello,
Thank you for your awesome work! Flash attention is going to be used everywhere!
I have a few questions please:
- To use flash attention in an existing PyTorch transformer, it suffices to replace
torch.nn.MultiheadAttentionwithflash_attn.flash_attention.FlashMHA, is that correct? - Training is also supported out-of-the-box I guess? The question also includes mixed precision training, i.e., compatibility with
torch.autocast()context manager. - I see that you also provide a Fused Softmax implementation. According to the docstrings, this layer is used for auto-regressive models. If I only use the transformer encoder, e.g., vision transformers, then it’s not worth using it. Is that correct?
Thank you very much in advance for your answers.
Issue Analytics
- State:
- Created a year ago
- Reactions:1
- Comments:5 (2 by maintainers)
Top Results From Across the Web
100 Getting to Know You Questions - SignUpGenius
100 Getting to Know You Questions · Who is your hero? · If you could live anywhere, where would it be? · What...
Read more >450 Fun Questions to Ask People in ANY Situation (That Work!)
Deep Questions to Ask People · Who knows you best? · Where do you see yourself in 10 years? · What makes you...
Read more >500 Good Questions to Ask - Conversation Starters World
GOOD QUESTIONS TO ASK · What weird food combinations do you really enjoy? · What social stigma does society need to get over?...
Read more >400 Fun Questions to Ask People (Friends, Family, Strangers)
400 Wacky, Wild & Totally Fun Questions to Ask Anyone—Including Friends, Family & Even Strangers! Find a good, interesting, and random question ......
Read more >272 Deep Questions to Ask: A Guy, Girl, Friend, or Anyone
One way is to ask them deep questions. So here are some deep questions you can ask different people--people like your partner, friends,...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found

Yes. The two modules do the same thing, though they might have different APIs and arguments. You should read the arguments and documents to pass in the right things.
Yes, training works, and mixed-precision training works. torch.autocast() will do the right thing and cast the q, k, v to either fp16 or bf16. FlashAttention does not support fp32.
The Fused Softmax was taken from apex/megatron purely for benchmarking. They’re only useful if either you have causal mask before softmax (e.g. autoregressive models) or key padding mask before softmax (e.g. BERT where sequences in a batch have different lengths). If you’re using transformer encoder and all sequences in the batch have the same length, then they’re won’t apply to your case.
@tridao Great, thanks! Will create a PR soon!