Please allow @jax.jit on staticmethods
See original GitHub issueThe following code does not work,
class Foo:
@jax.jit
@staticmethod
def foo(x):
return x * 2
with an error TypeError: Expected a callable value, got <staticmethod object at 0x7fe007e767f0>
. This is because a staticmethod
is NOT a callable
, until Python 3.10. Of course, staticmethod is callable – something we can call.
In #1251 one possible workaround is suggested:
class Foo:
@functools.partial(jax.jit, static_argnums=(0,))
def foo(self, x):
# this function should not make use of `self` to be pure-functional
return x * 2
which is a bit ugly, but basically does the same thing as long as self
is not used.
I don’t see any reason staticmethod cannot be jit-ed as other non-method functions. I think this is just a matter of extending _check_callable(fun)
to support staticmethod
s (which will happen automatically in Python 3.10); any pitfalls I would’ve missed?
Issue Analytics
- State:
- Created 2 years ago
- Comments:8 (3 by maintainers)
Top Results From Across the Web
JAX Frequently Asked Questions (FAQ)
JAX operations may be executed eagerly or after compilation (if inside jit() ); they are dispatched asynchronously (see Asynchronous dispatch); and they can...
Read more >The Sharp Bits — JAX documentation
Allowing mutation of variables in-place makes program analysis and transformation difficult. JAX requires that programs are pure functions.
Read more >jax.jit - JAX documentation
Positional arguments indicated by static_argnums can be anything at all, provided they are hashable and have an equality operation defined. Static arguments are ......
Read more >Just In Time Compilation with JAX
JIT compiling a function# · We defined selu_jit as the compiled version of selu . · We ran selu_jit once on x ....
Read more >Jax cannot find the static argnums - Stack Overflow
Here is the modified code. import jax.numpy as jnp from jax import grad, jit, value_and_grad from jax import ...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
I think the cleanest approach is probably to use
@staticmethod @jit
rather than@jit @staticmethod
.If what you say about Python 3.10 is correct, though, the tests referred to above will (I think?) fail when run with Py3.10. If there are no fundamental reasons why staticmethod objects cannot be jitted, I would be fine with a PR that removes those tests and updates
_is_callable
to recognize static methods.Thank you @jblespiau for the explanation.
That would probably be how Google usually writes a JAX code, but I respectfully disagree staticmethods should be prohibited, because some community codes that do not necessarily follow the Google convention strictly might still want jit-ed staticmethods. Personally I don’t like having a top- or module-level plain function because I have to place them outside the class, usually very far from the methods that implement a very relevant logic.
That said, as an alternative one could define a (nested) plain function inside
__init__
and assign them as an attribute, but I feel using staticmethods would give a bit more flexibility. I can prepare a PR to support this feature.