question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[BUG] With many observables `generate_shifted_tapes()` is called "unreasonably often" resulting in massive performance loss

See original GitHub issue

Expected behavior

When taking the parameter shift hessian of a QNode returning the expectation value of several compatible observables, I expect the construction of shifted tapes to only take place once per gate, then the derivatives of all observables can be computed from this single set of shifted gates.

Actual behavior

In reality parameter_shift.py:152:expval_param_shift is called for each observable separately and generate_shifted_tapes() is therefore called number of parametrized gates times number of observables many times.

This results in a massive performance hit. In the 10 qubit example below, PennyLane spends over 60% of the time copying tapes, compared to just ~4% with the simulation of circuits, even on the slow default.qubit simulator.

Additional information

To get a nice graphical overview of what is happening run the example below with cProfile, e.g… like this: python -m cProfile -o timing test.py && gprof2dot -f pstats timing -o graph.dot && dot -Tsvg graph.dot -o graph.svg

Source code

# test.py 
# run with:
# python -m cProfile -o timing test.py && gprof2dot -f pstats timing -o graph.dot && dot -Tsvg graph.dot -o graph.svg
import pennylane as qml
from pennylane import numpy as np

num_wires = 10
wires = range(num_wires)
dev = qml.device('default.qubit', num_wires)

np.random.seed(42)
init_params = np.random.randn(num_wires//2)

@qml.qnode(dev, diff_method='parameter-shift')
def qnode(params):
    for par_idx, wire in enumerate(wires[::2]):
        qml.Hadamard(wire)
        qml.CRY(params[par_idx], wires=[wire, wire+1])
    return [qml.expval(qml.PauliZ(wire1) @ qml.PauliZ(wire2)) for wire1 in wires for wire2 in wires if wire1 != wire2]

print(len([0 for wire1 in wires for wire2 in wires if wire1 != wire2]))

for _ in range(4):
    print(qml.jacobian(qnode)(init_params))

Tracebacks

No response

System information

Python 3.8.8 | packaged by conda-forge | (default, Feb 20 2021, 16:22:27)
[GCC 9.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import pennylane as qml; qml.about()
WARNING: pip is being invoked by an old script wrapper. This will fail in a future version of pip.
Please see https://github.com/pypa/pip/issues/5599 for advice on fixing the underlying issue.
To avoid this problem you can invoke Python with '-m pip' instead of running pip directly.
Name: PennyLane
Version: 0.21.0.dev0
Summary: PennyLane is a Python quantum machine learning library by Xanadu Inc.
Home-page: https://github.com/XanaduAI/pennylane
Author: None
Author-email: None
License: Apache License 2.0
Location: /home/cvjjm/src/covqcstack/qcware/pennylane
Requires: numpy, scipy, networkx, retworkx, autograd, toml, appdirs, semantic-version, autoray, cachetools, pennylane-lightning
Required-by: pytket-pennylane, PennyLane-Qchem, PennyLane-Lightning, covvqetools
Platform info:           Linux-5.10.102.1-microsoft-standard-WSL2-x86_64-with-glibc2.10
Python version:          3.8.8
Numpy version:           1.20.1
Scipy version:           1.6.1
Installed devices:
- default.gaussian (PennyLane-0.21.0.dev0)
- default.mixed (PennyLane-0.21.0.dev0)
- default.qubit (PennyLane-0.21.0.dev0)
- default.qubit.autograd (PennyLane-0.21.0.dev0)
- default.qubit.jax (PennyLane-0.21.0.dev0)
- default.qubit.tf (PennyLane-0.21.0.dev0)
- default.qubit.torch (PennyLane-0.21.0.dev0)
- pytket.pytketdevice (pytket-pennylane-0.1.0)
- lightning.qubit (PennyLane-Lightning-0.20.2)
- cov.qubit (covvqetools-0.1.1)

Existing GitHub issues

  • I have searched existing GitHub issues to make sure the issue does not already exist.

Issue Analytics

  • State:closed
  • Created a year ago
  • Comments:20 (20 by maintainers)

github_iconTop GitHub Comments

2reactions
cvjjmcommented, Jul 5, 2022

Great! The implementation in #2645 looks very nice.

2reactions
mlxdcommented, Apr 8, 2022

Hi @cvjjm this is really helpful. Running your provided example takes around 40s locally. Attacking the offending function allows me to bring this down to ~27s. It may take a few days, but I’ll aim to put together a general solution that helps here.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Bed Bugs - Frequently Asked Questions (FAQs) - CDC
Do bed bugs spread disease? Bed bugs are not known to spread disease. Bed bugs can be an annoyance because their presence may...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found