Jax scans are slower than expected
See original GitHub issueI am implementing the tridiagonal matrix algorithm (TDMA) to solve many tridiagonal systems of the same shape in two sweeps (one forward and one backward pass).
The shape of each diagonal is something like (100_000, 100)
, and I vectorize over the leading axis, so this should be reasonably efficient.
In pure NumPy, I would do it like this:
def tdma_naive(a, b, c, d):
"""
Solves many tridiagonal matrix systems with diagonals a, b, c and RHS vectors d.
"""
assert a.shape == b.shape and a.shape == c.shape and a.shape == d.shape
n = a.shape[-1]
for i in range(1, n):
w = a[..., i] / b[..., i - 1]
b[..., i] += -w * c[..., i - 1]
d[..., i] += -w * d[..., i - 1]
out = np.empty_like(a)
out[..., -1] = d[..., -1] / b[..., -1]
for i in range(n - 2, -1, -1):
out[..., i] = (d[..., i] - c[..., i] * out[..., i + 1]) / b[..., i]
return out
The JAX implementation looks like this:
def tdma_jax_kernel(a, b, c, d):
def compute_primes(last_primes, x):
last_cp, last_dp = last_primes
a, b, c, d = x
denom = 1. / (b - a * last_cp)
cp = c * denom
dp = (d - a * last_dp) * denom
new_primes = (cp, dp)
return new_primes, new_primes
diags = (a.T, b.T, c.T, d.T)
init = jnp.zeros((a.shape[1], a.shape[0]))
_, (cp, dp) = jax.lax.scan(compute_primes, (init, init), diags)
def backsubstitution(last_x, x):
cp, dp = x
new_x = dp - cp * last_x
return new_x, new_x
_, sol = jax.lax.scan(backsubstitution, init, (cp[::-1], dp[::-1]))
return sol[::-1].T
I implemented the algorithm in a handful of backends (including a sloppily written CUDA kernel). You can see the results in this Gist:
https://gist.github.com/dionhaefner/a97ef80b77e02b36e4b248bb97541161
The executive summary is that Jax is 2.5x slower than Numba on CPU, and 3x slower than my amateurish CUDA kernel on GPU (but is on par with Numba here).
If I eliminate the tranposes from the Jax implementation and transpose the inputs beforehand, the implementation gains a factor 2 of performance on GPU, so it would be nice if scan
supported scanning over arbitrary axes.
Is this behavior something that is expected, and is there something else I can do to make the Jax implementation more efficient?
Issue Analytics
- State:
- Created 3 years ago
- Comments:20 (12 by maintainers)
@clemisch Indeed the control flow is staged entirely out to XLA and the trip count is known at compile time, so at the XLA HLO level we’ve completely eliminated dynamic control flow. But there’s more to the story on GPUs.
Warning: I’m not an expert! This is just my best understanding and I hope others will correct me / fill in the gaps where I mess things up.
An XLA:GPU program is itself ultimately lowered to a hybrid CPU/GPU program, where the GPU parts are kernels (a mix of CUDA / cuDNN kernels and XLA-codegenned ones) and the CPU parts just handle launching the kernels and perhaps other runtime calls. ~(It’s all compiled, i.e. the CPU part isn’t interpreted like TF or PyTorch, which is why the CPU-side overheads can be much lower as in the NumPyro benchmarks, though for some workloads that doesn’t make any difference.)~ (EDIT: removed because IMO it’s impossible to define “interpreted” vs “compiled” precisely.)
In this case,
jax.lax.scan
generates a single XLA HLO While loop with a fixed trip count, so how does that get turned into such a hybrid CPU/GPU program? Perhaps the best thing XLA:GPU could do would be to lower it into a single kernel, since that would minimize overheads and maximize optimization opportunities. But XLA:GPU can’t (yet) generate a single kernel for whole loops. Instead, the loop has both CPU and GPU parts. The second best thing we could hope XLA:GPU would do is generate a single GPU kernel for the loop body, so that the CPU part of the program would just be a CPU loop with a fixed trip count launching those kernels, and we’d only pay one launch overhead cost per iteration. Unfortunately, XLA:GPU often has to generate multiple kernels for the loop body, meaning we pay several kernel launch overheads per iteration. (Moreover, I believe that sometimes, but not always, XLA:GPU may generate extra copy operations for the loop carry.)The upshot of all this is that XLA:GPU doesn’t (yet) do the best with some loops. There could be some fundamental limits based on the GPU programming model or the tools NVIDIA provides for generating GPU programs, but I suspect we’re not at those limits yet and more can be done with more investment in XLA:GPU. So the best policy is to send love and support towards XLA:GPU developers (both on Google compiler teams and in open source, including at NVIDIA) so we can make this thing we love even better!
(One reason I’m optimistic for the future here is from seeing what XLA:TPU can do, since it’s the most developed XLA backend. With XLA:TPU, the whole program is staged out to the TPU, including the control flow for scans and other loops, so things like kernel launch overheads just don’t exist.)
Does that make sense?
XLA GPU currently always executes dynamic control flow on the CPU. So small loop iterations (like what you have here) end up much slower, due to the need to frequently synchronize between the CPU/GPU.
You can find a similar example of this sort of slow down in https://github.com/google/jax/pull/3076