Using ```grad``` on ```vmap``` on ```map``` on function containing ```sinc``` results in error
See original GitHub issueHi,
I ran into an error while trying to take the grad
of a vmap
on a map
on a function, which contains sinc
. The following code reproduces the error:
import jax
import jax.numpy as jnp
N_batch = 5
N_elem = 3
batch_data = jnp.reshape(jnp.arange(N_batch*N_elem),(N_batch,N_elem))
def fail(x):
return jnp.sinc(x)
def mapfail(x):
return jax.lax.map(fail, x)
vmapmapfail = jax.vmap(mapfail)
def loss(param):
pos = param*batch_data
desc = vmapmapfail(pos)
return jnp.mean(desc)
jax.value_and_grad(loss)(0.1)
- Error:
JaxStackTraceBeforeTransformation: TypeError: broadcast_in_dim broadcast_dimensions must have length equal to operand ndim; got broadcast_dimensions () for operand ndim 1.
- Jax version: 0.3.13
I suspect this is related to the custom derivatives of the _sinc_maclaurin
, since commenting out both @partial(custom_jvp, nondiff_argnums=(0,))
and @_sinc_maclaurin.defjvp
in the jax source “solves” the problem. It is also specific to map
, since using a vmap instead results in no error.
Ps.: A colleague of mine also reminded me, that the documentation of sinc, does not describe the x=0 behaviour correctly.
Issue Analytics
- State:
- Created a year ago
- Comments:11 (7 by maintainers)
Top Results From Across the Web
Custom derivative rules for JAX-transformable Python functions
When trying to track down the source of a nan runtime error, or just examine carefully the cotangent (gradient) values being propagated, it...
Read more >functorch.vmap - PyTorch
vmap is the vectorizing map; vmap(func) returns a new function that maps func over some dimension of the inputs. Semantically, vmap pushes the...
Read more >Learning about JAX :axes in vmap() - Jiayi Wu Cox
Auto-differentiation (a.k.a grad in JAX) for customized functions. ... function which maps the function one specified over using in_axes and ...
Read more >Rhythmia™ Mapping System - vingmed.dk
IntellaMap Orion™ High-Resolution Mapping Catheter ... use the diagnostic and ablation catheters of your choice ... Results in other cases may vary.
Read more >How to Run an Asynchronous Function in Array.map()
map () is a synchronous operation and runs a function on each element in the array resulting in a new array with the...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
This is a fantastic bug! Thanks for raising it. (And thanks for the tip about the docstring too.)
I think the best fix may be to add a
full_like
primitive. (Alternatives include: defensively batching more outputs ofcustom_jvp
functions, or probing custom JVP rules’ vmap data dependence, but neither of those seem as good.)A quick fix, which is basically the same thing, is to replace these lines with
The issue has to do with data dependence, and an assumption baked into
custom_jvp
about how the input-output data dependence of a custom JVP rule relates to the input-output data dependence of the function for which it is the rule. That is, if we havef = custom_jvp(f_orig)
andf.defjvp(f_jvp)
, our batching rule forcustom_jvp
roughly assumes that for anyin_axes: Tuple[Optional[int], ...]
we havevmap_out_axes(f_orig, in_axes) == vmap_out_axes(lambda xs: jax.jvp(f, xs, xs)[1], in_axes))
, which in turn roughly means that the input-output data dependence off_jvp
looks like the input-output dependence off_orig
.But for
sinc
that’s not the case: the_sinc_maclaurin
function has no data dependence on its input, while in its JVP rule the tangent output has a data dependence on the tangent input (which it must, for linearity to hold). What resulted was ultimately a type error: we’d get acustom_jvp_call_jaxpr
application with anf32[]
output (and downstream operations, likebroadcast_in_dim
, which were set up for thatf32[]
output), but when differentiated we’d get anf32[5]
primal (andf32[5]
tangent), which was then type-incompatible with downstream applications.There’s a bit more going on here which was necessary to exhibit this bug; in particular,
scan
(i.e.lax.map
) was necessary because it causes the JVP rule to be run in a later pass, rather than the JVP happening “on the fly” where we could’ve noticed the output of the JVP rule was batched. That’s why replacing thelax.map
application with ajnp.stack([jnp.sinc(x_) for x_ in x])
did not exhibit the bug. (Sorry, this paragraph is probably even more inside-baseball than the preceding paragraph…)This analysis directly leads to the two of the possible solutions briefly mentioned above:
full_like
primitive then we could writef
so as to model the same data dependence asf_jvp
here;custom_jvp_call_jaxpr
vmap rule, we won’t get this issue of JVPs looking batched where primals aren’t.I kind of like the first approach but I’m not sure. I might land the quick fix in the meantime, since multiplying by zero isn’t so bad (and it encodes the data dependence we want to maintain).
Thank you for the quick fix, and also for the detailed explanation. With those changes it works perfectly.
And thanks again for your work on jax!