loading pytorch confuses gaussian process kernel
See original GitHub issueDescribe the bug
Loading pytorch, e.g., with import torch
seems to deteriorate the accuracy of the gaussian process regression.
Steps/Code to Reproduce
I adopt the example from https://scikit-learn.org/stable/auto_examples/gaussian_process/plot_gpr_noisy_targets.html.
Example loading pytoch before making the kernel:
import numpy as np
from matplotlib import pyplot as plt
import scipy
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process import kernels
import torch # take out this line to fix the result
np.random.seed(1)
def f(x):
"""The function to predict."""
return x * np.sin(x)
# Observations
X = np.atleast_2d(np.random.normal(scale=5, size=1000)).T
y = f(X).ravel()
dy = 0.5 + 1.0 * np.random.random(y.shape)
noise = np.random.normal(0, dy)
y += noise
# Mesh the input space for evaluations of the real function, the prediction and
# its MSE
x = np.atleast_2d(np.linspace(np.min(X), np.max(X), 1000)).T
# Instantiate a Gaussian Process model
kernel = (
kernels.ConstantKernel(1.0, constant_value_bounds='fixed')
* kernels.RBF(1., (1e-2, 1e2))
+ kernels.WhiteKernel()
)
gp = GaussianProcessRegressor(kernel=kernel)
# Fit to data using Maximum Likelihood Estimation of the parameters
gp.fit(X, y)
# Make the prediction on the meshed x-axis (ask for MSE as well)
y_pred, sigma = gp.predict(x, return_std=True)
# Plot the function, the prediction and the 95% confidence interval based on
# the MSE
plt.figure()
plt.plot(x, f(x), 'r:', label=r'$f(x) = x\,\sin(x)$')
plt.plot(X, y, 'r.', markersize=10, label='Observations')
plt.plot(x, y_pred, 'b-', label='Prediction')
plt.fill(np.concatenate([x, x[::-1]]),
np.concatenate([y_pred - 1.9600 * sigma,
(y_pred + 1.9600 * sigma)[::-1]]),
alpha=.5, fc='b', ec='None', label='95% confidence interval')
plt.xlabel('$x$')
plt.ylabel('$f(x)$')
plt.ylim(-10, 20)
plt.legend(loc='upper left')
Example Results
Expected Results
Actual Results
Further, passing n_restarts_optimizer=9
tends to crash the optimization of pytorch was loaded before the first time defining a kernel.
Further Results
In my application the discrepancy becomes more dramatic. Loading pytorch: result of the same code but without loading pytorch:
Versions
System:
python: 3.8.8 | packaged by conda-forge | (default, Feb 20 2021, 16:22:27) [GCC 9.3.0]
executable: /home/dotto/.conda/envs/biopytorch-7/bin/python
machine: Linux-4.15.0-136-generic-x86_64-with-glibc2.10
Python dependencies:
pip: 21.0.1
setuptools: 49.6.0.post20210108
sklearn: 0.24.1
numpy: 1.20.1
scipy: 1.6.2
Cython: 0.29.23
pandas: 1.2.2
matplotlib: 3.3.4
joblib: 1.0.1
threadpoolctl: 2.1.0
Built with OpenMP: True
torch: 1.7.1
Issue Analytics
- State:
- Created 2 years ago
- Comments:6 (4 by maintainers)
Top Results From Across the Web
How to combine different kernels for Gaussian process in ...
I am trying to learn gaussian process by using GPyTorch to fit a Gaussian Process Regression model. However, I can't figure out a...
Read more >Applying kernels to separate dimensions in gpytorch
This example is helpful, but can't find an explanation for how to apply distinct kernels to individual dimensions in a multi-dimensional GPR ( ......
Read more >Saving and Loading Models — GPyTorch 1.9.0 documentation
In this bite-sized notebook, we'll go over how to save and load models. In general, the process is the same as for any...
Read more >[Bayesian DL] 2. Gaussian Process | by jun94 | jun-devpBlog
It might be confusing, but in Gaussian process the random variable(s) ... Why the covariance can be substituted by a kernel is shown...
Read more >How to efficiently subsample from large images - vision
Hi, I'm new to Pytorch and deep learning in general. I'm developing a bacterial cell segmentation tool for microscopy with Pytorch/Unet.
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
I’m not good either 😬. I don’t think they patch numpy and scipy but I wouldn’t be surprised if they patched something (like the backend or something like that)
In a new env using the condo package
pytorch-1.8.1-py3.9_cuda10.2_cudnn7.6.5_0
I can also not reproduce the error. So it may just be an issue with the specific pytorch version. Thank you for considering and testing!!