question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

matmul slow for complex dtypes

See original GitHub issue

Hey, thanks for the great work here!

I noticed that matmuls for complex dtypes are ~ 20x to 25x slower on my macbook than they are for real dtypes. Here is a simple code that does the timing. Wondering if this is expected, or what I can do to speed that up. Thanks!

import numpy as np
import jax
from jax import config
config.update('jax_enable_x64',True)
import time


@jax.jit    
def do_matvec_simple(matrix, vector):
    res = 0
    for _ in range(100):
        res += matrix @ vector
    return res

@jax.jit    
def do_matmul_simple(matrix1, matrix2):
    res = 0
    for _ in range(100):
        res += matrix1 @ matrix2
    return res
    
def run_timings_dot(dtype, D):
    matrix = jax.numpy.array(np.random.rand(D,D).astype(dtype))
    vector = jax.numpy.array(np.random.rand(D).astype(dtype))
    t1=time.time()
    for _ in range(100):
        res = matrix @ vector
    res.block_until_ready()
    print(f'loop over 100 matrix-vector muls in dtype {np.dtype(dtype).name}', time.time() -t1)

    res = do_matvec_simple(matrix, vector)
    res.block_until_ready()
    t1 = time.time()
    res = do_matvec_simple(matrix, vector)
    res.block_until_ready()
    print(f'jit 100 do_matvec_simple for dtype {np.dtype(dtype).name}', time.time() - t1)
    
def run_timings_matmul_simple(dtype, D):
    A = jax.numpy.array(np.random.rand(D,D).astype(dtype))
    B = jax.numpy.array(np.random.rand(D,D).astype(dtype))
    t1=time.time()
    for _ in range(100):
        res = A@B
    res.block_until_ready()
    print(f'loop over 100 matrix-matrix muls in dtype {np.dtype(dtype).name}', time.time() -t1)

    res = do_matmul_simple(A,B)
    res.block_until_ready()
    t1 = time.time()
    res = do_matmul_simple(A,B)
    res.block_until_ready()
    print(f'jit 100 do_matmul_simple for dtype {np.dtype(dtype).name}', time.time() - t1)        
    

print('########  timings for matrix-vector  ###########')
print('      ----------   float64   --------------')
run_timings_dot(np.float64, 1000)
print('      ----------   complex128   --------------')
run_timings_dot(np.complex128, 1000)
print()
print()
print('########  timings for matrix-matrix  ###########')
print('      ----------   float64   --------------')
run_timings_matmul_simple(np.float64, 400)
print('      ----------   complex128   --------------')
run_timings_matmul_simple(np.complex128, 400)

update: disabling double precision seems to increase the slowdown to ~ 100x

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:7 (4 by maintainers)

github_iconTop GitHub Comments

2reactions
hawkinspcommented, May 13, 2020

Yes, I think there’s a bug here. On CPU, XLA is falling back to a naive implementation of matmul for complex types instead of calling into an optimized implementation as it does for floating point types. This should be easy to fix.

0reactions
mganahlcommented, Jun 27, 2020

Thanks for fixing!

Read more comments on GitHub >

github_iconTop Results From Across the Web

numpy matmul very slow when one of the two is np.array ...
A.real is not a "new" array; it's a method of accessing the real values of the complex dtype. I don't have time to...
Read more >
numpy matmul very slow when one of the two is np.array ...
I discovered that when matmul ing two numpy arrays, if one of the two is the real or imaginary part of a bigger...
Read more >
Faster Matrix Multiplications in Numpy - Benjamin Johnston
Matrix multiplications in NumPy are reasonably fast without the need for optimization. However, if every second counts, it is possible to ...
Read more >
tf.linalg.matmul | TensorFlow v2.11.0
c = tf.matmul(a, b) c # `a` * `b` <tf.Tensor: shape=(2, 2), dtype=int32, numpy= array([[ 58, 64], [139, 154]], dtype=int32)>. A batch matrix...
Read more >
Supported NumPy features - Numba documentation
NumPy dtypes provide type information useful when compiling, ... On Python 3.5 and above, the matrix multiplication operator from PEP 465 (i.e. a...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found