Not compatible with np.tensordot
See original GitHub issueEnvironments:
- jaxlib: 0.1.61+cuda110
- jax: 0.2.10
- dm-haiku: 0.0.4.dev0
- numpy: 1.19.5
To reproduce the bug:
- code is given as bellow
- you might need to run it multiple times
import jax
import haiku as hk
import numpy as np
def net(x):
torso_net = hk.Conv2D(32, kernel_shape=[8, 8], stride=[4, 4], padding='VALID')
return hk.BatchApply(torso_net)(x)
def main():
init_fn, apply_fn = hk.without_apply_rng(
hk.transform(lambda obs: net(obs))
)
# Uncomment can also reproduce the bug
# @jax.jit
def initial_params(rng_key):
_ = np.tensordot(np.zeros((84,84,3)), [0., 0., 0.], (-1, 0))
return init_fn(rng_key, np.zeros((1,1,84,84,4)))
rng_key = jax.random.PRNGKey(428)
params = initial_params(rng_key)
if __name__ == '__main__':
main()
The full error messages/tracebacks.
- with @jax.jit
Segmentation fault (core dumped)
- without @jax.jit
/usr/local/lib/python3.6/dist-packages/jax/_src/numpy/lax_numpy.py:2970: UserWarning: Explicitly requested dtype float64 requested in zeros is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
lax._check_user_dtype_supported(dtype, "zeros")
Segmentation fault (core dumped)
Issue Analytics
- State:
- Created 3 years ago
- Comments:9 (5 by maintainers)
Top Results From Across the Web
numpy.tensordot — NumPy v1.24 Manual
Given two tensors, a and b, and an array_like object containing two array_like objects, (a_axes, b_axes) , sum the products of a's and...
Read more >understanding numpy np.tensordot - python - Stack Overflow
I am trying to understand how this tensordot function work . I know that it returns the tensordot product. but axes part is...
Read more >Using @njit with numpy.tensordot - Numba Discussion
Is np.tensordot simply not supported or am I doing something wrong? Thanks in advance, help is very much appreciated.
Read more >numpy.tensordot() - JAX documentation - Read the Docs
Given two tensors, a and b , and an array_like object containing two array_like objects, (a_axes, b_axes) , sum the products of a...
Read more >tf.tensordot | TensorFlow v2.11.0
If the shapes of a , b , and axes are incompatible. IndexError, If the values in axes exceed the rank of the...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Looks like the bug is gone.
Can you try this again on the latest
jaxlib==0.1.64
? I think you can just writepip install --upgrade jaxlib==0.1.64+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html
to use the same cuda110 version you have now.