question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

np.take and np.einsum aren't fused properly

See original GitHub issue

I’m trying to translate Product Key Memory in PyTorch into JAX, and this requires the translation of nn.EmbeddingBag with per_sample_weights, as I haven’t found any counterpart in JAX (but if you know, please let me know). For this, I wrote scatter_weighted_sum, the weighted sum version of scatter_add, in hope that it’ll be efficient with fusing. However, jit didn’t fuse np.take, reshape and np.einsum properly, which resulted in a huge intermediate object. Since #1979 concluded this sort of ops will be fused on GPU, I was wondering what is causing this problem. If by any chance this isn’t supported on GPU, should this work on TPU? I’m using JAX ver. 0.1.67 with various Colab GPUs.

hidden_dim = 512
n_keys = 512
batch = 2 ** 15
knn = 32
heads = 4
key = random.PRNGKey(0)
values = random.normal(key, (n_keys ** 2, hidden_dim))         
indices = random.randint(key, (batch*heads, knn), 0, n_keys ** 2)     
weights = random.normal(key, (batch*heads, knn))     
@jit
def scatter_weighted_sum(inputs, indices, weights): 
    num_bags = weights.shape[-1]
    dim = inputs.shape[-1]
    indices = indices.reshape(-1)
    tmp = inputs.take(indices, axis=0).reshape(-1, num_bags, dim)
    return np.einsum('ind, in -> id', tmp, weights)

Issue Analytics

  • State:open
  • Created 3 years ago
  • Reactions:2
  • Comments:8 (3 by maintainers)

github_iconTop GitHub Comments

3reactions
shz0116commented, Jan 28, 2021

Using embeddingbag instead of mebedding+reduce will have a “huge” impact not only on performance but also on memory footprint for large number of indices. By computing the “sum” operation on the fly will save great amount of memory. We do not need to store all the table lookup results. A higher priority should be given to this issue.

3reactions
hawkinspcommented, May 27, 2020

I agree, XLA is not fusing the gather with the reduction, which would avoid materializing the intermediate value, and I don’t see a good reason for that. I filed a couple of bugs for the XLA teams.

A possible workaround would be to use lax.map to compute the elements in chunks.

Read more comments on GitHub >

github_iconTop Results From Across the Web

numpy einsum/tensordot with shared non-contracted axis
In fact, np.einsum is quite good in method1 since it compute the result ... On my 6-core machine, it barely use 50% of...
Read more >
numpy.einsum — NumPy v1.24 Manual
When there is only one operand, no axes are summed, and no output parameter is provided, a view into the operand is returned...
Read more >
Release Notes — NumPy v1.12 Manual - GitHub Pages
Order of operations in np.einsum can now be optimized for large speed improvements. New signature argument to np.vectorize for vectorizing with core ...
Read more >
Release Notes — NumPy v1.14 Manual
NaT values in datetime arrays are now properly aligned. ... The np.einsum function will use BLAS when possible and optimize by default¶.
Read more >
Deformable convolution and other custom ops - MLIR
np.take and np.einsum aren't fused properly ... level to enable the “mega” op while exposing the decomposed forms where there is no support....
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found