Add HOWTO: Batch Normalization
See original GitHub issueBatch Normalization is more complicated than most layers because of the mutation of moving averages during training.
There is a Discussion on this already (#921), but it may be nice to add a HOWTO for this since we could add the code below as well. The code below is copied from a Colab by @levskaya, which highlights the general state management API you use for any state-computation in a NN.
Also take a look at the comments at #1489.
import numpy as np
import jax
from jax import random, lax, numpy as jnp
import flax
from flax import linen as nn
We define a trivial conv + BN layer
class Foo(nn.Module):
train: bool
filters: int
@nn.compact
def __call__(self, x):
x = nn.Conv(self.filters, (1, 1), use_bias=False, dtype=jnp.float32)(x)
x = nn.BatchNorm(use_running_average=not self.train,
momentum=0.9,
epsilon=1e-5,
dtype=jnp.float32)(x)
return x
key = random.PRNGKey(0)
x = jnp.ones((5,4,4,3))
# We instantiate the layer then call its init function to get initial variable collections.
foo_vars = Foo(filters=7, train=True).init(key, x)
foo_vars
This returns the following:
FrozenDict({
params: {
Conv_0: {
kernel: DeviceArray([[[[ 0.50138927, 0.7354811 , 0.7896391 , -0.63713336,
1.081016 , -0.29067358, -0.3780927 ],
[ 0.7357814 , 0.24682549, -0.55378306, 0.16909008,
0.85014457, 1.0167135 , 0.19896305],
[ 1.1461202 , 0.8548834 , -1.0578486 , -0.6013309 ,
0.2501557 , 0.3332178 , -0.36248836]]]], dtype=float32),
},
BatchNorm_0: {
scale: DeviceArray([1., 1., 1., 1., 1., 1., 1.], dtype=float32),
bias: DeviceArray([0., 0., 0., 0., 0., 0., 0.], dtype=float32),
},
},
batch_stats: {
BatchNorm_0: {
mean: DeviceArray([0., 0., 0., 0., 0., 0., 0.], dtype=float32),
var: DeviceArray([1., 1., 1., 1., 1., 1., 1.], dtype=float32),
},
},
})
We explicitly say which variable collections are to be mutated by the apply function, those are then returned as auxilliary variables.
y1, new_batch_stats = Foo(filters=7, train=True).apply(foo_vars, x, mutable=['batch_stats'])
new_batch_state
This returns the following:
FrozenDict({
batch_stats: {
BatchNorm_0: {
mean: DeviceArray([ 0.23832898, 0.18371896, -0.08219925, -0.10693741,
0.21813159, 0.10592576, -0.05416182], dtype=float32),
var: DeviceArray([0.9000007 , 0.90000004, 0.9 , 0.90000004, 0.9000001 ,
0.9 , 0.9 ], dtype=float32),
},
},
})
We stitch together params and batch stats collections to evaluate again. Normally the params would have been updated by a training step using an optimizer.
new_foo_vars = {'params': foo_vars['params'], 'batch_stats': new_batch_stats}
y2, even_newer_batch_stats = Foo(filters=7, train=True).apply(new_foo_vars, x, mutable=['batch_stats'])
Issue Analytics
- State:
- Created 3 years ago
- Reactions:16
- Comments:6 (1 by maintainers)
Top GitHub Comments
Hi @cccntu – what do you mean by “different”?
pop
returns two parts: the variable collection that you popped, and the remaining collections still all grouped together, e.g.:Maybe the docstring for
FrozenDict.pop
isn’t clear enough? But this is working as intended.(BTW in your example you wrote
params = model.init(rng, dummy_input)
– it should probably bevariables = model.init(rng, dummy_input)
as parameters are one of possible multiple variable collections)Does this help? What could we improve in our documentation so that this would be less confusing?
See the ImageNet example for the canonical example of combining BatchNorm and pmap. There indeed we sync the statistics before evaluation with
pmean
.