parallel execution of read/write operations (paper vs implementation)
See original GitHub issueIn Algorithm 1 of FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness, the inner loop starts by loading from HBM in SRAM a tile from tensor O into Oi. It then do some ops and save Oi from SRAM into HBM tensor O. For each iteration of the outer loop, the inner loop do a full round and therefore there is a full read / write of O tensor.
If the outer loop (j) is executed serially, it should provide the correct result.
But when we execute in parallel the outer loop, it seems that we would have concurrent access to O for both reading and writing.
Therefore, if we strictly follow the algorithm 1 and βjustβ parallelize the outer loop, the output should not be correct.
In the triton implementation, they have switched inner and outer loop, therefore they donβt need to load Oi, they start with a zeroed one, do all computation and save to HBM the final result. Therefore they donβt have any concurrent access issue on this variable.
Can you please clarify the way the parallelization is implemented in CUDA? Is there something missing in the description above which may explain why algorithm 1 would work in parallel?
Issue Analytics
- State:
- Created a year ago
- Reactions:1
- Comments:6 (3 by maintainers)

Top Related StackOverflow Question
Hi tridao, sorry for bothering you in a closed issue again. As Triton swaps the order of the loops, ori:
swapped:
It seems that the loading of K/V increases to ππ times for the Triton version?
Thatβs a great observation! The CUDA code implements the algorithm as written in the paper, we do not parallelize the outer loop. Instead, we parallelize over the batch and nheads dimensions (each threadblock computes 1 head). This is sufficient if batch * nheads is large enough (around >= 80) so we have enough parallel work for each threadblock to do.
One could also swap the order of the inner and outer loops (as done in Triton).