question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

jax.scipy.sparse.linalg.cg inconsistent results between runs

See original GitHub issue

Hi all,

the conjugate gradient function inside jax.scipy.sparse seems to be very inconsistent on jax GPU. I’m a new user to jax so im not sure if this issue has been addressed somewhere. I believe it is somewhat related to #565 and #9784.

To see the full picture, I have saved both input A and b so I can get consistent result between each runs. No preconditioning is applied.

I tested my result on three platforms: CPU[colab + local], GPU[colab + local] and TPU[colab].

Out of all the runs I have done, these three platforms all produce different results but only GPU has inconsistent issue between runs.

  • On local machine, jax on CPU produces exactly the same result with colab CPU. And it is CONSISTENT between different runs.
  • On colab, jax on TPU is also CONSISTENT between different runs.
  • On GPU, both colab and my local machine has large INCONSISTENCY between runs. Sometimes even output a nan matrix.

I have seen people mention the issue with CUDA version, so I tested out cuda11.1, 11.2 and 11.4 and they all have the same issue.

To see how much changes it make, heres the output of three different runs: DeviceArray([ 9.28246680e+03, 1.50545068e+04, 1.90608145e+04, 2.23634746e+04, 2.50702012e+04, 2.76033926e+04, 2.99257559e+04, 3.21613457e+04, 3.42872852e+04,...

DeviceArray([-8.13425984e-03, -1.17020588e-02, -1.27483038e-02, -1.18785836e-02, -9.67487786e-03, -6.41405629e-03, -2.11878261e-03, 3.24898120e-03, 9.95288976e-03,...

DeviceArray([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,...

I am using jax 0.3.4
jaxlib 0.3.2+cuda11.cudnn82 scipy 1.8.0 numpy 1.22.3

Issue Analytics

  • State:open
  • Created a year ago
  • Comments:19 (9 by maintainers)

github_iconTop GitHub Comments

1reaction
marsaevcommented, Oct 20, 2022

@tlu7 Just another remark - currently jax uses CUSPARSE_MV_ALG_DEFAULT as a parameter to cusparse spmv in files https://github.com/google/jax/blob/main/jaxlib/cuda/cusparse_kernels.cc and https://github.com/google/jax/blob/main/jaxlib/cuda/cusparse.cc , which is deprecated and might default to non-deterministic result in general. I would suggest using CUSPARSE_SPMV_COO_ALG2 instead - according to docs https://docs.nvidia.com/cuda/cusparse/index.html#cusparse-generic-function-spmv :

Provides deterministic (bit-wise) results for each run. If opA != CUSPARSE_OPERATION_NON_TRANSPOSE, it is identical to CUSPARSE_SPMV_COO_ALG1
1reaction
tlu7commented, Oct 19, 2022

Thanks for looking into this! We recently enable the lowering of BCOO dot_general to cuSparse (https://github.com/google/jax/pull/12138). Yes, indices_sorted=True is one of the requirements for using cuSparse.

Read more comments on GitHub >

github_iconTop Results From Across the Web

jax.scipy.sparse.linalg.cg - JAX documentation - Read the Docs
Derivatives of cg are implemented via implicit differentiation with another cg solve, rather than by differentiating through the solver. They will be accurate ......
Read more >
scipy.sparse.linalg.cg — SciPy v1.9.3 Manual
Iteration will stop after maxiter steps even if the specified tolerance has not been achieved. M{sparse matrix, ndarray, LinearOperator}. Preconditioner for A.
Read more >
Issues using the scipy.sparse.linalg linear system solvers
I've got linear system to solve which consists of large, sparse matrices. I've been using ...
Read more >
arXiv:2211.10154v1 [cs.CV] 17 Nov 2022
in classification score resulting from the perturbation. ... Finally, we run ... jugate Gradient [28], from jax.scipy.sparse.linalg.cg. This.
Read more >
(PDF) NetKet 3: Machine Learning Toolbox for Many-Body ...
We introduce version 3 of NetKet, the machine learning toolbox ... solver (the default is jax.scipy.sparse.linalg.cg ) using nk.optimizer.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found