[BUG] Jax compiled default.qubit.jax device raises ConversionError for qml.QubitStateVector
See original GitHub issueExpected behavior
The _apply_state_vector method seems to not have been adapted to be compatible with jax compiled code when setting a state vector with qml.QubitStateVector.
Actual behavior
Specifically in
if not np.allclose(np.linalg.norm(state, ord=2), 1.0, atol=tolerance):
raise ValueError("Sum of amplitudes-squared does not equal one.")
where the norm of the state vector is calculated with np.linalg.norm raises a jax._src.errors.TracerArrayConversionError.
A solution could be to use the jax.numpy version instead: jnp.linalg.norm
Additional information
No response
Source code
import pennylane as qml
import jax
import numpy as np
def circuit(x):
wires = list(range(2))
qml.QubitStateVector(x, wires=wires)
return [qml.expval(qml.PauliX(wires=i)) for i in wires]
dev = qml.device("default.qubit.jax", wires=list(range(2)))
qnode = jax.jit(qml.QNode(circuit, dev, interface="jax"))
state_vector = np.array([0.5 + 0.5j, 0.5 + 0.5j, 0, 0])
f_norm = jax.jit(jax.numpy.linalg.norm) # works
#f_norm = jax.jit(np.linalg.norm) # does not work, raises same error
print(f_norm(state_vector))
qnode(state_vector)
Tracebacks
No response
System information
Name: PennyLane
Version: 0.17.0
Summary: PennyLane is a Python quantum machine learning library by Xanadu Inc.
Home-page: https://github.com/XanaduAI/pennylane
Author:
Author-email:
License: Apache License 2.0
Location: /home/fabian/.local/lib/python3.8/site-packages
Requires: appdirs, networkx, semantic-version, numpy, scipy, toml, autoray, autograd
Required-by: pennylane-qulacs, PennyLane-qiskit
Platform info: Linux-5.11.0-34-generic-x86_64-with-glibc2.29
Python version: 3.8.10
Numpy version: 1.19.2
Scipy version: 1.7.1
Installed devices:
- default.gaussian (PennyLane-0.17.0)
- default.mixed (PennyLane-0.17.0)
- default.qubit (PennyLane-0.17.0)
- default.qubit.autograd (PennyLane-0.17.0)
- default.qubit.jax (PennyLane-0.17.0)
- default.qubit.tf (PennyLane-0.17.0)
- default.tensor (PennyLane-0.17.0)
- default.tensor.tf (PennyLane-0.17.0)
- qulacs.simulator (pennylane-qulacs-0.15.0)
- qiskit.aer (PennyLane-qiskit-0.17.0)
- qiskit.basicaer (PennyLane-qiskit-0.17.0)
- qiskit.ibmq (PennyLane-qiskit-0.17.0)
- I have searched exisisting GitHub issues to make sure the issue does not already exist.
Issue Analytics
- State:
- Created 2 years ago
- Comments:5 (4 by maintainers)
Top Results From Across the Web
When using capacitor camera, requesting permissions on web fails ...
Bug Report When using capacitor camera, requesting permissions on web fails. As IONIC and CAPACITOR is meant to support HYBRID apps, this is...
Read more >Support version pinning for pack installs - StackStorm/St2 - IssueHint
[BUG] Jax compiled default.qubit.jax device raises ConversionError for qml.QubitStateVector, 5, 2021-09-20, 2022-10-07. Delete task without confirmation ...
Read more >pennylane - githubmemory
Add differentiation support to the `qml.transforms.get_unitary_matrix` ... [BUG] Jax compiled default.qubit.jax device raises ConversionError for qml.
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found

Hi @bonfab, with #1683 merged, this should be resolved in the
masterbranch.One approach to fix this could be to make sure that the
qml.math.linalg.normandqml.math.allclosefunctions both work with the JAX jit — once this is the case, we can modify thisdefault.qubitmethod to use these functions instead