Basic linear algebra for complex numbers
See original GitHub issue🚀 Feature
Support basic linear algebra for complex numbers.
Motivation
I talked with @sw005320 about https://github.com/nttcslab-sp/dnn_wpe and it turns out, that the matrix inversion implemented with real numbers is unstable. In a beamforming example @Emrys365 observed a performance difference of 5 dB in a signal to distortion ratio (SDR) where he replaced the inversion with numpy code (torch: 5dB, numpy 10dB).
I tried torch.inverse
and torch.solve
and interestingly they are working in 1.6.0.dev20200623+cpu
(Not mentioned in https://github.com/pytorch/pytorch/issues/33152).
Is it possible, to support torch.matmul
and some other linear algebra functions?
I also tried to use backward after torch.solve
and the code fails with the exception msg, that matmul
is not implemented.
Does someone know, how the gradient is defined in torch for complex numbers?
Is it grad_real + j grad_imag
or grad_real - j grad_imag
?
And how can I add/fix the gradient, when I find a broken implementation?
Pitch
- Using native functions for linear algebra of complex numbers instead of https://github.com/kamo-naoyuki/pytorch_complex
torch.solve
andtorch.inverse
miss a backward functiontorch.matmul
does not work
Alternatives
- https://github.com/kamo-naoyuki/pytorch_complex has precision problems with the matrix inverse.
Additional context
Currently, I am considering to jump between pytorch_complex
and torch.autograd.Function
:
def hermite(a):
return a.transpose(-2, -1).conj()
def matmul(t1, t2):
real1, imag1 = t1.real, t1.imag
real2, imag2 = t2.real, t2.imag
o_real = torch.matmul(real1, real2) - torch.matmul(imag1, imag2)
o_imag = torch.matmul(real1, imag2) + torch.matmul(imag1, real2)
return o_real + 1j * o_imag
class Solve(torch.autograd.Function):
@staticmethod
def forward(ctx, A, b):
x, _ = torch.solve(b, A)
ctx.save_for_backward(A, x)
return x
@staticmethod
def backward(ctx, grad_output):
A, x = ctx.saved_tensors
gb, _ = torch.solve(grad_output, hermite(A))
gA = - matmul(gb, hermite(x))
return gA, gb
Issue Analytics
- State:
- Created 3 years ago
- Comments:17 (6 by maintainers)
Top GitHub Comments
@anjali411 @boeddeker I will coordinate the meeting then.
In PyTorch 1.9, complex numbers are supported in most (if not all) of the linear algebra. Thanks for the feedback!