question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Auxiliary arguments inside lax.root's solve

See original GitHub issue

Consider the signature of lax.root, for simplicity omitting the tangent_solve argument:

def root(f, initial_guess, solve):
  """Differentiably solve for a roots of a function.

  This is a low-level routine, mostly intended for internal use in JAX.
  Gradients of root() are defined with respect to closed-over variables from
  the provided function f.

  Args:
    f: function for which to find a root. Should accept a single argument,
      return a tree of arrays with the same structure as its input.
    initial_guess: initial guess for a zero of f.
    solve: function to solve for the roots of f. Should take two positional
      arguments, f and initial_guess, and return a solution with the same
      structure as initial_guess such that func(solution) = 0. In other words,
      the following is assumed to be true (but not checked)::

        solution = solve(f, initial_guess)
        error = f(solution)
        assert all(error == 0)

  Returns:
    The result of calling solve(f, initial_guess) with gradients defined via
    implicit differentiation assuming ``f(solve(f, initial_guess)) == 0``.
  """

The essence of root is that it calculates solve(f, initial_guess) with a custom JVP rule defined via implicit differentiation of the function f.

An implicit requirement of lax.root is that solve(f, initial_guess) must be a pure function, without any closed over variables. This would suffice for the generic scipy.optimize.root, but unfortunately, we really do want data dependent solvers in many interesting cases. How an we use these with lax.root, or with the similarly designed lax.linear_solve?

Unfortunately, as soon as we turn solve into a closure, it breaks when we use tracers, e.g., for jit:

from functools import partial
import jax
from jax import lax
import jax.numpy as np

simple_root = partial(lax.root, tangent_solve=None)

def linear_solve(a, b):
  f = lambda y: np.dot(a, y) - b
  x0 = np.linalg.solve(a, b)
  return simple_root(f, np.zeros_like(b), lambda *args: x0)

a = np.eye(2)
b = np.ones((2,))
linear_solve(a, b)  # DeviceArray([1., 1.], dtype=float32)
jax.jit(linear_solve)(a, b)  # TypeError: No constant handler for type: <class 'jax.interpreters.partial_eval.JaxprTracer'>

The problem is that our closed over tracer leaks into the inside of the root primitive.

One answer is to compute the solution outside of solve, and simply pass it in as the initial_guess, e.g.,

def linear_solve(a, b):
  f = lambda y: np.dot(a, y) - b
  x0 = np.linalg.solve(a, b)
  return simple_root(f, x0, lambda f, x0: x0,)

This version works! In fact, it suggests that perhaps we should have a different, simpler interface for lax.root:

def root(f, x):
  """Define gradients via implicit differentiation.

  Args:
    f: function for which to find a root. Should accept a single argument,
      return a tree of arrays with the same structure as its input.
    x: zero of f, i.e., all(f(x) == 0).

  Returns:
    x, but with gradients defined via implicit differentiation of f.
  """

Unfortunately, this introduces another problem: if we calculate x with non-differentiable primitives, we won’t even get to calling root. JAX will raise an error about undefined JVP rules, e.g., for while_loop. This is why we added the solve argument to root in the first place.

I see several possible ways to resolve this, none of which are entirely satisfactory:

  1. Add extra optional arguments to lax.root, for explicitly passing auxiliary argument into solve. These will get passed into the root primitive directly, allowing the closure problem to be side-stepped, but it’s still surprising that you can’t use a closure in solve.
  2. Change the signature of root to root(f, x), removing the solve argument. Encourage liberal use of lax.stop_gradients to avoid attempting to compute uncomputable or expensive to compjute JVP terms. Maybe we can make this easier by adding either a higher order function or context manager to stop gradients, e.g.,lax.stop_grad(lax.while_loop)(...) or with lax.disable_gradients(): ....
  3. Somehow evaluate solve into a JAXpr inside root (the function, not the primitive) without evaluating it’s JVP. This looks sort of like the context manager solution, except contained inside root.

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Comments:24 (22 by maintainers)

github_iconTop GitHub Comments

3reactions
shoyercommented, Oct 25, 2019

But more generally, issues of finite dev-hours aside, would it make sense (or even be possible) to change JAX’s JVP handling to be completely lazy (maybe with shape inference and memory allocation still eager)? By that I mean do minimal JVP related work until the output of some JVP is needed by some non-JVP operation.

I’ve been learning more about how JAX’s transformations work.

Transformations like forward-mode automatic differentiation, batching and JIT compilation work by overloading normal Python evaluation, with arrays replaced by symbolic tracers. For example, in forward-mode auto-diff, every variable is replaced by a tracer that keeps track of both original variable (the “primal”) and its derivative (the “tangent”). There’s no graph, just normal Python code.

This is really nice when it comes to debugging, because it means you can really sensible error messages and tracebacks from Python, pointing back exactly to the line where things went wrong. For example, if you try to differentiate a function using while_loop, you may get a long traceback, but somewhere in there it will point to a line of code that you wrote. You can even drop into a debugger.

In contrast, backwards mode auto-diff requires evaluating an abstract computation graph. If something breaks, the Python traceback will not be very interpretable, because it will point to an evaluation deep inside JAX’s auto-diff machinery not your original Python function.

One reason why JAX is such a pleasure to use that such abstract graph evaluation is only done to the bare minimum extent possible. Backwards mode auto-diff is composed into a series of simpler transformation, namely forward mode auto-diff followed by transposition. Only the later part requires evaluating a graph rather than normal Python code.

Making JVP evaluation lazy would negate this advantage. Instead of evaluating normal Python code for JVP rules, we’d be evaluating a stored computation graph.

We run into a similar issue if we try to make errors from missing JVP rules lazy. Now legitimately missing JVP rules result in error messages that no longer point back to user code, so it’s no longer obvious where things went wrong.

For these reasons, I think it’s a non-starter to change JAX’s JVP evaluation to be lazy, even for error reporting. It would be nice for this niche use-cases, but would make debugging harder for everything else.

0reactions
gehringcommented, Oct 26, 2019

I agree. custom_root is a little clunky, but it is harder to misuse.

Sounds like we are in agreement on all points. I’m happy with the discussion and I can’t think of any new points to bring up. If we can’t get reasonably get define_implicit_gradient to work with vjps without needing stop_gradient_fun (which, as I understand it, is not possible/worth it), then I’m happy with preferring custom_root.

Looking forward to seeing this PR merged! This is quality work 😃

Read more comments on GitHub >

github_iconTop Results From Across the Web

jax.lax.custom_root - JAX documentation
Differentiably solve for a roots of a function. This is a low-level routine, mostly intended for internal use in JAX. Gradients of custom_root()...
Read more >
The AUX1 LAX family of auxin influx carriers is required for the ...
It was demonstrated here that mutations within the AUX1 LAX family are associated with changes in cell pattern establishment in the embryonic ...
Read more >
Role of auxin dependent AUX/LAX A Impact ... - ResearchGate
Download scientific diagram | Role of auxin dependent AUX/LAX A Impact of AUX/LAX ... A key role for root tip architecture and reflux...
Read more >
Solutions to the Yang-Baxter equations with $ osp_q (1| 2 ...
H.Saleur in 1990 published solution to the spectral parameter ... the Lax operator of ospq(1|2n) with the isomorphism existing between the representations ...
Read more >
Darboux Transformation and New Multi-Soliton Solutions of the ...
Based on the Lax pair of the parameter Levi system, the N-fold Darboux ... In this paper, we investigate the Whitham-Broer-Kaup (WBK) system...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found