Large negative eigenvalues in covariance matrices with CUDA and torch==1.9.1 in batch mode
See original GitHub issue🐛 Large negative eigenvalues in covariance matrices with CUDA and torch==1.9.1 in batch mode
The symptom: posterior predictive covariance matrices have very large negative eigenvalues, well beyond what can be cured with jitter.
The conditions: exact GP regression in batch mode on CUDA with torch==1.9.1. Dropping the torch version solves the problem, as does moving to CPU. The problem is not present with a single GP rather than a batch.
To reproduce
import gpytorch, torch
print(gpytorch.__version__)
print(torch.__version__)
# Set cuda=True and batch=True to get the problem.
# Turning either off to make it disappear.
cuda = True
batch = True
# These other flags don't affect the outcome but are there as sanity checks
index_0_batch = False
fixed_noise = True
learn_additional_noise = False
device = "cuda" if cuda else "cpu"
bs = torch.Size([11, ]) if batch else torch.Size([])
train_x = torch.randn(100, device=device, dtype=float)
train_y = torch.randn(100, device=device, dtype=float)
train_y_std = torch.randn(100, device=device, dtype=float)**2 * .1
test_x = torch.randn(200, device=device, dtype=float)
test_y = torch.randn(200, device=device, dtype=float)
class ExactGPModel(gpytorch.models.ExactGP):
def __init__(self, train_x, train_y, likelihood):
super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
self.mean_module = gpytorch.means.ConstantMean(batch_shape=bs)
self.covar_module = gpytorch.kernels.RBFKernel(batch_shape=bs)
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
if fixed_noise:
likelihood = gpytorch.likelihoods.FixedNoiseGaussianLikelihood(noise=train_y_std, batch_shape=bs,
learn_additional_noise=learn_additional_noise).to(device)
else:
likelihood = gpytorch.likelihoods.GaussianLikelihood(
batch_shape=bs).to(device)
model = ExactGPModel(train_x, train_y, likelihood).to(device)
model.train()
likelihood.train()
# Includes GaussianLikelihood parameters
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)
for i in range(1):
optimizer.zero_grad()
output = model(train_x)
loss = -mll(output, train_y)
if index_0_batch:
loss = loss[0]
else:
loss = loss.mean(dim=0)
loss.backward()
optimizer.step()
model.eval()
likelihood.eval()
f_preds = model(test_x)
y_preds = likelihood(model(test_x))
f_covar = f_preds.covariance_matrix
print(f"cuda: {cuda}")
print(f"batch mode: {batch}")
print(f"fixed y noise: {fixed_noise}")
if fixed_noise:
print(f"learn additional y noise: {learn_additional_noise}")
print("first batch index of loss" if index_0_batch else "mean loss over batch")
print(f"Min covariance eigenvalue: {torch.linalg.eigvalsh(f_covar).min()}")
1.5.1
1.9.1+cu102
cuda: True
batch mode: False
fixed y noise: True
learn additional y noise: False
mean loss over batch
Min covariance eigenvalue: -173010.1313435591
Expected Behavior
The expected behaviour is just what one gets with torch==1.9.0. That is, any numerical differences in the covariance matrices between CPU and CUDA should be minor rather than the major difference seen here.
System information
- gpytorch==1.5.1
- torch=={1.9.0, 1.9.1}
- Google Colab notebook.
Additional context
I found this while experimenting with Bayesian hyperparameter inference using variational inference. The predictive posterior and MLL require computing intractable integrals against the variational posterior. I’m using MC integration, so each of the batch GPs corresponds to i.i.d. samples from the hyperparameter variational posterior.
The code I’ve supplied above isn’t actually doing this, but I think I’ve stripped it back to the bare minimum required to produce the numerical problem.
I’m not very familiar with the pytorch code base, but the only differences between 1.9.0 and 1.9.1 that look at all related to linear algebra are some shape checks prior to matrix multiplication. (https://github.com/pytorch/pytorch/compare/v1.9.0...v1.9.1)
Issue Analytics
- State:
- Created 2 years ago
- Comments:11
Top GitHub Comments
Thanks for debugging @IvanYashchuk, would you file a more detailed issue about what’s causing the problem in the PyTorch GitHub so we can review it there?
With 1.9.0 (as confirmed by the print version)
And then running
pip install torch==1.9.1
and restarting the kernelSo unless I’m missing something, torch 1.9.0 <-> 1.9.1 is the only change?