Unexpected (?) BatchNorm behavior: model's flattened form changes through iterations
See original GitHub issueHi, This might be something that is already known, or perhaps I’m not using the library as intended. Apologies in advance if that’s the case. First some background info:
I’m writing code for a scenario that features a form of pipeline parallelism. I have a model, which I split in parts/modules, and each part is run on a different device. The results of each part are passed on to the next in a loop. The model features BatchNorm (I’m trying to implement some known results that use it, although I’m now aware that BatchNorm is finicky in Jax).
As a test case, I feed N batches of the exact same samples in the first N iterations, then do some updates on my model. I repeat this procedure with a new batch, which is fed repeatedly for the next N iterations. As a sanity check, in every N consecutive iterations, the model should output the same values. This is not the case, though, and I think BatchNorm might be the issue.
To debug, I thought I’d check whether the model’s parameters change during these N iterations, by flattening it and comparing it to its previous version. However, I run into errors regarding “List arity mismatch”. I have a very simplified example that exhibits this sort of behavior below. To simulate my use case, the second module/part is only run from the third iteration onward. Even for i = 1, the two model “versions” are not comparable (one was before running anything, the second after running the first module/part).
If I remove the BatchNorm layers there are no errors, which leads me to believe that the fact that it modifies its state is the problem. Am I using something wrong here? If not, how can I work around this, and what could possibly cause my model’s output to be different for the same inputs?
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr
key = jr.PRNGKey(0)
mkey, dkey = jr.split(key)
model_pt1 = eqx.nn.Sequential([
eqx.nn.Linear(in_features=3, out_features=4, key=mkey),
eqx.experimental.BatchNorm(input_size=4, axis_name="batch"),
])
model_pt2 = eqx.nn.Sequential([
eqx.nn.Linear(in_features=4, out_features=4, key=mkey),
eqx.experimental.BatchNorm(input_size=4, axis_name="batch"),
])
model_combined = [model_pt1, model_pt2]
x = jr.normal(dkey, (10, 3))
flattened_model, _ = jax.tree_util.tree_flatten(eqx.filter(model_combined, eqx.is_inexact_array))
for i in range(10):
prev_flattened_model = flattened_model
flattened_model, _ = jax.tree_util.tree_flatten(eqx.filter(model_combined, eqx.is_inexact_array))
diffs = jax.tree_map(lambda x, y : jnp.abs(x - y), flattened_model, prev_flattened_model)
y1 = jax.vmap(model_pt1, axis_name="batch")(x)
if(i >= 2):
y2 = jax.vmap(model_pt2, axis_name="batch")(y1)
Issue Analytics
- State:
- Created 10 months ago
- Comments:9 (4 by maintainers)
Top GitHub Comments
Hi @geomlyd
model_pt2
triggers another modification.BatchNorm
as it already has set no gradient for its running statistics byjax.lax.stop_gradient
. Usually justis_array
.Looks like this is related to #234 (see my last comment there). Essentially, after the first call to
BatchNorm
thetreedef
of theBatchNorm
module changes, leaving it with more leaves than before and so the two lists of leaves cannot be compared / substracted. A smaller working example than the one above is:which also throws:
I’m just now trying to figure out exactly how
BatchNorm
is implemented and what would be the fix but the maintainers will probably find it faster.