Saving kernels (pickle)
See original GitHub issueHowdy folks,
I am aware that one can save the state_dict using torch.save(). But since I am experimenting with different kernel structures I thought it would be nice to keep track of them by putting them in a simple data base like shelve
and then be able to recall them when I decided I wanted to train (new) data on it.
So I was thinking of something simple like this:
import shelve
class KernelRepo:
def __init__(self, shelve_db_filename="default_kernel_shelve.db"):
self._db = shelve_db_filename
def store_kernel(self, key, kernel):
assert(isinstance(key,str))
s = shelve.open(self._db)
try:
s[key] = kernel
finally:
s.close()
def get_kernel(self, key):
assert(isinstance(key,str))
kernel = None
s = shelve.open(self._db, flag='r')
try:
kernel = s[key]
finally:
s.close()
return kernel
And then when I have a kernel that I would like to store, I would do something like this:
from prototype.kernel_repository import KernelRepo
import gpytorch
default_kernel = gpytorch.kernels.RBFKernel()*gpytorch.kernels.PolynomialKernel(power=2) + \
gpytorch.kernels.ScaleKernel(gpytorch.kernels.LinearKernel())
kernel_repo = KernelRepo()
kernel_repo.store_kernel("default kernel", default_kernel)
Unfortunately I am running in the following issue.
_pickle.PicklingError: Can't pickle <built-in function softplus>: import of module 'torch._C._nn' failed
I am relatively new to Python so I am still trying to work out what the root cause of this issue is, but perhaps someone has already run into something similar, or has a better solution for “managing kernels”.
My ultimate goal would be to automatically create key-tags so that I can keep track of different kernels, and automate the kernel selection process by using something like an ABC-SMC method.
But for now I would be really happy with being able to store data kernels so I can recall them programmatically using something like shelve
or something else.
Any ideas, suggestions or solutions?
Thanks in advance
Galto
Issue Analytics
- State:
- Created 4 years ago
- Comments:6 (2 by maintainers)
Top GitHub Comments
Hi @PhilippThoelke – if you want to put up a PR, I think this is a reasonable solution. I think using
torch.nn.Softplus
would be the best solution, as we don’t use nondefault parametersbeta
orthreshold
?I ran into the same issue and can’t exchange pickle for dill as it is being used inside pytorch-lightning to copy the model for distributed training and I don’t really want to switch out pickle for dill inside pytorch-lightning.
The line that seems to be causing the problem is https://github.com/cornellius-gp/gpytorch/blob/8f9b44fc57dbb0a13b568946f07a37e9332f92c4/gpytorch/constraints/constraints.py#L8 The functional version of softplus in PyTorch is not defined using Python but inside a C extension (
torch._C._nn
). Pickle doesn’t seem to like this. Removing this import and defining softplus insideconstraints.py
fixes the pickling problem.It even works if you just define
softplus
as PyTorch’sSoftplus
module instead of the functional version:This removes the
beta
andthreshold
parameters from the softplus call though so just implementing softplus from hand might be better here. E.g.:I can create a pull request if you think this is a good enough fix.