Why DifferentiableOptimizer detaches parameters when track_higher_grads = False?
See original GitHub issueHi! Thank you for this awesome library, it helps me a lot.
I am not sure whether I’m missing something, but I’m confused about why DifferentiableOptimizer detaches parameters when track_higher_grads = False
:
which cuts the gradient path back to the original model parameters, even though copy_initial_weights=False
. When we set copy_initial_weights=False
, we want to allow gradients flow back to the original model parameters, but line 257 cut off the gradient flow.
In my use case, I want to implement something like FOMAML and here is a simplify version of my code:
def inner_loop(self, fmodel, diffopt, train_input, train_target):
# ...
def outer_loop(self, task_batch):
self.out_optim.zero_grad()
for task_data in task_batch:
support_input, support_target, query_input, query_target = task_data
with higher.innerloop_ctx(
self.model, self.in_optim, copy_initial_weights=False, track_higher_grads=False
) as (fmodel, diffopt):
self.inner_loop(fmodel, diffopt, support_input, support_target)
query_output = fmodel(query_input)
query_loss = F.cross_entropy(query_output, query_target)
query_loss.backward()
for param in self.model.parameters():
print(param.grad) # output: None
self.out_optim.step()
The gradients were not propagated back to the original parameters. My code works well after I edit the code of higher to:
new_params = params[:]
for group, mapping in zip(self.param_groups, self._group_to_param_list):
for p, index in zip(group['params'], mapping):
new_params[index] = p
I know this problem can be solved by manully mapping the gradients, but I just wonder why detaching parameters is necessary here. And thank you for your nice work again!
Issue Analytics
- State:
- Created 3 years ago
- Reactions:3
- Comments:7
Top GitHub Comments
As a workaround, I think you can use
diff_opt.step(loss, grad_callback=lambda grads: [g.detach() for g in grads])
. This gives the same outer loop gradient as when usingtorch.autograd.grad
to compute gradients withtrack_higher_grads=False
, but.backward()
still works. As a bonus, you also get first-order gradients for inner loop learning rates (if you’re learning those). Withtrack_higher_grads=False
, you don’t get gradients for learning rates.solution is easy, they are doing detach on params p not on gradients g which is totally of course!