question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

PixelCNN++ example flakey on TPUs, probably due to bfloat16

See original GitHub issue

The PixelCNN++ should reach a test loss value of 2.92, we’re finding that that doesn’t happen consistently. @marcvanzee is in the process of doing more rigorous testing, running training many times with different random seeds on different versions of the model, with and without JAX omni-staging enabled. The data we have so far (see below) is limited, but we expect to follow up soon with more detail, and hopefully identify whether this problem is real or just bad luck. For now we’re focusing on GPU, will look at TPU after.

Training run with

python train.py --batch_size=320 --num_epochs=1800 

Sheet with run statistics: https://docs.google.com/spreadsheets/d/1IDtAKUTHr6MTynKvl5szsepb9_ta-arZU-3yvyj6KQs/edit?usp=sharing

Update by @marcvanzee (Oct 29, 2020)

The performance on GPU is as expected now on Linen. Jamie is investigating TPUs now and was still experiencing some unexpected behavior, so I will assign the issue to him.

Update by @j-towns (Feb 1, 2020)

I think the discrepency between GPU and TPU is likely caused by bfloat16 vs float32 in conv ops. We should test this hypothesis by running on TPU with precision=lax.Precision.HIGHEST, and potentially add a command line argument to allow the user to choose which precision setting to use.

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:8 (1 by maintainers)

github_iconTop GitHub Comments

1reaction
marcvanzeecommented, Oct 29, 2020

From @j-towns in personal communication: “i’ve closed it for now because i basically don’t have time to run experiments on TPU. I guess if it’s working fine on GPU then any bug that still exists is unlikely to be in Flax but more likely in JAX or XLA. so i’m not too worried about it.”

1reaction
marcvanzeecommented, Oct 8, 2020

Update: I did three runs but I wasn’t using different random seeds. @jheek suggested to try this, so I will re-run a few more experiments in the coming week, and report back once I have the numbers.

Read more comments on GitHub >

github_iconTop Results From Across the Web

BFloat16: The secret to high performance on Cloud TPUs
How the high performance of Google Cloud TPUs is driven by Brain Floating Point Format, or bfloat16.
Read more >
PixelCNN - Keras
It is designed to generate images (or other data types) iteratively from an input vector where the probability distribution of prior elements ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found