question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Complex-valued layers (e.g: dense)

See original GitHub issue

Hello. I am one of the main contributors of the netket project, a machine-learning for quantum physics package. Together with other developers, we have migrated most of our code to jax, and I am considering to deeply integrate with flax too.

Note: this issue is a sister issue to jax#5312

However, in netket we make extended use of complex-valued neural networks, which at the moment are not supported by flax. To properly support flax I would like to be able to define layers such as Dense, Conv and others to use complex-values. While sometimes we can split the real and complex part of a dense layer by hand, sometimes it would be much easier to have an easier way to do so.

Ideally, I would like the example below to work:

Problem you have encountered:

Complex-valued dtypes with most layers silently return real-valued layers. See for example:

>>> import jax, flax
>>> m=flax.linen.Dense(3, dtype=jax.numpy.complex64)
>>> _, weights = m.init_with_output(jax.random.PRNGKey(0), (3,))
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
>>> weights
FrozenDict({
    params: {
        kernel: DeviceArray([[ 0.67380387, -0.3294223 , -0.9614107 ]], dtype=float32),
        bias: DeviceArray([0., 0., 0.], dtype=float32),
    },
})

Where I would have expected a complex-valued network.

the issue probably lies in jax truncated_normal (called by lecun_normal) not supporting complex values. I am opening this issue, however, because I am not sure that it would ever make sense for lecun_normal to support complex numbers natively, and therefore it might make more sense for flax to have it’s own initialisers dispatch to different implementations depending on the required dtype.

In any case, silently returning the wrong output is confusing.

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Reactions:2
  • Comments:9 (9 by maintainers)

github_iconTop GitHub Comments

1reaction
PhilipVinccommented, Jan 6, 2021

Thanks for chiming in.

Indeed Flax works just fine with complex parameters, but that is not my problem: in quantum physics we often mix and match several layers, some of which are only real, some of which have complex weights, and it would be handy if it was easy to only specify a complex dtype and let flax dispatch to a different initialiser.

I see now what dtype is supposed to mean, thank you for clarifying that. However I’d say that the interface is then a bit… cumbersome if one desires to change the type (even between float32 and float64) of the parameters: instead of simply specifying a kwarg like param_dtype, one has to check how many (and what type) of weight initialisers must be passed to every layer, maybe check the default ones, and essentially call again the default ones with a different dtype.

Imagine one wants to create a float64 version of a Dense layer:

m = nn.Dense(3, dtype=jax.numpy.float64, kernel_init=flax.nn.initializers.lecun_normal(dtype=jax.numpy.float64), bias_init= flax.nn.initializers.normal(dtype=jax.numpy.float64))

and compare this to a much simpler

m = nn.Dense(3, dtype=jax.numpy.float64, params_dtype=jax.numpy.float64)

Moreover, if that was supported, one could also think about adding custom default initialisers in flax that correctly support complex dtypes

0reactions
PhilipVinccommented, Jan 19, 2022

Yes I’m aware. jax will transition to using float32 by default in most applications, but most of those inconsistencies will still be there.

#1776 would address most of them

Read more comments on GitHub >

github_iconTop Results From Across the Web

Complex Dense — cvnn 0.1.0 documentation
class ComplexDense ¶. Fully connected complex-valued layer. Implements the operation: σ(input * weights + bias). where data types can be either complex or ......
Read more >
Complex Valued Neural Networks for Physics Applications
While dense layers represent the more general approach, the amount of parameters increases rapidly, limiting their application for large ...
Read more >
Complex-valued Neural Networks with Non-parametric ... - arXiv
Abstract. Complex-valued neural networks (CVNNs) are a powerful modeling tool for domains where data can be naturally interpreted.
Read more >
Advanced layer types – Introduction to deep-learning
So far, we have seen one type of layer, namely the fully connected, or dense layer. This layer is called fully connected, because...
Read more >
Complex-valued Linear Layers for Deep Neural Network ...
Learnable Complex-valued Frequency Domain Filters. Properties of the Model ... e.g., Dereverberation + Beamforming + Feature extraction.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found