question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Depth-first vs. breadth-first traversal

See original GitHub issue

JTFS ran 40% faster than TS on GPU, and I couldn’t fathom why. On CPU it is several times slower, as expected. TS is an exact subset of JTFS, so this should be impossible.

Then it dawned on me, JTFS batches lowpassing for each n2. To test, I swapped the compute graph for TS to mimic JTFS: 87% speedup, beating JTFS. Note this is only the lowpassing stage, the n2<->n1 computes can still be batched in both TS and JTFS.

TS was still under 40% faster than JTFS in the config I tested, so torch’s FFT is either poorly optimized for singletons or exceptionally optimized for batching. Latter seems to be the case since JTFS is way slower on a multi-core CPU, and per the 40%, since all frequential scattering is batched. Averaged over 10 iters:

Q=8
N      | t_ts_old / t_ts_new | t_jtfs / t_ts_new
2048   | 1.69 | 1.37 
8192   | 1.67 | 1.22
32768  | 1.87 | 1.38
131072 | 1.80 | 1.31

Q=16
N      | t_ts_old / t_ts_new | t_jtfs / t_ts_new
2048   | 1.56 | 1.37 
8192   | 1.72 | 1.23
32768  | 2.16 | 1.28
131072 | 2.02 | 1.43

Think this should be convincing to rewrite the compute graph for Scattering1D. My implem below, still needs matching meta etc.

Scattering1D

def scattering1d(x, pad_fn, unpad, backend, J, log2_T, psi1, psi2, phi,
        ind_start=None, ind_end=None, oversampling=0,
        max_order=2, average=True, size_scattering=(0, 0, 0),
        out_type='array'):
    """
    Main function implementing the 1-D scattering transform.

    Parameters
    ----------
    x : Tensor
        a torch Tensor of size `(B, 1, N)` where `N` is the temporal size
    psi1 : dictionary
        a dictionary of filters (in the Fourier domain), with keys (`j`, `q`).
        `j` corresponds to the downsampling factor for
        :math:`x \\ast psi1[(j, q)]``, and `q` corresponds to a pitch class
        (chroma).
        * psi1[(j, n)] is itself a dictionary, with keys corresponding to the
        dilation factors: psi1[(j, n)][j2] corresponds to a support of size
        :math:`2^{J_\\text{max} - j_2}`, where :math:`J_\\text{max}` has been
        defined a priori (`J_max = size` of the padding support of the input)
        * psi1[(j, n)] only has real values;
        the tensors are complex so that broadcasting applies
    psi2 : dictionary
        a dictionary of filters, with keys (j2, n2). Same remarks as for psi1
    phi : dictionary
        a dictionary of filters of scale :math:`2^J` with keys (`j`)
        where :math:`2^j` is the downsampling factor.
        The array `phi[j]` is a real-valued filter.
    J : int
        scale of the scattering
    log2_T : int
        (log2 of) temporal support of low-pass filter, controlling amount of
        imposed time-shift invariance and maximum subsampling
    pad_left : int, optional
        how much to pad the signal on the left. Defaults to `0`
    pad_right : int, optional
        how much to pad the signal on the right. Defaults to `0`
    ind_start : dictionary of ints, optional
        indices to truncate the signal to recover only the
        parts which correspond to the actual signal after padding and
        downsampling. Defaults to None
    ind_end : dictionary of ints, optional
        See description of ind_start
    oversampling : int, optional
        how much to oversample the scattering (with respect to :math:`2^J`):
        the higher, the larger the resulting scattering
        tensor along time. Defaults to `0`
    order2 : boolean, optional
        Whether to compute the 2nd order or not. Defaults to `False`.
    average_U1 : boolean, optional
        whether to average the first order vector. Defaults to `True`
    size_scattering : tuple
        Contains the number of channels of the scattering, precomputed for
        speed-up. Defaults to `(0, 0, 0)`.
    vectorize : boolean, optional
        whether to return a dictionary or a tensor. Defaults to False.

    """
    subsample_fourier = backend.subsample_fourier
    modulus = backend.modulus
    rfft = backend.rfft
    ifft = backend.ifft
    irfft = backend.irfft
    cdgmm = backend.cdgmm
    concatenate = backend.concatenate
    concatenate_v2 = backend.concatenate_v2

    # S is simply a dictionary if we do not perform the averaging...
    batch_size = x.shape[0]
    kJ = max(log2_T - oversampling, 0)
    temporal_size = ind_end[kJ] - ind_start[kJ]
    out_S_0, out_S_1, out_S_2 = [], [], []

    # pad to a dyadic size and make it complex
    U_0 = pad_fn(x)
    # compute the Fourier transform
    U_0_hat = rfft(U_0)

    # Get S0
    k0 = max(log2_T - oversampling, 0)

    if average:
        S_0_c = cdgmm(U_0_hat, phi[0])
        S_0_hat = subsample_fourier(S_0_c, 2**k0)
        S_0_r = irfft(S_0_hat)

        S_0 = unpad(S_0_r, ind_start[k0], ind_end[k0])
    else:
        S_0 = x
    out_S_0.append({'coef': S_0,
                    'j': (),
                    'n': ()})

    # First order:
    U_1_list = []
    for n1 in range(len(psi1)):
        # Convolution + downsampling
        j1 = psi1[n1]['j']

        k1 = max(min(j1, log2_T) - oversampling, 0)

        assert psi1[n1]['xi'] < 0.5 / (2**k1)
        U_1_c = cdgmm(U_0_hat, psi1[n1][0])
        U_1_hat = subsample_fourier(U_1_c, 2**k1)
        U_1_c = ifft(U_1_hat)

        # Take the modulus
        U_1_m = modulus(U_1_c)

        if average or max_order > 1:
            U_1_hat = rfft(U_1_m)
        U_1_list.append(U_1_hat)

        if average:
            # Convolve with phi_J
            k1_J = max(log2_T - k1 - oversampling, 0)
            S_1_c = cdgmm(U_1_hat, phi[k1])
            S_1_hat = subsample_fourier(S_1_c, 2**k1_J)
            S_1_r = irfft(S_1_hat)

            S_1 = unpad(S_1_r, ind_start[k1_J + k1], ind_end[k1_J + k1])
        else:
            S_1 = unpad(U_1_m, ind_start[k1], ind_end[k1])

        out_S_1.append({'coef': S_1,
                        'j': (j1,),
                        'n': (n1,)})

    if max_order == 2:
        # 2nd order
        for n2 in range(len(psi2)):
            j2 = psi2[n2]['j']

            if j2 == 0:
                continue

            Y_2_list = []
            for n1 in range(len(U_1_list)):
                j1 = psi1[n1]['j']
                if j1 >= j2:
                    continue
                U_1_hat = U_1_list[n1]
                k1 = max(min(j1, log2_T) - oversampling, 0)

                assert psi2[n2]['xi'] < psi1[n1]['xi']

                # convolution + downsampling
                k2 = max(min(j2, log2_T) - k1 - oversampling, 0)

                U_2_c = cdgmm(U_1_hat, psi2[n2][k1])
                U_2_hat = subsample_fourier(U_2_c, 2**k2)
                # take the modulus
                U_2_c = ifft(U_2_hat)

                U_2_m = modulus(U_2_c)
                Y_2_list.append(U_2_m)

            if average:
                U_2_arr = concatenate_v2(Y_2_list, axis=1)
                U_2_hat = rfft(U_2_arr)

                # Convolve with phi_J
                k2_J = max(log2_T - k2 - k1 - oversampling, 0)

                S_2_c = cdgmm(U_2_hat, phi[k1 + k2])
                S_2_hat = subsample_fourier(S_2_c, 2**k2_J)
                S_2_r = irfft(S_2_hat)

                S_2 = unpad(S_2_r, ind_start[k1 + k2 + k2_J],
                            ind_end[k1 + k2 + k2_J])
            else:
                S_2 = unpad(U_2_m, ind_start[k1 + k2], ind_end[k1 + k2])

            for n1 in range(len(U_1_list)):
                j1 = psi1[n1]['j']
                if j1 >= j2:
                    continue
                out_S_2.append({'coef': S_2[:, n1:n1+1],
                                'j': (j1, j2),
                                'n': (n1, n2)})

    out_S = []
    out_S.extend(out_S_0)
    out_S.extend(out_S_1)
    out_S.extend(out_S_2)

    if out_type == 'array' and average:
        out_S = concatenate([x['coef'] for x in out_S])

    return out_S

__all__ = ['scattering1d']
benchmark
# -*- coding: utf-8 -*-
import numpy as np
import torch
import torch.utils.benchmark as benchmark
from kymatio import Scattering1D, TimeFrequencyScattering1D
from kymatio.visuals import plotscat
from timeit import default_timer as dtime

def timeit(fn, n_iters=10):
    t0 = dtime()
    for _ in range(n_iters):
        fn()
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
    return (dtime() - t0) / n_iters


def sc0(sc, x):
    return sc(x)

def sc1(jtfs, x):
    return jtfs(x)

#%%
def bench(device, N, params_tm, params_fr, n_iters=10, dont_print=()):
    GPU = bool(device == 'gpu')
    frontend = 'torch' if GPU else 'numpy'

    x = (torch.randn(N, device='cuda') if GPU else
         np.random.randn(N))
    sc   = Scattering1D(shape=N, **params_tm, frontend=frontend)
    jtfs = TimeFrequencyScattering1D(shape=N, **params_tm, **params_fr,
                                     frontend=frontend)
    if GPU:
        sc = sc.cuda()
        jtfs = jtfs.cuda()

    # warmup
    _ = sc(x)
    _ = jtfs(x)

    if GPU:
        t_sc_fn = benchmark.Timer(
            stmt='sc0(sc, x)',
            setup='from __main__ import sc0',
            globals={'x': x, 'sc': sc},
        )
        t_jtfs_fn = benchmark.Timer(
            stmt='sc1(jtfs, x)',
            setup='from __main__ import sc1',
            globals={'x': x, 'jtfs': jtfs},
        )
        t_sc_obj   = t_sc_fn.timeit(  n_iters)
        t_jtfs_obj = t_jtfs_fn.timeit(n_iters)
        t_sc   = t_sc_obj.mean
        t_jtfs = t_jtfs_obj.mean
    else:
        t_sc   = timeit(lambda: sc(x),   n_iters)
        t_jtfs = timeit(lambda: jtfs(x), n_iters)

    tm_str = ", ".join(f"{k}={v}" for k, v in params_tm.items()
                       if k not in dont_print)
    fr_str = ", ".join(f"{k}={v}" for k, v in params_fr.items()
                       if k not in dont_print)
    print(("N={}, {}\n"
           "{}, {}\n"
           "SC:   {:.3f} sec\n"
           "JTFS: {:.3f} sec\n").format(N, device.upper(), tm_str, fr_str,
                                        t_sc, t_jtfs))
    return t_sc, t_jtfs

n_iters = 10
N_all = (2048, 8192, 32768, 131072)[:]
devices = ('gpu', 'cpu')[:1]

params_tm = dict(Q=8, max_pad_factor=1)
params_fr = dict(F=8, J_fr=5, Q_fr=1, max_pad_factor_fr=1,
                 pad_mode_fr='zero')
dont_print = ('max_pad_factor', 'max_pad_factor_fr', 'pad_mode_fr')

t_sc_all, t_jtfs_all = {}, {}
for device in devices:
    t_sc_all[device], t_jtfs_all[device] = [], []
    for N in N_all:
        params_tm.update(dict(T=N//4, J=int(np.log2(N) - 2)))
        t_sc, t_jtfs = bench(device, N, params_tm, params_fr, n_iters, dont_print)
        t_sc_all[device].append(t_sc)
        t_jtfs_all[device].append(t_jtfs)

#%%
for device in devices:
    r = np.array(t_jtfs_all[device]) / np.array(t_sc_all[device])
    plotscat(
        np.log2(N_all), r,
        title="t_jtfs / t_sc | {}, {}-iter avg".format(device.upper(), n_iters),
        xlabel='log2(N)', ylabel='time [sec]', ylims=(0, 1.05*r.max()), show=1)

Issue Analytics

  • State:open
  • Created 2 years ago
  • Comments:6

github_iconTop GitHub Comments

1reaction
lostanlencommented, May 2, 2022

This way, if power users want a faster implementation, they can code their own breadth first search core, and then replace ScatteringObject.scattering#d with the new implementation.

Right. I think this is what @eickenberg had in mind with #735 “modularizing the codebase in a semantically meaningful way” In #735 he also proposed to implement the scattering operator (convolution with many wavelets and/or one low-pass) by means of a Python generator, which would allow to quickly switch between DFS and BFS at runtime, with little extra code.

Alternatively, users can now sub class the Torch scattering object and redefine the scattering function here without having to redefine the core scattering. This makes development of different scatterings much easier for the end user (like if they’re trying to implement optimizable scattering ;D)

Yes, this is another important perspective, aimed at power users and whoever wants to improve/extend the current definitions we have.

0reactions
MuawizChaudharycommented, Apr 30, 2022

A power feature would be to let users replace the core scattering function here..

That is, we remove the from ...scattering#d.core.scattering#d import scattering#d line from each specific framework’s frontend, and instead move the import into base_frontend, and set self.scattering#d = scattering#d.core.scattering#d.

This way, if power users want a faster implementation, they can code their own breadth first search core, and then replace ScatteringObject.scattering#d with the new implementation.

Alternatively, users can now sub class the Torch scattering object and redefine the scattering function here without having to redefine the core scattering. This makes development of different scatterings much easier for the end user (like if they’re trying to implement optimizable scattering ;D)

Read more comments on GitHub >

github_iconTop Results From Across the Web

Difference between BFS and DFS - GeeksforGeeks
DFS stands for Depth First Search. 2. ... BFS(Breadth First Search) uses Queue data structure for finding the shortest path. DFS(Depth First ......
Read more >
Breadth First Vs Depth First - algorithm - Stack Overflow
Depth First Traversal : Traversal is not done ACROSS entire levels at a time. Instead, traversal dives into the DEPTH (from root to...
Read more >
Depth-First Search vs. Breadth-First Search - Baeldung
Both algorithms search by superimposing a tree over the graph, which we call the search tree. DFS and BFS set its root to...
Read more >
Depth-First Search (DFS) vs Breadth-First Search (BFS)
This post will cover the difference between the Depth–first search (DFS) and Breadth–first search (BFS) algorithm used to traverse/search tree or graph data ......
Read more >
BFS vs DFS – Difference Between Them - Guru99
The full form of BFS is Breadth-First Search, while the full form of DFS is Depth-First Search. BFS uses a queue to keep...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found