Numerical Stability Concerns
See original GitHub issueHi, this seems like a super promising direction for making large language model training more scalable and accessible. However, as was also noted in #10, there seems to be a slight, but potentially significant numerical mismatch with the reference Pytorch attention implementation. A quick check of the computed values with a 768-dimension, 4,096 sequence length setup yields:
plain_fwd = attention_ref(qkv, attention_mask_bool, dropout_p, causal=causal)
flash_fwd = flash_attn_func(qkv[0], cu_seqlens, dropout_p, max_seqlen_in_batch, causal=causal)
delta = plain_fwd - flash_fwd
total_value = lambda val: val.abs().sum().item()
print(f'Total Activations: plain={total_value(plain_fwd):.2f}, flash={total_value(flash_fwd):.2f}, delta: {total_value(delta):.2f}')
# Total Activations: plain=65408.00, flash=65408.00, delta: 7.32
That is, while both functions achieve the same total value, small differences sum up to a meaningful delta. An inspection of the difference shows that about 86% of the values in delta are “true” zeros, with the remainder being off by about 1.7e-5 on average. Locations where the delta was non-zero were not particularly different in absolute activation.
Note that I removed all padding and replaced qkv_unpad with qkv[0] to make these more comparable – using qkv_unpad even without any padding tokens present cause the delta to grow to 93.31 and the total activations to diverge. I also disabled dropout to ensure this is not due to randomness in the function itself.
Do you have any leads on where this instability might originate and if it can be remedied? If it can be fixed, this optimization could be really impactful.
-Vincent
Issue Analytics
- State:
- Created a year ago
- Comments:13 (8 by maintainers)

Top Related StackOverflow Question
Btw we’re working on implementing FlashAttention for bf16 format. The folk wisdom seems to be that bf16 training is more stable than fp16 training in general.
bf16 is now supported.