[Bug] Reloading saved parameters into a variational model hurts performance
See original GitHub issue🐛 Bug
The performance of the saved SVGP model differs a lot if the model is not newly instantiated.
To reproduce
** Code snippet to reproduce **
import os
from math import floor
import tqdm
import gpytorch
from gpytorch.models import ApproximateGP
from gpytorch.variational import CholeskyVariationalDistribution
from gpytorch.variational import VariationalStrategy
from scipy.io import loadmat
from sklearn.metrics import mean_squared_error
import torch
from torch.utils.data import TensorDataset, DataLoader
import urllib.request
if not os.path.isfile('../elevators.mat'):
print('Downloading \'elevators\' UCI dataset...')
urllib.request.urlretrieve('https://drive.google.com/uc?export=download&id=1jhWL3YUHvXIaftia4qeAyDwVxo6j1alk', '../elevators.mat')
data = torch.Tensor(loadmat('../elevators.mat')['data'])
X = data[:, :-1]
X = X - X.min(0)[0]
X = 2 * (X / X.max(0)[0]) - 1
y = data[:, -1]
train_n = int(floor(0.8 * len(X)))
train_x = X[:train_n, :].contiguous()
train_y = y[:train_n].contiguous()
test_x = X[train_n:, :].contiguous()
test_y = y[train_n:].contiguous()
train_dataset = TensorDataset(train_x, train_y)
train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True)
test_dataset = TensorDataset(test_x, test_y)
test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False)
class GPModel(ApproximateGP):
def __init__(self, inducing_points):
variational_distribution = CholeskyVariationalDistribution(inducing_points.size(0))
variational_strategy = VariationalStrategy(self, inducing_points, variational_distribution, learn_inducing_locations=True)
super(GPModel, self).__init__(variational_strategy)
self.mean_module = gpytorch.means.ConstantMean()
self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
inducing_points = train_x[:500, :]
model = GPModel(inducing_points=inducing_points)
likelihood = gpytorch.likelihoods.GaussianLikelihood()
num_epochs = 4
model.train()
likelihood.train()
# We use SGD here, rather than Adam. Emperically, we find that SGD is better for variational regression
optimizer = torch.optim.Adam([
{'params': model.parameters()},
{'params': likelihood.parameters()},
], lr=0.01)
# Our loss object. We're using the VariationalELBO
mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=train_y.size(0))
epochs_iter = tqdm.tqdm(range(num_epochs), desc="Epoch")
for i in epochs_iter:
# Within each iteration, we will go over each minibatch of data
minibatch_iter = tqdm.tqdm(train_loader, desc="Minibatch", leave=False)
for x_batch, y_batch in minibatch_iter:
optimizer.zero_grad()
output = model(x_batch)
loss = -mll(output, y_batch)
minibatch_iter.set_postfix(loss=loss.item())
loss.backward()
optimizer.step()
model.eval()
likelihood.eval()
y_pred_list = []
y_true_list = []
with torch.no_grad():
for x_batch, y_batch in test_loader:
preds = model(x_batch)
y_pred_list.append(preds.mean)
y_true_list.append(y_batch)
y_pred = torch.cat(y_pred_list, dim=0).numpy()
y_true = torch.cat(y_true_list, dim=0).numpy()
test_mse = mean_squared_error(y_true, y_pred)
print(test_mse)
# torch.save(model.state_dict(), 'my_gp_with_nn_model.pth') # uncomment once to save the model on the disk
state_dict = torch.load('my_gp_with_nn_model.pth')
# inducing_points = train_x[:500, :]
# model = GPModel(inducing_points=inducing_points)
# likelihood = gpytorch.likelihoods.GaussianLikelihood()
model.load_state_dict(state_dict)
model.eval()
likelihood.eval()
y_pred_list = []
y_true_list = []
with torch.no_grad():
for x_batch, y_batch in test_loader:
preds = model(x_batch)
y_pred_list.append(preds.mean)
y_true_list.append(y_batch)
y_pred = torch.cat(y_pred_list, dim=0).numpy()
y_true = torch.cat(y_true_list, dim=0).numpy()
test_mse = mean_squared_error(y_true, y_pred)
print(test_mse)
** Stack trace/error message **
0.010551954
0.23515734 # or 0.5873753 or 0.40703666, unpredictable
Expected Behavior
The performance of a saved model is expected to be consistent after loading from the saved file. However, whether instantiate it from class seems to affect the performance a lot.
In the snippet, the interesting part is
inducing_points = train_x[:500, :]
model = GPModel(inducing_points=inducing_points)
likelihood = gpytorch.likelihoods.GaussianLikelihood()
If I uncomment it, i.e., I instantiate a new model and load the weights to it, the performance is always consistent. The test MSE of the trained model and saved model is almost very close to each other. However, if I don’t instantiate a new model, but rather load the weights to an existing trained model, the performance is unpredictable.
System information
Please complete the following information:
- GPyTorch Version: 1.2.0
- PyTorch Version: 1.6.0
- Debian 10
Additional context
I came across the bug when I saved the model to disk based on some validation score and reload it after the whole training and test it on a test set in the same script. The reloading was done to an existing trained model, rather than a newly instantiated model. When I start a new script, instantiate a new model, load the saved weights to double-check the test performance, the number turns out to be different.
I also tested the same thing on ExactGP, which doesn’t have such a problem. I am not sure whether such behavior is expected by the library itself. Pure PyTorch modules don’t seem to have such behavior.
Issue Analytics
- State:
- Created 3 years ago
- Comments:12 (5 by maintainers)
Top GitHub Comments
Huh this is very strange. Thanks @ZhiliangWu and @irum for letting us know. I’ll take a look later today!
The reason for this behavior is that the variational models cache some of the expensive computations. Loading from the state dict does not clear these precomputed caches (though it should) - and so the loaded model is using an outdated cache. This is what is causing the error discrepancies.
I’m putting up a PR now to fix this. A hacky fix for now is to run:
After this PR is in, you won’t have to call
train
theneval
.