question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Overrides of NumPy functions on JAX arrays

See original GitHub issue

NumPy has protocols, based on the __array_ufunc__ and __array_function__ methods, that allow for overriding what NumPy functions like np.sin() and np.concatenate when called on other array types.

In practice, this means users can write import numpy as np to get NumPy functions that work on JAX arrays instead of needing to write import jax.numpy as np.

It might make sense to implement these methods on JAX’s array objects. A working prototype of this can be found in https://github.com/google/jax/pull/611.

Reason to do this:

  • This would make possible to write generic code that works with NumPy/JAX/Dask/sparse/whatever, at least in simple cases: you can just use import numpy as np and it will probably work. This is particularly advantageous for third-party libraries (e.g., for projects like opt-einsum or xarray) that want to support multiple backends in a clean, composable way.
  • By opting into NumPy’s API, JAX gets an override API “for free”. This could be useful even if all you care about is supporting operations on JAX arrays. For example: you could write a library that wraps JAX and adds Pytorch 1.3 style named tensors.
  • JAX’s JIT compilation allows for powerful “zero-cost abstraction” like C++ but in Python. There are projects like xarray that could potentially make use of this in a really compelling way, e.g., you could write a simulation with labeled multi-dimensional arrays with unit checking, without any extra performance cost!
  • More generally: it’s a nice integration point with the third-party SciPy/PyData ecosystem. There’s assuredly loads of other cool stuff you could do with it.

Reasons not to do this:

  • This breaks existing code that relying upon NumPy functions coercing arguments to NumPy arrays. Large projects using JAX will probably need to add some explicit calls to onp.asarray(). https://github.com/google/jax/pull/611 includes a handful of examples of this internally in JAX.
  • The implementation is rather complex and a little fragile, especially if we want to accommodate a flag that allows for switching it on and off. This imposes an additional maintenance burden on the JAX team.
  • We don’t yet have any concrete examples of end-user use-cases for this functionality. It would let us easily wrap JAX with xarray, but what would that be good for?

Decision by @mattjj and myself: We’re not going merge this yet, because it’s not clear that anyone would even use it and it imposes a maintenance burden.

If you have compelling use-cases, please speak up. We could relatively easily make this happen, but would need someone who could commit to being a passionate user first.

Issue Analytics

  • State:open
  • Created 4 years ago
  • Reactions:9
  • Comments:8 (2 by maintainers)

github_iconTop GitHub Comments

6reactions
lukasheinrichcommented, Mar 6, 2020

adherence to NEP13 and NEP18 would make it useful to integrate jax into projects that rely on them for portability. Specifically we’re looking to integrate jax w/ scale-out systems like e.g. dask and particle physics libraries like https://github.com/scikit-hep/awkward-array. @jpivarski can probably comment better on the technical details but we’d very much be a passionate user 😃

3reactions
Hoezecommented, Apr 14, 2020

I love the imagination of xarray with jax in the back… Would be so awesome! Also, it’s quite unfortunate that Tensorflow/Jax/… all have different APIs compared to numpy.

Read more comments on GitHub >

github_iconTop Results From Across the Web

JAX Frequently Asked Questions (FAQ)
The recommended solution in this case is to make use of functions like jax.numpy.where() to do your computation on padded arrays with fixed...
Read more >
Nontransitive subclassing with numpy and jax - Stack Overflow
The reason this is the case is because jax.numpy.ndarray overrides instance checks with a metaclass: class _ArrayMeta(type(np.ndarray)): ...
Read more >
NEP 18 — A dispatch mechanism for NumPy's high level array ...
We plan to retain an interface that makes it possible to override NumPy functions, but the way to do so for particular functions...
Read more >
Differentiation using JAX — Awkward Array 2.0.2 documentation
JAX, amongst other things, is a powerful tool for computing derivatives of native Python and NumPy code. Awkward Array implements support for the...
Read more >
Challenge: Basics of JAX - Introduction to JAX and Deep ...
We'll continue the convention of jnp for JAX NumPy. ... JAX and NumPy syntax. In this simple exercise, we are given a JAX...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found