question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Numerical Stability Concerns

See original GitHub issue

Hi, this seems like a super promising direction for making large language model training more scalable and accessible. However, as was also noted in #10, there seems to be a slight, but potentially significant numerical mismatch with the reference Pytorch attention implementation. A quick check of the computed values with a 768-dimension, 4,096 sequence length setup yields:

plain_fwd = attention_ref(qkv, attention_mask_bool, dropout_p, causal=causal)
flash_fwd = flash_attn_func(qkv[0], cu_seqlens, dropout_p, max_seqlen_in_batch, causal=causal)
delta = plain_fwd - flash_fwd
total_value = lambda val: val.abs().sum().item()
print(f'Total Activations: plain={total_value(plain_fwd):.2f}, flash={total_value(flash_fwd):.2f}, delta: {total_value(delta):.2f}')
# Total Activations: plain=65408.00, flash=65408.00, delta: 7.32

That is, while both functions achieve the same total value, small differences sum up to a meaningful delta. An inspection of the difference shows that about 86% of the values in delta are “true” zeros, with the remainder being off by about 1.7e-5 on average. Locations where the delta was non-zero were not particularly different in absolute activation.

Note that I removed all padding and replaced qkv_unpad with qkv[0] to make these more comparable – using qkv_unpad even without any padding tokens present cause the delta to grow to 93.31 and the total activations to diverge. I also disabled dropout to ensure this is not due to randomness in the function itself.

Do you have any leads on where this instability might originate and if it can be remedied? If it can be fixed, this optimization could be really impactful.

-Vincent

Issue Analytics

  • State:closed
  • Created a year ago
  • Comments:13 (8 by maintainers)

github_iconTop GitHub Comments

5reactions
tridaocommented, Jul 2, 2022

Btw we’re working on implementing FlashAttention for bf16 format. The folk wisdom seems to be that bf16 training is more stable than fp16 training in general.

4reactions
tridaocommented, Jul 10, 2022

bf16 is now supported.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Numerical stability - Wikipedia
Calculations that can be proven not to magnify approximation errors are called numerically stable. One of the common tasks of numerical analysis is...
Read more >
Numerical Issues | CFD-101
First and foremost issue is the question of numerical stability; a simulation that is not computationally stable is useless. There are four separate...
Read more >
Model Problems in Numerical Stability Theory for Initial Value ...
In the past numerical stability theory for initial value problems in ordinary differential equations has been dominated by the study of problems with...
Read more >
4 Stiffness and Stability
This notion of stability is often referred to as absolute stability, and it is important when dealing with stiff ODEs. An absolutely stable...
Read more >
Numerical Stability - University of Saskatchewan
Problem conditioning and numerical stability. •. •. 1. Page 3. Problem Conditioning. In a very abstract sense, solving a problem is like.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found