question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

jit `lax.scatter_add` does not improve performance in CPU

See original GitHub issue

In CPU, the following script

import numpy as onp
import jax.numpy as np
from jax import jit, lax
from jax.config import config; config.update("jax_platform_name", "cpu")

@jit
def vector_to_tril_matrix(t):
    idx = np.reshape(np.arange(79 * 79), (79, 79))[onp.tril_indices(79, 0)]
    x = lax.scatter_add(np.zeros((t.shape[0], 79 * 79)), np.expand_dims(idx, axis=-1), t,
                        lax.ScatterDimensionNumbers(update_window_dims=range(t.ndim - 1),
                                                    inserted_window_dims=(t.ndim - 1,),
                                                    scatter_dims_to_operand_dims=(t.ndim - 1,)))
    return np.reshape(x, (-1, 79, 79))

%time vector_to_tril_matrix(np.ones((8000, 3160)))
%time vector_to_tril_matrix(np.ones((8000, 3160)))

returns 5.61 s and 5.28 s, while in GPU the script returns 631 ms and 9.29 ms.

Because the difference of pre-cached and after-cached are large in GPU, I would expect the same happens for CPU but it seems not be the case.

However, testing for a smaller batch of input in CPU, we can see that jit works. For example,

%time vector_to_tril_matrix(np.ones((8, 3160)))
%time vector_to_tril_matrix(np.ones((8, 3160)))

returns 254 ms and 786 µs, which is expected.

Because the shape is decreased by 1000, I would expect the speed of vector_to_tril_matrix(np.ones((8000, 3160))) would be around 786 µs * 1000 = 786 ms (or much smaller than that if vectorization works here) but the first test shows that it is not the case.

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Comments:12 (12 by maintainers)

github_iconTop GitHub Comments

4reactions
jlebarcommented, May 19, 2019

The commit mentioned above should greatly increase the speed of this computation. I observed ~20x faster on my machine.

Note that you may still observe the fact that you get a nonlinear slowdown as you increase the computation’s size. That is, if the tensors are 10x bigger, the computation may be 100x slower, or worse. This is a fundamental property of CPUs, caches, etc.

Thank you for reporting this bug!

2reactions
mattjjcommented, May 10, 2019

Looking like it takes much effort for XLA devs to support various platforms because the implementations are done separately.

Yes, but the promise of XLA is to solve that problem once and for all and remove it as a concern for its users.

That is, systems like TF often need to implement separate CPU and GPU versions of many ops (and those systems often have O(hundreds) or O(thousands) of ops). XLA aims to solve that problem once and for all by having a relatively small number of ops (like O(dozens)) and being able to generate code for the ops and their compositions for multiple different backends. That makes it super easy to bring up a new front-end system like JAX that can immediately target CPU or GPU (or TPU!). But the XLA team still has to solve the hard problem of generating different instructions and kernels for different platforms.

So XLA devs are solving the hard multi-device problems for us (and other libraries like JAX), and moreover their task is much better scoped because they rely on a small number of ops and compositionality.

Read more comments on GitHub >

github_iconTop Results From Across the Web

`jax.jit` not improving in place update performance for large ...
I am trying to apply a number of in place updates to a 2D matrix. It appears that using jit to the in...
Read more >
jax.lax.scatter_add - JAX documentation
If true, may improve performance on some backends. JAX does not check this promise: if the updated elements overlap when unique_indices is True...
Read more >
Fast Finite Width Neural Tangent Kernel - arXiv
(vs the pseudo-NTK) should lead to improved downstream task performance due to a better infinite width/linearization.
Read more >
Source code for numpyro.util
This utility tells XLA that there are `n` host (CPU) devices available to use. ... Checks if `x` is not an array generated...
Read more >
Fast Finite Width Neural Tangent Kernel
proving efficiency in a wide range of NN architectures on all hardware platforms. ... can be reused across VJP calls, resulting in NO...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found