Jax GMRES on GPU largely slower than its scipy counterpart
See original GitHub issueHello, I used the following script to compare the performance of the jax gmres solver and the one from scipy :
from time import time
import jax
import scipy
from jax import jit
def solve_jax(A, b):
return jax.scipy.sparse.linalg.gmres(lambda v: A @ v, b,solve_method='batched', atol=1e-5)
solve_jax_jit = jit(solve_jax)
A = scipy.random.rand(30, 30)
b = scipy.random.rand(30)
t1 = time()
for i in range(20):
x = scipy.sparse.linalg.gmres(A, b, atol=1e-5, restart=20)
#print(i, x[0][0])
t2 = time()
for i in range(20):
x_jax = solve_jax_jit(A, b)
#print(i, x_jax[0][0])
t3 = time()
print(f"{(t2 - t1)/20:.2e} vs {(t3 - t2)/20:.2e}")
Here was the result :
1.09e-01 vs 2.04e+00
Nearly a 20x factor of difference ! I tried to play around with the gmres parameters but it did not change the time ratio scipy/jax.
I am using a GPU backend (could this explain the performence difference ?) on Windows (my GPU is a GTX 1650) and I have Cuda 11.6 installed. The Jax version I’m using is 0.2.26.
Issue Analytics
- State:
- Created 2 years ago
- Comments:6 (1 by maintainers)
Top Results From Across the Web
jax.scipy.sparse.linalg.gmres - JAX documentation
However, convergence is often slow for nearly symmetric operators. Parameters. A (ndarray, function, or matmul-compatible object) – 2D array or function that ...
Read more >Learning differentiable solvers for systems with hard constraints
We propose a method to solve partial differential equations (PDEs) ... [1] JAX GMRES on GPUs largely slower than its scipy counterpart ......
Read more >Does JAX run slower than NumPy? - python - Stack Overflow
For individual matrix operations on CPU, JAX is often slower than NumPy, but JIT-compiled sequences of operations in JAX are often faster than ......
Read more >JAX-DIPS: Neural bootstrapping of finite discretization ... - arXiv
Abstract. We present a scalable strategy for development of mesh-free hybrid neuro-symbolic partial differen-.
Read more >NETKET 3: Machine learning toolbox for many-body quantum ...
JAX provides the ability to write numerical code in pure Python using ... entiation, and makes models easily portable to GPU platforms.
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
This is a known issue with JAX’s iterative methods on GPUs. XLA’s loops on GPU sync back to the host CPU on each loop iteration, so they are slow if the function being solved is quick to evaluate. I actually made a benchmark for this issue that we passed off to the XLA GPU team.
The work-around is either to run this sort of computation on the CPU instead (e.g., by copying all arrays explicitly to the CPU with
jax.device_put
). Or you could also try using TPUs, for which XLA is able to run loops entirely on device.We could also consider writing/leveraging custom CUDA kernels or iterative solvers on GPUs, but that would be a major undertaking and likely would have some disadvantages (like losing pytree support).
I’m using
I installed the Windows build following these instruction (using the jaxlib-0.1.75+cuda111-cp39-none-win_amd64.whl wheel) so maybe the issue is within this build. (https://github.com/cloudhan/jax-windows-builder)