One array type to rule them all!
See original GitHub issueI use Python 3’s typing features as much as possible. Unfortunately jax’s value hierarchy makes this a little bit challenging. Consider the following snippet,
from typing import NamedTuple
import jax.numpy as jp
from jax import lax, random
class Normal(NamedTuple):
loc: ArrayType
scale: ArrayType
def sample(self, rng, sample_shape=()) -> ArrayType:
batch_shape = lax.broadcast_shapes(self.loc.shape, self.scale.shape)
return self.loc + self.scale * random.normal(
rng, shape=sample_shape + batch_shape)
I’d like to be able to fill in the mystery ArrayType
s with something like a made-up “jp.Array”, but AFAICT from the array class hierarchy, there is no such type that really fits. At first glance jp.DeviceArray
looks like an eligible candidate, but then there is also ConcreteArray
, ShapedArray
, and UnshapedArray
. I’m not really sure what the differences are between them but some of them derive from jax.core.AbstractValue
, while DeviceArray
does not… If there are other cases then I’d certainly like to avoid limiting my type signatures to only operating on arrays that live on-device. To make matters more confusing there also seems to be _FilledConstant
and DeviceConstant
:
In [10]: jp.ones((2, 3))
Out[10]:
_FilledConstant([[1., 1., 1.],
[1., 1., 1.]], dtype=float32)
All in all, it’s not clear to me how each of these types play together and how (if possible) to unify them. What’s the appropriate type to be used here? And if it does not yet exist, could we create such a type?
Issue Analytics
- State:
- Created 4 years ago
- Reactions:3
- Comments:13 (9 by maintainers)
Have you looked at https://github.com/numpy/numpy-stubs and https://github.com/ramonhagenaars/nptyping ?
FWIW my solution here has been to begin a pyi module that is symlinked for both
numpy
andjax.numpy
:That certainly doesn’t solve the problem for
jax.lax
, etc. but it’s a start. I guess it also means that one way to approach this would be to have ajax-types
package that contains a bunch of pyi definitions, instead of having to start including type annotations in jax itself if that is not desirable.