question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Not compatible with np.tensordot

See original GitHub issue

Environments:

  • jaxlib: 0.1.61+cuda110
  • jax: 0.2.10
  • dm-haiku: 0.0.4.dev0
  • numpy: 1.19.5

To reproduce the bug:

  • code is given as bellow
  • you might need to run it multiple times
import jax
import haiku as hk
import numpy as np

def net(x):
  torso_net = hk.Conv2D(32, kernel_shape=[8, 8], stride=[4, 4], padding='VALID')
  return  hk.BatchApply(torso_net)(x)

def main():
  init_fn, apply_fn = hk.without_apply_rng(
    hk.transform(lambda obs: net(obs))
  )

  # Uncomment can also reproduce the bug
  # @jax.jit
  def initial_params(rng_key):
    _ = np.tensordot(np.zeros((84,84,3)), [0., 0., 0.], (-1, 0))
    return init_fn(rng_key, np.zeros((1,1,84,84,4)))

  rng_key = jax.random.PRNGKey(428)
  params = initial_params(rng_key)

if __name__ == '__main__':
  main()

The full error messages/tracebacks.

Segmentation fault (core dumped)
/usr/local/lib/python3.6/dist-packages/jax/_src/numpy/lax_numpy.py:2970: UserWarning: Explicitly requested dtype float64 requested in zeros is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  lax._check_user_dtype_supported(dtype, "zeros")
Segmentation fault (core dumped)

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:9 (5 by maintainers)

github_iconTop GitHub Comments

1reaction
bingykangcommented, Mar 29, 2021

Looks like the bug is gone.

0reactions
mattjjcommented, Mar 28, 2021

Can you try this again on the latest jaxlib==0.1.64? I think you can just write pip install --upgrade jaxlib==0.1.64+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html to use the same cuda110 version you have now.

Read more comments on GitHub >

github_iconTop Results From Across the Web

numpy.tensordot — NumPy v1.24 Manual
Given two tensors, a and b, and an array_like object containing two array_like objects, (a_axes, b_axes) , sum the products of a's and...
Read more >
understanding numpy np.tensordot - python - Stack Overflow
I am trying to understand how this tensordot function work . I know that it returns the tensordot product. but axes part is...
Read more >
Using @njit with numpy.tensordot - Numba Discussion
Is np.tensordot simply not supported or am I doing something wrong? Thanks in advance, help is very much appreciated.
Read more >
numpy.tensordot() - JAX documentation - Read the Docs
Given two tensors, a and b , and an array_like object containing two array_like objects, (a_axes, b_axes) , sum the products of a...
Read more >
tf.tensordot | TensorFlow v2.11.0
If the shapes of a , b , and axes are incompatible. IndexError, If the values in axes exceed the rank of the...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found