question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Non fully reproducible results on GPU

See original GitHub issue

Although random key is fixed (e.g. jax.random.PRNGKey(0)), the results of different runs are always different.

My question is how one can fix the random behavior? Because my expectation is when I choose a fixed random key, all the runs should produce the same result.

Thank you in advance.

Use the following code to reproduce the issue (I simply take the MNIST example with shuffle removed):

import jax
import flax
import numpy as onp
import jax.numpy as jnp
import tensorflow_datasets as tfds

class CNN(flax.nn.Module):
  def apply(self, x):
    x = flax.nn.Conv(x, features=32, kernel_size=(3, 3))
    x = jax.nn.relu(x)
    x = flax.nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = flax.nn.Conv(x, features=64, kernel_size=(3, 3))
    x = flax.nn.relu(x)
    x = flax.nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))
    x = flax.nn.Dense(x, features=256)
    x = flax.nn.relu(x)
    x = flax.nn.Dense(x, features=10)
    x = flax.nn.log_softmax(x)
    return x

@jax.vmap
def cross_entropy_loss(logits, label):
  return -logits[label]

def compute_metrics(logits, labels):
  loss = jnp.mean(cross_entropy_loss(logits, labels))
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  return {'loss': loss, 'accuracy': accuracy}

@jax.jit
def train_step(optimizer, batch):
  def loss_fn(model):
    logits = model(batch['image'])
    loss = jnp.mean(cross_entropy_loss(
        logits, batch['label']))
    return loss, logits
  optimizer, _, _ = optimizer.optimize(loss_fn)
  return optimizer

@jax.jit
def eval(model, eval_ds):
  logits = model(eval_ds['image'] / 255.0)
  return compute_metrics(logits, eval_ds['label'])

def train():
  train_ds = tfds.load('mnist', split=tfds.Split.TRAIN)
  train_ds = train_ds.cache().batch(128)
  test_ds = tfds.as_numpy(tfds.load(
      'mnist', split=tfds.Split.TEST, batch_size=-1))

  _, model = CNN.create_by_shape(
      jax.random.PRNGKey(0),
      [((1, 28, 28, 1), jnp.float32)])

  optimizer = flax.optim.Momentum(
      learning_rate=0.1, beta=0.9).create(model)

  for epoch in range(10):
    for batch in tfds.as_numpy(train_ds):
      batch['image'] = batch['image'] / 255.0
      optimizer = train_step(optimizer, batch)

    metrics = eval(optimizer.target, test_ds)
    print('eval epoch: %d, loss: %.4f, accuracy: %.2f'
         % (epoch+1,
          metrics['loss'], metrics['accuracy'] * 100))

train()
train()

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Comments:13 (5 by maintainers)

github_iconTop GitHub Comments

1reaction
goingtosleepcommented, Oct 28, 2020

I confirm that for the MNIST example, this issue is solved. With the following command:

!export XLA_FLAGS=--xla_gpu_deterministic_reductions && export TF_CUDNN_DETERMINISTIC=1 && echo $XLA_FLAGS, $TF_CUDNN_DETERMINISTIC && python main.py,

results are consistent between 2 runs (on Google Colab):

--xla_gpu_deterministic_reductions, 1
2020-10-28 16:14:10.248318: W tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:39] Overriding allow_growth setting because the TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original config value was 0.
eval epoch: 1, loss: 0.0563, accuracy: 98.14
eval epoch: 2, loss: 0.0561, accuracy: 98.21
eval epoch: 3, loss: 0.0401, accuracy: 98.79
eval epoch: 4, loss: 0.0365, accuracy: 98.89
eval epoch: 5, loss: 0.0359, accuracy: 98.95
eval epoch: 6, loss: 0.0360, accuracy: 98.94
eval epoch: 7, loss: 0.0303, accuracy: 99.16
eval epoch: 8, loss: 0.0418, accuracy: 98.93
eval epoch: 9, loss: 0.0406, accuracy: 99.03
eval epoch: 10, loss: 0.0326, accuracy: 99.18

eval epoch: 1, loss: 0.0563, accuracy: 98.14
eval epoch: 2, loss: 0.0561, accuracy: 98.21
eval epoch: 3, loss: 0.0401, accuracy: 98.79
eval epoch: 4, loss: 0.0365, accuracy: 98.89
eval epoch: 5, loss: 0.0359, accuracy: 98.95
eval epoch: 6, loss: 0.0360, accuracy: 98.94
eval epoch: 7, loss: 0.0303, accuracy: 99.16
eval epoch: 8, loss: 0.0418, accuracy: 98.93
eval epoch: 9, loss: 0.0406, accuracy: 99.03
eval epoch: 10, loss: 0.0326, accuracy: 99.18
1reaction
avitalcommented, Mar 10, 2020

Yes, indeed at the moment XLA builds on GPU aren’t fully reproducible, e.g. https://github.com/google/jax/issues/565. I’ll check with the JAX team to learn more.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Keras GPU/CPU Reproducibility Test - Kaggle
Although it is not possible to have full reproducibility when using a gpu, if a seed is fixed, the results will be more...
Read more >
Reproducibility in Tensorflow with GPUs - Stack Overflow
Is it possible to achieve complete reproducibility of results in Tensorflow using a gpu? For example, if I do the following:
Read more >
GPU in state where results are not reproducible!
I just noticed that on my development machine (4.0RC2) the results of my program suddenly starting not making any sense at all.
Read more >
Reproducibility — PyTorch 1.13 documentation
Completely reproducible results are not guaranteed across PyTorch releases, individual commits, or different platforms. Furthermore, results may not be ...
Read more >
How to compare CNN models with non-reproducible results?
I use Keras and for training, I use a GPU, Google Colab with Tensorflow backend. Unfortunately I'm not able to create the same...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found