question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[Bug] LKJCholeskyFactorPrior fails on GPU

See original GitHub issue

🐛 Bug

LKJCholeskyFactorPrior fails on the GPU. This also takes out GPU support for botorch’s KroneckerMultiTaskGP @Balandat

To reproduce

** Code snippet to reproduce **

import torch
from gpytorch.priors import LKJCholeskyFactorPrior

a = torch.randn(5, 5)
mat = a @ a.t() + torch.diag(torch.rand(5))
inv_sqrt = torch.diag(mat.diag().reciprocal())
corrmat = inv_sqrt @ mat @ inv_sqrt

prior = LKJCholeskyFactorPrior(5, 0.5)
prior.log_prob(corrmat)

prior = prior.to(torch.device("cuda:0"))
prior.log_prob(corrmat.cuda())

## botorch error
train_x = torch.randn(30, 1).cuda()
train_y = torch.randn(30, 3).cuda()
model = KroneckerMultiTaskGP(train_x, train_y)
mll = ExactMarginalLogLikelihood(model.likelihood, model)

mll(model(*model.train_inputs), model.train_targets) # same error

** Stack trace/error message **

File ~/gpytorch/gpytorch/priors/prior.py:27, in Prior.log_prob(self, x)
     22 def log_prob(self, x):
     23     r"""
     24     :return: log-probability of the parameter value under the prior
     25     :rtype: torch.Tensor
     26     """
---> 27     return super(Prior, self).log_prob(self.transform(x))

File ~/miniconda3/lib/python3.9/site-packages/torch/distributions/lkj_cholesky.py:117, in LKJCholesky.log_prob(self, value)
    115 order = torch.arange(2, self.dim + 1)
    116 order = 2 * (self.concentration - 1).unsqueeze(-1) + self.dim - order
--> 117 unnormalized_log_pdf = torch.sum(order * diag_elems.log(), dim=-1)
    118 # Compute normalization constant (page 1999 of [1])
    119 dm1 = self.dim - 1

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

Expected Behavior

LKJCholesky log prob shouldn’t error.

System information

Please complete the following information:

  • pytorch 1.11

Additional context

Might be fixed in pytorch nightly but thought I’d point it out:

https://github.com/pytorch/pytorch/issues/58774

I could also put up a PR here copying that log prob and just enforcing the order tensor to have the proper device.

Issue Analytics

  • State:open
  • Created a year ago
  • Comments:5

github_iconTop GitHub Comments

2reactions
neerajpradcommented, Apr 19, 2022

This is a bug which was fixed by @dme65 last month, that’s why the nightly seems to work fine. I’ll also comment on the issue.

0reactions
neerajpradcommented, Apr 19, 2022

Easiest fix is probably just to perform that computation on the CPU.

Unfortunately, I don’t have a good solution in the interim. A hacky one that we have previously used in Pyro is to monkey-patch until the solution makes it to release. See https://github.com/pyro-ppl/pyro/blob/dev/pyro/distributions/torch_patch.py#L65 on one way to patch the log prob which can be removed on PyTorch’s next release.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Wishart / InverseWishart / LKJ priors · Issue #1692 · pyro-ppl ...
Actually modeling-wise the LKJ prior would be more useful, but I tried that and it also runs into the same error. Here's my...
Read more >
Source code for gpytorch.priors.lkj_prior
LKJCholeskyFactorPrior is different from LKJPrior in that it accepts the Cholesky factor of the correlation matrix to compute probabilities.
Read more >
Difficulties with lkj_corr - Google Groups
When I run the model using the lkj distribution I get the error: ... you do have to register the prior explicitly when...
Read more >
Speeding up a logistic regression with RHS prior (Turing vs ...
Case study: Speeding up a logistic regression with RHS prior (Turing vs Numpyro) - any tricks I'm ... I think the best way...
Read more >
CommaInitializer.h error running model on GPU
I get a c++ error ... The HPC isn't giving me any error that I can see. ... cholesky factor of correlation matrix...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found