Compilation hangs indefinitely on GPU
See original GitHub issueI am encountering an issue where compilation on GPU hangs forever in a semi-deterministic way (happens every time, but at slightly different places). All functions have been compiled successfully before (but with different shapes).
This happens in the middle of a huge model code, and I unfortunately haven’t been able to come up with a reproducer. After 2 minutes I get the “slow compile” warning, then all I can do is send SIGKILL
.
I have dumped the HLO but it looks inconspicuous to me:
https://gist.github.com/dionhaefner/e5680e131975b6bf566c1e1cbc554476
The only lead I have is that right before it hangs, I do something like this:
# <do computations on GPU with JAX>
import numpy as onp
rhs = onp.asarray(rhs)
x0 = onp.asarray(x0)
linear_solution, info = scipy.sparse.linalg.bicgstab(
_matrix,
rhs,
x0=x0,
atol=0,
tol=settings.congr_epsilon,
maxiter=settings.congr_max_iterations,
**self._extra_args,
)
return jnp.asarray(linear_solution)
# a couple of lines later everything hangs
If I comment out the BiCG solver everything works.
This happens on JAX built from source and current wheels. Downgrading jaxlib did not help either. Works on jaxlib 0.1.64, albeit poorly (factor of 10 slower for some reason).
If you have any pointer on how to debug this I would be grateful.
Issue Analytics
- State:
- Created 2 years ago
- Comments:11 (3 by maintainers)
Top GitHub Comments
FWIW, this does not occur when I do
Could this be SciPy’s internal OpenMP parallelization clashing with JAX’s thread parallelism?
Seems fixed with recent JAX, thanks!