question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

What does optax.mask do?

See original GitHub issue

I don’t understand what optax.mask does.

I would expect that the masked optimizer

optax.mask(optax.sgd(0.1), {"Dense": True, "bias": False})

would only apply the optimisation to sub-leafs of Dense and not optimise sub-leaves of bias. Which means that the masked gradient should match the sgd one for Dense and be zero for bias.

However it seems to me that the masked updates are correct for sub-leafs of Dense (so where the mask is True, but they are the identity where the mask is False.

Is this intended behaviour? it seems rather strange to me. I was trying to update only a subsets of the weights of my model but this was not working

MWE:

import jax.numpy as jnp
import jax
import optax

pars = {"Dense": {"kernel": jnp.zeros((2,3)), "bias": jnp.zeros((3))}, "bias":jnp.zeros(2)}
grad = jax.tree_map(jnp.ones_like, pars)

op = optax.masked(optax.sgd(0.1), {"Dense": True, "bias": False})

op_state = op.init(pars)

masked_updates, new_op_state = op.update(grad, op_state,  pars)
>>> masked_updates
{'Dense': {'bias': DeviceArray([-0.1, -0.1, -0.1], dtype=float32), 'kernel': DeviceArray([[-0.1, -0.1, -0.1],
             [-0.1, -0.1, -0.1]], dtype=float32)}, 'bias': DeviceArray([1., 1.], dtype=float32)}

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Reactions:5
  • Comments:8 (3 by maintainers)

github_iconTop GitHub Comments

2reactions
PhilipVinccommented, Aug 12, 2021

I think it does, because Dead Code Elimination will see that they are not used anywhere. (zero_grad calls zeros_like)

2reactions
PhilipVinccommented, Aug 11, 2021

@pharringtonp19 sorry i was on holiday. @n2cholas I eventually had figured this out. I admit, this behaviour makes complete sense within the framework of gradient transformation of optax, however it took me a while to figure it out.

the current docstrings states Mask updates so only a subset of them are computed. . In optax view, that means that only a subset of the gradient are transformed, and the rest arent. But if one isn’t familiar with optax, i would immediately think that only some gradients are computed, and the rest are zeroed out or whatever.

I think the docstring could be modified to state something like:

  """Mask updates so only a subset of them are transformed, while the rest are 
passed-through unchanged.
"""
Read more comments on GitHub >

github_iconTop Results From Across the Web

Common Optimizers — Optax documentation - Read the Docs
You can use optax.masked to make your own AdamaxW variant where additive_weight_decay is applied only to a subset of params . References.
Read more >
Build a Transformer in JAX from scratch: how to write and train ...
All we do is build the causal mask, using np.trill() which nullify all elements of the array above the kth, multiply with our...
Read more >
optax Changelog - PyUp.io
Do not use None in optax.masked. in https://github.com/deepmind/optax/pull/338 ... One useful use case is setting the upper bound for the second momentum, ...
Read more >
dev/seq2seq/run_seq2seq_flax.py · dalle-mini ... - Hugging Face
The mask is True for parameters that should be decayed. ... For more details about the parameters please check https://github.com/deepmind/optax/blob/ ...
Read more >
[N] DeepMind releases two new Jax Libraries: Optax for ...
Distributed training is convenient already, but I'm excited for it to get even more convenient as things like gmap and mask mature.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found