jax_debug_nans
See original GitHub issueThe current implementation of quantile function generates NaNs that are detected when jax_debug_nans is activated.
from jax.config import config
config.update("jax_debug_nans", True)
import jax.numpy as jnp
jnp.quantile(jnp.ones((3, 3)), 0.5)
from jax.config import config
config.update("jax_debug_nans", True)
import jax.numpy as jnp
jnp.where(1, float('nan'), 1)
Traceback
File "/home/esac/projects/venv/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 5560, in percentile
return quantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input,
File "/home/esac/projects/venv/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 5370, in quantile
return _quantile(a, q, axis, interpolation, keepdims, False)
File "/home/esac/projects/venv/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 5436, in _quantile
a = where(any(isnan(a), axis=axis, keepdims=True), nan, a)
File "/home/esac/projects/venv/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 1686, in where
return _where(condition, x, y)
File "/home/esac/projects/venv/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 1669, in _where
x, y = _promote_dtypes(x, y)
File "/home/esac/projects/venv/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 268, in _promote_dtypes
return [lax._convert_element_type(x, to_dtype, weak_type) for x in args]
File "/home/esac/projects/venv/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 268, in <listcomp>
return [lax._convert_element_type(x, to_dtype, weak_type) for x in args]
FloatingPointError: invalid value (nan) encountered in convert_element_type
python-BaseException
Invalid value encountered in the output of a jit function. Calling the de-optimized version.
Issue Analytics
- State:
- Created 2 years ago
- Reactions:1
- Comments:9
Top Results From Across the Web
JAX debugging flags - JAX documentation - Read the Docs
jax_debug_nans is a JAX flag that when enabled, automatically raises an error when a NaN is detected. It has special handling for JIT-compiled...
Read more >Jax - Debugging NaN-values
i spent the last 6 hours trying to debug seemingly randomly occuring NaN-values in Jax. I have narrowed down that the NaNs initially...
Read more >Tips for debugging NaNs in gradient? #475 - google/jax
I am running an optimisation using gradients from Jax, and everything goes well for a number of steps until the gradients returned are...
Read more >Computer Vision News - February 2021
18 Computer Vision Tool It is often very difficult to debug functions or even simple event compilations. JAX provides a very simple API...
Read more >Get nan from svi [numpyro] - Misc.
I tried to debug it, but all value I got is like this ... In summary, you can use numpyro.util.fori_loop instead of jax.lax.fori_loop...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Can you check whether #7648 fixes the issue you’re seeing? It’s not yet made it into a release, so you’d have to install from the main branch.