question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[BUG] JIT on qml.broadcast fails due to numpy array conversion in `single_dispatch.py` i.e, `_to_numpy_jax(x)`

See original GitHub issue

Expected behavior

I should be able to JIT a qnode using JAX which uses the qml.broadcast function. The non-JITed version works (try commenting the JIT below). I suspect there is a numpy conversion call happening inside that makes JIT fail.

The error message indicates that single_dispatch.py in _to_numpy_jax(x) is where there is a conversion to Numpy.

Actual behavior

TracerArrayConversionError                Traceback (most recent call last)
~/miniconda3/envs/jaxnn/lib/python3.9/site-packages/pennylane/math/single_dispatch.py in _to_numpy_jax(x)
    588     try:
--> 589         return np.array(getattr(x, "val", x))
    590     except TracerArrayConversionError as e:

~/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/core.py in __array__(self, *args, **kw)
    482   def __array__(self, *args, **kw):
--> 483     raise TracerArrayConversionError(self)
    484 

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)>

Additional information

I wanted to use this function to make a quantum circuit that is the opposite of qml.StronglyEntanglingLayers by broadcasting a set of single qubit unitaries.

Source code

import jax
import jax.numpy as jnp


from pennylane import broadcast
import pennylane as qml


dev = qml.device("default.qubit.jax", wires=[0, 1, 2], shots=None)

def mytemplate(pars, wires):
    qml.RY(pars[0], wires=wires)
    qml.RX(pars[1], wires=wires)

@jax.jit
@qml.qnode(dev)
def circuit(pars):
    broadcast(unitary=mytemplate, pattern="single", wires=[0, 1, 2], parameters=pars)
    return qml.expval(qml.PauliZ(0))

parameters = jnp.array([[[1., 1.]], [[2., 1.]], [[0.1, 1.]]])
print(circuit(parameters))

qml.about()

Tracebacks

TracerArrayConversionError                Traceback (most recent call last)
~/miniconda3/envs/jaxnn/lib/python3.9/site-packages/pennylane/math/single_dispatch.py in _to_numpy_jax(x)
    588     try:
--> 589         return np.array(getattr(x, "val", x))
    590     except TracerArrayConversionError as e:

~/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/core.py in __array__(self, *args, **kw)
    482   def __array__(self, *args, **kw):
--> 483     raise TracerArrayConversionError(self)
    484 

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)>
While tracing the function circuit at /var/folders/8s/tfpsk_fx609f8w7z__yzz9vh0000gn/T/ipykernel_91532/3484084806.py:15 for jit, this concrete value was not available in Python because it depends on the value of the argument 'pars'.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

The above exception was the direct cause of the following exception:

ValueError                                Traceback (most recent call last)
/var/folders/8s/tfpsk_fx609f8w7z__yzz9vh0000gn/T/ipykernel_91532/3484084806.py in <module>
     20 
     21 parameters = jnp.array([[[1., 1.]], [[2., 1.]], [[0.1, 1.]]])
---> 22 print(circuit(parameters))
     23 
     24 qml.about()
...
--> 591         raise ValueError(
    592             "Converting a JAX array to a NumPy array not supported when using the JAX JIT."
    593         ) from e

ValueError: Converting a JAX array to a NumPy array not supported when using the JAX JIT.

System information

Name: PennyLane
Version: 0.24.0
Summary: PennyLane is a Python quantum machine learning library by Xanadu Inc.
Home-page: https://github.com/XanaduAI/pennylane
Author: 
Author-email: 
License: Apache License 2.0
Location: /Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, pennylane-lightning, retworkx, scipy, semantic-version, toml
Required-by: PennyLane-Forest, PennyLane-Lightning, PennyLane-Qchem, PennyLane-qiskit

Platform info:           macOS-12.4-x86_64-i386-64bit
Python version:          3.9.7
Numpy version:           1.20.1
Scipy version:           1.7.3
Installed devices:
- lightning.qubit (PennyLane-Lightning-0.24.0)
- default.gaussian (PennyLane-0.24.0)
- default.mixed (PennyLane-0.24.0)
- default.qubit (PennyLane-0.24.0)
- default.qubit.autograd (PennyLane-0.24.0)
- default.qubit.jax (PennyLane-0.24.0)
- default.qubit.tf (PennyLane-0.24.0)
- default.qubit.torch (PennyLane-0.24.0)
...
- qiskit.ibmq.sampler (PennyLane-qiskit-0.21.0)
- forest.numpy_wavefunction (PennyLane-Forest-0.20.0)
- forest.qvm (PennyLane-Forest-0.20.0)
- forest.wavefunction (PennyLane-Forest-0.20.0)

Existing GitHub issues

  • I have searched existing GitHub issues to make sure the issue does not already exist.

Issue Analytics

  • State:closed
  • Created a year ago
  • Comments:5 (5 by maintainers)

github_iconTop GitHub Comments

1reaction
quantshahcommented, Sep 29, 2022

Thank you @albi3ro that fixed it. I usually just do “jax-jit” and things work but I forgot it here. I am a bit rusty on what different interfaces are doing. I thought setting the device to “default.qubit.jax” should have made it clear that I want JIT 😃. Feel free to close the issue @CatalinaAlbornoz . I also tried with the latest version of PL and the issue remains if one does not use ‘interface=jax’

0reactions
albi3rocommented, Sep 29, 2022

I found the problem. You need to specify interface="jax" when constructing the QNode:

@jax.jit
@qml.qnode(dev, interface="jax")
def circuit(pars):
    broadcast(unitary=mytemplate, pattern="single", wires=[0, 1, 2], parameters=pars)
    return qml.expval(qml.PauliZ(0))

Since interface="jax" was not specified, it tries to convert all the parameters to vanilla numpy.

While that was a fairly unhelpful error message, we plan to detect the interface automatically in the future. Hopefully, that will stop this kind of thing from occurring.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Fail to fit my ML model due to Tensor conversion error
The main issue is that your using dataframe so you going to have arrays of arrays you can solve some parties calling .to_numpy()...
Read more >
Broadcasting — NumPy v1.24 Manual
The term broadcasting describes how NumPy treats arrays with different shapes during arithmetic operations. Subject to certain constraints, the smaller ...
Read more >
Convert python numpy array to double - MATLAB Answers
The plot call below will throw an error because x is not a matlab type. How do you convert a python numpy array...
Read more >
How to Think in JAX - JAX documentation - Read the Docs
Python's duck-typing allows JAX arrays and NumPy arrays to be used ... The equivalent in JAX results in an error, as JAX arrays...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found