question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Jax GMRES on GPU largely slower than its scipy counterpart

See original GitHub issue

Hello, I used the following script to compare the performance of the jax gmres solver and the one from scipy :

from time import time

import jax
import scipy
from jax import jit


def solve_jax(A, b):
    return jax.scipy.sparse.linalg.gmres(lambda v: A @ v, b,solve_method='batched', atol=1e-5)


solve_jax_jit = jit(solve_jax)

A = scipy.random.rand(30, 30)
b = scipy.random.rand(30)

t1 = time()
for i in range(20):
    x = scipy.sparse.linalg.gmres(A, b, atol=1e-5, restart=20)
    #print(i, x[0][0])
t2 = time()
for i in range(20):
    x_jax = solve_jax_jit(A, b)
    #print(i, x_jax[0][0])
t3 = time()

print(f"{(t2 - t1)/20:.2e} vs {(t3 - t2)/20:.2e}")

Here was the result :

1.09e-01 vs 2.04e+00

Nearly a 20x factor of difference ! I tried to play around with the gmres parameters but it did not change the time ratio scipy/jax.

I am using a GPU backend (could this explain the performence difference ?) on Windows (my GPU is a GTX 1650) and I have Cuda 11.6 installed. The Jax version I’m using is 0.2.26.

Issue Analytics

  • State:open
  • Created 2 years ago
  • Comments:6 (1 by maintainers)

github_iconTop GitHub Comments

4reactions
shoyercommented, Jan 20, 2022

This is a known issue with JAX’s iterative methods on GPUs. XLA’s loops on GPU sync back to the host CPU on each loop iteration, so they are slow if the function being solved is quick to evaluate. I actually made a benchmark for this issue that we passed off to the XLA GPU team.

The work-around is either to run this sort of computation on the CPU instead (e.g., by copying all arrays explicitly to the CPU with jax.device_put). Or you could also try using TPUs, for which XLA is able to run loops entirely on device.

We could also consider writing/leveraging custom CUDA kernels or iterative solvers on GPUs, but that would be a major undertaking and likely would have some disadvantages (like losing pytree support).

0reactions
Azercococommented, Jan 20, 2022

I’m using

  • Python 3.9
  • Cuda 11.6
  • Jax Lib 0.1.175
  • Jax 0.2.26
  • GPU version
  • I had the issue with both single and double precision

I installed the Windows build following these instruction (using the jaxlib-0.1.75+cuda111-cp39-none-win_amd64.whl wheel) so maybe the issue is within this build. (https://github.com/cloudhan/jax-windows-builder)

Read more comments on GitHub >

github_iconTop Results From Across the Web

jax.scipy.sparse.linalg.gmres - JAX documentation
However, convergence is often slow for nearly symmetric operators. Parameters. A (ndarray, function, or matmul-compatible object) – 2D array or function that ...
Read more >
Learning differentiable solvers for systems with hard constraints
We propose a method to solve partial differential equations (PDEs) ... [1] JAX GMRES on GPUs largely slower than its scipy counterpart ......
Read more >
Does JAX run slower than NumPy? - python - Stack Overflow
For individual matrix operations on CPU, JAX is often slower than NumPy, but JIT-compiled sequences of operations in JAX are often faster than ......
Read more >
JAX-DIPS: Neural bootstrapping of finite discretization ... - arXiv
Abstract. We present a scalable strategy for development of mesh-free hybrid neuro-symbolic partial differen-.
Read more >
NETKET 3: Machine learning toolbox for many-body quantum ...
JAX provides the ability to write numerical code in pure Python using ... entiation, and makes models easily portable to GPU platforms.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found