How do I alter the parameter cotangents in a custom derivative?
See original GitHub issueYesterday, I started learning Haiku in order to port my codebase over, but I’m running into some showstoppers and I’m wondering if anyone could offer some helpful pointers. My main issue right now is how to port over a custom gradient that has this form in my code:
@custom_vjp
def f(..., weights): ...
def fwd(..., weights):
internal_vjp = vjp(g, weights)
return f(...), internal_vjp
def bwd(residuals, y_bar):
internal_vjp = residuals
weights_bar = internal_vjp(y_bar)
# In fact I split y_bar into a variety of pieces, and then vmap internal_vjp over those different pieces,
# and finally I assemble the weight cotangents. This allows me to funnel different cotangents from
# y_bar to different parameters.
return ..., weights_bar
f.defvjp(fwd, bwd)
(This ability to store one VJP in the residuals of another custom VJP was something that I added to JAX https://github.com/google/jax/pull/3705.)
The problem with porting this over to Haiku is that the forward pass is not an explicit function of the weights, and so the backward pass doesn’t have the opportunity to pass cotangents to the weights.
I’m new to Haiku, but I wonder if it would be possible to do something like this:
def f(...):
return internal_f(..., hk.get_relevant_parameters(g)) # g is the internal function that implements f.
@custom_vjp
def internal_f(..., weights: hk.Params):
t = hk.transform(g)
return t.apply(weights)
def fwd(..., weights: hk.Params):
t = hk.transform(g)
primal, internal_vjp = vjp(t.apply, weights)
return primal, internal_vjp
def bwd(residuals, y_bar):
internal_vjp = residuals
weights_bar = internal_vjp(y_bar)
# Fortunately, I can assemble weights_bar thanks to the filter and merge functions in hk.data_structure.
return ..., weights_bar
internal_f.defvjp(fwd, bwd)
Basically, get_relevant_parameters
would be something like:
def get_relevant_parameters(f: Callable[..., Any], *args: Any, **kwargs: Any) -> Params:
parameters = transform(f).init(0, *args, **kwargs) # RNG value is irrelevant for the parameters returned by init.
# Return parameters from the current frame that are needed by f.
return {k: {l: current_frame().params[k][l] for l, _ in bundle.items()}
for k, bundle in parameters.items()}
Alternatively, I could just pass in current_frame().params
to f
, but t would be annoying to have pass None
as corresponding cotangents for all the parameters in the model in bwd
.
I’m going to keep working on this, but I thought I’d file the issue early in the very likely case I’m missing something. Thanks a lot, and great project by the way!
Issue Analytics
- State:
- Created 3 years ago
- Comments:5 (2 by maintainers)
Of course, feel free to keep this open as long as is useful for you.
We do have an experimental feature called
lift
. This enables you to make the parameters explicit for a specific function call inside a transform. It is not very well documented but a few advanced users are making heavy use of it internally (e.g. to make it easier to scan over the application of modules).I’ve knocked up an example here:
https://colab.research.google.com/gist/tomhennigan/6f1237b5fb268a3d6d2391329ba2d051/example-of-using-hk-experimental-lift.ipynb
I wonder if this will be sufficient for your use case (making relevant parameters explicit inside a haiku transform).