question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Please consider supporting optimization of metaparamers

See original GitHub issue

This is a very exciting project! I was just considering using flax.optim when I found optax, and I love the elegant combine.chain design of the varias optimizer aliases. Very cool!

I’d like to consider learning as an iterated function of the parameters, which itself depends on meta-parameters (e.g. learning rate). Then, I can use the fixed point theorem to calculate the gradient of the loss on a batch with respect to the metaparameters.

Unfortunately, optax’s GradientTransformations are implemented using functions that close over values, which means that these values cannot be JAX tracers. From my understanding, you cannot take the derivative with respect to the step_size if the step size is a closed-over-value.

I know this might be a serious change, but would it be possible, instead of having:

def scale(step_size: float) -> GradientTransformation:
  ...
  return GradientTransformation(init_fn, update_fn)

To implement the abstract methods init_fn and update_fn in an inherited class:

class Scale(GradientTransformation):
  def __init__(self, step_size):
    ...

This design would allow:

  • taking derivatives with respect to various meta-parameters (like step_size),
  • inspecting objects (scale.step_size is available in the object oriented approach) for debugging,
  • comparing objects and preventing recompilation of jitted functions. If, for some reason, you call scale(1e-3) twice, you get a different object each time, and these objects will not compare equal. If these objects are passed to a jitted function, the function will be recompiled even though the objects would normally be equal.

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:25 (10 by maintainers)

github_iconTop GitHub Comments

1reaction
Waterkincommented, May 1, 2022

@Waterkin It appears your eta is a dict.

Anyway, I think the optax is solution is over-complicated. It’s a lot easier to simply use the transformations in tjax.gradient, which have all meta-parameters as ordinary dynamic values.

Thanks for your reply, Neil. I’ll see if tjax helps.

0reactions
NeilGirdharcommented, May 1, 2022

@Waterkin It appears your eta is a dict.

Anyway, I think the optax is solution is over-complicated. It’s a lot easier to simply use the transformations in tjax.gradient, which have all meta-parameters as ordinary dynamic values.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Update VisualStudioCodeCredential exception text - Azure/Azure ...
Please consider supporting optimization of metaparamers, 25, 2020-08-17, 2022-11-29. Should `RawWindowHandle` implement `Send`?, 3, 2021-12-05, 2022-11-19.
Read more >
[Bug]: Login not working after upgrading to 22.2.8/23.0.5 from 22.2.7 ...
We have a working nextcloud instance at version 22.2.7 with php7.3. LDAP against a samba4 AD domain is used for user authentication. SMB...
Read more >
optax - githubmemory
Please consider supporting optimization of metaparamers ... Make software development more efficient, Also welcome to join our telegram.
Read more >
Optax is a gradient processing and optimization library for JAX.
An apply_updates function can be used to eventually apply the transformed gradients to the set of parameters of interest. Separating gradient transformations ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found