question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Matrix multiplication precision API

See original GitHub issue

Operations that do matrix multiplication in JAX accept a precison for controlling precision when executed on TPUs.

The current API seems non-ideal to me:

  1. You have to pass an enum value (e.g., np.dot(x, y, precision=lax.Precison.HIGHEST)). This is a little cumbersome and inconsistent with most NumPy/SciPy APIs which use strings (e.g., np.dot(x, y, precision='highest')).
  2. The current names for precision levels (“highest”, “high” and “default”) are not very descriptive. In my ideal world we would use some direct indication of the corresponding precision (e.g., bfloat16 multiplication with float32 accumulation), but as the very least can we switch “default” to “low”?
  3. The default low precision is a bit of a footgun, at least when doing anything that isn’t implementing a neural net layer. In my opinion, it would be much safer to use “highest” precision by default (which isn’t that much slower) on float32 data. Neural net libraries, of course, can default to lower precision, so this really only effects users who directly use NumPy APIs or the @ infix operator.

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Reactions:10
  • Comments:6 (6 by maintainers)

github_iconTop GitHub Comments

5reactions
hawkinspcommented, Aug 19, 2020

Users have also requested a way to set a more “global” default precision.

One possible mechanism to do this is via a scope, e.g.:

with jax.precision("highest"):
  ...

I would suggest that it should override only operations with default precision.

1reaction
shoyercommented, May 16, 2020

“3 pass bfloat16” is coincidentally very close to (slightly higher than?) the precision of Nvidia’s new “tensorfloat32”. So that could also be a good name for this intermediate precision on TPUs

Read more comments on GitHub >

github_iconTop Results From Across the Web

Matrix Multiplication - AI Engine API User Guide - Xilinx
This class template is parametrized with the matrix multiplication shape (MxKxN), the data types and, optionally, the requested accmululation precision.
Read more >
Precision error on matrix multiplication - Stack Overflow
A problem with your code as it now stands is that the multiplication is being done in float precision then added to a...
Read more >
Matrix Multiplication — OneDNN documentation
The matrix multiplication (MatMul) primitive computes the product of two 2D tensors ... C++ API example demonstrating how one can perform reduced precision...
Read more >
Matrix math for the web - Web APIs - MDN Web Docs
This article explores how to create matrices and how to use them with CSS ... What does multiplying by the identity matrix look...
Read more >
jax.default_matmul_precision - JAX documentation
Some platforms, like TPU, offer configurable precision levels for matrix multiplication and convolution computations, trading off accuracy for speed.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found