jnp.argsort much slower than the numpy version
See original GitHub issueHere’s a comparison of the JAX and numpy versions of argsort
on a CPU:
import numpy as np
import jax.numpy as jnp
from jax import config, random
config.update('jax_platform_name', 'cpu')
key = random.PRNGKey(42)
key, subkey = random.split(key)
x_jnp = random.uniform(subkey, (100, 10000))
x_np = np.array(x_jnp)
%%timeit
np.argsort(x_np, axis=0)
%%timeit
jnp.argsort(x_jnp, axis=0).block_until_ready()
In this case jnp.argsort
is ~5X slower than than np.argsort
. I’m seeing >20x difference with more realistic arrays. Why is there such a large difference in performance between the two implementations?
Issue Analytics
- State:
- Created a year ago
- Comments:5 (2 by maintainers)
Top Results From Across the Web
`jnp.array` is much slower than `np.array` when called on a ...
I assume the additional overhead of converting the NumPy array to JAX is negligible if they are on the same device? It might...
Read more >numpy argsort slow performance - python - Stack Overflow
Note that the Numba's implementation of argsort is less efficient than the one of Numpy but the parallel version should be much faster...
Read more >JAX Frequently Asked Questions (FAQ)
For example, if we switch this example to use 10x10 input instead, JAX/GPU runs 10x slower than NumPy/CPU (100 µs vs 10 µs)....
Read more >numpy.argsort — NumPy v1.24 Manual
It returns an array of indices of the same shape as a that index data along the given axis in sorted order. Parameters:...
Read more >Bayesian Regression Using NumPyro
Let us now write a regressionn model in NumPyro to predict the divorce rate as a ... y_hpdi): # Sort values for plotting...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
You might find this FAQ helpful: FAQ: Is JAX Faster Than NumPy?.
I also note that the slowness is specific to floating-point values. Sorting int32 values is significantly faster. The only difference between the two as far as I can tell is the comparison function.