O(1) forward computation requires potentially unbounded time to compute gradient
See original GitHub issueSo this is a fun one.
The context here is that I’m implementing #5642, which I’m thinking of as a reverse-mode autodifferentiable while loop subject to a maximum number of iterations.
The good news is that the forward pass works perfectly (correct asymptotics; handles issues to do with vmap
and in-place updates c.f. #8192). However the backward pass can be arbitrarily expensive. I’ve already tried staring at the jaxpr without anything jumping out at me as being obviously wrong. I’m not completely sure whether to regard this as a bug in JAX (maybe XLA), or if this is something I can work around on the user side of things.
The following is most minimal of MWE I’ve been able to put together (e.g. this version won’t vmap efficiently; I’ve cut the special handling of that out). Even so it’s more of a “moderately sized working example”.
First of all, here is the (simplified) code for bounded_while_loop
:
import jax
import jax.lax as lax
def bounded_while_loop(cond_fun, body_fun, init_val, max_steps):
"""API as `lax.while_loop`, except that it takes an integer `max_steps` argument."""
if not isinstance(max_steps, int) or max_steps < 0:
raise ValueError("max_steps must be a non-negative integer")
if max_steps == 0:
return init_val
if max_steps & (max_steps - 1) != 0:
raise ValueError("max_steps must be a power of two")
init_data = (cond_fun(init_val), init_val)
_, val = _while_loop(cond_fun, body_fun, init_data, max_steps)
return val
def _while_loop(cond_fun, body_fun, data, max_steps):
if max_steps == 1:
pred, val = data
new_val = body_fun(val)
keep = lambda a, b: lax.select(pred, a, b)
new_val = jax.tree_map(keep, new_val, val)
return cond_fun(new_val), new_val
else:
def _call(_data):
return _while_loop(cond_fun, body_fun, _data, max_steps // 2)
def _scan_fn(_data, _):
_pred, _ = _data
return lax.cond(_pred, _call, lambda x: x, _data), None
return lax.scan(_scan_fn, data, xs=None, length=2)[0]
Then the test harness:
import functools as ft
import jax
import jax.experimental.stax as stax
import jax.numpy as jnp
import jax.random as jrandom
import time
_key = jrandom.PRNGKey(0)
_init, _apply = stax.serial(stax.Dense(1024),
stax.elementwise(jnp.tanh),
stax.Dense(1024),
stax.elementwise(jnp.tanh),
stax.Dense(1))
expensive_fn = ft.partial(_apply, _init(_key, (1,))[1])
def cond_fun(val):
x, step = val
return step < 8
def body_fun(val):
x, step = val
return (expensive_fn(x), step + 1)
def timed(fn):
def timer(*a, **kw):
start = time.time()
fn(*a, **kw).block_until_ready()
end = time.time()
print(end - start)
return timer
@timed
@ft.partial(jax.jit, static_argnums=1)
@jax.grad
def f(val, max_steps):
return jnp.sum(bounded_while_loop(cond_fun, body_fun, (val, 0), max_steps)[0])
val = jnp.array([1.])
f(val, 8)
f(val, 8) # 0.037744998931884766
f(val, 16)
f(val, 16) # 0.05941605567932129
f(val, 32)
f(val, 32) # 0.10239911079406738
As you can see, the runtime (of this gradient operation) increases proportional to the max_steps
bound. This is despite the forward operation running in constant time (just comment out the jax.grad
and rerun). Indeed the overall number of steps taken in the while loop is always exactly 8, by the choice of cond_fun
.
(Incidentally compile times are exponential in the depth of nested scan
s because of #8193, #8184, but based on the commentary in #8184 it sounds like a fix for that issue might be on the horizon.)
These times were obtained on the CPU. I’ve verified that the same behaviour also occurs on the GPU (with a more expensive expensive_fn
).
This kind of adaptive computation is the sort of thing for which I’m still usually reaching towards PyTorch, what with the static-graph requirement of XLA. I’d love to get the above working, as IMO it’s the one arena in which JAX still hasn’t caught up.
Issue Analytics
- State:
- Created 2 years ago
- Comments:5 (5 by maintainers)
So what you describe is actually what I’m doing at the moment – I basically have a Python while loop and
jax.jit
the interior. The various limitations of this I’ve been bumping up against are:Interpreter overhead The first problem is the one you indicated in your comment – performance.
A normal pattern is to stack
jit
-grad
-vmap
. This incurs very little runtime overhead.Switching things around to
grad
-vmap
-jit
means frequently passing through JAX internals. This gets really expensive. Here’s a couple of examples.First, the flame graph for an operation doing just
vmap
-jit
:The regions in blue are the times that XLA is actually being executed. Everything else is just overhead. I’ve also highlighted a green region: this is the cost of crossing a JIT API boundary, in which a bunch of complicated objects are partitioned into trace/static. It’s so large because this is happening repeatedly inside a while loop. [For this example I don’t think the
vmap
changes anything – it’s just the cost of passing back-and-forth throughjax.jit
so many times.][I realise blue/green may an issue if you’re colour-blind – I can look into how to re-colour things appropriately if so; let me know.]
Second, an operation doing
grad
-vmap
-jit
:I’m a little less certain about this one (I’m not quite as familiar with the internals of
jax.grad
), but I think in this case the blue region is the forward pass, and the purple region is the backward pass. Everything else looks to be things like jaxpr manipulation, i.e. interpreter overhead that wouldn’t be there if if the operation could bejit
’d.Developer ergonomics Without exception, every JAX operation has to be jitted. Op-by-op mode incurs simply too much overhead. (A fact arrived at by staring at flame graphs of the call stack, much like those above.) This means writing code like
with lots of little mini-jit-functions every time a JAX array must be interacted with.
The above is an actual example from my code – introducing this JIT produced a measurable improvement in performance.
User ergonomics I’m developing a software library. It’s pretty frustrating for a user to be told that they can’t JIT
mylibrary.myfunction
, and basically subjects the user to the same thing as above: at minimum one has to write a “beforemylibrary.myfunction
JIT” and an “aftermylibrary.myfunction
JIT”.Fixing this is actually my primary concern, as this fundamentally breaks composability with the rest of the JAX ecosystem.
In-place updates As I recall (been a while since I figured out how to work around this issue), it’s not possible to make in-place updates to the same buffer in different
jax.jit
regions; passing back to Python-land forces a copy.I’m aware of
donate_argnums
but this hasn’t seemed to help – possibly things are somehow too complex for the compiler to figure out? And either waydonate_argnums
isn’t yet supported on the CPU (spitting out a bunch of warnings instead).This one’s pretty important for efficiency purposes.
What to trace/static The main routine I’m writing takes an instance of
jax.tree_utils.Partial
as an argument. This is some parameterised function specified by the user of the software library. The parameterisation may include a mixture of some things worth JIT-tracing and some things worth JIT-static’ing.When the
jax.jit
call happens inside the software library, it falls on the developer to make an opionated choice on what to JIT. (e.g. “trace all JAX arrays”). A user may want something slightly different.This one can obviously be worked around in various ways – have extra arguments for the “JIT-trace args” and the “JIT-static args”, or let a user specify some partition function – but that’s less elegant at the API level, not to mention the library internals now need to pipe multiple argments around the place.
Overall it’d be preferential to avoid making this the library’s problem at all. Just make it possible for the user
jax.jit
their entire operation,mylibrary.myfunction
and all./walloftext!
Sounds good! Thank you for the write-up. (Crossing my fingers for the glorious dynamic-shape future.)