Jax on GPU and jit is 11x slower than Pytorch on GPU
See original GitHub issueDescription
I migrated from Pytorch to Jax but I am noticing 11x slowdown on Jax. To test more generally, I used a simple function that sums the first three powers of a matrix
def fn(x):
return x+x*x+x*x*x
x=np.random.randn(10000,10000).astype(dtype='float32')
jax_fn=jit(fn)
x=jnp.array(x)
%timeit -n5 jax_fn(x).block_until_ready()
Jax takes 5.48 ms. This is running on GPU [by checking print(device_put(1, jax.devices()[0]).device_buffer.device()) ]
While same code on Pytorch on the same GPU runs in 459 microseconds which is 11x faster.
Im wondering where the slowdown is coming from and if there are any ways to speed it up?
Thanks a lot for your help
What jax/jaxlib version are you using?
pip install “jax[cuda11_cudnn82]” -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Which accelerator(s) are you using?
GPU
Additional system info
Python V 3.10.6. WSL.
NVIDIA GPU info
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.48.03 Driver Version: 516.25 CUDA Version: 11.7 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 Quadro P5200 On | 00000000:01:00.0 On | N/A |
| N/A 39C P8 7W / N/A | 14497MiB / 16384MiB | 1% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| 0 N/A N/A 28 G /Xwayland N/A |
| 0 N/A N/A 4039 C /python3.10 N/A |
| 0 N/A N/A 4389 C /python3.10 N/A |
+-----------------------------------------------------------------------------+
Issue Analytics
- State:
- Created 10 months ago
- Comments:10 (4 by maintainers)
Top Results From Across the Web
jax.lax.linalg.eigh on GPU and multi-core CPU doesn't parallel ...
I find that this(7.5s for 24*1024*320*320 ) is 50x faster than JAX on 24-core CPU (15.6s for 1024*320*320 ) and 40x faster than...
Read more >[D] Current State of JAX vs Pytorch? : r/MachineLearning
PyTorch has a lower barrier to entry, because it feels more like normal Python. When you lean into its advanced features a bit...
Read more >Why You Should (or Shouldn't) be Using Google's JAX in 2022
JAX's Just-in-Time Compilation In this case, we see that JAX is a staggering 9.3 times faster than NumPy, and if we both JIT...
Read more >Performance of JAX vs PyTorch - Kaggle
Let's compare how fast two libraries can calculate a gradient of the same function: JAX vs PyTorch. No hardware acceleration will be enabled,...
Read more >Jax/Flax (very) slow RNN-forward-pass compared to pyTorch?
The reason the JAX code compiles slowly is that during JIT compilation JAX unrolls loops. So in terms of XLA compilation, your function...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Just like we need
block_until_ready()
to properly profile GPUs in JAX, for PyTorch we will needtorch.cuda.synchronize()
.On a Colab T4 instance:
Hence JAX is indeed faster 😄
A note on repeated computations: jaxprs don’t contain any of this logic (they’re just intermediate representations of the computations you write out in the Python code); all deduplication is done at the compiler level. You can confirm this by printing the compiled HLO:
The output is a bit hard to reed, but you can see here that multiply(param_0, param_0) is only computed once in the compiled version.