question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Precision issues when using JAX interface

See original GitHub issue

Issue description

The default precision in JAX is float32. Using the JAX interface in backprop mode causes non-trivial deviation in the results of variational circuits.

  • Expected behavior: Evaluating QNodes with the JAX interface in all contexts should match precision of other interfaces.

  • Actual behavior: Evaluating QNodes with JAX yields differences greater than the standard tolerance of 1e-8, even on QNodes with a modest number of rotations. Furthermore, the returned value is of type float32 even when float64 support is manually enabled. (Both issues are resolved by setting diff_method='parameter_shift' for the QNode.)

  • Reproduces how often: Always

  • System information: (post the output of import pennylane as qml; qml.about())

Name: PennyLane
Version: 0.16.0.dev0
Summary: PennyLane is a Python quantum machine learning library by Xanadu Inc.
Home-page: https://github.com/XanaduAI/pennylane
Author: None
Author-email: None
License: Apache License 2.0
Location: /home/olivia/Code/pennylane
Requires: numpy, scipy, networkx, autograd, toml, appdirs, semantic-version, autoray
Required-by: pennylane-qulacs, PennyLane-qsharp, PennyLane-qiskit, PennyLane-Qchem, PennyLane-Forest, PennyLane-Cirq, PennyLane-SF
Platform info:           Linux-5.4.0-73-generic-x86_64-with-glibc2.10
Python version:          3.8.5
Numpy version:           1.19.5
Scipy version:           1.4.1
Installed devices:
- qulacs.simulator (pennylane-qulacs-0.14.0)
- microsoft.QuantumSimulator (PennyLane-qsharp-0.8.0)
- qiskit.aer (PennyLane-qiskit-0.15.0)
- qiskit.basicaer (PennyLane-qiskit-0.15.0)
- qiskit.ibmq (PennyLane-qiskit-0.15.0)
- forest.numpy_wavefunction (PennyLane-Forest-0.15.0)
- forest.qvm (PennyLane-Forest-0.15.0)
- forest.wavefunction (PennyLane-Forest-0.15.0)
- cirq.mixedsimulator (PennyLane-Cirq-0.13.0)
- cirq.pasqal (PennyLane-Cirq-0.13.0)
- cirq.qsim (PennyLane-Cirq-0.13.0)
- cirq.qsimh (PennyLane-Cirq-0.13.0)
- cirq.simulator (PennyLane-Cirq-0.13.0)
- strawberryfields.fock (PennyLane-SF-0.16.0.dev0)
- strawberryfields.gaussian (PennyLane-SF-0.16.0.dev0)
- strawberryfields.gbs (PennyLane-SF-0.16.0.dev0)
- strawberryfields.remote (PennyLane-SF-0.16.0.dev0)
- strawberryfields.tf (PennyLane-SF-0.16.0.dev0)
- default.gaussian (PennyLane-0.16.0.dev0)
- default.mixed (PennyLane-0.16.0.dev0)
- default.qubit (PennyLane-0.16.0.dev0)
- default.qubit.autograd (PennyLane-0.16.0.dev0)
- default.qubit.jax (PennyLane-0.16.0.dev0)
- default.qubit.tf (PennyLane-0.16.0.dev0)
- default.tensor (PennyLane-0.16.0.dev0)
- default.tensor.tf (PennyLane-0.16.0.dev0)

Source code and tracebacks

Here is some starting code that enables float64 support, and creates a circuit:

import pennylane as qml
from pennylane import numpy as np

import jax
from jax.config import config
config.update("jax_enable_x64", True)

def circuit(weights, inpt):
    qml.RX(weights[0], wires=0)
    qml.RY(weights[1], wires=1)
    qml.RY(inpt[0], wires=0)
    qml.RX(inpt[1], wires=1)
    qml.CNOT(wires=[1, 0])
    return qml.expval(qml.PauliZ(0))

dev = qml.device("default.qubit", wires=2)

Evaluating the circuit with a standard QNode yields:

>>> qnode = qml.QNode(circuit, dev)
>>> weights = np.array([0.5, 0.2])
>>> inpt = np.array([0.6, 0.5])
>>> res = qnode(weights, inpt)
>>> res
0.6229628309572718
>>> res.dtype
float64

Evaluating with JAX yields a different value starting from the 7th decimal point:

>>> qnode_jax = qml.QNode(circuit, dev, interface="jax")
>>> weights = jax.numpy.array([0.5, 0.2])
>>> inpt = jax.numpy.array([0.6, 0.5])
>>> weights.dtype
float64
>>> res = qnode_jax(weight, inpt)
>>> res
0.6229629516601562
>>> res.dtype
float32

Setting diff_method="parameter-shift" gives the expected results:

>>> qnode_jax_param_shift = qml.QNode(circuit, dev, interface="jax", diff_method="parameter-shift")
>>> res = qnode_jax_param_shift(weights, inpt)
>>> res
0.6229628309572718
>>> res.dtype
float64

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:6 (6 by maintainers)

github_iconTop GitHub Comments

1reaction
glassnotescommented, Jul 30, 2021

Can we close this now that #1485 was merged?

1reaction
albi3rocommented, May 27, 2021

I was looking into a similar issue with finite-diff and jax, see PR #1349 . The jacobian tape itself uses numpy and float64, where the initial parameter enters as float32, and the jax device uses float32. I found a way to fix the finite-diff case at least.

We should edit the devices to allow users to specify datatype where possible. Sometimes users want 1e-8 accuracy, and sometimes users want to fit a large state onto their computer RAM.

Read more comments on GitHub >

github_iconTop Results From Across the Web

FFT precision/performance · Issue #2952 · google/jax
I find the noticeably difference between outputs of numpy.fft.fft and jax.numpy.fft.fft. The difference also changes with different device.
Read more >
Mixed Precision Training using Jax - tensorflow
I'm trying to understand how did Haiku achieve 2x speedup when training ResNet50 on ImageNet ...
Read more >
How to Think in JAX - JAX documentation - Read the Docs
JAX provides a NumPy-inspired interface for convenience. Through duck-typing, JAX arrays can often be used as drop-in replacements of NumPy arrays.
Read more >
[D] Should We Be Using JAX in 2022? : r/MachineLearning
Will JAX's functional paradigm lead to issues for those without functional experience, especially in Deep Learning?
Read more >
How to fix the error "The scale exceeded the precision. ...
This document describes how to fix precision errors with float ... OLE DB error trace [Non-interface error: Column 'xxx' (ordinal 1) of ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found