question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Wrapping a slow Python function in an asynchronous DeviceArray

See original GitHub issue

I have a slow Python computation that produces a pytree (in the motivating case, it loads weights over a network). I know the shapes and dtypes in advance, and would like to be able to wrap up my function as a jnp.DeviceArray so that it can be used as if it was a normal array by further jax computation. Assuming for the moment that instead of a pytree our function returns a single array, the code would look like

def async_wrap(f, *, shape, dtype):
  """Produces a jnp.DeviceArray of given shape and dtype whose value if f().
  
  `f()` is run on a separate thread, so that this function returns immediately.
  If `f()` can't be converted to an array of the given shape and dtype, an exception is thrown.
  """
  ...

def slow():
  time.sleep(10)  # Note that this releases the GIL
  return jnp.arange(4, dtype=np.int32)

@jax.jit
def square(x):
  return x * x

def main():
  # Returns quickly
  x = async_wrap(slow, shape=(4,), dtype=np.int32)
  
  # Add more jax computation on top of x.
  # The compilation of square is overlapped with slow().
  # Returns either immediately or once compilation completes (I forget what happens normally)
  y = square(x)

  # Blocks for about 10 seconds, then prints roughly [0, 1, 4, 9]
  print(y)

Naively, I’d expect that async_wrap (which is a terrible name) could mostly reuse existing asynchrony machinery inside Jax, but I’m not confident of that.

The actual async_wrap interface would want to support pytrees. Ignoring the number of threads created, the pytree version is implementable on top of the single array version, but we probably don’t want to ignore the number of threads created.

Issue Analytics

  • State:open
  • Created a year ago
  • Reactions:1
  • Comments:6 (2 by maintainers)

github_iconTop GitHub Comments

1reaction
girvingcommented, Jun 6, 2022

@froystig Should we synchronize some of the offline discussion into this bug thread? I’m also not sure what the correct path forward is based on that discussion.

0reactions
froystigcommented, Jun 13, 2022

The off-thread discussion suggested defining a thunk-backed jax array, without saying the word “future” at first. Specifically, think of a python object (call it LazyDeviceArray) with the same interface as jax’s (say DeviceArray), but whose construction takes a function:

thunk :: () -> DeviceArray

An instance of LazyDeviceArray carries the information needed to correspond to a jax type (typically shape and dtype). The moment it is involved in an operation that requires its data, it defers to thunk() (say, also checking that thunk’s output indeed matches the shape/dtype info it carries).

The idea is that it might be easy enough to back thunk by a future array, whether that’s done using concurrent.futures or something else. Intuitively, this seems doable: kick off the asynchronous computation prior to constructing the LazyDeviceArray, and supply a thunk that blocks until the future array is ready.

I suspect that concurrent or asyncio might help make a driving example, although I understand we may not want to build on those further than that. Also, whether this is a proposed addition to jax, and whether jax ought to offers utilities for futures alongside it, are questions for later (my first guess is “no” on the latter but we have more to learn first).

Read more comments on GitHub >

github_iconTop Results From Across the Web

JAX Frequently Asked Questions (FAQ)
If you have a Python function that changes behavior after using jax.jit() ... fact that wrapping a function with jit() can change the...
Read more >
How can I wrap a synchronous function in an async coroutine?
The method I was looking for is run_in_executor. This allows a synchronous function to be run asynchronously without blocking an event loop.
Read more >
Wrapping async functions for use in sync code
I have a frustrating problem. I'm writing an application that downloads files, unzips them and extracts a single file from the archive.
Read more >
Performance Tips — Numba 0.50.1 documentation
No Python mode vs Object mode¶. A common pattern is to decorate functions with @jit as this is the most flexible decorator offered...
Read more >
CuPy Documentation - Read the Docs
Moving a device array to the host can be done by cupy.asnumpy() as ... In this section, a Python function wrapped with the...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found