question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Using optax with partially complex models

See original GitHub issue

Hello!

I’m trying to implement the Fourier Neural Operator model from https://github.com/devzhk/PINO/, translating the Pytorch code there into jax. I ran into an issue where the optax optimizer_state becomes complex for all weight matrices of my model, rather than only the leaves corresponding to complex weights in the model. This only happens after a while, and I’ve only been able to reproduce this while using an optax scheduler. Using only optax.adam seems to work without issue.

Full reproducing script: https://colab.research.google.com/drive/1a5BFz0G20t_VvXOfQqJfqHo4-HL3Ok3X?usp=sharing.

EDIT: This is related to (but possibly different than) https://github.com/deepmind/optax/issues/196

Issue Analytics

  • State:open
  • Created 2 years ago
  • Comments:8 (1 by maintainers)

github_iconTop GitHub Comments

1reaction
PhilipVinccommented, Jan 24, 2022

optax should at least output some form of warning when handling complex parameters…

Yeah, That’s what we tried to argue with optax developers. Hopefully once #297 is merged this will be addressed.

I think my initial observation was due to the gradient clipping. The code for computing the global norm for clipping has the same issue as Adam, and squares a potentially complex number, rather than taking the squared module.

@wdphy16 does PR #279 address gradient clipping as well?

1reaction
PhilipVinccommented, Jan 22, 2022

By the way, can you try using master? The current release of optax does not include the changes to work with complex numbers.

Current master silently computes the wrong norm in adam.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Common Optimizers — Optax documentation - Read the Docs
The optimizer is based on modeling neural network gradients via deep relative trust (a ... In addition, NovoGrad requires half the memory compared...
Read more >
arXiv:2112.10526v2 [quant-ph] 18 Aug 2022
The default assump- tion is that models with complex weights are non-holomorphic, but some objects (most notably the quantum geometric tensor) ...
Read more >
JAX Deep Learning Code and Models - Model Zoo
Optax is a gradient processing and optimization library for JAX. JAX ... Official code for Score-Based Generative Modeling through Stochastic Differential ...
Read more >
Source code for numpyro.infer.svi
*SVI Part I: An Introduction to Stochastic Variational Inference in Pyro*, ... from optax import adam, chain, clip >>> svi = SVI(model, guide, ......
Read more >
What makes JAX so Awesome - Analytics India Magazine
It extends the API for Sonnet, our module-based neural network programming model in TensorFlow. Optax. Gradient-based optimization is important ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found