question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

jax.numpy array indexing has different out-of-bounds behavior to numpy

See original GitHub issue
import jax
import jax.numpy as np

x = np.arange(10)
x = jax.device_put(x)
print(x[[13]])

This prints [3], but it should actually throw an out of bounds error like the original NumPy.

P.S.: why does np.arange return a host array? Is this intended behavior or shouldn’t it rather behave like np.array and return a device array?

Issue Analytics

  • State:open
  • Created 5 years ago
  • Reactions:4
  • Comments:16 (14 by maintainers)

github_iconTop GitHub Comments

2reactions
mattjjcommented, Oct 18, 2022

You might find jax.experimental.checkify useful for catching OOB checks.

1reaction
schmrlngcommented, Jul 30, 2019

Notably, even non-fancy indexing doesn’t raise a bounds error anymore (it wraps instead). Using the example in https://github.com/google/jax/issues/278#issuecomment-457829164:

import jax
import jax.numpy as np

x = np.arange(10)
x = jax.device_put(x)
print(x[13])  # 3
Read more comments on GitHub >

github_iconTop Results From Across the Web

jax.numpy package - JAX documentation - Read the Docs
True if two arrays have the same shape and elements, False otherwise. ... which may have slightly different behavior than numpy.ndarray.astype() in some ......
Read more >
The Sharp Bits — JAX documentation
In Numpy, you are used to errors being thrown when you index an array ... JAX must choose some non-error behavior for out...
Read more >
jax.numpy.take - JAX documentation - Read the Docs
When axis is not None, this function does the same thing as “fancy” indexing (indexing arrays using arrays); however, it can be easier...
Read more >
Source code for jax._src.numpy.lax_numpy - JAX documentation
NumPy operations are implemented in Python in terms of the primitive ... we want JAX scalars to have the same type # promotion...
Read more >
Source code for jax._src.lax.slicing - JAX documentation
For most use cases, you should prefer `Numpy-style indexing ... The behavior for out-of-bounds indices when set to 'promise_in_bounds' is ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found