question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

jnp.argsort much slower than the numpy version

See original GitHub issue

Here’s a comparison of the JAX and numpy versions of argsort on a CPU:

import numpy as np
import jax.numpy as jnp
from jax import config, random
config.update('jax_platform_name', 'cpu')

key = random.PRNGKey(42)
key, subkey = random.split(key)

x_jnp = random.uniform(subkey, (100, 10000))
x_np = np.array(x_jnp)

%%timeit
np.argsort(x_np, axis=0)

%%timeit
jnp.argsort(x_jnp, axis=0).block_until_ready()

In this case jnp.argsort is ~5X slower than than np.argsort. I’m seeing >20x difference with more realistic arrays. Why is there such a large difference in performance between the two implementations?

Issue Analytics

  • State:open
  • Created a year ago
  • Comments:5 (2 by maintainers)

github_iconTop GitHub Comments

1reaction
jakevdpcommented, Apr 24, 2022

You might find this FAQ helpful: FAQ: Is JAX Faster Than NumPy?.

0reactions
hawkinspcommented, Apr 25, 2022

I also note that the slowness is specific to floating-point values. Sorting int32 values is significantly faster. The only difference between the two as far as I can tell is the comparison function.

Read more comments on GitHub >

github_iconTop Results From Across the Web

`jnp.array` is much slower than `np.array` when called on a ...
I assume the additional overhead of converting the NumPy array to JAX is negligible if they are on the same device? It might...
Read more >
numpy argsort slow performance - python - Stack Overflow
Note that the Numba's implementation of argsort is less efficient than the one of Numpy but the parallel version should be much faster...
Read more >
JAX Frequently Asked Questions (FAQ)
For example, if we switch this example to use 10x10 input instead, JAX/GPU runs 10x slower than NumPy/CPU (100 µs vs 10 µs)....
Read more >
numpy.argsort — NumPy v1.24 Manual
It returns an array of indices of the same shape as a that index data along the given axis in sorted order. Parameters:...
Read more >
Bayesian Regression Using NumPyro
Let us now write a regressionn model in NumPyro to predict the divorce rate as a ... y_hpdi): # Sort values for plotting...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found