[BUG] qml.probs not working with jax.jit
See original GitHub issueExpected behavior
This should return probability values.
dev = qml.device("default.qubit.jax", wires=2, shots=100)
@jax.jit
@qml.qnode(dev, interface="jax")
def my_circuit(param):
qml.RX(param, wires=0)
qml.CNOT(wires=[0, 1])
return qml.probs(0)
print(my_circuit(1))
Actual behavior
Code errors out.
Additional information
Code without jitting works:
@qml.qnode(dev, interface="jax")
def my_circuit(param):
qml.RX(param, wires=0)
qml.CNOT(wires=[0, 1])
return qml.probs(0)
>>> print(my_circuit(1))
[0.79 0.21]
Other measurements like qml.sample(qml.PauliZ(0)), qml.state() etc. work.
Source code
No response
Tracebacks
File "/temp.py", line 57, in <module>
print(my_circuit(1))
File "/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/lib/python3.9/site-packages/jax/_src/api.py", line 416, in cache_miss
out_flat = xla.xla_call(
File "/lib/python3.9/site-packages/jax/core.py", line 1632, in bind
return call_bind(self, fun, *args, **params)
File "/lib/python3.9/site-packages/jax/core.py", line 1623, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/lib/python3.9/site-packages/jax/core.py", line 1635, in process
return trace.process_call(self, fun, tracers, params)
File "/lib/python3.9/site-packages/jax/core.py", line 627, in process_call
return primitive.impl(f, *tracers, **params)
File "/lib/python3.9/site-packages/jax/interpreters/xla.py", line 687, in _xla_call_impl
compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
File "/lib/python3.9/site-packages/jax/linear_util.py", line 263, in memoized_fun
ans = call(fun, *args)
File "/lib/python3.9/site-packages/jax/interpreters/xla.py", line 759, in _xla_callable_uncached
return lower_xla_callable(fun, device, backend, name, donated_invars,
File "/lib/python3.9/site-packages/jax/interpreters/xla.py", line 771, in lower_xla_callable
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
File "/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1542, in trace_to_jaxpr_final
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
File "/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1520, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
File "/lib/python3.9/site-packages/jax/linear_util.py", line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/lib/python3.9/site-packages/pennylane/qnode.py", line 699, in __call__
res = self.qtape.execute(device=self.device)
File "/lib/python3.9/site-packages/pennylane/tape/tape.py", line 1324, in execute
return self._execute(params, device=device)
File "/lib/python3.9/site-packages/pennylane/tape/tape.py", line 1355, in execute_device
res = device.execute(self)
File "/lib/python3.9/site-packages/pennylane/_qubit_device.py", line 227, in execute
results = self.statistics(circuit.observables)
File "/lib/python3.9/site-packages/pennylane/_qubit_device.py", line 404, in statistics
self.probability(wires=obs.wires, shot_range=shot_range, bin_size=bin_size)
File "/lib/python3.9/site-packages/pennylane/_qubit_device.py", line 681, in probability
return self.estimate_probability(wires=wires, shot_range=shot_range, bin_size=bin_size)
File "/lib/python3.9/site-packages/pennylane/_qubit_device.py", line 657, in estimate_probability
basis_states, counts = np.unique(indices, return_counts=True)
File "<__array_function__ internals>", line 5, in unique
File "/lib/python3.9/site-packages/numpy/lib/arraysetops.py", line 260, in unique
ar = np.asanyarray(ar)
File "/lib/python3.9/site-packages/numpy/core/_asarray.py", line 171, in asanyarray
return array(a, dtype, copy=False, order=order, subok=True)
File "/lib/python3.9/site-packages/jax/core.py", line 483, in __array__
raise TracerArrayConversionError(self)
jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(int64[100])>with<DynamicJaxprTrace(level=0/1)>
While tracing the function my_circuit at /temp.py:45 for jit, this concrete value was not available in Python because it depends on the value of the argument 'param'.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
System information
Name: PennyLane
Version: 0.19.0
Summary: PennyLane is a Python quantum machine learning library by Xanadu Inc.
Home-page: https://github.com/XanaduAI/pennylane
Author:
Author-email:
License: Apache License 2.0
Location: /lib/python3.9/site-packages
Requires: pennylane-lightning, autograd, appdirs, semantic-version, numpy, networkx, cachetools, autoray, scipy, toml
Required-by: PennyLane-SF, PennyLane-qiskit, PennyLane-Lightning
Platform info: Linux-5.11.0-41-generic-x86_64-with-glibc2.31
Python version: 3.9.6
Numpy version: 1.20.3
Scipy version: 1.7.1
Installed devices:
- default.gaussian (PennyLane-0.19.0)
- default.mixed (PennyLane-0.19.0)
- default.qubit (PennyLane-0.19.0)
- default.qubit.autograd (PennyLane-0.19.0)
- default.qubit.jax (PennyLane-0.19.0)
- default.qubit.tf (PennyLane-0.19.0)
- default.qubit.torch (PennyLane-0.19.0)
- default.tensor (PennyLane-0.19.0)
- default.tensor.tf (PennyLane-0.19.0)
- strawberryfields.fock (PennyLane-SF-0.19.0)
- strawberryfields.gaussian (PennyLane-SF-0.19.0)
- strawberryfields.gbs (PennyLane-SF-0.19.0)
- strawberryfields.remote (PennyLane-SF-0.19.0)
- strawberryfields.tf (PennyLane-SF-0.19.0)
- qiskit.aer (PennyLane-qiskit-0.18.0)
- qiskit.basicaer (PennyLane-qiskit-0.18.0)
- qiskit.ibmq (PennyLane-qiskit-0.18.0)
- lightning.qubit (PennyLane-Lightning-0.18.0)
- I have searched exisisting GitHub issues to make sure the issue does not already exist.
Issue Analytics
- State:
- Created 2 years ago
- Comments:11 (11 by maintainers)
Top Results From Across the Web
JIT-ed calculation of Hessian [grad(grad)] fails with JAX #2163
I was trying to compute the Hessian and saw that the Jax interface breaks down if we have the JIT on. Without JIT,...
Read more >Problem when using JAX with AmplitudeEmbedding
Dear experts, I'm trying to set up a circuit with AmplitudeEmbedding + StronglyEntanglingLayers using JAX. I'm using 4 input variables, ...
Read more >JAX Errors - JAX documentation - Read the Docs
This error occurs when a JAX Tracer object is used in a context where a concrete value is required. In some situations, it...
Read more >Just In Time Compilation with JAX
This is a feature, not a bug: JAX is designed to understand ... weak_type=True)>with<DynamicJaxprTrace(level=0/1)> The problem arose with the bool function.
Read more >JAX Frequently Asked Questions (FAQ)
scan() ) or avoid wrapping the loop with jit (you can still use jit decorated functions inside the loop). If you're not sure...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found

Hey @Jaybsoni, this seemed like a hotfix. I’d prefer you do it to ensure no unintended consequences are there and everything works as expected.
For sure, I can take a look at it 👍🏼