Feature request: utility functions to allow stopping meta-gradient propagation
See original GitHub issueHi there, it would be really useful if higher’s API allowed blocking gradients at a point in the graph
Use Cases:
-
When working with Actor Critic models, I’d want to be able to allow gradients to flow through the critic when training it, but block them when using the critic to train the actor. It would be useful if i could make a detached copy of the critic for the second part.
-
When working with meta-gradients that don’t go over the whole training loop (eg only use meta-gradients on K unrolled steps), it would be useful if we could detach the models so gradients don’t flow past K.
Proposed API
The simplest and most intuitive way in my opinion is to make the differentiable optimizers use with pytorch’s existing .state_dict()
and .load_state_dict()
, so we can call them on pytorch’s normal modules and optimziers like:
nn_module_model.load_state_dict(functional_model.state_dict())
normal_optimizer.load_state_dict(differentiable_optimizer.state_dict())
To continue training with meta gradients, we can then re-create the higher versions of these. The functional models already do this, but the differentiable optimizers don’t. I’ve hacked together a patch that works for my use case (2-step unrolling), but I doubt would work in general.
An extension for convenience
Lastly, a more convenient (but less intuitive) extension to this idea would be to have the old .detach()
and .detach_()
methods for torch tensors apply to the Functional models. it didn’t amke sense for nn.Modules since they weren’t differentiable, but it would be so convenient to for Functional models since they (their parameters) are now differentiable too:
detach()
should create a fresh copy of the whole functional model that is cut from the computation graph, which is really useful for the actor critic case:
# critic needs to be trained and backprop its meta gradients
critic_loss = F.mse(
fmodel_critic(state,action),
target_q_value
)
# actor also needs to be trained and backprop its meta-gradients, but critic used in training it should NOT backprop policy's loss into its own parameters
fmodel_critic_for_training_actor = fmodel_critic.detach() # copy of fmodel_critic that doesn't alter original's params
actor_loss = - fmodel_critic_for_training_actor(
state,
fmodel_actor(state)
)
# backpropagation and updates now ignore the detached copy
diffopt_critic.step(critic_loss)
diffopt_actor.step(actor_loss)
detach_()
should detach all param tensors in place. This would also be really useful for the and more efficient for the K step unroll use case since we don’t need to deep copy back to a pytorch nn.Module model, then copy again to a functional model to continue. With this.detach_()
version we can just do that in place :
with higher.innerloop_ctx(model,optim) as (fmodel,diffoptim):
for i in range(num_steps): # training loop
for k in range(meta_gradient_steps): # inner loop i need to backprop through
....#get losses
train_step(fmodel,loss,diffoptim)
... # do meta updates
fmodel = fmodel.detach_() # block backprop here by detaching all params in palce
Of course, these are just for convenience and efficiency since the state_dict
methods could functionally do the same things, albeit slower
Issue Analytics
- State:
- Created 4 years ago
- Comments:11 (5 by maintainers)
Top GitHub Comments
Thanks, this is super. Just FYI, I am on leave all of November, and most of December. If this is urgent, please flag this here. I will try and find time when I return early December to implement an appropriate and robust
state_dict
method for patched modules and differentiable optimizers which does the right thing.In the meantime, if you fancy having a stab at doing this yourself, we welcome contributions and it would be very helpful indeed.
Hi,
I am wondering are there any updates on the
state_dict()
for the optimizers?