Indexing multidim array causes massive slowdown in Jax 0.2.28
See original GitHub issueOn Jax version 0.2.28: (doesn’t reproduce on i.e. 0.2.15) The following code runs extremely slow:
import jax
import jax.numpy as jnp
import time
def tester(carry, unused):
selection = jnp.arange(1000)
selected = carry[:, selection, ...]
return carry, selected[0,0,0,0,0]
size = 8
starter = jnp.ones((10,2000,size,size,4))
start_time = time.time()
for i in range(200):
print("time", time.time() - start_time) # 4+ seconds each iteration
starter, _ = tester(starter, ())
Curiously: changing size
to 7 makes the loop fast again (<< 1s per iteration, as in 0.2.15). Any idea what’s going on here?
Replicated across TPU sizes colab, v2-8 and v3-8.
Issue Analytics
- State:
- Created 2 years ago
- Comments:6 (3 by maintainers)
Top Results From Across the Web
Is there a way to speed up indexing a vector with JAX?
I am indexing vectors and using JAX, but I have noticed a considerable slow-down compared to numpy when simply indexing arrays.
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Thanks for the report, I agree this is much slower than it should be. I filed a bug against the TPU compiler folks (Google bug b/219921462).
@hawkinsp thanks for the fix but — how can I get that version to work? I tried to do
pip install "jax[tpu]" -f http:/... -U
, but this gives me 01/28 version, so according to https://github.com/google/jax/issues/9220 I tried to upgrade just the nightly withpip install --upgrade libtpu-nightly -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
but import crashes on TPUVM.