question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[Documentation] Document how to @beartype JAX dataclasses with Equinox

See original GitHub issue

This works without type checking, but fails when I add the @typechecker decorator

import jax.numpy as jnp
from jax import vmap
from typing import NamedTuple, Optional, Union
from beartype import beartype as typechecker
#from typeguard import typechecked as typechecker


@typechecker
class MyParams(NamedTuple):
    mean: float

theta = MyParams(42.0)

def f(params, y):
    return params.mean + y

data = jnp.ones((10,2))
foo = vmap(f, (None,0))(theta, data)
print(foo)

It raises the error

@beartyped namedtuple_MyParams.MyParams.__new__()
 parameter mean=<object object at 0x7fb770c3a840> violates type hint <class 'float'>, 
as object <object object at 0x7fb770c3a840> not instance of float.

If I use typeguard i get the error

type of argument "mean" must be either float or int; got object instead

I think the theta variable is getting reshaped by jax to have a leading batch dimension, and this violates the type signature.

@slinderman @patrick-kidger FYI

Issue Analytics

  • State:open
  • Created 10 months ago
  • Comments:11 (3 by maintainers)

github_iconTop GitHub Comments

1reaction
murphykcommented, Nov 15, 2022

Yes, equinox does indeed solve this problem! 😃 Here is a demonstration of this, for the record (since the benefits of equinox may not be obvious to others 😃


@jaxtyped
@typechecker
#@dataclass(frozen=True) # bad
class MyParams(eqx.Module): # good
    mean: Scalar
    mean2: Float[Array, "ndim"]

theta = MyParams(1.0, jnp.ones(2))

def f(params, y):
    p = MyParams(params.mean + y, y*jnp.ones(3))
    return p

data = jnp.arange(5)

foo = vmap(lambda y: f(theta, y))(data)
print(foo)
print('mean ', foo.mean)
print('mean2 ', foo.mean2)

This type checks and produces the desired output:

MyParams(mean=f32[5], mean2=f32[5,3])
mean  [1. 2. 3. 4. 5.]
mean2
  [[0. 0. 0.]
 [1. 1. 1.]
 [2. 2. 2.]
 [3. 3. 3.]
 [4. 4. 4.]]
1reaction
patrick-kidgercommented, Nov 15, 2022

That’s because, as the error suggests, dataclasses aren’t JAX types. You should register it as PyTree. (Or use equinox.Module, which is basically @dataclass + pytree registration + a few edge-case improvements all wrapped together.)

FWIW this definitely has nothing to do with beartype, so to spare @leycec I suggest we discuss this elsewhere? 😃

Read more comments on GitHub >

github_iconTop Results From Across the Web

Issues · beartype/beartype
[Documentation] Document how to @beartype JAX dataclasses with Equinox. #185 opened 22 days ago by murphyk · 11. [Feature Request] Improve exception message ......
Read more >
dataclasses — Data Classes — Python 3.11.1 documentation
A field is defined as a class variable that has a type annotation. With two exceptions described below, nothing in dataclass() examines the...
Read more >
How do I document a constructor for a class using Python ...
I have __init__ methods with nice docstring documentation, specifying the attributes the constructors take and their types. However, if I change ...
Read more >
conda-forge
blackdoc, 0.3.8, MIT, X, run black on documentation code snippets. blacken-docs, 1.12.1, MIT, X, Run `black` on python code blocks in documentation files....
Read more >
Compare Packages Between Distributions
beartype 0.11.0 beathazardultra 20130308 beautifulsoup4 4.11.1 ... CGI-Ajax 0.707.0. CGI-Application 4.610.0 ... gnome-getting-started-docs 3.38.1
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found