question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Indexing numpy array with DeviceArray: index interpreted as tuple

See original GitHub issue

When you try to index a numpy ndarray with a DeviceArray, the numpy array tries to interpret the jax array as a tuple.

import numpy as onp
import jax.numpy as np
x = onp.zeros((5,7))
np_idx = onp.array([1,2,3])
jax_idx = np.array([1,2,3])
x[np_idx]
x[jax_idx] # <- raises IndexError

Workaround: put jax_idx in a singleton tuple x[(jax_idx,)]

This bug resulted in a confusing situation where my function worked when decorated by jax.jit but had a shape mismatch when called on a numpy array.

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Reactions:1
  • Comments:10 (7 by maintainers)

github_iconTop GitHub Comments

3reactions
hawkinspcommented, Feb 14, 2022

The upstream NumPy PR was merged! So I’m going to declare this fixed, even though you’re going to have to wait for NumPy 1.23 to get it…

1reaction
levskayacommented, Jul 10, 2020

Just remarking that I just helped an intern debug this issue again, so it’s still happening circa July 2020.

Read more comments on GitHub >

github_iconTop Results From Across the Web

indexing into numpy array with jax array: faulty error messages
In the future this will be interpreted as an array index, `arr[np.array(seq)]`, which will result either in an error or a different result....
Read more >
jax.numpy.take - JAX documentation - Read the Docs
When axis is not None, this function does the same thing as “fancy” indexing (indexing arrays using arrays); however, it can be easier...
Read more >
Working with Numpy Arrays: Indexing | by Kurtis Pykes
On that note, we can describe numpy arrays as a grid of the same type values that is indexed via a tuple of...
Read more >
Common Gotchas in JAX - Colaboratory - Google Colab
If we try to update a JAX device array in-place, however, we get an error! ... In Numpy, you are used to errors...
Read more >
Troubleshooting and tips — Numba 0.50.1 documentation
If however you call it with a tuple and a number, Numba is unable to say what the ... gdb_init import numpy as...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found