Precision issues when using JAX interface
See original GitHub issueIssue description
The default precision in JAX is float32
. Using the JAX interface in backprop mode causes non-trivial deviation in the results of variational circuits.
-
Expected behavior: Evaluating QNodes with the JAX interface in all contexts should match precision of other interfaces.
-
Actual behavior: Evaluating QNodes with JAX yields differences greater than the standard tolerance of
1e-8
, even on QNodes with a modest number of rotations. Furthermore, the returned value is of typefloat32
even whenfloat64
support is manually enabled. (Both issues are resolved by settingdiff_method='parameter_shift'
for the QNode.) -
Reproduces how often: Always
-
System information: (post the output of
import pennylane as qml; qml.about()
)
Name: PennyLane
Version: 0.16.0.dev0
Summary: PennyLane is a Python quantum machine learning library by Xanadu Inc.
Home-page: https://github.com/XanaduAI/pennylane
Author: None
Author-email: None
License: Apache License 2.0
Location: /home/olivia/Code/pennylane
Requires: numpy, scipy, networkx, autograd, toml, appdirs, semantic-version, autoray
Required-by: pennylane-qulacs, PennyLane-qsharp, PennyLane-qiskit, PennyLane-Qchem, PennyLane-Forest, PennyLane-Cirq, PennyLane-SF
Platform info: Linux-5.4.0-73-generic-x86_64-with-glibc2.10
Python version: 3.8.5
Numpy version: 1.19.5
Scipy version: 1.4.1
Installed devices:
- qulacs.simulator (pennylane-qulacs-0.14.0)
- microsoft.QuantumSimulator (PennyLane-qsharp-0.8.0)
- qiskit.aer (PennyLane-qiskit-0.15.0)
- qiskit.basicaer (PennyLane-qiskit-0.15.0)
- qiskit.ibmq (PennyLane-qiskit-0.15.0)
- forest.numpy_wavefunction (PennyLane-Forest-0.15.0)
- forest.qvm (PennyLane-Forest-0.15.0)
- forest.wavefunction (PennyLane-Forest-0.15.0)
- cirq.mixedsimulator (PennyLane-Cirq-0.13.0)
- cirq.pasqal (PennyLane-Cirq-0.13.0)
- cirq.qsim (PennyLane-Cirq-0.13.0)
- cirq.qsimh (PennyLane-Cirq-0.13.0)
- cirq.simulator (PennyLane-Cirq-0.13.0)
- strawberryfields.fock (PennyLane-SF-0.16.0.dev0)
- strawberryfields.gaussian (PennyLane-SF-0.16.0.dev0)
- strawberryfields.gbs (PennyLane-SF-0.16.0.dev0)
- strawberryfields.remote (PennyLane-SF-0.16.0.dev0)
- strawberryfields.tf (PennyLane-SF-0.16.0.dev0)
- default.gaussian (PennyLane-0.16.0.dev0)
- default.mixed (PennyLane-0.16.0.dev0)
- default.qubit (PennyLane-0.16.0.dev0)
- default.qubit.autograd (PennyLane-0.16.0.dev0)
- default.qubit.jax (PennyLane-0.16.0.dev0)
- default.qubit.tf (PennyLane-0.16.0.dev0)
- default.tensor (PennyLane-0.16.0.dev0)
- default.tensor.tf (PennyLane-0.16.0.dev0)
Source code and tracebacks
Here is some starting code that enables float64
support, and creates a circuit:
import pennylane as qml
from pennylane import numpy as np
import jax
from jax.config import config
config.update("jax_enable_x64", True)
def circuit(weights, inpt):
qml.RX(weights[0], wires=0)
qml.RY(weights[1], wires=1)
qml.RY(inpt[0], wires=0)
qml.RX(inpt[1], wires=1)
qml.CNOT(wires=[1, 0])
return qml.expval(qml.PauliZ(0))
dev = qml.device("default.qubit", wires=2)
Evaluating the circuit with a standard QNode yields:
>>> qnode = qml.QNode(circuit, dev)
>>> weights = np.array([0.5, 0.2])
>>> inpt = np.array([0.6, 0.5])
>>> res = qnode(weights, inpt)
>>> res
0.6229628309572718
>>> res.dtype
float64
Evaluating with JAX yields a different value starting from the 7th decimal point:
>>> qnode_jax = qml.QNode(circuit, dev, interface="jax")
>>> weights = jax.numpy.array([0.5, 0.2])
>>> inpt = jax.numpy.array([0.6, 0.5])
>>> weights.dtype
float64
>>> res = qnode_jax(weight, inpt)
>>> res
0.6229629516601562
>>> res.dtype
float32
Setting diff_method="parameter-shift"
gives the expected results:
>>> qnode_jax_param_shift = qml.QNode(circuit, dev, interface="jax", diff_method="parameter-shift")
>>> res = qnode_jax_param_shift(weights, inpt)
>>> res
0.6229628309572718
>>> res.dtype
float64
Issue Analytics
- State:
- Created 2 years ago
- Comments:6 (6 by maintainers)
Can we close this now that #1485 was merged?
I was looking into a similar issue with
finite-diff
and jax, see PR #1349 . The jacobian tape itself uses numpy and float64, where the initial parameter enters as float32, and the jax device uses float32. I found a way to fix the finite-diff case at least.We should edit the devices to allow users to specify datatype where possible. Sometimes users want 1e-8 accuracy, and sometimes users want to fit a large state onto their computer RAM.