Regularize jax.numpy API
See original GitHub issueI’ve been picking away at this for the past year or so (since #3038), but I wanted to track it a bit more formally here.
We want the jax.numpy API to have two properties:
-
static arguments should generally be checked with
core.concrete_or_error()
. This helps localize concretization errors and provides a more uniform user experience. -
dynamic arguments should generally be validated with
_check_arraylike
. This is because passing lists to jax functions can be a quiet source of performance degradation, because lists are treated as pytrees, for example:
In [1]: import jax.numpy as jnp
In [2]: from jax import jit
In [3]: f = jit(lambda x: jnp.mean(jnp.asarray(x)))
In [4]: %timeit f([float(i) for i in range(1000)])
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
1.79 ms ± 556 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [5]: %timeit f(jnp.array([float(i) for i in range(1000)]))
375 µs ± 4.76 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
This change may be problematic because it breaks users who may be passing lists to jax.numpy
functions, but this input type restriction has long been documented. Still, with each change I plan to run a full set of tests and fix downstream packages if necessary.
Issue Analytics
- State:
- Created 2 years ago
- Reactions:4
- Comments:5 (2 by maintainers)
@PhilipVinc - thanks for the response. I don’t think we will bump the minor version due to this work. At this point, virtually every JAX release includes backward-incompatible breakages at some level. So far we have bumped the minor version only once, when we landed the omnistaging change that significantly re-designed JAX’s staging and dispatch behavior. This change is essentially just stricter input validation, and brings JAX’s implementation more fully in line with its documented behavior. In my judgment, that doesn’t warrant a minor version bump.
I appreciate the feedback… I think this is the first time in my life someone has thanked me for breaking their code 😁