Performance in small functions
See original GitHub issueI am doing some experiments with probabilistic programming in jax
. It was quite easy lexically to port the project from autograd
to jax
, but I am finding a performance hit on the CPU, especially for small models. Specifically I see a ~100x slowdown in computing a log probability compared to autograd, and a 1.3x slowdown computing a gradient compared to autograd. Those experiments are below.
Putting all this together, I see a ~6x slowdown in running hamiltonian monte carlo with jax
compared to autograd
on the CPU (I wrapped the gradient with jit
, but have not spent much more time tuning).
Question
I see #427 suggests that benefits are seen for larger functions that are not dominated by dispatch. Are there any other suggestions or best practices for computations like these? I may go through and restructure the program to allow jit
to be used more often to see how much that helps.
from jax import jit
import jax.numpy as jnp
import numpy as onp
import autograd.numpy as anp
def logp(x):
"""N(x | 1., 0.1)"""
return 0.5 * (onp.log(2 * onp.pi * 0.1 * 0.1) + ((x - 1.) / 0.1) ** 2)
@jit
def jlogp_jit(x):
"""N(x | 1., 0.1)"""
return 0.5 * (jnp.log(2 * jnp.pi * 0.1 * 0.1) + ((x - 1.) / 0.1) ** 2)
def jlogp(x):
"""N(x | 1., 0.1)"""
return 0.5 * (jnp.log(2 * jnp.pi * 0.1 * 0.1) + ((x - 1.) / 0.1) ** 2)
def alogp(x):
"""N(x | 1., 0.1)"""
return 0.5 * (anp.log(2 * anp.pi * 0.1 * 0.1) + ((x - 1.) / 0.1) ** 2)
%timeit logp(0.1)
# 1.14 µs ± 92.6 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
%timeit jlogp_jit(0.1)
# 214 µs ± 21.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit jlogp(0.1)
# 674 µs ± 33.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit alogp(0.1)
# 2.16 µs ± 157 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
This holds true to a lesser extent for gradients as well:
from jax import grad as jgrad
from autograd import grad as agrad
adlogp = agrad(alogp)
jdlogp = jgrad(jlogp)
jdlogp_jit = jit(jgrad(jlogp))
jdlogp_jit_jit = jit(jgrad(jlogp_jit))
%timeit adlogp(0.1)
# 156 µs ± 8.96 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit jdlogp(0.1)
# 3.44 ms ± 500 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit jdlogp_jit(0.1)
# 232 µs ± 25.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit jdlogp_jit_jit(0.1)
# 204 µs ± 13 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Issue Analytics
- State:
- Created 4 years ago
- Reactions:1
- Comments:8 (8 by maintainers)
Thank you for the kind words! I spotted numpyro yesterday and it looks very nice. Possibly the biggest design difference is I am trying to use this only as a reference and learning implementation, though numpyro looks like you could actually use it for real work! I was impressed at
autograd
for letting me write concise and readable code, and was hoping not to have to change it too much to usejax
.The iterative NUTS implementation is very interesting - I am taking a closer look at that now. I have been looking at transforming the tree doubling into an iterative scheme, and I was impressed to see you all had done it. I would love to read any thoughts you had on it (in fact, I know a few people who would…)
I had a (very short) conversation about it with one of the Stan developers, and he was not convinced that the recursion ever goes deep enough for there to be a performance benefit to it (I think both Stan and PyMC3 use a default of 10 doublings before issuing a warning, which really is not that deep). That said, it seems like it gives you a bit more control in searching for the right scale for a trajectory, and might make the code more readable/maintainable, which is not a small thing.
I’m happy that you are also interested in iterative NUTS! I will share with you a note to translate recursive to iterative in a couple of days. 😃
About the benefit, yes, it gives us more control over small models, especially when using JAX. The overhead of JAX is pretty large. For example, if each leapfrog step takes 100micro second for overhead cost, then it takes 100ms to build a full tree (with depth=10). So it takes 100s to get 1000 samples, which is so costly 😦 When using iterative NUTS, we can jit the whole trajectory, which in turns have the benefit that the overhead is just in the range of 100micro second. We did a benchmark here showing that it took 1s to build a full tree with recursive algorithm while only 500 microsecond with iterative algorithm. But I also think that the algorithm is less benefit in other frameworks, where the overhead for each leapfrog step is small (says about 1micro second). Or when the computational cost of each leapfrog step is in the range of miliseconds (as in the covertype example of Simple, Distributed, and Accelerated Probabilistic Programming paper).