[Bug] LKJCholeskyFactorPrior fails on GPU
See original GitHub issue🐛 Bug
LKJCholeskyFactorPrior fails on the GPU. This also takes out GPU support for botorch’s KroneckerMultiTaskGP
@Balandat
To reproduce
** Code snippet to reproduce **
import torch
from gpytorch.priors import LKJCholeskyFactorPrior
a = torch.randn(5, 5)
mat = a @ a.t() + torch.diag(torch.rand(5))
inv_sqrt = torch.diag(mat.diag().reciprocal())
corrmat = inv_sqrt @ mat @ inv_sqrt
prior = LKJCholeskyFactorPrior(5, 0.5)
prior.log_prob(corrmat)
prior = prior.to(torch.device("cuda:0"))
prior.log_prob(corrmat.cuda())
## botorch error
train_x = torch.randn(30, 1).cuda()
train_y = torch.randn(30, 3).cuda()
model = KroneckerMultiTaskGP(train_x, train_y)
mll = ExactMarginalLogLikelihood(model.likelihood, model)
mll(model(*model.train_inputs), model.train_targets) # same error
** Stack trace/error message **
File ~/gpytorch/gpytorch/priors/prior.py:27, in Prior.log_prob(self, x)
22 def log_prob(self, x):
23 r"""
24 :return: log-probability of the parameter value under the prior
25 :rtype: torch.Tensor
26 """
---> 27 return super(Prior, self).log_prob(self.transform(x))
File ~/miniconda3/lib/python3.9/site-packages/torch/distributions/lkj_cholesky.py:117, in LKJCholesky.log_prob(self, value)
115 order = torch.arange(2, self.dim + 1)
116 order = 2 * (self.concentration - 1).unsqueeze(-1) + self.dim - order
--> 117 unnormalized_log_pdf = torch.sum(order * diag_elems.log(), dim=-1)
118 # Compute normalization constant (page 1999 of [1])
119 dm1 = self.dim - 1
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
Expected Behavior
LKJCholesky log prob shouldn’t error.
System information
Please complete the following information:
- pytorch 1.11
Additional context
Might be fixed in pytorch nightly but thought I’d point it out:
https://github.com/pytorch/pytorch/issues/58774
I could also put up a PR here copying that log prob and just enforcing the order tensor to have the proper device.
Issue Analytics
- State:
- Created a year ago
- Comments:5
Top Results From Across the Web
Wishart / InverseWishart / LKJ priors · Issue #1692 · pyro-ppl ...
Actually modeling-wise the LKJ prior would be more useful, but I tried that and it also runs into the same error. Here's my...
Read more >Source code for gpytorch.priors.lkj_prior
LKJCholeskyFactorPrior is different from LKJPrior in that it accepts the Cholesky factor of the correlation matrix to compute probabilities.
Read more >Difficulties with lkj_corr - Google Groups
When I run the model using the lkj distribution I get the error: ... you do have to register the prior explicitly when...
Read more >Speeding up a logistic regression with RHS prior (Turing vs ...
Case study: Speeding up a logistic regression with RHS prior (Turing vs Numpyro) - any tricks I'm ... I think the best way...
Read more >CommaInitializer.h error running model on GPU
I get a c++ error ... The HPC isn't giving me any error that I can see. ... cholesky factor of correlation matrix...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
This is a bug which was fixed by @dme65 last month, that’s why the nightly seems to work fine. I’ll also comment on the issue.
Unfortunately, I don’t have a good solution in the interim. A hacky one that we have previously used in Pyro is to monkey-patch until the solution makes it to release. See https://github.com/pyro-ppl/pyro/blob/dev/pyro/distributions/torch_patch.py#L65 on one way to patch the log prob which can be removed on PyTorch’s next release.