Backprop support for lfilter
See original GitHub issue🚀 Feature
It is currently not possible to backpropagate gradients through an lfilter because of this inplace operation: https://github.com/pytorch/audio/blob/master/torchaudio/functional.py#L661
Motivation
It’s not worth the pytorch overhead to even use lfilter without backprop support (it’s much faster when implemented using e.g. numba). When I saw that this was implemented here, I was hoping to use it instead of my own implementation (which is implemented as a custom RNN) as it is honestly too slow.
Pitch
I would love to see that inplace operation replaced with something that would allow supporting backprop. I’m not sure what the most efficient way to do this is.
Alternatives
I implemented transposed direct form II digital filters as custom RNNs, but the performance is pretty poor (which seems to be a problem with the fuser). This is the simplest version I tried, which works, but as I said it’s quite slow.
class DigitalFilterModel(jit.ScriptModule):
def __init__(self):
super(DigitalFilterModel, self).__init__()
@jit.script_method
def forward(self, x, coeffs, v1, v2, v3):
# type: (Tensor, Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]
seq_len = x.shape[1]
output = torch.jit.annotate(List[Tensor], [])
x = x.unbind(1)
coeffs = coeffs.unbind(1)
for i in range(seq_len):
sample = x[i]
out = coeffs[0] * sample + v1
output.append(out)
v1 = coeffs[1] * sample - coeffs[4] * out + v2
v2 = coeffs[2] * sample - coeffs[5] * out + v3
v3 = coeffs[3] * sample - coeffs[6] * out
return torch.stack(output, 1), v1, v2, v3
Another alternative I’ve used when I only need to backprop through the filter, but not optimize the actual coefficients, is to take advantage of the fact that tanh is close to linear for very small inputs and design a standard RNN to be equivalent to the digital filter. Crushing the input, then rescaling the output to keep it linear gives a result very close to the original filter, but this is obviously quite a hack:
class RNNTDFWrapper(nn.Module):
def __init__(self, eps=0.000000001):
super(RNNTDFWrapper, self).__init__()
self.eps = eps
self.rnn = nn.RNN(1, 4, 1, False, True)
def set_coefficients(self, coeffs):
self.rnn.weight_ih_l0.data[:,:] = torch.tensor(coeffs[:4]).view(-1,1)
self.rnn.weight_hh_l0.data[:,:] = 0.0
self.rnn.weight_hh_l0.data[0,1] = 1.0
self.rnn.weight_hh_l0.data[1,2] = 1.0
self.rnn.weight_hh_l0.data[2,3] = 1.0
self.rnn.weight_hh_l0.data[:3,0] = -1.0 * torch.tensor(coeffs[4:])
def forward(self, x):
batch_size = x.shape[0]
x = self.eps * x.view(batch_size, -1, 1)
x, _ = self.rnn.forward(x)
x = (1.0/self.eps) * x[:,:,0]
return x
Issue Analytics
- State:
- Created 3 years ago
- Reactions:3
- Comments:10 (7 by maintainers)
Top GitHub Comments
@vincentqb thanks, I’ll take a look.
Thanks for writing this and sharing it with the community! If torchscriptabilitiy is not a concern, then this is a great way to bind the forward and the backward pass 😃 This is in fact how we (temporarily) bind the prototype RNN transducer here in torchaudio.
Such custom autograd functions (both in python and C++) are not currently supported by torchscript though. Using this within torchaudio directly in place of the current
lfilter
(which is torchscriptable) would be BC breaking unfortunately. In the long term, we’ll need to register the backward pass with autograd. Here’s a tutorial for how to do this in a torchscriptable manner.