Default serialisation fails for `BatchNorm`.
See original GitHub issueHi,
The defaut serialisation fails when a model with BatchNorm
is serialised. A small test script executed on dev
branch.
def test_serialise_bn(getkey):
net = eqx.nn.Sequential(
[
eqx.experimental.BatchNorm(3, axis_name="batch"),
]
)
eqx.tree_serialise_leaves('/tmp/net.eqx', net)
assert True
with the error
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../equinox/serialisation.py:183: in tree_serialise_leaves
jtu.tree_map(_serialise, filter_spec, pytree)
../../../miniconda3/envs/equinox/lib/python3.8/site-packages/jax/_src/tree_util.py:201: in tree_map
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
../../../miniconda3/envs/equinox/lib/python3.8/site-packages/jax/_src/tree_util.py:201: in <genexpr>
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
../equinox/serialisation.py:181: in _serialise
jtu.tree_map(__serialise, x, is_leaf=is_leaf)
../../../miniconda3/envs/equinox/lib/python3.8/site-packages/jax/_src/tree_util.py:201: in tree_map
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
../../../miniconda3/envs/equinox/lib/python3.8/site-packages/jax/_src/tree_util.py:201: in <genexpr>
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
../equinox/serialisation.py:179: in __serialise
spec(f, y)
../equinox/serialisation.py:50: in default_serialise_filter_spec
value, _, _ = x.unsafe_get()
../equinox/experimental/stateful.py:112: in unsafe_get
return _state_cache[self._obj]
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <WeakKeyDictionary at 0x7fc6deff8d90>
key = <equinox.experimental.stateful._IndexObj object at 0x7fc6deec12b0>
def __getitem__(self, key):
> return self.data[ref(key)]
E KeyError: <weakref at 0x7fc736f7e590; to '_IndexObj' at 0x7fc6deec12b0>
../../../miniconda3/envs/equinox/lib/python3.8/weakref.py:383: KeyError
Issue Analytics
- State:
- Created a year ago
- Comments:10 (10 by maintainers)
Top Results From Across the Web
Serialization of BatchNormalization statistics fails when layer ...
BatchNormalization serialization fails to load the stored statistics avg_mean and avg_var when the layer is initialized using the axis ...
Read more >The Batch Normalization layer of Keras is broken - Datumbox
The problem with the current implementation of Keras is that when a BN layer is frozen, it continues to use the mini-batch statistics...
Read more >tf.keras.layers.BatchNormalization | TensorFlow v2.11.0
Batch normalization applies a transformation that maintains the mean output close to 0 and the output standard deviation close to 1.
Read more >Tensorflow fails to initialize batch normalization layer
I'm running Tensorflow 1.12 on GPU in a conda environment. I have several batch norm layers as part of ...
Read more >BatchNorm2d — PyTorch 1.13 documentation
Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs with additional channel dimension) as described in the paper Batch Normalization: ......
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Excellent. And btw you definitely shouldn’t set
_state
. This is part of some deep magic to make inference mode work without the cost of looking the value up at runtime.Use
eqx.experimental.set_state
if you ever want to modify the state manually.Sorry to reopen this issue.
I tried with saving a model mid-training and loading the checkpoint to resume it. It seems the behaviour is breaking on the fixed branch (https://github.com/patrick-kidger/equinox/pull/172).
Sharing a small script to reproduce the behaviour. The script needs to be run
twice
.Once with
LOAD=False
and then withLOAD=True
. WithLOAD=False
, the script works as intended and the net is serialised to the disk. WithLOAD=True
, I get the error