Tips for debugging NaNs in gradient?
See original GitHub issueHi there,
I am running an optimisation using gradients from Jax, and everything goes well for a number of steps until the gradients returned are all nan
. I am having a bit of a hard time tracking down where the problem is; the forward calculations all seem to be fine.
Is there some way I can work out which operation is causing the nans from grad
? This would be really useful.
Thanks!
Issue Analytics
- State:
- Created 5 years ago
- Reactions:7
- Comments:21 (9 by maintainers)
Top Results From Across the Web
python - Jax - Debugging NaN-values - Stack Overflow
1 Answer 1 · As a hotfix, switching to float64 can do the trick. · Gradient Clipping is All You Need (docs) ·...
Read more >[Solved] Debugging NaNs in gradients - PyTorch Forums
To debug NaN grad, you can add backward hook at each step of your network, and print to see where they become NaN....
Read more >Debugging Tips for Neural Networks | by Wilson Wang
Check for exploding or vanishing gradients. If you see a gradient go to 0 or nan/infinity, you can be sure the network will...
Read more >Debugging Gradient NaN's - Pyro Discussion Forum
step() returns nan , as a result of the gradients returning nan . Additionally, if I don't include any params in the common...
Read more >Debugging and Visualisation in PyTorch using Hooks
Hooks for Tensors · You can print the value of gradient for debugging. You can also log them. This is especially useful with...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Blocking a user is the worst feeling! That’s a magic word to get us to help you out ASAP 😃
I added some basic nan debugging machinery in #482. As with other config options there are a few ways to turn it on:
JAX_DEBUG_NANS
environment variable to something truthy,from jax.config import config
andconfig.update("jax_debug_nans", True)
near the top of your main file,from jax.config import config
andconfig.parse_flags_with_absl()
to your main file, then set the option using a command-line flag like--jax_debug_nans=True
.Switching that option on adds a nan check to every floating point type value produced by XLA. That means values are pulled back to the host and checked as ndarrays for every primitive operation not under an
@jit
. For code under an@jit
, the output of every@jit
function is checked and if a nan is present it will re-run the function in de-optimized op-by-op mode (effectively removing one level of@jit
at a time).There could be tricky situations that arise, like nans that only occur under a
@jit
but don’t get produced in de-optimized mode. In that case you’ll see a warning message print out but your code will continue to execute, so we can dig in deeper.If the nans are being produced in the backward pass of a gradient evaluation, when an exception is raised several frames up in the stack trace you’ll be in the
backward_pass
function, which is essentially a simple jaxpr interpreter that walks the sequence of primitive operations in reverse. If it’s not immediately obvious, you can poke around a bit to find the primitive that’s producing the nan by doing things in an interactive debugger likep eqn.primitive
in that stack frame.How does that sound? This is a good opportunity to add exactly the tooling we want; JAX is tiny and easy to instrument, so there’s no reason not to get this right.
Great question. I don’t think we have a solution in place for this right now, but I think we can make one.
There are at least two things to solve here:
np.seterr(invalid="raise")