PixelCNN++ example flakey on TPUs, probably due to bfloat16
See original GitHub issueThe PixelCNN++ should reach a test loss value of 2.92, we’re finding that that doesn’t happen consistently. @marcvanzee is in the process of doing more rigorous testing, running training many times with different random seeds on different versions of the model, with and without JAX omni-staging enabled. The data we have so far (see below) is limited, but we expect to follow up soon with more detail, and hopefully identify whether this problem is real or just bad luck. For now we’re focusing on GPU, will look at TPU after.
Training run with
python train.py --batch_size=320 --num_epochs=1800
Sheet with run statistics: https://docs.google.com/spreadsheets/d/1IDtAKUTHr6MTynKvl5szsepb9_ta-arZU-3yvyj6KQs/edit?usp=sharing
Update by @marcvanzee (Oct 29, 2020)
The performance on GPU is as expected now on Linen. Jamie is investigating TPUs now and was still experiencing some unexpected behavior, so I will assign the issue to him.
Update by @j-towns (Feb 1, 2020)
I think the discrepency between GPU and TPU is likely caused by bfloat16 vs float32 in conv ops. We should test this hypothesis by running on TPU with precision=lax.Precision.HIGHEST
, and potentially add a command line argument to allow the user to choose which precision setting to use.
Issue Analytics
- State:
- Created 3 years ago
- Comments:8 (1 by maintainers)
Top GitHub Comments
From @j-towns in personal communication: “i’ve closed it for now because i basically don’t have time to run experiments on TPU. I guess if it’s working fine on GPU then any bug that still exists is unlikely to be in Flax but more likely in JAX or XLA. so i’m not too worried about it.”
Update: I did three runs but I wasn’t using different random seeds. @jheek suggested to try this, so I will re-run a few more experiments in the coming week, and report back once I have the numbers.