Please consider supporting optimization of metaparamers
See original GitHub issueThis is a very exciting project! I was just considering using flax.optim when I found optax, and I love the elegant combine.chain
design of the varias optimizer aliases. Very cool!
I’d like to consider learning as an iterated function of the parameters, which itself depends on meta-parameters (e.g. learning rate). Then, I can use the fixed point theorem to calculate the gradient of the loss on a batch with respect to the metaparameters.
Unfortunately, optax’s GradientTransformations
are implemented using functions that close over values, which means that these values cannot be JAX tracers. From my understanding, you cannot take the derivative with respect to the step_size
if the step size is a closed-over-value.
I know this might be a serious change, but would it be possible, instead of having:
def scale(step_size: float) -> GradientTransformation:
...
return GradientTransformation(init_fn, update_fn)
To implement the abstract methods init_fn
and update_fn
in an inherited class:
class Scale(GradientTransformation):
def __init__(self, step_size):
...
This design would allow:
- taking derivatives with respect to various meta-parameters (like
step_size
), - inspecting objects (
scale.step_size
is available in the object oriented approach) for debugging, - comparing objects and preventing recompilation of jitted functions. If, for some reason, you call
scale(1e-3)
twice, you get a different object each time, and these objects will not compare equal. If these objects are passed to a jitted function, the function will be recompiled even though the objects would normally be equal.
Issue Analytics
- State:
- Created 3 years ago
- Comments:25 (10 by maintainers)
Thanks for your reply, Neil. I’ll see if tjax helps.
@Waterkin It appears your
eta
is a dict.Anyway, I think the optax is solution is over-complicated. It’s a lot easier to simply use the transformations in
tjax.gradient
, which have all meta-parameters as ordinary dynamic values.