question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

FFT on CPU noticeably slower than SciPy's FFT

See original GitHub issue

The FFT implementation in JAX seems to be noticeably slower than the one in SciPy even though both use some flavor of the PocketFFT FFT implementation.

from jax.config import config

config.update("jax_enable_x64", True)

import timeit
import numpy as np
from scipy import fft as sp_fft
from jax import numpy as jnp
from jax import jit, random

key = random.PRNGKey(42)
jr = random.normal(key, (2**26, ))
r = np.array(jr)

N_IT = 7

timing = timeit.timeit(lambda: np.fft.fft(r), number=N_IT) / N_IT
print(f"NumPy: {timing} s")
timing = timeit.timeit(lambda: sp_fft.fft(r), number=N_IT) / N_IT
print(f"SciPy: {timing} s")

jax_fft = jnp.fft.fft
timing = timeit.timeit(
    lambda: jax_fft(jr).block_until_ready(), number=N_IT
) / N_IT
print(f"JAX (unjitted): {timing} s")

jax_fft_jit = jit(jax_fft)
jax_fft_jit(jr)  # Warm-up
timing = timeit.timeit(
    lambda: jax_fft_jit(jr).block_until_ready(), number=N_IT
) / N_IT
print(f"JAX (jitted): {timing} s")

On an AMD Ryzen 7 4800H (with Radeon Graphics) with JAX 0.2.18 and jaxlib 0.1.69 installed from PyPI, I get the following timings:

WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
NumPy: 3.5420487681424544 s
SciPy: 1.9076589474290293 s
JAX (unjitted): 4.5806920641433475 s
JAX (jitted): 4.63735637514245 s

The timings improve if I compile JAX(lib) myself though it still compares unfavorably to SciPy (compiled JAX (unjitted): 3.506302507857202 s).

My hypothesis is that JAX is distributing binaries that are suboptimal for recent AMD CPUs and much more importantly JAX is probably using some outdated version of PocketFFT.

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:32 (23 by maintainers)

github_iconTop GitHub Comments

1reaction
mreineckcommented, Aug 18, 2021

Just to clarify earlier parts of the discussion: the C version of pocketfft is the older one which should not be used anymore, unless the available environment is C only. The C++ version of pocketfft has better performance and many more capabilities and has the nice property of coming as a single header. If you have access to C++17, I recommend the FFT component of ducc0, which is the evolution of C++ pocketfft with many minor improvements (but nothing really groundbreaking).

1reaction
mreineckcommented, Aug 18, 2021

I don’t think we can use ducc0 ourselves in JAX, because it has a GPL license and JAX has an Apache license.

That’s correct, but if we are talking about the FFT component only, I think I could make that available under BSD terms. Please let me know if you’d be interested in that.

(The directly FFT-related files in ducc0 still have the BSD licensing header, as far as I remember, but some of the support header files don’t, so they’d need to be adjusted.)

Read more comments on GitHub >

github_iconTop Results From Across the Web

PyFFTW slower than SciPy FFT? - python
PyFFTW is approx. 7x slower than SciPy FFT which differs from other users experiences. What is wrong in this code? Python 2.7.9, PyFFTW...
Read more >
FluidFFT: Common API (C++ and Python) for Fast Fourier ...
A problem is that for one-dimensional FFT, all the data have to be located in the memory of the process that perform the...
Read more >
tcFFT: Accelerating Half-Precision FFT through Tensor Cores
The DFT is useful in many fields, but computing it directly from the definition is often too slow to be practical. One of...
Read more >
Numpy FFT lots slower than Octave FFT? : r/Python
It seems that (on my computer anyway), if you repeat the numpy fft, then it is faster the second time. Perhaps it learns...
Read more >
Migrating from Fiji & CLIJ to Napari & pyclEsperanto
The problem is it will require more complex java and python wrappers to keep track of and pass the FFT plans to and...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found