question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Custom VJPs for external functions

See original GitHub issue

Hi! I want to define custom gradients for a simulation for sensitivity analysis. I have been using autograd for this, but since it is not actively being developed anymore I wanted to switch to jax. In autograd I would write something like this:

from autograd import grad
from autograd.extend import primitive, defvjp
import simulation

@primitive
def sim(params):
    results = simulation.run(params)
    return results

def sim_vjp(ans, params):
    def vjp(g):
        # custom gradient code goes here
        return gradient
    return vjp

defvjp(sim, sim_vjp)

In autograd, this worked fine and I was able to chain this together with some other differentiable transformations and get gradients out of the whole thing. From what I was able to gather, the above would be written in jax as follows:

import jax
import simulation

@jax.custom_transforms
def sim(params):
    results = simulation.run(params)
    return results

def sim_vjp(ans, params):
    def vjp(g):
        # custom gradient code goes here
        return gradient
    return vjp

jax.defvjp_all(sim, sim_vjp)

However, this throws Exception: Tracer can't be used with raw numpy functions., which I assume is because the simulation code does not use jax. Are the custom gradients in jax not black-boxes as in autograd anymore, i.e. is this a fundamental limitation or have I screwed something up? Do I need to implement this using lax primitives, and if so, how?

I would be grateful for a minimal example implementing this for some arbitrary non-jax function. This code here for example works in autograd:

from autograd import grad
from autograd.extend import primitive, defvjp
from scipy.ndimage import gaussian_filter

@primitive
def filter(img):
    return gaussian_filter(img, 1)

def filter_vjp(ans, img):
    def vjp(g):
        return gaussian_filter(g, 1)
    return vjp

defvjp(filter, filter_vjp)

How would one translate this so it works in jax? Thanks so much!

Issue Analytics

  • State:open
  • Created 4 years ago
  • Reactions:5
  • Comments:14 (9 by maintainers)

github_iconTop GitHub Comments

8reactions
mattjjcommented, Aug 18, 2019

Hi all, sorry for the slow response! @tpr0p @mrbaozi

The issue here is the difference between a custom_transforms function and a Primitive. You want a Primitive.

From the custom_transforms docstring (emphasis mine):

A primary use case of custom_transforms is defining custom VJP rules (aka custom gradients) for a Python function, while still supporting other transformations like jax.jit and jax.vmap.

Let me unpack that, because it’s not very detailed.

The custom_transforms function is useful when you have a Python function that JAX can handle just fine (to compile, differentiate, batch, etc.) but you still want to override how it behaves under one (or more) of those transformations while retaining the default behavior for the others. So a custom_transforms function isn’t totally opaque to the tracing/transforming machinery: in fact, if you don’t override any of its transformation rules, then it’s traced/transformed into just like a regular function. That’s different from Autograd’s primitives, because those were always totally opaque. The main use case for custom_transforms is where you have a Python function implemented with jax.numpy and you like how it behaves under jit, but you want to control how it behaves under grad.

In contrast, a JAX Primitive (defined in core.py) is directly analogous Autograd’s primitive, in that it sets up an opaque function. When you define a Primitive you need to define a rule for every transformation you want to use (rather than just the ones you want to override). Most of JAX’s primitives are in the lax package, and we implement everything on top of those.

We haven’t documented how to set up your own Primitives yet (it’s the venerable issue #116), but it’s not too hard. Here’s an adaptation of @tpr0p’s example:

from jax import core
import numpy as onp  # I changed this name out of habit

# Set up a Primitive, using a handy level of indirection
def foo(x):
  return foo_p.bind(x)
foo_p = core.Primitive('foo')

At this point there are no rules defined for foo_p, not even an evaluation rule (we consider eval to be just another transformation!). Here’s the error we get if we try to call it:

In [2]: foo(3)
NotImplementedError: Evaluation rule for 'foo' not implemented

Let’s define an evaluation rule in terms of onp, a totally opaque un-traceable call into C code:

foo_p.def_impl(onp.square)

And now:

In [4]: foo(3)
Out[4]: 9

Woohoo! But we can’t do anything else with it. We can add a VJP rule like this (though actually for all our primitives we instead define a JVP rule, this might be more familiar, cf. #636):

from jax.interpreters import ad
ad.defvjp(foo_p, lambda g, x: 2 * x * g)

And now:

In [5]: from jax import grad
In [6]: grad(foo)(3.)
Out[6]: DeviceArray(6., dtype=float32)

There’s also an API closer to the one in @tpr0p 's original example:

def f_vjp(x):
  return foo(x), lambda g: (2 * g * x,)
ad.defvjp_all(foo_p, f_vjp)

To use jit you’ll need to define a translation rule.

Does that make sense? What’d I miss?

3reactions
mattjjcommented, Apr 16, 2020

Actually, for external functions a new primitive should be used, not custom_jvp/vjp stuff. That is, external functions fall into case 2 articulated at the top of the Custom derivative rules for JAX-transformable Python functions tutorial.

I think this topic is important enough that it needs its own tutorial explanation (i.e. I don’t think the “How JAX primitives work” is quite the right explanation for people looking to solve this particular issue, just because we should have more direct examples for this use case).

Read more comments on GitHub >

github_iconTop Results From Across the Web

Custom derivative rules for JAX-transformable Python functions
JAX will automatically transpose the linear computation on tangent values from our custom JVP rule, computing the VJP as efficiently as if we ......
Read more >
AngularJS: Call external functions from directive?
I'm trying to create a custom component that uses external functions defined in the controller, but I'm experiencing different problems.
Read more >
Use External Google Sheets Add-on Function inside Custom ...
In short, you cannot access myAddOnFunction() programmatically. However, what you could try is simply inserting the necessary text into your ...
Read more >
Introduction to External Functions - Snowflake Documentation
(For a description of the different types of endpoints, see endpoints .) Snowflake external functions and API integrations do not support AWS custom...
Read more >
External Functions Available to DataWeave
4 cannot use the Mule namespace or use Mule Runtime functions in custom modules, only in DataWeave scripts and mappings. For DataWeave runtime...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found