Implementing Batch Normalization
See original GitHub issueIn Flax, Batch Normalization is a bit finicky since each call to apply
requires marking batch_stats
as mutable and updating the batch_stats
afterward.
bn = flax.linen.BatchNorm(use_running_average=True)
x = jnp.arange(24).reshape(3, 6)
vars = bn.init(random.PRNGKey(0), x)
# Mark the batch stats as mutable so we can update them in the variable dictionary
x_normed, mutated_vars = bn.apply(vars, x, mutable=['batch_stats'])
vars = {**vars, **mutated_vars} # Update the variables with our diff
x_normed2, mutated_vars2 = bn.apply(vars, x, mutable=['batch_stats'])
How could this be implemented as a Module in Equinox? I’m happy to submit an implementation given some guidance.
Issue Analytics
- State:
- Created a year ago
- Comments:9 (9 by maintainers)
Top Results From Across the Web
Implementing Batch Normalization in Python | by Tracy Chang
Batch normalization deals with the problem of poorly initialization of neural networks. It can be interpreted as doing preprocessing at ...
Read more >8.5. Batch Normalization - Dive into Deep Learning
Batch normalization is applied to individual layers, or optionally, to all of them: In each training iteration, we first normalize the inputs (of...
Read more >Training Deep Neural Networks with Batch Normalization
Helps in faster convergence. · Improves gradient flow through the network (and hence mitigates the vanishing gradient problem). · Allows higher learning rate...
Read more >A Gentle Introduction to Batch Normalization for Deep Neural ...
Batch normalization can be implemented during training by calculating the mean and standard deviation of each input variable to a layer per ...
Read more >Batch normalization | What it is and how to implement it
In this video, we will learn about Batch Normalization. Batch Normalization is a secret weapon that has the power to solve many problems...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Yeah, this is a known issue with JAX – namely, that
host_callback.call
handles errors differently depending on OS, device, or phase of the moon.See also https://github.com/google/jax/issues/9457
I don’t think there’s much that can be done about this one from the point of view of Equinox I’m afraid.
Well, I got successfully nerd-sniped into spending my weekend implementing this. (Mostly the new “stateful” technology that makes this possible.)
equinox.experimental.BatchNorm
now exists. Happy hacking.