Non fully reproducible results on GPU
See original GitHub issueAlthough random key is fixed (e.g. jax.random.PRNGKey(0)
), the results of different runs are always different.
My question is how one can fix the random behavior? Because my expectation is when I choose a fixed random key, all the runs should produce the same result.
Thank you in advance.
Use the following code to reproduce the issue (I simply take the MNIST example with shuffle
removed):
import jax
import flax
import numpy as onp
import jax.numpy as jnp
import tensorflow_datasets as tfds
class CNN(flax.nn.Module):
def apply(self, x):
x = flax.nn.Conv(x, features=32, kernel_size=(3, 3))
x = jax.nn.relu(x)
x = flax.nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = flax.nn.Conv(x, features=64, kernel_size=(3, 3))
x = flax.nn.relu(x)
x = flax.nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1))
x = flax.nn.Dense(x, features=256)
x = flax.nn.relu(x)
x = flax.nn.Dense(x, features=10)
x = flax.nn.log_softmax(x)
return x
@jax.vmap
def cross_entropy_loss(logits, label):
return -logits[label]
def compute_metrics(logits, labels):
loss = jnp.mean(cross_entropy_loss(logits, labels))
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
return {'loss': loss, 'accuracy': accuracy}
@jax.jit
def train_step(optimizer, batch):
def loss_fn(model):
logits = model(batch['image'])
loss = jnp.mean(cross_entropy_loss(
logits, batch['label']))
return loss, logits
optimizer, _, _ = optimizer.optimize(loss_fn)
return optimizer
@jax.jit
def eval(model, eval_ds):
logits = model(eval_ds['image'] / 255.0)
return compute_metrics(logits, eval_ds['label'])
def train():
train_ds = tfds.load('mnist', split=tfds.Split.TRAIN)
train_ds = train_ds.cache().batch(128)
test_ds = tfds.as_numpy(tfds.load(
'mnist', split=tfds.Split.TEST, batch_size=-1))
_, model = CNN.create_by_shape(
jax.random.PRNGKey(0),
[((1, 28, 28, 1), jnp.float32)])
optimizer = flax.optim.Momentum(
learning_rate=0.1, beta=0.9).create(model)
for epoch in range(10):
for batch in tfds.as_numpy(train_ds):
batch['image'] = batch['image'] / 255.0
optimizer = train_step(optimizer, batch)
metrics = eval(optimizer.target, test_ds)
print('eval epoch: %d, loss: %.4f, accuracy: %.2f'
% (epoch+1,
metrics['loss'], metrics['accuracy'] * 100))
train()
train()
Issue Analytics
- State:
- Created 4 years ago
- Comments:13 (5 by maintainers)
Top Results From Across the Web
Keras GPU/CPU Reproducibility Test - Kaggle
Although it is not possible to have full reproducibility when using a gpu, if a seed is fixed, the results will be more...
Read more >Reproducibility in Tensorflow with GPUs - Stack Overflow
Is it possible to achieve complete reproducibility of results in Tensorflow using a gpu? For example, if I do the following:
Read more >GPU in state where results are not reproducible!
I just noticed that on my development machine (4.0RC2) the results of my program suddenly starting not making any sense at all.
Read more >Reproducibility — PyTorch 1.13 documentation
Completely reproducible results are not guaranteed across PyTorch releases, individual commits, or different platforms. Furthermore, results may not be ...
Read more >How to compare CNN models with non-reproducible results?
I use Keras and for training, I use a GPU, Google Colab with Tensorflow backend. Unfortunately I'm not able to create the same...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
I confirm that for the MNIST example, this issue is solved. With the following command:
!export XLA_FLAGS=--xla_gpu_deterministic_reductions && export TF_CUDNN_DETERMINISTIC=1 && echo $XLA_FLAGS, $TF_CUDNN_DETERMINISTIC && python main.py
,results are consistent between 2 runs (on Google Colab):
Yes, indeed at the moment XLA builds on GPU aren’t fully reproducible, e.g. https://github.com/google/jax/issues/565. I’ll check with the JAX team to learn more.