question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Performance in small functions

See original GitHub issue

I am doing some experiments with probabilistic programming in jax. It was quite easy lexically to port the project from autograd to jax, but I am finding a performance hit on the CPU, especially for small models. Specifically I see a ~100x slowdown in computing a log probability compared to autograd, and a 1.3x slowdown computing a gradient compared to autograd. Those experiments are below.

Putting all this together, I see a ~6x slowdown in running hamiltonian monte carlo with jax compared to autograd on the CPU (I wrapped the gradient with jit, but have not spent much more time tuning).

Question

I see #427 suggests that benefits are seen for larger functions that are not dominated by dispatch. Are there any other suggestions or best practices for computations like these? I may go through and restructure the program to allow jit to be used more often to see how much that helps.

from jax import jit
import jax.numpy as jnp

import numpy as onp

import autograd.numpy as anp


def logp(x):
    """N(x | 1., 0.1)"""
    return 0.5 * (onp.log(2 * onp.pi * 0.1 * 0.1) + ((x - 1.) / 0.1) ** 2)

@jit
def jlogp_jit(x):
    """N(x | 1., 0.1)"""
    return 0.5 * (jnp.log(2 * jnp.pi * 0.1 * 0.1) + ((x - 1.) / 0.1) ** 2)

def jlogp(x):
    """N(x | 1., 0.1)"""
    return 0.5 * (jnp.log(2 * jnp.pi * 0.1 * 0.1) + ((x - 1.) / 0.1) ** 2)

def alogp(x):
    """N(x | 1., 0.1)"""
    return 0.5 * (anp.log(2 * anp.pi * 0.1 * 0.1) + ((x - 1.) / 0.1) ** 2)

%timeit logp(0.1)
# 1.14 µs ± 92.6 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

%timeit jlogp_jit(0.1)
# 214 µs ± 21.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

%timeit jlogp(0.1)
# 674 µs ± 33.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

%timeit alogp(0.1)
# 2.16 µs ± 157 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

This holds true to a lesser extent for gradients as well:

from jax import grad as jgrad
from autograd import grad as agrad

adlogp = agrad(alogp)
jdlogp = jgrad(jlogp)
jdlogp_jit = jit(jgrad(jlogp))
jdlogp_jit_jit = jit(jgrad(jlogp_jit))

%timeit adlogp(0.1)
# 156 µs ± 8.96 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

%timeit jdlogp(0.1)
# 3.44 ms ± 500 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

%timeit jdlogp_jit(0.1)
# 232 µs ± 25.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

%timeit jdlogp_jit_jit(0.1)
# 204 µs ± 13 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Reactions:1
  • Comments:8 (8 by maintainers)

github_iconTop GitHub Comments

2reactions
ColCarrollcommented, Apr 8, 2019

Thank you for the kind words! I spotted numpyro yesterday and it looks very nice. Possibly the biggest design difference is I am trying to use this only as a reference and learning implementation, though numpyro looks like you could actually use it for real work! I was impressed at autograd for letting me write concise and readable code, and was hoping not to have to change it too much to use jax.

The iterative NUTS implementation is very interesting - I am taking a closer look at that now. I have been looking at transforming the tree doubling into an iterative scheme, and I was impressed to see you all had done it. I would love to read any thoughts you had on it (in fact, I know a few people who would…)

I had a (very short) conversation about it with one of the Stan developers, and he was not convinced that the recursion ever goes deep enough for there to be a performance benefit to it (I think both Stan and PyMC3 use a default of 10 doublings before issuing a warning, which really is not that deep). That said, it seems like it gives you a bit more control in searching for the right scale for a trajectory, and might make the code more readable/maintainable, which is not a small thing.

1reaction
fehiepsicommented, Apr 8, 2019

I’m happy that you are also interested in iterative NUTS! I will share with you a note to translate recursive to iterative in a couple of days. 😃

About the benefit, yes, it gives us more control over small models, especially when using JAX. The overhead of JAX is pretty large. For example, if each leapfrog step takes 100micro second for overhead cost, then it takes 100ms to build a full tree (with depth=10). So it takes 100s to get 1000 samples, which is so costly 😦 When using iterative NUTS, we can jit the whole trajectory, which in turns have the benefit that the overhead is just in the range of 100micro second. We did a benchmark here showing that it took 1s to build a full tree with recursive algorithm while only 500 microsecond with iterative algorithm. But I also think that the algorithm is less benefit in other frameworks, where the overhead for each leapfrog step is small (says about 1micro second). Or when the computational cost of each leapfrog step is in the range of miliseconds (as in the covertype example of Simple, Distributed, and Accelerated Probabilistic Programming paper).

Read more comments on GitHub >

github_iconTop Results From Across the Web

C# Performance on Small Functions
I can make performance match much better by changing one line of code: a = a + b + 1;. Change it to:...
Read more >
How much do function calls impact performance?
In a nutshell: function calls may or may not impact performance. ... Long story short, the overhead of a direct (non-virtual) function call ......
Read more >
Make your programs run faster: avoid function calls
This decision is motivated by performance, long function is not worth inlining because the function itself takes long time and call overhead is ......
Read more >
Performance Tips · The Julia Language
The use of functions is not only important for performance: functions are more reusable and testable, and clarify what steps are being done...
Read more >
Reducing function-call overhead
Avoid breaking your program into too many small functions. If you must use small functions, you can use the -qipa compiler option to...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found