question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

change `scan` and `while_loop` impls to Python versions

See original GitHub issue

Currently lax.while_loop (and as a consequence lax.fori_loop and lax.scan) incur compilation time every time they’re evaluated in op-by-op mode, making them seem slow to execute without being inside an @jit (since the @jit will handle the caching). We should remedy that, and make the op-by-op impl rules fast. (Separately, while we’re looking at this code, we might be able to replace _while_loop_translation_rule by calling xla.lower_fun on its new impl.)

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Reactions:6
  • Comments:8 (8 by maintainers)

github_iconTop GitHub Comments

1reaction
mattjjcommented, Aug 7, 2019

We haven’t forgotten about this! We’re working on a rewrite to remove tuples from the jaxpr language (it’s an internals-only change, it won’t change the API anywhere) and as part of that change we’re revising the control flow pretty heavily (because it uses a lot of tuples). Once that other change lands we should fix these performance bugs once and for all!

1reaction
fehiepsicommented, Jul 6, 2019

Hi @mattjj , previously, we observed that jit(lax.fori_loop, ...) is faster than lax.fori_loop depending on body_fn. I think that the following code (a little verbose to make it explicit what I want to illustrate) is good for benchmark because it illustrates some problems with the latest PR:

import time
from jax import grad, jit, lax

def dual_averaging(t0=10, kappa=0.75, gamma=0.05):
    def init_fn(prox_center=0.):
        x_t = 0.
        x_avg = 0.  # average of primal sequence
        g_avg = 0.  # average of dual sequence
        t = 0
        return x_t, x_avg, g_avg, t, prox_center

    def update_fn(g, state):
        x_t, x_avg, g_avg, t, prox_center = state
        t = t + 1
        g_avg = (1 - 1 / (t + t0)) * g_avg + g / (t + t0)
        x_t = prox_center - (t ** 0.5) / gamma * g_avg
        weight_t = t ** (-kappa)
        x_avg = (1 - weight_t) * x_avg + weight_t * x_t
        return x_t, x_avg, g_avg, t, prox_center

    return init_fn, update_fn

def optimize(f):
    da_init, da_update = dual_averaging(gamma=0.5)
    init_state = da_init()
    
    def body_fn(i, state):
        x = state[0]
        g = grad(f)(x)
        return da_update(g, state)

    last_state = lax.fori_loop(0, 1000, body_fn, init_state)
    x_avg = last_state[1]
    return x_avg

def optimize_v1(f):
    da_init, da_update = dual_averaging(gamma=0.5)
    init_state = da_init()
    
    @jit
    def body_fn(i, state):
        x = state[0]
        g = grad(f)(x)
        return da_update(g, state)

    last_state = lax.fori_loop(0, 1000, body_fn, init_state)
    x_avg = last_state[1]
    return x_avg

def optimize_v2(f):
    da_init, da_update = dual_averaging(gamma=0.5)
    init_state = da_init()
    
    @jit
    def body_fn(i, state):
        x = state[0]
        g = grad(f)(x)
        return da_update(g, state)

    last_state = init_state
    for i in range(1000):
        last_state = body_fn(i, last_state)
    x_avg = last_state[1]
    return x_avg

def optimize_v3(f):
    da_init, da_update = dual_averaging(gamma=0.5)
    init_state = da_init()
    
    def body_fn(i, state):
        x = state[0]
        g = grad(f)(x)
        return da_update(g, state)

    last_state = jit(lax.fori_loop, static_argnums=(2,))(0, 1000, body_fn, init_state)
    x_avg = last_state[1]
    return x_avg

f = lambda x: (x + 1) ** 2

The target is to use various versions of optimize on f. Here are the result with latest PR:

tic = time.time()
print(optimize(f).copy())
print("time before compiling:", time.time() - tic)

fn = jit(optimize, static_argnums=(0,))
tic = time.time()
print(fn(f).copy())
print("time with compiling:", time.time() - tic)

tic = time.time()
print(fn(f).copy())
print("time after compiling:", time.time() - tic)

print("===v1===")

tic = time.time()
print(optimize_v1(f).copy())
print("time before compiling:", time.time() - tic)

fn = jit(optimize_v1, static_argnums=(0,))
tic = time.time()
print(fn(f).copy())
print("time with compiling:", time.time() - tic)

tic = time.time()
print(fn(f).copy())
print("time after compiling:", time.time() - tic)

print("===v2===")

tic = time.time()
print(optimize_v2(f).copy())
print("time before compiling:", time.time() - tic)

fn = jit(optimize_v2, static_argnums=(0,))
tic = time.time()
print(fn(f).copy())
print("time with compiling:", time.time() - tic)

tic = time.time()
print(fn(f).copy())
print("time after compiling:", time.time() - tic)

print("===v3===")

tic = time.time()
print(optimize_v3(f).copy())
print("time before compiling:", time.time() - tic)

fn = jit(optimize_v3, static_argnums=(0,))
tic = time.time()
print(fn(f).copy())
print("time with compiling:", time.time() - tic)

tic = time.time()
print(fn(f).copy())
print("time after compiling:", time.time() - tic)
-0.99569756
time before compiling: 1.5503859519958496
-0.99569756
time with compiling: 1.3272857666015625
-0.99569756
time after compiling: 0.0006554126739501953
===v1===
-0.99569726
time before compiling: 15.485527992248535
-0.99569726
time with compiling: 15.443722009658813
-0.99569726
time after compiling: 0.000667572021484375
===v2===
-0.99569726
time before compiling: 0.29051971435546875
-0.99569726
time with compiling: 0.2970407009124756
-0.99569726
time after compiling: 0.0004899501800537109
===v3===
-0.99569726
time before compiling: 0.034581899642944336
-0.99569726
time with compiling: 0.038028717041015625
-0.99569726
time after compiling: 0.0004868507385253906

While before the PR, I get

-0.99569726
time before compiling: 0.1041407585144043
-0.99569726
time with compiling: 0.039406776428222656
-0.99569726
time after compiling: 0.00051116943359375
===v1===
-0.99569726
time before compiling: 0.04027819633483887
-0.99569726
time with compiling: 0.04513883590698242
-0.99569726
time after compiling: 0.0003654956817626953
===v2===
-0.99569726
time before compiling: 0.2962920665740967
-0.99569726
time with compiling: 0.2989203929901123
-0.99569726
time after compiling: 0.0006308555603027344
===v3===
-0.99569726
time before compiling: 0.03446555137634277
-0.99569726
time with compiling: 0.03931307792663574
-0.99569726
time after compiling: 0.0003647804260253906

I can observe that the behaviour of lax.fori_loop outside jit has been changed with the last PR, and seems worse than before. (btw, while playing with some benchmark codes, I observed that in recent versions of jax (e.g. v0.1.39), lax.fori_loop seems a bit faster than jit(lax.fori_loop, ...) and I am unable to make an example to show jit(lax.fori_loop, ...) is faster than lax.fori_loop).

Read more comments on GitHub >

github_iconTop Results From Across the Web

How to change condition from a while loop - python
When I execute this code, the While loop ignores the completed condition and keeps asking for adding more books even though the user...
Read more >
Look Ma, No For-Loops: Array Programming With NumPy
In this tutorial, you'll see step by step how to take advantage of vectorization and broadcasting, so that you can use NumPy to...
Read more >
2. Lexical analysis — Python 3.11.1 documentation
A Python program is read by a parser. Input to the parser is a stream of tokens, generated by the lexical analyzer. This...
Read more >
How To Construct While Loops in Python 3 - DigitalOcean
One way to repeat similar tasks is through using loops. We'll be covering Python's while loop in this tutorial. A while loop implements...
Read more >
while (Boolean condition) statement;
A compound statement is a bunch of statements enclosed by curly braces! } • A Boolean condition is either true or false. •...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found