question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

kwargs sometimes cannot work well with jit

See original GitHub issue

Minimal repro:

from functools import partial
from jax import jit

@partial(jit, static_argnums=(3,))
def f(a, b, c, d):
  if d > 0:
    return a + b - c
  else:
    return d

f(a=1, b=2, c=3, d=4) would complain about TypeError: Jitted function has static_argnums=(3,) but was called with only 0 positional arguments.

but f(1, 2, 3, 4) works.

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Reactions:2
  • Comments:5 (4 by maintainers)

github_iconTop GitHub Comments

1reaction
mattjjcommented, Aug 10, 2019

This error is intentional, and unfortunately (AFAICT) necessary.

The underlying issue is that in Python there’s no way to reliably determine a function’s signature (it can be changed or hidden by decorators, functools.partial, etc.), which means there’s no way to match named arguments to their positions (and vice-versa) given a calling set of args/kwargs. That in turn means static_argnums doesn’t work with names. See #595 for a bit more info.

One solution is to use a wrapper like this:

from functools import partial
from jax import jit

@partial(jit, static_argnums=(3,))
def _f(a, b, c, d):
  if d > 0:
    return a + b - c
  else:
    return d

def f(a, b, c, d):
  return _f(a, b, c, d)

Now callers of f can write f(1, 2, 3, 4) or f(a=1, b=2, c=3, d=4) or whatever calling convention they like; the single layer of indirection from f to _f means that Python standardizes the arguments for us. We use this trick inside JAX in a lot of places, like in random.py but also for every lax primitive, which is one reason why you see things like def sin(x): return sin_p.bind(x) in lax.py.

WDYT?

0reactions
jakevdpcommented, Jun 21, 2022

The original issue has now been fixed for jit, and similar improvements for vmap and other transforms are tracked in #10614

Read more comments on GitHub >

github_iconTop Results From Across the Web

Just-in-time compilation (JIT) - Duke People
Utility function for timing functions. import time from numpy.testing import assert_almost_equal. def timer(f, *args, **kwargs): start = time.clock() ans ...
Read more >
Just In Time Compilation with JAX
In this section, we will further explore how JAX works, and how we can make it performant. We will discuss the jax.jit() transform,...
Read more >
Frequently Asked Questions - Numba
New users sometimes expect to JIT-compile such functions: ... This doesn't work well with Numba, and can produce errors like TypeError: No matching...
Read more >
Saving and Loading Models - PyTorch
One common way to do inference with a trained model is to use TorchScript, an intermediate representation of a PyTorch model that can...
Read more >
What to do when you get an error - Hugging Face Course
If you can't seem to find the source of an error, make sure you expand the full traceback ... **kwargs) 1052 # Do...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found