question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Backpropagating `take` fails for complex arrays

See original GitHub issue

When I’m running the following script

import jax.numpy as jnp
import jax
from netket.nn.symmetric_linear import DenseSymmMatrix

symmetries = jnp.asarray([[0,1,2,3],[3,2,1,0]])
layer = DenseSymmMatrix(symmetries=symmetries, features=2, dtype=complex)

σ = jnp.asarray([[[1.,1.,-1.,-1.]]])

pars = layer.init(jax.random.PRNGKey(0), σ)
def f(p):
  return layer.apply(p, σ)
primals, vjp_fun = jax.vjp(f, pars)

print(vjp_fun(primals))

on newer versions of JAX, I get the following error:

Traceback (most recent call last):
  File "/home/vol06/scarf1036/NQS/Fabien/MNWE.py", line 23, in <module>
    primals, vjp_fun = jax.vjp(f, pars)
  File "/home/vol06/scarf1036/NQS/Fabien/MNWE.py", line 22, in f
    return layer.apply(p, σ)
  File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/netket/nn/symmetric_linear.py", line 120, in __call__
    kernel = jnp.take(kernel, jnp.asarray(self.symmetries), 2)
  File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 3339, in take
    return _take(a, indices, None if axis is None else operator.index(axis), out,
  File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 3397, in _take
    return lax.gather(a, indices[..., None], dimension_numbers=dnums,
jax._src.source_info_util.JaxStackTraceBeforeTransformation: TypeError: 'NoneType' object is not iterable

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/vol06/scarf1036/NQS/Fabien/MNWE.py", line 25, in <module>
    print(vjp_fun(primals))
  File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/jax/_src/tree_util.py", line 284, in __call__
    return self.fun(*args, **kw)
  File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/jax/_src/api.py", line 2333, in _vjp_pullback_wrapper
    ans = fun(*args)
  File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/jax/_src/tree_util.py", line 284, in __call__
    return self.fun(*args, **kw)
  File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/jax/interpreters/ad.py", line 126, in unbound_vjp
    arg_cts = backward_pass(jaxpr, reduce_axes, True, consts, dummy_args, cts)
  File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/jax/interpreters/ad.py", line 232, in backward_pass
    cts_out = get_primitive_transpose(eqn.primitive)(
  File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/jax/interpreters/ad.py", line 601, in call_transpose
    out_flat = primitive.bind(fun, *all_args, **new_params)
  File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/jax/core.py", line 1711, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/jax/core.py", line 1724, in call_bind
    outs = top_trace.process_call(primitive, fun, tracers, params)
  File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/jax/core.py", line 616, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/jax/_src/dispatch.py", line 143, in _xla_call_impl
    compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
  File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/jax/linear_util.py", line 272, in memoized_fun
    ans = call(fun, *args)
  File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/jax/_src/dispatch.py", line 170, in _xla_callable_uncached
    return lower_xla_callable(fun, device, backend, name, donated_invars,
  File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/jax/linear_util.py", line 272, in memoized_fun
    ans = call(fun, *args)
  File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/jax/_src/dispatch.py", line 170, in _xla_callable_uncached
    return lower_xla_callable(fun, device, backend, name, donated_invars,
  File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/jax/_src/profiler.py", line 206, in wrapper
    return func(*args, **kwargs)
  File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/jax/_src/dispatch.py", line 258, in lower_xla_callable
    module = mlir.lower_jaxpr_to_module(
  File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/jax/interpreters/mlir.py", line 501, in lower_jaxpr_to_module
    lower_jaxpr_to_fun(
  File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/jax/interpreters/mlir.py", line 650, in lower_jaxpr_to_fun
    out_vals = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack),
  File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/jax/interpreters/mlir.py", line 738, in jaxpr_subcomp
    ans = rule(rule_ctx, *map(_unwrap_singleton_ir_values, in_nodes),
  File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/jax/_src/lax/slicing.py", line 2065, in _scatter_add_lower_gpu
    (indices,), = clip_fn(ctx, ctx.avals_in, None, operand, indices, updates,
  File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/jax/interpreters/mlir.py", line 774, in f_lowered
    *map(wrap_singleton_ir_values, args))
  File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/jax/_src/util.py", line 47, in safe_map
    return list(map(f, *args))
  File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/jax/interpreters/mlir.py", line 418, in wrap_singleton_ir_values
    return (x,) if isinstance(x, ir.Value) else tuple(x)
TypeError: 'NoneType' object is not iterable

DenseSymmMatrix is a flax module, defined here. The error comes from line 120, which is a call to jnp.take; replacing it with the equivalent kernel = kernel[:,:,jnp.asarray(self.symmetries)] fixes the issue.

The issue is only present

  • when backpropagating (calling f(pars) in the above script runs fine)
  • if dtype=complex not if dtype=float
  • on GPU backends
  • on new versions of JAX (the same script runs fine with v0.2.25)

cf. https://github.com/netket/netket/issues/1171

Issue Analytics

  • State:closed
  • Created a year ago
  • Comments:6 (6 by maintainers)

github_iconTop GitHub Comments

2reactions
YouJiachengcommented, Apr 22, 2022

Okay I fix it. mlir.lower_fun wrapped function has had different signature comparing to xla counterpart since https://github.com/google/jax/commit/a87b21148c6d7eb9b46c751dde40b17ca0e7b03e

1reaction
YouJiachengcommented, Apr 22, 2022

More precisely, it only fail with np.complex128, not np.complex64.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Understanding Backpropagation Algorithm
In this article, I went through a detailed explanation of how backpropagation works under the hood using mathematical techniques like computing ...
Read more >
tensorflow - How to backpropagate with complex valued weights
To do so, we need to run backpropagation on a neural network which contains complex valued weights. When we try to do so...
Read more >
Backpropagation concept explained in 5 levels of difficulty
Backpropagation is the technique used by computers to find out the error between a guess and the correct solution, provided the correct solution...
Read more >
Backpropagation - CEDAR
Matrix Multiplication: Forward Propagation ... In matrix multiplication notation ... We get the backpropagation formula for error derivatives at stage j.
Read more >
An Modified Error Function for the Complex-value ... - CiteSeerX
Abstract—The complex-valued backpropagation algorithm has been widely used in fields ... whose parameters (weights and thresholds) are all complex numbers,.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found