matmul slow for complex dtypes
See original GitHub issueHey, thanks for the great work here!
I noticed that matmuls for complex dtypes are ~ 20x to 25x slower on my macbook than they are for real dtypes. Here is a simple code that does the timing. Wondering if this is expected, or what I can do to speed that up. Thanks!
import numpy as np
import jax
from jax import config
config.update('jax_enable_x64',True)
import time
@jax.jit
def do_matvec_simple(matrix, vector):
res = 0
for _ in range(100):
res += matrix @ vector
return res
@jax.jit
def do_matmul_simple(matrix1, matrix2):
res = 0
for _ in range(100):
res += matrix1 @ matrix2
return res
def run_timings_dot(dtype, D):
matrix = jax.numpy.array(np.random.rand(D,D).astype(dtype))
vector = jax.numpy.array(np.random.rand(D).astype(dtype))
t1=time.time()
for _ in range(100):
res = matrix @ vector
res.block_until_ready()
print(f'loop over 100 matrix-vector muls in dtype {np.dtype(dtype).name}', time.time() -t1)
res = do_matvec_simple(matrix, vector)
res.block_until_ready()
t1 = time.time()
res = do_matvec_simple(matrix, vector)
res.block_until_ready()
print(f'jit 100 do_matvec_simple for dtype {np.dtype(dtype).name}', time.time() - t1)
def run_timings_matmul_simple(dtype, D):
A = jax.numpy.array(np.random.rand(D,D).astype(dtype))
B = jax.numpy.array(np.random.rand(D,D).astype(dtype))
t1=time.time()
for _ in range(100):
res = A@B
res.block_until_ready()
print(f'loop over 100 matrix-matrix muls in dtype {np.dtype(dtype).name}', time.time() -t1)
res = do_matmul_simple(A,B)
res.block_until_ready()
t1 = time.time()
res = do_matmul_simple(A,B)
res.block_until_ready()
print(f'jit 100 do_matmul_simple for dtype {np.dtype(dtype).name}', time.time() - t1)
print('######## timings for matrix-vector ###########')
print(' ---------- float64 --------------')
run_timings_dot(np.float64, 1000)
print(' ---------- complex128 --------------')
run_timings_dot(np.complex128, 1000)
print()
print()
print('######## timings for matrix-matrix ###########')
print(' ---------- float64 --------------')
run_timings_matmul_simple(np.float64, 400)
print(' ---------- complex128 --------------')
run_timings_matmul_simple(np.complex128, 400)
update: disabling double precision seems to increase the slowdown to ~ 100x
Issue Analytics
- State:
- Created 3 years ago
- Comments:7 (4 by maintainers)
Top Results From Across the Web
numpy matmul very slow when one of the two is np.array ...
A.real is not a "new" array; it's a method of accessing the real values of the complex dtype. I don't have time to...
Read more >numpy matmul very slow when one of the two is np.array ...
I discovered that when matmul ing two numpy arrays, if one of the two is the real or imaginary part of a bigger...
Read more >Faster Matrix Multiplications in Numpy - Benjamin Johnston
Matrix multiplications in NumPy are reasonably fast without the need for optimization. However, if every second counts, it is possible to ...
Read more >tf.linalg.matmul | TensorFlow v2.11.0
c = tf.matmul(a, b) c # `a` * `b` <tf.Tensor: shape=(2, 2), dtype=int32, numpy= array([[ 58, 64], [139, 154]], dtype=int32)>. A batch matrix...
Read more >Supported NumPy features - Numba documentation
NumPy dtypes provide type information useful when compiling, ... On Python 3.5 and above, the matrix multiplication operator from PEP 465 (i.e. a...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Yes, I think there’s a bug here. On CPU, XLA is falling back to a naive implementation of matmul for complex types instead of calling into an optimized implementation as it does for floating point types. This should be easy to fix.
Thanks for fixing!