jax.numpy array indexing has different out-of-bounds behavior to numpy
See original GitHub issueimport jax
import jax.numpy as np
x = np.arange(10)
x = jax.device_put(x)
print(x[[13]])
This prints [3]
, but it should actually throw an out of bounds error like the original NumPy.
P.S.: why does np.arange
return a host array? Is this intended behavior or shouldn’t it rather behave like np.array and return a device array?
Issue Analytics
- State:
- Created 5 years ago
- Reactions:4
- Comments:16 (14 by maintainers)
Top Results From Across the Web
jax.numpy package - JAX documentation - Read the Docs
True if two arrays have the same shape and elements, False otherwise. ... which may have slightly different behavior than numpy.ndarray.astype() in some ......
Read more >The Sharp Bits — JAX documentation
In Numpy, you are used to errors being thrown when you index an array ... JAX must choose some non-error behavior for out...
Read more >jax.numpy.take - JAX documentation - Read the Docs
When axis is not None, this function does the same thing as “fancy” indexing (indexing arrays using arrays); however, it can be easier...
Read more >Source code for jax._src.numpy.lax_numpy - JAX documentation
NumPy operations are implemented in Python in terms of the primitive ... we want JAX scalars to have the same type # promotion...
Read more >Source code for jax._src.lax.slicing - JAX documentation
For most use cases, you should prefer `Numpy-style indexing ... The behavior for out-of-bounds indices when set to 'promise_in_bounds' is ...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
You might find
jax.experimental.checkify
useful for catching OOB checks.Notably, even non-fancy indexing doesn’t raise a bounds error anymore (it wraps instead). Using the example in https://github.com/google/jax/issues/278#issuecomment-457829164: