question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[BUG] unable to checkpoint model with qml.qnn.TorchLayer

See original GitHub issue

Expected behavior

something like:

Average loss over epoch 1: 0.4803
Average loss over epoch 2: 0.3553
Accuracy: 78.0%

Output file is generated by PATH (no output errors)

Actual behavior

Returned

Average loss over epoch 1: 0.4803
Average loss over epoch 2: 0.3553
Accuracy: 78.0%

However checkpoint file is created but stops short of including full model.

PyTorch unable to pickle qnode: _pickle.PicklingError: Can't pickle <function qnode at 0x7fc7e4169160>: it's not the same object as __main__.qnode

Additional information

Error occurs every time.

Source code

# Coping and pasting code from: https://pennylane.ai/qml/demos/tutorial_qnn_module_torch.html

import torch
import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import make_moons

# Set random seeds
torch.manual_seed(42)
np.random.seed(42)

X, y = make_moons(n_samples=200, noise=0.1)
y_ = torch.unsqueeze(torch.tensor(y), 1)  # used for one-hot encoded labels
y_hot = torch.scatter(torch.zeros((200, 2)), 1, y_, 1)

c = ["#1f77b4" if y_ == 0 else "#ff7f0e" for y_ in y]  # colours for each class
#plt.axis("off")
#plt.scatter(X[:, 0], X[:, 1], c=c)
#plt.show()

import pennylane as qml

n_qubits = 2
dev = qml.device("default.qubit", wires=n_qubits)

@qml.qnode(dev)
def qnode(inputs, weights):
    qml.AngleEmbedding(inputs, wires=range(n_qubits))
    qml.BasicEntanglerLayers(weights, wires=range(n_qubits))
    return [qml.expval(qml.PauliZ(wires=i)) for i in range(n_qubits)]

n_layers = 6
weight_shapes = {"weights": (n_layers, n_qubits)}

qlayer = qml.qnn.TorchLayer(qnode, weight_shapes)

clayer_1 = torch.nn.Linear(2, 2)
clayer_2 = torch.nn.Linear(2, 2)
softmax = torch.nn.Softmax(dim=1)
layers = [clayer_1, qlayer, clayer_2, softmax]
model = torch.nn.Sequential(*layers)

opt = torch.optim.SGD(model.parameters(), lr=0.2)
loss = torch.nn.L1Loss()

X = torch.tensor(X, requires_grad=True).float()
y_hot = y_hot.float()

batch_size = 5
batches = 200 // batch_size

data_loader = torch.utils.data.DataLoader(
    list(zip(X, y_hot)), batch_size=5, shuffle=True, drop_last=True
)

epochs = 2

for epoch in range(epochs):

    running_loss = 0

    for xs, ys in data_loader:
        opt.zero_grad()

        loss_evaluated = loss(model(xs), ys)
        loss_evaluated.backward()

        opt.step()

        running_loss += loss_evaluated

    avg_loss = running_loss / batches
    print("Average loss over epoch {}: {:.4f}".format(epoch + 1, avg_loss))

y_pred = model(X)
predictions = torch.argmax(y_pred, axis=1).detach().numpy()

correct = [1 if p == p_true else 0 for p, p_true in zip(predictions, y)]
accuracy = sum(correct) / len(correct)
print(f"Accuracy: {accuracy * 100}%")

# Saving a checkpoint after training

PATH = "./checkpoint.pth"

torch.save({
            'model': model,
            'model_state_dict': model.state_dict(),
            'optimizer': opt,
            'optimizer_state_dict': opt.state_dict(),
            'loss': running_loss,
            }, PATH)

Tracebacks

Average loss over epoch 1: 0.4943
Average loss over epoch 2: 0.4226
Accuracy: 75.5%
Traceback (most recent call last):
  File "/home/PL_save.py", line 84, in <module>
    torch.save({
  File "/home/miniconda3/envs/py9/lib/python3.9/site-packages/torch/serialization.py", line 379, in save
    _save(obj, opened_zipfile, pickle_module, pickle_protocol)
  File "/home/miniconda3/envs/py9/lib/python3.9/site-packages/torch/serialization.py", line 484, in _save
    pickler.dump(obj)
_pickle.PicklingError: Can't pickle <function qnode at 0x7fc7e4169160>: it's not the same object as __main__.qnode

System information

>>> import pennylane as qml; qml.about()
Name: PennyLane
Version: 0.21.0
Summary: PennyLane is a Python quantum machine learning library by Xanadu Inc.
Home-page: https://github.com/XanaduAI/pennylane
Author: 
Author-email: 
License: Apache License 2.0
Location: /home/miniconda3/envs/py9/lib/python3.9/site-packages
Requires: autoray, retworkx, cachetools, semantic-version, scipy, pennylane-lightning, networkx, numpy, toml, appdirs, autograd
Required-by: PennyLane-Lightning
Platform info:           Linux-4.18.0-348.7.1.el8_5.x86_64-x86_64-with-glibc2.28
Python version:          3.9.7
Numpy version:           1.22.2
Scipy version:           1.8.0
Installed devices:
- default.gaussian (PennyLane-0.21.0)
- default.mixed (PennyLane-0.21.0)
- default.qubit (PennyLane-0.21.0)
- default.qubit.autograd (PennyLane-0.21.0)
- default.qubit.jax (PennyLane-0.21.0)
- default.qubit.tf (PennyLane-0.21.0)
- default.qubit.torch (PennyLane-0.21.0)
- lightning.qubit (PennyLane-Lightning-0.21.0)

Existing GitHub issues

  • I have searched existing GitHub issues to make sure the issue does not already exist.

Issue Analytics

  • State:open
  • Created 2 years ago
  • Comments:7 (4 by maintainers)

github_iconTop GitHub Comments

2reactions
albi3rocommented, Feb 17, 2022

I have a temporary workaround while I’m working on a better fix.

The problem is that we can’t pickle the qnode. This happens because of naming confusion between the QNode and the function that the QNode is wrapping.

If you construct the QNode with qml.QNode like:

def qfunc(inputs, weights):
    qml.AngleEmbedding(inputs, wires=range(n_qubits))
    qml.BasicEntanglerLayers(weights, wires=range(n_qubits))
    return [qml.expval(qml.PauliZ(wires=i)) for i in range(n_qubits)]

qnode = qml.QNode(qfunc, dev)

This would eliminate the naming confusion and create a pickle-able object.

0reactions
isaacdevlugtcommented, Mar 30, 2022

Hi everyone! I stumbled upon this cool package called cloudpickle today: https://github.com/cloudpipe/cloudpickle

If I replace the call to torch.save in the posted source code with the following code, I get no errors.

import cloudpickle

dictionary = {
    "model": model,
    "model_state_dict": model.state_dict(),
    "optimizer": opt,
    "optimizer_state_dict": opt.state_dict(),
    "loss": running_loss,
}

with open(PATH, "wb") as f:
    cloudpickle.dump(dictionary, f)

Then, one can load back the dictionary with pickle as follows.

import pickle

with open(PATH, "rb") as f:
    dictionary = pickle.load(f)

print(dictionary)

After unpickling, it also works to continue right where you left off!

model = dictionary["model"]
for epoch in range(epochs):

    for xs, ys in data_loader:
        opt.zero_grad()

        loss_evaluated = loss(model(xs), ys)
        loss_evaluated.backward()

        opt.step()

        running_loss += loss_evaluated

    avg_loss = running_loss / batches
    print("Average loss over epoch {}: {:.4f}".format(epoch + 1, avg_loss))

y_pred = model(X)
predictions = torch.argmax(y_pred, axis=1).detach().numpy()

correct = [1 if p == p_true else 0 for p, p_true in zip(predictions, y)]
accuracy = sum(correct) / len(correct)
print(f"Accuracy: {accuracy * 100}%")

No issues

Read more comments on GitHub >

github_iconTop Results From Across the Web

qml.qnn.TorchLayer — PennyLane 0.27.0 documentation
Converts a QNode() to a Torch layer. The result can be used within the torch.nn Sequential or Module classes for creating quantum and...
Read more >
Problem with loading checkpoint of a model with embeddings
Bug Unable to load from checkpoint for model with embeddings Code sample model arch class Model(pl.LightningModule): def __init__(self, ...
Read more >
Loading model from checkpoint is not working - Stack Overflow
When I try and use the trained model I am unable to load the weights using load_from_checkpoint . It seems there is a...
Read more >
Simple index - piwheels
... autocharles eresponse cyber-proto sliceoptim tagy print2log roughrider-token pyabi django-dynamic-models ianmatsontestproject pelican-decorate-content ...
Read more >
torch.argmax "axis" - ms.fr.edu.vn Search
[BUG] unable to checkpoint model with qml.qnn.TorchLayer - Lightrun lightrun.com › answers › pennylaneai-pennylan... ... avg_loss)) y_pred = model(X) ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found