question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

STFT/ISTFT pre-allocate output buffers

See original GitHub issue

Description

STFT allocates an output buffer, but sometimes you might want to have it compute directly into an existing buffer. For example, in griffin-lim, the method alternates stft/istft for each iterate, which is then discarded. It would be better if we could give it an out= variable, which it would use instead of allocating a new buffer; this way, we could cut down on redundant memory allocations.

If the shape and/or dtype doesn’t line up, then it should throw an exception.

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Comments:24 (24 by maintainers)

github_iconTop GitHub Comments

1reaction
bmcfeecommented, Jun 24, 2022

Okay, finally shook out the bugs in istft:

Minimal-copy / pre-allocated istft
def istft(
    stft_matrix,
    hop_length=None,
    win_length=None,
    n_fft=None,
    window="hann",
    center=True,
    dtype=None,
    length=None,
    out=None,
):

    if n_fft is None:
        n_fft = 2 * (stft_matrix.shape[-2] - 1)

    # By default, use the entire frame
    if win_length is None:
        win_length = n_fft

    # Set the default hop, if it's not already specified
    if hop_length is None:
        hop_length = int(win_length // 4)

    ifft_window = get_window(window, win_length, fftbins=True)

    # Pad out to match n_fft, and add broadcasting axes
    ifft_window = util.pad_center(ifft_window, size=n_fft)
    ifft_window = util.expand_to(ifft_window, ndim=stft_matrix.ndim, axes=-2)

    # For efficiency, trim STFT frames according to signal length if available
    if length:
        if center:
            padded_length = length + 2 * (n_fft//2)
        else:
            padded_length = length
        n_frames = min(stft_matrix.shape[-1], int(np.ceil(padded_length / hop_length)))
    else:
        n_frames = stft_matrix.shape[-1]

    if dtype is None:
        dtype = util.dtype_c2r(stft_matrix.dtype)

    shape = list(stft_matrix.shape[:-2])
    expected_signal_len = n_fft + hop_length * (n_frames - 1)
    
    if length:
        expected_signal_len = length
    elif center:
        expected_signal_len -= 2*(n_fft//2)
    
    shape.append(expected_signal_len)
    
    if out is None:
        y = np.zeros(shape, dtype=dtype)
    elif not np.allclose(out.shape, shape):
        raise ParameterError(f'Shape mismatch for provided output array out.shape={out.shape} != {shape}')
    else:
        y = out
        # Since we'll be doing overlap-add here, this needs to be initialized to zero.
        y.fill(0.)

    fft = get_fftlib()

    if center:
        # First frame that does not depend on padding
        #  k * hop_length - n_fft//2 >= 0
        # k * hop_length >= n_fft // 2
        # k >= (n_fft//2 / hop_length)
        
        start_frame = int(np.ceil((n_fft//2) / hop_length))
        
        # Do overlap-add on the head block
        ytmp = ifft_window * fft.irfft(stft_matrix[..., :start_frame], n=n_fft, axis=-2)
        
        shape[-1] = n_fft + hop_length * (start_frame - 1)
        head_buffer = np.zeros(shape, dtype=dtype)
        
        __overlap_add(head_buffer, ytmp, hop_length)
        
        # If y is smaller than the head buffer, take everything
        if y.shape[-1] < shape[-1] - n_fft//2:
            y[..., :] = head_buffer[..., n_fft//2:y.shape[-1]+n_fft//2]
        else:
            # Trim off the first n_fft//2 samples from the head and copy into target buffer
            y[..., :shape[-1]-n_fft//2] = head_buffer[..., n_fft//2:]
        
        # This offset compensates for any differences between frame alignment
        # and padding truncation
        offset = start_frame * hop_length - n_fft//2
        
    else:
        start_frame = 0
        offset = 0

    n_columns = util.MAX_MEM_BLOCK // (
        np.prod(stft_matrix.shape[:-1]) * stft_matrix.itemsize
    )
    n_columns = max(n_columns, 1)
    
    frame = 0
    for bl_s in range(start_frame, n_frames, n_columns):
        
        bl_t = min(bl_s + n_columns, n_frames)

        # invert the block and apply the window function
        ytmp = ifft_window * fft.irfft(stft_matrix[..., bl_s:bl_t], n=n_fft, axis=-2)

        # Overlap-add the istft block starting at the i'th frame
        __overlap_add(y[..., frame * hop_length + offset:], ytmp, hop_length)

        frame += bl_t - bl_s

    # Normalize by sum of squared window
    ifft_window_sum = window_sumsquare(
        window=window,
        n_frames=n_frames,
        win_length=win_length,
        n_fft=n_fft,
        hop_length=hop_length,
        dtype=dtype,
    )
    
    if center:
        start = n_fft//2
    else:
        start = 0
           
    ifft_window_sum = util.fix_length(ifft_window_sum[..., start:], size=y.shape[-1])

    approx_nonzero_indices = ifft_window_sum > util.tiny(ifft_window_sum)
    
    y[..., approx_nonzero_indices] /= ifft_window_sum[approx_nonzero_indices]
    
    return y

@jit(nopython=True)
def __overlap_add(y, ytmp, hop_length):
    # numba-accelerated overlap add for inverse stft
    # y is the pre-allocated output buffer
    # ytmp is the windowed inverse-stft frames
    # hop_length is the hop-length of the STFT analysis

    n_fft = ytmp.shape[-2]
    N = n_fft
    for frame in range(ytmp.shape[-1]):
        sample = frame * hop_length
        if N > y.shape[-1] - sample:
            N = y.shape[-1] - sample
        
        y[..., sample : (sample + N)] += ytmp[..., :N, frame]

The savings are … non-existent. For a similar benchmark to the above, we get about 20ms and 5-10MB reduction in footprint. So, not nothing, but maybe not worth all the effort?

It does pass the following battery of tests though:

for center in [True, False]:
    for N in [512, 1024, 2048, 10000]:
        y = np.random.randn(2, N)
        for n_fft in [1023, 1024, 1025]:
            if (not center) and N < n_fft:
                continue
            for hop_length in [129, 255, 256, 257, 384, n_fft]:
                for length in [None, y.shape[-1], y.shape[-1]//2]:
                    print(N, n_fft, hop_length, center, length)
                    D = librosa.stft(y, hop_length=hop_length, n_fft=n_fft, center=center)
                    yi1 = librosa.istft(D, hop_length=hop_length, n_fft=n_fft, center=center, length=length)
                    yi2 = istft(D, hop_length=hop_length, n_fft=n_fft, center=center, length=length)
                    assert np.allclose(yi1, yi2)
                    out = np.empty_like(yi2)
                    yi3 = istft(D, hop_length=hop_length, n_fft=n_fft, center=center, out=out, length=length)
                    assert yi3 is out
                    assert np.allclose(yi2, yi3)
0reactions
bmcfeecommented, Jun 24, 2022

Punting this to 0.10 due to unanticipated API breakage.

Read more comments on GitHub >

github_iconTop Results From Across the Web

How do I pre-allocate an output buffer for the decoder?
To pre-allocate your own buffer, you should: Declare a pointer that will reference the buffer you wish to allocate. (This pointer can either...
Read more >
Preallocation - MATLAB & Simulink
Loops that incrementally increase the size of an array each time through the loop can adversely affect performance and memory use.
Read more >
Preallocating contiguous DMA buffers on Windows
Technical Document #3: Preallocating contiguous DMA buffers on Windows. The WinDriver DMA APIs do not limit the size of allocated DMA buffers.
Read more >
Pytorch - why does preallocating memory cause "trying to ...
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found