question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Jax on GPU and jit is 11x slower than Pytorch on GPU

See original GitHub issue

Description

I migrated from Pytorch to Jax but I am noticing 11x slowdown on Jax. To test more generally, I used a simple function that sums the first three powers of a matrix

def fn(x):
    return x+x*x+x*x*x
x=np.random.randn(10000,10000).astype(dtype='float32')
jax_fn=jit(fn)
x=jnp.array(x)
%timeit  -n5 jax_fn(x).block_until_ready()

Jax takes 5.48 ms. This is running on GPU [by checking print(device_put(1, jax.devices()[0]).device_buffer.device()) ]

While same code on Pytorch on the same GPU runs in 459 microseconds which is 11x faster.

Im wondering where the slowdown is coming from and if there are any ways to speed it up?

Thanks a lot for your help

What jax/jaxlib version are you using?

pip install “jax[cuda11_cudnn82]” -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Which accelerator(s) are you using?

GPU

Additional system info

Python V 3.10.6. WSL.

NVIDIA GPU info

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.48.03    Driver Version: 516.25       CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Quadro P5200        On   | 00000000:01:00.0  On |                  N/A |
| N/A   39C    P8     7W /  N/A |  14497MiB / 16384MiB |      1%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A        28      G   /Xwayland                       N/A      |
|    0   N/A  N/A      4039      C   /python3.10                     N/A      |
|    0   N/A  N/A      4389      C   /python3.10                     N/A      |
+-----------------------------------------------------------------------------+

Issue Analytics

  • State:open
  • Created 10 months ago
  • Comments:10 (4 by maintainers)

github_iconTop GitHub Comments

4reactions
yhtangcommented, Nov 14, 2022

Just like we need block_until_ready() to properly profile GPUs in JAX, for PyTorch we will need torch.cuda.synchronize().

On a Colab T4 instance:

%timeit -n 100 f_jit(x_jax).block_until_ready()  # measure JAX runtime

3.85 ms ± 26.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

x_torch = torch.tensor(x, device='cuda')

def fn_torch(x):
    r = x+x*x+x*x*x
    torch.cuda.synchronize()

with torch.no_grad():
    %timeit -n 10 fn_torch(x_torch)

21.3 ms ± 131 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Hence JAX is indeed faster 😄

2reactions
jakevdpcommented, Nov 9, 2022

A note on repeated computations: jaxprs don’t contain any of this logic (they’re just intermediate representations of the computations you write out in the Python code); all deduplication is done at the compiler level. You can confirm this by printing the compiled HLO:

print(jax.jit(fn).lower(x).compile().as_text())
HloModule jit_fn, entry_computation_layout={(f32[10]{0})->f32[10]{0}}

%fused_computation (param_0.2: f32[10]) -> f32[10] {
  %param_0.2 = f32[10]{0} parameter(0)
  %multiply.1 = f32[10]{0} multiply(f32[10]{0} %param_0.2, f32[10]{0} %param_0.2), metadata={op_name="jit(fn)/jit(main)/mul" source_file="<ipython-input-5-82aba86e1c4e>" source_line=5}
  %add.1 = f32[10]{0} add(f32[10]{0} %param_0.2, f32[10]{0} %multiply.1), metadata={op_name="jit(fn)/jit(main)/add" source_file="<ipython-input-5-82aba86e1c4e>" source_line=5}
  %multiply.0 = f32[10]{0} multiply(f32[10]{0} %multiply.1, f32[10]{0} %param_0.2), metadata={op_name="jit(fn)/jit(main)/mul" source_file="<ipython-input-5-82aba86e1c4e>" source_line=5}
  ROOT %add.0 = f32[10]{0} add(f32[10]{0} %add.1, f32[10]{0} %multiply.0), metadata={op_name="jit(fn)/jit(main)/add" source_file="<ipython-input-5-82aba86e1c4e>" source_line=5}
}

ENTRY %main.7 (Arg_0.1: f32[10]) -> f32[10] {
  %Arg_0.1 = f32[10]{0} parameter(0)
  ROOT %fusion = f32[10]{0} fusion(f32[10]{0} %Arg_0.1), kind=kLoop, calls=%fused_computation, metadata={op_name="jit(fn)/jit(main)/add" source_file="<ipython-input-5-82aba86e1c4e>" source_line=5}
}

The output is a bit hard to reed, but you can see here that multiply(param_0, param_0) is only computed once in the compiled version.

Read more comments on GitHub >

github_iconTop Results From Across the Web

jax.lax.linalg.eigh on GPU and multi-core CPU doesn't parallel ...
I find that this(7.5s for 24*1024*320*320 ) is 50x faster than JAX on 24-core CPU (15.6s for 1024*320*320 ) and 40x faster than...
Read more >
[D] Current State of JAX vs Pytorch? : r/MachineLearning
PyTorch has a lower barrier to entry, because it feels more like normal Python. When you lean into its advanced features a bit...
Read more >
Why You Should (or Shouldn't) be Using Google's JAX in 2022
JAX's Just-in-Time Compilation​​ In this case, we see that JAX is a staggering 9.3 times faster than NumPy, and if we both JIT...
Read more >
Performance of JAX vs PyTorch - Kaggle
Let's compare how fast two libraries can calculate a gradient of the same function: JAX vs PyTorch. No hardware acceleration will be enabled,...
Read more >
Jax/Flax (very) slow RNN-forward-pass compared to pyTorch?
The reason the JAX code compiles slowly is that during JIT compilation JAX unrolls loops. So in terms of XLA compilation, your function...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found