FFT on CPU noticeably slower than SciPy's FFT
See original GitHub issueThe FFT implementation in JAX seems to be noticeably slower than the one in SciPy even though both use some flavor of the PocketFFT FFT implementation.
from jax.config import config
config.update("jax_enable_x64", True)
import timeit
import numpy as np
from scipy import fft as sp_fft
from jax import numpy as jnp
from jax import jit, random
key = random.PRNGKey(42)
jr = random.normal(key, (2**26, ))
r = np.array(jr)
N_IT = 7
timing = timeit.timeit(lambda: np.fft.fft(r), number=N_IT) / N_IT
print(f"NumPy: {timing} s")
timing = timeit.timeit(lambda: sp_fft.fft(r), number=N_IT) / N_IT
print(f"SciPy: {timing} s")
jax_fft = jnp.fft.fft
timing = timeit.timeit(
lambda: jax_fft(jr).block_until_ready(), number=N_IT
) / N_IT
print(f"JAX (unjitted): {timing} s")
jax_fft_jit = jit(jax_fft)
jax_fft_jit(jr) # Warm-up
timing = timeit.timeit(
lambda: jax_fft_jit(jr).block_until_ready(), number=N_IT
) / N_IT
print(f"JAX (jitted): {timing} s")
On an AMD Ryzen 7 4800H (with Radeon Graphics) with JAX 0.2.18 and jaxlib 0.1.69 installed from PyPI, I get the following timings:
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
NumPy: 3.5420487681424544 s
SciPy: 1.9076589474290293 s
JAX (unjitted): 4.5806920641433475 s
JAX (jitted): 4.63735637514245 s
The timings improve if I compile JAX(lib) myself though it still compares unfavorably to SciPy (compiled JAX (unjitted): 3.506302507857202 s).
My hypothesis is that JAX is distributing binaries that are suboptimal for recent AMD CPUs and much more importantly JAX is probably using some outdated version of PocketFFT.
Issue Analytics
- State:
- Created 2 years ago
- Comments:32 (23 by maintainers)
Top Results From Across the Web
PyFFTW slower than SciPy FFT? - python
PyFFTW is approx. 7x slower than SciPy FFT which differs from other users experiences. What is wrong in this code? Python 2.7.9, PyFFTW...
Read more >FluidFFT: Common API (C++ and Python) for Fast Fourier ...
A problem is that for one-dimensional FFT, all the data have to be located in the memory of the process that perform the...
Read more >tcFFT: Accelerating Half-Precision FFT through Tensor Cores
The DFT is useful in many fields, but computing it directly from the definition is often too slow to be practical. One of...
Read more >Numpy FFT lots slower than Octave FFT? : r/Python
It seems that (on my computer anyway), if you repeat the numpy fft, then it is faster the second time. Perhaps it learns...
Read more >Migrating from Fiji & CLIJ to Napari & pyclEsperanto
The problem is it will require more complex java and python wrappers to keep track of and pass the FFT plans to and...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Just to clarify earlier parts of the discussion: the C version of pocketfft is the older one which should not be used anymore, unless the available environment is C only. The C++ version of pocketfft has better performance and many more capabilities and has the nice property of coming as a single header. If you have access to C++17, I recommend the FFT component of ducc0, which is the evolution of C++ pocketfft with many minor improvements (but nothing really groundbreaking).
That’s correct, but if we are talking about the FFT component only, I think I could make that available under BSD terms. Please let me know if you’d be interested in that.
(The directly FFT-related files in ducc0 still have the BSD licensing header, as far as I remember, but some of the support header files don’t, so they’d need to be adjusted.)