question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[Feature Request] efficient logdet for KroneckerProductLazyTensor

See original GitHub issue

🚀 Feature Request

It appears that KroneckerProductLazyTensor doesn’t compute efficient logdets (as a product of the scaled log-determinants of the factors). This means it costs O(n^2m^2) for a 2-factor kronecker kernel with n x n and m x m factors instead of O(n^2 + m^2) even if we’re using efficient inv_quads.

This is fairly straightforward to add and would speed up multitask models and probably a bunch of other stuff. I’m happy to do it, just a bit confused on (a) what reduce_inv_quad does and (b) where the efficient inv_quad is happening right now and whether I can reuse it or also need to implement an efficient kron-solve.

Simple example:

import gpytorch
import torch
from gpytorch.functions import logdet, inv_matmul
from gpytorch.utils.cholesky import psd_safe_cholesky
LOG2PI = 1.8378770664093453

def lazy_matrixnormal_likelihood(X, R, C):
    """
    Likelihood for matrix-variate normal
    (https://en.wikipedia.org/wiki/Matrix_normal_distribution)
    using lazy logdet and inv_matmul
    """
    row_logdet = logdet(R)
    col_logdet = logdet(C)

    logdet_term = X.shape[0] * col_logdet + X.shape[1] * row_logdet

    trace_term = torch.trace(inv_matmul(R, X).matmul(inv_matmul(C, X.t())))

    return -0.5 * (X.shape[1] * X.shape[0] * LOG2PI + logdet_term + trace_term)


def lazy_kron_mvn(X, K):
    
    """likelihood for multivariate normal using lazy
    """
    logdet_term = logdet(K)

    trace_term = X.t().matmul(inv_matmul(K, X))
    return -0.5 * (X.shape[0] * LOG2PI + logdet_term + trace_term)


def chol_logdet(L):
    """
    log determinant of a matrix with cholesky factor L
    """
    return 2 * torch.sum(torch.log(torch.diag(L)))


def cholesky_solve(x, L):
    """
    Solve positive definite system using the cholesky factor 
    """
    y = torch.triangular_solve(x, L, upper=False).solution
    return torch.triangular_solve(y, L.t(), upper=True).solution

def nonlazy_matrixnormal_likelihood(X, R, C):
    cholr = psd_safe_cholesky(R)
    cholc = psd_safe_cholesky(C)

    logdet_term = X.shape[0] * chol_logdet(cholc) + X.shape[1] * chol_logdet(cholr)

    trace_term = torch.trace(cholesky_solve(X, cholr).matmul(cholesky_solve(X.t(), cholc)))

    return -0.5 * (X.shape[1] * X.shape[0] * LOG2PI + logdet_term + trace_term)
def kron_bench(n=100, m=100, device='cpu'):
    print(f"====== n={n},m={m}, device={device} ======")
    X = torch.arange(n, dtype=torch.float, device=device)[:,None] @ torch.arange(m, dtype=torch.float, device=device)[None,:]
    y = torch.randn(m, n, device=device)
    yflat = y.flatten()
    muflat = torch.zeros(n*m, device=device)
    k = gpytorch.kernels.RBFKernel()
    K_row = (k(X) + torch.eye(m)).to(device)
    K_col = (k(X.t()) + torch.eye(n)).to(device)
    K_kron = gpytorch.lazy.KroneckerProductLazyTensor(K_col, K_row).to(device)
    
    print(f"\n matrix-normal with lazy solve/logdet")
    print(f"logp: {lazy_matrixnormal_likelihood(y, K_row, K_col)}")
    %timeit lazy_matrixnormal_likelihood(y, K_row, K_col)

    print(f"\n matrix-normal with torch cholesky")
    print(f"logp: {nonlazy_matrixnormal_likelihood(y, K_row.evaluate(), K_col.evaluate())}")
    %timeit nonlazy_matrixnormal_likelihood(y, K_row.evaluate(), K_col.evaluate())
    
    print(f"\n multivariate normal with KronLazy")
    print(f"logp: {lazy_kron_mvn(yflat, K_kron)}")
    %timeit lazy_kron_mvn(yflat, K_kron)
    
    print(f"\n gpytorch multivariate normal with Kronlazy")
    mvn_distr = gpytorch.distributions.multivariate_normal.MultivariateNormal(mean=muflat, covariance_matrix=K_kron)
    print(f"logp: { mvn_distr.log_prob(yflat)}")
    %timeit mvn_distr.log_prob(yflat)    
    
    
    print("\n")
    

kron_bench(n=100, m=100, device='cpu')
kron_bench(n=100, m=100, device='cuda')
kron_bench(n=1000, m=1000, device='cpu')
kron_bench(n=1000, m=1000, device='cuda')
====== n=100,m=100, device=cpu ======

 matrix-normal with lazy solve/logdet
logp: -17354.173828125
1.55 ms ± 16.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

 matrix-normal with torch cholesky
logp: -17354.173828125
716 µs ± 11 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

 multivariate normal with KronLazy
logp: -17354.173828125
100 ms ± 2.16 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

 gpytorch multivariate normal with Kronlazy
logp: -17354.171875
98.9 ms ± 1.89 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


====== n=100,m=100, device=cuda ======

 matrix-normal with lazy solve/logdet
logp: -17373.841796875
8.63 ms ± 14.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

 matrix-normal with torch cholesky
logp: -17373.841796875
5.62 ms ± 14.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

 multivariate normal with KronLazy
logp: -17373.841796875
43.5 ms ± 23.9 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

 gpytorch multivariate normal with Kronlazy
logp: -17373.83984375
30.2 ms ± 27.6 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


====== n=1000,m=1000, device=cpu ======

 matrix-normal with lazy solve/logdet
logp: -1737016.125
820 ms ± 7.41 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

 matrix-normal with torch cholesky
logp: -1737016.25
40.1 ms ± 420 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

 multivariate normal with KronLazy
logp: -1737016.0
16.2 s ± 157 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

 gpytorch multivariate normal with Kronlazy
logp: -1737015.75
17.3 s ± 118 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


====== n=1000,m=1000, device=cuda ======

 matrix-normal with lazy solve/logdet
logp: -1735884.75
87 ms ± 199 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

 matrix-normal with torch cholesky
logp: -1736324.125
23.9 ms ± 892 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

 multivariate normal with KronLazy
logp: -1736284.5
570 ms ± 194 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

 gpytorch multivariate normal with Kronlazy
logp: -1736323.625
578 ms ± 118 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Reactions:1
  • Comments:16 (13 by maintainers)

github_iconTop GitHub Comments

1reaction
wjmaddoxcommented, Nov 24, 2021

This should be finally resolved by #1786 amongst other improvements to KroneckerProductLazyTensor.

0reactions
Balandatcommented, Mar 7, 2021

Looks like this is still open and will require inv_quad_logdet KroneckerProductLazyTensor? Currently the efficient Cholesky will be used if the matrix is small enough, but not for larger matrices (where this would be useful).

Read more comments on GitHub >

github_iconTop Results From Across the Web

Feature Request Template: How to Manage Suggestions at ...
Streamline and organize user feedback with this free feature request template. Available in Google Docs and Sheets (no email required).
Read more >
Feature Request Management - Savio
Manage your product feature requests efficiently. The Savio Feature Request Management platform helps you consolidate feature requests in a single place.
Read more >
How to effectively prioritize your feature requests - Acute
To effectively prioritize feature requests and assign them a prioritization score you want to calculate an aggregation of scores based on ...
Read more >
Feature Requests | Mote
To me, this doesn't make sense and is not efficient. Can that be added to the options? 0. 1. Drag QR from extension...
Read more >
3 ways to manage software feature request - Amoeboids
Discover best practices to manage feature requests from different stakeholders. ... To create an efficient feature request tracking process is no easy task, ......
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found