question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[Proposal] Consistent `argnums` and `argnames` parameters for transformations

See original GitHub issue

Hey JAX team,

I have been trying to wrap my head around 'argument annotation` in JAX for a bit in the hopes of finding a more intuitive/consistent implementation, which has lead me to the big block of text below. I would be super keen to hear your thoughts as I try to dive deeper into the inner workings of JAX.

Lately there have been a number of issues requesting improvements to *_argnums and *_argnames parameters used in transformations in addition to other ergonomics improvements related to declaring which function arguments should be annotated with a given property. I figured it might be helpful to make an over-arching issue with the end goal of having a consistent, ergonomic way of specifying these parameters. Managing argument ‘annotations’ in transformations has definitely been one of the more frustrating experiences of learning JAX (which is otherwise entirely amazing, of course)

Related issues:

jax.jit correctly implements static_argnames even for cases with keyword-only arguments, which would suggest that it should be possible to add argnames equivalents to any function that currently only implements argnums.

An easier but less robust fix could be to map argnames to argnums using inspect (see discussion: #1159). This would likely not work for keyword-only arguments (though it might for things like donate_arg...?)

Current shortcomings

Currently even the most robust implementation of the ‘argument annotation’ mechanism behaves in a somewhat counter-intuitive way (although this is suggested in the fine print of the docstring, if one reads it with sufficient care):

def f(a, /, b, *, c):
    print(a, b, c)

jf = jit(f, static_argnames=("a", "b", "c"))
jf(1, 2, c=3)
> Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)> 2 3
# Expected: 1 2 3

jf2 = jit(f, static_argnames=("b", "c"), static_argnums=(0,))
jf2(1, 2, c=3)
> 1 Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)> 3
# Expected: 1 2 3

jf2(1, b=2, c=3)
> 1 2 3
# As expected

jf3 = jit(f, static_argnums=(0, 1, 2))
jf3(1, 2, c=3)
> 1 2 Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
# Expected: 1 2 3

jf3(1, b=2, c=3)
> 1 2 Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
# Expected: 1 2 3

The fact that we have one instance where we are able to get the expected result gives hope that a solution should be possible by inspecting the function and arguments and modifying static_argnums and static_argnames accordingly – or perhaps a better solution exists? Ideally we would want to avoid inspecting the arguments at call-time.

I have started toying with validation of static_argnums and static_argnames in #10603

Goals

My suggestion would be that a solution that fixes the inconsistencies above (or in the worst case documents them thoroughly) is found for jax.jit.

Once that is done, it would be great to see *_argnames and keyword-arg support added to other functions:

  • jax.experiment.pjit
  • jax.pmap
  • jax.value_and_grad
  • jax.custom_vjp
  • jax.custom_jvp
  • jax.hessian
  • jax.jacrev
  • jax.jacfwd
  • jax.grad

Additionally #10476 can be explored (could live in jax.experimental.annotations, if there is any interest for this feature at all)

Progress

  • Get feedback and decide on: (this issue)
    • Interface (potential changes in function signatures for argument annotations)
    • Behaviour
  • Document interface and behaviour (initial PR: #10677)
  • Make tests and ensure consistency for functions
    • jax.jit (PR: #10619)
    • jax.experiment.pjit
    • jax.pmap
    • jax.value_and_grad
    • jax.custom_vjp
    • jax.custom_jvp
    • jax.hessian
    • jax.jacrev
    • jax.jacfwd
    • jax.grad

Issue Analytics

  • State:open
  • Created a year ago
  • Comments:18 (16 by maintainers)

github_iconTop GitHub Comments

2reactions
carlosgmartincommented, Nov 26, 2022

I second the proposal of https://github.com/google/jax/issues/10614#issuecomment-1132238808 and https://github.com/google/jax/issues/10614#issuecomment-1145151133 to use *_args : Sequence[int | str]. Treat each element as an argument number if it’s an int and argument name if it’s a str.

2reactions
JeppeKlitgaardcommented, May 19, 2022

Idea - Interface discussion:

Don’t use *_argnums and *_argnames at all, just a *_args parameter of type Sequence[str | int] where integers are taken as positions and strings are taken as argument names.

This is not only more succinct, but also allows us to maintain full backwards compatibility: argnums and argnames would continue to work as they currently do, but would give rise to a deprecation warning for a few versions before being removed (potentially until JAX 1.0, but preferably sooner).

Using the container class approach as proposed (proposal not finished) in #10746 would enable relatively painless support of argnames+argnums and args for the period of deprecation.

Read more comments on GitHub >

github_iconTop Results From Across the Web

jax.jit - JAX documentation
An optional int or collection of ints that specify which positional arguments to treat as static (compile-time constant). Operations that only depend on...
Read more >
Understanding JAX argnums parameter in its gradient function
I'm trying to understand the behaviour of argnums in JAX's gradient function. Suppose I have the following function:
Read more >
jax - bytemeta
`static_argnames` of `jax.jit` does not correctly infer `argnums` ... [Proposal] Consistent `argnums` and `argnames` parameters for transformations.
Read more >
update dualpane images in how to use - Cu-Mkp/Edition-Webpages
[Proposal] Consistent `argnums` and `argnames` parameters for transformations, 17, 2022-05-06, 2022-08-13. [Snyk] Fix for 1 vulnerabilities, 0, 2020-07-02 ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found