Using optax with partially complex models
See original GitHub issueHello!
I’m trying to implement the Fourier Neural Operator model from https://github.com/devzhk/PINO/, translating the Pytorch code there into jax.
I ran into an issue where the optax optimizer_state becomes complex for all weight matrices of my model, rather than only the leaves corresponding to complex weights in the model. This only happens after a while, and I’ve only been able to reproduce this while using an optax scheduler. Using only optax.adam
seems to work without issue.
Full reproducing script: https://colab.research.google.com/drive/1a5BFz0G20t_VvXOfQqJfqHo4-HL3Ok3X?usp=sharing.
EDIT: This is related to (but possibly different than) https://github.com/deepmind/optax/issues/196
Issue Analytics
- State:
- Created 2 years ago
- Comments:8 (1 by maintainers)
Top Results From Across the Web
Common Optimizers — Optax documentation - Read the Docs
The optimizer is based on modeling neural network gradients via deep relative trust (a ... In addition, NovoGrad requires half the memory compared...
Read more >arXiv:2112.10526v2 [quant-ph] 18 Aug 2022
The default assump- tion is that models with complex weights are non-holomorphic, but some objects (most notably the quantum geometric tensor) ...
Read more >JAX Deep Learning Code and Models - Model Zoo
Optax is a gradient processing and optimization library for JAX. JAX ... Official code for Score-Based Generative Modeling through Stochastic Differential ...
Read more >Source code for numpyro.infer.svi
*SVI Part I: An Introduction to Stochastic Variational Inference in Pyro*, ... from optax import adam, chain, clip >>> svi = SVI(model, guide, ......
Read more >What makes JAX so Awesome - Analytics India Magazine
It extends the API for Sonnet, our module-based neural network programming model in TensorFlow. Optax. Gradient-based optimization is important ...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Yeah, That’s what we tried to argue with optax developers. Hopefully once #297 is merged this will be addressed.
@wdphy16 does PR #279 address gradient clipping as well?
By the way, can you try using master? The current release of optax does not include the changes to work with complex numbers.
Current master silently computes the wrong norm in adam.