Add HOWTO: Gradient clipping
See original GitHub issueWe currently don’t have a canonical implementation in Flax. Some optimizers have it built in.
Possible implementation:
def clip_grad_norm(grad, max_norm):
norm = jnp.linalg.norm(jax.tree_util.tree_leaves(jax.tree_map(jnp.linalg.norm, grad)))
factor = jnp.minimum(max_norm, max_norm / (norm + 1e-6))
return jax.tree_map((lambda x: x * factor), grad)
Issue Analytics
- State:
- Created 3 years ago
- Reactions:2
- Comments:6 (1 by maintainers)
Top Results From Across the Web
How to apply gradient clipping in TensorFlow? - Stack Overflow
I would like to know How to apply gradient clipping on this network on the RNN where there is a possibility of exploding...
Read more >Understanding Gradient Clipping (and How It Can Fix ...
The idea behind clipping-by-value is simple. We define a minimum clip value and a maximum clip value. If a gradient exceeds some threshold...
Read more >How to Avoid Exploding Gradients With Gradient Clipping
Another solution to the exploding gradient problem is to clip the gradient if it becomes too large or too small. We can update...
Read more >How to apply Gradient Clipping in PyTorch
There are many ways to compute gradient clipping, but a common one is to rescale gradients so that their norm is at most...
Read more >How to Make a Gradient Clipping Mask in Illustrator - YouTube
Learn how to make a gradient clipping mask in this Illustrator tutorial. We learn how to use transparency masks to create and edit...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Yes. I actually just saw that there is an implementation there for what we discussed.
optax.clip_by_global_norm()
Shouldn’t this be part of optax?