add an is_ready() for checking async readiness
See original GitHub issueThe asynchronous dispatch page explains the asynchronous model jax uses and why naive benchmarks can be misleading, but it’s difficult to understand what the “jax-y” way of doing more general asynchronous programming is. Are users expected to leverage external tools (e.g. asyncio
) or are there mechanisms in jax which are more efficient?
As a motivating example, consider an RL learner that has an agent continuously filling a replay buffer on one device and a learner tweaking weights on another. Using asyncio
I would write something along the lines of:
import asyncio
async def gather_experience(dataset):
while True:
trajectory = await environment.random_sample()
dataset.add(trajectory)
async def train(params, dataset):
for batch in dataset:
params = await train_step(params, batch)
return params
loop = asyncio.get_event_loop()
gather_task = loop.create_task(gather_experience(dataset))
train_task = loop.create_task(train(params, dataset)
loop.run_until_complete(train_task)
I’ve tried using id_tap
with futures as in the following:
import asyncio
import typing as tp
from jax.experimental.host_callback import id_tap
async def _wait_on(fut: asyncio.Future):
while not fut.done():
await asyncio.sleep(0)
return fut.result()
def as_future(
x: tp.Any, loop: tp.Optional[asyncio.AbstractEventLoop] = None
) -> asyncio.Future:
loop = loop or asyncio.get_event_loop()
fut = loop.create_future()
def set_result(x_val, _):
fut.set_result(x_val)
id_tap(set_result, x)
return loop.create_task(_wait_on(fut))
but it’s not exactly elegant (I have no idea why the _wait_on
is necessary…) and it also forces transfer of data from device to host which isn’t always necessary. An exposed is_ready
method akin to block_until_ready
which doesn’t actually block would make this much nicer - but even then I feel there’s likely some redundancy between jax
’s internal scheduling system and asyncio
’s event loop.
Issue Analytics
- State:
- Created 2 years ago
- Comments:7 (5 by maintainers)
Top GitHub Comments
I changed the title to scope the issue more narrowly to adding an
is_ready
. There’s more here, but this part of the question is concrete and actionable.@jackd thanks for the extra context!
looks like
x.is_ready
results in busy polling whilex.on_done
avoids that@hawkinsp if there are no hidden problems I’m not seeing,
on_done
is strictly better