RuntimeError: INTERNAL: Could not find the corresponding function
See original GitHub issueHi, there, I’ve encountered the same issue as in #6046, but on the up-to-date version.
I am using Ubuntu 18.04.5 LTS, CUDA 11.3 and RTX 3090, the scripts to reproduce the error is:
import jax, imageio, jax.numpy as np, matplotlib.pyplot as plt
from jax.experimental import stax, optimizers
from tqdm.notebook import tqdm as tqdm
from jax import jit, grad, random
rand_key = random.PRNGKey(0)
adapted from here
The full error messages/tracebacks are:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-2-db26676596ad> in <module>
3 from tqdm.notebook import tqdm as tqdm
4 from jax import jit, grad, random
----> 5 rand_key = random.PRNGKey(0)
/scratch/hw501/anaconda3/lib/python3.7/site-packages/jax/_src/random.py in PRNGKey(seed)
103 """
104 return _return_prng_keys(
--> 105 True, prng.seed_with_impl(prng.threefry_prng_impl, seed))
106
107 def _fold_in(key: KeyArray, data: int) -> KeyArray:
/scratch/hw501/anaconda3/lib/python3.7/site-packages/jax/_src/prng.py in seed_with_impl(impl, seed)
189
190 def seed_with_impl(impl: PRNGImpl, seed: int) -> PRNGKeyArray:
--> 191 return PRNGKeyArray(impl, impl.seed(seed))
192
193
/scratch/hw501/anaconda3/lib/python3.7/site-packages/jax/_src/prng.py in threefry_seed(seed)
227
228 convert = lambda k: lax.reshape(lax.convert_element_type(k, np.uint32), [1])
--> 229 k1 = convert(lax.shift_right_logical(seed_arr, lax._const(seed_arr, 32)))
230 k2 = convert(jnp.bitwise_and(seed_arr, np.uint32(0xFFFFFFFF)))
231 return lax.concatenate([k1, k2], 0)
/scratch/hw501/anaconda3/lib/python3.7/site-packages/jax/_src/lax/lax.py in shift_right_logical(x, y)
386 def shift_right_logical(x: Array, y: Array) -> Array:
387 r"""Elementwise logical right shift: :math:`x \gg y`."""
--> 388 return shift_right_logical_p.bind(x, y)
389
390 def eq(x: Array, y: Array) -> Array:
/scratch/hw501/anaconda3/lib/python3.7/site-packages/jax/core.py in bind(self, *args, **params)
265 args, used_axis_names(self, params) if self._dispatch_on_params else None)
266 tracers = map(top_trace.full_raise, args)
--> 267 out = top_trace.process_primitive(self, tracers, params)
268 return map(full_lower, out) if self.multiple_results else full_lower(out)
269
/scratch/hw501/anaconda3/lib/python3.7/site-packages/jax/core.py in process_primitive(self, primitive, tracers, params)
610
611 def process_primitive(self, primitive, tracers, params):
--> 612 return primitive.impl(*tracers, **params)
613
614 def process_call(self, primitive, f, tracers, params):
/scratch/hw501/anaconda3/lib/python3.7/site-packages/jax/interpreters/xla.py in apply_primitive(prim, *args, **params)
274 """Impl rule that compiles and runs a single primitive 'prim' using XLA."""
275 compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)
--> 276 return compiled_fun(*args)
277
278
/scratch/hw501/anaconda3/lib/python3.7/site-packages/jax/interpreters/xla.py in _execute_compiled_primitive(prim, compiled, result_handler, *args)
390 device, = compiled.local_devices()
391 input_bufs = list(it.chain.from_iterable(device_put(x, device) for x in args if x is not token))
--> 392 out_bufs = compiled.execute(input_bufs)
393 check_special(prim.name, out_bufs)
394 return result_handler(*out_bufs)
RuntimeError: INTERNAL: Could not find the corresponding function
Issue Analytics
- State:
- Created 2 years ago
- Comments:9 (2 by maintainers)
Top Results From Across the Web
Cannot use GPU on Ubuntu 16.04, CUDA 11.0 #6046 - GitHub
I have a GeForce RTX 3090 with CUDA 11.0 installed on Ubuntu 16.04 and the installation works fine with TensorFlow. The path /usr/local/cuda ......
Read more >python - Module 'jaxlib.xla_extension.jax_jit' has no attribute ...
I just upgraded JAX by writing. pip install --upgrade jax jaxlib. on the anaconda command prompt and the problem is resolved.
Read more >Frequently Asked Questions - MMDetection's documentation!
This error indicates that your module has parameters that were not used in producing loss. This phenomenon may be caused by running different...
Read more >Troubleshooting and tips — Numba 0.50.1 documentation
Another common reason for Numba not being able to compile your code is that it cannot statically determine the return type of a...
Read more >CUDA semantics — PyTorch 1.13 documentation
The first step is to determine whether the GPU should be used or not. A common pattern is to use Python's argparse module...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Had the same issue, @junyaoshi’s solution worked for me too (installing
cudatoolkit-dev
). Just posting for reference in case other uses have the same issue.Thank you @hansen7! I was able to make resolve this error using @hansen7’s comment above. In addition, for me it was important to install
cudatoolkit-dev
in my conda env and have the updated version of nvidia driver. Ubuntu and CUDA versions didn’t seem to matter.