question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Using ```grad``` on ```vmap``` on ```map``` on function containing ```sinc``` results in error

See original GitHub issue

Hi,

I ran into an error while trying to take the grad of a vmap on a map on a function, which contains sinc. The following code reproduces the error:

import jax
import jax.numpy as jnp

N_batch = 5
N_elem = 3

batch_data = jnp.reshape(jnp.arange(N_batch*N_elem),(N_batch,N_elem))

def fail(x):
    return jnp.sinc(x)

def mapfail(x):
    return jax.lax.map(fail, x)

vmapmapfail = jax.vmap(mapfail)

def loss(param):
    pos = param*batch_data
    desc = vmapmapfail(pos)
    return jnp.mean(desc)

jax.value_and_grad(loss)(0.1)
  • Error: JaxStackTraceBeforeTransformation: TypeError: broadcast_in_dim broadcast_dimensions must have length equal to operand ndim; got broadcast_dimensions () for operand ndim 1.
  • Jax version: 0.3.13

I suspect this is related to the custom derivatives of the _sinc_maclaurin, since commenting out both @partial(custom_jvp, nondiff_argnums=(0,)) and @_sinc_maclaurin.defjvp in the jax source “solves” the problem. It is also specific to map, since using a vmap instead results in no error.

Ps.: A colleague of mine also reminded me, that the documentation of sinc, does not describe the x=0 behaviour correctly.

Issue Analytics

  • State:closed
  • Created a year ago
  • Comments:11 (7 by maintainers)

github_iconTop GitHub Comments

2reactions
mattjjcommented, May 19, 2022

This is a fantastic bug! Thanks for raising it. (And thanks for the tip about the docstring too.)

I think the best fix may be to add a full_like primitive. (Alternatives include: defensively batching more outputs of custom_jvp functions, or probing custom JVP rules’ vmap data dependence, but neither of those seem as good.)

A quick fix, which is basically the same thing, is to replace these lines with

if k % 2:
  return x * 0
else:
  return x * 0 + lax.full_like(x, (-1) ** (k // 2) / (k + 1))

The issue has to do with data dependence, and an assumption baked into custom_jvp about how the input-output data dependence of a custom JVP rule relates to the input-output data dependence of the function for which it is the rule. That is, if we have f = custom_jvp(f_orig) and f.defjvp(f_jvp), our batching rule for custom_jvp roughly assumes that for any in_axes: Tuple[Optional[int], ...] we have vmap_out_axes(f_orig, in_axes) == vmap_out_axes(lambda xs: jax.jvp(f, xs, xs)[1], in_axes)), which in turn roughly means that the input-output data dependence of f_jvp looks like the input-output dependence of f_orig.

But for sinc that’s not the case: the _sinc_maclaurin function has no data dependence on its input, while in its JVP rule the tangent output has a data dependence on the tangent input (which it must, for linearity to hold). What resulted was ultimately a type error: we’d get a custom_jvp_call_jaxpr application with an f32[] output (and downstream operations, like broadcast_in_dim, which were set up for that f32[] output), but when differentiated we’d get an f32[5] primal (and f32[5] tangent), which was then type-incompatible with downstream applications.

There’s a bit more going on here which was necessary to exhibit this bug; in particular, scan (i.e. lax.map) was necessary because it causes the JVP rule to be run in a later pass, rather than the JVP happening “on the fly” where we could’ve noticed the output of the JVP rule was batched. That’s why replacing the lax.map application with a jnp.stack([jnp.sinc(x_) for x_ in x]) did not exhibit the bug. (Sorry, this paragraph is probably even more inside-baseball than the preceding paragraph…)

This analysis directly leads to the two of the possible solutions briefly mentioned above:

  1. if we had a full_like primitive then we could write f so as to model the same data dependence as f_jvp here;
  2. if we defensively assume more stuff will come out batched in the custom_jvp_call_jaxpr vmap rule, we won’t get this issue of JVPs looking batched where primals aren’t.

I kind of like the first approach but I’m not sure. I might land the quick fix in the meantime, since multiplying by zero isn’t so bad (and it encodes the data dependence we want to maintain).

1reaction
Ph03n1xdustcommented, May 19, 2022

Thank you for the quick fix, and also for the detailed explanation. With those changes it works perfectly.

And thanks again for your work on jax!

Read more comments on GitHub >

github_iconTop Results From Across the Web

Custom derivative rules for JAX-transformable Python functions
When trying to track down the source of a nan runtime error, or just examine carefully the cotangent (gradient) values being propagated, it...
Read more >
functorch.vmap - PyTorch
vmap is the vectorizing map; vmap(func) returns a new function that maps func over some dimension of the inputs. Semantically, vmap pushes the...
Read more >
Learning about JAX :axes in vmap() - Jiayi Wu Cox
Auto-differentiation (a.k.a grad in JAX) for customized functions. ... function which maps the function one specified over using in_axes and ...
Read more >
Rhythmia™ Mapping System - vingmed.dk
IntellaMap Orion™ High-Resolution Mapping Catheter ... use the diagnostic and ablation catheters of your choice ... Results in other cases may vary.
Read more >
How to Run an Asynchronous Function in Array.map()
map () is a synchronous operation and runs a function on each element in the array resulting in a new array with the...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found