kwargs sometimes cannot work well with jit
See original GitHub issueMinimal repro:
from functools import partial
from jax import jit
@partial(jit, static_argnums=(3,))
def f(a, b, c, d):
if d > 0:
return a + b - c
else:
return d
f(a=1, b=2, c=3, d=4)
would complain about TypeError: Jitted function has static_argnums=(3,) but was called with only 0 positional arguments.
but f(1, 2, 3, 4)
works.
Issue Analytics
- State:
- Created 4 years ago
- Reactions:2
- Comments:5 (4 by maintainers)
Top Results From Across the Web
Just-in-time compilation (JIT) - Duke People
Utility function for timing functions. import time from numpy.testing import assert_almost_equal. def timer(f, *args, **kwargs): start = time.clock() ans ...
Read more >Just In Time Compilation with JAX
In this section, we will further explore how JAX works, and how we can make it performant. We will discuss the jax.jit() transform,...
Read more >Frequently Asked Questions - Numba
New users sometimes expect to JIT-compile such functions: ... This doesn't work well with Numba, and can produce errors like TypeError: No matching...
Read more >Saving and Loading Models - PyTorch
One common way to do inference with a trained model is to use TorchScript, an intermediate representation of a PyTorch model that can...
Read more >What to do when you get an error - Hugging Face Course
If you can't seem to find the source of an error, make sure you expand the full traceback ... **kwargs) 1052 # Do...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
This error is intentional, and unfortunately (AFAICT) necessary.
The underlying issue is that in Python there’s no way to reliably determine a function’s signature (it can be changed or hidden by decorators,
functools.partial
, etc.), which means there’s no way to match named arguments to their positions (and vice-versa) given a calling set of args/kwargs. That in turn means static_argnums doesn’t work with names. See #595 for a bit more info.One solution is to use a wrapper like this:
Now callers of
f
can writef(1, 2, 3, 4)
orf(a=1, b=2, c=3, d=4)
or whatever calling convention they like; the single layer of indirection fromf
to_f
means that Python standardizes the arguments for us. We use this trick inside JAX in a lot of places, like in random.py but also for every lax primitive, which is one reason why you see things likedef sin(x): return sin_p.bind(x)
in lax.py.WDYT?
The original issue has now been fixed for
jit
, and similar improvements forvmap
and other transforms are tracked in #10614