question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Please allow @jax.jit on staticmethods

See original GitHub issue

The following code does not work,

class Foo:
  @jax.jit
  @staticmethod
  def foo(x):
    return x * 2

with an error TypeError: Expected a callable value, got <staticmethod object at 0x7fe007e767f0>. This is because a staticmethod is NOT a callable, until Python 3.10. Of course, staticmethod is callable – something we can call.

In #1251 one possible workaround is suggested:

class Foo:
  @functools.partial(jax.jit, static_argnums=(0,))
  def foo(self, x):
    # this function should not make use of `self` to be pure-functional
    return x * 2

which is a bit ugly, but basically does the same thing as long as self is not used.

I don’t see any reason staticmethod cannot be jit-ed as other non-method functions. I think this is just a matter of extending _check_callable(fun) to support staticmethods (which will happen automatically in Python 3.10); any pitfalls I would’ve missed?

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:8 (3 by maintainers)

github_iconTop GitHub Comments

2reactions
jakevdpcommented, Aug 26, 2021

I think the cleanest approach is probably to use @staticmethod @jit rather than @jit @staticmethod.

If what you say about Python 3.10 is correct, though, the tests referred to above will (I think?) fail when run with Py3.10. If there are no fundamental reasons why staticmethod objects cannot be jitted, I would be fine with a PR that removes those tests and updates _is_callable to recognize static methods.

2reactions
wookayincommented, Aug 26, 2021

Thank you @jblespiau for the explanation.

It’s always possible to have a plain, top-level function doing the same thing. So usually, we never have static methods, and thus, never want to jax.jit them.

That would probably be how Google usually writes a JAX code, but I respectfully disagree staticmethods should be prohibited, because some community codes that do not necessarily follow the Google convention strictly might still want jit-ed staticmethods. Personally I don’t like having a top- or module-level plain function because I have to place them outside the class, usually very far from the methods that implement a very relevant logic.

That said, as an alternative one could define a (nested) plain function inside __init__ and assign them as an attribute, but I feel using staticmethods would give a bit more flexibility. I can prepare a PR to support this feature.

Read more comments on GitHub >

github_iconTop Results From Across the Web

JAX Frequently Asked Questions (FAQ)
JAX operations may be executed eagerly or after compilation (if inside jit() ); they are dispatched asynchronously (see Asynchronous dispatch); and they can...
Read more >
The Sharp Bits — JAX documentation
Allowing mutation of variables in-place makes program analysis and transformation difficult. JAX requires that programs are pure functions.
Read more >
jax.jit - JAX documentation
Positional arguments indicated by static_argnums can be anything at all, provided they are hashable and have an equality operation defined. Static arguments are ......
Read more >
Just In Time Compilation with JAX
JIT compiling a function# · We defined selu_jit as the compiled version of selu . · We ran selu_jit once on x ....
Read more >
Jax cannot find the static argnums - Stack Overflow
Here is the modified code. import jax.numpy as jnp from jax import grad, jit, value_and_grad from jax import ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found