question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

`jnp.array` is much slower than `np.array` when called on a nested list

See original GitHub issue

Consider the following example (I am using JAX version 0.3.10):

>>> import jax.numpy as jnp
>>> import numpy as np
>>> x = [list(range(1000)) for i in range(1000)]    # A simple nested list
>>> %timeit np.array(x)
79.4 ms ± 6.85 ms per loop (mean ± std. dev. of 7 runs, 10 loops each
>>> %timeit jnp.array(x)
3.22 s ± 174 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

NumPy is orders of magnitude faster than JAX here. A similar problem for non-nested lists (#2919) was addressed in PR #3350.

Issue Analytics

  • State:closed
  • Created a year ago
  • Comments:10 (4 by maintainers)

github_iconTop GitHub Comments

1reaction
jakevdpcommented, May 13, 2022

Sure, one important reason we can’t just call np.array when the argument is a list is that JAX transforms inject tracers in place of values, so np.array will result in an error:

from jax import jit

@jit
def f(x):
  return jnp.array([x, 0, 0])
print(f(x))
# [0. 0. 0.]

@jit
def g(x):
  return np.array([x, 0, 0])
g(x)
# ---------------------------------------------------------------------------
# TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float16[])>with<DynamicJaxprTrace(level=0/1)>
# While tracing the function g at <ipython-input-13-53206b980741>:9 for jit, this concrete value was not available in Python because it depends on the value of the argument 'x'.
# See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

Another more subtle example is when you have a list with a mix of dtypes where type promotion behavior differs between numpy and jax:

import jax.numpy as jnp
import numpy as np

x = jnp.float16(0)

print(np.array([x, 0, 0]).dtype)
# float64

print(jnp.array([x, 0, 0]).dtype)
# float16

JAX treats unspecified python scalars as “weakly-typed”, in that they conform to the types of other “strongly-typed” values they interact with. Numpy treats Python scalars as 64-bit, and tends to promote things to 64-bit indiscriminately. This is problematic when working on accelerators like GPU or TPU, where it is often preferable to work in lower-precision for performance reasons, but users still want the convenience of writing things like 0 and 1 without explicit dtype declaration.

1reaction
jakevdpcommented, May 11, 2022

I’ll note also that even the previous “fast path” that you reference in that PR would slow down by orders of magnitude once jit-compiled, and in many cases would return different results between jit and non-jit execution, which both are behaviors we try to avoid.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Numpy array is much slower than list - Stack Overflow
Given two binary matrices, if you multiply one of them with (1-other) it gives the number of different entries where the first one...
Read more >
jax.numpy package - JAX documentation - Read the Docs
A nicer way to build up index tuples for arrays. Return an array representing the indices of a grid.
Read more >
NumPy array vs nested list - Medium
NumPy arrays can be much faster than nested lists and one good test of performance is a speed comparison. This test is going...
Read more >
Source code for datasets.features.features - Hugging Face
This can be useful to keep only 1-d arrays to instantiate Arrow arrays. ... JAX_AVAILABLE and "jax" in sys.modules: import jax.numpy as jnp...
Read more >
JAX is for Joy, AutoDiff, and Xeleration - Jan Ebert
NumPy and JAX call any n-dimensional tensor an array, so do not expect a ... You can also split off more than one...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found