Einsum is slow
See original GitHub issueAn example where np.einsum
is slower than manual matmul/transpositions. (#1966 works equally fast for me, but this example is consistently slower) on CPU and GPU.
import jax.numpy as np
import jax.random as random
from jax.api import jit
a = random.normal(random.PRNGKey(1), (100, 20, 20, 3))
b = random.normal(random.PRNGKey(2), (200, 20, 20, 3))
@jit
def matmul(a, b):
return np.transpose(np.matmul(np.transpose(a, axes=(1, 2, 0, 3)), np.transpose(b, axes=(1, 2, 3, 0))), axes=(2, 3, 0, 1))
@jit
def einsum(a, b):
return np.einsum('nxyc,mxyc->nmxy', a, b, optimize=True)
np.sum(np.abs(einsum(a, b) - matmul(a, b)))
%timeit einsum(a, b).block_until_ready()
%timeit matmul(a, b).block_until_ready()
Also note that if you run it on CPU, the difference between the method outputs becomes non-zero
DeviceArray(0.01003271, dtype=float32)
- not sure how concerning it is.
Issue Analytics
- State:
- Created 4 years ago
- Reactions:2
- Comments:9 (6 by maintainers)
Top Results From Across the Web
Why is numpy's einsum slower than numpy's built-in functions?
Can someone explain why einsum is so much slower here? If it matters, here is my numpy config: In [6]: np.show_config() lapack_info: libraries...
Read more >Why PyTorch einsum is significantly slower than transpose
I have been tinkering with some DL models and wanted to implement part of it using PyTorch einsum. Before doing so I was...
Read more >Tensor contractions with numpy's einsum function seem slow ...
Hi everyone, It seems the manual way of doing tensor contractions using matrix multiplication is much faster than using numpy's einsum function.
Read more >einsum running 100x slower than expected when controlling ...
For two simple einsum cases, Case 2 is running ~30% slower than Case 1, despite needing to perform 101x fewer ops.
Read more >Write Better And Faster Python Using Einstein Notation
... hard to read, and even slow. This was the case for me until I discovered NumPy's einsum function a while ago and...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
FYI, I have revisited the example below on:
Will file bugs agains XLA:CPU and XLA:GPU!
Ran some benchmarks on several libraries, time in seconds on a Ryzen 3900X