question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Possible bug in end-dec attention?

See original GitHub issue

In the encoder-decoder architecture, encoder output is passed to decoder as keys to be used in attention. Here (https://github.com/lucidrains/reformer-pytorch/blob/5f5bbf4fd5806f45d2cb3b7373021786b3b34e5b/reformer_pytorch/reformer_pytorch.py#L598) you are concating keys with x (where x is the decoder input) and then apply self-attention. Does it make sense to do self attention on decoder-input and encoder outputs? Because even in the trax codes these two are handled separately: (https://github.com/google/trax/blob/c7c47a14ef8ea5b260ac78c22cbadd6dc1fb605b/trax/models/reformer/reformer.py#L968) at first self attention is applied on the decoder input, and then a seperate encoder-decoder attention is applied between the new representation for decoder and the keys.

I don’t if this is the reason or not but I have this simple copy-reverse task where the loss stops at 2.08. However in the trax code the loss becomes close to 0 after a few steps.

def cycle():
    while True:
        source = torch.randint(2, 10, (32, 768)).long().cuda()
        target_np = np.flip(source.cpu().numpy(),axis=1).copy()   #Reverse of copy of numpy array of given tensor
        target = torch.from_numpy(target_np).long().cuda()

        mask = torch.ones(32, 768).bool().cuda()

        yield (source, target, mask)

# First example: Copy Reverse: 768 tokens - vocab size: 256

model = ReformerEncDec(
    dim = 512,
    enc_num_tokens = 256,
    enc_depth = 1,
    enc_max_seq_len = 768,
    enc_heads=1,
    dec_num_tokens = 256,
    dec_depth = 1,
    dec_max_seq_len = 768,
    dec_heads=1,
).cuda()

#model = TrainingWrapper(model)
model.cuda()


# optimizer

optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# training

for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
    model.train()

    for __ in range(GRADIENT_ACCUMULATE_EVERY):
        source, target, mask = next(cycle())
        loss = model(seq_in=source, seq_out=target, return_loss = True, enc_input_mask=mask)
        loss.backward()

    print(f'training loss: {loss.item()}')
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optim.step()
    optim.zero_grad()

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Reactions:2
  • Comments:18 (13 by maintainers)

github_iconTop GitHub Comments

1reaction
py4commented, May 1, 2020

Ok. Thank you 😃

1reaction
lucidrainscommented, May 1, 2020

If you run the encoder / decoder from my Sinkhorn project against the Reformer on full attention, you will see they take about the same number of iterations for the task you have, even though Sinkhorn has self-attention and contextual attention separate

Read more comments on GitHub >

github_iconTop Results From Across the Web

Microsoft Squashes Zero-Day, Actively Exploited Bugs in Dec ...
The company assessed the vulnerability as something that attackers are more likely compromise, even though attack complexity itself is high.
Read more >
What is Defect/Bug Life Cycle in Software Testing? Defect Life ...
Answer: When a defect that is found is not of very high importance and the one which can get fixed in the later...
Read more >
The December 2022 Security Update Review
The largest is the update for Experience Manager, which covers 32 bugs. The most severe of these could allow code execution through cross-site ......
Read more >
[Update: Dec. 20] YouTube bugs/issues & pending ...
Here we are tracking all the bugs and problems found on YouTube and their status as well as any pending improvements that are...
Read more >
Can you retain my attention? - Growth Bug
Our decisions are tainted by the investments we accumulate, and the more we invest in something the harder it becomes to abandon it....
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found