[Feature Request] efficient logdet for KroneckerProductLazyTensor
See original GitHub issue🚀 Feature Request
It appears that KroneckerProductLazyTensor doesn’t compute efficient logdets (as a product of the scaled log-determinants of the factors). This means it costs O(n^2m^2) for a 2-factor kronecker kernel with n x n and m x m factors instead of O(n^2 + m^2) even if we’re using efficient inv_quads.
This is fairly straightforward to add and would speed up multitask models and probably a bunch of other stuff. I’m happy to do it, just a bit confused on (a) what reduce_inv_quad
does and (b) where the efficient inv_quad is happening right now and whether I can reuse it or also need to implement an efficient kron-solve.
Simple example:
import gpytorch
import torch
from gpytorch.functions import logdet, inv_matmul
from gpytorch.utils.cholesky import psd_safe_cholesky
LOG2PI = 1.8378770664093453
def lazy_matrixnormal_likelihood(X, R, C):
"""
Likelihood for matrix-variate normal
(https://en.wikipedia.org/wiki/Matrix_normal_distribution)
using lazy logdet and inv_matmul
"""
row_logdet = logdet(R)
col_logdet = logdet(C)
logdet_term = X.shape[0] * col_logdet + X.shape[1] * row_logdet
trace_term = torch.trace(inv_matmul(R, X).matmul(inv_matmul(C, X.t())))
return -0.5 * (X.shape[1] * X.shape[0] * LOG2PI + logdet_term + trace_term)
def lazy_kron_mvn(X, K):
"""likelihood for multivariate normal using lazy
"""
logdet_term = logdet(K)
trace_term = X.t().matmul(inv_matmul(K, X))
return -0.5 * (X.shape[0] * LOG2PI + logdet_term + trace_term)
def chol_logdet(L):
"""
log determinant of a matrix with cholesky factor L
"""
return 2 * torch.sum(torch.log(torch.diag(L)))
def cholesky_solve(x, L):
"""
Solve positive definite system using the cholesky factor
"""
y = torch.triangular_solve(x, L, upper=False).solution
return torch.triangular_solve(y, L.t(), upper=True).solution
def nonlazy_matrixnormal_likelihood(X, R, C):
cholr = psd_safe_cholesky(R)
cholc = psd_safe_cholesky(C)
logdet_term = X.shape[0] * chol_logdet(cholc) + X.shape[1] * chol_logdet(cholr)
trace_term = torch.trace(cholesky_solve(X, cholr).matmul(cholesky_solve(X.t(), cholc)))
return -0.5 * (X.shape[1] * X.shape[0] * LOG2PI + logdet_term + trace_term)
def kron_bench(n=100, m=100, device='cpu'):
print(f"====== n={n},m={m}, device={device} ======")
X = torch.arange(n, dtype=torch.float, device=device)[:,None] @ torch.arange(m, dtype=torch.float, device=device)[None,:]
y = torch.randn(m, n, device=device)
yflat = y.flatten()
muflat = torch.zeros(n*m, device=device)
k = gpytorch.kernels.RBFKernel()
K_row = (k(X) + torch.eye(m)).to(device)
K_col = (k(X.t()) + torch.eye(n)).to(device)
K_kron = gpytorch.lazy.KroneckerProductLazyTensor(K_col, K_row).to(device)
print(f"\n matrix-normal with lazy solve/logdet")
print(f"logp: {lazy_matrixnormal_likelihood(y, K_row, K_col)}")
%timeit lazy_matrixnormal_likelihood(y, K_row, K_col)
print(f"\n matrix-normal with torch cholesky")
print(f"logp: {nonlazy_matrixnormal_likelihood(y, K_row.evaluate(), K_col.evaluate())}")
%timeit nonlazy_matrixnormal_likelihood(y, K_row.evaluate(), K_col.evaluate())
print(f"\n multivariate normal with KronLazy")
print(f"logp: {lazy_kron_mvn(yflat, K_kron)}")
%timeit lazy_kron_mvn(yflat, K_kron)
print(f"\n gpytorch multivariate normal with Kronlazy")
mvn_distr = gpytorch.distributions.multivariate_normal.MultivariateNormal(mean=muflat, covariance_matrix=K_kron)
print(f"logp: { mvn_distr.log_prob(yflat)}")
%timeit mvn_distr.log_prob(yflat)
print("\n")
kron_bench(n=100, m=100, device='cpu')
kron_bench(n=100, m=100, device='cuda')
kron_bench(n=1000, m=1000, device='cpu')
kron_bench(n=1000, m=1000, device='cuda')
====== n=100,m=100, device=cpu ======
matrix-normal with lazy solve/logdet
logp: -17354.173828125
1.55 ms ± 16.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
matrix-normal with torch cholesky
logp: -17354.173828125
716 µs ± 11 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
multivariate normal with KronLazy
logp: -17354.173828125
100 ms ± 2.16 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
gpytorch multivariate normal with Kronlazy
logp: -17354.171875
98.9 ms ± 1.89 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
====== n=100,m=100, device=cuda ======
matrix-normal with lazy solve/logdet
logp: -17373.841796875
8.63 ms ± 14.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
matrix-normal with torch cholesky
logp: -17373.841796875
5.62 ms ± 14.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
multivariate normal with KronLazy
logp: -17373.841796875
43.5 ms ± 23.9 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
gpytorch multivariate normal with Kronlazy
logp: -17373.83984375
30.2 ms ± 27.6 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
====== n=1000,m=1000, device=cpu ======
matrix-normal with lazy solve/logdet
logp: -1737016.125
820 ms ± 7.41 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
matrix-normal with torch cholesky
logp: -1737016.25
40.1 ms ± 420 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
multivariate normal with KronLazy
logp: -1737016.0
16.2 s ± 157 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
gpytorch multivariate normal with Kronlazy
logp: -1737015.75
17.3 s ± 118 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
====== n=1000,m=1000, device=cuda ======
matrix-normal with lazy solve/logdet
logp: -1735884.75
87 ms ± 199 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
matrix-normal with torch cholesky
logp: -1736324.125
23.9 ms ± 892 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
multivariate normal with KronLazy
logp: -1736284.5
570 ms ± 194 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
gpytorch multivariate normal with Kronlazy
logp: -1736323.625
578 ms ± 118 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
Issue Analytics
- State:
- Created 4 years ago
- Reactions:1
- Comments:16 (13 by maintainers)
Top Results From Across the Web
Feature Request Template: How to Manage Suggestions at ...
Streamline and organize user feedback with this free feature request template. Available in Google Docs and Sheets (no email required).
Read more >Feature Request Management - Savio
Manage your product feature requests efficiently. The Savio Feature Request Management platform helps you consolidate feature requests in a single place.
Read more >How to effectively prioritize your feature requests - Acute
To effectively prioritize feature requests and assign them a prioritization score you want to calculate an aggregation of scores based on ...
Read more >Feature Requests | Mote
To me, this doesn't make sense and is not efficient. Can that be added to the options? 0. 1. Drag QR from extension...
Read more >3 ways to manage software feature request - Amoeboids
Discover best practices to manage feature requests from different stakeholders. ... To create an efficient feature request tracking process is no easy task, ......
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
This should be finally resolved by #1786 amongst other improvements to
KroneckerProductLazyTensor
.Looks like this is still open and will require
inv_quad_logdet
KroneckerProductLazyTensor
? Currently the efficient Cholesky will be used if the matrix is small enough, but not for larger matrices (where this would be useful).