What does optax.mask do?
See original GitHub issueI don’t understand what optax.mask
does.
I would expect that the masked optimizer
optax.mask(optax.sgd(0.1), {"Dense": True, "bias": False})
would only apply the optimisation to sub-leafs of Dense
and not optimise sub-leaves of bias
.
Which means that the masked gradient should match the sgd one for Dense
and be zero for bias
.
However it seems to me that the masked updates are correct for sub-leafs of Dense
(so where the mask is True, but they are the identity where the mask is False.
Is this intended behaviour? it seems rather strange to me. I was trying to update only a subsets of the weights of my model but this was not working
MWE:
import jax.numpy as jnp
import jax
import optax
pars = {"Dense": {"kernel": jnp.zeros((2,3)), "bias": jnp.zeros((3))}, "bias":jnp.zeros(2)}
grad = jax.tree_map(jnp.ones_like, pars)
op = optax.masked(optax.sgd(0.1), {"Dense": True, "bias": False})
op_state = op.init(pars)
masked_updates, new_op_state = op.update(grad, op_state, pars)
>>> masked_updates
{'Dense': {'bias': DeviceArray([-0.1, -0.1, -0.1], dtype=float32), 'kernel': DeviceArray([[-0.1, -0.1, -0.1],
[-0.1, -0.1, -0.1]], dtype=float32)}, 'bias': DeviceArray([1., 1.], dtype=float32)}
Issue Analytics
- State:
- Created 2 years ago
- Reactions:5
- Comments:8 (3 by maintainers)
Top Results From Across the Web
Common Optimizers — Optax documentation - Read the Docs
You can use optax.masked to make your own AdamaxW variant where additive_weight_decay is applied only to a subset of params . References.
Read more >Build a Transformer in JAX from scratch: how to write and train ...
All we do is build the causal mask, using np.trill() which nullify all elements of the array above the kth, multiply with our...
Read more >optax Changelog - PyUp.io
Do not use None in optax.masked. in https://github.com/deepmind/optax/pull/338 ... One useful use case is setting the upper bound for the second momentum, ...
Read more >dev/seq2seq/run_seq2seq_flax.py · dalle-mini ... - Hugging Face
The mask is True for parameters that should be decayed. ... For more details about the parameters please check https://github.com/deepmind/optax/blob/ ...
Read more >[N] DeepMind releases two new Jax Libraries: Optax for ...
Distributed training is convenient already, but I'm excited for it to get even more convenient as things like gmap and mask mature.
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
I think it does, because Dead Code Elimination will see that they are not used anywhere. (
zero_grad
callszeros_like
)@pharringtonp19 sorry i was on holiday. @n2cholas I eventually had figured this out. I admit, this behaviour makes complete sense within the framework of
gradient transformation
of optax, however it took me a while to figure it out.the current docstrings states
Mask updates so only a subset of them are computed.
. In optax view, that means that only a subset of the gradient are transformed, and the rest arent. But if one isn’t familiar with optax, i would immediately think that only some gradients are computed, and the rest are zeroed out or whatever.I think the docstring could be modified to state something like: