[Bug] Prior losses are incorrectly added to the mll in batch-mode
See original GitHub issue🐛 Bug
Prior losses are currently being added up incorrectly in ExactMarginalLogLikelihood
. The line:
res.add_(prior.log_prob(closure()).sum())
will sum up all of the losses and then add them to the mll. If you are using a batch model this sum gets added to all of the batch dimensions which will count the losses multiple times when eventually calling loss.sum().backward()
. It looks like the priors may not support batch mode which leads to a large variety of different shapes, but the .sum()
call masks this issue since it just sums everything up anyway.
To reproduce
Code snippet (taken from test_train_on_batch_test_on_batch
):
import math
import torch
import gpytorch
from gpytorch.distributions import MultivariateNormal
from gpytorch.kernels import RBFKernel, ScaleKernel
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.means import ConstantMean
train_x1 = torch.linspace(0, 2, 11).unsqueeze(-1)
train_y1 = torch.sin(train_x1 * (2 * math.pi)).squeeze()
train_x2 = torch.linspace(0, 1, 11).unsqueeze(-1)
train_y2 = torch.sin(train_x2 * (2 * math.pi)).squeeze()
train_x12 = torch.cat((train_x1.unsqueeze(0), train_x2.unsqueeze(0)), dim=0).contiguous()
train_y12 = torch.cat((train_y1.unsqueeze(0), train_y2.unsqueeze(0)), dim=0).contiguous()
class ExactGPModel(gpytorch.models.ExactGP):
def __init__(self, train_inputs, train_targets, likelihood, batch_shape=torch.Size()):
super(ExactGPModel, self).__init__(train_inputs, train_targets, likelihood)
self.mean_module = ConstantMean(batch_shape=batch_shape, prior=gpytorch.priors.SmoothedBoxPrior(-1, 1))
self.covar_module = ScaleKernel(
RBFKernel(
batch_shape=batch_shape,
lengthscale_prior=gpytorch.priors.NormalPrior(
loc=torch.zeros(*batch_shape, 1, 1), scale=torch.ones(*batch_shape, 1, 1)
),
),
batch_shape=batch_shape,
outputscale_prior=gpytorch.priors.SmoothedBoxPrior(-2, 2),
)
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return MultivariateNormal(mean_x, covar_x)
# We're manually going to set the hyperparameters to something they shouldn't be
likelihood = GaussianLikelihood(
noise_prior=gpytorch.priors.NormalPrior(loc=torch.zeros(2), scale=torch.ones(2)),
batch_shape=torch.Size([2]),
)
gp_model = ExactGPModel(train_x12, train_y12, likelihood, batch_shape=torch.Size([2]))
for name, prior, closure, _ in gp_model.named_priors():
print(name, prior.log_prob(closure()).shape)
Output:
likelihood.noise_covar.noise_prior torch.Size([2, 2])
mean_module.mean_prior torch.Size([2])
covar_module.outputscale_prior torch.Size([])
covar_module.base_kernel.lengthscale_prior torch.Size([2, 1, 1])
Expected Behavior
The prior losses should have the same size and be added up via res.add_(prior.log_prob(closure()))
without the inner sum call.
System information
Please complete the following information: GPyTorch Version: 1.2.0 PyTorch Version: 1.6.0 Mac
Additional context
This was originally discovered in PR #1314.
cc: @Balandat
Issue Analytics
- State:
- Created 3 years ago
- Comments:7 (5 by maintainers)
Top GitHub Comments
Oof okay this is pretty bad. I’ll get #1317 and #1318 fixed early this week, and we can push out a 1.2.1
Oops, I commented on #1317 without looking at updates here. What about the case where we’re using a batched univariate prior for a vector-valued hyperparam (like lengthscales with ARD)? Seems like we shouldn’t always expand.