jit `lax.scatter_add` does not improve performance in CPU
See original GitHub issueIn CPU, the following script
import numpy as onp
import jax.numpy as np
from jax import jit, lax
from jax.config import config; config.update("jax_platform_name", "cpu")
@jit
def vector_to_tril_matrix(t):
idx = np.reshape(np.arange(79 * 79), (79, 79))[onp.tril_indices(79, 0)]
x = lax.scatter_add(np.zeros((t.shape[0], 79 * 79)), np.expand_dims(idx, axis=-1), t,
lax.ScatterDimensionNumbers(update_window_dims=range(t.ndim - 1),
inserted_window_dims=(t.ndim - 1,),
scatter_dims_to_operand_dims=(t.ndim - 1,)))
return np.reshape(x, (-1, 79, 79))
%time vector_to_tril_matrix(np.ones((8000, 3160)))
%time vector_to_tril_matrix(np.ones((8000, 3160)))
returns 5.61 s
and 5.28 s
,
while in GPU the script returns 631 ms
and 9.29 ms
.
Because the difference of pre-cached and after-cached are large in GPU, I would expect the same happens for CPU but it seems not be the case.
However, testing for a smaller batch of input in CPU, we can see that jit
works. For example,
%time vector_to_tril_matrix(np.ones((8, 3160)))
%time vector_to_tril_matrix(np.ones((8, 3160)))
returns 254 ms
and 786 µs
, which is expected.
Because the shape is decreased by 1000, I would expect the speed of vector_to_tril_matrix(np.ones((8000, 3160)))
would be around 786 µs * 1000 = 786 ms
(or much smaller than that if vectorization works here) but the first test shows that it is not the case.
Issue Analytics
- State:
- Created 4 years ago
- Comments:12 (12 by maintainers)
Top Results From Across the Web
`jax.jit` not improving in place update performance for large ...
I am trying to apply a number of in place updates to a 2D matrix. It appears that using jit to the in...
Read more >jax.lax.scatter_add - JAX documentation
If true, may improve performance on some backends. JAX does not check this promise: if the updated elements overlap when unique_indices is True...
Read more >Fast Finite Width Neural Tangent Kernel - arXiv
(vs the pseudo-NTK) should lead to improved downstream task performance due to a better infinite width/linearization.
Read more >Source code for numpyro.util
This utility tells XLA that there are `n` host (CPU) devices available to use. ... Checks if `x` is not an array generated...
Read more >Fast Finite Width Neural Tangent Kernel
proving efficiency in a wide range of NN architectures on all hardware platforms. ... can be reused across VJP calls, resulting in NO...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
The commit mentioned above should greatly increase the speed of this computation. I observed ~20x faster on my machine.
Note that you may still observe the fact that you get a nonlinear slowdown as you increase the computation’s size. That is, if the tensors are 10x bigger, the computation may be 100x slower, or worse. This is a fundamental property of CPUs, caches, etc.
Thank you for reporting this bug!
Yes, but the promise of XLA is to solve that problem once and for all and remove it as a concern for its users.
That is, systems like TF often need to implement separate CPU and GPU versions of many ops (and those systems often have O(hundreds) or O(thousands) of ops). XLA aims to solve that problem once and for all by having a relatively small number of ops (like O(dozens)) and being able to generate code for the ops and their compositions for multiple different backends. That makes it super easy to bring up a new front-end system like JAX that can immediately target CPU or GPU (or TPU!). But the XLA team still has to solve the hard problem of generating different instructions and kernels for different platforms.
So XLA devs are solving the hard multi-device problems for us (and other libraries like JAX), and moreover their task is much better scoped because they rely on a small number of ops and compositionality.