Modify `default.mixed` to use `qml.math` for native autodiff backpropagation support
See original GitHub issueFeature details
We would like to create a version of the default.mixed simulator that can perform computations native to the TensorFlow, PyTorch and Jax autodiff frameworks. This can be achieved by modifying the existing device to use the qml.math module. The qml.math module supports dispatching to the aforementioned frameworks by using autoray. It allows one to write code that would be compatible with multiple auto differentiation frameworks.
There would be two advantages to modifying the default.mixed device to use qml.math:
- The device would support backpropagation and so diff_method=‘backprop’ should be supported;
- Framework specific features would work with
default.mixed:- TensorFlow: hybrid models using Keras could be supported using backprop
- Jax: the jax.vmap function for parallelizing the quantum circuit computations would be supported
Implementation
As pointers, the following code snippets should execute well without any errors raised:
TensorFlow
import tensorflow as tf
p = 0.01
dev = qml.device("default.mixed", wires=1)
@qml.qnode(dev, interface="tf", diff_method="backprop")
def circuit(x):
qml.RX(x[1], wires=0)
qml.Rot(x[0], x[1], x[2], wires=0)
qml.DepolarizingChannel(p, wires=0)
return qml.expval(qml.PauliZ(0))
weights = tf.Variable([0.2, 0.5, 0.1])
with tf.GradientTape() as tape:
res = circuit(weights)
print(tape.gradient(res, weights))
PyTorch
dev = qml.device('default.mixed', wires=2)
p = 0.01
@qml.qnode(dev, interface='torch', diff_method='backprop')
def circuit3(phi, theta):
qml.RX(x[1], wires=0)
qml.Rot(x[0], x[1], x[2], wires=0)
qml.DepolarizingChannel(p, wires=0)
return qml.expval(qml.PauliZ(0))
phi = torch.tensor([0.5, 0.1, 0.4], requires_grad=True)
result = circuit3(phi)
result.backward()
phi.grad
theta.grad
Jax
from jax import numpy as jnp
dev = qml.device("default.mixed", wires=1)
p = 0.01
@qml.qnode(dev, interface="jax", diff_method="backprop")
def circuit(x):
qml.RX(x[1], wires=0)
qml.Rot(x[0], x[1], x[2], wires=0)
qml.DepolarizingChannel(p, wires=0)
return qml.expval(qml.PauliZ(0))
weights = jnp.array([0.2, 0.5, 0.1])
grad_fn = jax.grad(circuit)
print(grad_fn(weights))
import jax
import networkx as nx
import pennylane as qml
G = nx.Graph()
G.add_nodes_from([0,1,2,3])
G.add_edges_from([(0,3),(1,2),(1,3)])
def noisy_circuit(prob,**kwargs):
for k in range(len(G.nodes)):
qml.BitFlip(prob, wires=k)
return qml.expval(qml.PauliZ(0))
dev = qml.device("default.mixed", wires = len(G.nodes))
qcircuit = qml.QNode(noisy_circuit, dev, interface = "jax")
vcircuit = jax.vmap(qcircuit)
probs = jax.numpy.array([0., 0.05, 0.1])
vcircuit(probs)
It is worth noting that this issue is highly exploratory and resolving the issue will likely entail a deeper dive beforehand.
How important would you say this feature is?
1: Not important. Would be nice to have.
Additional information
See the following discussion forum threads:
TensorFlow & Keras with noise models: https://discuss.pennylane.ai/t/adding-noise-to-a-keras-hybrid-nn/1241
Jax vmap function: https://discuss.pennylane.ai/t/jax-with-default-mixed-device/1170
Issue Analytics
- State:
- Created 2 years ago
- Reactions:2
- Comments:30 (30 by maintainers)

Top Related StackOverflow Question
Yep, exactly 😃 The plan is to remove all of
tf_ops.py,torch_ops.py, etc.Ideally, we can remove all the static methods from the devices as well at some point!
If an op is non-parametric, I believe it is okay to use NumPy, since the autodiff frameworks will automatically convert it to a tensor when needed.
Hello @Jaybsoni, thanks for the update. I am using
default.mixedfor noise simulations which takes a very long time. I did all of my ideal simulations usingdefault.qubit.jaxso I was hoping to have JAX and JIT support with noise simulations to drastically reduce runtime.