question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Importing `jax.numpy` leads to Partially Initialised module error

See original GitHub issue

Please:

  • Check for duplicate issues.

I am running JAX on a Fedora 35 system, with CUDA 11.6, CuDNN 8.2, Driver version 510.60.02

[I installed CuDNN based on the RHEL8 instructions here, since Fedora 35 doesn’t seem to officially get the builds for it] I installed JAX by following the instructions present in the README.md file. On doing that, and simply running the import command, I get the following error

import jax.numpy as jnp
AttributeError                            Traceback (most recent call last)
Input In [8], in <cell line: 4>()
      2 import nibabel as nib
      3 import numpy as np 
----> 4 import jax.numpy as jnp
      6 class NIIHandler():
      8     def __init__(self, TRAIN_DATASET_PATH):

File ~/.local/lib/python3.10/site-packages/jax/__init__.py:58, in <module>
     38 from jax._src.config import (
     39   config as config,
     40   enable_checks as enable_checks,
   (...)
     55   transfer_guard_device_to_host as transfer_guard_device_to_host,
     56 )
     57 from .core import eval_context as ensure_compile_time_eval
---> 58 from jax._src.api import (
     59   ad,  # TODO(phawkins): update users to avoid this.
     60   block_until_ready,
     61   checkpoint as checkpoint,
     62   checkpoint_policies as checkpoint_policies,
     63   closure_convert as closure_convert,
     64   curry,  # TODO(phawkins): update users to avoid this.
     65   custom_gradient as custom_gradient,
     66   custom_jvp as custom_jvp,
     67   custom_vjp as custom_vjp,
     68   default_backend as default_backend,
     69   device_count as device_count,
     70   device_get as device_get,
     71   device_put as device_put,
     72   device_put_sharded as device_put_sharded,
     73   device_put_replicated as device_put_replicated,
     74   devices as devices,
     75   disable_jit as disable_jit,
     76   eval_shape as eval_shape,
     77   flatten_fun_nokwargs,  # TODO(phawkins): update users to avoid this.
     78   float0 as float0,
     79   grad as grad,
     80   hessian as hessian,
     81   host_count as host_count,
     82   host_id as host_id,
     83   host_ids as host_ids,
     84   jacobian as jacobian,
     85   jacfwd as jacfwd,
     86   jacrev as jacrev,
     87   jit as jit,
     88   jvp as jvp,
     89   local_device_count as local_device_count,
     90   local_devices as local_devices,
     91   linearize as linearize,
     92   linear_transpose as linear_transpose,
     93   make_jaxpr as make_jaxpr,
     94   mask as mask,
     95   named_call as named_call,
     96   pmap as pmap,
     97   process_count as process_count,
     98   process_index as process_index,
     99   pxla,  # TODO(phawkins): update users to avoid this.
    100   remat as remat,
    101   shapecheck as shapecheck,
    102   ShapedArray as ShapedArray,
    103   ShapeDtypeStruct as ShapeDtypeStruct,
    104   # TODO(phawkins): hide tree* functions from jax, update callers to use
    105   # jax.tree_util.
    106   treedef_is_leaf,
    107   tree_flatten,
    108   tree_leaves,
    109   tree_map,
    110   tree_multimap,
    111   tree_structure,
    112   tree_transpose,
    113   tree_unflatten,
    114   value_and_grad as value_and_grad,
    115   vjp as vjp,
    116   vmap as vmap,
    117   xla,  # TODO(phawkins): update users to avoid this.
    118   xla_computation as xla_computation,
    119 )
    120 from jax.experimental.maps import soft_pmap as soft_pmap
    121 from jax.version import __version__ as __version__

File ~/.local/lib/python3.10/site-packages/jax/_src/api.py:61, in <module>
     55 from jax._src import traceback_util
     56 from jax._src.api_util import (
     57     flatten_fun, apply_flat_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2,
     58     argnums_partial, argnums_partial_except, flatten_axes, donation_vector,
     59     rebase_donate_argnums, _ensure_index, _ensure_index_tuple,
     60     shaped_abstractify, _ensure_str_tuple, argnames_partial_except)
---> 61 from jax._src.lax import lax as lax_internal
     62 from jax._src.lib import jax_jit
     63 from jax._src.lib import xla_bridge as xb

File ~/.local/lib/python3.10/site-packages/jax/_src/lax/lax.py:1653, in <module>
   1651 tan_p = standard_unop(_float | _complex, 'tan')
   1652 ad.defjvp2(tan_p, lambda g, ans, x: mul(g, _const(x, 1) + square(ans)))
-> 1653 if jax._src.lib.mlir_api_version >= 11:
   1654   mlir.register_lowering(tan_p, partial(_nary_lower_mhlo, chlo.TanOp))
   1655 else:

AttributeError: partially initialized module 'jax' has no attribute '_src' (most likely due to a circular import)

I am extremely new to JAX, so please do let me know if there is something else I should be trying instead. Attaching my nvidia-smi and nvcc -- version results below.

image

image

Thank you very much!

Issue Analytics

  • State:closed
  • Created a year ago
  • Comments:7

github_iconTop GitHub Comments

1reaction
jakevdpcommented, Jun 22, 2022

It looks like you have a local file in your directory named jax.py, so when you run import jax it is loading this file instead of the jax package. I’d suggest renaming this file; for example:

$ mv jax.py my_jax_script.py
0reactions
carinawmy1commented, Jul 10, 2022

Hi I got the same error when import jax, like the screenshot: WechatIMG288

WechatIMG289 WechatIMG290 thank you very much for your advice!
Read more comments on GitHub >

github_iconTop Results From Across the Web

How to fix AttributeError: partially initialized module? [duplicate]
This can happen when there's a local file with the same name as an imported module – Python sees the local file and...
Read more >
AttributeError: partially initialized module has no attribute
This means that you are either trying to access an attribute that is not present on the module, or you have an incorrect...
Read more >
jax.core - JAX documentation - Read the Docs
_src.config import FLAGS, config from jax.errors import ... t, tracers: List[Tracer]) -> Exception: assert tracers why = partial(_why_alive, {id(tracers)}) ...
Read more >
partially initialized module 'random' has no attribute 'sample ...
AttributeError: partially initialized module 'jax' has no attribute '_src' (most likely due to a circular import) #385. Closed ... partially initialized module ......
Read more >
AttributeError: partially initialized module has no attribute ...
AttributeError: partially initialized module 'MODULE_NAME' has no attribute 'ATTRIBUTE_NAME' (most likely due to a circular import).
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found