question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

More measurments supported with JAX-JIT

See original GitHub issue

Feature details

With a shots vector that is not None, the jax device only supports Variance and Expectation. It would be a great addition if JAX can support other measurements like probabilities and state.

Hope will to have things like this work:

dev = qml.device('default.qubit.jax', wires=2, shots=10)

@jax.jit
@qml.qnode(dev, interface='jax')
def circ(x):
    qml.PauliZ(wires=0)
    qml.RY(x, wires=0)
    return qml.probs(0)
print(circ(1))

Implementation

No response

How important would you say this feature is?

3: Very important! Blocking work.

Additional information

qml.probs() used to work earlier. But it is now raising InterfaceUnsupportedError

Issue Analytics

  • State:closed
  • Created a year ago
  • Comments:5 (5 by maintainers)

github_iconTop GitHub Comments

1reaction
antalszavacommented, Apr 1, 2022

Hi @ankit27kh, thanks for the report! It’s good to see that you would like to use these features.

A workaround when using defaul.qubit.jax is to set interface=None with shots>0. The reason why this would be necessary is that when shots>0, we always apply the interface (regardless of the device) and the JAX jittable interface doesn’t support returning probabilities and states.

This addition is also in works with https://github.com/PennyLaneAI/pennylane/pull/2034.

0reactions
antalszavacommented, Jun 10, 2022

Hi @ankit27kh, the original example should work now with the master branch (and will be released soon) 👍 Note still, that the example makes use of the JAX jit interface.

Read more comments on GitHub >

github_iconTop Results From Across the Web

JAX Frequently Asked Questions (FAQ)
Keep in mind these important differences from NumPy when measuring the speed of code using JAX: JAX code is Just-In-Time (JIT) compiled. Most...
Read more >
Jit with unknown shape · Issue #803 · google/jax - GitHub
I'm wondering if there is a way to tell the compiler to treat one of the dimensions of a shaped array as unknown....
Read more >
Error when trying to jit the computation of the Jacobian in JAX ...
ValueError: Non-hashable static arguments are not supported, which can lead to unexpected cache-misses. Static argument (index 2) of type <class ...
Read more >
A Quick Intro to JAX with Examples | by Fabio Chiusano
There are two ways in which JAX uses JIT compilation: ... matrices of different sizes: the bigger the multiplied matrixes, the more the...
Read more >
Getting started with JAX - Towards Data Science
We simply import the JAX version of NumPy as well as the good old vanilla version. Most of the standard NumPy functions are...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found