question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Odd output from jax.ops.index_update when jitted

See original GitHub issue
import jax
import jax.numpy as jnp

def demo(n=8):
  fn = lambda x: jax.ops.index_update(x, slice(1, None), 1 + x[:-1])
  y = jnp.zeros(n)
  print(fn(y))
  print(jax.jit(fn)(y))
demo()

# [0. 1. 1. 1. 1. 1. 1. 1.]
# [0. 1. 2. 3. 4. 5. 6. 7.]

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:14 (12 by maintainers)

github_iconTop GitHub Comments

1reaction
hawkinspcommented, Oct 7, 2021

We have a candidate fix for this issue that hopefully should land soon.

1reaction
hawkinspcommented, Sep 29, 2021

Debugging progress is happening. We suspect that this isn’t a CPU-only bug, it’s just much more likely to exhibit on CPU. Watch this space!

Read more comments on GitHub >

github_iconTop Results From Across the Web

Inconsistencies and divergence depending on use of JIT #3602
I have come across this odd behavior in an implementation of a batched ... samples samples = jax.ops.index_update(samples, jax.ops.index[i, ...
Read more >
The Sharp Bits — JAX documentation
# JAX re-runs the Python function when the type or shape of the argument changes print ("Third call, different type: ", jit(impure_print_side_effect)(jnp.array ...
Read more >
Conditional update in JAX? - python - Stack Overflow
How can I do the same thing in JAX? I tried import numpy as onp and using that to create arrays, but that...
Read more >
JAX 201: A running Intro to JAX - Kaggle
I have uploaded the notebook to Kaggle to allow ease of running on here. Content Outline:¶. JAX vs Numpy; Automatic Differentiation; Vectorization; JIT...
Read more >
Efficiently sampling a large ODE model (compiling issues?)
Compiling module jit__body_fn.274620” ... ~np.isnan(data) y = jax.ops.index_update(data, np.isnan(data), ... It seems strange to me too.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found