[BUG] unable to checkpoint model with qml.qnn.TorchLayer
See original GitHub issueExpected behavior
something like:
Average loss over epoch 1: 0.4803
Average loss over epoch 2: 0.3553
Accuracy: 78.0%
Output file is generated by PATH (no output errors)
Actual behavior
Returned
Average loss over epoch 1: 0.4803
Average loss over epoch 2: 0.3553
Accuracy: 78.0%
However checkpoint file is created but stops short of including full model.
PyTorch unable to pickle qnode: _pickle.PicklingError: Can't pickle <function qnode at 0x7fc7e4169160>: it's not the same object as __main__.qnode
Additional information
Error occurs every time.
Source code
# Coping and pasting code from: https://pennylane.ai/qml/demos/tutorial_qnn_module_torch.html
import torch
import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import make_moons
# Set random seeds
torch.manual_seed(42)
np.random.seed(42)
X, y = make_moons(n_samples=200, noise=0.1)
y_ = torch.unsqueeze(torch.tensor(y), 1) # used for one-hot encoded labels
y_hot = torch.scatter(torch.zeros((200, 2)), 1, y_, 1)
c = ["#1f77b4" if y_ == 0 else "#ff7f0e" for y_ in y] # colours for each class
#plt.axis("off")
#plt.scatter(X[:, 0], X[:, 1], c=c)
#plt.show()
import pennylane as qml
n_qubits = 2
dev = qml.device("default.qubit", wires=n_qubits)
@qml.qnode(dev)
def qnode(inputs, weights):
qml.AngleEmbedding(inputs, wires=range(n_qubits))
qml.BasicEntanglerLayers(weights, wires=range(n_qubits))
return [qml.expval(qml.PauliZ(wires=i)) for i in range(n_qubits)]
n_layers = 6
weight_shapes = {"weights": (n_layers, n_qubits)}
qlayer = qml.qnn.TorchLayer(qnode, weight_shapes)
clayer_1 = torch.nn.Linear(2, 2)
clayer_2 = torch.nn.Linear(2, 2)
softmax = torch.nn.Softmax(dim=1)
layers = [clayer_1, qlayer, clayer_2, softmax]
model = torch.nn.Sequential(*layers)
opt = torch.optim.SGD(model.parameters(), lr=0.2)
loss = torch.nn.L1Loss()
X = torch.tensor(X, requires_grad=True).float()
y_hot = y_hot.float()
batch_size = 5
batches = 200 // batch_size
data_loader = torch.utils.data.DataLoader(
list(zip(X, y_hot)), batch_size=5, shuffle=True, drop_last=True
)
epochs = 2
for epoch in range(epochs):
running_loss = 0
for xs, ys in data_loader:
opt.zero_grad()
loss_evaluated = loss(model(xs), ys)
loss_evaluated.backward()
opt.step()
running_loss += loss_evaluated
avg_loss = running_loss / batches
print("Average loss over epoch {}: {:.4f}".format(epoch + 1, avg_loss))
y_pred = model(X)
predictions = torch.argmax(y_pred, axis=1).detach().numpy()
correct = [1 if p == p_true else 0 for p, p_true in zip(predictions, y)]
accuracy = sum(correct) / len(correct)
print(f"Accuracy: {accuracy * 100}%")
# Saving a checkpoint after training
PATH = "./checkpoint.pth"
torch.save({
'model': model,
'model_state_dict': model.state_dict(),
'optimizer': opt,
'optimizer_state_dict': opt.state_dict(),
'loss': running_loss,
}, PATH)
Tracebacks
Average loss over epoch 1: 0.4943
Average loss over epoch 2: 0.4226
Accuracy: 75.5%
Traceback (most recent call last):
File "/home/PL_save.py", line 84, in <module>
torch.save({
File "/home/miniconda3/envs/py9/lib/python3.9/site-packages/torch/serialization.py", line 379, in save
_save(obj, opened_zipfile, pickle_module, pickle_protocol)
File "/home/miniconda3/envs/py9/lib/python3.9/site-packages/torch/serialization.py", line 484, in _save
pickler.dump(obj)
_pickle.PicklingError: Can't pickle <function qnode at 0x7fc7e4169160>: it's not the same object as __main__.qnode
System information
>>> import pennylane as qml; qml.about()
Name: PennyLane
Version: 0.21.0
Summary: PennyLane is a Python quantum machine learning library by Xanadu Inc.
Home-page: https://github.com/XanaduAI/pennylane
Author:
Author-email:
License: Apache License 2.0
Location: /home/miniconda3/envs/py9/lib/python3.9/site-packages
Requires: autoray, retworkx, cachetools, semantic-version, scipy, pennylane-lightning, networkx, numpy, toml, appdirs, autograd
Required-by: PennyLane-Lightning
Platform info: Linux-4.18.0-348.7.1.el8_5.x86_64-x86_64-with-glibc2.28
Python version: 3.9.7
Numpy version: 1.22.2
Scipy version: 1.8.0
Installed devices:
- default.gaussian (PennyLane-0.21.0)
- default.mixed (PennyLane-0.21.0)
- default.qubit (PennyLane-0.21.0)
- default.qubit.autograd (PennyLane-0.21.0)
- default.qubit.jax (PennyLane-0.21.0)
- default.qubit.tf (PennyLane-0.21.0)
- default.qubit.torch (PennyLane-0.21.0)
- lightning.qubit (PennyLane-Lightning-0.21.0)
Existing GitHub issues
- I have searched existing GitHub issues to make sure the issue does not already exist.
Issue Analytics
- State:
- Created 2 years ago
- Comments:7 (4 by maintainers)
Top Results From Across the Web
qml.qnn.TorchLayer — PennyLane 0.27.0 documentation
Converts a QNode() to a Torch layer. The result can be used within the torch.nn Sequential or Module classes for creating quantum and...
Read more >Problem with loading checkpoint of a model with embeddings
Bug Unable to load from checkpoint for model with embeddings Code sample model arch class Model(pl.LightningModule): def __init__(self, ...
Read more >Loading model from checkpoint is not working - Stack Overflow
When I try and use the trained model I am unable to load the weights using load_from_checkpoint . It seems there is a...
Read more >Simple index - piwheels
... autocharles eresponse cyber-proto sliceoptim tagy print2log roughrider-token pyabi django-dynamic-models ianmatsontestproject pelican-decorate-content ...
Read more >torch.argmax "axis" - ms.fr.edu.vn Search
[BUG] unable to checkpoint model with qml.qnn.TorchLayer - Lightrun lightrun.com › answers › pennylaneai-pennylan... ... avg_loss)) y_pred = model(X) ...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found

I have a temporary workaround while I’m working on a better fix.
The problem is that we can’t pickle the qnode. This happens because of naming confusion between the QNode and the function that the QNode is wrapping.
If you construct the QNode with
qml.QNodelike:This would eliminate the naming confusion and create a pickle-able object.
Hi everyone! I stumbled upon this cool package called
cloudpickletoday: https://github.com/cloudpipe/cloudpickleIf I replace the call to
torch.savein the posted source code with the following code, I get no errors.Then, one can load back the dictionary with
pickleas follows.After unpickling, it also works to continue right where you left off!
No issues