question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Extracting learning rate from optimizer state directly

See original GitHub issue

Hi! I’m trying to extract the learning rate from an optax optimizer directly for logging to Tensorboard.

I know I could get it from my learning rate schedule object instead by passing in step, but we’ve previously run into situations where the optimizer step # and expected step # went out of sync (our fault, not optax’s), so to be safe we’d like to get it directly from the optimizer object. In Tensorflow you can do self._optimizer._get_hyper('learning_rate') to access it since it gets logged via _set_hyper. Is there an easy way to do a similar thing in optax?

Issue Analytics

  • State:open
  • Created 2 years ago
  • Comments:5 (3 by maintainers)

github_iconTop GitHub Comments

6reactions
mkuneschcommented, Oct 31, 2021

Hi! Thanks for the question!

The optimizers themselves do not have any internal state and by default optimizers that use a learning rate schedule only store the current step in the optimizer state, not the current learning rate. There are two ways around this which worked well for me in the past:

  • You can get the step from the optimizer state and pass it to the schedule: optimizers that use a schedule (e.g. scale_by_schedule) store the step in the optimizer state (e.g. in the field count for ScaleByScheduleState) so you can extract this count from the optimizer state and pass it to the schedule by hand to get the current learning rate. This is similar to what you have already done, only that now the step is guaranteed to be correct.
  • You can use inject_hyperparams in schedule.py to make any hyperparameter a modifiable part of the optimizer state. This means that you can promote the learning rate to be part of the optimizer state so that you can access it in the optimizer state directly. There is more info and an example in the docstring of inject_hyperparams.

Let me know if you have any questions about the two suggestions above - I’m happy to provide more detail or a code snippet!

I should also say that since logging of the learning rate is so common, we are thinking about providing a more direct way of accessing it in the optimizer state. Any thoughts or comments on this are welcome so I’ll also label this as enhancement, at least for now.

1reaction
n2cholascommented, Apr 18, 2022

@borisdayma this will still work when the hyperparameter is a function. opt_state.hyperparams stores the most recent value of the hyper parameter. Below is a code snippet illustrating this:

import optax
import jax.numpy as jnp

@optax.inject_hyperparams
def optimizer(learning_rate, eps=1e-8):
  return optax.chain(
      optax.scale_by_rss(initial_accumulator_value=0.0, eps=eps),
      optax.scale(learning_rate),
  )

tx = optimizer(optax.linear_schedule(0.1, 0.0001, 10))
opt_state = tx.init({'w': jnp.ones((5, 5)), 'b': jnp.zeros((5))})
print(opt_state.hyperparams['learning_rate'])
grads = {'w': jnp.full((5, 5), 0.1), 'b': jnp.full((5), 0.1)}
updates, new_opt_state = tx.update(grads, opt_state)
print(new_opt_state.hyperparams['learning_rate'])

Output:

0.1
0.09001
Read more comments on GitHub >

github_iconTop Results From Across the Web

Can you extract the current learning rate from tf.keras.Adam?
I want to experiment with decay during training, using Tensorflow's keras implementation and Adam. It appears that model.optimizer.lr is the ...
Read more >
Get current LR of optimizer with adaptive LR - PyTorch Forums
How can I get the current learning rate being used by my optimizer? Many of the optimizers in the torch.optim class use variable...
Read more >
Optimizers — MONAI 1.1.0 Documentation
The learning rate range test increases the learning rate in a pre-training run between two boundaries in a linear or exponential manner.
Read more >
Summary - Flax - Read the Docs
Using Optax. Gradient Transformations. Optax Training Step. Multi Optimizer. Train State ; Previous API. Optimizer and OptimizerDef. Previous Training Step.
Read more >
Optimizers — Apache MXNet documentation
The extent to which the parameters are updated in this direction is governed by a hyperparameter called the learning rate. This process, known...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found