question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Question about flash attention dropout

See original GitHub issue

Hi

I noticed that with attention dropout, besides dropout for P, in the final write, there is code as follows

for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
            float sum = p_sum_o[jj][0];
            float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
            if (Is_dropout && is_final_write) {
                inv_sum *= params.rp_dropout;
            }
            out[jj] = fmha::fmul4(out[jj], inv_sum);
        }

https://github.com/HazyResearch/flash-attention/blob/main/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h#L629

same for backward https://github.com/HazyResearch/flash-attention/blob/main/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h#L689

I did not find it in the paper

image

Issue Analytics

  • State:closed
  • Created 9 months ago
  • Comments:9 (4 by maintainers)

github_iconTop GitHub Comments

1reaction
tridaocommented, Dec 21, 2022

There’s a trick in the implementation here to save storage: This line generates the dropout mask and stores the dropout mask as sign of the elements in P (e.g. we store 0.1 if attention prob is 0.1 and the corresponding dropout mask is 1, we store -0.2 if attention prob is 0.2 and the corresponding dropout mask is 0).

This block of code corresponds to line 18 and 20: if the value we store in p is negative, then we get ds = p * d (i.e., -attn_prob * d), because it corresponds to dropout mask being zero. If the value we store in p is positive, then we get ds = p * acc_dp (i.e., attn_prob * (dp - d)), since it corresponds to dropout mask being 1.

Later on when we need to use P we take relu to remove the dropout mask that was being stored as sign of P.

If you want to implement it it’s probably simpler to store the dropout mask separately without this trick.

0reactions
SeaOfOceancommented, Dec 21, 2022

image

Another question, what’s the code corresponding line:18 in algorithm?

Read more comments on GitHub >

github_iconTop Results From Across the Web

Issues · HazyResearch/flash-attention - GitHub
Issues list ; Triton version with dropout. #90 opened 2 days ago ; Support for PyTorch 2.0. #88 opened last week ; Numerical...
Read more >
MedAI #54: FlashAttention: Fast and Memory-Efficient Exact ...
Title: FlashAttention: Fast and Memory-Efficient Exact Attention with IO-AwarenessSpeaker: Tri DaoAbstract:Transformers are slow and ...
Read more >
FlashAttention: Fast and Memory-Efficient Exact ... - DeepAI
An important question is whether making attention faster and more memory-efficient can help Transformer models address their runtime and ...
Read more >
FlashAttention: Fast and Memory-Efficient ... - Papers With Code
We argue that a missing principle is making attention algorithms IO-aware -- accounting for reads and writes between levels of GPU memory. We ......
Read more >
Fast and Memory-Efficient Exact Attention with IO-Awareness
We propose FlashAttention, an IO-aware exact attention algorithm that uses ... We first report some updates, address a few common questions, ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found