question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Tips for debugging NaNs in gradient?

See original GitHub issue

Hi there,

I am running an optimisation using gradients from Jax, and everything goes well for a number of steps until the gradients returned are all nan. I am having a bit of a hard time tracking down where the problem is; the forward calculations all seem to be fine.

Is there some way I can work out which operation is causing the nans from grad? This would be really useful.

Thanks!

Issue Analytics

  • State:open
  • Created 5 years ago
  • Reactions:7
  • Comments:21 (9 by maintainers)

github_iconTop GitHub Comments

29reactions
mattjjcommented, Mar 7, 2019

Blocking a user is the worst feeling! That’s a magic word to get us to help you out ASAP 😃

I added some basic nan debugging machinery in #482. As with other config options there are a few ways to turn it on:

  1. you can set the JAX_DEBUG_NANS environment variable to something truthy,
  2. you can add from jax.config import config and config.update("jax_debug_nans", True) near the top of your main file,
  3. you can add from jax.config import config and config.parse_flags_with_absl() to your main file, then set the option using a command-line flag like --jax_debug_nans=True.

Switching that option on adds a nan check to every floating point type value produced by XLA. That means values are pulled back to the host and checked as ndarrays for every primitive operation not under an @jit. For code under an @jit, the output of every @jit function is checked and if a nan is present it will re-run the function in de-optimized op-by-op mode (effectively removing one level of @jit at a time).

There could be tricky situations that arise, like nans that only occur under a @jit but don’t get produced in de-optimized mode. In that case you’ll see a warning message print out but your code will continue to execute, so we can dig in deeper.

If the nans are being produced in the backward pass of a gradient evaluation, when an exception is raised several frames up in the stack trace you’ll be in the backward_pass function, which is essentially a simple jaxpr interpreter that walks the sequence of primitive operations in reverse. If it’s not immediately obvious, you can poke around a bit to find the primitive that’s producing the nan by doing things in an interactive debugger like p eqn.primitive in that stack frame.

How does that sound? This is a good opportunity to add exactly the tooling we want; JAX is tiny and easy to instrument, so there’s no reason not to get this right.

7reactions
mattjjcommented, Mar 4, 2019

Great question. I don’t think we have a solution in place for this right now, but I think we can make one.

There are at least two things to solve here:

  1. set up the equivalent of np.seterr(invalid="raise")
  2. catch nans on the backward pass, and associate them helpfully with user code
Read more comments on GitHub >

github_iconTop Results From Across the Web

python - Jax - Debugging NaN-values - Stack Overflow
1 Answer 1 · As a hotfix, switching to float64 can do the trick. · Gradient Clipping is All You Need (docs) ·...
Read more >
[Solved] Debugging NaNs in gradients - PyTorch Forums
To debug NaN grad, you can add backward hook at each step of your network, and print to see where they become NaN....
Read more >
Debugging Tips for Neural Networks | by Wilson Wang
Check for exploding or vanishing gradients. If you see a gradient go to 0 or nan/infinity, you can be sure the network will...
Read more >
Debugging Gradient NaN's - Pyro Discussion Forum
step() returns nan , as a result of the gradients returning nan . Additionally, if I don't include any params in the common...
Read more >
Debugging and Visualisation in PyTorch using Hooks
Hooks for Tensors · You can print the value of gradient for debugging. You can also log them. This is especially useful with...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found