Use case study: Intheon
See original GitHub issueWe’re very happy to see progress in getting the array API standardized – obviously a monumental undertaking when all is considered.
We’ve been using backend-agnostic numpy-style code for over a year now in research and now production, and are gradually rolling it out across a multi-megabyte codebase (supporting mainly numpy, cupy, pytorch, tensorflow, jax (also dask, but not using that at the moment), so I thought I’d share our user story in case it’s helpful. It’s a similar use case than what the array API addresses, but it was built with the current numpy workalike APIs in mind.
Our codebase has an entrypoint equivalent to get get_namespace
, although we call them backends, and a typical use case looks like:
def sqrt_sum(arr1, arr2):
be = backend_for(arr1, arr2)
return be.sqrt(be.asarray(arr1) + be.asarray(arr2))
For each of the backends we have a (usually thin) compatibility layer that adds any missing functions or fixes up issues with the function signature. In our case, backend_for
looks at __array_priority__
to return the namespace for the highest-priority array, although we rarely use it with more than one array (but it results in the above function accepting multiple array types and promoting according to np < dask < {jax, tf, cupy} < torch). Getting a namespace this way is about as fast as doing .T
on a 1x1 numpy array (thanks to some caching) so we use it on even the smallest subroutine.
We use a lot of the API surface of these backends, and typically the most compute-intensive subroutines have an option to choose a preferred backend (via a function backend_get(shorthand)
) out of the subset that supports the necessary ops, or to keep the same backend. We’ve been extremely impressed with the compatibility of cupy, the performance of torch (notably also its cpu arrays), and we have a few places where we prefer tf or jax because they might have a faster implementation of a critical op or parallelize better (e.g. jax). We find that even over a half-year time span, things evolved rapidly enough that the ideal backend pick for a function may change from one to another (one of the reasons we aim for that much flexibility).
We traced our API usage and found that perhaps 90% of our backend-agnostic call sites would be covered by the array API as it exists now. There are a few places that use different aliases for the same functionality (due to some forms being more popular with current backends than others) the most frequent issues being absolute(x)
and concatenate(x)
, other examples being arccos(x)
(all our backends) vs acos(x)
(no backend). We also frequently use convenience shorthands like hstack(x)
, vstack(x)
, ravel(x)
or x.flatten()
, but those could be substituted easily (or provided in a wrapper).
We found a few omissions that would require a bit more code rewriting, among others the functions minimum(a,b)
, maximum(a,b)
, clip(x, lower, upper)
(presumably that would turn into where(x>upper,upper,where(x<lower,lower,x))
. Also we frequently use moveaxis(x,a,b)
and swapaxes(x,a,b)
, e.g., in linear algebra on stacks of matrices (>30 call sites for us). All of these are supported by the above 6 backends already and they’re pretty trivial, fortunately. Our code uses real(x)
in some places since some of the implementations might return spurious complex numbers; that may be reason enough to at least partially specify that function already now. Also reciprocal(x)
occurs frequently in our code, I guess that’s from a suspicion that writing 1/x
may fail to use the reciprocal instruction if it’s available.
A few things that we use have no substitute at this point, unfortunately, namely einsum()
, lstsq(x,y)
, eig(x)
, and sqrtm()
(though the latter could be implemented via eigh
). We hope that these eventually find their way (back) into the API. I realize that lstsq was removed as per a previous discussion (and it’s understandable given that the API is a bit crufty), but then our code base has 26 unique call sites of that alone, since we’re dealing mostly with engineering and stats. One might reasonably assume that backends that already have optimized implementations of that (all 6 do, and torch/tf support batched matrices) will provide it anyway in their array API namespace. However, we do worry that, given that it has been deliberately removed, we can’t be sure that some of the existing backends won’t be encouraged to minimize their maintenance surface and drop functions like that from their formerly numpy-compatible (soon array API compatible) namespace, forcing users like us to deal with their raw EagerTensor, xla_extension.DeviceArray, or whatever it may be called, and go find it in whichever ancestral tensorflow namespace the functionality may have been buried in before. We’re wondering if a tradeoff could be made where e.g., some of the rarely-used outputs could be marked as as “reserved” placeholder and are allowed to hold unspecified values (e.g., None) until perhaps at some future date the API specifies them. There’s also the option to go the same route as with svd
, where some arguments were removed in the interest of simplicity. On the plus side, it’s good to see diag
retired in favor of diagonal
(especially so in the age of batched matrices).
Other than that, for multi-GPU we use a backend-provided context manager where available (torch, tf, cupy) a custom context manager where it’s not (jax), and a no-op context manager for numpy & dask (usage looks like with be.select_device(id):
). That’s because passing device= through all compute functions down into the individual array creation calls (from arange to eye) just isn’t all that practical with a large and deeply nested codebase, and it’s easy to overlook call sites, causing hidden performance bugs that only turn up on multi-accelerator runs – however, since the user can write their own context manager (with a stack in thread-local storage and wrappers around the array creation functions), that can be worked around with some effort.
Lastly, our indexed array accesses in our main codebase (the parts that we hope to eventually port) look like the following:
x[:, :, slice,:] # similar expressions occur hundreds of times with variable counts of :'s
x[:,:,slice,:,slice,:] # hundreds of times, but can fall back to chained x[:,:,:,:,slice,:][:,:,slice,:,:,:]
x[:,:,indices,:] # similar uses occur dozens of times
X[:, an_int, :, :] # probably a few dozen times
X[:, bools, :, :] # a few to a dozen times
x[:, reverse_slice,:] # occasional (not supported by current pytorch)
x[indices, :, indices] # never (we don't use numpy advanced indexing with multiple arrays)
x[:, None, slice] # used as equivalent to np.newaxis
We use a high-level array wrapper (similar in spirit to xarray) that supports arbitrarily strided views and allows writes into those views, which results in low-level calls (in the guts of the array class) equivalent to the form:
be.reshape(be.transpose(x, order), shape)[:,slice,:,:,slice,:] = y # invoked similarly from hundreds of places (indirectly)
… that’s because we spend much of our time dealing with multi-way tensors (e.g., neural data) that have axes such as space, time, frequency, instance, statistic, or feature (often 3-5 at a time), and most subroutines are agnostic to the presence or order of most axes except for one or two, so they create views that move those few to specific places and then read/write through the transposed view. Our way of dealing with backends that don’t support that is not enabling them for those functions (and having feature flags on the backends for reverse indexing, slice assignment, and mesh indexing support to catch cases where we do).
I wasn’t sure if this is the right place to report relevant API usage “in the field”, hopefully it is.
Issue Analytics
- State:
- Created 2 years ago
- Reactions:3
- Comments:9 (7 by maintainers)
Top GitHub Comments
Raised in issue ( https://github.com/data-apis/array-api/issues/482 )
Raised issue ( https://github.com/data-apis/array-api/issues/483 )
Thanks for following up on that! Emailed you.