question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

O(1) forward computation requires potentially unbounded time to compute gradient

See original GitHub issue

So this is a fun one.

The context here is that I’m implementing #5642, which I’m thinking of as a reverse-mode autodifferentiable while loop subject to a maximum number of iterations.

The good news is that the forward pass works perfectly (correct asymptotics; handles issues to do with vmap and in-place updates c.f. #8192). However the backward pass can be arbitrarily expensive. I’ve already tried staring at the jaxpr without anything jumping out at me as being obviously wrong. I’m not completely sure whether to regard this as a bug in JAX (maybe XLA), or if this is something I can work around on the user side of things.

The following is most minimal of MWE I’ve been able to put together (e.g. this version won’t vmap efficiently; I’ve cut the special handling of that out). Even so it’s more of a “moderately sized working example”.

First of all, here is the (simplified) code for bounded_while_loop:

import jax
import jax.lax as lax

def bounded_while_loop(cond_fun, body_fun, init_val, max_steps):
    """API as `lax.while_loop`, except that it takes an integer `max_steps` argument."""
    if not isinstance(max_steps, int) or max_steps < 0:
        raise ValueError("max_steps must be a non-negative integer")
    if max_steps == 0:
        return init_val
    if max_steps & (max_steps - 1) != 0:
        raise ValueError("max_steps must be a power of two")

    init_data = (cond_fun(init_val), init_val)
    _, val = _while_loop(cond_fun, body_fun, init_data, max_steps)
    return val

def _while_loop(cond_fun, body_fun, data, max_steps):
    if max_steps == 1:
        pred, val = data
        new_val = body_fun(val)
        keep = lambda a, b: lax.select(pred, a, b)
        new_val = jax.tree_map(keep, new_val, val)
        return cond_fun(new_val), new_val
    else:

        def _call(_data):
            return _while_loop(cond_fun, body_fun, _data, max_steps // 2)

        def _scan_fn(_data, _):
            _pred, _ = _data
            return lax.cond(_pred, _call, lambda x: x, _data), None

        return lax.scan(_scan_fn, data, xs=None, length=2)[0]

Then the test harness:

import functools as ft
import jax
import jax.experimental.stax as stax
import jax.numpy as jnp
import jax.random as jrandom
import time

_key = jrandom.PRNGKey(0)
_init, _apply = stax.serial(stax.Dense(1024),
                            stax.elementwise(jnp.tanh),
                            stax.Dense(1024),
                            stax.elementwise(jnp.tanh),
                            stax.Dense(1))
expensive_fn = ft.partial(_apply, _init(_key, (1,))[1])

def cond_fun(val):
    x, step = val
    return step < 8

def body_fun(val):
    x, step = val
    return (expensive_fn(x), step + 1)

def timed(fn):
    def timer(*a, **kw):
        start = time.time()
        fn(*a, **kw).block_until_ready()
        end = time.time()
        print(end - start)
    return timer

@timed
@ft.partial(jax.jit, static_argnums=1)
@jax.grad
def f(val, max_steps):
    return jnp.sum(bounded_while_loop(cond_fun, body_fun, (val, 0), max_steps)[0])

val = jnp.array([1.])
f(val, 8)
f(val, 8)  # 0.037744998931884766
f(val, 16)
f(val, 16)  # 0.05941605567932129
f(val, 32)
f(val, 32)  # 0.10239911079406738

As you can see, the runtime (of this gradient operation) increases proportional to the max_steps bound. This is despite the forward operation running in constant time (just comment out the jax.grad and rerun). Indeed the overall number of steps taken in the while loop is always exactly 8, by the choice of cond_fun.

(Incidentally compile times are exponential in the depth of nested scans because of #8193, #8184, but based on the commentary in #8184 it sounds like a fix for that issue might be on the horizon.)

These times were obtained on the CPU. I’ve verified that the same behaviour also occurs on the GPU (with a more expensive expensive_fn).

This kind of adaptive computation is the sort of thing for which I’m still usually reaching towards PyTorch, what with the static-graph requirement of XLA. I’d love to get the above working, as IMO it’s the one arena in which JAX still hasn’t caught up.

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:5 (5 by maintainers)

github_iconTop GitHub Comments

1reaction
patrick-kidgercommented, Oct 16, 2021

So what you describe is actually what I’m doing at the moment – I basically have a Python while loop and jax.jit the interior. The various limitations of this I’ve been bumping up against are:

Interpreter overhead The first problem is the one you indicated in your comment – performance.

A normal pattern is to stack jit-grad-vmap. This incurs very little runtime overhead.

Switching things around to grad-vmap-jit means frequently passing through JAX internals. This gets really expensive. Here’s a couple of examples.

First, the flame graph for an operation doing just vmap-jit:

image

The regions in blue are the times that XLA is actually being executed. Everything else is just overhead. I’ve also highlighted a green region: this is the cost of crossing a JIT API boundary, in which a bunch of complicated objects are partitioned into trace/static. It’s so large because this is happening repeatedly inside a while loop. [For this example I don’t think the vmap changes anything – it’s just the cost of passing back-and-forth through jax.jit so many times.]

[I realise blue/green may an issue if you’re colour-blind – I can look into how to re-colour things appropriately if so; let me know.]

Second, an operation doing grad-vmap-jit:

image

I’m a little less certain about this one (I’m not quite as familiar with the internals of jax.grad), but I think in this case the blue region is the forward pass, and the purple region is the backward pass. Everything else looks to be things like jaxpr manipulation, i.e. interpreter overhead that wouldn’t be there if if the operation could be jit’d.

Developer ergonomics Without exception, every JAX operation has to be jitted. Op-by-op mode incurs simply too much overhead. (A fact arrived at by staring at flame graphs of the call stack, much like those above.) This means writing code like

@jax.jit
def _jit_lt(a, b):
    return a < b

while _jit_lt(a, b):
    ...

with lots of little mini-jit-functions every time a JAX array must be interacted with.

The above is an actual example from my code – introducing this JIT produced a measurable improvement in performance.

User ergonomics I’m developing a software library. It’s pretty frustrating for a user to be told that they can’t JIT mylibrary.myfunction, and basically subjects the user to the same thing as above: at minimum one has to write a “before mylibrary.myfunction JIT” and an “after mylibrary.myfunction JIT”.

Fixing this is actually my primary concern, as this fundamentally breaks composability with the rest of the JAX ecosystem.

In-place updates As I recall (been a while since I figured out how to work around this issue), it’s not possible to make in-place updates to the same buffer in different jax.jit regions; passing back to Python-land forces a copy.

I’m aware of donate_argnums but this hasn’t seemed to help – possibly things are somehow too complex for the compiler to figure out? And either way donate_argnums isn’t yet supported on the CPU (spitting out a bunch of warnings instead).

This one’s pretty important for efficiency purposes.

What to trace/static The main routine I’m writing takes an instance of jax.tree_utils.Partial as an argument. This is some parameterised function specified by the user of the software library. The parameterisation may include a mixture of some things worth JIT-tracing and some things worth JIT-static’ing.

When the jax.jit call happens inside the software library, it falls on the developer to make an opionated choice on what to JIT. (e.g. “trace all JAX arrays”). A user may want something slightly different.

This one can obviously be worked around in various ways – have extra arguments for the “JIT-trace args” and the “JIT-static args”, or let a user specify some partition function – but that’s less elegant at the API level, not to mention the library internals now need to pipe multiple argments around the place.

Overall it’d be preferential to avoid making this the library’s problem at all. Just make it possible for the user jax.jit their entire operation, mylibrary.myfunction and all.

/walloftext!

0reactions
patrick-kidgercommented, May 24, 2022

Sounds good! Thank you for the write-up. (Crossing my fingers for the glorious dynamic-shape future.)

Read more comments on GitHub >

github_iconTop Results From Across the Web

Is it possible to have early stopping in lax.scan ? #5642 - GitHub
I am considering a case in which there is a loop of operations, ... O(1) forward computation requires potentially unbounded time to compute...
Read more >
Neural Networks: Optimization Part 1 - Deep Learning, CMU
We use variants of gradient descent to do so. • The gradient of the error with respect to network parameters is computed through...
Read more >
Understanding Gradient Clipping (and How It Can Fix ...
You see, in a backward pass, we calculate gradients of all weights and biases in order to converge our cost function. These gradients,...
Read more >
Gradient Descent Algorithm and Its Variants | by Imad Dabbura
This is feasible if the objective function is convex, i.e. any local minimum is a global minimum. Find the lowest possible value of...
Read more >
Unbounded solution but bounded Euler discretization
The central problem with forward Euler is that it doesn't continuously update the vector field, and therefore can tunnel through fine ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found