One array type to rule them all!
See original GitHub issueI use Python 3’s typing features as much as possible. Unfortunately jax’s value hierarchy makes this a little bit challenging. Consider the following snippet,
from typing import NamedTuple
import jax.numpy as jp
from jax import lax, random
class Normal(NamedTuple):
loc: ArrayType
scale: ArrayType
def sample(self, rng, sample_shape=()) -> ArrayType:
batch_shape = lax.broadcast_shapes(self.loc.shape, self.scale.shape)
return self.loc + self.scale * random.normal(
rng, shape=sample_shape + batch_shape)
I’d like to be able to fill in the mystery ArrayTypes with something like a made-up “jp.Array”, but AFAICT from the array class hierarchy, there is no such type that really fits. At first glance jp.DeviceArray looks like an eligible candidate, but then there is also ConcreteArray, ShapedArray, and UnshapedArray. I’m not really sure what the differences are between them but some of them derive from jax.core.AbstractValue, while DeviceArray does not… If there are other cases then I’d certainly like to avoid limiting my type signatures to only operating on arrays that live on-device. To make matters more confusing there also seems to be _FilledConstant and DeviceConstant:
In [10]: jp.ones((2, 3))
Out[10]:
_FilledConstant([[1., 1., 1.],
[1., 1., 1.]], dtype=float32)
All in all, it’s not clear to me how each of these types play together and how (if possible) to unify them. What’s the appropriate type to be used here? And if it does not yet exist, could we create such a type?
Issue Analytics
- State:
- Created 4 years ago
- Reactions:3
- Comments:13 (9 by maintainers)

Top Related StackOverflow Question
Have you looked at https://github.com/numpy/numpy-stubs and https://github.com/ramonhagenaars/nptyping ?
FWIW my solution here has been to begin a pyi module that is symlinked for both
numpyandjax.numpy:That certainly doesn’t solve the problem for
jax.lax, etc. but it’s a start. I guess it also means that one way to approach this would be to have ajax-typespackage that contains a bunch of pyi definitions, instead of having to start including type annotations in jax itself if that is not desirable.