Backpropagating `take` fails for complex arrays
See original GitHub issueWhen I’m running the following script
import jax.numpy as jnp
import jax
from netket.nn.symmetric_linear import DenseSymmMatrix
symmetries = jnp.asarray([[0,1,2,3],[3,2,1,0]])
layer = DenseSymmMatrix(symmetries=symmetries, features=2, dtype=complex)
σ = jnp.asarray([[[1.,1.,-1.,-1.]]])
pars = layer.init(jax.random.PRNGKey(0), σ)
def f(p):
return layer.apply(p, σ)
primals, vjp_fun = jax.vjp(f, pars)
print(vjp_fun(primals))
on newer versions of JAX, I get the following error:
Traceback (most recent call last):
File "/home/vol06/scarf1036/NQS/Fabien/MNWE.py", line 23, in <module>
primals, vjp_fun = jax.vjp(f, pars)
File "/home/vol06/scarf1036/NQS/Fabien/MNWE.py", line 22, in f
return layer.apply(p, σ)
File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/netket/nn/symmetric_linear.py", line 120, in __call__
kernel = jnp.take(kernel, jnp.asarray(self.symmetries), 2)
File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 3339, in take
return _take(a, indices, None if axis is None else operator.index(axis), out,
File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 3397, in _take
return lax.gather(a, indices[..., None], dimension_numbers=dnums,
jax._src.source_info_util.JaxStackTraceBeforeTransformation: TypeError: 'NoneType' object is not iterable
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/vol06/scarf1036/NQS/Fabien/MNWE.py", line 25, in <module>
print(vjp_fun(primals))
File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/jax/_src/tree_util.py", line 284, in __call__
return self.fun(*args, **kw)
File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/jax/_src/api.py", line 2333, in _vjp_pullback_wrapper
ans = fun(*args)
File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/jax/_src/tree_util.py", line 284, in __call__
return self.fun(*args, **kw)
File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/jax/interpreters/ad.py", line 126, in unbound_vjp
arg_cts = backward_pass(jaxpr, reduce_axes, True, consts, dummy_args, cts)
File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/jax/interpreters/ad.py", line 232, in backward_pass
cts_out = get_primitive_transpose(eqn.primitive)(
File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/jax/interpreters/ad.py", line 601, in call_transpose
out_flat = primitive.bind(fun, *all_args, **new_params)
File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/jax/core.py", line 1711, in bind
return call_bind(self, fun, *args, **params)
File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/jax/core.py", line 1724, in call_bind
outs = top_trace.process_call(primitive, fun, tracers, params)
File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/jax/core.py", line 616, in process_call
return primitive.impl(f, *tracers, **params)
File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/jax/_src/dispatch.py", line 143, in _xla_call_impl
compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/jax/linear_util.py", line 272, in memoized_fun
ans = call(fun, *args)
File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/jax/_src/dispatch.py", line 170, in _xla_callable_uncached
return lower_xla_callable(fun, device, backend, name, donated_invars,
File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/jax/linear_util.py", line 272, in memoized_fun
ans = call(fun, *args)
File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/jax/_src/dispatch.py", line 170, in _xla_callable_uncached
return lower_xla_callable(fun, device, backend, name, donated_invars,
File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/jax/_src/profiler.py", line 206, in wrapper
return func(*args, **kwargs)
File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/jax/_src/dispatch.py", line 258, in lower_xla_callable
module = mlir.lower_jaxpr_to_module(
File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/jax/interpreters/mlir.py", line 501, in lower_jaxpr_to_module
lower_jaxpr_to_fun(
File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/jax/interpreters/mlir.py", line 650, in lower_jaxpr_to_fun
out_vals = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack),
File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/jax/interpreters/mlir.py", line 738, in jaxpr_subcomp
ans = rule(rule_ctx, *map(_unwrap_singleton_ir_values, in_nodes),
File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/jax/_src/lax/slicing.py", line 2065, in _scatter_add_lower_gpu
(indices,), = clip_fn(ctx, ctx.avals_in, None, operand, indices, updates,
File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/jax/interpreters/mlir.py", line 774, in f_lowered
*map(wrap_singleton_ir_values, args))
File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/jax/_src/util.py", line 47, in safe_map
return list(map(f, *args))
File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/jax/interpreters/mlir.py", line 418, in wrap_singleton_ir_values
return (x,) if isinstance(x, ir.Value) else tuple(x)
TypeError: 'NoneType' object is not iterable
DenseSymmMatrix
is a flax module, defined here. The error comes from line 120, which is a call to jnp.take
; replacing it with the equivalent kernel = kernel[:,:,jnp.asarray(self.symmetries)]
fixes the issue.
The issue is only present
- when backpropagating (calling
f(pars)
in the above script runs fine) - if
dtype=complex
not ifdtype=float
- on GPU backends
- on new versions of JAX (the same script runs fine with v0.2.25)
Issue Analytics
- State:
- Created a year ago
- Comments:6 (6 by maintainers)
Top Results From Across the Web
Understanding Backpropagation Algorithm
In this article, I went through a detailed explanation of how backpropagation works under the hood using mathematical techniques like computing ...
Read more >tensorflow - How to backpropagate with complex valued weights
To do so, we need to run backpropagation on a neural network which contains complex valued weights. When we try to do so...
Read more >Backpropagation concept explained in 5 levels of difficulty
Backpropagation is the technique used by computers to find out the error between a guess and the correct solution, provided the correct solution...
Read more >Backpropagation - CEDAR
Matrix Multiplication: Forward Propagation ... In matrix multiplication notation ... We get the backpropagation formula for error derivatives at stage j.
Read more >An Modified Error Function for the Complex-value ... - CiteSeerX
Abstract—The complex-valued backpropagation algorithm has been widely used in fields ... whose parameters (weights and thresholds) are all complex numbers,.
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Okay I fix it. mlir.lower_fun wrapped function has had different signature comparing to xla counterpart since https://github.com/google/jax/commit/a87b21148c6d7eb9b46c751dde40b17ca0e7b03e
More precisely, it only fail with
np.complex128
, notnp.complex64
.