jitted function is slow with JAX 0.1.42
See original GitHub issueHere is a repro script
import jax
import jax.numpy as np
from jax import random, lax, jit
def welford_covariance():
def init_fn(size):
return np.zeros(size), np.zeros(size), 0
def update_fn(sample, state):
mean, m2, n = state
n = n + 1
delta_pre = sample - mean
mean = mean + delta_pre / n
delta_post = sample - mean
m2 = m2 + delta_pre * delta_post
return mean, m2, n
def final_fn(state):
mean, m2, n = state
cov = m2 / (n - 1)
cov_inv_sqrt = np.sqrt(np.reciprocal(cov))
return cov, cov_inv_sqrt
return init_fn, update_fn, final_fn
def warmup_adapter():
mm_init, mm_update, mm_final = welford_covariance()
def init_fn(z, rng, mass_matrix_size):
inverse_mass_matrix = np.ones(mass_matrix_size)
mass_matrix_sqrt = inverse_mass_matrix
mm_state = mm_init(mass_matrix_size)
return (inverse_mass_matrix, mass_matrix_sqrt, mm_state, rng)
def _update_at_window_end(z, rng_ss, state):
inverse_mass_matrix, mass_matrix_sqrt, mm_state, rng = state
inverse_mass_matrix, mass_matrix_sqrt = mm_final(mm_state)
mm_state = mm_init(inverse_mass_matrix.shape[-1])
return (inverse_mass_matrix, mass_matrix_sqrt, mm_state, rng)
def update_fn(t, accept_prob, z, state):
inverse_mass_matrix, mass_matrix_sqrt, mm_state, rng = state
rng, rng_ss = random.split(rng)
state = (inverse_mass_matrix, mass_matrix_sqrt, mm_state, rng)
state = lax.cond(t < 10,
(z, rng_ss, state), lambda args: _update_at_window_end(*args),
state, lambda x: x)
return state
return init_fn, update_fn
wa_init, wa_update = warmup_adapter()
wa_update = jit(wa_update) # uncomment this will make it fast
z = np.ones(3)
wa_state = wa_init(z, random.PRNGKey(0), mass_matrix_size=3)
import time
for t in range(10):
tic = time.time()
wa_state = wa_update(t, 0.1 * t, z, wa_state)
print(time.time() - tic)
which returns
0.0958707332611084
0.08851313591003418
0.08699154853820801
0.09005379676818848
0.08801078796386719
0.08790731430053711
0.09052276611328125
0.08893036842346191
0.0877068042755127
0.08900189399719238
while using non-jit wa_update
, we get
0.12209296226501465
0.002318859100341797
0.0018045902252197266
0.002086162567138672
0.0018031597137451172
0.0017747879028320312
0.0020046234130859375
0.002245187759399414
0.0020575523376464844
0.0024704933166503906
I think this is just a typo somewhere (similar to #1237).
cc @neerajprad
Issue Analytics
- State:
- Created 4 years ago
- Comments:5 (5 by maintainers)
Top Results From Across the Web
Change log - JAX documentation
The jax.Array type helps make parallelism a core feature of JAX, simplifies and unifies JAX internals, and allows us to unify jit and ......
Read more >JAX pmap is slower than jit(vmap), how to speedup?
I have two fairly complex and independent computations that I want to run on two GPUs with pmap . Surprisingly the pmap -ed...
Read more >jax · PyPI
At its core, JAX is an extensible system for transforming numerical functions. Here are four transformations of primary interest: grad , jit ,...
Read more >jaxnet 0.2.7 on PyPI - Libraries.io
Neural Nets for JAX - 0.2.7 - a Python package on PyPI - Libraries.io. ... This will free your function from slow Python...
Read more >15. JAX - Python Programming for Economics and Finance
the lecture might be significantly slower when running on your machine, and ... JAX provides data types, functions and a compiler for fast...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
It’s hitting a compile every time (not sure why yet). I added a
print("COMPILING")
before this line and it’s printing for every iteration. Pretty much all of these bad performance issues come from missing the compilation cache.I hereby promise that in the PR closing this bug I’ll also add at least a print statement there that can be enabled with a flag! That will make it super easy to check when a performance problem is a recompilation issue.
It seems related to an issue I got a few days ago: the code uses all GPU memory while I just run CPU code (with jaxlib 0.1.23 installed from https://storage.googleapis.com/jax-releases). Things are back to normal again when I install jaxlib version from pypi. I’ll check if #1240 fixes it too.