Custom VJPs for external functions
See original GitHub issueHi! I want to define custom gradients for a simulation for sensitivity analysis. I have been using autograd for this, but since it is not actively being developed anymore I wanted to switch to jax. In autograd I would write something like this:
from autograd import grad
from autograd.extend import primitive, defvjp
import simulation
@primitive
def sim(params):
results = simulation.run(params)
return results
def sim_vjp(ans, params):
def vjp(g):
# custom gradient code goes here
return gradient
return vjp
defvjp(sim, sim_vjp)
In autograd, this worked fine and I was able to chain this together with some other differentiable transformations and get gradients out of the whole thing. From what I was able to gather, the above would be written in jax as follows:
import jax
import simulation
@jax.custom_transforms
def sim(params):
results = simulation.run(params)
return results
def sim_vjp(ans, params):
def vjp(g):
# custom gradient code goes here
return gradient
return vjp
jax.defvjp_all(sim, sim_vjp)
However, this throws Exception: Tracer can't be used with raw numpy functions.
, which I assume is because the simulation code does not use jax. Are the custom gradients in jax not black-boxes as in autograd anymore, i.e. is this a fundamental limitation or have I screwed something up? Do I need to implement this using lax primitives, and if so, how?
I would be grateful for a minimal example implementing this for some arbitrary non-jax function. This code here for example works in autograd:
from autograd import grad
from autograd.extend import primitive, defvjp
from scipy.ndimage import gaussian_filter
@primitive
def filter(img):
return gaussian_filter(img, 1)
def filter_vjp(ans, img):
def vjp(g):
return gaussian_filter(g, 1)
return vjp
defvjp(filter, filter_vjp)
How would one translate this so it works in jax? Thanks so much!
Issue Analytics
- State:
- Created 4 years ago
- Reactions:5
- Comments:14 (9 by maintainers)
Hi all, sorry for the slow response! @tpr0p @mrbaozi
The issue here is the difference between a
custom_transforms
function and aPrimitive
. You want aPrimitive
.From the
custom_transforms
docstring (emphasis mine):Let me unpack that, because it’s not very detailed.
The
custom_transforms
function is useful when you have a Python function that JAX can handle just fine (to compile, differentiate, batch, etc.) but you still want to override how it behaves under one (or more) of those transformations while retaining the default behavior for the others. So acustom_transforms
function isn’t totally opaque to the tracing/transforming machinery: in fact, if you don’t override any of its transformation rules, then it’s traced/transformed into just like a regular function. That’s different from Autograd’s primitives, because those were always totally opaque. The main use case forcustom_transforms
is where you have a Python function implemented withjax.numpy
and you like how it behaves underjit
, but you want to control how it behaves undergrad
.In contrast, a JAX
Primitive
(defined in core.py) is directly analogous Autograd’sprimitive
, in that it sets up an opaque function. When you define aPrimitive
you need to define a rule for every transformation you want to use (rather than just the ones you want to override). Most of JAX’s primitives are in thelax
package, and we implement everything on top of those.We haven’t documented how to set up your own
Primitive
s yet (it’s the venerable issue #116), but it’s not too hard. Here’s an adaptation of @tpr0p’s example:At this point there are no rules defined for
foo_p
, not even an evaluation rule (we consider eval to be just another transformation!). Here’s the error we get if we try to call it:Let’s define an evaluation rule in terms of
onp
, a totally opaque un-traceable call into C code:And now:
Woohoo! But we can’t do anything else with it. We can add a VJP rule like this (though actually for all our primitives we instead define a JVP rule, this might be more familiar, cf. #636):
And now:
There’s also an API closer to the one in @tpr0p 's original example:
To use
jit
you’ll need to define a translation rule.Does that make sense? What’d I miss?
Actually, for external functions a new primitive should be used, not custom_jvp/vjp stuff. That is, external functions fall into case 2 articulated at the top of the Custom derivative rules for JAX-transformable Python functions tutorial.
I think this topic is important enough that it needs its own tutorial explanation (i.e. I don’t think the “How JAX primitives work” is quite the right explanation for people looking to solve this particular issue, just because we should have more direct examples for this use case).