question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Feature request: utility functions to allow stopping meta-gradient propagation

See original GitHub issue

Hi there, it would be really useful if higher’s API allowed blocking gradients at a point in the graph

Use Cases:

  • When working with Actor Critic models, I’d want to be able to allow gradients to flow through the critic when training it, but block them when using the critic to train the actor. It would be useful if i could make a detached copy of the critic for the second part.

  • When working with meta-gradients that don’t go over the whole training loop (eg only use meta-gradients on K unrolled steps), it would be useful if we could detach the models so gradients don’t flow past K.

Proposed API The simplest and most intuitive way in my opinion is to make the differentiable optimizers use with pytorch’s existing .state_dict() and .load_state_dict() , so we can call them on pytorch’s normal modules and optimziers like:

nn_module_model.load_state_dict(functional_model.state_dict())
normal_optimizer.load_state_dict(differentiable_optimizer.state_dict())

To continue training with meta gradients, we can then re-create the higher versions of these. The functional models already do this, but the differentiable optimizers don’t. I’ve hacked together a patch that works for my use case (2-step unrolling), but I doubt would work in general.

An extension for convenience Lastly, a more convenient (but less intuitive) extension to this idea would be to have the old .detach() and .detach_() methods for torch tensors apply to the Functional models. it didn’t amke sense for nn.Modules since they weren’t differentiable, but it would be so convenient to for Functional models since they (their parameters) are now differentiable too:

  • detach() should create a fresh copy of the whole functional model that is cut from the computation graph, which is really useful for the actor critic case:
# critic needs to be trained and backprop its meta gradients
critic_loss = F.mse(
    fmodel_critic(state,action),
    target_q_value
) 

# actor also needs to be trained and backprop its meta-gradients, but critic used in training it should NOT backprop policy's loss into its own parameters
fmodel_critic_for_training_actor = fmodel_critic.detach() # copy of fmodel_critic that doesn't alter original's params
actor_loss = - fmodel_critic_for_training_actor(
    state,
    fmodel_actor(state)
) 
# backpropagation and updates now ignore the detached copy
diffopt_critic.step(critic_loss)
diffopt_actor.step(actor_loss)
  • detach_() should detach all param tensors in place. This would also be really useful for the and more efficient for the K step unroll use case since we don’t need to deep copy back to a pytorch nn.Module model, then copy again to a functional model to continue. With this .detach_() version we can just do that in place :
with  higher.innerloop_ctx(model,optim) as (fmodel,diffoptim):
    for i in range(num_steps): # training loop
        for k in range(meta_gradient_steps): # inner loop i need to backprop through
            ....#get losses
            train_step(fmodel,loss,diffoptim)
        ... # do meta updates
        fmodel = fmodel.detach_() # block backprop here by detaching all params in palce

Of course, these are just for convenience and efficiency since the state_dict methods could functionally do the same things, albeit slower

Issue Analytics

  • State:open
  • Created 4 years ago
  • Comments:11 (5 by maintainers)

github_iconTop GitHub Comments

1reaction
egrefencommented, Oct 25, 2019

Thanks, this is super. Just FYI, I am on leave all of November, and most of December. If this is urgent, please flag this here. I will try and find time when I return early December to implement an appropriate and robust state_dict method for patched modules and differentiable optimizers which does the right thing.

In the meantime, if you fancy having a stab at doing this yourself, we welcome contributions and it would be very helpful indeed.

0reactions
thomas0809commented, Oct 23, 2020

@ihexx Note the patched modules already support this, i.e. you can do something like

module = ...
fmodule = higher.monkeypatch(module)
with torch.nograd():
    for p in module.parameters():
        p.add_(1)  # modify original params so that fmodule params are different
module.load_state_dict(fmodule.state_dict())  # restore params from fmodule

Does this fit your needs on the module side of things?

I’ll look at the optimizers later this week or early next week as there are other bugs that have higher priority.

Hi,

I am wondering are there any updates on the state_dict() for the optimizers?

Read more comments on GitHub >

github_iconTop Results From Across the Web

Scalable Online Recurrent Learning Using Columnar Neural ...
gradients from parameters connecting the features from one column to the feature ... The stop-gradient operations allows us to compute ∂hi.
Read more >
Toward Generate-and-Test Algorithms for Continual Feature ...
The backpropagation algorithm is a fundamental algorithm for training mod- ern artificial neural networks (ANNs). However, it is known the backpropa-.
Read more >
Meta Learning for Control - UC Berkeley EECS
number of prior experiences/tasks to enable fast adaptation and ... crafted features that effectively encode a strong prior on the relevant ...
Read more >
Automatic machine learning
The final published version features the final layout of the ... This yields a predictive distribution, which allows to stop training based.
Read more >
Learning Representations by Stochastic Meta-Gradient ...
Request PDF | On Dec 30, 2017, Vivek Veeriah and others published ... applies meta-gradient descent for directly learning good features from ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found