jax.numpy.argmin/jax.numpy.argmax fail with nans
See original GitHub issueIssue Analytics
- State:
- Created 3 years ago
- Comments:13 (9 by maintainers)
Top Results From Across the Web
numpy.nanargmax() - JAX documentation - Read the Docs
argmax returns -1 for all-NaN slices and does not raise an error. Original docstring below. NaNs. For all-NaN slices ValueError is raised. Warning:...
Read more >numpy.nanargmax — NumPy v1.24 Manual
Return the indices of the maximum values in the specified axis ignoring NaNs. For all-NaN slices ValueError is raised. Warning: the results cannot...
Read more >making numpy.nanargmin return nan if column is all nan
Right now, it raises a ValueError , when that happens. And i cant use numpy.argmin , since that will fail when there are...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Thanks - looks like it’s only a problem on GPU, not on CPU or TPU. This suggests it may be an XLA issue.
I think this was already fixed by https://github.com/google/jax/pull/6764 . That PR isn’t in a
jax
release yet, though, so to see the fix you’d need to usejax
from Github head.(This both was and was not an XLA issue. It’s an XLA issue that GPU acts differently, because XLA/GPU lacks support for variadic reductions, forcing us to use a different code path on GPU. The fallback path handling NaNs differently was a JAX issue.)