question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[BUG] qml.probs not working with jax.jit

See original GitHub issue

Expected behavior

This should return probability values.

dev = qml.device("default.qubit.jax", wires=2, shots=100)

@jax.jit
@qml.qnode(dev, interface="jax")
def my_circuit(param):
    qml.RX(param, wires=0)
    qml.CNOT(wires=[0, 1])
    return qml.probs(0)

print(my_circuit(1))

Actual behavior

Code errors out.

Additional information

Code without jitting works:

@qml.qnode(dev, interface="jax")
def my_circuit(param):
    qml.RX(param, wires=0)
    qml.CNOT(wires=[0, 1])
    return qml.probs(0)

>>> print(my_circuit(1))
[0.79 0.21]

Other measurements like qml.sample(qml.PauliZ(0)), qml.state() etc. work.

Source code

No response

Tracebacks

File "/temp.py", line 57, in <module>
    print(my_circuit(1))
  File "/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/lib/python3.9/site-packages/jax/_src/api.py", line 416, in cache_miss
    out_flat = xla.xla_call(
  File "/lib/python3.9/site-packages/jax/core.py", line 1632, in bind
    return call_bind(self, fun, *args, **params)
  File "/lib/python3.9/site-packages/jax/core.py", line 1623, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/lib/python3.9/site-packages/jax/core.py", line 1635, in process
    return trace.process_call(self, fun, tracers, params)
  File "/lib/python3.9/site-packages/jax/core.py", line 627, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/lib/python3.9/site-packages/jax/interpreters/xla.py", line 687, in _xla_call_impl
    compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
  File "/lib/python3.9/site-packages/jax/linear_util.py", line 263, in memoized_fun
    ans = call(fun, *args)
  File "/lib/python3.9/site-packages/jax/interpreters/xla.py", line 759, in _xla_callable_uncached
    return lower_xla_callable(fun, device, backend, name, donated_invars,
  File "/lib/python3.9/site-packages/jax/interpreters/xla.py", line 771, in lower_xla_callable
    jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
  File "/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1542, in trace_to_jaxpr_final
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
  File "/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1520, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
  File "/lib/python3.9/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/lib/python3.9/site-packages/pennylane/qnode.py", line 699, in __call__
    res = self.qtape.execute(device=self.device)
  File "/lib/python3.9/site-packages/pennylane/tape/tape.py", line 1324, in execute
    return self._execute(params, device=device)
  File "/lib/python3.9/site-packages/pennylane/tape/tape.py", line 1355, in execute_device
    res = device.execute(self)
  File "/lib/python3.9/site-packages/pennylane/_qubit_device.py", line 227, in execute
    results = self.statistics(circuit.observables)
  File "/lib/python3.9/site-packages/pennylane/_qubit_device.py", line 404, in statistics
    self.probability(wires=obs.wires, shot_range=shot_range, bin_size=bin_size)
  File "/lib/python3.9/site-packages/pennylane/_qubit_device.py", line 681, in probability
    return self.estimate_probability(wires=wires, shot_range=shot_range, bin_size=bin_size)
  File "/lib/python3.9/site-packages/pennylane/_qubit_device.py", line 657, in estimate_probability
    basis_states, counts = np.unique(indices, return_counts=True)
  File "<__array_function__ internals>", line 5, in unique
  File "/lib/python3.9/site-packages/numpy/lib/arraysetops.py", line 260, in unique
    ar = np.asanyarray(ar)
  File "/lib/python3.9/site-packages/numpy/core/_asarray.py", line 171, in asanyarray
    return array(a, dtype, copy=False, order=order, subok=True)
  File "/lib/python3.9/site-packages/jax/core.py", line 483, in __array__
    raise TracerArrayConversionError(self)
jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(int64[100])>with<DynamicJaxprTrace(level=0/1)>
While tracing the function my_circuit at /temp.py:45 for jit, this concrete value was not available in Python because it depends on the value of the argument 'param'.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

System information

Name: PennyLane
Version: 0.19.0
Summary: PennyLane is a Python quantum machine learning library by Xanadu Inc.
Home-page: https://github.com/XanaduAI/pennylane
Author: 
Author-email: 
License: Apache License 2.0
Location: /lib/python3.9/site-packages
Requires: pennylane-lightning, autograd, appdirs, semantic-version, numpy, networkx, cachetools, autoray, scipy, toml
Required-by: PennyLane-SF, PennyLane-qiskit, PennyLane-Lightning
Platform info:           Linux-5.11.0-41-generic-x86_64-with-glibc2.31
Python version:          3.9.6
Numpy version:           1.20.3
Scipy version:           1.7.1
Installed devices:
- default.gaussian (PennyLane-0.19.0)
- default.mixed (PennyLane-0.19.0)
- default.qubit (PennyLane-0.19.0)
- default.qubit.autograd (PennyLane-0.19.0)
- default.qubit.jax (PennyLane-0.19.0)
- default.qubit.tf (PennyLane-0.19.0)
- default.qubit.torch (PennyLane-0.19.0)
- default.tensor (PennyLane-0.19.0)
- default.tensor.tf (PennyLane-0.19.0)
- strawberryfields.fock (PennyLane-SF-0.19.0)
- strawberryfields.gaussian (PennyLane-SF-0.19.0)
- strawberryfields.gbs (PennyLane-SF-0.19.0)
- strawberryfields.remote (PennyLane-SF-0.19.0)
- strawberryfields.tf (PennyLane-SF-0.19.0)
- qiskit.aer (PennyLane-qiskit-0.18.0)
- qiskit.basicaer (PennyLane-qiskit-0.18.0)
- qiskit.ibmq (PennyLane-qiskit-0.18.0)
- lightning.qubit (PennyLane-Lightning-0.18.0)

  • I have searched exisisting GitHub issues to make sure the issue does not already exist.

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:11 (11 by maintainers)

github_iconTop GitHub Comments

1reaction
ankit27khcommented, Dec 13, 2021

Hey @Jaybsoni, this seemed like a hotfix. I’d prefer you do it to ensure no unintended consequences are there and everything works as expected.

0reactions
Jaybsonicommented, Dec 14, 2021

For sure, I can take a look at it 👍🏼

Read more comments on GitHub >

github_iconTop Results From Across the Web

JIT-ed calculation of Hessian [grad(grad)] fails with JAX #2163
I was trying to compute the Hessian and saw that the Jax interface breaks down if we have the JIT on. Without JIT,...
Read more >
Problem when using JAX with AmplitudeEmbedding
Dear experts, I'm trying to set up a circuit with AmplitudeEmbedding + StronglyEntanglingLayers using JAX. I'm using 4 input variables, ...
Read more >
JAX Errors - JAX documentation - Read the Docs
This error occurs when a JAX Tracer object is used in a context where a concrete value is required. In some situations, it...
Read more >
Just In Time Compilation with JAX
This is a feature, not a bug: JAX is designed to understand ... weak_type=True)>with<DynamicJaxprTrace(level=0/1)> The problem arose with the bool function.
Read more >
JAX Frequently Asked Questions (FAQ)
scan() ) or avoid wrapping the loop with jit (you can still use jit decorated functions inside the loop). If you're not sure...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found