question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Encoder-decoder fails at KMeans attention

See original GitHub issue

I haven’t been able to dig into the root cause here yet, but I’m getting the following error when trying to run an encoder-decoder:

 File "/home/tom/.local/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 15, in decorate_context
    return func(*args, **kwargs)
  File "/home/tom/.local/lib/python3.8/site-packages/routing_transformer/encoder_decoder.py", line 77, in generate
    return self.dec.generate(seq_out_start, max_seq_len, context = context, **{**dec_kwargs, **kwargs})
  File "/home/tom/.local/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 15, in decorate_context
    return func(*args, **kwargs)
  File "/home/tom/.local/lib/python3.8/site-packages/routing_transformer/autoregressive_wrapper.py", line 71, in generate
    logits, _ = self.net(x, input_mask=input_mask, **kwargs)
  File "/home/tom/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/tom/.local/lib/python3.8/site-packages/routing_transformer/autopadder.py", line 33, in forward
    return self.net(x, **kwargs)
  File "/home/tom/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/tom/.local/lib/python3.8/site-packages/routing_transformer/routing_transformer.py", line 614, in forward
    x, loss = self.routing_transformer(x, **kwargs)
  File "/home/tom/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/tom/.local/lib/python3.8/site-packages/routing_transformer/routing_transformer.py", line 592, in forward
    x, loss = self.layers(x, **kwargs)
  File "/home/tom/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/tom/.local/lib/python3.8/site-packages/routing_transformer/reversible.py", line 200, in forward
    out, f_loss, g_loss =  _ReversibleFunction.apply(x, blocks, args)
  File "/home/tom/.local/lib/python3.8/site-packages/routing_transformer/reversible.py", line 137, in forward
    x, f_loss, g_loss = block(x, **kwarg)
  File "/home/tom/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/tom/.local/lib/python3.8/site-packages/routing_transformer/reversible.py", line 80, in forward
    f_out, f_loss = cast_return(self.f(x2, record_rng=self.training, **f_args), requires_grad = False)
  File "/home/tom/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/tom/.local/lib/python3.8/site-packages/routing_transformer/reversible.py", line 53, in forward
    return self.net(*args, **kwargs)
  File "/home/tom/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/tom/.local/lib/python3.8/site-packages/routing_transformer/routing_transformer.py", line 121, in forward
    return self.fn(x, **kwargs)
  File "/home/tom/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/tom/.local/lib/python3.8/site-packages/routing_transformer/routing_transformer.py", line 524, in forward
    global_out, loss = self.global_attn(q, k, v, query_mask = input_mask, key_mask = context_mask)
  File "/home/tom/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/tom/.local/lib/python3.8/site-packages/routing_transformer/routing_transformer.py", line 390, in forward
    dists, aux_loss = self.kmeans(torch.cat((q, k), dim=2), update_kmeans)
  File "/home/tom/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/tom/.local/lib/python3.8/site-packages/routing_transformer/routing_transformer.py", line 339, in forward
    self.init(x)
  File "/home/tom/.local/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 15, in decorate_context
    return func(*args, **kwargs)
  File "/home/tom/.local/lib/python3.8/site-packages/routing_transformer/routing_transformer.py", line 325, in init
    self.means.data.copy_(means)
RuntimeError: The size of tensor a (64) must match the size of tensor b (2) at non-singleton dimension 1

Here are my model params:

model = RoutingTransformerEncDec(
    enc_num_tokens=7000,
    dec_num_tokens=7000,
    dim=512,
    enc_ff_mult=4,
    dec_ff_mult=4,
    enc_depth=16,
    dec_depth=16,
    enc_heads=8,
    dec_heads=8,
    enc_max_seq_len=8192,
    dec_max_seq_len=8192,
    enc_window_size=128,
    dec_window_size=128,
    enc_causal=False,
    #dec_causal=True,  # decoder is always set to causal,
    enc_ff_dropout=0.05,
    dec_ff_dropout=0.05,
    enc_reversible=True,
    dec_reversible=True,
)

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:16 (16 by maintainers)

github_iconTop GitHub Comments

1reaction
tomweingartencommented, Jul 2, 2020

@lucidrains Quick update: Running with the new version fixed my training loss problem! Unfortunately I’m seeing some weird results for predictions that I can’t quite explain yet, but it’s going to take me a bit longer to dig into why that is. I’m also going to play around with mixed attention head locality too, thanks for the tip!

1reaction
tomweingartencommented, Jun 23, 2020

Thanks! You’re right that it is silly to run the generate() method before fitting. I do it just as a last check to make sure I haven’t done anything weird like accidentally load a checkpoint when I shouldn’t have. Thanks for the fix!

Read more comments on GitHub >

github_iconTop Results From Across the Web

Fully featured implementation of Routing Transformer - GitHub
A fully featured implementation of Routing Transformer. The paper proposes using k-means to route similar queries / keys into the same cluster for...
Read more >
What You Never Knew About Attention Mechanisms - Medium
An encoder-decoder model uses a neural network to translate one input to another through the use of encoded feature representation. The attention mechanism ......
Read more >
Self-Attention based encoder-Decoder for multistep human ...
Our research reveals, that a technique of deep learning, the attention-based Encoder-Decoder (ED) architecture, provides the best accuracy for the mobility ...
Read more >
Joint optimization of an autoencoder for clustering and ...
We define the Clustering Module (CM) as the one-hidden layer autoencoder with encoding and decoding functions {\mathcal {F}} and {\mathcal ...
Read more >
Attention-Based Sequence-to-Sequence Model for Time ...
The model consists of two parts, encoder and decoder. The encoder part is a BIGRU recurrent neural network and incorporates a self-attentive ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found