jaxlib.mlir._mlir_libs._mlir overrides python's logging module config.
See original GitHub issueDescription
Importing jaxlib.mlir._mlir_libs._mlir
(and by extension any modules depending on these, including jaxlib
and jax
), seems to render python’s logging module unable to work.
I have a script test.py
:
import logging
import importlib
import sys
logger = logging.getLogger(__name__)
import jax
logging.basicConfig(level=logging.INFO)
# import jax
logger.info('Test')
Running python -m test
results in no logging output. If the second import statement, after the logging.basicConfig
, is used instead of the first, then the expected INFO:__main__:Test
line appears. If another module e.g. numpy
is used instead of jax
, the log line appears. I also do see the log line when using jax==0.3.14, jaxlib==0.3.14
- it disappears on version 0.3.15-0.3.18.
By looking at sys.modules
before and after the import, the module that replicates this behaviour with the smallest number of dependencies appears to be jaxlib.mlir._mlir_libs._mlir.
What jax/jaxlib version are you using?
jax >=0.3.14, jaxlib >=0.3.14
Which accelerator(s) are you using?
CPU
Additional system info
Linux
NVIDIA GPU info
No response
Issue Analytics
- State:
- Created a year ago
- Comments:8 (6 by maintainers)
The upstream LLVM fix was merged. This should be fixed as part of the next jaxlib release.
I guess we’re still waiting for the MLIR fix to be merged…