question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[BUG] Jitting with JAX produces incorrect results with parameter broadcasting

See original GitHub issue

Expected behavior

Consider the following QNode:

dev = qml.device("default.qubit", wires=2)

@qml.qnode(dev, diff_method="backprop", interface="jax")
def circuit(x):
    qml.RX(x, wires=0)
    qml.CNOT(wires=[0, 1])
    return qml.expval(qml.PauliZ(1))

When the QNode is called with a parameter, the following should be output:

>>> x = jnp.array([0.3, 0.6])
>>> jax.jit(circuit)(x)
[0.9553365 0.8253356]

Actual behavior

Instead, the following is output:

>>> x = jnp.array([0.3, 0.6])
>>> jax.jit(circuit)(x)
[0.95533645 0.95533645]

It seems that only the first parameter is used for all calculations and the rest are ignored.

Additional information

If diff_method="parameter-shift" is set in the QNode, the issue disappears.

If jax.jit is removed in the above code, the issue disappears.

The issue also appears if the QNode is decorated with qml.transforms.batch_params.

Source code

No response

Tracebacks

No response

System information

Name: PennyLane
Version: 0.24.0.dev0
Summary: PennyLane is a Python quantum machine learning library by Xanadu Inc.
Home-page: https://github.com/XanaduAI/pennylane
Author: None
Author-email: None
License: Apache License 2.0
Location: c:\users\edward.jiang\documents\pennylane
Requires: numpy, scipy, networkx, retworkx, autograd, toml, appdirs, semantic-version, autoray, cachetools, pennylane-lightning
Required-by: PennyLane-Lightning

Platform info:           Windows-10-10.0.19042-SP0
Python version:          3.8.10
Numpy version:           1.22.3
Scipy version:           1.8.0
Installed devices:
- default.gaussian (PennyLane-0.24.0.dev0)
- default.mixed (PennyLane-0.24.0.dev0)
- default.mixed.autograd (PennyLane-0.24.0.dev0)
- default.qubit (PennyLane-0.24.0.dev0)
- default.qubit.autograd (PennyLane-0.24.0.dev0)
- default.qubit.jax (PennyLane-0.24.0.dev0)
- default.qubit.tf (PennyLane-0.24.0.dev0)
- default.qubit.torch (PennyLane-0.24.0.dev0)
- lightning.qubit (PennyLane-Lightning-0.23.0)

Existing GitHub issues

  • I have searched existing GitHub issues to make sure the issue does not already exist.

Issue Analytics

  • State:open
  • Created a year ago
  • Comments:6 (2 by maintainers)

github_iconTop GitHub Comments

1reaction
dwierichscommented, Jun 24, 2022

Thanks for digging this one out! 👍 The reason seems to be, that qml.transforms.broadcast_expand is not JIT compatible.

edit: Actually, that’s not quite true. The problem is that by default, qml.execute, which is called in QNode.__call__, makes use of qml.interfaces.cache_execute, which in turn caches the result and therefore does not produce multiple different, but only a single result, which then is retrieved from cache. I patched PL to print out the cache values and receive, when calling the example above:

>>> out = circuit(x)
cached values: [Traced<ShapedArray(float64[1])>with<DynamicJaxprTrace(level=0/1)>]

As we can see, only one value is computed and stored in the cache. I think this is because all traced tapes after applying broadcast_transform have the same hash, and I’m not sure it’s possible to change that.

What next?

  1. With the introduction of parameter broadcasting to DefaultQubit, this problem will be gone, but only for DefaultQubit devices. In particular, DefaultMixed and basically all other devices will still have this problem.
  2. This problem for example did not occur when the device’s batch_transform produces multiple tapes because of a Hamiltonian decomposition. That’s because the resulting terms differ in the measured observables and the tape hashes also differ (if they don’t, the Hamiltonian would not be decomposed, I suppose)
  3. I think we should deactivate caching when executing broadcasted tapes, but I am not sure how to do this best. One option would be to override (but not overwrite!) the execution kwarg cache in QNode.__call__ if the QNode tape after construction has a batch_size that is not None, and if the QNode device does not support broadcasting (there is a device flag for this). @josh146 what do you think?

Also tagging @antalszava for visibility. Do you have an idea how to proceed best here? 😃

1reaction
eddddddycommented, Jun 22, 2022

Perhaps this is more related to the recent additions to JAX-JIT support than to parameter broadcasting? Not entirely sure 😕.

Read more comments on GitHub >

github_iconTop Results From Across the Web

passing parameters to jax.jit 'ed functions #1922 - GitHub
When I comment the jit-decorator out, it runs properly. If this is deliberate, could someone clarify what causes this behavior? from jax import ......
Read more >
The Sharp Bits — JAX documentation
In the code below, there are some examples of incorrect attempts to use iterators with JAX. Most of them return an error, but...
Read more >
Jersey 2.37 User Guide - GitHub Pages
Implementing SSE support in a JAX-RS resource (with JAX-RS SSE API). 16.4.1. Simple SSE resource method; 16.4.2. Broadcasting with Jersey SSE.
Read more >
Error upon compilation while using jax.jit - Stack Overflow
f(*args, **dict(self.params, **kwargs)) # File "test_7.py", line 42, in compute # arr_1 = arr_1.at[i ...
Read more >
PyTorch on XLA Devices
PyTorch/XLA makes it easy to accelerate training by running on multiple XLA ... If False, the file_or_path argument should be a different file...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found