Importing `jax.numpy` leads to Partially Initialised module error
See original GitHub issuePlease:
- Check for duplicate issues.
I am running JAX
on a Fedora 35
system, with CUDA 11.6
, CuDNN 8.2
, Driver version 510.60.02
[I installed CuDNN
based on the RHEL8
instructions here, since Fedora 35
doesn’t seem to officially get the builds for it]
I installed JAX
by following the instructions present in the README.md file. On doing that, and simply running the import command, I get the following error
import jax.numpy as jnp
AttributeError Traceback (most recent call last)
Input In [8], in <cell line: 4>()
2 import nibabel as nib
3 import numpy as np
----> 4 import jax.numpy as jnp
6 class NIIHandler():
8 def __init__(self, TRAIN_DATASET_PATH):
File ~/.local/lib/python3.10/site-packages/jax/__init__.py:58, in <module>
38 from jax._src.config import (
39 config as config,
40 enable_checks as enable_checks,
(...)
55 transfer_guard_device_to_host as transfer_guard_device_to_host,
56 )
57 from .core import eval_context as ensure_compile_time_eval
---> 58 from jax._src.api import (
59 ad, # TODO(phawkins): update users to avoid this.
60 block_until_ready,
61 checkpoint as checkpoint,
62 checkpoint_policies as checkpoint_policies,
63 closure_convert as closure_convert,
64 curry, # TODO(phawkins): update users to avoid this.
65 custom_gradient as custom_gradient,
66 custom_jvp as custom_jvp,
67 custom_vjp as custom_vjp,
68 default_backend as default_backend,
69 device_count as device_count,
70 device_get as device_get,
71 device_put as device_put,
72 device_put_sharded as device_put_sharded,
73 device_put_replicated as device_put_replicated,
74 devices as devices,
75 disable_jit as disable_jit,
76 eval_shape as eval_shape,
77 flatten_fun_nokwargs, # TODO(phawkins): update users to avoid this.
78 float0 as float0,
79 grad as grad,
80 hessian as hessian,
81 host_count as host_count,
82 host_id as host_id,
83 host_ids as host_ids,
84 jacobian as jacobian,
85 jacfwd as jacfwd,
86 jacrev as jacrev,
87 jit as jit,
88 jvp as jvp,
89 local_device_count as local_device_count,
90 local_devices as local_devices,
91 linearize as linearize,
92 linear_transpose as linear_transpose,
93 make_jaxpr as make_jaxpr,
94 mask as mask,
95 named_call as named_call,
96 pmap as pmap,
97 process_count as process_count,
98 process_index as process_index,
99 pxla, # TODO(phawkins): update users to avoid this.
100 remat as remat,
101 shapecheck as shapecheck,
102 ShapedArray as ShapedArray,
103 ShapeDtypeStruct as ShapeDtypeStruct,
104 # TODO(phawkins): hide tree* functions from jax, update callers to use
105 # jax.tree_util.
106 treedef_is_leaf,
107 tree_flatten,
108 tree_leaves,
109 tree_map,
110 tree_multimap,
111 tree_structure,
112 tree_transpose,
113 tree_unflatten,
114 value_and_grad as value_and_grad,
115 vjp as vjp,
116 vmap as vmap,
117 xla, # TODO(phawkins): update users to avoid this.
118 xla_computation as xla_computation,
119 )
120 from jax.experimental.maps import soft_pmap as soft_pmap
121 from jax.version import __version__ as __version__
File ~/.local/lib/python3.10/site-packages/jax/_src/api.py:61, in <module>
55 from jax._src import traceback_util
56 from jax._src.api_util import (
57 flatten_fun, apply_flat_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2,
58 argnums_partial, argnums_partial_except, flatten_axes, donation_vector,
59 rebase_donate_argnums, _ensure_index, _ensure_index_tuple,
60 shaped_abstractify, _ensure_str_tuple, argnames_partial_except)
---> 61 from jax._src.lax import lax as lax_internal
62 from jax._src.lib import jax_jit
63 from jax._src.lib import xla_bridge as xb
File ~/.local/lib/python3.10/site-packages/jax/_src/lax/lax.py:1653, in <module>
1651 tan_p = standard_unop(_float | _complex, 'tan')
1652 ad.defjvp2(tan_p, lambda g, ans, x: mul(g, _const(x, 1) + square(ans)))
-> 1653 if jax._src.lib.mlir_api_version >= 11:
1654 mlir.register_lowering(tan_p, partial(_nary_lower_mhlo, chlo.TanOp))
1655 else:
AttributeError: partially initialized module 'jax' has no attribute '_src' (most likely due to a circular import)
I am extremely new to JAX
, so please do let me know if there is something else I should be trying instead. Attaching my nvidia-smi
and nvcc -- version
results below.
Thank you very much!
Issue Analytics
- State:
- Created a year ago
- Comments:7
Top Results From Across the Web
How to fix AttributeError: partially initialized module? [duplicate]
This can happen when there's a local file with the same name as an imported module – Python sees the local file and...
Read more >AttributeError: partially initialized module has no attribute
This means that you are either trying to access an attribute that is not present on the module, or you have an incorrect...
Read more >jax.core - JAX documentation - Read the Docs
_src.config import FLAGS, config from jax.errors import ... t, tracers: List[Tracer]) -> Exception: assert tracers why = partial(_why_alive, {id(tracers)}) ...
Read more >partially initialized module 'random' has no attribute 'sample ...
AttributeError: partially initialized module 'jax' has no attribute '_src' (most likely due to a circular import) #385. Closed ... partially initialized module ......
Read more >AttributeError: partially initialized module has no attribute ...
AttributeError: partially initialized module 'MODULE_NAME' has no attribute 'ATTRIBUTE_NAME' (most likely due to a circular import).
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
It looks like you have a local file in your directory named
jax.py
, so when you runimport jax
it is loading this file instead of the jax package. I’d suggest renaming this file; for example:Hi I got the same error when import jax, like the screenshot: