`jaxlib==0.1.44` segfaults when trying to run XLA on GPU
See original GitHub issueWhen trying to run JAX with jaxlib==0.1.44
I run in to a segmentation fault on my machine with Python 3.8 and CUDA 10.2 if I run on GPU. This issue no longer occurs if I downgrade jaxlib
to 0.1.43
.
I installed jaxlib
using the installation instructions in the README for both versions, and I properly set the XLA CUDA directory in both cases to the same location. From what I gather, only jaxlib
is changing to generate the segfault.
I tried to do some digging and it seems like the segfault is coming from jaxlib/xla_extension.so
, particularly here is what gdb
produces:
0x00007fffd6f991e8 in absl::lts_2020_02_25::Mutex::ReaderLock() () from /home/ziyadedher/research/.venv/lib/python3.8/site-packages/jaxlib/xla_extension.so
Reverting to jaxlib==0.1.43
fixes the issue.
>>> jax.__version__
'0.1.63'
>>> jaxlib.__version__
'0.1.44'
>>> tensorflow.__version__
'2.2.0-rc3'
Some system information truncated to show the important bits:
$ nvcc --version
Cuda compilation tools, release 10.2, V10.2.89
$ python --version
Python 3.8.2
$ modinfo nvidia
filename: /lib/modules/5.6.4-arch1-1/extramodules/nvidia.ko.xz
version: 440.82
Issue Analytics
- State:
- Created 3 years ago
- Comments:15 (6 by maintainers)
Top Results From Across the Web
Tensorflow see's GPU but only uses xla_cpu and crashes ...
device xla_gpu is listed but when forcing tensforflow to use it just crashes saying it can't find ptaxs. For environment information please ...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
We have a strong suspicion that the bug is here: https://github.com/tensorflow/tensorflow/blob/05991352f7fdb12ed774561269609fd908e7f95e/tensorflow/compiler/xla/python/local_client.cc#L778
.release()
and.get()
are called on astd::unique_ptr
in different arguments to the same function. Argument order of evaluation differs between compilers (e.g., clang vs gcc). We tend to testclang
internally (and have never seen this bug) but our external builds are built withgcc
which has the opposite order of evaluation. @skye is preparing a fix.This should be fixed in jaxlib 0.1.45, hot off the press! I’m gonna close this, but please let us know if you’re still experiencing segfaults. (Here’s the fix for anyone interested: https://github.com/tensorflow/tensorflow/commit/78edbb6403b73d6c79bd58e23e08dc21b5c33847)