question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Unexpected (?) BatchNorm behavior: model's flattened form changes through iterations

See original GitHub issue

Hi, This might be something that is already known, or perhaps I’m not using the library as intended. Apologies in advance if that’s the case. First some background info:

I’m writing code for a scenario that features a form of pipeline parallelism. I have a model, which I split in parts/modules, and each part is run on a different device. The results of each part are passed on to the next in a loop. The model features BatchNorm (I’m trying to implement some known results that use it, although I’m now aware that BatchNorm is finicky in Jax).

As a test case, I feed N batches of the exact same samples in the first N iterations, then do some updates on my model. I repeat this procedure with a new batch, which is fed repeatedly for the next N iterations. As a sanity check, in every N consecutive iterations, the model should output the same values. This is not the case, though, and I think BatchNorm might be the issue.

To debug, I thought I’d check whether the model’s parameters change during these N iterations, by flattening it and comparing it to its previous version. However, I run into errors regarding “List arity mismatch”. I have a very simplified example that exhibits this sort of behavior below. To simulate my use case, the second module/part is only run from the third iteration onward. Even for i = 1, the two model “versions” are not comparable (one was before running anything, the second after running the first module/part).

If I remove the BatchNorm layers there are no errors, which leads me to believe that the fact that it modifies its state is the problem. Am I using something wrong here? If not, how can I work around this, and what could possibly cause my model’s output to be different for the same inputs?

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr

key = jr.PRNGKey(0)
mkey, dkey = jr.split(key)
model_pt1 = eqx.nn.Sequential([
    eqx.nn.Linear(in_features=3, out_features=4, key=mkey),
    eqx.experimental.BatchNorm(input_size=4, axis_name="batch"),
])
model_pt2 = eqx.nn.Sequential([   
    eqx.nn.Linear(in_features=4, out_features=4, key=mkey),
    eqx.experimental.BatchNorm(input_size=4, axis_name="batch"),
])

model_combined = [model_pt1, model_pt2]

x = jr.normal(dkey, (10, 3))
flattened_model, _ = jax.tree_util.tree_flatten(eqx.filter(model_combined, eqx.is_inexact_array))
for i in range(10):
    prev_flattened_model = flattened_model
    flattened_model, _ = jax.tree_util.tree_flatten(eqx.filter(model_combined, eqx.is_inexact_array))
    
    diffs = jax.tree_map(lambda x, y : jnp.abs(x - y), flattened_model, prev_flattened_model)
    y1 = jax.vmap(model_pt1, axis_name="batch")(x)
    if(i >= 2):
        y2 = jax.vmap(model_pt2, axis_name="batch")(y1)

Issue Analytics

  • State:closed
  • Created 10 months ago
  • Comments:9 (4 by maintainers)

github_iconTop GitHub Comments

1reaction
uuirscommented, Nov 28, 2022

Hi @geomlyd

  1. Just try this, the reason is when you call model_pt2 triggers another modification.
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr

key = jr.PRNGKey(0)
mkey, dkey = jr.split(key)
model_pt1 = eqx.nn.Sequential([
    eqx.nn.Linear(in_features=3, out_features=4, key=mkey),
    eqx.experimental.BatchNorm(input_size=4, axis_name="batch"),
])
model_pt2 = eqx.nn.Sequential([   
    eqx.nn.Linear(in_features=4, out_features=4, key=mkey),
    eqx.experimental.BatchNorm(input_size=4, axis_name="batch"),
])

model_combined = [model_pt1, model_pt2]

x = jr.normal(dkey, (10, 3))
flattened_model, _ = jax.tree_util.tree_flatten(eqx.filter(model_combined, eqx.is_inexact_array))
for i in range(10):
    prev_flattened_model = flattened_model
    flattened_model, _ = jax.tree_util.tree_flatten(eqx.filter(model_combined, eqx.is_inexact_array))
    
    y1 = jax.vmap(model_pt1, axis_name="batch")(x)
    y2 = jax.vmap(model_pt2, axis_name="batch")(y1)
    if(i >= 2):
        diffs = jax.tree_map(lambda x, y : jnp.abs(x - y), flattened_model, prev_flattened_model)
        assert eqx.tree_equal(flattened_model, prev_flattened_model)
  1. No need to filter BatchNorm as it already has set no gradient for its running statistics by jax.lax.stop_gradient. Usually just is_array.
1reaction
ciupakabracommented, Nov 25, 2022

Looks like this is related to #234 (see my last comment there). Essentially, after the first call to BatchNorm the treedef of the BatchNorm module changes, leaving it with more leaves than before and so the two lists of leaves cannot be compared / substracted. A smaller working example than the one above is:

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr

key = jr.PRNGKey(0)
mkey, dkey = jr.split(key)
model = eqx.experimental.BatchNorm(input_size=4, axis_name="batch")

x = jr.normal(dkey, (10, 4))

flat_model, treedef = jax.tree_util.tree_flatten(eqx.filter(model, eqx.is_inexact_array))

jax.vmap(model, axis_name="batch")(x)

new_flat_model, new_treedef = jax.tree_util.tree_flatten(eqx.filter(model, eqx.is_inexact_array))

diffs = jax.tree_map(lambda x, y : jnp.abs(x - y), new_flat_model, flat_model)

which also throws:

Traceback (most recent call last):
  File "/Users/andriusovsianas/repos/test/test.py", line 18, in <module>
    diffs = jax.tree_map(lambda x, y : jnp.abs(x - y), new_flat_model, flat_model)
  File "/Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/tree_util.py", line 206, in tree_map
    all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
  File "/Users/andriusovsianas/miniconda/lib/python3.10/site-packages/jax/_src/tree_util.py", line 206, in <listcomp>
    all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
ValueError: List arity mismatch: 2 != 4; list: [DeviceArray([1., 1., 1., 1.], dtype=float32), DeviceArray([0., 0., 0., 0.], dtype=float32)].

I’m just now trying to figure out exactly how BatchNorm is implemented and what would be the fix but the maintainers will probably find it faster.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Unexpected behaviour when predicting with batch norm and ...
If you are using dropout or batch normalisation and you run predict on multiple observations, it won't give an error but it will...
Read more >
Batch normalization in 3 levels of understanding
An updated explanation of Batch Normalization through 3 levels of understanding : in 30 seconds, 3 minutes, and a comprehensive guide ...
Read more >
8.5. Batch Normalization - Dive into Deep Learning
Batch normalization is applied to individual layers, or optionally, to all of them: In each training iteration, we first normalize the inputs (of...
Read more >
How to Identify and Diagnose GAN Failure Modes
We will then impair the GAN models in different ways and explore a range of failure modes that you may encounter when training...
Read more >
Understanding Gradient Clipping (and How It Can Fix ...
Deep dive into exploding gradients problem. For calculating gradients in a Deep Recurrent Networks we use something called Backpropagation through time (BPTT), ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found