`jnp.array` is much slower than `np.array` when called on a nested list
See original GitHub issueConsider the following example (I am using JAX version 0.3.10):
>>> import jax.numpy as jnp
>>> import numpy as np
>>> x = [list(range(1000)) for i in range(1000)] # A simple nested list
>>> %timeit np.array(x)
79.4 ms ± 6.85 ms per loop (mean ± std. dev. of 7 runs, 10 loops each
>>> %timeit jnp.array(x)
3.22 s ± 174 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
NumPy is orders of magnitude faster than JAX here. A similar problem for non-nested lists (#2919) was addressed in PR #3350.
Issue Analytics
- State:
- Created a year ago
- Comments:10 (4 by maintainers)
Top Results From Across the Web
Numpy array is much slower than list - Stack Overflow
Given two binary matrices, if you multiply one of them with (1-other) it gives the number of different entries where the first one...
Read more >jax.numpy package - JAX documentation - Read the Docs
A nicer way to build up index tuples for arrays. Return an array representing the indices of a grid.
Read more >NumPy array vs nested list - Medium
NumPy arrays can be much faster than nested lists and one good test of performance is a speed comparison. This test is going...
Read more >Source code for datasets.features.features - Hugging Face
This can be useful to keep only 1-d arrays to instantiate Arrow arrays. ... JAX_AVAILABLE and "jax" in sys.modules: import jax.numpy as jnp...
Read more >JAX is for Joy, AutoDiff, and Xeleration - Jan Ebert
NumPy and JAX call any n-dimensional tensor an array, so do not expect a ... You can also split off more than one...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Sure, one important reason we can’t just call
np.array
when the argument is a list is that JAX transforms inject tracers in place of values, sonp.array
will result in an error:Another more subtle example is when you have a list with a mix of dtypes where type promotion behavior differs between numpy and jax:
JAX treats unspecified python scalars as “weakly-typed”, in that they conform to the types of other “strongly-typed” values they interact with. Numpy treats Python scalars as 64-bit, and tends to promote things to 64-bit indiscriminately. This is problematic when working on accelerators like GPU or TPU, where it is often preferable to work in lower-precision for performance reasons, but users still want the convenience of writing things like
0
and1
without explicit dtype declaration.I’ll note also that even the previous “fast path” that you reference in that PR would slow down by orders of magnitude once jit-compiled, and in many cases would return different results between jit and non-jit execution, which both are behaviors we try to avoid.