question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Batchnorm example doesn't work

See original GitHub issue

Hi there!

I’m sorry if this issue is already known, but apparently the BatchNorm layer isn’t working correctly. For instance, the example code listed in the documentation fails with ValueError: Unable to parse module assembly (see diagnostics):

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr

key = jr.PRNGKey(0)
mkey, dkey = jr.split(key)
model = eqx.nn.Sequential([
    eqx.nn.Linear(in_features=3, out_features=4, key=mkey),
    eqx.experimental.BatchNorm(input_size=4, axis_name="batch"),
])

x = jr.normal(dkey, (10, 3))
jax.vmap(model, axis_name="batch")(x)
# ValueError: Unable to parse module assembly (see diagnostics)

Issue Analytics

  • State:open
  • Created a year ago
  • Comments:5 (2 by maintainers)

github_iconTop GitHub Comments

3reactions
patrick-kidgercommented, Sep 22, 2022

Oh, interesting. Yup, I can reproduce this on my GPU machine too.

The bad news is that I don’t think this is something that can be fixed quickly. This looks like an obscure error from JAX doing something weird.

The good news is that core JAX has some incoming changes, that might see equinox.experimental.BatchNorm switching to a totally different, and robust, implementation.

(For context about why BatchNorm is such a tricky operation to support: it’s because getting/setting its running statistics is a side effect. Meanwhile JAX, being functional, basically doesn’t support side effects! Supporting this properly, without compromising on all the lovely things that make JAX efficient, is pretty difficult.)

So right now I think the answer is a somewhat unsatisfying “let’s wait and see”.

0reactions
ciupakabracommented, Nov 21, 2022

Seems to work with equinox==0.9.2. One small thing is that without jitting it seems to be very slow:

import equinox as eqx
import jax.random as jrandom
import jax.nn as jnn
import jax.numpy as jnp
import jax
import time

class Network(eqx.Module):
    net: eqx.Module

    def __init__(self, in_size, out_size, width, depth, *, key, bn=True):

        keys = jrandom.split(key, depth + 1)
        layers = []
        if depth == 0:
            layers.append(eqx.nn.Linear(in_size, out_size, key=keys[0]))
        else:
            layers.append(eqx.nn.Linear(in_size, width, key=keys[0]))
            if bn:
                layers.append(eqx.experimental.BatchNorm(width, axis_name="batch"))
            for i in range(depth - 1):
                layers.append(eqx.nn.Linear(width, width, key=keys[i + 1]))
                if bn: 
                    layers.append(eqx.experimental.BatchNorm(width, axis_name="batch"))
                layers.append(eqx.nn.Lambda(jnn.relu))
            layers.append(eqx.nn.Linear(width, out_size, key=keys[-1]))

        self.net = eqx.nn.Sequential(layers)

    def __call__(self, x):
        return self.net(x)



if __name__=="__main__":

    key = jrandom.PRNGKey(0)

    init_key, data_key = jrandom.split(key, 2)

    net = Network(10, 5, 3, 300, key=init_key, bn=False)
    bn_net = Network(10, 5, 3, 300, key=init_key, bn=True)

    x = jrandom.normal(data_key, (32, 10))

    func = jax.vmap(net, axis_name="batch")
    jitted = jax.jit(func)

    bn_func = jax.vmap(bn_net, axis_name="batch")
    bn_jitted = jax.jit(func)

    # compile
    jitted(x)
    bn_jitted(x)
    
    start = time.time()
    y = func(x)
    finish = time.time()
    print(f"Wout BN / Wout JIT took: {finish-start:.2f}")

    start = time.time()
    y = jitted(x)
    finish = time.time()
    print(f"Wout BN / With JIT took: {finish-start:.2f}")

    start = time.time()
    y = bn_func(x)
    finish = time.time()
    print(f"With BN / Wout JIT took: {finish-start:.2f}")

    start = time.time()
    y = bn_jitted(x)
    finish = time.time()
    print(f"With BN / With JIT took: {finish-start:.2f}")

On my gpu machine this outputs

Wout BN / Wout JIT took: 0.42
Wout BN / With JIT took: 0.00
With BN / Wout JIT took: 15.67
With BN / With JIT took: 0.00
Read more comments on GitHub >

github_iconTop Results From Across the Web

Freezing network with batch norm does not work with TRT
Describe the problem. When I freeze a protobuf that contains batch normalization and then try to use it with TRT, it fails with...
Read more >
Batch Norm Explained Visually — How it works, and why ...
As we discussed above, during Training, Batch Norm starts by calculating the mean and variance for a mini-batch. However, during Inference, we ...
Read more >
Poor Result with BatchNormalization - Stack Overflow
Issue 1: DCGAN paper suggest to use BN(Batch Normalization) both the generator and discriminator. But, I couldn't get better result with BN ...
Read more >
The Batch Normalization layer of Keras is broken - Datumbox
In this blog post, I will try to build a case for why Keras' BatchNormalization layer does not play nice with Transfer Learning,...
Read more >
The Danger of Batch Normalization in Deep Learning - Mindee
This works well in practice, but we cannot do the same at inference ... For example, it will not be obvious if a...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found