`--config=cuda` overrides explicit options, instead of augmenting them
See original GitHub issuePlease:
- Check for duplicate issues.
- Provide a complete example of how to reproduce the bug, wrapped in triple backticks like this:
import jax.numpy as jnp
print(jnp.arange(10))
# [0 1 2 3 4 5 6 7 8 9]
- If applicable, include full error messages/tracebacks.
I am trying to build cuda-enabled jaxlib on conda-forge. I am running into a problem where the --config=cuda
flag replaces the custom toolchain. I understand this is outside the scope here, but getting cuda-enabled jax on conda-forge seems to be a very valuable contribution, and I hope you can help us achieve that. We do build cuda-enabled tensorflow, and we follow our method almost exactly. The option --config=cuda
works fine in our tensorflow builds.
WARNING: option '--config=cuda' (source command line options) was expanded and now overrides the explicit option --crosstool_top=//bazel_toolchain:toolchain with --crosstool_top=@local_config_cuda//crosstool:toolchain
Once the above happens, the local_config_cuda parameters essentially cannot see any of the system libraries and everything breaks. I haven’t looked too deeply into this yet, so bringing it here first in case you have a quick answer/shortcut.
Corresponding PR in conda-forge: https://github.com/conda-forge/jaxlib-feedstock/pull/97
Tagging @hawkinsp who’s been involved in issues related to the builds before. Thanks!
Issue Analytics
- State:
- Created a year ago
- Comments:10 (10 by maintainers)
Another thing you can try is to add things at the end to
.jax_configure.bazelrc
(Make sure you are adding your custom things at the end of the file – not the python file but the .jax_configure.bazelrc file.)The arguments to
bazel
are ordering sensitive. I’m wonder if the fix would be to reorder the arguments here: https://github.com/google/jax/blob/7d4d15e260e24c54f9fbc1685a45d910edb61e43/build/build.py#L493so that the user-provided
--bazel_options
override--config=cuda
(i.e., come later), rather than the other way around?