Possible bug in end-dec attention?
See original GitHub issueIn the encoder-decoder architecture, encoder output is passed to decoder as keys to be used in attention. Here (https://github.com/lucidrains/reformer-pytorch/blob/5f5bbf4fd5806f45d2cb3b7373021786b3b34e5b/reformer_pytorch/reformer_pytorch.py#L598) you are concating keys with x (where x is the decoder input) and then apply self-attention. Does it make sense to do self attention on decoder-input and encoder outputs? Because even in the trax codes these two are handled separately: (https://github.com/google/trax/blob/c7c47a14ef8ea5b260ac78c22cbadd6dc1fb605b/trax/models/reformer/reformer.py#L968) at first self attention is applied on the decoder input, and then a seperate encoder-decoder attention is applied between the new representation for decoder and the keys.
I don’t if this is the reason or not but I have this simple copy-reverse task where the loss stops at 2.08. However in the trax code the loss becomes close to 0 after a few steps.
def cycle():
while True:
source = torch.randint(2, 10, (32, 768)).long().cuda()
target_np = np.flip(source.cpu().numpy(),axis=1).copy() #Reverse of copy of numpy array of given tensor
target = torch.from_numpy(target_np).long().cuda()
mask = torch.ones(32, 768).bool().cuda()
yield (source, target, mask)
# First example: Copy Reverse: 768 tokens - vocab size: 256
model = ReformerEncDec(
dim = 512,
enc_num_tokens = 256,
enc_depth = 1,
enc_max_seq_len = 768,
enc_heads=1,
dec_num_tokens = 256,
dec_depth = 1,
dec_max_seq_len = 768,
dec_heads=1,
).cuda()
#model = TrainingWrapper(model)
model.cuda()
# optimizer
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
# training
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
model.train()
for __ in range(GRADIENT_ACCUMULATE_EVERY):
source, target, mask = next(cycle())
loss = model(seq_in=source, seq_out=target, return_loss = True, enc_input_mask=mask)
loss.backward()
print(f'training loss: {loss.item()}')
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optim.step()
optim.zero_grad()
Issue Analytics
- State:
- Created 3 years ago
- Reactions:2
- Comments:18 (13 by maintainers)
Top GitHub Comments
Ok. Thank you 😃
If you run the encoder / decoder from my Sinkhorn project against the Reformer on full attention, you will see they take about the same number of iterations for the task you have, even though Sinkhorn has self-attention and contextual attention separate