question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

With tensorflow probability: AttributeError: module 'jax' has no attribute 'custom_transforms'

See original GitHub issue

I am reimplementing some Google/DeepMind research code that uses jax and tensorflow probability (e.g. https://arxiv.org/pdf/2101.11046.pdf) that relies on TensorFlow Probability.

I am following the tutorial for TensorFlow Probability: https://www.tensorflow.org/probability/examples/TensorFlow_Probability_on_JAX

The below bug seems to occur on the first import - unless I’m missing something simple (it’s my first time).

Complete example of how to reproduce the bug:

Install jax on Ubuntu 20.04 / Anaconda / Python 3.9.2:

pip install --upgrade pip
pip install --upgrade jax jaxlib  # CPU-only version

Install tensorflow probability:

pip install --upgrade tensorflow-probability

Try to follow the example: https://www.tensorflow.org/probability/examples/TensorFlow_Probability_on_JAX

from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions

Full error message/traceback:

🕙 16:59:21 ❯ ipython
Python 3.9.2 (default, Mar  3 2021, 20:02:32) 
Type 'copyright', 'credits' or 'license' for more information
IPython 7.22.0 -- An enhanced Interactive Python. Type '?' for help.

In [1]: from tensorflow_probability.substrates import jax as tfp
   ...: 

In [2]: tfp
Out[2]: <module 'tensorflow_probability.substrates.jax'>

In [3]: dir(tfp)
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-3-24f289242cbf> in <module>
----> 1 dir(tfp)

~/miniconda3/lib/python3.9/site-packages/tensorflow_probability/python/internal/lazy_loader.py in __dir__(self)
     59 
     60   def __dir__(self):
---> 61     module = self._load()
     62     return dir(module)
     63 

~/miniconda3/lib/python3.9/site-packages/tensorflow_probability/python/internal/lazy_loader.py in _load(self)
     42       self._on_first_access = None
     43     # Import the target module and insert it into the parent's namespace
---> 44     module = importlib.import_module(self.__name__)
     45     if self._parent_module_globals is not None:
     46       self._parent_module_globals[self._local_name] = module

~/miniconda3/lib/python3.9/importlib/__init__.py in import_module(name, package)
    125                 break
    126             level += 1
--> 127     return _bootstrap._gcd_import(name[level:], package, level)
    128 
    129 

~/miniconda3/lib/python3.9/importlib/_bootstrap.py in _gcd_import(name, package, level)

~/miniconda3/lib/python3.9/importlib/_bootstrap.py in _find_and_load(name, import_)

~/miniconda3/lib/python3.9/importlib/_bootstrap.py in _find_and_load_unlocked(name, import_)

~/miniconda3/lib/python3.9/importlib/_bootstrap.py in _load_unlocked(spec)

~/miniconda3/lib/python3.9/importlib/_bootstrap_external.py in exec_module(self, module)

~/miniconda3/lib/python3.9/importlib/_bootstrap.py in _call_with_frames_removed(f, *args, **kwds)

~/miniconda3/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/__init__.py in <module>
     42 
     43 from tensorflow_probability.python.version import __version__
---> 44 from tensorflow_probability.substrates.jax import bijectors
     45 from tensorflow_probability.substrates.jax import distributions
     46 from tensorflow_probability.substrates.jax import experimental

~/miniconda3/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/bijectors/__init__.py in <module>
     21 # pylint: disable=unused-import,wildcard-import,line-too-long,g-importing-member
     22 
---> 23 from tensorflow_probability.substrates.jax.bijectors.absolute_value import AbsoluteValue
     24 from tensorflow_probability.substrates.jax.bijectors.affine import Affine
     25 from tensorflow_probability.substrates.jax.bijectors.affine_linear_operator import AffineLinearOperator

~/miniconda3/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/bijectors/absolute_value.py in <module>
     21 from tensorflow_probability.python.internal.backend.jax.compat import v2 as tf
     22 
---> 23 from tensorflow_probability.substrates.jax.bijectors import bijector
     24 from tensorflow_probability.substrates.jax.internal import assert_util
     25 from tensorflow_probability.substrates.jax.internal import dtype_util

~/miniconda3/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py in <module>
     33 from tensorflow_probability.substrates.jax.internal import nest_util
     34 from tensorflow_probability.substrates.jax.internal import prefer_static as ps
---> 35 from tensorflow_probability.substrates.jax.math import gradient
     36 from tensorflow_probability.python.internal.backend.jax import nest  # pylint: disable=g-direct-tensorflow-import
     37 

~/miniconda3/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/math/__init__.py in <module>
     21 from tensorflow_probability.python.internal import all_util
     22 # from tensorflow_probability.substrates.jax.math import ode
---> 23 from tensorflow_probability.substrates.jax.math import psd_kernels
     24 from tensorflow_probability.substrates.jax.math.bessel import bessel_iv_ratio
     25 from tensorflow_probability.substrates.jax.math.bessel import bessel_ive

~/miniconda3/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/math/psd_kernels/__init__.py in <module>
     20 
     21 from tensorflow_probability.python.internal import all_util
---> 22 from tensorflow_probability.substrates.jax.math.psd_kernels.exp_sin_squared import ExpSinSquared
     23 from tensorflow_probability.substrates.jax.math.psd_kernels.exponentiated_quadratic import ExponentiatedQuadratic
     24 from tensorflow_probability.substrates.jax.math.psd_kernels.feature_scaled import FeatureScaled

~/miniconda3/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/math/psd_kernels/exp_sin_squared.py in <module>
     24 from tensorflow_probability.substrates.jax.internal import assert_util
     25 from tensorflow_probability.substrates.jax.internal import tensor_util
---> 26 from tensorflow_probability.substrates.jax.math.psd_kernels.internal import util
     27 from tensorflow_probability.substrates.jax.math.psd_kernels.positive_semidefinite_kernel import PositiveSemidefiniteKernel
     28 

~/miniconda3/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/math/psd_kernels/internal/util.py in <module>
    108 
    109 @tf.custom_gradient
--> 110 def sqrt_with_finite_grads(x, name=None):
    111   """A sqrt function whose gradient at zero is very large but finite.
    112 

~/miniconda3/lib/python3.9/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py in _custom_gradient(f)
    420       return cts_in
    421     return value, vjp_
--> 422   @jax.custom_transforms
    423   @functools.wraps(f)
    424   def wrapped(*args, **kwargs):

AttributeError: module 'jax' has no attribute 'custom_transforms'

In [4]: 🕙 16:59:21 ❯ ipython
Python 3.9.2 (default, Mar  3 2021, 20:02:32) 
Type 'copyright', 'credits' or 'license' for more information
IPython 7.22.0 -- An enhanced Interactive Python. Type '?' for help.

In [1]: from tensorflow_probability.substrates import jax as tfp
   ...: 

In [2]: tfp
Out[2]: <module 'tensorflow_probability.substrates.jax'>

In [3]: dir(tfp)
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-3-24f289242cbf> in <module>
----> 1 dir(tfp)

~/miniconda3/lib/python3.9/site-packages/tensorflow_probability/python/internal/lazy_loader.py in __dir__(self)
     59 
     60   def __dir__(self):
---> 61     module = self._load()
     62     return dir(module)
     63 

~/miniconda3/lib/python3.9/site-packages/tensorflow_probability/python/internal/lazy_loader.py in _load(self)
     42       self._on_first_access = None
     43     # Import the target module and insert it into the parent's namespace
---> 44     module = importlib.import_module(self.__name__)
     45     if self._parent_module_globals is not None:
     46       self._parent_module_globals[self._local_name] = module

~/miniconda3/lib/python3.9/importlib/__init__.py in import_module(name, package)
    125                 break
    126             level += 1
--> 127     return _bootstrap._gcd_import(name[level:], package, level)
    128 
    129 

~/miniconda3/lib/python3.9/importlib/_bootstrap.py in _gcd_import(name, package, level)

~/miniconda3/lib/python3.9/importlib/_bootstrap.py in _find_and_load(name, import_)

~/miniconda3/lib/python3.9/importlib/_bootstrap.py in _find_and_load_unlocked(name, import_)

~/miniconda3/lib/python3.9/importlib/_bootstrap.py in _load_unlocked(spec)

~/miniconda3/lib/python3.9/importlib/_bootstrap_external.py in exec_module(self, module)

~/miniconda3/lib/python3.9/importlib/_bootstrap.py in _call_with_frames_removed(f, *args, **kwds)

~/miniconda3/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/__init__.py in <module>
     42 
     43 from tensorflow_probability.python.version import __version__
---> 44 from tensorflow_probability.substrates.jax import bijectors
     45 from tensorflow_probability.substrates.jax import distributions
     46 from tensorflow_probability.substrates.jax import experimental

~/miniconda3/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/bijectors/__init__.py in <module>
     21 # pylint: disable=unused-import,wildcard-import,line-too-long,g-importing-member
     22 
---> 23 from tensorflow_probability.substrates.jax.bijectors.absolute_value import AbsoluteValue
     24 from tensorflow_probability.substrates.jax.bijectors.affine import Affine
     25 from tensorflow_probability.substrates.jax.bijectors.affine_linear_operator import AffineLinearOperator

~/miniconda3/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/bijectors/absolute_value.py in <module>
     21 from tensorflow_probability.python.internal.backend.jax.compat import v2 as tf
     22 
---> 23 from tensorflow_probability.substrates.jax.bijectors import bijector
     24 from tensorflow_probability.substrates.jax.internal import assert_util
     25 from tensorflow_probability.substrates.jax.internal import dtype_util

~/miniconda3/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py in <module>
     33 from tensorflow_probability.substrates.jax.internal import nest_util
     34 from tensorflow_probability.substrates.jax.internal import prefer_static as ps
---> 35 from tensorflow_probability.substrates.jax.math import gradient
     36 from tensorflow_probability.python.internal.backend.jax import nest  # pylint: disable=g-direct-tensorflow-import
     37 

~/miniconda3/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/math/__init__.py in <module>
     21 from tensorflow_probability.python.internal import all_util
     22 # from tensorflow_probability.substrates.jax.math import ode
---> 23 from tensorflow_probability.substrates.jax.math import psd_kernels
     24 from tensorflow_probability.substrates.jax.math.bessel import bessel_iv_ratio
     25 from tensorflow_probability.substrates.jax.math.bessel import bessel_ive

~/miniconda3/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/math/psd_kernels/__init__.py in <module>
     20 
     21 from tensorflow_probability.python.internal import all_util
---> 22 from tensorflow_probability.substrates.jax.math.psd_kernels.exp_sin_squared import ExpSinSquared
     23 from tensorflow_probability.substrates.jax.math.psd_kernels.exponentiated_quadratic import ExponentiatedQuadratic
     24 from tensorflow_probability.substrates.jax.math.psd_kernels.feature_scaled import FeatureScaled

~/miniconda3/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/math/psd_kernels/exp_sin_squared.py in <module>
     24 from tensorflow_probability.substrates.jax.internal import assert_util
     25 from tensorflow_probability.substrates.jax.internal import tensor_util
---> 26 from tensorflow_probability.substrates.jax.math.psd_kernels.internal import util
     27 from tensorflow_probability.substrates.jax.math.psd_kernels.positive_semidefinite_kernel import PositiveSemidefiniteKernel
     28 

~/miniconda3/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/math/psd_kernels/internal/util.py in <module>
    108 
    109 @tf.custom_gradient
--> 110 def sqrt_with_finite_grads(x, name=None):
    111   """A sqrt function whose gradient at zero is very large but finite.
    112 

~/miniconda3/lib/python3.9/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py in _custom_gradient(f)
    420       return cts_in
    421     return value, vjp_
--> 422   @jax.custom_transforms
    423   @functools.wraps(f)
    424   def wrapped(*args, **kwargs):

AttributeError: module 'jax' has no attribute 'custom_transforms'

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:9 (4 by maintainers)

github_iconTop GitHub Comments

1reaction
sharadmvcommented, May 20, 2021

Sorry for the trouble! TFP has been updated at HEAD and on its nightly branches, but the latest pip release (0.12.2) only works with JAX <= 0.2.11. We plan on having a mainline release soon that will be compatible with the newest JAX version.

In the meantime, feel free to use our nightlies via:

pip install tfp-nightly[jax]

These should always be compatible with the latest JAX.

1reaction
mattjjcommented, May 20, 2021

(Hey Jaan!)

I think TFP has been updated not to depend on jax.custom_transforms, but maybe the pypi package hasn’t been yet.

@sharadmv can you advise?

Read more comments on GitHub >

github_iconTop Results From Across the Web

TensorFlow Probability on JAX
TensorFlow Probability (TFP) is a library for probabilistic reasoning and statistical analysis that now also works on JAX! For those not ...
Read more >
AttributeError: module 'tensorflow_probability.* has no ...
I'm using tensorflow 2.0.0 and tensorflow-probability 0.8.0 and I see that in colab notebook is used @tf.function so I thought it used ...
Read more >
AttributeError: module 'transforms' has no attribute 'Normalize'
It seems you are importing a custom transforms module and not torchvision.transforms , which doesn't seem to have the Normalize transformation.
Read more >
tensorflow-probability - PyPI
TFP also works as "Tensor-friendly Probability" in pure JAX!: from tensorflow_probability.substrates import jax as tfp -- Learn more here. Our probabilistic ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found