question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Error with custom_vjp + scan + vmap (jax==0.2.9)

See original GitHub issue

Hello,

Using lax.scan over a function with a custom_vjp, then vmapping the resulting function, and then attempting backward mode differentiation using jax.vjp, leads to an error with jax 0.2.9. Jax 0.2.8 works.

This is the code to reproduce the error:

import jax
import jax.numpy as jnp


@jax.custom_vjp
def mul(x, coeff): return x * coeff
def mul_fwd(x, coeff): return mul(x, coeff), (x, coeff)
def mul_bwd(res, g):
    x, coeff = res
    g_x = g * coeff
    g_coeff = (x * g).sum()
    return g_x, g_coeff
mul.defvjp(mul_fwd, mul_bwd)


def scan_over_mul(x, coeff):
    def f_(x, t):
        return mul(x, coeff), None
    y, _ = jax.lax.scan(f_, x, jnp.arange(3))
    return y


key = jax.random.PRNGKey(0)
key1, key2 = jax.random.split(key, 2)
x_batch = jax.random.normal(key1, (3, 2))
covector_batch = jax.random.normal(key2, (3, 2))
coeff = jnp.array(1.)

batched_scan_over_mul = jax.vmap(scan_over_mul, in_axes=(0, None), out_axes=0)
res = batched_scan_over_mul(x_batch, coeff)

res, vjp_fun = jax.vjp(batched_scan_over_mul, x_batch, coeff)

grads = vjp_fun(covector_batch)  # This line throws ValueError
print(grads)

Things work as expected (i.e., no error and grads[1], which corresponds to coeff has only one element) if doing one of:

  • scanning over a function without a custom_vjp,
  • replacing res, vjp_fun = jax.vjp(batched_scan_over_mul, x_batch, coeff) by with jax.disable_jit(): res, vjp_fun = jax.vjp(batched_scan_over_mul, x_batch, coeff)
  • replacing scan by a python for loop.
  • Using jax 0.2.8 instead of 0.2.9. The master branch throws the same error as 0.2.9.

Here is the error:

WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Traceback (most recent call last):
  File "/Users/florianhopfmueller/code/debug_vmap_error.py", line 38, in <module>
    grads = vjp_fun(covector_batch)
  File "/Users/florianhopfmueller/opt/anaconda3/envs/jax029/lib/python3.9/site-packages/jax/api.py", line 1834, in _vjp_pullback_wrapper
    ans = fun(*args)
  File "/Users/florianhopfmueller/opt/anaconda3/envs/jax029/lib/python3.9/site-packages/jax/interpreters/ad.py", line 121, in unbound_vjp
    arg_cts = backward_pass(jaxpr, consts, dummy_args, cts)
  File "/Users/florianhopfmueller/opt/anaconda3/envs/jax029/lib/python3.9/site-packages/jax/interpreters/ad.py", line 227, in backward_pass
    cts_out = get_primitive_transpose(eqn.primitive)(cts_in, *invals,
  File "/Users/florianhopfmueller/opt/anaconda3/envs/jax029/lib/python3.9/site-packages/jax/_src/lax/control_flow.py", line 1693, in _scan_transpose
    jaxpr_trans = _transpose_scan_jaxpr(
  File "/Users/florianhopfmueller/opt/anaconda3/envs/jax029/lib/python3.9/site-packages/jax/_src/lax/control_flow.py", line 1728, in _transpose_scan_jaxpr
    return _make_closed_jaxpr(transposed, res1_avals + c_avals + b_avals + res2_avals)
  File "/Users/florianhopfmueller/opt/anaconda3/envs/jax029/lib/python3.9/site-packages/jax/_src/lax/control_flow.py", line 1732, in _make_closed_jaxpr
    jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(traceable, in_avals)
  File "/Users/florianhopfmueller/opt/anaconda3/envs/jax029/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1186, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
  File "/Users/florianhopfmueller/opt/anaconda3/envs/jax029/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1196, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
  File "/Users/florianhopfmueller/opt/anaconda3/envs/jax029/lib/python3.9/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/florianhopfmueller/opt/anaconda3/envs/jax029/lib/python3.9/site-packages/jax/_src/lax/control_flow.py", line 1722, in transposed
    cbar_abar = ad.backward_pass(jaxpr.jaxpr, jaxpr.consts, primals, b_bar)
  File "/Users/florianhopfmueller/opt/anaconda3/envs/jax029/lib/python3.9/site-packages/jax/interpreters/ad.py", line 227, in backward_pass
    cts_out = get_primitive_transpose(eqn.primitive)(cts_in, *invals,
  File "/Users/florianhopfmueller/opt/anaconda3/envs/jax029/lib/python3.9/site-packages/jax/interpreters/ad.py", line 687, in _custom_lin_transpose
    cts_in = bwd.call_wrapped(*res, *cts_out)
  File "/Users/florianhopfmueller/opt/anaconda3/envs/jax029/lib/python3.9/site-packages/jax/linear_util.py", line 179, in call_wrapped
    ans = gen.send(ans)
  File "/Users/florianhopfmueller/opt/anaconda3/envs/jax029/lib/python3.9/site-packages/jax/interpreters/batching.py", line 70, in _match_axes
    raise ValueError(msg)
ValueError: vmap has mapped output but out_axes is None

Please let me know if this is expected, or if you need any more info to address it! Thanks a lot.

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:6 (6 by maintainers)

github_iconTop GitHub Comments

1reaction
mattjjcommented, Mar 28, 2021

~After fixing that issue, I also saw a float0 issue pop up!~ (Turns out that was a mistake on my part.)

This is a great test case you provided 😄

1reaction
mattjjcommented, Feb 24, 2021

Thanks for the excellent report!

Read more comments on GitHub >

github_iconTop Results From Across the Web

jax.lax.scan - JAX documentation - Read the Docs
Scan a function over leading array axes while carrying along state. ... Also unlike that Python version, scan() is a JAX primitive and...
Read more >
Jax.lax.scan with arguments? - Stack Overflow
I'm trying to speed up the execution of my code rewriting for loops into jax.lax.scan, but I ran into the issue that I...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found