Extremely slow GPU execution
See original GitHub issueThe following code is almost instantaneous (<1ms) on the CPU, but is extremely slow on the GPU (7s). I’m trying to track down the source of the problem. I have pared down my code from 5000 lines down to 80 lines, and I don’t think I can remove any more. I have added comments in places that I found that have surprising (to me) effects on the GPU run time.
How can I make this code run faster on the GPU than it does on the CPU? What am I doing wrong?
from functools import partial
from typing import Any
import haiku as hk
import jax.numpy as jnp
from contexttimer import Timer
from jax import jit
from jax.experimental import enable_x64
from jax.lax import while_loop
from jax.nn import sigmoid, softplus
from jax.random import PRNGKey, normal, split
from tjax.dataclasses import dataclass # Equivalent to flax.struct.dataclass
class Linear(hk.Module):
def __init__(self, output_size: int):
super().__init__()
self.output_size = output_size
def __call__(self, inputs):
w = hk.get_parameter("w", [inputs.shape[-1], self.output_size],
inputs.dtype, # Passing dtype costs 23%!
init=jnp.zeros)
# Calling softplus costs 32%!
return jnp.dot(inputs, softplus(w))
class NoisyMLP(hk.Module):
def __init__(self, layer_sizes):
super().__init__()
self.layers = [Linear(output_size) for output_size in layer_sizes]
def __call__(self, inputs):
out = inputs
for layer in self.layers:
out = layer(out)
out = sigmoid(out) # Sigmoid costs 10%!
return out
@dataclass
class SamplerState:
code_momentum: Any
rng: Any
iterations: Any
shape = (1,)
def nat_to_exp(natural_explanation):
mlp = NoisyMLP((12, *shape))
return mlp(natural_explanation)
def haiku_weight_initializer() -> None:
nat_to_exp(jnp.zeros(shape))
def state_needs_iteration(maximum_iterations, state) -> bool:
return state.iterations < maximum_iterations
def update_state(weights, state):
leak_rng, new_rng = split(state.rng)
nat_to_exp_f = hk.transform(nat_to_exp).apply
force = nat_to_exp_f(weights, None, state.code_momentum)
new_code_momentum = force + normal(leak_rng, force.shape)
return SamplerState(new_code_momentum, new_rng, state.iterations + 1)
def find_fixed_point(weights, initial_state, maximum_iterations):
return while_loop(partial(state_needs_iteration, maximum_iterations),
partial(update_state, weights),
initial_state)
@partial(jit, static_argnums=()) # Passing maximum_iterations non-statically costs 43%!
def infer_encoding(weights, initial_rng, maximum_iterations):
initial_sampler_state = SamplerState(jnp.zeros(shape), initial_rng, 0)
return find_fixed_point(weights, initial_sampler_state, maximum_iterations)
with enable_x64(): # Enabling 64-bit costs 50%.
rng = PRNGKey(12)
weight_rng, inference_rng = split(rng)
weights = hk.transform(haiku_weight_initializer).init(weight_rng)
for _ in range(10):
with Timer() as timer:
infer_encoding(weights, inference_rng, 8000)
print(timer.elapsed)
Issue Analytics
- State:
- Created 2 years ago
- Comments:30 (22 by maintainers)
Top Results From Across the Web
GPU Execution Slow: Algorithm Architecture Questions
Hi there, I've managed to get openCL compiled and responsive and debugged my Kernel code such that no errors are thrown. This is...
Read more >CUDA very slow performance
I am doing a project in which I have to port a molecular prediction program to the GPU, it works with boost, I...
Read more >Why is this code ten times slower on the GPU than CPU?
The GPU code is ten times slower than the CPU equivalent because the GPU code exhibits a perfect storm of performance-wrecking ...
Read more >PC slow with new graphics card
Also one or more ISR routines that belong to a driver running in your system appear to be executing for too long. At...
Read more >Why GPUs are Slow at Executing NFAs and How ... - NSF PAR
In particular, the Automata Processor (AP) proposed by Micron [17] is an in-memory accelerator for NFAs. The AP achieves significant throughput and energy ......
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Thanks for raising this, and for working so hard to minimize it!
The best tool here is to use profiling. If you can get a profile showing a realistic workload, we can really dig in to what improvements can be made (either to your code, to JAX itself, or to XLA:GPU).
There’s one effect that would explain one of your comments, though I don’t think it would explain the code as written being slow. General while_loops can require returning control to the host on each iteration just to decide whether to dispatch another iteration of the loop body on the GPU, incurring expensive synchronization and transfer overheads (which would loom large when the loop body itself is cheap). But in XLA:GPU there’s a “for loop optimization” which is meant to notice when the loop actually has a statically fixed trip count (as it does here, at least with the code as written!) so that control need not be returned to the host on each iteration.
Could you share a profile of the execution so we can dig in?
First of all, Jax Triton looks amazing! Yes, it should solve my problem with quite of bit of work on my side. So thank you for that.
However, I have some thoughts that I’d like to get feedback on.
My problem boils down to an internal scan that evaluates something like
Where
f
is a “forwards pass” function of primals, andf_bwd
is the corresponding backward pass of cotangents.If
f
is a simple neural network with noise, then it’s fairly straightforward to write this in Triton. The backward pass can easily be written, but it’s annoying. Why am I doing this? Jax is already calculating the backward pass, and I might make mistakes that I’ll have to debug. That’s what I meant when I asked if I would have access to Jax’s automatic differentiation. It appears that I’ll have to manually differentiatef
and then implement that in Triton.I also thought about how I would write this in Triton. I could just manually write every fused kernel I need. And at the end of it, I’d have a library of pieces of kernels that I could compose to do what I need. These would probably be extra methods on “modules” (from Haiku or Flax) that would do things like:
Then I would have some way of composing multiple modules into a single fused kernel. This entails two functions
triton_forwards
andtriton_backwards
that would produce a forward or backward Triton function. Each functionThen I thought: why am I doing all this? Wouldn’t it make much more sense to have a conversion from XLA to Triton?
I understand that Triton is a very limited language. I understand that it may not be possible to convert everything that XLA can do to Triton. But I’m not doing anything that crazy. If the converter wants to bail out if I try to do something like take a hyperbolic sin, that’s fine! I’m just doing ordinary multiplications, exponentiations, addition, etc.
And I remember Matt explaining to me that Nvidia’s kernels (e.g. matrix multiplication) are better optimized than anything the user can do. But I’m pretty sure that the last time I looked at this, my runtime is dominated by kernel spawning. Even if Triton is 50% as fast as Nvidia’s hand-crafted kernels, the ability to fuse literally hundreds of kernels together would more than compensate. And the reason it’s hundreds is because I have a scan (described above), and each iteration of the scan is a whole new set of kernel spawning.
So, my question boils down to: Why have we decided on Jax-Triton as the solution? Why not convert XLA to Triton as best you can, and then we can keep programming in the Jax we love?