np.take and np.einsum aren't fused properly
See original GitHub issueI’m trying to translate Product Key Memory in PyTorch into JAX, and this requires the translation of nn.EmbeddingBag with per_sample_weights, as I haven’t found any counterpart in JAX (but if you know, please let me know). For this, I wrote scatter_weighted_sum, the weighted sum version of scatter_add, in hope that it’ll be efficient with fusing. However, jit didn’t fuse np.take, reshape and np.einsum properly, which resulted in a huge intermediate object. Since #1979 concluded this sort of ops will be fused on GPU, I was wondering what is causing this problem. If by any chance this isn’t supported on GPU, should this work on TPU? I’m using JAX ver. 0.1.67 with various Colab GPUs.
hidden_dim = 512
n_keys = 512
batch = 2 ** 15
knn = 32
heads = 4
key = random.PRNGKey(0)
values = random.normal(key, (n_keys ** 2, hidden_dim))
indices = random.randint(key, (batch*heads, knn), 0, n_keys ** 2)
weights = random.normal(key, (batch*heads, knn))
@jit
def scatter_weighted_sum(inputs, indices, weights):
num_bags = weights.shape[-1]
dim = inputs.shape[-1]
indices = indices.reshape(-1)
tmp = inputs.take(indices, axis=0).reshape(-1, num_bags, dim)
return np.einsum('ind, in -> id', tmp, weights)
Issue Analytics
- State:
- Created 3 years ago
- Reactions:2
- Comments:8 (3 by maintainers)
Top Results From Across the Web
numpy einsum/tensordot with shared non-contracted axis
In fact, np.einsum is quite good in method1 since it compute the result ... On my 6-core machine, it barely use 50% of...
Read more >numpy.einsum — NumPy v1.24 Manual
When there is only one operand, no axes are summed, and no output parameter is provided, a view into the operand is returned...
Read more >Release Notes — NumPy v1.12 Manual - GitHub Pages
Order of operations in np.einsum can now be optimized for large speed improvements. New signature argument to np.vectorize for vectorizing with core ...
Read more >Release Notes — NumPy v1.14 Manual
NaT values in datetime arrays are now properly aligned. ... The np.einsum function will use BLAS when possible and optimize by default¶.
Read more >Deformable convolution and other custom ops - MLIR
np.take and np.einsum aren't fused properly ... level to enable the “mega” op while exposing the decomposed forms where there is no support....
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Using embeddingbag instead of mebedding+reduce will have a “huge” impact not only on performance but also on memory footprint for large number of indices. By computing the “sum” operation on the fly will save great amount of memory. We do not need to store all the table lookup results. A higher priority should be given to this issue.
I agree, XLA is not fusing the gather with the reduction, which would avoid materializing the intermediate value, and I don’t see a good reason for that. I filed a couple of bugs for the XLA teams.
A possible workaround would be to use
lax.map
to compute the elements in chunks.