[BUG] JIT on qml.broadcast fails due to numpy array conversion in `single_dispatch.py` i.e, `_to_numpy_jax(x)`
See original GitHub issueExpected behavior
I should be able to JIT a qnode using JAX which uses the qml.broadcast function. The non-JITed version works (try commenting the JIT below). I suspect there is a numpy conversion call happening inside that makes JIT fail.
The error message indicates that single_dispatch.py in _to_numpy_jax(x) is where there is a conversion to Numpy.
Actual behavior
TracerArrayConversionError Traceback (most recent call last)
~/miniconda3/envs/jaxnn/lib/python3.9/site-packages/pennylane/math/single_dispatch.py in _to_numpy_jax(x)
588 try:
--> 589 return np.array(getattr(x, "val", x))
590 except TracerArrayConversionError as e:
~/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/core.py in __array__(self, *args, **kw)
482 def __array__(self, *args, **kw):
--> 483 raise TracerArrayConversionError(self)
484
TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)>
Additional information
I wanted to use this function to make a quantum circuit that is the opposite of qml.StronglyEntanglingLayers by broadcasting a set of single qubit unitaries.
Source code
import jax
import jax.numpy as jnp
from pennylane import broadcast
import pennylane as qml
dev = qml.device("default.qubit.jax", wires=[0, 1, 2], shots=None)
def mytemplate(pars, wires):
qml.RY(pars[0], wires=wires)
qml.RX(pars[1], wires=wires)
@jax.jit
@qml.qnode(dev)
def circuit(pars):
broadcast(unitary=mytemplate, pattern="single", wires=[0, 1, 2], parameters=pars)
return qml.expval(qml.PauliZ(0))
parameters = jnp.array([[[1., 1.]], [[2., 1.]], [[0.1, 1.]]])
print(circuit(parameters))
qml.about()
Tracebacks
TracerArrayConversionError Traceback (most recent call last)
~/miniconda3/envs/jaxnn/lib/python3.9/site-packages/pennylane/math/single_dispatch.py in _to_numpy_jax(x)
588 try:
--> 589 return np.array(getattr(x, "val", x))
590 except TracerArrayConversionError as e:
~/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/core.py in __array__(self, *args, **kw)
482 def __array__(self, *args, **kw):
--> 483 raise TracerArrayConversionError(self)
484
TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)>
While tracing the function circuit at /var/folders/8s/tfpsk_fx609f8w7z__yzz9vh0000gn/T/ipykernel_91532/3484084806.py:15 for jit, this concrete value was not available in Python because it depends on the value of the argument 'pars'.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
The above exception was the direct cause of the following exception:
ValueError Traceback (most recent call last)
/var/folders/8s/tfpsk_fx609f8w7z__yzz9vh0000gn/T/ipykernel_91532/3484084806.py in <module>
20
21 parameters = jnp.array([[[1., 1.]], [[2., 1.]], [[0.1, 1.]]])
---> 22 print(circuit(parameters))
23
24 qml.about()
...
--> 591 raise ValueError(
592 "Converting a JAX array to a NumPy array not supported when using the JAX JIT."
593 ) from e
ValueError: Converting a JAX array to a NumPy array not supported when using the JAX JIT.
System information
Name: PennyLane
Version: 0.24.0
Summary: PennyLane is a Python quantum machine learning library by Xanadu Inc.
Home-page: https://github.com/XanaduAI/pennylane
Author:
Author-email:
License: Apache License 2.0
Location: /Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, pennylane-lightning, retworkx, scipy, semantic-version, toml
Required-by: PennyLane-Forest, PennyLane-Lightning, PennyLane-Qchem, PennyLane-qiskit
Platform info: macOS-12.4-x86_64-i386-64bit
Python version: 3.9.7
Numpy version: 1.20.1
Scipy version: 1.7.3
Installed devices:
- lightning.qubit (PennyLane-Lightning-0.24.0)
- default.gaussian (PennyLane-0.24.0)
- default.mixed (PennyLane-0.24.0)
- default.qubit (PennyLane-0.24.0)
- default.qubit.autograd (PennyLane-0.24.0)
- default.qubit.jax (PennyLane-0.24.0)
- default.qubit.tf (PennyLane-0.24.0)
- default.qubit.torch (PennyLane-0.24.0)
...
- qiskit.ibmq.sampler (PennyLane-qiskit-0.21.0)
- forest.numpy_wavefunction (PennyLane-Forest-0.20.0)
- forest.qvm (PennyLane-Forest-0.20.0)
- forest.wavefunction (PennyLane-Forest-0.20.0)
Existing GitHub issues
- I have searched existing GitHub issues to make sure the issue does not already exist.
Issue Analytics
- State:
- Created a year ago
- Comments:5 (5 by maintainers)
Top Results From Across the Web
Fail to fit my ML model due to Tensor conversion error
The main issue is that your using dataframe so you going to have arrays of arrays you can solve some parties calling .to_numpy()...
Read more >Broadcasting — NumPy v1.24 Manual
The term broadcasting describes how NumPy treats arrays with different shapes during arithmetic operations. Subject to certain constraints, the smaller ...
Read more >Convert python numpy array to double - MATLAB Answers
The plot call below will throw an error because x is not a matlab type. How do you convert a python numpy array...
Read more >How to Think in JAX - JAX documentation - Read the Docs
Python's duck-typing allows JAX arrays and NumPy arrays to be used ... The equivalent in JAX results in an error, as JAX arrays...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found

Thank you @albi3ro that fixed it. I usually just do “jax-jit” and things work but I forgot it here. I am a bit rusty on what different interfaces are doing. I thought setting the device to “default.qubit.jax” should have made it clear that I want JIT 😃. Feel free to close the issue @CatalinaAlbornoz . I also tried with the latest version of PL and the issue remains if one does not use ‘interface=jax’
I found the problem. You need to specify
interface="jax"when constructing theQNode:Since
interface="jax"was not specified, it tries to convert all the parameters to vanilla numpy.While that was a fairly unhelpful error message, we plan to detect the interface automatically in the future. Hopefully, that will stop this kind of thing from occurring.