Depth-first vs. breadth-first traversal
See original GitHub issueJTFS ran 40% faster than TS on GPU, and I couldn’t fathom why. On CPU it is several times slower, as expected. TS is an exact subset of JTFS, so this should be impossible.
Then it dawned on me, JTFS batches lowpassing for each n2
. To test, I swapped the compute graph for TS to mimic JTFS: 87% speedup, beating JTFS. Note this is only the lowpassing stage, the n2<->n1
computes can still be batched in both TS and JTFS.
TS was still under 40% faster than JTFS in the config I tested, so torch’s FFT is either poorly optimized for singletons or exceptionally optimized for batching. Latter seems to be the case since JTFS is way slower on a multi-core CPU, and per the 40%, since all frequential scattering is batched. Averaged over 10 iters:
Q=8
N | t_ts_old / t_ts_new | t_jtfs / t_ts_new
2048 | 1.69 | 1.37
8192 | 1.67 | 1.22
32768 | 1.87 | 1.38
131072 | 1.80 | 1.31
Q=16
N | t_ts_old / t_ts_new | t_jtfs / t_ts_new
2048 | 1.56 | 1.37
8192 | 1.72 | 1.23
32768 | 2.16 | 1.28
131072 | 2.02 | 1.43
Think this should be convincing to rewrite the compute graph for Scattering1D
. My implem below, still needs matching meta
etc.
Scattering1D
def scattering1d(x, pad_fn, unpad, backend, J, log2_T, psi1, psi2, phi,
ind_start=None, ind_end=None, oversampling=0,
max_order=2, average=True, size_scattering=(0, 0, 0),
out_type='array'):
"""
Main function implementing the 1-D scattering transform.
Parameters
----------
x : Tensor
a torch Tensor of size `(B, 1, N)` where `N` is the temporal size
psi1 : dictionary
a dictionary of filters (in the Fourier domain), with keys (`j`, `q`).
`j` corresponds to the downsampling factor for
:math:`x \\ast psi1[(j, q)]``, and `q` corresponds to a pitch class
(chroma).
* psi1[(j, n)] is itself a dictionary, with keys corresponding to the
dilation factors: psi1[(j, n)][j2] corresponds to a support of size
:math:`2^{J_\\text{max} - j_2}`, where :math:`J_\\text{max}` has been
defined a priori (`J_max = size` of the padding support of the input)
* psi1[(j, n)] only has real values;
the tensors are complex so that broadcasting applies
psi2 : dictionary
a dictionary of filters, with keys (j2, n2). Same remarks as for psi1
phi : dictionary
a dictionary of filters of scale :math:`2^J` with keys (`j`)
where :math:`2^j` is the downsampling factor.
The array `phi[j]` is a real-valued filter.
J : int
scale of the scattering
log2_T : int
(log2 of) temporal support of low-pass filter, controlling amount of
imposed time-shift invariance and maximum subsampling
pad_left : int, optional
how much to pad the signal on the left. Defaults to `0`
pad_right : int, optional
how much to pad the signal on the right. Defaults to `0`
ind_start : dictionary of ints, optional
indices to truncate the signal to recover only the
parts which correspond to the actual signal after padding and
downsampling. Defaults to None
ind_end : dictionary of ints, optional
See description of ind_start
oversampling : int, optional
how much to oversample the scattering (with respect to :math:`2^J`):
the higher, the larger the resulting scattering
tensor along time. Defaults to `0`
order2 : boolean, optional
Whether to compute the 2nd order or not. Defaults to `False`.
average_U1 : boolean, optional
whether to average the first order vector. Defaults to `True`
size_scattering : tuple
Contains the number of channels of the scattering, precomputed for
speed-up. Defaults to `(0, 0, 0)`.
vectorize : boolean, optional
whether to return a dictionary or a tensor. Defaults to False.
"""
subsample_fourier = backend.subsample_fourier
modulus = backend.modulus
rfft = backend.rfft
ifft = backend.ifft
irfft = backend.irfft
cdgmm = backend.cdgmm
concatenate = backend.concatenate
concatenate_v2 = backend.concatenate_v2
# S is simply a dictionary if we do not perform the averaging...
batch_size = x.shape[0]
kJ = max(log2_T - oversampling, 0)
temporal_size = ind_end[kJ] - ind_start[kJ]
out_S_0, out_S_1, out_S_2 = [], [], []
# pad to a dyadic size and make it complex
U_0 = pad_fn(x)
# compute the Fourier transform
U_0_hat = rfft(U_0)
# Get S0
k0 = max(log2_T - oversampling, 0)
if average:
S_0_c = cdgmm(U_0_hat, phi[0])
S_0_hat = subsample_fourier(S_0_c, 2**k0)
S_0_r = irfft(S_0_hat)
S_0 = unpad(S_0_r, ind_start[k0], ind_end[k0])
else:
S_0 = x
out_S_0.append({'coef': S_0,
'j': (),
'n': ()})
# First order:
U_1_list = []
for n1 in range(len(psi1)):
# Convolution + downsampling
j1 = psi1[n1]['j']
k1 = max(min(j1, log2_T) - oversampling, 0)
assert psi1[n1]['xi'] < 0.5 / (2**k1)
U_1_c = cdgmm(U_0_hat, psi1[n1][0])
U_1_hat = subsample_fourier(U_1_c, 2**k1)
U_1_c = ifft(U_1_hat)
# Take the modulus
U_1_m = modulus(U_1_c)
if average or max_order > 1:
U_1_hat = rfft(U_1_m)
U_1_list.append(U_1_hat)
if average:
# Convolve with phi_J
k1_J = max(log2_T - k1 - oversampling, 0)
S_1_c = cdgmm(U_1_hat, phi[k1])
S_1_hat = subsample_fourier(S_1_c, 2**k1_J)
S_1_r = irfft(S_1_hat)
S_1 = unpad(S_1_r, ind_start[k1_J + k1], ind_end[k1_J + k1])
else:
S_1 = unpad(U_1_m, ind_start[k1], ind_end[k1])
out_S_1.append({'coef': S_1,
'j': (j1,),
'n': (n1,)})
if max_order == 2:
# 2nd order
for n2 in range(len(psi2)):
j2 = psi2[n2]['j']
if j2 == 0:
continue
Y_2_list = []
for n1 in range(len(U_1_list)):
j1 = psi1[n1]['j']
if j1 >= j2:
continue
U_1_hat = U_1_list[n1]
k1 = max(min(j1, log2_T) - oversampling, 0)
assert psi2[n2]['xi'] < psi1[n1]['xi']
# convolution + downsampling
k2 = max(min(j2, log2_T) - k1 - oversampling, 0)
U_2_c = cdgmm(U_1_hat, psi2[n2][k1])
U_2_hat = subsample_fourier(U_2_c, 2**k2)
# take the modulus
U_2_c = ifft(U_2_hat)
U_2_m = modulus(U_2_c)
Y_2_list.append(U_2_m)
if average:
U_2_arr = concatenate_v2(Y_2_list, axis=1)
U_2_hat = rfft(U_2_arr)
# Convolve with phi_J
k2_J = max(log2_T - k2 - k1 - oversampling, 0)
S_2_c = cdgmm(U_2_hat, phi[k1 + k2])
S_2_hat = subsample_fourier(S_2_c, 2**k2_J)
S_2_r = irfft(S_2_hat)
S_2 = unpad(S_2_r, ind_start[k1 + k2 + k2_J],
ind_end[k1 + k2 + k2_J])
else:
S_2 = unpad(U_2_m, ind_start[k1 + k2], ind_end[k1 + k2])
for n1 in range(len(U_1_list)):
j1 = psi1[n1]['j']
if j1 >= j2:
continue
out_S_2.append({'coef': S_2[:, n1:n1+1],
'j': (j1, j2),
'n': (n1, n2)})
out_S = []
out_S.extend(out_S_0)
out_S.extend(out_S_1)
out_S.extend(out_S_2)
if out_type == 'array' and average:
out_S = concatenate([x['coef'] for x in out_S])
return out_S
__all__ = ['scattering1d']
benchmark
# -*- coding: utf-8 -*-
import numpy as np
import torch
import torch.utils.benchmark as benchmark
from kymatio import Scattering1D, TimeFrequencyScattering1D
from kymatio.visuals import plotscat
from timeit import default_timer as dtime
def timeit(fn, n_iters=10):
t0 = dtime()
for _ in range(n_iters):
fn()
torch.cuda.empty_cache()
torch.cuda.synchronize()
return (dtime() - t0) / n_iters
def sc0(sc, x):
return sc(x)
def sc1(jtfs, x):
return jtfs(x)
#%%
def bench(device, N, params_tm, params_fr, n_iters=10, dont_print=()):
GPU = bool(device == 'gpu')
frontend = 'torch' if GPU else 'numpy'
x = (torch.randn(N, device='cuda') if GPU else
np.random.randn(N))
sc = Scattering1D(shape=N, **params_tm, frontend=frontend)
jtfs = TimeFrequencyScattering1D(shape=N, **params_tm, **params_fr,
frontend=frontend)
if GPU:
sc = sc.cuda()
jtfs = jtfs.cuda()
# warmup
_ = sc(x)
_ = jtfs(x)
if GPU:
t_sc_fn = benchmark.Timer(
stmt='sc0(sc, x)',
setup='from __main__ import sc0',
globals={'x': x, 'sc': sc},
)
t_jtfs_fn = benchmark.Timer(
stmt='sc1(jtfs, x)',
setup='from __main__ import sc1',
globals={'x': x, 'jtfs': jtfs},
)
t_sc_obj = t_sc_fn.timeit( n_iters)
t_jtfs_obj = t_jtfs_fn.timeit(n_iters)
t_sc = t_sc_obj.mean
t_jtfs = t_jtfs_obj.mean
else:
t_sc = timeit(lambda: sc(x), n_iters)
t_jtfs = timeit(lambda: jtfs(x), n_iters)
tm_str = ", ".join(f"{k}={v}" for k, v in params_tm.items()
if k not in dont_print)
fr_str = ", ".join(f"{k}={v}" for k, v in params_fr.items()
if k not in dont_print)
print(("N={}, {}\n"
"{}, {}\n"
"SC: {:.3f} sec\n"
"JTFS: {:.3f} sec\n").format(N, device.upper(), tm_str, fr_str,
t_sc, t_jtfs))
return t_sc, t_jtfs
n_iters = 10
N_all = (2048, 8192, 32768, 131072)[:]
devices = ('gpu', 'cpu')[:1]
params_tm = dict(Q=8, max_pad_factor=1)
params_fr = dict(F=8, J_fr=5, Q_fr=1, max_pad_factor_fr=1,
pad_mode_fr='zero')
dont_print = ('max_pad_factor', 'max_pad_factor_fr', 'pad_mode_fr')
t_sc_all, t_jtfs_all = {}, {}
for device in devices:
t_sc_all[device], t_jtfs_all[device] = [], []
for N in N_all:
params_tm.update(dict(T=N//4, J=int(np.log2(N) - 2)))
t_sc, t_jtfs = bench(device, N, params_tm, params_fr, n_iters, dont_print)
t_sc_all[device].append(t_sc)
t_jtfs_all[device].append(t_jtfs)
#%%
for device in devices:
r = np.array(t_jtfs_all[device]) / np.array(t_sc_all[device])
plotscat(
np.log2(N_all), r,
title="t_jtfs / t_sc | {}, {}-iter avg".format(device.upper(), n_iters),
xlabel='log2(N)', ylabel='time [sec]', ylims=(0, 1.05*r.max()), show=1)
Issue Analytics
- State:
- Created 2 years ago
- Comments:6
Top GitHub Comments
Right. I think this is what @eickenberg had in mind with #735 “modularizing the codebase in a semantically meaningful way” In #735 he also proposed to implement the scattering operator (convolution with many wavelets and/or one low-pass) by means of a Python generator, which would allow to quickly switch between DFS and BFS at runtime, with little extra code.
Yes, this is another important perspective, aimed at power users and whoever wants to improve/extend the current definitions we have.
A power feature would be to let users replace the core scattering function here..
That is, we remove the
from ...scattering#d.core.scattering#d import scattering#d
line from each specific framework’s frontend, and instead move the import intobase_frontend
, and setself.scattering#d = scattering#d.core.scattering#d
.This way, if power users want a faster implementation, they can code their own breadth first search core, and then replace
ScatteringObject.scattering#d
with the new implementation.Alternatively, users can now sub class the Torch scattering object and redefine the
scattering
function here without having to redefine the core scattering. This makes development of different scatterings much easier for the end user (like if they’re trying to implement optimizable scattering ;D)