With tensorflow probability: AttributeError: module 'jax' has no attribute 'custom_transforms'
See original GitHub issueI am reimplementing some Google/DeepMind research code that uses jax and tensorflow probability (e.g. https://arxiv.org/pdf/2101.11046.pdf) that relies on TensorFlow Probability.
I am following the tutorial for TensorFlow Probability: https://www.tensorflow.org/probability/examples/TensorFlow_Probability_on_JAX
The below bug seems to occur on the first import - unless I’m missing something simple (it’s my first time).
Complete example of how to reproduce the bug:
Install jax on Ubuntu 20.04 / Anaconda / Python 3.9.2:
pip install --upgrade pip
pip install --upgrade jax jaxlib # CPU-only version
Install tensorflow probability:
pip install --upgrade tensorflow-probability
Try to follow the example: https://www.tensorflow.org/probability/examples/TensorFlow_Probability_on_JAX
from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
Full error message/traceback:
🕙 16:59:21 ❯ ipython
Python 3.9.2 (default, Mar 3 2021, 20:02:32)
Type 'copyright', 'credits' or 'license' for more information
IPython 7.22.0 -- An enhanced Interactive Python. Type '?' for help.
In [1]: from tensorflow_probability.substrates import jax as tfp
...:
In [2]: tfp
Out[2]: <module 'tensorflow_probability.substrates.jax'>
In [3]: dir(tfp)
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-3-24f289242cbf> in <module>
----> 1 dir(tfp)
~/miniconda3/lib/python3.9/site-packages/tensorflow_probability/python/internal/lazy_loader.py in __dir__(self)
59
60 def __dir__(self):
---> 61 module = self._load()
62 return dir(module)
63
~/miniconda3/lib/python3.9/site-packages/tensorflow_probability/python/internal/lazy_loader.py in _load(self)
42 self._on_first_access = None
43 # Import the target module and insert it into the parent's namespace
---> 44 module = importlib.import_module(self.__name__)
45 if self._parent_module_globals is not None:
46 self._parent_module_globals[self._local_name] = module
~/miniconda3/lib/python3.9/importlib/__init__.py in import_module(name, package)
125 break
126 level += 1
--> 127 return _bootstrap._gcd_import(name[level:], package, level)
128
129
~/miniconda3/lib/python3.9/importlib/_bootstrap.py in _gcd_import(name, package, level)
~/miniconda3/lib/python3.9/importlib/_bootstrap.py in _find_and_load(name, import_)
~/miniconda3/lib/python3.9/importlib/_bootstrap.py in _find_and_load_unlocked(name, import_)
~/miniconda3/lib/python3.9/importlib/_bootstrap.py in _load_unlocked(spec)
~/miniconda3/lib/python3.9/importlib/_bootstrap_external.py in exec_module(self, module)
~/miniconda3/lib/python3.9/importlib/_bootstrap.py in _call_with_frames_removed(f, *args, **kwds)
~/miniconda3/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/__init__.py in <module>
42
43 from tensorflow_probability.python.version import __version__
---> 44 from tensorflow_probability.substrates.jax import bijectors
45 from tensorflow_probability.substrates.jax import distributions
46 from tensorflow_probability.substrates.jax import experimental
~/miniconda3/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/bijectors/__init__.py in <module>
21 # pylint: disable=unused-import,wildcard-import,line-too-long,g-importing-member
22
---> 23 from tensorflow_probability.substrates.jax.bijectors.absolute_value import AbsoluteValue
24 from tensorflow_probability.substrates.jax.bijectors.affine import Affine
25 from tensorflow_probability.substrates.jax.bijectors.affine_linear_operator import AffineLinearOperator
~/miniconda3/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/bijectors/absolute_value.py in <module>
21 from tensorflow_probability.python.internal.backend.jax.compat import v2 as tf
22
---> 23 from tensorflow_probability.substrates.jax.bijectors import bijector
24 from tensorflow_probability.substrates.jax.internal import assert_util
25 from tensorflow_probability.substrates.jax.internal import dtype_util
~/miniconda3/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py in <module>
33 from tensorflow_probability.substrates.jax.internal import nest_util
34 from tensorflow_probability.substrates.jax.internal import prefer_static as ps
---> 35 from tensorflow_probability.substrates.jax.math import gradient
36 from tensorflow_probability.python.internal.backend.jax import nest # pylint: disable=g-direct-tensorflow-import
37
~/miniconda3/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/math/__init__.py in <module>
21 from tensorflow_probability.python.internal import all_util
22 # from tensorflow_probability.substrates.jax.math import ode
---> 23 from tensorflow_probability.substrates.jax.math import psd_kernels
24 from tensorflow_probability.substrates.jax.math.bessel import bessel_iv_ratio
25 from tensorflow_probability.substrates.jax.math.bessel import bessel_ive
~/miniconda3/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/math/psd_kernels/__init__.py in <module>
20
21 from tensorflow_probability.python.internal import all_util
---> 22 from tensorflow_probability.substrates.jax.math.psd_kernels.exp_sin_squared import ExpSinSquared
23 from tensorflow_probability.substrates.jax.math.psd_kernels.exponentiated_quadratic import ExponentiatedQuadratic
24 from tensorflow_probability.substrates.jax.math.psd_kernels.feature_scaled import FeatureScaled
~/miniconda3/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/math/psd_kernels/exp_sin_squared.py in <module>
24 from tensorflow_probability.substrates.jax.internal import assert_util
25 from tensorflow_probability.substrates.jax.internal import tensor_util
---> 26 from tensorflow_probability.substrates.jax.math.psd_kernels.internal import util
27 from tensorflow_probability.substrates.jax.math.psd_kernels.positive_semidefinite_kernel import PositiveSemidefiniteKernel
28
~/miniconda3/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/math/psd_kernels/internal/util.py in <module>
108
109 @tf.custom_gradient
--> 110 def sqrt_with_finite_grads(x, name=None):
111 """A sqrt function whose gradient at zero is very large but finite.
112
~/miniconda3/lib/python3.9/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py in _custom_gradient(f)
420 return cts_in
421 return value, vjp_
--> 422 @jax.custom_transforms
423 @functools.wraps(f)
424 def wrapped(*args, **kwargs):
AttributeError: module 'jax' has no attribute 'custom_transforms'
In [4]: 🕙 16:59:21 ❯ ipython
Python 3.9.2 (default, Mar 3 2021, 20:02:32)
Type 'copyright', 'credits' or 'license' for more information
IPython 7.22.0 -- An enhanced Interactive Python. Type '?' for help.
In [1]: from tensorflow_probability.substrates import jax as tfp
...:
In [2]: tfp
Out[2]: <module 'tensorflow_probability.substrates.jax'>
In [3]: dir(tfp)
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-3-24f289242cbf> in <module>
----> 1 dir(tfp)
~/miniconda3/lib/python3.9/site-packages/tensorflow_probability/python/internal/lazy_loader.py in __dir__(self)
59
60 def __dir__(self):
---> 61 module = self._load()
62 return dir(module)
63
~/miniconda3/lib/python3.9/site-packages/tensorflow_probability/python/internal/lazy_loader.py in _load(self)
42 self._on_first_access = None
43 # Import the target module and insert it into the parent's namespace
---> 44 module = importlib.import_module(self.__name__)
45 if self._parent_module_globals is not None:
46 self._parent_module_globals[self._local_name] = module
~/miniconda3/lib/python3.9/importlib/__init__.py in import_module(name, package)
125 break
126 level += 1
--> 127 return _bootstrap._gcd_import(name[level:], package, level)
128
129
~/miniconda3/lib/python3.9/importlib/_bootstrap.py in _gcd_import(name, package, level)
~/miniconda3/lib/python3.9/importlib/_bootstrap.py in _find_and_load(name, import_)
~/miniconda3/lib/python3.9/importlib/_bootstrap.py in _find_and_load_unlocked(name, import_)
~/miniconda3/lib/python3.9/importlib/_bootstrap.py in _load_unlocked(spec)
~/miniconda3/lib/python3.9/importlib/_bootstrap_external.py in exec_module(self, module)
~/miniconda3/lib/python3.9/importlib/_bootstrap.py in _call_with_frames_removed(f, *args, **kwds)
~/miniconda3/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/__init__.py in <module>
42
43 from tensorflow_probability.python.version import __version__
---> 44 from tensorflow_probability.substrates.jax import bijectors
45 from tensorflow_probability.substrates.jax import distributions
46 from tensorflow_probability.substrates.jax import experimental
~/miniconda3/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/bijectors/__init__.py in <module>
21 # pylint: disable=unused-import,wildcard-import,line-too-long,g-importing-member
22
---> 23 from tensorflow_probability.substrates.jax.bijectors.absolute_value import AbsoluteValue
24 from tensorflow_probability.substrates.jax.bijectors.affine import Affine
25 from tensorflow_probability.substrates.jax.bijectors.affine_linear_operator import AffineLinearOperator
~/miniconda3/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/bijectors/absolute_value.py in <module>
21 from tensorflow_probability.python.internal.backend.jax.compat import v2 as tf
22
---> 23 from tensorflow_probability.substrates.jax.bijectors import bijector
24 from tensorflow_probability.substrates.jax.internal import assert_util
25 from tensorflow_probability.substrates.jax.internal import dtype_util
~/miniconda3/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py in <module>
33 from tensorflow_probability.substrates.jax.internal import nest_util
34 from tensorflow_probability.substrates.jax.internal import prefer_static as ps
---> 35 from tensorflow_probability.substrates.jax.math import gradient
36 from tensorflow_probability.python.internal.backend.jax import nest # pylint: disable=g-direct-tensorflow-import
37
~/miniconda3/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/math/__init__.py in <module>
21 from tensorflow_probability.python.internal import all_util
22 # from tensorflow_probability.substrates.jax.math import ode
---> 23 from tensorflow_probability.substrates.jax.math import psd_kernels
24 from tensorflow_probability.substrates.jax.math.bessel import bessel_iv_ratio
25 from tensorflow_probability.substrates.jax.math.bessel import bessel_ive
~/miniconda3/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/math/psd_kernels/__init__.py in <module>
20
21 from tensorflow_probability.python.internal import all_util
---> 22 from tensorflow_probability.substrates.jax.math.psd_kernels.exp_sin_squared import ExpSinSquared
23 from tensorflow_probability.substrates.jax.math.psd_kernels.exponentiated_quadratic import ExponentiatedQuadratic
24 from tensorflow_probability.substrates.jax.math.psd_kernels.feature_scaled import FeatureScaled
~/miniconda3/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/math/psd_kernels/exp_sin_squared.py in <module>
24 from tensorflow_probability.substrates.jax.internal import assert_util
25 from tensorflow_probability.substrates.jax.internal import tensor_util
---> 26 from tensorflow_probability.substrates.jax.math.psd_kernels.internal import util
27 from tensorflow_probability.substrates.jax.math.psd_kernels.positive_semidefinite_kernel import PositiveSemidefiniteKernel
28
~/miniconda3/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/math/psd_kernels/internal/util.py in <module>
108
109 @tf.custom_gradient
--> 110 def sqrt_with_finite_grads(x, name=None):
111 """A sqrt function whose gradient at zero is very large but finite.
112
~/miniconda3/lib/python3.9/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py in _custom_gradient(f)
420 return cts_in
421 return value, vjp_
--> 422 @jax.custom_transforms
423 @functools.wraps(f)
424 def wrapped(*args, **kwargs):
AttributeError: module 'jax' has no attribute 'custom_transforms'
Issue Analytics
- State:
- Created 2 years ago
- Comments:9 (4 by maintainers)
Top Results From Across the Web
TensorFlow Probability on JAX
TensorFlow Probability (TFP) is a library for probabilistic reasoning and statistical analysis that now also works on JAX! For those not ...
Read more >AttributeError: module 'tensorflow_probability.* has no ...
I'm using tensorflow 2.0.0 and tensorflow-probability 0.8.0 and I see that in colab notebook is used @tf.function so I thought it used ...
Read more >AttributeError: module 'transforms' has no attribute 'Normalize'
It seems you are importing a custom transforms module and not torchvision.transforms , which doesn't seem to have the Normalize transformation.
Read more >tensorflow-probability - PyPI
TFP also works as "Tensor-friendly Probability" in pure JAX!: from tensorflow_probability.substrates import jax as tfp -- Learn more here. Our probabilistic ...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Sorry for the trouble! TFP has been updated at HEAD and on its nightly branches, but the latest pip release (0.12.2) only works with JAX <= 0.2.11. We plan on having a mainline release soon that will be compatible with the newest JAX version.
In the meantime, feel free to use our nightlies via:
These should always be compatible with the latest JAX.
(Hey Jaan!)
I think TFP has been updated not to depend on
jax.custom_transforms
, but maybe the pypi package hasn’t been yet.@sharadmv can you advise?