question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

How do I alter the parameter cotangents in a custom derivative?

See original GitHub issue

Yesterday, I started learning Haiku in order to port my codebase over, but I’m running into some showstoppers and I’m wondering if anyone could offer some helpful pointers. My main issue right now is how to port over a custom gradient that has this form in my code:

@custom_vjp
def f(..., weights): ...

def fwd(..., weights):
    internal_vjp = vjp(g, weights)
    return f(...), internal_vjp

def bwd(residuals, y_bar):
    internal_vjp = residuals
    weights_bar = internal_vjp(y_bar)
    # In fact I split y_bar into a variety of pieces, and then vmap internal_vjp over those different pieces,
    # and finally I assemble the weight cotangents.  This allows me to funnel different cotangents from
    # y_bar to different parameters.
    return ..., weights_bar

f.defvjp(fwd, bwd)

(This ability to store one VJP in the residuals of another custom VJP was something that I added to JAX https://github.com/google/jax/pull/3705.)

The problem with porting this over to Haiku is that the forward pass is not an explicit function of the weights, and so the backward pass doesn’t have the opportunity to pass cotangents to the weights.

I’m new to Haiku, but I wonder if it would be possible to do something like this:

def f(...):
    return internal_f(..., hk.get_relevant_parameters(g))  # g is the internal function that implements f.

@custom_vjp
def internal_f(..., weights: hk.Params):
    t = hk.transform(g)
    return t.apply(weights)

def fwd(..., weights: hk.Params):
    t = hk.transform(g)
    primal, internal_vjp = vjp(t.apply, weights)
    return primal, internal_vjp

def bwd(residuals, y_bar):
    internal_vjp = residuals
    weights_bar = internal_vjp(y_bar)
    # Fortunately, I can assemble weights_bar thanks to the filter and merge functions in hk.data_structure.
    return ..., weights_bar

internal_f.defvjp(fwd, bwd)

Basically, get_relevant_parameters would be something like:

def get_relevant_parameters(f: Callable[..., Any], *args: Any, **kwargs: Any) -> Params:
  parameters = transform(f).init(0, *args, **kwargs)  # RNG value is irrelevant for the parameters returned by init.
  # Return parameters from the current frame that are needed by f.
  return {k: {l: current_frame().params[k][l] for l, _ in bundle.items()}
          for k, bundle in parameters.items()}

Alternatively, I could just pass in current_frame().params to f, but t would be annoying to have pass None as corresponding cotangents for all the parameters in the model in bwd.

I’m going to keep working on this, but I thought I’d file the issue early in the very likely case I’m missing something. Thanks a lot, and great project by the way!

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:5 (2 by maintainers)

github_iconTop GitHub Comments

1reaction
tomhennigancommented, Feb 5, 2021

Of course, feel free to keep this open as long as is useful for you.

1reaction
tomhennigancommented, Feb 5, 2021

We do have an experimental feature called lift. This enables you to make the parameters explicit for a specific function call inside a transform. It is not very well documented but a few advanced users are making heavy use of it internally (e.g. to make it easier to scan over the application of modules).

I’ve knocked up an example here:

https://colab.research.google.com/gist/tomhennigan/6f1237b5fb268a3d6d2391329ba2d051/example-of-using-hk-experimental-lift.ipynb

I wonder if this will be sufficient for your use case (making relevant parameters explicit inside a haiku transform).

Read more comments on GitHub >

github_iconTop Results From Across the Web

Custom derivative rules for JAX-transformable Python functions
When trying to track down the source of a nan runtime error, or just examine carefully the cotangent (gradient) values being propagated, it...
Read more >
Derivative of a Cotangent Function - YouTube
This video covers the basics of getting the derivative of a cotangent function.I'll do a couple of examples involving different combinations ...
Read more >
Derivative of Cotangent Function - YouTube
Learn how to find the derivative of the cotangent function.
Read more >
Introduction · ChainRules - JuliaDiff
ChainRules is all about providing a rich set of rules for differentiation. ... The core notion is sometimes called custom AD primitives, custom...
Read more >
What is the Derivative of tan(x)? - Video & Lesson Transcript
The derivative of tangent is secant squared and the derivative of cotangent is negative cosecant squared. Using this new rule and the chain...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found