[BUG] Jitting with JAX produces incorrect results with parameter broadcasting
See original GitHub issueExpected behavior
Consider the following QNode:
dev = qml.device("default.qubit", wires=2)
@qml.qnode(dev, diff_method="backprop", interface="jax")
def circuit(x):
qml.RX(x, wires=0)
qml.CNOT(wires=[0, 1])
return qml.expval(qml.PauliZ(1))
When the QNode is called with a parameter, the following should be output:
>>> x = jnp.array([0.3, 0.6])
>>> jax.jit(circuit)(x)
[0.9553365 0.8253356]
Actual behavior
Instead, the following is output:
>>> x = jnp.array([0.3, 0.6])
>>> jax.jit(circuit)(x)
[0.95533645 0.95533645]
It seems that only the first parameter is used for all calculations and the rest are ignored.
Additional information
If diff_method="parameter-shift" is set in the QNode, the issue disappears.
If jax.jit is removed in the above code, the issue disappears.
The issue also appears if the QNode is decorated with qml.transforms.batch_params.
Source code
No response
Tracebacks
No response
System information
Name: PennyLane
Version: 0.24.0.dev0
Summary: PennyLane is a Python quantum machine learning library by Xanadu Inc.
Home-page: https://github.com/XanaduAI/pennylane
Author: None
Author-email: None
License: Apache License 2.0
Location: c:\users\edward.jiang\documents\pennylane
Requires: numpy, scipy, networkx, retworkx, autograd, toml, appdirs, semantic-version, autoray, cachetools, pennylane-lightning
Required-by: PennyLane-Lightning
Platform info: Windows-10-10.0.19042-SP0
Python version: 3.8.10
Numpy version: 1.22.3
Scipy version: 1.8.0
Installed devices:
- default.gaussian (PennyLane-0.24.0.dev0)
- default.mixed (PennyLane-0.24.0.dev0)
- default.mixed.autograd (PennyLane-0.24.0.dev0)
- default.qubit (PennyLane-0.24.0.dev0)
- default.qubit.autograd (PennyLane-0.24.0.dev0)
- default.qubit.jax (PennyLane-0.24.0.dev0)
- default.qubit.tf (PennyLane-0.24.0.dev0)
- default.qubit.torch (PennyLane-0.24.0.dev0)
- lightning.qubit (PennyLane-Lightning-0.23.0)
Existing GitHub issues
- I have searched existing GitHub issues to make sure the issue does not already exist.
Issue Analytics
- State:
- Created a year ago
- Comments:6 (2 by maintainers)
Top Results From Across the Web
passing parameters to jax.jit 'ed functions #1922 - GitHub
When I comment the jit-decorator out, it runs properly. If this is deliberate, could someone clarify what causes this behavior? from jax import ......
Read more >The Sharp Bits — JAX documentation
In the code below, there are some examples of incorrect attempts to use iterators with JAX. Most of them return an error, but...
Read more >Jersey 2.37 User Guide - GitHub Pages
Implementing SSE support in a JAX-RS resource (with JAX-RS SSE API). 16.4.1. Simple SSE resource method; 16.4.2. Broadcasting with Jersey SSE.
Read more >Error upon compilation while using jax.jit - Stack Overflow
f(*args, **dict(self.params, **kwargs)) # File "test_7.py", line 42, in compute # arr_1 = arr_1.at[i ...
Read more >PyTorch on XLA Devices
PyTorch/XLA makes it easy to accelerate training by running on multiple XLA ... If False, the file_or_path argument should be a different file...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found

Thanks for digging this one out! 👍 The reason seems to be, that
qml.transforms.broadcast_expandis not JIT compatible.edit: Actually, that’s not quite true. The problem is that by default,
qml.execute, which is called inQNode.__call__, makes use ofqml.interfaces.cache_execute, which in turn caches the result and therefore does not produce multiple different, but only a single result, which then is retrieved from cache. I patched PL to print out the cache values and receive, when calling the example above:As we can see, only one value is computed and stored in the cache. I think this is because all traced tapes after applying
broadcast_transformhave the same hash, and I’m not sure it’s possible to change that.What next?
DefaultQubit, this problem will be gone, but only forDefaultQubitdevices. In particular,DefaultMixedand basically all other devices will still have this problem.batch_transformproduces multiple tapes because of a Hamiltonian decomposition. That’s because the resulting terms differ in the measured observables and the tape hashes also differ (if they don’t, the Hamiltonian would not be decomposed, I suppose)cacheinQNode.__call__if the QNode tape after construction has abatch_sizethat is not None, and if the QNode device does not support broadcasting (there is a device flag for this). @josh146 what do you think?Also tagging @antalszava for visibility. Do you have an idea how to proceed best here? 😃
Perhaps this is more related to the recent additions to JAX-JIT support than to parameter broadcasting? Not entirely sure 😕.