Question about flash attention dropout
See original GitHub issueHi
I noticed that with attention dropout, besides dropout for P, in the final write, there is code as follows
for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
float sum = p_sum_o[jj][0];
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
if (Is_dropout && is_final_write) {
inv_sum *= params.rp_dropout;
}
out[jj] = fmha::fmul4(out[jj], inv_sum);
}
same for backward https://github.com/HazyResearch/flash-attention/blob/main/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h#L689
I did not find it in the paper
Issue Analytics
- State:
- Created 9 months ago
- Comments:9 (4 by maintainers)
Top Results From Across the Web
Issues · HazyResearch/flash-attention - GitHub
Issues list ; Triton version with dropout. #90 opened 2 days ago ; Support for PyTorch 2.0. #88 opened last week ; Numerical...
Read more >MedAI #54: FlashAttention: Fast and Memory-Efficient Exact ...
Title: FlashAttention: Fast and Memory-Efficient Exact Attention with IO-AwarenessSpeaker: Tri DaoAbstract:Transformers are slow and ...
Read more >FlashAttention: Fast and Memory-Efficient Exact ... - DeepAI
An important question is whether making attention faster and more memory-efficient can help Transformer models address their runtime and ...
Read more >FlashAttention: Fast and Memory-Efficient ... - Papers With Code
We argue that a missing principle is making attention algorithms IO-aware -- accounting for reads and writes between levels of GPU memory. We ......
Read more >Fast and Memory-Efficient Exact Attention with IO-Awareness
We propose FlashAttention, an IO-aware exact attention algorithm that uses ... We first report some updates, address a few common questions, ...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
There’s a trick in the implementation here to save storage: This line generates the dropout mask and stores the dropout mask as sign of the elements in P (e.g. we store 0.1 if attention prob is 0.1 and the corresponding dropout mask is 1, we store -0.2 if attention prob is 0.2 and the corresponding dropout mask is 0).
This block of code corresponds to line 18 and 20: if the value we store in p is negative, then we get ds = p * d (i.e., -attn_prob * d), because it corresponds to dropout mask being zero. If the value we store in p is positive, then we get ds = p * acc_dp (i.e., attn_prob * (dp - d)), since it corresponds to dropout mask being 1.
Later on when we need to use P we take relu to remove the dropout mask that was being stored as sign of P.
If you want to implement it it’s probably simpler to store the dropout mask separately without this trick.
Another question, what’s the code corresponding line:18 in algorithm?