question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

RNN computation graph does not include cloned RNN parameters

See original GitHub issue

So I have a model with an RNN and a linear layer afterwards:

class BedPredictor(nn.Module):
    def __init__(self):
        super().__init__()
        self.RNN = nn.LSTM(NUM_VARIABLES, HIDDEN_DIM)
        self.hidden2Pred = nn.Linear(HIDDEN_DIM, NUM_VARIABLES)

    def forward(self, x, hidden=None):
        x = x.view(x.shape[0], -1, NUM_VARIABLES)

        if hidden is None:
            out, hidden = self.RNN(x)
        else:
            out, hidden = self.RNN(x, hidden)

        out = out.squeeze(1)
        pred = self.hidden2Pred(out)

        return pred, hidden

It learns perfectly well without meta-learning on my data, but as soon as I include learn2learn it says that there are unused tensors when adapting the learner. Here’s my training loop:

for task in tasks:
        learner = maml.clone()
        learner.train()

        for _ in range(ADAPTATION_STEPS):
            seq = get_task(task)

            prediction, hidden = learner(seq)

            loss = criterion(prediction[:-1], seq[1:])
            learner.adapt(loss)

I’ve simplified it, removing parts I’m confident the bug does not come from. This makes the program crash unless I specify allow_unused=True when instanciating MAML, but I should’ve have to do this.

I decided to investigate the computation graph. The unused tensors from learner in the graph are all of the LSTM tensors: module.RNN.weight_ih_l0, module.RNN.weight_hh_l0, module.RNN.bias_ih_l0, module.RNN.bias_hh_l0.

I also plotted the backwards graphs for both the normal approach (no meta-learning) and with l2l using torchviz:

image

image

The Variables are in blue, and display their name if they’re from either the maml, learner or original model used to define the MAML. The second graph, using l2l, uses RNN parameters which aren’t named, and therefore come from somehwere else, though I have no idea how to trace them.

Therefore I think there is an issue with MAML.clone() when dealing with RNNs, probably due to the way RNNs work on the inside. This issue persists with nn.GRU and nn.RNN.

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:6 (2 by maintainers)

github_iconTop GitHub Comments

4reactions
farid-faricommented, Apr 24, 2020

Leaving this also here in case anyone runs across the same issue: when using the CudNN backend, double differentiation for RNNs is not supported.

RuntimeError: derivative for _cudnn_rnn_backward is not implemented

It’s an issue with PyTorch (stemming from CudNN limitations), and has an open bug report: https://github.com/pytorch/pytorch/issues/5261.

The fix is to disable the backend, which seems to have a significant performance hit (but it still runs on GPU):

torch.backends.cudnn.enabled = False
1reaction
seba-1511commented, Apr 23, 2020

Thanks a lot for the minimal example @farid-fari.

The bug is specific to RNN modules; essentially, we need to reset the self._flat_weights attribute of the RNN since those weights are used in the forward pass and not the weights inself._parameters. For example, such a reset is computed whenever self._apply() is called.

I’ll push a fix shortly, and cut a new minor release as this is an important issue. In the meantime, you should be able to fix your example by calling:

learner.module._apply(lambda x: x)

before your forward pass. See this Colab for an example. If you are on master, you can even move that line after maml.clone() thanks to the newer clone_module implementation.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Compute grads of cloned tensor Pytorch - Stack Overflow
I wanted to copy the hidden states of my RNN into a list and get the gradients of the loss wrt to each...
Read more >
ComputationGraph (deeplearning4j-nn 0.6.0 API) - Javadoc.io
Method Summary ; void, init(org.nd4j.linalg.api.ndarray.INDArray parameters, boolean cloneParametersArray). Initialize the ComputationGraph, optionally with an ...
Read more >
9.4. Recurrent Neural Networks - Dive into Deep Learning
Parameters of the RNN include the weights W x h ∈ R d × h , W h h ∈ R h ×...
Read more >
Getting started with JAX (MLPs, CNNs & RNNs)
We use partial to “clone” all the parameters to use at all timesteps. As before we can now “instantiate” our RNN and all...
Read more >
Dilated Recurrent Neural Networks - NIPS papers
Thus, the Dilated LSTM is not a general solution for modeling at multiple temporal resolutions. We empirically validate the DILATEDRNN in multiple RNN...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found