Multiple issues with running code on multiple A100s
See original GitHub issueThere seem to be multiple issues with running code on multiple A100 cards (via pmap
). Running the same code on a single GPU is fine. I can’t reproduce all of it but I will attach one scenario.
I have no idea what could cause these issues as they are very specific. Here is an example:
import jax
import jax.numpy as jnp
def pca(x, y):
mean = (y[..., None]*x).mean()
centered = x - mean
cov = centered.T @ jnp.diag(y) @ centered
s, vecs = jnp.linalg.eigh(cov)
vecs = vecs.T
return s, vecs
def cond_pca(x, y):
s, axes = pca(x, y)
def fn(_):
s, axes = pca(x, y)
return s, axes
s, axes = jax.lax.cond(s.sum() < 5, fn, lambda x: x, (s, axes))
return axes
jax.pmap(jax.vmap(lambda x: cond_pca(x, jnp.arange(5))))(jnp.ones((2, 4, 5, 3)))
It compiles quite long and then prints:
2021-08-27 18:00:23.044559: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc:206] Failed to find best cuBLAS algorithm, GEMM performance might be suboptimal: Internal: All algorithms tried for %custom-call.2 = f32[4,3,5]{2,1,0} custom-call(f32[4,5,3]{2,1,0} %subtract.151, f32[4,5,5]{2,1,0} %convert.177), custom_call_target="__cublas$gemm", metadata={op_type="dot_general" op_name="pmap(<lambda>)/dot_general[ dimension_numbers=(((2,), (1,)), ((0,), (0,)))\n precision=None\n preferred_element_type=None ]" source_file="/tmp/ipykernel_1178334/410680133.py" source_line=4}, backend_config="{\"alpha_real\":1,\"alpha_imag\":0,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"1\"],\"lhs_batch_dimensions\":[\"0\"],\"rhs_batch_dimensions\":[\"0\"]},\"batch_size\":\"4\",\"lhs_stride\":\"15\",\"rhs_stride\":\"25\"}" failed. Falling back to default algorithm.
Altering the code above slightly may resolve the issue I had to dig a bit to nail it down. But similar errors of the kind Failed to find best cuBLAS algorithm
are thrown throughout my code.
Another issue is that my training code simply doesn’t train on multiple A100 (it works on multiple GTX 1080ti and a single A100 with pmap). The GPUs are stuck at 100% utilization but their power consumption is as if they were idling and nothing gets computes. The program is frozen. Sadly, I do not know how to reproduce this without sharing my whole code base which I can’t do.
Here a bit of information about my setup:
- jax
0.2.19[cuda111]
- Nvidia driver
470.57.02
- Python
3.9
Issue Analytics
- State:
- Created 2 years ago
- Comments:7
I simply deactivated IOMMU in the BIOS settings. Inter GPU communication worked after the reboot.
@n-gao How did you fix this issue ? Could you please share a bit of experience ? Thanks!