question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Iterating through hk modules

See original GitHub issue

Let’s say I want to iterate through all modules inside an hk model and replace all hn.Linears with my own custom Module or monkey-patch some of their properties. Does haiku currently support something along these lines?

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:6

github_iconTop GitHub Comments

1reaction
tomhennigancommented, Sep 4, 2020

That sounds like it would work 😄 . Re map functions you probably want something like params = jax.tree_map(apply_mask, params).

We also have a pruning example which implements https://arxiv.org/abs/1710.01878. I suspect this could be a useful reference for you.

1reaction
tomhennigancommented, Sep 4, 2020

It sounds like what you want is w = jax.lax.stop_gradient(w) (which is basically what you describe, identify forward and 0 for backwards). If you put that in your custom getter it will cause gradients of those parameters to be zero.

Something to watch out for is that in other frameworks (e.g. TF) stop_gradient causes a None to be returned as the gradient which optimizers then skip. In JAX this causes zeros to be returned. Another way to say this is that other frameworks AD systems allow you to tell the difference between “gradient disabled” and “0 gradient”, in JAX you can only do this if you look at the value of the gradient and conditionally update the parameter and optimizer state based on that.

With some optimizers this can cause a non-zero update to be applied to your parameters (even when gradients are zero), usually this is not what people want when applying stop gradient on parameters (you want to keep the value of those parameters fixed).

If you want to skip updating some params entirely, I would suggest not doing this with custom getters and stop_gradient, but rather partitioning your parameters into ones you want to update and ones you want to hold fixed:

my_f = hk.transform(my_f)

def my_loss_fn(train_params, non_train_params, ..):
  params = hk.data_structures.merge(train_params, non_train_params)
  out = my_f.apply(params, ..)
  ..
  return loss

grad_my_loss_fn = jax.grad(my_loss_fn)

def is_trainable(module_name, param_name, param_value):
  # Can be whatever you want..
  return 'linear' in module_name

params = my_f.init(..)
train_params, non_train_params = hk.data_structures.partition(is_trainable, params)
opt_state = opt.init(train_params)  # Only get opt_state for trainable params, saves some memory :)

for batch in dataset:
  grads = grad_my_loss_fn(train_params, non_train_params, ..)
  # NOTE: grads will only be defined for `train_params`.

  # Only updating train_params.
  updates, opt_state = opt.update(grads, opt_state, train_params)
  train_params = optax.apply_updates(updates, train_params)

You would probably want to rework the above so you could jit the train step but it would look basically the same. You could even close over the constant parameters (rather than pass them in each time) which would allow XLA to potentially do some constant folding and run your code even faster.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Haiku API reference
Transforms a function using Haiku modules into a pair of pure functions. See transform() for general details on Haiku transformations.
Read more >
Iterate over alle modules in the current folder? - Forums - IBM
Hi! I will iterate with a for-loop over all modules in the current folder in that way, that I can assign this current...
Read more >
4. Iterators and Generators - Python Cookbook, 3rd ... - O'Reilly
You want to iterate over all of the possible combinations or permutations of a collection of items. Solution. The itertools module provides three...
Read more >
Finetuning Transformers with JAX + Haiku
With that final hk.Module complete, we'll populate the config object we've been referencing through our hk.Module 's and apply hk.
Read more >
How to iterate through a module's functions - python
But watch out, this will execute every function (callable) in the module. If some specific function receives any arguments it will fail. A...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found