Numerical stability of attention layers
See original GitHub issueHello!
Thank’s for a cool implementation! I’ve been playing around with the code on CIFAR-10 trying to replicate SOTA fid values. However, I get some nasty loss explosions all of a sudden in the middle of training in fp16 with sensible configs, e.g.,
class BaseUnet32(Unet):
def __init__(self, *args, **kwargs):
kwargs.update(
dict(
dim=192,
dim_mults=(1, 2, 4),
num_resnet_blocks=2,
layer_attns=(False, True, True),
layer_cross_attns=False,
attn_heads=1,
ff_mult=2.0,
)
)
super().__init__(*args, **kwargs)
I’ve looked into attention implementation in guided-diffusion
and it seems that they do scaling by 1/sqrt(dim) a bit differently by redistributing it as 1/sqrt(sqrt(dim)) in q and k (link), as well as doing softmax
on float32 inputs.
Testing it at the moment, will let you know if it helps
Issue Analytics
- State:
- Created a year ago
- Reactions:1
- Comments:18 (11 by maintainers)
Top Results From Across the Web
5.4. Numerical Stability and Initialization
Xavier initialization suggests that, for each layer, variance of any output is not affected by the number of inputs, and variance of any...
Read more >[1911.03584] On the Relationship between Self-Attention and ...
Our numerical experiments then show that self-attention layers attend to pixel-grid patterns similarly to CNN layers, corroborating our ...
Read more >Kernelized Attention with Relative Positional Encoding
Stable, Fast and Accurate: Kernelized Attention with ... For each layer, the hidden size is set to 768, and the number of attention...
Read more >neural networks - What is numerical stability?
Numerical stability refers to how a malformed input affects the execution of an algorithm. In a numerically stable algorithm, errors in the ...
Read more >Write your own custom Attention layer: Easy, intuitive guide
The goal of this layer is that it should give us the 'attention weights' — one for each word. Number of attention weights...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
@lucidrains Hey, I set attention heads to 4 and lowered learning rate, going good so far (both fp32 and fp16 runs are roughly identical)
But the boost was really nice - 1200 img/sec to 1800 img/sec until it diverged 😹