question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Extremely slow GPU execution

See original GitHub issue

The following code is almost instantaneous (<1ms) on the CPU, but is extremely slow on the GPU (7s). I’m trying to track down the source of the problem. I have pared down my code from 5000 lines down to 80 lines, and I don’t think I can remove any more. I have added comments in places that I found that have surprising (to me) effects on the GPU run time.

How can I make this code run faster on the GPU than it does on the CPU? What am I doing wrong?

from functools import partial
from typing import Any
import haiku as hk
import jax.numpy as jnp
from contexttimer import Timer
from jax import jit
from jax.experimental import enable_x64
from jax.lax import while_loop
from jax.nn import sigmoid, softplus
from jax.random import PRNGKey, normal, split
from tjax.dataclasses import dataclass  # Equivalent to flax.struct.dataclass

class Linear(hk.Module):
    def __init__(self, output_size: int):
        super().__init__()
        self.output_size = output_size

    def __call__(self, inputs):
        w = hk.get_parameter("w", [inputs.shape[-1], self.output_size],
                             inputs.dtype,  # Passing dtype costs 23%!
                             init=jnp.zeros)
        # Calling softplus costs 32%!
        return jnp.dot(inputs, softplus(w))

class NoisyMLP(hk.Module):
    def __init__(self, layer_sizes):
        super().__init__()
        self.layers = [Linear(output_size) for output_size in layer_sizes]

    def __call__(self, inputs):
        out = inputs
        for layer in self.layers:
            out = layer(out)
            out = sigmoid(out)  # Sigmoid costs 10%!
        return out

@dataclass
class SamplerState:
    code_momentum: Any
    rng: Any
    iterations: Any

shape = (1,)

def nat_to_exp(natural_explanation):
    mlp = NoisyMLP((12, *shape))
    return mlp(natural_explanation)

def haiku_weight_initializer() -> None:
    nat_to_exp(jnp.zeros(shape))

def state_needs_iteration(maximum_iterations, state) -> bool:
    return state.iterations < maximum_iterations

def update_state(weights, state):
    leak_rng, new_rng = split(state.rng)
    nat_to_exp_f = hk.transform(nat_to_exp).apply
    force = nat_to_exp_f(weights, None, state.code_momentum)
    new_code_momentum = force + normal(leak_rng, force.shape)
    return SamplerState(new_code_momentum, new_rng, state.iterations + 1)

def find_fixed_point(weights, initial_state, maximum_iterations):
    return while_loop(partial(state_needs_iteration, maximum_iterations),
                      partial(update_state, weights),
                      initial_state)

@partial(jit, static_argnums=())  # Passing maximum_iterations non-statically costs 43%!
def infer_encoding(weights, initial_rng, maximum_iterations):
    initial_sampler_state = SamplerState(jnp.zeros(shape), initial_rng, 0)
    return find_fixed_point(weights, initial_sampler_state, maximum_iterations)

with enable_x64():  # Enabling 64-bit costs 50%.
    rng = PRNGKey(12)
    weight_rng, inference_rng = split(rng)
    weights = hk.transform(haiku_weight_initializer).init(weight_rng)
    for _ in range(10):
        with Timer() as timer:
            infer_encoding(weights, inference_rng, 8000)
        print(timer.elapsed)

Issue Analytics

  • State:open
  • Created 2 years ago
  • Comments:30 (22 by maintainers)

github_iconTop GitHub Comments

2reactions
mattjjcommented, Jun 18, 2021

Thanks for raising this, and for working so hard to minimize it!

The best tool here is to use profiling. If you can get a profile showing a realistic workload, we can really dig in to what improvements can be made (either to your code, to JAX itself, or to XLA:GPU).

There’s one effect that would explain one of your comments, though I don’t think it would explain the code as written being slow. General while_loops can require returning control to the host on each iteration just to decide whether to dispatch another iteration of the loop body on the GPU, incurring expensive synchronization and transfer overheads (which would loom large when the loop body itself is cheap). But in XLA:GPU there’s a “for loop optimization” which is meant to notice when the loop actually has a statically fixed trip count (as it does here, at least with the code as written!) so that control need not be returned to the host on each iteration.

Could you share a profile of the execution so we can dig in?

1reaction
NeilGirdharcommented, Sep 19, 2022

First of all, Jax Triton looks amazing! Yes, it should solve my problem with quite of bit of work on my side. So thank you for that.

However, I have some thoughts that I’d like to get feedback on.

My problem boils down to an internal scan that evaluates something like

x[i+1] = x[i] + k * f_bwd(z - f(x[i]))

Where f is a “forwards pass” function of primals, and f_bwd is the corresponding backward pass of cotangents.

If f is a simple neural network with noise, then it’s fairly straightforward to write this in Triton. The backward pass can easily be written, but it’s annoying. Why am I doing this? Jax is already calculating the backward pass, and I might make mistakes that I’ll have to debug. That’s what I meant when I asked if I would have access to Jax’s automatic differentiation. It appears that I’ll have to manually differentiate f and then implement that in Triton.

I also thought about how I would write this in Triton. I could just manually write every fused kernel I need. And at the end of it, I’d have a library of pieces of kernels that I could compose to do what I need. These would probably be extra methods on “modules” (from Haiku or Flax) that would do things like:

  • report the shapes and dtypes of the inputs and outputs,
  • allocate the intermediate storage,
  • sample from a Jax PRNG whatever random numbers are necessary,
  • produce Triton code for the forward and backward pass (without the decorated function header–just the code)

Then I would have some way of composing multiple modules into a single fused kernel. This entails two functions triton_forwards and triton_backwards that would produce a forward or backward Triton function. Each function

  • has parameters that could be a graph of “modules”.
  • allocates all of the intermediate storage it needs (maybe even reusing space if possible),
  • samples all of the random numbers,
  • then call a jitted Triton function.
  • The jitted Triton function would contain various sub-calls to ordinary Python functions corresponding to each “module”, and would have Triton pointers to the intermediate data structures, the inputs and outputs, and the intermediate storage.

Then I thought: why am I doing all this? Wouldn’t it make much more sense to have a conversion from XLA to Triton?

I understand that Triton is a very limited language. I understand that it may not be possible to convert everything that XLA can do to Triton. But I’m not doing anything that crazy. If the converter wants to bail out if I try to do something like take a hyperbolic sin, that’s fine! I’m just doing ordinary multiplications, exponentiations, addition, etc.

And I remember Matt explaining to me that Nvidia’s kernels (e.g. matrix multiplication) are better optimized than anything the user can do. But I’m pretty sure that the last time I looked at this, my runtime is dominated by kernel spawning. Even if Triton is 50% as fast as Nvidia’s hand-crafted kernels, the ability to fuse literally hundreds of kernels together would more than compensate. And the reason it’s hundreds is because I have a scan (described above), and each iteration of the scan is a whole new set of kernel spawning.

So, my question boils down to: Why have we decided on Jax-Triton as the solution? Why not convert XLA to Triton as best you can, and then we can keep programming in the Jax we love?

Read more comments on GitHub >

github_iconTop Results From Across the Web

GPU Execution Slow: Algorithm Architecture Questions
Hi there, I've managed to get openCL compiled and responsive and debugged my Kernel code such that no errors are thrown. This is...
Read more >
CUDA very slow performance
I am doing a project in which I have to port a molecular prediction program to the GPU, it works with boost, I...
Read more >
Why is this code ten times slower on the GPU than CPU?
The GPU code is ten times slower than the CPU equivalent because the GPU code exhibits a perfect storm of performance-wrecking ...
Read more >
PC slow with new graphics card
Also one or more ISR routines that belong to a driver running in your system appear to be executing for too long. At...
Read more >
Why GPUs are Slow at Executing NFAs and How ... - NSF PAR
In particular, the Automata Processor (AP) proposed by Micron [17] is an in-memory accelerator for NFAs. The AP achieves significant throughput and energy ......
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found