question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Jax scans are slower than expected

See original GitHub issue

I am implementing the tridiagonal matrix algorithm (TDMA) to solve many tridiagonal systems of the same shape in two sweeps (one forward and one backward pass).

The shape of each diagonal is something like (100_000, 100), and I vectorize over the leading axis, so this should be reasonably efficient.

In pure NumPy, I would do it like this:

def tdma_naive(a, b, c, d):
    """
    Solves many tridiagonal matrix systems with diagonals a, b, c and RHS vectors d.
    """
    assert a.shape == b.shape and a.shape == c.shape and a.shape == d.shape

    n = a.shape[-1]

    for i in range(1, n):
        w = a[..., i] / b[..., i - 1]
        b[..., i] += -w * c[..., i - 1]
        d[..., i] += -w * d[..., i - 1]

    out = np.empty_like(a)
    out[..., -1] = d[..., -1] / b[..., -1]

    for i in range(n - 2, -1, -1):
        out[..., i] = (d[..., i] - c[..., i] * out[..., i + 1]) / b[..., i]

    return out

The JAX implementation looks like this:

def tdma_jax_kernel(a, b, c, d):
    def compute_primes(last_primes, x):
        last_cp, last_dp = last_primes
        a, b, c, d = x

        denom = 1. / (b - a * last_cp)
        cp = c * denom
        dp = (d - a * last_dp) * denom

        new_primes = (cp, dp)
        return new_primes, new_primes

    diags = (a.T, b.T, c.T, d.T)
    init = jnp.zeros((a.shape[1], a.shape[0]))
    _, (cp, dp) = jax.lax.scan(compute_primes, (init, init), diags)

    def backsubstitution(last_x, x):
        cp, dp = x
        new_x = dp - cp * last_x
        return new_x, new_x

    _, sol = jax.lax.scan(backsubstitution, init, (cp[::-1], dp[::-1]))

    return sol[::-1].T

I implemented the algorithm in a handful of backends (including a sloppily written CUDA kernel). You can see the results in this Gist:

https://gist.github.com/dionhaefner/a97ef80b77e02b36e4b248bb97541161

The executive summary is that Jax is 2.5x slower than Numba on CPU, and 3x slower than my amateurish CUDA kernel on GPU (but is on par with Numba here).

If I eliminate the tranposes from the Jax implementation and transpose the inputs beforehand, the implementation gains a factor 2 of performance on GPU, so it would be nice if scan supported scanning over arbitrary axes.

Is this behavior something that is expected, and is there something else I can do to make the Jax implementation more efficient?

Issue Analytics

  • State:open
  • Created 3 years ago
  • Comments:20 (12 by maintainers)

github_iconTop GitHub Comments

10reactions
mattjjcommented, Jul 4, 2020

@clemisch Indeed the control flow is staged entirely out to XLA and the trip count is known at compile time, so at the XLA HLO level we’ve completely eliminated dynamic control flow. But there’s more to the story on GPUs.

Warning: I’m not an expert! This is just my best understanding and I hope others will correct me / fill in the gaps where I mess things up.

An XLA:GPU program is itself ultimately lowered to a hybrid CPU/GPU program, where the GPU parts are kernels (a mix of CUDA / cuDNN kernels and XLA-codegenned ones) and the CPU parts just handle launching the kernels and perhaps other runtime calls. ~(It’s all compiled, i.e. the CPU part isn’t interpreted like TF or PyTorch, which is why the CPU-side overheads can be much lower as in the NumPyro benchmarks, though for some workloads that doesn’t make any difference.)~ (EDIT: removed because IMO it’s impossible to define “interpreted” vs “compiled” precisely.)

In this case, jax.lax.scan generates a single XLA HLO While loop with a fixed trip count, so how does that get turned into such a hybrid CPU/GPU program? Perhaps the best thing XLA:GPU could do would be to lower it into a single kernel, since that would minimize overheads and maximize optimization opportunities. But XLA:GPU can’t (yet) generate a single kernel for whole loops. Instead, the loop has both CPU and GPU parts. The second best thing we could hope XLA:GPU would do is generate a single GPU kernel for the loop body, so that the CPU part of the program would just be a CPU loop with a fixed trip count launching those kernels, and we’d only pay one launch overhead cost per iteration. Unfortunately, XLA:GPU often has to generate multiple kernels for the loop body, meaning we pay several kernel launch overheads per iteration. (Moreover, I believe that sometimes, but not always, XLA:GPU may generate extra copy operations for the loop carry.)

The upshot of all this is that XLA:GPU doesn’t (yet) do the best with some loops. There could be some fundamental limits based on the GPU programming model or the tools NVIDIA provides for generating GPU programs, but I suspect we’re not at those limits yet and more can be done with more investment in XLA:GPU. So the best policy is to send love and support towards XLA:GPU developers (both on Google compiler teams and in open source, including at NVIDIA) so we can make this thing we love even better!

(One reason I’m optimistic for the future here is from seeing what XLA:TPU can do, since it’s the most developed XLA backend. With XLA:TPU, the whole program is staged out to the TPU, including the control flow for scans and other loops, so things like kernel launch overheads just don’t exist.)

Does that make sense?

2reactions
shoyercommented, Jul 4, 2020

XLA GPU currently always executes dynamic control flow on the CPU. So small loop iterations (like what you have here) end up much slower, due to the need to frequently synchronize between the CPU/GPU.

You can find a similar example of this sort of slow down in https://github.com/google/jax/pull/3076

Read more comments on GitHub >

github_iconTop Results From Across the Web

JAX Frequently Asked Questions (FAQ)
You can expect slow compilation if the output is many hundreds or thousands of lines long. Sometimes it isn't obvious how to rewrite...
Read more >
Precision Imaging Centers, Jacksonville FL | State-of-the-Art ...
World-class radiologists and a compassionate care team. Same-day imaging & diagnosis. State-of-the-art 3D mammography, MRI, PET/CT, and more.
Read more >
Know Before You Go - Fear Zone Jax Haunted Attraction
What to Expect. Upon arriving at the FearZoneJax Haunted House, if you have already purchased your tickets, proceed directly to your designated entrance ......
Read more >
Best practice for nonlinear, time-dependent PDE likelihood
My current implementation simply uses aesara.scan() to solve the forward model which ... Proposed new implementation (custom Op with JAX)
Read more >
Jax Dav Arts - Facebook
Jax Dav Arts. ८७५ लाइक · ४३ जना यसको बारेमा कुरा गर्दै छन्. Yo! ... these finished a little slower than I expected...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found