`jit(pmap(f))` causes inefficient behavior
See original GitHub issueCombining jit
with pmap
produces some undesirable and surprising behaviors.
As one example, any lazy intermediate constants used by the function get instantiated and copied to every device. For instance:
@jax.jit
@jax.pmap
def foo(x):
z = jnp.zeros((500_000_000,))
return jax.lax.tie_in(z, x)
foo(jnp.arange(16).reshape((8, 2)))
This causes 2GB of data to be allocated on each device (and, right now, if this is the only computation you run, this can be verified by looking at list(list(jax.pxla.parallel_callable.__closure__[1].cell_contents.items())[0][1].values())[-1][0].__closure__[0].cell_contents
, but that might break).
Relatedly, the jit
causes the return value to be copied back to a single host instead of staying as a ShardedDeviceArray.
Ideally, adding jit
would not make behavior worse. But having a warning when such a situation occurs would also be useful here, since pmap
on its own does the right thing.
Issue Analytics
- State:
- Created 3 years ago
- Comments:8 (3 by maintainers)
Good idea - I updated some of the documentation in #10757
Since the original issue has been addressed by #3426 and there is a warning for jit of pmap now I’ll close this issue. Please feel free to reopen or file a new issue!