question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Improve Error Message: Jitting linen.Module

See original GitHub issue

Example code:

model = Model(1)

#@jax.jit
def eval_step(model, params, batch):
  logits = model.apply({'params': params}, batch['X'])
  return compute_metrics(logits, batch['y'])

def eval_model(model, params, test_ds):
  metrics = eval_step(model, params, test_ds)
  metrics = jax.device_get(metrics)
  summary = jax.tree_map(lambda x: x.item(), metrics)
  return summary['loss'], summary['accuracy']

print(eval_model(model, params, test_ds))

Throws the following general JAX error:

TypeError: Argument 'Model(
    # attributes
    features = 3
)' of type <class '__main__.Model'> is not a valid JAX type

Modules in Linen aren’t “pytypes” thus they can’t be flattened/unflattened as needed when entering and exiting JAX transformations. The common pattern is to use static_argnums to jit which is equivalent to having the module instance behave like a constant inside the transformed function

We should consider registering Modules with pytrees just to throw an error explaining this.

Issue Analytics

  • State:open
  • Created 3 years ago
  • Reactions:2
  • Comments:13 (11 by maintainers)

github_iconTop GitHub Comments

1reaction
jheekcommented, Apr 19, 2022

Normally you would use jax.jit(jitted_step_bad, static_argnums=(0,)) so the model is a static argument. You cannot pass a model as a normal arg because it’s not a (pytree of) JAX arrays.

0reactions
melissatancommented, Jul 6, 2022

Was caught up with other responsibilities for some time, picking this back up now.

I’ve made the changes @jheek suggested, in the 3 tree_map() calls in get_module_scopes(). But this appears to be causing a test failure in several tests that use nn.jit() in linen_transforms_test, e.g.:

def test_compact_aliasing_collision(self):
    class Foo(nn.Module):
      m1: nn.Module
      m2: nn.Module
      @nn.compact
      def __call__(self, x):
        x = self.m2(self.m1(x))
        return x
    class Bar(nn.Module):
      @nn.compact
      def __call__(self, x):
        dense = nn.Dense(2)
        x = nn.jit(Foo)(dense, dense)(x)  # <-- fails here, and outputs the jit pytree error that I'm defining in the PR #2270 .
        return x
    k = random.PRNGKey(0)
    x = jnp.zeros((2, 2))
    _ = Bar().init(k, x)

I tried specifying static_argnums for the nn.jit() call in the test above, but then got another failure about “Non-hashable static arguments are not supported.”.

I’m not sufficiently familiar with the internals of linen_transforms. Is there another workaround that prevents the failure in module_lifecycle.rst (https://github.com/google/flax/issues/853#issuecomment-1113396067)?

Read more comments on GitHub >

github_iconTop Results From Across the Web

flax.linen.module - Read the Docs
__self__, Module)): # pytype: disable=attribute-error method_or_fn ... and re-raise it as a # more informative and correct error message. try: return ...
Read more >
https://raw.githubusercontent.com/google/flax/mast...
Improved many docstrings and error messages. ... Add `param_dtype` attribute to standard Linen Modules for specifying parameter dtypes.
Read more >
Flax 2 ("Linen") - Colaboratory - Google Colab
Install Flax at head: !pip install --upgrade -q git+https://github.com/google/flax.git ... We call the init method on the instantiated Module.
Read more >
Module 6 Safety and Infection Control Flashcards | Quizlet
The nursing supervisor would be notified of the incident; however, on the basis of the data in the question, the nurse should tell...
Read more >
Inventory Control Defined: Best Practices, Systems ... - NetSuite
This guide provides everything you need to get started on inventory control. ... How Inventory Control Can Improve Your Business.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found