question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Inconsistent network behaviour when using different batch sizes for `model.apply` on CPU

See original GitHub issue

Hey, thanks for the great work!

I’m using BatchNorm in my network, but have set the use_running_average parameter of BatchNorm layers to true, which means it will not compute any running mean/stds using the input data that is passing through the network and it will use the pre-computed parameters. Thus, the network’s behaviour doesn’t change among different batches (Ideally, I guess, but it should be true).

I’ve provided a simple reproducible Colab notebook that reproduces the example. The colab needs two files to run properly which are:

psd_data.pkl is the pickled version of a dict containing three things:

  • data: The train and test data used for training the model.
  • params: The trained parameters of the WideResNet module that we’re using, such that it will achieve 1.0 train accuracy and 0.89 test accuracy.
  • labels: The labels of the datapoints in data, to double check the accuracies.

The problem that I have is:

ys = []
for i in range(10):
  ys.append(apply_fn(params, X_train[i:i+1]))
ys = jnp.stack(ys).squeeze()
vs = apply_fn(params, X_train[:10])
np.allclose(ys, vs)
# Outputs False!

which shows that the network’s behaviour varies for different outputs. I expect this to output true, as I have fixed the parameters and the BatchNorm layers. Am I doing something wrong?

https://colab.research.google.com/drive/1a_SheAt9RH9tPRJ1DC60yccsbYaFssDx?usp=sharing

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:22 (6 by maintainers)

github_iconTop GitHub Comments

2reactions
jheekcommented, Sep 5, 2022

Closing because dtype behavior is now consistent since dropping the default float32 dtype

1reaction
jheekcommented, Jan 6, 2022

Here is a slightly simpler example that I think reproduces your issue:

import numpy as np
from jax import lax, random, jit
from jax import nn
import jax

init_fn = nn.initializers.lecun_normal()
lhs = random.normal(random.PRNGKey(0), (128, 256))

def conv(x):
  for i in range(10):
    rhs = init_fn(random.PRNGKey(i), (256, 256))
    x = x @ rhs
  return jax.device_get(x)

def np_conv(x):
  x = jax.device_get(x)
  for i in range(10):
    rhs = jax.device_get(init_fn(random.PRNGKey(i), (256, 256)))
    x = x @ rhs
  return x

logits = conv(lhs)
logits_np = np_conv(lhs)

logits_loop = np.zeros_like(logits)
for i in range(128):
  logits_loop[i] = conv(lhs[i:i+1])

print(np.allclose(logits_loop, logits))  # outputs False

The previous example uses stdev=1 weights for the kernels which gives you infinities/NaNs if you stack a bunch of them. In this example you get a relative error of approximately 10^-6 stacking more will make the errors larger.

I added a check against numpy as well which gives RMS errors of roughly 10^-6 errors for all pairs in logits, logits_np, and logits_loop):

image

Read more comments on GitHub >

github_iconTop Results From Across the Web

Why tuning 2 times, the accuracy of compiled model nearly 0?
My initial suspicion is that the auto-tuned network has been tuned for a batch size of 1. So if you run it with...
Read more >
How to use Different Batch Sizes when Training and ...
The batch size limits the number of samples to be shown to the network before a weight update can be performed. This same...
Read more >
What is the effect of using extremely small batch sizes in deep ...
Trying to estimate mean and std with small batch size is a big no-no. This will cause havok on your network outputs to...
Read more >
Why does the loss/accuracy fluctuate during the training ...
This is why batch_size parameter exists which determines how many samples you want to use to make one update to the model parameters....
Read more >
On Batch-size Selection for Stochastic Training for Graph ...
We study the batch size selection problem for training graph neural network (GNN) with SGD method. To reduce the training time while keeping...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found