Complex-valued layers (e.g: dense)
See original GitHub issueHello. I am one of the main contributors of the netket project, a machine-learning for quantum physics package. Together with other developers, we have migrated most of our code to jax, and I am considering to deeply integrate with flax too.
Note: this issue is a sister issue to jax#5312
However, in netket we make extended use of complex-valued neural networks, which at the moment are not supported by flax.
To properly support flax I would like to be able to define layers such as Dense
, Conv
and others to use complex-values.
While sometimes we can split the real and complex part of a dense layer by hand, sometimes it would be much easier to have an easier way to do so.
Ideally, I would like the example below to work:
Problem you have encountered:
Complex-valued dtype
s with most layers silently return real-valued layers. See for example:
>>> import jax, flax
>>> m=flax.linen.Dense(3, dtype=jax.numpy.complex64)
>>> _, weights = m.init_with_output(jax.random.PRNGKey(0), (3,))
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
>>> weights
FrozenDict({
params: {
kernel: DeviceArray([[ 0.67380387, -0.3294223 , -0.9614107 ]], dtype=float32),
bias: DeviceArray([0., 0., 0.], dtype=float32),
},
})
Where I would have expected a complex-valued network.
the issue probably lies in jax truncated_normal
(called by lecun_normal
) not supporting complex values.
I am opening this issue, however, because I am not sure that it would ever make sense for lecun_normal
to support complex numbers natively, and therefore it might make more sense for flax to have it’s own initialisers dispatch to different implementations depending on the required dtype.
In any case, silently returning the wrong output is confusing.
Issue Analytics
- State:
- Created 3 years ago
- Reactions:2
- Comments:9 (9 by maintainers)
Top GitHub Comments
Thanks for chiming in.
Indeed Flax works just fine with complex parameters, but that is not my problem: in quantum physics we often mix and match several layers, some of which are only real, some of which have complex weights, and it would be handy if it was easy to only specify a complex
dtype
and let flax dispatch to a different initialiser.I see now what dtype is supposed to mean, thank you for clarifying that. However I’d say that the interface is then a bit… cumbersome if one desires to change the type (even between float32 and float64) of the parameters: instead of simply specifying a kwarg like
param_dtype
, one has to check how many (and what type) of weight initialisers must be passed to every layer, maybe check the default ones, and essentially call again the default ones with a different dtype.Imagine one wants to create a
float64
version of a Dense layer:and compare this to a much simpler
Moreover, if that was supported, one could also think about adding custom default initialisers in flax that correctly support complex
dtypes
…Yes I’m aware. jax will transition to using float32 by default in most applications, but most of those inconsistencies will still be there.
#1776 would address most of them