jax.scipy.sparse.linalg.cg inconsistent results between runs
See original GitHub issueHi all,
the conjugate gradient function inside jax.scipy.sparse seems to be very inconsistent on jax GPU. I’m a new user to jax so im not sure if this issue has been addressed somewhere. I believe it is somewhat related to #565 and #9784.
To see the full picture, I have saved both input A and b so I can get consistent result between each runs. No preconditioning is applied.
I tested my result on three platforms: CPU[colab + local], GPU[colab + local] and TPU[colab].
Out of all the runs I have done, these three platforms all produce different results but only GPU has inconsistent issue between runs.
- On local machine, jax on CPU produces exactly the same result with colab CPU. And it is CONSISTENT between different runs.
- On colab, jax on TPU is also CONSISTENT between different runs.
- On GPU, both colab and my local machine has large INCONSISTENCY between runs. Sometimes even output a nan matrix.
I have seen people mention the issue with CUDA version, so I tested out cuda11.1, 11.2 and 11.4 and they all have the same issue.
To see how much changes it make, heres the output of three different runs:
DeviceArray([ 9.28246680e+03, 1.50545068e+04, 1.90608145e+04, 2.23634746e+04, 2.50702012e+04, 2.76033926e+04, 2.99257559e+04, 3.21613457e+04, 3.42872852e+04,...
DeviceArray([-8.13425984e-03, -1.17020588e-02, -1.27483038e-02, -1.18785836e-02, -9.67487786e-03, -6.41405629e-03, -2.11878261e-03, 3.24898120e-03, 9.95288976e-03,...
DeviceArray([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,...
I am using
jax 0.3.4
jaxlib 0.3.2+cuda11.cudnn82
scipy 1.8.0
numpy 1.22.3
Issue Analytics
- State:
- Created a year ago
- Comments:19 (9 by maintainers)
@tlu7 Just another remark - currently jax uses
CUSPARSE_MV_ALG_DEFAULT
as a parameter to cusparse spmv in files https://github.com/google/jax/blob/main/jaxlib/cuda/cusparse_kernels.cc and https://github.com/google/jax/blob/main/jaxlib/cuda/cusparse.cc , which is deprecated and might default to non-deterministic result in general. I would suggest usingCUSPARSE_SPMV_COO_ALG2
instead - according to docs https://docs.nvidia.com/cuda/cusparse/index.html#cusparse-generic-function-spmv :Thanks for looking into this! We recently enable the lowering of BCOO dot_general to cuSparse (https://github.com/google/jax/pull/12138). Yes,
indices_sorted=True
is one of the requirements for using cuSparse.