question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

FastAttention doesn't give results in agreement with standard attention?

See original GitHub issue

Hi there,

I ran this code to compare the results of standard attention with fast attention. Surprisingly, I’m getting very large errors (about 80%). Any idea as to where this comes from?

import torch
import numpy as np
from performer_pytorch import FastAttention

num_nodes = 24
feat_dim = 128
nb_features = 8 * feat_dim

num_its = 5
errs = []

for _ in range(num_its):
    Q = torch.randn(1, 1, num_nodes, feat_dim)
    K = torch.randn(1, 1, num_nodes, feat_dim)
    V = torch.randn(1, 1, num_nodes, feat_dim)

    # fast attention
    
    attn = FastAttention(dim_heads=feat_dim,
                         nb_features=nb_features,
                         causal=False)
        
    fast = attn(q=Q,
                k=K,
                v=V)


    Q = Q.reshape(-1, feat_dim)
    K = K.reshape(-1, feat_dim)
    V = V.reshape(-1, feat_dim)

    
    # standard attention
    
    A = torch.exp(torch.matmul(Q, K.transpose(0, 1)) / feat_dim ** 0.5)
    ones = torch.ones(num_nodes)
    D_inv = torch.diag(1 / torch.matmul(A, ones))
    slow = torch.matmul(D_inv, torch.matmul(A, V))

    err = abs(slow - fast).mean() / abs(slow).mean() * 100
    
    errs.append(err)

mean_err = np.mean(errs)
std_err = np.std(errs)

print("Error is (%.2f +/- %.2f)%%" % (mean_err, std_err)) # prints Error is (73.28 +/- 1.99)%

Issue Analytics

  • State:open
  • Created 2 years ago
  • Comments:7

github_iconTop GitHub Comments

1reaction
simonaxelrodcommented, Jul 1, 2021

I haven’t tried it was Jax yet but I’ll give that a shot

0reactions
bngcodecommented, Aug 24, 2021

Consider equation (7) of the paper. In this Lemma the SM+ is defined as an expectation of some exponential terms. The Expectation is of course an integral over R^d. Now what is done in the code and in the paper is that to approximate this integral, we take m=numb_features samples and orthogonalized them. Now this induced a huge error. The integral cannot be approximated sufficient good enough by taking only m points from R^d.

This explains could explain the errors by @simonaxelrod or am I missing here something?

Read more comments on GitHub >

github_iconTop Results From Across the Web

FNet: Do we need the attention layer at all? [Explained with ...
Giving up of attention mechanism is an interesting direction of research happening right now. The attention mechanism is definitely cool.
Read more >
The Set-Up-To-Fail Syndrome - Harvard Business Review
When an employee fails—or even just performs poorly—managers typically do not blame themselves. The employee doesn't understand the work, a manager might ...
Read more >
What Happens When There Is Lack of Attention in Relationship?
Is attention important in a relationship? Effects of lack of attention in relationship; 6 Signs your partner needs attention; 6 Causes of lack ......
Read more >
Lecture 13: Attention - YouTube
Lecture 13 introduces attention as a mechanism for deep networks to dynamically pay attention to different parts of their inputs.
Read more >
Why Attention to Detail is Important in the Workplace
The antithesis of attention to detail is carelessness, which implies a lack of appreciation or interest in the finer details. When someone cares ......
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found