question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

TPU cannot do simple arithmetic!

See original GitHub issue

I am trying to do simple matrix multiplication on TPU, but it gives a wrong result:

import jax.numpy as np
import numpy as onp

# On CPU
x = onp.array([[0.3744, 0.1656],
               [0.4707, 0.1663]])
y = onp.array([[0.3946, 0.1186],
               [0.1569, 0.3145]])
z = onp.dot(x, y)

# On TPU
x_ = np.asarray(x)
y_ = np.asarray(y)
z_ = np.dot(x_, y_)

print('JAX device:', x_.device())

# Compare
print('CPU result:', z)
print('TPU result:', z_)
assert np.allclose(z, z_)

Output:

JAX device: TPU_0(process=0,(0,0,0,0))
CPU result: [[0.17372088 0.09648504]
 [0.21183069 0.10812637]]
TPU result: [[0.17405128 0.09669876]
 [0.21180916 0.10805416]]
Traceback (most recent call last):
  File "/home/ayaka/main.py", line 21, in <module>
    assert np.allclose(z, z_)
AssertionError

Manual calculation:

0.3744 * 0.3946 + 0.1656 * 0.1569 = 0.13732088

So the result on CPU is correct, while the result on TPU is wrong.

Library versions:

jax                  0.3.4
jaxlib               0.3.2
libtpu-nightly       0.1.dev20220315

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:5 (5 by maintainers)

github_iconTop GitHub Comments

3reactions
tomhennigancommented, Mar 21, 2022

Are there any research indicating low precision will not affect the model performance? As deep learning models are growing larger and larger, I am thinking that different precision may result to totally different output after many layers of operations.

It is important to note that while the default precision is low it is deterministic, so if you train a model in low precision and do inference on that trained model in low precision you should get the expected answer.

For very large models it is typical to drop the precision, because you need to compute so many floating point operations to train these models that the improvement in performance very significant on training time.

For Gopher (a large language model from DeepMind) we talk about low precision training (even lower than f32 defaults in JAX) with bfloat16 in Appendix C.2 of our paper https://arxiv.org/pdf/2112.11446.pdf.

How much training time can I save when using lower precision?

Typically accelerators have special hardware (“tensor cores”) for half precision (e.g. bf16, f16) compute and you can expect computations to run somewhere between 5-10x faster than full precision f32 computations.

JAX’s default precision for f32 dot product means the actual computation is done in bf16 on the TPU, so the performance improvement is significant vs. Precision.HIGH or Precision.HIGHEST.

0reactions
mattjjcommented, Mar 21, 2022

I think we can close this issue, but please let me know if not!

Read more comments on GitHub >

github_iconTop Results From Across the Web

What makes TPUs fine-tuned for deep learning?
TPUs can't run word processors, control rocket engines, or execute bank transactions, but they can handle the massive multiplications and ...
Read more >
What's inside a TPU?. A high-powered chip specialized for…
A neural network takes a lot of math, but most of the math is pretty simple: multiply a bunch of numbers, and add...
Read more >
python - Can not run model with TPU
I am doing it using kaggle: https://www.kaggle.com/jangedoo/utkface-new. I am running the TPU using kaggle notebook with accelerator TPU V3- ...
Read more >
How different is a TPU from GPU?
A TPU is a coprocessor, it cannot execute code in its ow... ... which work as a brains of the computer that perform...
Read more >
Google Coral Edge TPU explained in depth
These are made up with a lot of little parallel cores designed for very fast matrix calculations. You can read more about this...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found