question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[Bug?] Jax shuffle speed versus numpy random shuffle speed

See original GitHub issue

I was reexamining a training loop I’d made in jax for a talk, and noticed that when I tried to demonstrate that using jax.random.shuffle gives a speedup over using numpy.random.shuffle, I found that the latter was actually significantly faster. I then went and recreated it in a small script, and I’ve found that the jax shuffle method gets really slow for ndarrays of around 1e6 and 1e7, both of which are sizes we easily encounter.

The demonstration code and speeds are in this gist, but basically at around 1e7 numpy.random.shuffle takes a neat 3 seconds while jax.random.shuffle clocks in at over 4 minutes!

Is this intentional? Is it unavoidable due to not being able to do it in-place? I tried reading into the JAX shuffle implementation a bit and it seems like a lot of effort goes into making it rather fast, and the numpy random shuffle has the advantage of doing everything in place, so that would all make sense to me.

Issue Analytics

  • State:open
  • Created 4 years ago
  • Comments:5 (3 by maintainers)

github_iconTop GitHub Comments

2reactions
jakevdpcommented, Jun 24, 2022

Revisiting this: jax.random.shuffle has been deprecated in favor of jax.random.permutation, but the performance issue remains in the current main branch of JAX (run on a Colab CPU):

import numpy as np
from jax import random
import jax.numpy as jnp
import jax

x = np.random.rand(1_000_000).astype('float32')

%timeit np.random.permutation(x)
# 10 loops, best of 5: 42.5 ms per loop

x_jax = jnp.asarray(x)
key = random.PRNGKey(7548923)
_ = random.permutation(key, x_jax)
%timeit random.permutation(key, x_jax).block_until_ready()
# 1 loop, best of 5: 1.24 s per loop

# routine is already JIT-compiled, so another JIT doesn't help much:
perm_jit = jax.jit(random.permutation)
_ = perm_jit(key, x_jax)
%timeit perm_jit(key, x_jax).block_until_ready()
# 1 loop, best of 5: 1.19 s per loop

On GPU, JAX looks better:

10 loops, best of 5: 57.6 ms per loop
10 loop, best of 5: 16.4 ms per loop
100 loops, best of 5: 10.6 ms per loop

I’m going to assign to @froystig as he’s been working on PRNG updates and may have ideas about how to improve performance.

0reactions
akbircommented, Jul 26, 2022

Keen to hear if there is any update on this. (Happy to also help if directed!!)

Read more comments on GitHub >

github_iconTop Results From Across the Web

jax.random.shuffle - JAX documentation
Shuffle the elements of an array uniformly at random along an axis. Parameters. key ( Union [ Array , PRNGKeyArray ]) – a...
Read more >
numpy.random.shuffle — NumPy v1.24 Manual
Modify a sequence in-place by shuffling its contents. This function only shuffles the array along the first axis of a multi-dimensional array. The...
Read more >
np.shuffle much slower than np.random.choice - Stack Overflow
random.shuffle was bottlenecking my application. I tried replacing the shuffle with a call to np.random.choice and experienced a 10x speed-up.
Read more >
Process - Hugging Face
The shuffle() function randomly rearranges the column values. ... The primary purpose of map() is to speed up processing functions. It allows you...
Read more >
Using NumPy to Speed Up K-Means Clustering by 70x
We use NumPy to speed up the k-means clustering algorithm, then use cProfile to find ... data3, data4), axis = 0) # Shuffle...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found