Error with custom_vjp + scan + vmap (jax==0.2.9)
See original GitHub issueHello,
Using lax.scan over a function with a custom_vjp, then vmapping the resulting function, and then attempting backward mode differentiation using jax.vjp, leads to an error with jax 0.2.9. Jax 0.2.8 works.
This is the code to reproduce the error:
import jax
import jax.numpy as jnp
@jax.custom_vjp
def mul(x, coeff): return x * coeff
def mul_fwd(x, coeff): return mul(x, coeff), (x, coeff)
def mul_bwd(res, g):
x, coeff = res
g_x = g * coeff
g_coeff = (x * g).sum()
return g_x, g_coeff
mul.defvjp(mul_fwd, mul_bwd)
def scan_over_mul(x, coeff):
def f_(x, t):
return mul(x, coeff), None
y, _ = jax.lax.scan(f_, x, jnp.arange(3))
return y
key = jax.random.PRNGKey(0)
key1, key2 = jax.random.split(key, 2)
x_batch = jax.random.normal(key1, (3, 2))
covector_batch = jax.random.normal(key2, (3, 2))
coeff = jnp.array(1.)
batched_scan_over_mul = jax.vmap(scan_over_mul, in_axes=(0, None), out_axes=0)
res = batched_scan_over_mul(x_batch, coeff)
res, vjp_fun = jax.vjp(batched_scan_over_mul, x_batch, coeff)
grads = vjp_fun(covector_batch) # This line throws ValueError
print(grads)
Things work as expected (i.e., no error and grads[1]
, which corresponds to coeff
has only one element) if doing one of:
- scanning over a function without a custom_vjp,
- replacing
res, vjp_fun = jax.vjp(batched_scan_over_mul, x_batch, coeff)
bywith jax.disable_jit(): res, vjp_fun = jax.vjp(batched_scan_over_mul, x_batch, coeff)
- replacing scan by a python for loop.
- Using jax 0.2.8 instead of 0.2.9. The master branch throws the same error as 0.2.9.
Here is the error:
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Traceback (most recent call last):
File "/Users/florianhopfmueller/code/debug_vmap_error.py", line 38, in <module>
grads = vjp_fun(covector_batch)
File "/Users/florianhopfmueller/opt/anaconda3/envs/jax029/lib/python3.9/site-packages/jax/api.py", line 1834, in _vjp_pullback_wrapper
ans = fun(*args)
File "/Users/florianhopfmueller/opt/anaconda3/envs/jax029/lib/python3.9/site-packages/jax/interpreters/ad.py", line 121, in unbound_vjp
arg_cts = backward_pass(jaxpr, consts, dummy_args, cts)
File "/Users/florianhopfmueller/opt/anaconda3/envs/jax029/lib/python3.9/site-packages/jax/interpreters/ad.py", line 227, in backward_pass
cts_out = get_primitive_transpose(eqn.primitive)(cts_in, *invals,
File "/Users/florianhopfmueller/opt/anaconda3/envs/jax029/lib/python3.9/site-packages/jax/_src/lax/control_flow.py", line 1693, in _scan_transpose
jaxpr_trans = _transpose_scan_jaxpr(
File "/Users/florianhopfmueller/opt/anaconda3/envs/jax029/lib/python3.9/site-packages/jax/_src/lax/control_flow.py", line 1728, in _transpose_scan_jaxpr
return _make_closed_jaxpr(transposed, res1_avals + c_avals + b_avals + res2_avals)
File "/Users/florianhopfmueller/opt/anaconda3/envs/jax029/lib/python3.9/site-packages/jax/_src/lax/control_flow.py", line 1732, in _make_closed_jaxpr
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(traceable, in_avals)
File "/Users/florianhopfmueller/opt/anaconda3/envs/jax029/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1186, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
File "/Users/florianhopfmueller/opt/anaconda3/envs/jax029/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1196, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
File "/Users/florianhopfmueller/opt/anaconda3/envs/jax029/lib/python3.9/site-packages/jax/linear_util.py", line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/Users/florianhopfmueller/opt/anaconda3/envs/jax029/lib/python3.9/site-packages/jax/_src/lax/control_flow.py", line 1722, in transposed
cbar_abar = ad.backward_pass(jaxpr.jaxpr, jaxpr.consts, primals, b_bar)
File "/Users/florianhopfmueller/opt/anaconda3/envs/jax029/lib/python3.9/site-packages/jax/interpreters/ad.py", line 227, in backward_pass
cts_out = get_primitive_transpose(eqn.primitive)(cts_in, *invals,
File "/Users/florianhopfmueller/opt/anaconda3/envs/jax029/lib/python3.9/site-packages/jax/interpreters/ad.py", line 687, in _custom_lin_transpose
cts_in = bwd.call_wrapped(*res, *cts_out)
File "/Users/florianhopfmueller/opt/anaconda3/envs/jax029/lib/python3.9/site-packages/jax/linear_util.py", line 179, in call_wrapped
ans = gen.send(ans)
File "/Users/florianhopfmueller/opt/anaconda3/envs/jax029/lib/python3.9/site-packages/jax/interpreters/batching.py", line 70, in _match_axes
raise ValueError(msg)
ValueError: vmap has mapped output but out_axes is None
Please let me know if this is expected, or if you need any more info to address it! Thanks a lot.
Issue Analytics
- State:
- Created 3 years ago
- Comments:6 (6 by maintainers)
Top Results From Across the Web
jax.lax.scan - JAX documentation - Read the Docs
Scan a function over leading array axes while carrying along state. ... Also unlike that Python version, scan() is a JAX primitive and...
Read more >Jax.lax.scan with arguments? - Stack Overflow
I'm trying to speed up the execution of my code rewriting for loops into jax.lax.scan, but I ran into the issue that I...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
~After fixing that issue, I also saw a float0 issue pop up!~ (Turns out that was a mistake on my part.)
This is a great test case you provided 😄
Thanks for the excellent report!