Wrapping a slow Python function in an asynchronous DeviceArray
See original GitHub issueI have a slow Python computation that produces a pytree (in the motivating case, it loads weights over a network). I know the shapes and dtypes in advance, and would like to be able to wrap up my function as a jnp.DeviceArray
so that it can be used as if it was a normal array by further jax computation. Assuming for the moment that instead of a pytree our function returns a single array, the code would look like
def async_wrap(f, *, shape, dtype):
"""Produces a jnp.DeviceArray of given shape and dtype whose value if f().
`f()` is run on a separate thread, so that this function returns immediately.
If `f()` can't be converted to an array of the given shape and dtype, an exception is thrown.
"""
...
def slow():
time.sleep(10) # Note that this releases the GIL
return jnp.arange(4, dtype=np.int32)
@jax.jit
def square(x):
return x * x
def main():
# Returns quickly
x = async_wrap(slow, shape=(4,), dtype=np.int32)
# Add more jax computation on top of x.
# The compilation of square is overlapped with slow().
# Returns either immediately or once compilation completes (I forget what happens normally)
y = square(x)
# Blocks for about 10 seconds, then prints roughly [0, 1, 4, 9]
print(y)
Naively, I’d expect that async_wrap
(which is a terrible name) could mostly reuse existing asynchrony machinery inside Jax, but I’m not confident of that.
The actual async_wrap
interface would want to support pytrees. Ignoring the number of threads created, the pytree version is implementable on top of the single array version, but we probably don’t want to ignore the number of threads created.
Issue Analytics
- State:
- Created a year ago
- Reactions:1
- Comments:6 (2 by maintainers)
@froystig Should we synchronize some of the offline discussion into this bug thread? I’m also not sure what the correct path forward is based on that discussion.
The off-thread discussion suggested defining a thunk-backed jax array, without saying the word “future” at first. Specifically, think of a python object (call it
LazyDeviceArray
) with the same interface as jax’s (sayDeviceArray
), but whose construction takes a function:An instance of
LazyDeviceArray
carries the information needed to correspond to a jax type (typically shape and dtype). The moment it is involved in an operation that requires its data, it defers tothunk()
(say, also checking thatthunk
’s output indeed matches the shape/dtype info it carries).The idea is that it might be easy enough to back
thunk
by a future array, whether that’s done usingconcurrent.futures
or something else. Intuitively, this seems doable: kick off the asynchronous computation prior to constructing theLazyDeviceArray
, and supply athunk
that blocks until the future array is ready.I suspect that
concurrent
orasyncio
might help make a driving example, although I understand we may not want to build on those further than that. Also, whether this is a proposed addition to jax, and whether jax ought to offers utilities for futures alongside it, are questions for later (my first guess is “no” on the latter but we have more to learn first).