question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Add HOWTO: Batch Normalization

See original GitHub issue

Batch Normalization is more complicated than most layers because of the mutation of moving averages during training.

There is a Discussion on this already (#921), but it may be nice to add a HOWTO for this since we could add the code below as well. The code below is copied from a Colab by @levskaya, which highlights the general state management API you use for any state-computation in a NN.

Also take a look at the comments at #1489.

import numpy as np

import jax
from jax import random, lax, numpy as jnp

import flax
from flax import linen as nn

We define a trivial conv + BN layer

class Foo(nn.Module):
  train: bool
  filters: int

  @nn.compact
  def __call__(self, x):
    x = nn.Conv(self.filters, (1, 1), use_bias=False, dtype=jnp.float32)(x)
    x = nn.BatchNorm(use_running_average=not self.train,
                     momentum=0.9,
                     epsilon=1e-5,
                     dtype=jnp.float32)(x)
    return x

key = random.PRNGKey(0)
x = jnp.ones((5,4,4,3))

# We instantiate the layer then call its init function to get initial variable collections.
foo_vars = Foo(filters=7, train=True).init(key, x)
foo_vars

This returns the following:

FrozenDict({
    params: {
        Conv_0: {
            kernel: DeviceArray([[[[ 0.50138927,  0.7354811 ,  0.7896391 , -0.63713336,
                             1.081016  , -0.29067358, -0.3780927 ],
                           [ 0.7357814 ,  0.24682549, -0.55378306,  0.16909008,
                             0.85014457,  1.0167135 ,  0.19896305],
                           [ 1.1461202 ,  0.8548834 , -1.0578486 , -0.6013309 ,
                             0.2501557 ,  0.3332178 , -0.36248836]]]], dtype=float32),
        },
        BatchNorm_0: {
            scale: DeviceArray([1., 1., 1., 1., 1., 1., 1.], dtype=float32),
            bias: DeviceArray([0., 0., 0., 0., 0., 0., 0.], dtype=float32),
        },
    },
    batch_stats: {
        BatchNorm_0: {
            mean: DeviceArray([0., 0., 0., 0., 0., 0., 0.], dtype=float32),
            var: DeviceArray([1., 1., 1., 1., 1., 1., 1.], dtype=float32),
        },
    },
})

We explicitly say which variable collections are to be mutated by the apply function, those are then returned as auxilliary variables.

y1, new_batch_stats = Foo(filters=7, train=True).apply(foo_vars, x, mutable=['batch_stats'])
new_batch_state

This returns the following:

FrozenDict({
    batch_stats: {
        BatchNorm_0: {
            mean: DeviceArray([ 0.23832898,  0.18371896, -0.08219925, -0.10693741,
                          0.21813159,  0.10592576, -0.05416182], dtype=float32),
            var: DeviceArray([0.9000007 , 0.90000004, 0.9       , 0.90000004, 0.9000001 ,
                         0.9       , 0.9       ], dtype=float32),
        },
    },
})

We stitch together params and batch stats collections to evaluate again. Normally the params would have been updated by a training step using an optimizer.

new_foo_vars = {'params': foo_vars['params'], 'batch_stats': new_batch_stats}
y2, even_newer_batch_stats = Foo(filters=7, train=True).apply(new_foo_vars, x, mutable=['batch_stats'])

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Reactions:16
  • Comments:6 (1 by maintainers)

github_iconTop GitHub Comments

1reaction
avitalcommented, Feb 1, 2021

Hi @cccntu – what do you mean by “different”? pop returns two parts: the variable collection that you popped, and the remaining collections still all grouped together, e.g.:

variables = model.init(...)
# assume variables['params'] and variables['batch_stats'] are present here
other_variables, params = variables.pop('params')

# here params == variables['params'], and other_variables['batch_stats'] == variables['batch_stats']

Maybe the docstring for FrozenDict.pop isn’t clear enough? But this is working as intended.

(BTW in your example you wrote params = model.init(rng, dummy_input) – it should probably be variables = model.init(rng, dummy_input) as parameters are one of possible multiple variable collections)

Does this help? What could we improve in our documentation so that this would be less confusing?

0reactions
jheekcommented, Mar 8, 2021

I am very interested in the best practices for BatchNorm (or batch_stats in general I guess) when used inside a pmap

See the ImageNet example for the canonical example of combining BatchNorm and pmap. There indeed we sync the statistics before evaluation with pmean.

Read more comments on GitHub >

github_iconTop Results From Across the Web

How to Accelerate Learning of Deep Neural Networks With ...
A new BatchNormalization layer can be added to the model after the hidden layer before the output layer. Specifically, after the activation ...
Read more >
Where to apply batch normalization on standard CNNs
Andrew Ng says that batch normalization should be applied immediately before the non-linearity of the current layer. The authors of the BN paper ......
Read more >
Batch Normalization in practice: an example with Keras ...
Batch Normalization in practice: an example with Keras and TensorFlow 2.0. A step by step tutorial to add and customize batch normalization.
Read more >
Hands-On Guide To Implement Batch Normalization in ...
Batch normalization is a feature that we add between the layers of the neural network and it continuously takes the output from the...
Read more >
Batch Normalization in Keras - An Example
In this report, we'll show you how to add batch normalization to a Keras model, and observe the effect BatchNormalization has as we...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found