question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Regularize jax.numpy API

See original GitHub issue

I’ve been picking away at this for the past year or so (since #3038), but I wanted to track it a bit more formally here.

We want the jax.numpy API to have two properties:

  • static arguments should generally be checked with core.concrete_or_error(). This helps localize concretization errors and provides a more uniform user experience.

  • dynamic arguments should generally be validated with _check_arraylike. This is because passing lists to jax functions can be a quiet source of performance degradation, because lists are treated as pytrees, for example:

In [1]: import jax.numpy as jnp                                                                                                                              

In [2]: from jax import jit                                                                                                                                  

In [3]: f = jit(lambda x: jnp.mean(jnp.asarray(x)))

In [4]: %timeit f([float(i) for i in range(1000)])                                                                                                           
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
1.79 ms ± 556 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [5]: %timeit f(jnp.array([float(i) for i in range(1000)]))                                                                                                
375 µs ± 4.76 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

This change may be problematic because it breaks users who may be passing lists to jax.numpy functions, but this input type restriction has long been documented. Still, with each change I plan to run a full set of tests and fix downstream packages if necessary.

Issue Analytics

  • State:open
  • Created 2 years ago
  • Reactions:4
  • Comments:5 (2 by maintainers)

github_iconTop GitHub Comments

2reactions
jakevdpcommented, Aug 27, 2021

@PhilipVinc - thanks for the response. I don’t think we will bump the minor version due to this work. At this point, virtually every JAX release includes backward-incompatible breakages at some level. So far we have bumped the minor version only once, when we landed the omnistaging change that significantly re-designed JAX’s staging and dispatch behavior. This change is essentially just stricter input validation, and brings JAX’s implementation more fully in line with its documented behavior. In my judgment, that doesn’t warrant a minor version bump.

1reaction
jakevdpcommented, Oct 20, 2021

I appreciate the feedback… I think this is the first time in my life someone has thanked me for breaking their code 😁

Read more comments on GitHub >

github_iconTop Results From Across the Web

numpy.fix() - JAX documentation - Read the Docs
LAX-backend implementation of numpy.fix() . Original docstring below. Round an array of floats element-wise to nearest integer towards zero. The rounded values ...
Read more >
How to use the jax.numpy function in jax - Snyk
To help you get started, we've selected a few jax.numpy examples, based on popular ways it is used in public projects.
Read more >
Implicit differentiation of ridge regression. - JAXopt
from absl import app import jax import jax.numpy as jnp from jaxopt import implicit_diff from jaxopt import linear_solve from jaxopt import OptaxSolver ...
Read more >
NumPyro documentation - Pyro
To install NumPyro with the latest CPU version of JAX, you can use pip: ... with the similarity in the API for NumPy...
Read more >
jaxnet - PyPI
JAXnet's functional API provides unique benefits over TensorFlow2, Keras and PyTorch, ... All modules are composed in this way. jax.numpy is mirroring numpy...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found