Matrix multiplication precision API
See original GitHub issueOperations that do matrix multiplication in JAX accept a precison
for controlling precision when executed on TPUs.
The current API seems non-ideal to me:
- You have to pass an enum value (e.g.,
np.dot(x, y, precision=lax.Precison.HIGHEST)
). This is a little cumbersome and inconsistent with most NumPy/SciPy APIs which use strings (e.g.,np.dot(x, y, precision='highest')
). - The current names for precision levels (“highest”, “high” and “default”) are not very descriptive. In my ideal world we would use some direct indication of the corresponding precision (e.g., bfloat16 multiplication with float32 accumulation), but as the very least can we switch “default” to “low”?
- The default low precision is a bit of a footgun, at least when doing anything that isn’t implementing a neural net layer. In my opinion, it would be much safer to use “highest” precision by default (which isn’t that much slower) on float32 data. Neural net libraries, of course, can default to lower precision, so this really only effects users who directly use NumPy APIs or the
@
infix operator.
Issue Analytics
- State:
- Created 4 years ago
- Reactions:10
- Comments:6 (6 by maintainers)
Top Results From Across the Web
Matrix Multiplication - AI Engine API User Guide - Xilinx
This class template is parametrized with the matrix multiplication shape (MxKxN), the data types and, optionally, the requested accmululation precision.
Read more >Precision error on matrix multiplication - Stack Overflow
A problem with your code as it now stands is that the multiplication is being done in float precision then added to a...
Read more >Matrix Multiplication — OneDNN documentation
The matrix multiplication (MatMul) primitive computes the product of two 2D tensors ... C++ API example demonstrating how one can perform reduced precision...
Read more >Matrix math for the web - Web APIs - MDN Web Docs
This article explores how to create matrices and how to use them with CSS ... What does multiplying by the identity matrix look...
Read more >jax.default_matmul_precision - JAX documentation
Some platforms, like TPU, offer configurable precision levels for matrix multiplication and convolution computations, trading off accuracy for speed.
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Users have also requested a way to set a more “global” default precision.
One possible mechanism to do this is via a scope, e.g.:
I would suggest that it should override only operations with default precision.
“3 pass bfloat16” is coincidentally very close to (slightly higher than?) the precision of Nvidia’s new “tensorfloat32”. So that could also be a good name for this intermediate precision on TPUs