TPU cannot do simple arithmetic!
See original GitHub issueI am trying to do simple matrix multiplication on TPU, but it gives a wrong result:
import jax.numpy as np
import numpy as onp
# On CPU
x = onp.array([[0.3744, 0.1656],
[0.4707, 0.1663]])
y = onp.array([[0.3946, 0.1186],
[0.1569, 0.3145]])
z = onp.dot(x, y)
# On TPU
x_ = np.asarray(x)
y_ = np.asarray(y)
z_ = np.dot(x_, y_)
print('JAX device:', x_.device())
# Compare
print('CPU result:', z)
print('TPU result:', z_)
assert np.allclose(z, z_)
Output:
JAX device: TPU_0(process=0,(0,0,0,0))
CPU result: [[0.17372088 0.09648504]
[0.21183069 0.10812637]]
TPU result: [[0.17405128 0.09669876]
[0.21180916 0.10805416]]
Traceback (most recent call last):
File "/home/ayaka/main.py", line 21, in <module>
assert np.allclose(z, z_)
AssertionError
Manual calculation:
0.3744 * 0.3946 + 0.1656 * 0.1569 = 0.13732088
So the result on CPU is correct, while the result on TPU is wrong.
Library versions:
jax 0.3.4
jaxlib 0.3.2
libtpu-nightly 0.1.dev20220315
Issue Analytics
- State:
- Created 2 years ago
- Comments:5 (5 by maintainers)
Top Results From Across the Web
What makes TPUs fine-tuned for deep learning?
TPUs can't run word processors, control rocket engines, or execute bank transactions, but they can handle the massive multiplications and ...
Read more >What's inside a TPU?. A high-powered chip specialized for…
A neural network takes a lot of math, but most of the math is pretty simple: multiply a bunch of numbers, and add...
Read more >python - Can not run model with TPU
I am doing it using kaggle: https://www.kaggle.com/jangedoo/utkface-new. I am running the TPU using kaggle notebook with accelerator TPU V3- ...
Read more >How different is a TPU from GPU?
A TPU is a coprocessor, it cannot execute code in its ow... ... which work as a brains of the computer that perform...
Read more >Google Coral Edge TPU explained in depth
These are made up with a lot of little parallel cores designed for very fast matrix calculations. You can read more about this...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
It is important to note that while the default precision is low it is deterministic, so if you train a model in low precision and do inference on that trained model in low precision you should get the expected answer.
For very large models it is typical to drop the precision, because you need to compute so many floating point operations to train these models that the improvement in performance very significant on training time.
For Gopher (a large language model from DeepMind) we talk about low precision training (even lower than f32 defaults in JAX) with bfloat16 in Appendix C.2 of our paper https://arxiv.org/pdf/2112.11446.pdf.
Typically accelerators have special hardware (“tensor cores”) for half precision (e.g. bf16, f16) compute and you can expect computations to run somewhere between 5-10x faster than full precision f32 computations.
JAX’s default precision for f32 dot product means the actual computation is done in bf16 on the TPU, so the performance improvement is significant vs.
Precision.HIGH
orPrecision.HIGHEST
.I think we can close this issue, but please let me know if not!