Auxiliary arguments inside lax.root's solve
See original GitHub issueConsider the signature of lax.root
, for simplicity omitting the tangent_solve
argument:
def root(f, initial_guess, solve):
"""Differentiably solve for a roots of a function.
This is a low-level routine, mostly intended for internal use in JAX.
Gradients of root() are defined with respect to closed-over variables from
the provided function f.
Args:
f: function for which to find a root. Should accept a single argument,
return a tree of arrays with the same structure as its input.
initial_guess: initial guess for a zero of f.
solve: function to solve for the roots of f. Should take two positional
arguments, f and initial_guess, and return a solution with the same
structure as initial_guess such that func(solution) = 0. In other words,
the following is assumed to be true (but not checked)::
solution = solve(f, initial_guess)
error = f(solution)
assert all(error == 0)
Returns:
The result of calling solve(f, initial_guess) with gradients defined via
implicit differentiation assuming ``f(solve(f, initial_guess)) == 0``.
"""
The essence of root
is that it calculates solve(f, initial_guess)
with a custom JVP rule defined via implicit differentiation of the function f
.
An implicit requirement of lax.root
is that solve(f, initial_guess)
must be a pure function, without any closed over variables. This would suffice for the generic scipy.optimize.root
, but unfortunately, we really do want data dependent solvers in many interesting cases. How an we use these with lax.root
, or with the similarly designed lax.linear_solve
?
Unfortunately, as soon as we turn solve
into a closure, it breaks when we use tracers, e.g., for jit
:
from functools import partial
import jax
from jax import lax
import jax.numpy as np
simple_root = partial(lax.root, tangent_solve=None)
def linear_solve(a, b):
f = lambda y: np.dot(a, y) - b
x0 = np.linalg.solve(a, b)
return simple_root(f, np.zeros_like(b), lambda *args: x0)
a = np.eye(2)
b = np.ones((2,))
linear_solve(a, b) # DeviceArray([1., 1.], dtype=float32)
jax.jit(linear_solve)(a, b) # TypeError: No constant handler for type: <class 'jax.interpreters.partial_eval.JaxprTracer'>
The problem is that our closed over tracer leaks into the inside of the root
primitive.
One answer is to compute the solution outside of solve
, and simply pass it in as the initial_guess
, e.g.,
def linear_solve(a, b):
f = lambda y: np.dot(a, y) - b
x0 = np.linalg.solve(a, b)
return simple_root(f, x0, lambda f, x0: x0,)
This version works! In fact, it suggests that perhaps we should have a different, simpler interface for lax.root
:
def root(f, x):
"""Define gradients via implicit differentiation.
Args:
f: function for which to find a root. Should accept a single argument,
return a tree of arrays with the same structure as its input.
x: zero of f, i.e., all(f(x) == 0).
Returns:
x, but with gradients defined via implicit differentiation of f.
"""
Unfortunately, this introduces another problem: if we calculate x
with non-differentiable primitives, we won’t even get to calling root
. JAX will raise an error about undefined JVP rules, e.g., for while_loop
. This is why we added the solve
argument to root
in the first place.
I see several possible ways to resolve this, none of which are entirely satisfactory:
- Add extra optional arguments to
lax.root
, for explicitly passing auxiliary argument intosolve
. These will get passed into theroot
primitive directly, allowing the closure problem to be side-stepped, but it’s still surprising that you can’t use a closure insolve
. - Change the signature of
root
toroot(f, x)
, removing thesolve
argument. Encourage liberal use oflax.stop_gradients
to avoid attempting to compute uncomputable or expensive to compjute JVP terms. Maybe we can make this easier by adding either a higher order function or context manager to stop gradients, e.g.,lax.stop_grad(lax.while_loop)(...)
orwith lax.disable_gradients(): ...
. - Somehow evaluate
solve
into a JAXpr insideroot
(the function, not the primitive) without evaluating it’s JVP. This looks sort of like the context manager solution, except contained insideroot
.
Issue Analytics
- State:
- Created 4 years ago
- Comments:24 (22 by maintainers)
I’ve been learning more about how JAX’s transformations work.
Transformations like forward-mode automatic differentiation, batching and JIT compilation work by overloading normal Python evaluation, with arrays replaced by symbolic tracers. For example, in forward-mode auto-diff, every variable is replaced by a tracer that keeps track of both original variable (the “primal”) and its derivative (the “tangent”). There’s no graph, just normal Python code.
This is really nice when it comes to debugging, because it means you can really sensible error messages and tracebacks from Python, pointing back exactly to the line where things went wrong. For example, if you try to differentiate a function using
while_loop
, you may get a long traceback, but somewhere in there it will point to a line of code that you wrote. You can even drop into a debugger.In contrast, backwards mode auto-diff requires evaluating an abstract computation graph. If something breaks, the Python traceback will not be very interpretable, because it will point to an evaluation deep inside JAX’s auto-diff machinery not your original Python function.
One reason why JAX is such a pleasure to use that such abstract graph evaluation is only done to the bare minimum extent possible. Backwards mode auto-diff is composed into a series of simpler transformation, namely forward mode auto-diff followed by transposition. Only the later part requires evaluating a graph rather than normal Python code.
Making JVP evaluation lazy would negate this advantage. Instead of evaluating normal Python code for JVP rules, we’d be evaluating a stored computation graph.
We run into a similar issue if we try to make errors from missing JVP rules lazy. Now legitimately missing JVP rules result in error messages that no longer point back to user code, so it’s no longer obvious where things went wrong.
For these reasons, I think it’s a non-starter to change JAX’s JVP evaluation to be lazy, even for error reporting. It would be nice for this niche use-cases, but would make debugging harder for everything else.
Sounds like we are in agreement on all points. I’m happy with the discussion and I can’t think of any new points to bring up. If we can’t get reasonably get
define_implicit_gradient
to work with vjps without needingstop_gradient_fun
(which, as I understand it, is not possible/worth it), then I’m happy with preferringcustom_root
.Looking forward to seeing this PR merged! This is quality work 😃