Inconsistent network behaviour when using different batch sizes for `model.apply` on CPU
See original GitHub issueHey, thanks for the great work!
I’m using BatchNorm in my network, but have set the use_running_average
parameter of BatchNorm layers to true, which means it will not compute any running mean/stds using the input data that is passing through the network and it will use the pre-computed parameters. Thus, the network’s behaviour doesn’t change among different batches (Ideally, I guess, but it should be true).
I’ve provided a simple reproducible Colab notebook that reproduces the example. The colab needs two files to run properly which are:
wide_resnet_jax.py
: The python file containing the shallow WideResNet module implemented using Flax. You can download it from this gist: https://gist.github.com/mohamad-amin/5334109dba81b9c26e7b4d1ded7fd9adpsd_data.pkl
, which can be downloaded from: https://drive.google.com/file/d/18eb93M34vaWjFyNzIq-vnOfll0T6HCjT/view?usp=sharing
psd_data.pkl
is the pickled version of a dict containing three things:
data
: The train and test data used for training the model.params
: The trained parameters of the WideResNet module that we’re using, such that it will achieve 1.0 train accuracy and 0.89 test accuracy.labels
: The labels of the datapoints in data, to double check the accuracies.
The problem that I have is:
ys = []
for i in range(10):
ys.append(apply_fn(params, X_train[i:i+1]))
ys = jnp.stack(ys).squeeze()
vs = apply_fn(params, X_train[:10])
np.allclose(ys, vs)
# Outputs False!
which shows that the network’s behaviour varies for different outputs. I expect this to output true, as I have fixed the parameters and the BatchNorm layers. Am I doing something wrong?
https://colab.research.google.com/drive/1a_SheAt9RH9tPRJ1DC60yccsbYaFssDx?usp=sharing
Issue Analytics
- State:
- Created 2 years ago
- Comments:22 (6 by maintainers)
Top GitHub Comments
Closing because dtype behavior is now consistent since dropping the default float32 dtype
Here is a slightly simpler example that I think reproduces your issue:
The previous example uses stdev=1 weights for the kernels which gives you infinities/NaNs if you stack a bunch of them. In this example you get a relative error of approximately 10^-6 stacking more will make the errors larger.
I added a check against numpy as well which gives RMS errors of roughly 10^-6 errors for all pairs in logits, logits_np, and logits_loop):