`lax.scan` ~100x slower than recursion?
See original GitHub issueWorking on inference on network infection cascades, eventually with numpyro. At the moment, trying to make a fastgenerative model, taking some cues from here on fast sequential loops. Note use of lax.scan
.
However, it seems that (surprisingly) scan
seems to run significantly slower than recursion(!).
Here’s an example setup, I’ve tried to comment as best I could, quickly. Running in a notebook.
import numpy as onp
import jax.numpy as np
from jax.random import PRNGKey
# from jax.config import config
from jax import jit, grad, lax, random, vmap
from jax.ops import index_update, index, index_add
n_nodes = 50
n_edges = n_nodes*(n_nodes-1)//2
trans_times = onp.random.geometric(
onp.random.beta(2,5,size=(n_edges,)),
size=(n_edges,)
)
@jit
def jax_squareform(edgelist, n=n_nodes):
"""edgelist to adj. matrix"""
empty = np.zeros((n,n))
half = index_add(empty, index[np.triu_indices(n,1)], edgelist)
full = half+half.T
return full
a = jax_squareform(trans_times) # transition times
x0 = np.array([1]+(n_nodes-1)*[0]) # infect state
from collections import namedtuple
# (infected?, time-left-per-neighbor?)
InfectState = namedtuple('InfectState', ['x', 's_ij'])
@jit
def infect(state, step=1):
neighbor_set = state.s_ij*state.x # who knows an infected node?
getting_infected=np.any(neighbor_set==1, axis=1) # and is getting infected now?
x_p = lax.clamp(0,state.x+getting_infected, 1) # update infections
s_ij_p = lax.max(state.s_ij - step*getting_infected, 0.) # and time-left
return InfectState(x=x_p, s_ij=s_ij_p), step # new state
So at this point, the individual time-steps are running real fast
>>>%timeit infect(InfectState(x0, a))
211 µs ± 3.03 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Now let’s make the loops:
def pandemic(state, step=1, t=0, T=5):
state_p, _ = infect(state, step=step)
if t==T:
return state_p
elif (t>=0) and (t<T):
return pandemic(state_p, t=t+step)
else:
print('INVALID t!')
def pandemic_scan(state, step=1, t=0, T=5):
return lax.scan(
infect,
InfectState(x0, a),
np.full(T,step)
)
So the difference is pretty stark:
>>> %timeit pandemic(InfectState(x0, a), T=5)
1.42 ms ± 4.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
>>> %timeit pandemic_scan(InfectState(x0, a), T=5)
164 ms ± 638 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
So here’s the key questions:
-
Is there a more idiomatic way to use scan in this case that avoids whatever slow-down is occurring? The ultimate use-case involves inference around a bunch of
x(T)
observations to estimatex0
for eachx(T)
and an overalls_ij
given all of them. So presumably this need’s to be fast. -
Is there a version of
pandemic
within the jax ecosystem that might allow jit-compilation? It seems that dependence on the boolean comparison oft<T
is causing it to complain aboutstatic_argnums
, etc.
Issue Analytics
- State:
- Created 4 years ago
- Comments:6 (3 by maintainers)
You can get a good sense of the problem if you run computation through Python profile, like
%prun
in IPython. You’ll see that your code is getting compiled each time it’s run, instead of reusing the same compiled code.The immediate source of the problem here is that
lax.scan
effectively always callsjit
on its function argument, but no reference to that function is saved. It’s the same issue in your “local” version. Eachjit
is effectively being run from scratch, which means caching fails.This definitely known behavior (and it’s likely unavoidable) but it clearly isn’t well documented. We can and should fix that! 😃
@shoyer Ok, so yeah that appears to have done it:
So I’m going to have to ask, why jit-compiling upon function-call is faster? What exactly is going on here, and what prevents a trick like this from being integrated into the default behaviour of
jit
?For the heck of it I tried a “local” version of this, where the “private” version is only defined in the outer function’s scope:
and the result:
272 ms ± 4.32 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
So…what exactly is going on here? Is this documented behaviour?
Thanks again for your help!