question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[BUG] Update `default.qubit.jax` to align with jax 0.2.21 (multiple jit defs)

See original GitHub issue

Expected behavior

A jitted function can be defined multiple times. E.g., running the following in a Jupyter cell multiple times works:

import jax
import pennylane as qml


@jax.jit
def sample_circuit(phi, theta, key):
    # Device construction should happen inside a `jax.jit` decorated
    # method when using a PRNGKey.
    dev = qml.device('default.qubit.jax', wires=2, prng_key=key, shots=100)

    @qml.qnode(dev, interface='jax')
    def circuit(phi, theta):
        qml.RX(phi[0], wires=0)
        qml.RZ(phi[1], wires=1)
        qml.CNOT(wires=[0, 1])
        qml.RX(theta, wires=0)
        return qml.sample()  # Here, we take samples instead.

    return circuit(phi, theta)


# Get the samples from the jitted method.
samples = sample_circuit([0.5, 0.5], 0.0, jax.random.PRNGKey(0))

Actual behavior

On the 2nd definition (and onwards), the following error is raised:

UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with shape (2, 2) and dtype complex64 to escape.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.

The trace points to:

~/xanadu/pennylane/pennylane/devices/jax_ops.py in RX(theta)
     74         array[complex]: unitary 2x2 rotation matrix :math:`e^{-i \sigma_x \theta/2}`
     75     """
---> 76     return jnp.cos(theta / 2) * I + 1j * jnp.sin(-theta / 2) * X

Additional information

Likely the following breaking change in jax 0.2.21 is related:

When inside a transformation such as jax.jit, jax.numpy.array always stages the array it produces into the traced computation. Previously jax.numpy.array would sometimes produce a on-device array, even under a jax.jit decorator. This change may break code that used JAX arrays to perform shape or index computations that must be known statically; the workaround is to perform such computations using classic NumPy arrays instead.

The likely solution going forward is to revisit the jax operations used by default.qubit.jax.

Source code

No response

Tracebacks

No response

System information

Name: PennyLane
Version: 0.19.0.dev0
Summary: PennyLane is a Python quantum machine learning library by Xanadu Inc.
Home-page: https://github.com/XanaduAI/pennylane
Author: None
Author-email: None
License: Apache License 2.0
Location: /xanadu/pennylane
Requires: numpy, scipy, networkx, autograd, toml, appdirs, semantic-version, autoray, cachetools, pennylane-lightning
Required-by: PennyLane-Cirq, PennyLane-Orquestra, PennyLane-SF, pennylane-qulacs, PennyLane-IonQ, amazon-braket-pennylane-plugin, PennyLane-Forest, PennyLane-Honeywell, PennyLane-qiskit, PennyLane-AQT, PennyLane-Lightning, PennyLane-Qchem
Platform info:           Linux-5.11.0-37-generic-x86_64-with-glibc2.10
Python version:          3.8.5
Numpy version:           1.20.3
Scipy version:           1.7.1
Installed devices:
- cirq.mixedsimulator (PennyLane-Cirq-0.17.1)
- cirq.pasqal (PennyLane-Cirq-0.17.1)
- cirq.qsim (PennyLane-Cirq-0.17.1)
- cirq.qsimh (PennyLane-Cirq-0.17.1)
- cirq.simulator (PennyLane-Cirq-0.17.1)
- orquestra.forest (PennyLane-Orquestra-0.15.0)
- orquestra.ibmq (PennyLane-Orquestra-0.15.0)
- orquestra.qiskit (PennyLane-Orquestra-0.15.0)
- orquestra.qulacs (PennyLane-Orquestra-0.15.0)
- strawberryfields.fock (PennyLane-SF-0.16.0.dev0)
- strawberryfields.gaussian (PennyLane-SF-0.16.0.dev0)
- strawberryfields.gbs (PennyLane-SF-0.16.0.dev0)
- strawberryfields.remote (PennyLane-SF-0.16.0.dev0)
- strawberryfields.tf (PennyLane-SF-0.16.0.dev0)
- qulacs.simulator (pennylane-qulacs-0.17.0.dev0)
- ionq.qpu (PennyLane-IonQ-0.17.0.dev0)
- ionq.simulator (PennyLane-IonQ-0.17.0.dev0)
- braket.aws.qubit (amazon-braket-pennylane-plugin-1.4.1.dev0)
- braket.local.qubit (amazon-braket-pennylane-plugin-1.4.1.dev0)
- forest.numpy_wavefunction (PennyLane-Forest-0.17.0.dev0)
- forest.qvm (PennyLane-Forest-0.17.0.dev0)
- forest.wavefunction (PennyLane-Forest-0.17.0.dev0)
- honeywell.hqs (PennyLane-Honeywell-0.16.0.dev0)
- qiskit.aer (PennyLane-qiskit-0.18.0.dev0)
- qiskit.basicaer (PennyLane-qiskit-0.18.0.dev0)
- qiskit.ibmq (PennyLane-qiskit-0.18.0.dev0)
- aqt.noisy_sim (PennyLane-AQT-0.18.0)
- aqt.sim (PennyLane-AQT-0.18.0)
- lightning.qubit (PennyLane-Lightning-0.19.0.dev0)
- default.gaussian (PennyLane-0.19.0.dev0)
- default.mixed (PennyLane-0.19.0.dev0)
- default.qubit (PennyLane-0.19.0.dev0)
- default.qubit.autograd (PennyLane-0.19.0.dev0)
- default.qubit.jax (PennyLane-0.19.0.dev0)
- default.qubit.tf (PennyLane-0.19.0.dev0)
- default.qubit.torch (PennyLane-0.19.0.dev0)
- default.tensor (PennyLane-0.19.0.dev0)
- default.tensor.tf (PennyLane-0.19.0.dev0)

JAX versions:

jax==0.2.21
jaxlib==0.1.71


### 

- [X] I have searched exisisting GitHub issues to make sure the issue does not already exist.

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:6 (6 by maintainers)

github_iconTop GitHub Comments

1reaction
josh146commented, Oct 12, 2021

Having the same code block twice (starting from @jax.jit) will reproduce the error in a Python script.

Thanks Antal, I reduced it down to a minimal non-working example 🙂

@jax.jit
def circuit(x):
    dev = qml.device('default.qubit.jax', wires=2, shots=100)

    @qml.qnode(dev, interface='jax')
    def circuit(x):
        qml.RX(x, wires=0)
        return qml.expval(qml.PauliZ(0))

    return circuit(x)

@jax.jit
def circuit2(x):
    dev = qml.device('default.qubit.jax', wires=2, shots=100)

    @qml.qnode(dev, interface='jax')
    def circuit(x):
        qml.RX(x, wires=0)
        return qml.expval(qml.PauliZ(0))

    return circuit(x)

samples = circuit(0.5)
samples = circuit2(0.5)

Interestingly, if you replace default.qubit.jax with default.qubit, it works well! It just doesn’t support qml.sample 🙁

0reactions
antalszavacommented, Oct 20, 2021

Seems to have no effect 🤔 It seems to have depended on PennyLane solely. Just double-checked with:

jax==0.2.24
jaxlib==0.1.73
Read more comments on GitHub >

github_iconTop Results From Across the Web

Change log - JAX documentation
The default device order used by pmap on TPU if no order is specified now matches jax.devices() for single-process jobs. Previously the two...
Read more >
Using JAX with PennyLane
We'll be using the default.qubit device for the first part of this ... @qml.qnode(dev, interface="jax") def circuit(param): # These two ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found