RNN computation graph does not include cloned RNN parameters
See original GitHub issueSo I have a model with an RNN and a linear layer afterwards:
class BedPredictor(nn.Module):
def __init__(self):
super().__init__()
self.RNN = nn.LSTM(NUM_VARIABLES, HIDDEN_DIM)
self.hidden2Pred = nn.Linear(HIDDEN_DIM, NUM_VARIABLES)
def forward(self, x, hidden=None):
x = x.view(x.shape[0], -1, NUM_VARIABLES)
if hidden is None:
out, hidden = self.RNN(x)
else:
out, hidden = self.RNN(x, hidden)
out = out.squeeze(1)
pred = self.hidden2Pred(out)
return pred, hidden
It learns perfectly well without meta-learning on my data, but as soon as I include learn2learn
it says that there are unused tensors when adapting the learner. Here’s my training loop:
for task in tasks:
learner = maml.clone()
learner.train()
for _ in range(ADAPTATION_STEPS):
seq = get_task(task)
prediction, hidden = learner(seq)
loss = criterion(prediction[:-1], seq[1:])
learner.adapt(loss)
I’ve simplified it, removing parts I’m confident the bug does not come from. This makes the program crash unless I specify allow_unused=True
when instanciating MAML
, but I should’ve have to do this.
I decided to investigate the computation graph. The unused tensors from learner
in the graph are all of the LSTM tensors: module.RNN.weight_ih_l0
, module.RNN.weight_hh_l0
, module.RNN.bias_ih_l0
, module.RNN.bias_hh_l0
.
I also plotted the backwards graphs for both the normal approach (no meta-learning) and with l2l
using torchviz
:
The Variable
s are in blue, and display their name if they’re from either the maml
, learner
or original model used to define the MAML. The second graph, using l2l
, uses RNN parameters which aren’t named, and therefore come from somehwere else, though I have no idea how to trace them.
Therefore I think there is an issue with MAML.clone()
when dealing with RNNs, probably due to the way RNNs work on the inside. This issue persists with nn.GRU
and nn.RNN
.
Issue Analytics
- State:
- Created 3 years ago
- Comments:6 (2 by maintainers)
Top GitHub Comments
Leaving this also here in case anyone runs across the same issue: when using the CudNN backend, double differentiation for RNNs is not supported.
It’s an issue with PyTorch (stemming from CudNN limitations), and has an open bug report: https://github.com/pytorch/pytorch/issues/5261.
The fix is to disable the backend, which seems to have a significant performance hit (but it still runs on GPU):
Thanks a lot for the minimal example @farid-fari.
The bug is specific to RNN modules; essentially, we need to reset the
self._flat_weights
attribute of the RNN since those weights are used in the forward pass and not the weights inself._parameters
. For example, such a reset is computed wheneverself._apply()
is called.I’ll push a fix shortly, and cut a new minor release as this is an important issue. In the meantime, you should be able to fix your example by calling:
before your forward pass. See this Colab for an example. If you are on master, you can even move that line after
maml.clone()
thanks to the newerclone_module
implementation.