[flax.linen] Disable dropout in SelfAttention dynamically when using `setup`
See original GitHub issueProblem you have encountered:
Dropout allows to be enabled/disabled within the forward call since flax.linen.Dropout
accepts the boolean deterministic
in its __call__
function, see: https://flax.readthedocs.io/en/latest/_modules/flax/linen/stochastic.html#Dropout .
However the flax.linen.SelfAttention
class does not allow its dropout layer to be enabled/disabled in the __call__
method, but only in the __init__
method, see: https://flax.readthedocs.io/en/latest/_modules/flax/linen/attention.html#MultiHeadDotProductAttention
Let’s say we have the following Attention layer:
import flax.linen as nn
from jax import lax, random, numpy as jnp
class DummyAttention(nn.Module):
num_heads: int
@nn.compact
def __call__(self, inputs, deterministic):
inputs = nn.attention.SelfAttention(
dropout_rate=0.1,
num_heads=self.num_heads,
deterministic=deterministic,
)(inputs)
return inputs
model = DummyAttention(num_heads=2)
key1, key2 = random.split(random.PRNGKey(0))
x = jnp.ones((1, 8))
params = model.init({"params": key1, "dropout": key2}, x, True)
# train
print(model.apply(params, x, False, rngs={"dropout": key2}))
# eval
print(model.apply(params, x, True))
Now let’s say we want to rewrite the module using the setup
function:
class DummyAttention(nn.Module):
num_heads: int
def setup(self):
self.self_attn = nn.attention.SelfAttention(
dropout_rate=0.1,
num_heads=self.num_heads,
deterministic=deterministic,
)
def __call__(self, inputs, deterministic):
# problem..."deterministic" flag cannot be passed
inputs = self_attn(inputs)
return inputs
We encounter a problem here, no? We cannot pass deterministic
dynamically to SelfAttention
anymore.
Using setup(self)
with flax.linen.Dropout
is no problem as the call method accepts a deterministic
argument.
Would it maybe be possible to add deterministic
to the __call__
method of flax.SelfAttention
? I think it might also better align the API of flax.Dropout
with flax.SelfAttention
and allow to dynamically enable/disable dropout when using setup
.
I’d be happy to open a PR that adds an optional deterministic
arg to flax.linen.SelfAttention.__call__
. It could maybe overwrite SelfAttention's
deterministic
attribute in case it’s passed - what do you think?
Issue Analytics
- State:
- Created 3 years ago
- Comments:10 (3 by maintainers)
Top GitHub Comments
Yes precisely, PyTorch “solves” this with hard-coded mutable state (train vs eval) which is both incomplete (what if you have a third mode?) and also a bit dangerous (what if you forget to switch between the modes).
@jheek also makes the point that if you have a bound module instance that you tinker with in Colab it’s pretty weird to have to clone the instance and replace the variables with the variables from another instance to switch from one mode to another. This can happen for modules defined inside other modules, or for “interactive mutable modules” that will soon land in Flax that allow for easy tinkering in Colab.
So I now see that the change you’re asking for here is a good idea. It adds a bit of complexity to the implementation but serves our users better, and one of our guiding philosophies is to “agree” to own more complexity to serve users. A little bit like the “Absorb the complexity” in https://react.christmas/2019/24, but not to get too philosophical here…
Yeah, I’ve thought about it. But it would kind of mean that for training we would create two separateate instances no? Let’s say we have:
=> this would mean we need to do:
no?
I’m not a big fan of having two separate instances
train_model
andval_model
. I’d rather much prefer to havemodel.apply(...., deterministic=not train)