Auxiliary arguments inside lax.root's solve
See original GitHub issueConsider the signature of lax.root, for simplicity omitting the tangent_solve argument:
def root(f, initial_guess, solve):
  """Differentiably solve for a roots of a function.
  This is a low-level routine, mostly intended for internal use in JAX.
  Gradients of root() are defined with respect to closed-over variables from
  the provided function f.
  Args:
    f: function for which to find a root. Should accept a single argument,
      return a tree of arrays with the same structure as its input.
    initial_guess: initial guess for a zero of f.
    solve: function to solve for the roots of f. Should take two positional
      arguments, f and initial_guess, and return a solution with the same
      structure as initial_guess such that func(solution) = 0. In other words,
      the following is assumed to be true (but not checked)::
        solution = solve(f, initial_guess)
        error = f(solution)
        assert all(error == 0)
  Returns:
    The result of calling solve(f, initial_guess) with gradients defined via
    implicit differentiation assuming ``f(solve(f, initial_guess)) == 0``.
  """
The essence of root is that it calculates solve(f, initial_guess) with a custom JVP rule defined via implicit differentiation of the function f.
An implicit requirement of lax.root is that solve(f, initial_guess) must be a pure function, without any closed over variables. This would suffice for the generic scipy.optimize.root, but unfortunately, we really do want data dependent solvers in many interesting cases. How an we use these with lax.root, or with the similarly designed lax.linear_solve?
Unfortunately, as soon as we turn solve into a closure, it breaks when we use tracers, e.g., for jit:
from functools import partial
import jax
from jax import lax
import jax.numpy as np
simple_root = partial(lax.root, tangent_solve=None)
def linear_solve(a, b):
  f = lambda y: np.dot(a, y) - b
  x0 = np.linalg.solve(a, b)
  return simple_root(f, np.zeros_like(b), lambda *args: x0)
a = np.eye(2)
b = np.ones((2,))
linear_solve(a, b)  # DeviceArray([1., 1.], dtype=float32)
jax.jit(linear_solve)(a, b)  # TypeError: No constant handler for type: <class 'jax.interpreters.partial_eval.JaxprTracer'>
The problem is that our closed over tracer leaks into the inside of  the root primitive.
One answer is to compute the solution outside of solve, and simply pass it in as the initial_guess, e.g.,
def linear_solve(a, b):
  f = lambda y: np.dot(a, y) - b
  x0 = np.linalg.solve(a, b)
  return simple_root(f, x0, lambda f, x0: x0,)
This version works! In fact, it suggests that perhaps we should have a different, simpler interface for lax.root:
def root(f, x):
  """Define gradients via implicit differentiation.
  Args:
    f: function for which to find a root. Should accept a single argument,
      return a tree of arrays with the same structure as its input.
    x: zero of f, i.e., all(f(x) == 0).
  Returns:
    x, but with gradients defined via implicit differentiation of f.
  """
Unfortunately, this introduces another problem: if we calculate x with non-differentiable primitives, we won’t even get to calling root. JAX will raise an error about undefined JVP rules, e.g., for while_loop. This is why we added the solve argument to root in the first place.
I see several possible ways to resolve this, none of which are entirely satisfactory:
- Add extra optional arguments to 
lax.root, for explicitly passing auxiliary argument intosolve. These will get passed into therootprimitive directly, allowing the closure problem to be side-stepped, but it’s still surprising that you can’t use a closure insolve. - Change the signature of 
roottoroot(f, x), removing thesolveargument. Encourage liberal use oflax.stop_gradientsto avoid attempting to compute uncomputable or expensive to compjute JVP terms. Maybe we can make this easier by adding either a higher order function or context manager to stop gradients, e.g.,lax.stop_grad(lax.while_loop)(...)orwith lax.disable_gradients(): .... - Somehow evaluate 
solveinto a JAXpr insideroot(the function, not the primitive) without evaluating it’s JVP. This looks sort of like the context manager solution, except contained insideroot. 
Issue Analytics
- State:
 - Created 4 years ago
 - Comments:24 (22 by maintainers)
 

Top Related StackOverflow Question
I’ve been learning more about how JAX’s transformations work.
Transformations like forward-mode automatic differentiation, batching and JIT compilation work by overloading normal Python evaluation, with arrays replaced by symbolic tracers. For example, in forward-mode auto-diff, every variable is replaced by a tracer that keeps track of both original variable (the “primal”) and its derivative (the “tangent”). There’s no graph, just normal Python code.
This is really nice when it comes to debugging, because it means you can really sensible error messages and tracebacks from Python, pointing back exactly to the line where things went wrong. For example, if you try to differentiate a function using
while_loop, you may get a long traceback, but somewhere in there it will point to a line of code that you wrote. You can even drop into a debugger.In contrast, backwards mode auto-diff requires evaluating an abstract computation graph. If something breaks, the Python traceback will not be very interpretable, because it will point to an evaluation deep inside JAX’s auto-diff machinery not your original Python function.
One reason why JAX is such a pleasure to use that such abstract graph evaluation is only done to the bare minimum extent possible. Backwards mode auto-diff is composed into a series of simpler transformation, namely forward mode auto-diff followed by transposition. Only the later part requires evaluating a graph rather than normal Python code.
Making JVP evaluation lazy would negate this advantage. Instead of evaluating normal Python code for JVP rules, we’d be evaluating a stored computation graph.
We run into a similar issue if we try to make errors from missing JVP rules lazy. Now legitimately missing JVP rules result in error messages that no longer point back to user code, so it’s no longer obvious where things went wrong.
For these reasons, I think it’s a non-starter to change JAX’s JVP evaluation to be lazy, even for error reporting. It would be nice for this niche use-cases, but would make debugging harder for everything else.
Sounds like we are in agreement on all points. I’m happy with the discussion and I can’t think of any new points to bring up. If we can’t get reasonably get
define_implicit_gradientto work with vjps without needingstop_gradient_fun(which, as I understand it, is not possible/worth it), then I’m happy with preferringcustom_root.Looking forward to seeing this PR merged! This is quality work 😃