question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

The current implementation of quantile function generates NaNs that are detected when jax_debug_nans is activated.

from jax.config import config
config.update("jax_debug_nans", True)
import jax.numpy as jnp
jnp.quantile(jnp.ones((3, 3)), 0.5)
from jax.config import config
config.update("jax_debug_nans", True)
import jax.numpy as jnp
jnp.where(1, float('nan'), 1)

Traceback

  File "/home/esac/projects/venv/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 5560, in percentile
    return quantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input,
  File "/home/esac/projects/venv/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 5370, in quantile
    return _quantile(a, q, axis, interpolation, keepdims, False)
  File "/home/esac/projects/venv/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 5436, in _quantile
    a = where(any(isnan(a), axis=axis, keepdims=True), nan, a)
  File "/home/esac/projects/venv/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 1686, in where
    return _where(condition, x, y)
  File "/home/esac/projects/venv/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 1669, in _where
    x, y = _promote_dtypes(x, y)
  File "/home/esac/projects/venv/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 268, in _promote_dtypes
    return [lax._convert_element_type(x, to_dtype, weak_type) for x in args]
  File "/home/esac/projects/venv/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 268, in <listcomp>
    return [lax._convert_element_type(x, to_dtype, weak_type) for x in args]
FloatingPointError: invalid value (nan) encountered in convert_element_type
python-BaseException
Invalid value encountered in the output of a jit function. Calling the de-optimized version.

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Reactions:1
  • Comments:9

github_iconTop GitHub Comments

1reaction
jakevdpcommented, Aug 30, 2021

Can you check whether #7648 fixes the issue you’re seeing? It’s not yet made it into a release, so you’d have to install from the main branch.

0reactions
manuel-delvermecommented, Aug 30, 2021

werks

Read more comments on GitHub >

github_iconTop Results From Across the Web

JAX debugging flags - JAX documentation - Read the Docs
jax_debug_nans is a JAX flag that when enabled, automatically raises an error when a NaN is detected. It has special handling for JIT-compiled...
Read more >
Jax - Debugging NaN-values
i spent the last 6 hours trying to debug seemingly randomly occuring NaN-values in Jax. I have narrowed down that the NaNs initially...
Read more >
Tips for debugging NaNs in gradient? #475 - google/jax
I am running an optimisation using gradients from Jax, and everything goes well for a number of steps until the gradients returned are...
Read more >
Computer Vision News - February 2021
18 Computer Vision Tool It is often very difficult to debug functions or even simple event compilations. JAX provides a very simple API...
Read more >
Get nan from svi [numpyro] - Misc.
I tried to debug it, but all value I got is like this ... In summary, you can use numpyro.util.fori_loop instead of jax.lax.fori_loop...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found