[BUG] Update `default.qubit.jax` to align with jax 0.2.21 (multiple jit defs)
See original GitHub issueExpected behavior
A jitted function can be defined multiple times. E.g., running the following in a Jupyter cell multiple times works:
import jax
import pennylane as qml
@jax.jit
def sample_circuit(phi, theta, key):
# Device construction should happen inside a `jax.jit` decorated
# method when using a PRNGKey.
dev = qml.device('default.qubit.jax', wires=2, prng_key=key, shots=100)
@qml.qnode(dev, interface='jax')
def circuit(phi, theta):
qml.RX(phi[0], wires=0)
qml.RZ(phi[1], wires=1)
qml.CNOT(wires=[0, 1])
qml.RX(theta, wires=0)
return qml.sample() # Here, we take samples instead.
return circuit(phi, theta)
# Get the samples from the jitted method.
samples = sample_circuit([0.5, 0.5], 0.0, jax.random.PRNGKey(0))
Actual behavior
On the 2nd definition (and onwards), the following error is raised:
UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with shape (2, 2) and dtype complex64 to escape.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The trace points to:
~/xanadu/pennylane/pennylane/devices/jax_ops.py in RX(theta)
74 array[complex]: unitary 2x2 rotation matrix :math:`e^{-i \sigma_x \theta/2}`
75 """
---> 76 return jnp.cos(theta / 2) * I + 1j * jnp.sin(-theta / 2) * X
Additional information
Likely the following breaking change in jax 0.2.21 is related:
When inside a transformation such as jax.jit, jax.numpy.array always stages the array it produces into the traced computation. Previously jax.numpy.array would sometimes produce a on-device array, even under a jax.jit decorator. This change may break code that used JAX arrays to perform shape or index computations that must be known statically; the workaround is to perform such computations using classic NumPy arrays instead.
The likely solution going forward is to revisit the jax operations used by default.qubit.jax.
Source code
No response
Tracebacks
No response
System information
Name: PennyLane
Version: 0.19.0.dev0
Summary: PennyLane is a Python quantum machine learning library by Xanadu Inc.
Home-page: https://github.com/XanaduAI/pennylane
Author: None
Author-email: None
License: Apache License 2.0
Location: /xanadu/pennylane
Requires: numpy, scipy, networkx, autograd, toml, appdirs, semantic-version, autoray, cachetools, pennylane-lightning
Required-by: PennyLane-Cirq, PennyLane-Orquestra, PennyLane-SF, pennylane-qulacs, PennyLane-IonQ, amazon-braket-pennylane-plugin, PennyLane-Forest, PennyLane-Honeywell, PennyLane-qiskit, PennyLane-AQT, PennyLane-Lightning, PennyLane-Qchem
Platform info: Linux-5.11.0-37-generic-x86_64-with-glibc2.10
Python version: 3.8.5
Numpy version: 1.20.3
Scipy version: 1.7.1
Installed devices:
- cirq.mixedsimulator (PennyLane-Cirq-0.17.1)
- cirq.pasqal (PennyLane-Cirq-0.17.1)
- cirq.qsim (PennyLane-Cirq-0.17.1)
- cirq.qsimh (PennyLane-Cirq-0.17.1)
- cirq.simulator (PennyLane-Cirq-0.17.1)
- orquestra.forest (PennyLane-Orquestra-0.15.0)
- orquestra.ibmq (PennyLane-Orquestra-0.15.0)
- orquestra.qiskit (PennyLane-Orquestra-0.15.0)
- orquestra.qulacs (PennyLane-Orquestra-0.15.0)
- strawberryfields.fock (PennyLane-SF-0.16.0.dev0)
- strawberryfields.gaussian (PennyLane-SF-0.16.0.dev0)
- strawberryfields.gbs (PennyLane-SF-0.16.0.dev0)
- strawberryfields.remote (PennyLane-SF-0.16.0.dev0)
- strawberryfields.tf (PennyLane-SF-0.16.0.dev0)
- qulacs.simulator (pennylane-qulacs-0.17.0.dev0)
- ionq.qpu (PennyLane-IonQ-0.17.0.dev0)
- ionq.simulator (PennyLane-IonQ-0.17.0.dev0)
- braket.aws.qubit (amazon-braket-pennylane-plugin-1.4.1.dev0)
- braket.local.qubit (amazon-braket-pennylane-plugin-1.4.1.dev0)
- forest.numpy_wavefunction (PennyLane-Forest-0.17.0.dev0)
- forest.qvm (PennyLane-Forest-0.17.0.dev0)
- forest.wavefunction (PennyLane-Forest-0.17.0.dev0)
- honeywell.hqs (PennyLane-Honeywell-0.16.0.dev0)
- qiskit.aer (PennyLane-qiskit-0.18.0.dev0)
- qiskit.basicaer (PennyLane-qiskit-0.18.0.dev0)
- qiskit.ibmq (PennyLane-qiskit-0.18.0.dev0)
- aqt.noisy_sim (PennyLane-AQT-0.18.0)
- aqt.sim (PennyLane-AQT-0.18.0)
- lightning.qubit (PennyLane-Lightning-0.19.0.dev0)
- default.gaussian (PennyLane-0.19.0.dev0)
- default.mixed (PennyLane-0.19.0.dev0)
- default.qubit (PennyLane-0.19.0.dev0)
- default.qubit.autograd (PennyLane-0.19.0.dev0)
- default.qubit.jax (PennyLane-0.19.0.dev0)
- default.qubit.tf (PennyLane-0.19.0.dev0)
- default.qubit.torch (PennyLane-0.19.0.dev0)
- default.tensor (PennyLane-0.19.0.dev0)
- default.tensor.tf (PennyLane-0.19.0.dev0)
JAX versions:
jax==0.2.21
jaxlib==0.1.71
###
- [X] I have searched exisisting GitHub issues to make sure the issue does not already exist.
Issue Analytics
- State:
- Created 2 years ago
- Comments:6 (6 by maintainers)

Top Related StackOverflow Question
Thanks Antal, I reduced it down to a minimal non-working example 🙂
Interestingly, if you replace
default.qubit.jaxwithdefault.qubit, it works well! It just doesn’t supportqml.sample🙁Seems to have no effect 🤔 It seems to have depended on PennyLane solely. Just double-checked with: