Indexing numpy array with DeviceArray: index interpreted as tuple
See original GitHub issueWhen you try to index a numpy ndarray with a DeviceArray, the numpy array tries to interpret the jax array as a tuple.
import numpy as onp
import jax.numpy as np
x = onp.zeros((5,7))
np_idx = onp.array([1,2,3])
jax_idx = np.array([1,2,3])
x[np_idx]
x[jax_idx] # <- raises IndexError
Workaround: put jax_idx
in a singleton tuple x[(jax_idx,)]
This bug resulted in a confusing situation where my function worked when decorated by jax.jit but had a shape mismatch when called on a numpy array.
Issue Analytics
- State:
- Created 4 years ago
- Reactions:1
- Comments:10 (7 by maintainers)
Top Results From Across the Web
indexing into numpy array with jax array: faulty error messages
In the future this will be interpreted as an array index, `arr[np.array(seq)]`, which will result either in an error or a different result....
Read more >jax.numpy.take - JAX documentation - Read the Docs
When axis is not None, this function does the same thing as “fancy” indexing (indexing arrays using arrays); however, it can be easier...
Read more >Working with Numpy Arrays: Indexing | by Kurtis Pykes
On that note, we can describe numpy arrays as a grid of the same type values that is indexed via a tuple of...
Read more >Common Gotchas in JAX - Colaboratory - Google Colab
If we try to update a JAX device array in-place, however, we get an error! ... In Numpy, you are used to errors...
Read more >Troubleshooting and tips — Numba 0.50.1 documentation
If however you call it with a tuple and a number, Numba is unable to say what the ... gdb_init import numpy as...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
The upstream NumPy PR was merged! So I’m going to declare this fixed, even though you’re going to have to wait for NumPy 1.23 to get it…
Just remarking that I just helped an intern debug this issue again, so it’s still happening circa July 2020.