Iterating through hk modules
See original GitHub issueLet’s say I want to iterate through all modules inside an hk model and replace all hn.Linear
s with my own custom Module
or monkey-patch some of their properties. Does haiku currently support something along these lines?
Issue Analytics
- State:
- Created 3 years ago
- Comments:6
Top Results From Across the Web
Haiku API reference
Transforms a function using Haiku modules into a pair of pure functions. See transform() for general details on Haiku transformations.
Read more >Iterate over alle modules in the current folder? - Forums - IBM
Hi! I will iterate with a for-loop over all modules in the current folder in that way, that I can assign this current...
Read more >4. Iterators and Generators - Python Cookbook, 3rd ... - O'Reilly
You want to iterate over all of the possible combinations or permutations of a collection of items. Solution. The itertools module provides three...
Read more >Finetuning Transformers with JAX + Haiku
With that final hk.Module complete, we'll populate the config object we've been referencing through our hk.Module 's and apply hk.
Read more >How to iterate through a module's functions - python
But watch out, this will execute every function (callable) in the module. If some specific function receives any arguments it will fail. A...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
That sounds like it would work 😄 . Re map functions you probably want something like
params = jax.tree_map(apply_mask, params)
.We also have a pruning example which implements https://arxiv.org/abs/1710.01878. I suspect this could be a useful reference for you.
It sounds like what you want is
w = jax.lax.stop_gradient(w)
(which is basically what you describe, identify forward and 0 for backwards). If you put that in your custom getter it will cause gradients of those parameters to be zero.Something to watch out for is that in other frameworks (e.g. TF)
stop_gradient
causes aNone
to be returned as the gradient which optimizers then skip. In JAX this causes zeros to be returned. Another way to say this is that other frameworks AD systems allow you to tell the difference between “gradient disabled” and “0 gradient”, in JAX you can only do this if you look at the value of the gradient and conditionally update the parameter and optimizer state based on that.With some optimizers this can cause a non-zero update to be applied to your parameters (even when gradients are zero), usually this is not what people want when applying stop gradient on parameters (you want to keep the value of those parameters fixed).
If you want to skip updating some params entirely, I would suggest not doing this with custom getters and stop_gradient, but rather partitioning your parameters into ones you want to update and ones you want to hold fixed:
You would probably want to rework the above so you could
jit
the train step but it would look basically the same. You could even close over the constant parameters (rather than pass them in each time) which would allow XLA to potentially do some constant folding and run your code even faster.