A catch-all protocol for numpy-like duck arrays
See original GitHub issueThere are several functions for which I would like to see protocols constructed. I’ve raised issues for #11074 and #11128 but these are just special cases of a much larger issue that includes many operations. The sense I’ve gotten is that the process to change numpy takes a while, so I’m inclined to find a catch-all solution that can serve as a catch-all while things evolve.
To that end I propose that duck-arrays include a method that returns a module that mimics the numpy namespace
class ndarray:
def __array_module__(self):
import numpy as np
return np
class DaskArray:
def __array_module__(self):
import dask.array as da
return da
class CuPyArray:
def __array_module__(self):
import cupy as cp
return cp
class SparseArray:
def __array_module__(self):
import sparse
return sparse
...
Then, in various functions like stack or concatenate we check for these modules
def stack(args, **kwargs):
modules = {arg.__array_module__() for arg in args}
if len(modules) == 1:
module = list(modules)[0]
if module != numpy:
return module.stack(args, **kwargs)
...
There are likely several things wrong the implementation above, but my hope is that it gets a general point across that we’ll dispatch wholesale to the module of the provided duck arrays.
Issue Analytics
- State:
- Created 5 years ago
- Reactions:1
- Comments:11 (11 by maintainers)
My main concern with this approach is that top level functions should be raising
TypeError
rather than returningNotImplemented
.For example, consider Python arithmetic (on which
__array_ufunc__
was modeled) between two custom types that implement that implement the appropriate special methods (__add__
and__radd__
), but that don’t know about each other:However, I do like the idea of a generic method for NumPy functions that aren’t ufuncs. I would still make this a method on array objects, though, e.g.,
__array_function__
. NumPy’s implementation offunc
would callarg.__array_function__(func, *args, **kwargs)
in turn on each array argument to a function, and return the first result that is notNotImplemented
.In most cases, you could write something like the following:
Why restrict this module protocol to certain types at all? We can follow the same algorithm as, for example,
__array_ufunc__
. Here’s some example code:(Apologies for the long post)
sandbox.py