[Documentation] Document how to @beartype JAX dataclasses with Equinox
See original GitHub issueThis works without type checking, but fails when I add the @typechecker
decorator
import jax.numpy as jnp
from jax import vmap
from typing import NamedTuple, Optional, Union
from beartype import beartype as typechecker
#from typeguard import typechecked as typechecker
@typechecker
class MyParams(NamedTuple):
mean: float
theta = MyParams(42.0)
def f(params, y):
return params.mean + y
data = jnp.ones((10,2))
foo = vmap(f, (None,0))(theta, data)
print(foo)
It raises the error
@beartyped namedtuple_MyParams.MyParams.__new__()
parameter mean=<object object at 0x7fb770c3a840> violates type hint <class 'float'>,
as object <object object at 0x7fb770c3a840> not instance of float.
If I use typeguard i get the error
type of argument "mean" must be either float or int; got object instead
I think the theta
variable is getting reshaped by jax to have a leading batch dimension,
and this violates the type signature.
Issue Analytics
- State:
- Created 10 months ago
- Comments:11 (3 by maintainers)
Top Results From Across the Web
Issues · beartype/beartype
[Documentation] Document how to @beartype JAX dataclasses with Equinox. #185 opened 22 days ago by murphyk · 11. [Feature Request] Improve exception message ......
Read more >dataclasses — Data Classes — Python 3.11.1 documentation
A field is defined as a class variable that has a type annotation. With two exceptions described below, nothing in dataclass() examines the...
Read more >How do I document a constructor for a class using Python ...
I have __init__ methods with nice docstring documentation, specifying the attributes the constructors take and their types. However, if I change ...
Read more >conda-forge
blackdoc, 0.3.8, MIT, X, run black on documentation code snippets. blacken-docs, 1.12.1, MIT, X, Run `black` on python code blocks in documentation files....
Read more >Compare Packages Between Distributions
beartype 0.11.0 beathazardultra 20130308 beautifulsoup4 4.11.1 ... CGI-Ajax 0.707.0. CGI-Application 4.610.0 ... gnome-getting-started-docs 3.38.1
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Yes, equinox does indeed solve this problem! 😃 Here is a demonstration of this, for the record (since the benefits of equinox may not be obvious to others 😃
This type checks and produces the desired output:
That’s because, as the error suggests, dataclasses aren’t JAX types. You should register it as PyTree. (Or use
equinox.Module
, which is basically@dataclass
+ pytree registration + a few edge-case improvements all wrapped together.)FWIW this definitely has nothing to do with
beartype
, so to spare @leycec I suggest we discuss this elsewhere? 😃