question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Default serialisation fails for `BatchNorm`.

See original GitHub issue

Hi,

The defaut serialisation fails when a model with BatchNorm is serialised. A small test script executed on dev branch.


def test_serialise_bn(getkey):
    net = eqx.nn.Sequential(
        [
            eqx.experimental.BatchNorm(3, axis_name="batch"),
        ]
    )

    eqx.tree_serialise_leaves('/tmp/net.eqx', net)

    assert True

with the error

_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
../equinox/serialisation.py:183: in tree_serialise_leaves
    jtu.tree_map(_serialise, filter_spec, pytree)
../../../miniconda3/envs/equinox/lib/python3.8/site-packages/jax/_src/tree_util.py:201: in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
../../../miniconda3/envs/equinox/lib/python3.8/site-packages/jax/_src/tree_util.py:201: in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
../equinox/serialisation.py:181: in _serialise
    jtu.tree_map(__serialise, x, is_leaf=is_leaf)
../../../miniconda3/envs/equinox/lib/python3.8/site-packages/jax/_src/tree_util.py:201: in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
../../../miniconda3/envs/equinox/lib/python3.8/site-packages/jax/_src/tree_util.py:201: in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
../equinox/serialisation.py:179: in __serialise
    spec(f, y)
../equinox/serialisation.py:50: in default_serialise_filter_spec
    value, _, _ = x.unsafe_get()
../equinox/experimental/stateful.py:112: in unsafe_get
    return _state_cache[self._obj]
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <WeakKeyDictionary at 0x7fc6deff8d90>
key = <equinox.experimental.stateful._IndexObj object at 0x7fc6deec12b0>

    def __getitem__(self, key):
>       return self.data[ref(key)]
E       KeyError: <weakref at 0x7fc736f7e590; to '_IndexObj' at 0x7fc6deec12b0>

../../../miniconda3/envs/equinox/lib/python3.8/weakref.py:383: KeyError

Issue Analytics

  • State:closed
  • Created a year ago
  • Comments:10 (10 by maintainers)

github_iconTop GitHub Comments

1reaction
patrick-kidgercommented, Aug 8, 2022

Excellent. And btw you definitely shouldn’t set _state. This is part of some deep magic to make inference mode work without the cost of looking the value up at runtime.

Use eqx.experimental.set_state if you ever want to modify the state manually.

0reactions
paganpastacommented, Aug 18, 2022

Sorry to reopen this issue.

I tried with saving a model mid-training and loading the checkpoint to resume it. It seems the behaviour is breaking on the fixed branch (https://github.com/patrick-kidger/equinox/pull/172).

Sharing a small script to reproduce the behaviour. The script needs to be run twice.

Once with LOAD=False and then with LOAD=True. With LOAD=False, the script works as intended and the net is serialised to the disk. With LOAD=True, I get the error

    loss, grads = compute_loss(model, x, y, keys)
  File "/tmp/equinox/equinox/grad.py", line 30, in __call__
    return fun_value_and_grad(diff_x, nondiff_x, *args, **kwargs)
  File "/tmp/equinox/equinox/grad.py", line 27, in fun_value_and_grad
    return __self._fun(_x, *_args, **_kwargs)
  File "adv.py", line 13, in compute_loss
    logits = jax.vmap(model, axis_name=('batch'))(x, key=keys)
  File "/tmp/equinox/equinox/nn/composed.py", line 129, in __call__
    x = layer(x, key=key)
  File "/tmp/equinox/equinox/experimental/batch_norm.py", line 161, in __call__
    lambda: get_state(
  File "/opt/conda/lib/python3.7/site-packages/jax/experimental/host_callback.py", line 1334, in _outside_call_jvp_rule
    raise NotImplementedError("JVP rule is implemented only for id_tap, not for call.")
NotImplementedError: JVP rule is implemented only for id_tap, not for call.

import jax
import jax.numpy as jnp
import jax.random as jrandom
import optax

import numpy as np
import equinox as eqx

LOAD = False

@eqx.filter_value_and_grad
def compute_loss(model, x, y, keys):
    logits = jax.vmap(model, axis_name=('batch'))(x, key=keys)
    one_hot_actual = jax.nn.one_hot(y, num_classes=5)
    return optax.softmax_cross_entropy(logits, one_hot_actual).mean()
        

@eqx.filter_jit
def make_step(model, x, y, keys, optimizer, opt_state):
    loss, grads = compute_loss(model, x, y, keys)
    updates, opt_state = optimizer.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    return loss, model, opt_state

net = eqx.nn.Sequential(
            [
                eqx.nn.Linear(3, 5, key=jrandom.PRNGKey(0)),
                eqx.experimental.BatchNorm(5, axis_name='batch')
            ]
        )

x = jnp.asarray(np.random.rand(10, 3))
y = jnp.asarray(np.random.randint(0, 9, 10))
key = jrandom.split(jrandom.PRNGKey(0), 10)
if LOAD:
    net = eqx.tree_deserialise_leaves('/tmp/net.eqx', net)

optimizer = optax.adam(learning_rate=0.1)
opt_state = optimizer.init(eqx.filter(net, eqx.is_array))
_, net, _ = make_step(net, x, y, key, optimizer, opt_state)
eqx.tree_serialise_leaves('/tmp/net.eqx', net)

Read more comments on GitHub >

github_iconTop Results From Across the Web

Serialization of BatchNormalization statistics fails when layer ...
BatchNormalization serialization fails to load the stored statistics avg_mean and avg_var when the layer is initialized using the axis ...
Read more >
The Batch Normalization layer of Keras is broken - Datumbox
The problem with the current implementation of Keras is that when a BN layer is frozen, it continues to use the mini-batch statistics...
Read more >
tf.keras.layers.BatchNormalization | TensorFlow v2.11.0
Batch normalization applies a transformation that maintains the mean output close to 0 and the output standard deviation close to 1.
Read more >
Tensorflow fails to initialize batch normalization layer
I'm running Tensorflow 1.12 on GPU in a conda environment. I have several batch norm layers as part of ...
Read more >
BatchNorm2d — PyTorch 1.13 documentation
Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs with additional channel dimension) as described in the paper Batch Normalization: ......
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found