question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

`lax.scan` ~100x slower than recursion?

See original GitHub issue

Working on inference on network infection cascades, eventually with numpyro. At the moment, trying to make a fastgenerative model, taking some cues from here on fast sequential loops. Note use of lax.scan.

However, it seems that (surprisingly) scan seems to run significantly slower than recursion(!).

Here’s an example setup, I’ve tried to comment as best I could, quickly. Running in a notebook.

import numpy as onp
import jax.numpy as np
from jax.random import PRNGKey
# from jax.config import config
from jax import jit, grad, lax, random, vmap
from jax.ops import index_update, index, index_add

n_nodes = 50
n_edges = n_nodes*(n_nodes-1)//2

trans_times = onp.random.geometric(
    onp.random.beta(2,5,size=(n_edges,)),
    size=(n_edges,)
)

@jit 
def jax_squareform(edgelist, n=n_nodes):
    """edgelist to adj. matrix"""
    empty = np.zeros((n,n))
    half = index_add(empty, index[np.triu_indices(n,1)], edgelist)
    full = half+half.T
    return full

a = jax_squareform(trans_times)  # transition times
x0 = np.array([1]+(n_nodes-1)*[0])  # infect state

from collections import namedtuple
# (infected?, time-left-per-neighbor?)
InfectState = namedtuple('InfectState', ['x', 's_ij'])

@jit
def infect(state, step=1):
    neighbor_set = state.s_ij*state.x  # who knows an infected node?
    getting_infected=np.any(neighbor_set==1, axis=1) # and is getting infected now?
    x_p = lax.clamp(0,state.x+getting_infected, 1) # update infections
    s_ij_p = lax.max(state.s_ij - step*getting_infected, 0.) # and time-left
    return InfectState(x=x_p, s_ij=s_ij_p), step  # new state

So at this point, the individual time-steps are running real fast

>>>%timeit infect(InfectState(x0, a))
211 µs ± 3.03 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

Now let’s make the loops:

def pandemic(state, step=1, t=0, T=5):
    state_p, _ = infect(state, step=step)

    if t==T:
        return state_p
    elif (t>=0) and (t<T):
        return pandemic(state_p, t=t+step)
    else:
        print('INVALID t!')

def pandemic_scan(state, step=1, t=0, T=5):
    return lax.scan(
        infect, 
        InfectState(x0, a), 
        np.full(T,step)
    )

So the difference is pretty stark:

>>> %timeit pandemic(InfectState(x0, a), T=5)
1.42 ms ± 4.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

>>> %timeit pandemic_scan(InfectState(x0, a), T=5)
164 ms ± 638 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

So here’s the key questions:

  1. Is there a more idiomatic way to use scan in this case that avoids whatever slow-down is occurring? The ultimate use-case involves inference around a bunch of x(T) observations to estimate x0 for each x(T) and an overall s_ij given all of them. So presumably this need’s to be fast.

  2. Is there a version of pandemic within the jax ecosystem that might allow jit-compilation? It seems that dependence on the boolean comparison of t<T is causing it to complain about static_argnums, etc.

Issue Analytics

  • State:open
  • Created 4 years ago
  • Comments:6 (3 by maintainers)

github_iconTop GitHub Comments

2reactions
shoyercommented, Feb 19, 2020

You can get a good sense of the problem if you run computation through Python profile, like %prun in IPython. You’ll see that your code is getting compiled each time it’s run, instead of reusing the same compiled code.

The immediate source of the problem here is that lax.scan effectively always calls jit on its function argument, but no reference to that function is saved. It’s the same issue in your “local” version. Each jit is effectively being run from scratch, which means caching fails.

This definitely known behavior (and it’s likely unavoidable) but it clearly isn’t well documented. We can and should fix that! 😃

0reactions
tbsextoncommented, Feb 18, 2020

@shoyer Ok, so yeah that appears to have done it:

import functools

@functools.partial(jit, static_argnums=(1,2,3))
def _pandemic_scan(state, step, t, T):
    return lax.scan(
        infect, 
        InfectState(x0, a), 
        np.full(T,step)
    )

def pandemic_scan(state, step=1, t=0, T=5):
    return _pandemic_scan(state, step, t, T)
>>> %timeit pandemic_scan(InfectState(x0, a), T=5)
141 µs ± 19.7 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

So I’m going to have to ask, why jit-compiling upon function-call is faster? What exactly is going on here, and what prevents a trick like this from being integrated into the default behaviour of jit?

For the heck of it I tried a “local” version of this, where the “private” version is only defined in the outer function’s scope:

def pandemic_scan(state, step=1, t=0, T=5):
    
    @functools.partial(jit, static_argnums=(1,2,3))
    def _pandemic_scan(state, step, t, T):
        return lax.scan(
            infect, 
            InfectState(x0, a), 
            np.full(T,step)
        )

    return _pandemic_scan(state, step, t, T)

and the result:

272 ms ± 4.32 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

So…what exactly is going on here? Is this documented behaviour?

Thanks again for your help!

Read more comments on GitHub >

github_iconTop Results From Across the Web

BFD - River Thames Conditions
#1.8 Sentencia bamaca velasquez vs guatemala, Dirk baumgarte langenhagen. ... Scan at 12 weeks 6 days, Vierasmajat ja aitat, Variety of food in...
Read more >
Version 1.4.0 released - Nim Blog
Other than version 1.0, this is probably the biggest Nim release yet and ... Fixed “ nim dump is roughly 100x slower in...
Read more >
Untitled
#glass Fake security cameras vs real, One piece english dubbed movies, #Normal range ... Round lake beach police scanner, Pesawat latih jatuh di...
Read more >
ModelSim SE User's Manual
Mentor Graphics reserves the right to make changes in specifications and other information contained in this publication without prior notice, ...
Read more >
Issue with jax.lax.scan - python - Stack Overflow
lax.scan instead of a for loop with 100 iterations at line 22. I am supposed to update S and append it to S_list....
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found