question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[flax.linen] Disable dropout in SelfAttention dynamically when using `setup`

See original GitHub issue

Problem you have encountered:

Dropout allows to be enabled/disabled within the forward call since flax.linen.Dropout accepts the boolean deterministic in its __call__ function, see: https://flax.readthedocs.io/en/latest/_modules/flax/linen/stochastic.html#Dropout .

However the flax.linen.SelfAttention class does not allow its dropout layer to be enabled/disabled in the __call__ method, but only in the __init__ method, see: https://flax.readthedocs.io/en/latest/_modules/flax/linen/attention.html#MultiHeadDotProductAttention

Let’s say we have the following Attention layer:

import flax.linen as nn
from jax import lax, random, numpy as jnp


class DummyAttention(nn.Module):
    num_heads: int

    @nn.compact
    def __call__(self, inputs, deterministic):
        inputs = nn.attention.SelfAttention(
            dropout_rate=0.1,
            num_heads=self.num_heads,
            deterministic=deterministic,
        )(inputs)
        return inputs

model = DummyAttention(num_heads=2)

key1, key2 = random.split(random.PRNGKey(0))
x = jnp.ones((1, 8))
params = model.init({"params": key1, "dropout": key2}, x, True)

# train
print(model.apply(params, x, False, rngs={"dropout": key2}))

# eval
print(model.apply(params, x, True))

Now let’s say we want to rewrite the module using the setup function:

class DummyAttention(nn.Module):
    num_heads: int

   def setup(self):
       self.self_attn = nn.attention.SelfAttention(
            dropout_rate=0.1,
            num_heads=self.num_heads,
            deterministic=deterministic,
        )
    def __call__(self, inputs, deterministic):
        # problem..."deterministic" flag cannot be passed
        inputs = self_attn(inputs)
        return inputs

We encounter a problem here, no? We cannot pass deterministic dynamically to SelfAttention anymore.

Using setup(self) with flax.linen.Dropout is no problem as the call method accepts a deterministic argument.

Would it maybe be possible to add deterministic to the __call__ method of flax.SelfAttention? I think it might also better align the API of flax.Dropout with flax.SelfAttention and allow to dynamically enable/disable dropout when using setup.

I’d be happy to open a PR that adds an optional deterministic arg to flax.linen.SelfAttention.__call__. It could maybe overwrite SelfAttention's deterministic attribute in case it’s passed - what do you think?

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:10 (3 by maintainers)

github_iconTop GitHub Comments

3reactions
avitalcommented, Jan 13, 2021

Yes precisely, PyTorch “solves” this with hard-coded mutable state (train vs eval) which is both incomplete (what if you have a third mode?) and also a bit dangerous (what if you forget to switch between the modes).

@jheek also makes the point that if you have a bound module instance that you tinker with in Colab it’s pretty weird to have to clone the instance and replace the variables with the variables from another instance to switch from one mode to another. This can happen for modules defined inside other modules, or for “interactive mutable modules” that will soon land in Flax that allow for easy tinkering in Colab.

So I now see that the change you’re asking for here is a good idea. It adds a bit of complexity to the implementation but serves our users better, and one of our guiding philosophies is to “agree” to own more complexity to serve users. A little bit like the “Absorb the complexity” in https://react.christmas/2019/24, but not to get too philosophical here…

1reaction
patrickvonplatencommented, Jan 13, 2021

Yeah, I’ve thought about it. But it would kind of mean that for training we would create two separateate instances no? Let’s say we have:

class DummyAttention(nn.Module):
    num_heads: int
    deterministic: bool

   def setup(self):
       self.self_attn = nn.attention.SelfAttention(
            dropout_rate=0.1,
            num_heads=self.num_heads,
            deterministic=self.deterministic,
        )
    def __call__(self, inputs):
        inputs = self_attn(inputs)
        return inputs

=> this would mean we need to do:

train_model = DummyAttention(num_heads=3, deterministic=False)
# init model and train weights
# ...
val_model = DummyAttention(num_heads=3, deterministic=True)
# apply new instance with trained weights

no?

I’m not a big fan of having two separate instances train_model and val_model. I’d rather much prefer to have

model.apply(...., deterministic=not train)

Read more comments on GitHub >

github_iconTop Results From Across the Web

Tutorial 6 (JAX): Transformers and Multi-Head Attention
Thus, we focus here on what makes the Transformer and self-attention so ... We use JAX as acceleration backend, Flax for implementing neural...
Read more >
flax-community/dalle-mini · Create new file
+ For inference only, use `pip install ... setAttribute('href', image)\n a.click()\n canvas.remove()\n })\n }\n ... + self.dropout = dropout.
Read more >
MLP-Mixer: An all-MLP Architecture for Vision
technically simple alternative, that does not use convolutions or self-attention. Instead, Mixer's architecture is based entirely on ...
Read more >
11.4. The Bahdanau Attention Mechanism
To implement the RNN encoder-decoder with attention, we only need to redefine the decoder (omitting the generated symbols from the attention function simplifies ......
Read more >
A Regularization Method for Fully-Connected Self-Attention ...
As an appealing alternative to recurrent and convolutional layers, the fully-connected self-attention layer surprisingly lacks a specific dropout method. This ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found