Batchnorm example doesn't work
See original GitHub issueHi there!
I’m sorry if this issue is already known, but apparently the BatchNorm layer isn’t working correctly. For instance, the example code listed in the documentation fails with ValueError: Unable to parse module assembly (see diagnostics)
:
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr
key = jr.PRNGKey(0)
mkey, dkey = jr.split(key)
model = eqx.nn.Sequential([
eqx.nn.Linear(in_features=3, out_features=4, key=mkey),
eqx.experimental.BatchNorm(input_size=4, axis_name="batch"),
])
x = jr.normal(dkey, (10, 3))
jax.vmap(model, axis_name="batch")(x)
# ValueError: Unable to parse module assembly (see diagnostics)
Issue Analytics
- State:
- Created a year ago
- Comments:5 (2 by maintainers)
Top Results From Across the Web
Freezing network with batch norm does not work with TRT
Describe the problem. When I freeze a protobuf that contains batch normalization and then try to use it with TRT, it fails with...
Read more >Batch Norm Explained Visually — How it works, and why ...
As we discussed above, during Training, Batch Norm starts by calculating the mean and variance for a mini-batch. However, during Inference, we ...
Read more >Poor Result with BatchNormalization - Stack Overflow
Issue 1: DCGAN paper suggest to use BN(Batch Normalization) both the generator and discriminator. But, I couldn't get better result with BN ...
Read more >The Batch Normalization layer of Keras is broken - Datumbox
In this blog post, I will try to build a case for why Keras' BatchNormalization layer does not play nice with Transfer Learning,...
Read more >The Danger of Batch Normalization in Deep Learning - Mindee
This works well in practice, but we cannot do the same at inference ... For example, it will not be obvious if a...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Oh, interesting. Yup, I can reproduce this on my GPU machine too.
The bad news is that I don’t think this is something that can be fixed quickly. This looks like an obscure error from JAX doing something weird.
The good news is that core JAX has some incoming changes, that might see
equinox.experimental.BatchNorm
switching to a totally different, and robust, implementation.(For context about why
BatchNorm
is such a tricky operation to support: it’s because getting/setting its running statistics is a side effect. Meanwhile JAX, being functional, basically doesn’t support side effects! Supporting this properly, without compromising on all the lovely things that make JAX efficient, is pretty difficult.)So right now I think the answer is a somewhat unsatisfying “let’s wait and see”.
Seems to work with
equinox==0.9.2
. One small thing is that without jitting it seems to be very slow:On my gpu machine this outputs