question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Numerical stability of attention layers

See original GitHub issue

Hello!

Thank’s for a cool implementation! I’ve been playing around with the code on CIFAR-10 trying to replicate SOTA fid values. However, I get some nasty loss explosions all of a sudden in the middle of training in fp16 with sensible configs, e.g.,

class BaseUnet32(Unet):
    def __init__(self, *args, **kwargs):
        kwargs.update(
            dict(
                dim=192,
                dim_mults=(1, 2, 4),
                num_resnet_blocks=2,
                layer_attns=(False, True, True),
                layer_cross_attns=False,
                attn_heads=1,
                ff_mult=2.0,
            )
        )
        super().__init__(*args, **kwargs)

I’ve looked into attention implementation in guided-diffusion and it seems that they do scaling by 1/sqrt(dim) a bit differently by redistributing it as 1/sqrt(sqrt(dim)) in q and k (link), as well as doing softmax on float32 inputs. Testing it at the moment, will let you know if it helps

Issue Analytics

  • State:closed
  • Created a year ago
  • Reactions:1
  • Comments:18 (11 by maintainers)

github_iconTop GitHub Comments

1reaction
KhrulkovVcommented, Jun 22, 2022

@lucidrains Hey, I set attention heads to 4 and lowered learning rate, going good so far (both fp32 and fp16 runs are roughly identical)

1reaction
KhrulkovVcommented, Jun 21, 2022

But the boost was really nice - 1200 img/sec to 1800 img/sec until it diverged 😹

Read more comments on GitHub >

github_iconTop Results From Across the Web

5.4. Numerical Stability and Initialization
Xavier initialization suggests that, for each layer, variance of any output is not affected by the number of inputs, and variance of any...
Read more >
[1911.03584] On the Relationship between Self-Attention and ...
Our numerical experiments then show that self-attention layers attend to pixel-grid patterns similarly to CNN layers, corroborating our ...
Read more >
Kernelized Attention with Relative Positional Encoding
Stable, Fast and Accurate: Kernelized Attention with ... For each layer, the hidden size is set to 768, and the number of attention...
Read more >
neural networks - What is numerical stability?
Numerical stability refers to how a malformed input affects the execution of an algorithm. In a numerically stable algorithm, errors in the ...
Read more >
Write your own custom Attention layer: Easy, intuitive guide
The goal of this layer is that it should give us the 'attention weights' — one for each word. Number of attention weights...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found