question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

jitted function is slow with JAX 0.1.42

See original GitHub issue

Here is a repro script

import jax
import jax.numpy as np
from jax import random, lax, jit


def welford_covariance():
    def init_fn(size):
        return np.zeros(size), np.zeros(size), 0

    def update_fn(sample, state):
        mean, m2, n = state
        n = n + 1
        delta_pre = sample - mean
        mean = mean + delta_pre / n
        delta_post = sample - mean
        m2 = m2 + delta_pre * delta_post
        return mean, m2, n

    def final_fn(state):
        mean, m2, n = state
        cov = m2 / (n - 1)
        cov_inv_sqrt = np.sqrt(np.reciprocal(cov))
        return cov, cov_inv_sqrt

    return init_fn, update_fn, final_fn


def warmup_adapter():
    mm_init, mm_update, mm_final = welford_covariance()

    def init_fn(z, rng, mass_matrix_size):
        inverse_mass_matrix = np.ones(mass_matrix_size)
        mass_matrix_sqrt = inverse_mass_matrix
        mm_state = mm_init(mass_matrix_size)
        return (inverse_mass_matrix, mass_matrix_sqrt, mm_state, rng)

    def _update_at_window_end(z, rng_ss, state):
        inverse_mass_matrix, mass_matrix_sqrt, mm_state, rng = state
        inverse_mass_matrix, mass_matrix_sqrt = mm_final(mm_state)
        mm_state = mm_init(inverse_mass_matrix.shape[-1])
        return (inverse_mass_matrix, mass_matrix_sqrt, mm_state, rng)

    def update_fn(t, accept_prob, z, state):
        inverse_mass_matrix, mass_matrix_sqrt, mm_state, rng = state
        rng, rng_ss = random.split(rng)
        state = (inverse_mass_matrix, mass_matrix_sqrt, mm_state, rng)
        state = lax.cond(t < 10,
                         (z, rng_ss, state), lambda args: _update_at_window_end(*args),
                         state, lambda x: x)
        return state

    return init_fn, update_fn


wa_init, wa_update = warmup_adapter()
wa_update = jit(wa_update)  # uncomment this will make it fast

z = np.ones(3)
wa_state = wa_init(z, random.PRNGKey(0), mass_matrix_size=3)
import time
for t in range(10):
    tic = time.time()
    wa_state = wa_update(t, 0.1 * t, z, wa_state)
    print(time.time() - tic)

which returns

0.0958707332611084
0.08851313591003418
0.08699154853820801
0.09005379676818848
0.08801078796386719
0.08790731430053711
0.09052276611328125
0.08893036842346191
0.0877068042755127
0.08900189399719238

while using non-jit wa_update, we get

0.12209296226501465
0.002318859100341797
0.0018045902252197266
0.002086162567138672
0.0018031597137451172
0.0017747879028320312
0.0020046234130859375
0.002245187759399414
0.0020575523376464844
0.0024704933166503906

I think this is just a typo somewhere (similar to #1237).

cc @neerajprad

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Comments:5 (5 by maintainers)

github_iconTop GitHub Comments

3reactions
mattjjcommented, Aug 23, 2019

It’s hitting a compile every time (not sure why yet). I added a print("COMPILING") before this line and it’s printing for every iteration. Pretty much all of these bad performance issues come from missing the compilation cache.

I hereby promise that in the PR closing this bug I’ll also add at least a print statement there that can be enabled with a flag! That will make it super easy to check when a performance problem is a recompilation issue.

0reactions
fehiepsicommented, Aug 23, 2019

It seems related to an issue I got a few days ago: the code uses all GPU memory while I just run CPU code (with jaxlib 0.1.23 installed from https://storage.googleapis.com/jax-releases). Things are back to normal again when I install jaxlib version from pypi. I’ll check if #1240 fixes it too.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Change log - JAX documentation
The jax.Array type helps make parallelism a core feature of JAX, simplifies and unifies JAX internals, and allows us to unify jit and ......
Read more >
JAX pmap is slower than jit(vmap), how to speedup?
I have two fairly complex and independent computations that I want to run on two GPUs with pmap . Surprisingly the pmap -ed...
Read more >
jax · PyPI
At its core, JAX is an extensible system for transforming numerical functions. Here are four transformations of primary interest: grad , jit ,...
Read more >
jaxnet 0.2.7 on PyPI - Libraries.io
Neural Nets for JAX - 0.2.7 - a Python package on PyPI - Libraries.io. ... This will free your function from slow Python...
Read more >
15. JAX - Python Programming for Economics and Finance
the lecture might be significantly slower when running on your machine, and ... JAX provides data types, functions and a compiler for fast...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found