question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[Bug] Reloading saved parameters into a variational model hurts performance

See original GitHub issue

🐛 Bug

The performance of the saved SVGP model differs a lot if the model is not newly instantiated.

To reproduce

** Code snippet to reproduce **

import os

from math import floor

import tqdm
import gpytorch
from gpytorch.models import ApproximateGP
from gpytorch.variational import CholeskyVariationalDistribution
from gpytorch.variational import VariationalStrategy
from scipy.io import loadmat
from sklearn.metrics import mean_squared_error
import torch
from torch.utils.data import TensorDataset, DataLoader
import urllib.request


if not os.path.isfile('../elevators.mat'):
    print('Downloading \'elevators\' UCI dataset...')
    urllib.request.urlretrieve('https://drive.google.com/uc?export=download&id=1jhWL3YUHvXIaftia4qeAyDwVxo6j1alk', '../elevators.mat')

data = torch.Tensor(loadmat('../elevators.mat')['data'])
X = data[:, :-1]
X = X - X.min(0)[0]
X = 2 * (X / X.max(0)[0]) - 1
y = data[:, -1]


train_n = int(floor(0.8 * len(X)))
train_x = X[:train_n, :].contiguous()
train_y = y[:train_n].contiguous()

test_x = X[train_n:, :].contiguous()
test_y = y[train_n:].contiguous()


train_dataset = TensorDataset(train_x, train_y)
train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True)

test_dataset = TensorDataset(test_x, test_y)
test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False)


class GPModel(ApproximateGP):
    def __init__(self, inducing_points):
        variational_distribution = CholeskyVariationalDistribution(inducing_points.size(0))
        variational_strategy = VariationalStrategy(self, inducing_points, variational_distribution, learn_inducing_locations=True)
        super(GPModel, self).__init__(variational_strategy)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)


inducing_points = train_x[:500, :]
model = GPModel(inducing_points=inducing_points)
likelihood = gpytorch.likelihoods.GaussianLikelihood()

num_epochs = 4

model.train()
likelihood.train()

# We use SGD here, rather than Adam. Emperically, we find that SGD is better for variational regression
optimizer = torch.optim.Adam([
    {'params': model.parameters()},
    {'params': likelihood.parameters()},
], lr=0.01)

# Our loss object. We're using the VariationalELBO
mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=train_y.size(0))


epochs_iter = tqdm.tqdm(range(num_epochs), desc="Epoch")
for i in epochs_iter:
    # Within each iteration, we will go over each minibatch of data
    minibatch_iter = tqdm.tqdm(train_loader, desc="Minibatch", leave=False)
    for x_batch, y_batch in minibatch_iter:
        optimizer.zero_grad()
        output = model(x_batch)
        loss = -mll(output, y_batch)
        minibatch_iter.set_postfix(loss=loss.item())
        loss.backward()
        optimizer.step()


model.eval()
likelihood.eval()
y_pred_list = []
y_true_list = []
with torch.no_grad():
    for x_batch, y_batch in test_loader:
        preds = model(x_batch)
        y_pred_list.append(preds.mean)
        y_true_list.append(y_batch)

y_pred = torch.cat(y_pred_list, dim=0).numpy()
y_true = torch.cat(y_true_list, dim=0).numpy()

test_mse = mean_squared_error(y_true, y_pred)
print(test_mse)


# torch.save(model.state_dict(), 'my_gp_with_nn_model.pth')    # uncomment once to save the model on the disk
state_dict = torch.load('my_gp_with_nn_model.pth')

# inducing_points = train_x[:500, :]
# model = GPModel(inducing_points=inducing_points)
# likelihood = gpytorch.likelihoods.GaussianLikelihood()

model.load_state_dict(state_dict)

model.eval()
likelihood.eval()
y_pred_list = []
y_true_list = []
with torch.no_grad():
    for x_batch, y_batch in test_loader:
        preds = model(x_batch)
        y_pred_list.append(preds.mean)
        y_true_list.append(y_batch)

y_pred = torch.cat(y_pred_list, dim=0).numpy()
y_true = torch.cat(y_true_list, dim=0).numpy()

test_mse = mean_squared_error(y_true, y_pred)
print(test_mse)

** Stack trace/error message **

0.010551954
0.23515734   # or 0.5873753 or 0.40703666, unpredictable

Expected Behavior

The performance of a saved model is expected to be consistent after loading from the saved file. However, whether instantiate it from class seems to affect the performance a lot.

In the snippet, the interesting part is

inducing_points = train_x[:500, :]
model = GPModel(inducing_points=inducing_points)
likelihood = gpytorch.likelihoods.GaussianLikelihood()

If I uncomment it, i.e., I instantiate a new model and load the weights to it, the performance is always consistent. The test MSE of the trained model and saved model is almost very close to each other. However, if I don’t instantiate a new model, but rather load the weights to an existing trained model, the performance is unpredictable.

System information

Please complete the following information:

  • GPyTorch Version: 1.2.0
  • PyTorch Version: 1.6.0
  • Debian 10

Additional context

I came across the bug when I saved the model to disk based on some validation score and reload it after the whole training and test it on a test set in the same script. The reloading was done to an existing trained model, rather than a newly instantiated model. When I start a new script, instantiate a new model, load the saved weights to double-check the test performance, the number turns out to be different.

I also tested the same thing on ExactGP, which doesn’t have such a problem. I am not sure whether such behavior is expected by the library itself. Pure PyTorch modules don’t seem to have such behavior.

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:12 (5 by maintainers)

github_iconTop GitHub Comments

2reactions
gpleisscommented, Oct 15, 2020

Huh this is very strange. Thanks @ZhiliangWu and @irum for letting us know. I’ll take a look later today!

1reaction
gpleisscommented, Oct 15, 2020

The reason for this behavior is that the variational models cache some of the expensive computations. Loading from the state dict does not clear these precomputed caches (though it should) - and so the loaded model is using an outdated cache. This is what is causing the error discrepancies.

I’m putting up a PR now to fix this. A hacky fix for now is to run:

model.load_state_dict(state_dict)
model.train()  # this clears any precomputed caches
model.eval()

After this PR is in, you won’t have to call train then eval.

Read more comments on GitHub >

github_iconTop Results From Across the Web

[Bug] Reloading saved parameters into a variational model hurts ...
Bug. The performance of the saved SVGP model differs a lot if the model is not newly instantiated. To reproduce. ** Code snippet...
Read more >
Save, Load and Inference From TensorFlow Frozen Graph
Working with the models loaded from pb files is a little bit painful since you will have to work with tensor names all...
Read more >
Improving Diffusion Models as an Alternative To GANs, Part 2
In this post, we review three recent techniques developed at NVIDIA for overcoming the slow sampling challenge in diffusion models.
Read more >
World Models (the long version) - ADG Efficiency
A model that is able to approximate the environment transition dynamics can be used recurrently to generate rollouts of simulated experience.
Read more >
nvae a deep hierarchical variational autoencoder github
We can see that removing any of these components hurts performance. ... To tackle this issue, we use two tricks: (i) We define...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found