Improve Error Message: Jitting linen.Module
See original GitHub issueExample code:
model = Model(1)
#@jax.jit
def eval_step(model, params, batch):
logits = model.apply({'params': params}, batch['X'])
return compute_metrics(logits, batch['y'])
def eval_model(model, params, test_ds):
metrics = eval_step(model, params, test_ds)
metrics = jax.device_get(metrics)
summary = jax.tree_map(lambda x: x.item(), metrics)
return summary['loss'], summary['accuracy']
print(eval_model(model, params, test_ds))
Throws the following general JAX error:
TypeError: Argument 'Model(
# attributes
features = 3
)' of type <class '__main__.Model'> is not a valid JAX type
Modules in Linen aren’t “pytypes” thus they can’t be flattened/unflattened as needed when entering and exiting JAX transformations. The common pattern is to use static_argnums to jit which is equivalent to having the module instance behave like a constant inside the transformed function
We should consider registering Modules with pytrees just to throw an error explaining this.
Issue Analytics
- State:
- Created 3 years ago
- Reactions:2
- Comments:13 (11 by maintainers)
Top Results From Across the Web
flax.linen.module - Read the Docs
__self__, Module)): # pytype: disable=attribute-error method_or_fn ... and re-raise it as a # more informative and correct error message. try: return ...
Read more >https://raw.githubusercontent.com/google/flax/mast...
Improved many docstrings and error messages. ... Add `param_dtype` attribute to standard Linen Modules for specifying parameter dtypes.
Read more >Flax 2 ("Linen") - Colaboratory - Google Colab
Install Flax at head: !pip install --upgrade -q git+https://github.com/google/flax.git ... We call the init method on the instantiated Module.
Read more >Module 6 Safety and Infection Control Flashcards | Quizlet
The nursing supervisor would be notified of the incident; however, on the basis of the data in the question, the nurse should tell...
Read more >Inventory Control Defined: Best Practices, Systems ... - NetSuite
This guide provides everything you need to get started on inventory control. ... How Inventory Control Can Improve Your Business.
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Normally you would use jax.jit(jitted_step_bad, static_argnums=(0,)) so the model is a static argument. You cannot pass a model as a normal arg because it’s not a (pytree of) JAX arrays.
Was caught up with other responsibilities for some time, picking this back up now.
I’ve made the changes @jheek suggested, in the 3 tree_map() calls in get_module_scopes(). But this appears to be causing a test failure in several tests that use nn.jit() in linen_transforms_test, e.g.:
I tried specifying static_argnums for the nn.jit() call in the test above, but then got another failure about “Non-hashable static arguments are not supported.”.
I’m not sufficiently familiar with the internals of linen_transforms. Is there another workaround that prevents the failure in module_lifecycle.rst (https://github.com/google/flax/issues/853#issuecomment-1113396067)?