[Proposal] Consistent `argnums` and `argnames` parameters for transformations
See original GitHub issueHey JAX team,
I have been trying to wrap my head around 'argument annotation` in JAX for a bit in the hopes of finding a more intuitive/consistent implementation, which has lead me to the big block of text below. I would be super keen to hear your thoughts as I try to dive deeper into the inner workings of JAX.
Lately there have been a number of issues requesting improvements to *_argnums
and *_argnames
parameters used in transformations in addition to other ergonomics improvements related to declaring which function arguments should be annotated with a given property. I figured it might be helpful to make an over-arching issue with the end goal of having a consistent, ergonomic way of specifying these parameters. Managing argument ‘annotations’ in transformations has definitely been one of the more frustrating experiences of learning JAX (which is otherwise entirely amazing, of course)
Related issues:
jax.jit
correctly implements static_argnames
even for cases with keyword-only arguments, which would suggest that it should be possible to add argnames
equivalents to any function that currently only implements argnums
.
An easier but less robust fix could be to map argnames
to argnums
using inspect
(see discussion: #1159). This would likely not work for keyword-only arguments (though it might for things like donate_arg...
?)
Current shortcomings
Currently even the most robust implementation of the ‘argument annotation’ mechanism behaves in a somewhat counter-intuitive way (although this is suggested in the fine print of the docstring, if one reads it with sufficient care):
def f(a, /, b, *, c):
print(a, b, c)
jf = jit(f, static_argnames=("a", "b", "c"))
jf(1, 2, c=3)
> Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)> 2 3
# Expected: 1 2 3
jf2 = jit(f, static_argnames=("b", "c"), static_argnums=(0,))
jf2(1, 2, c=3)
> 1 Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)> 3
# Expected: 1 2 3
jf2(1, b=2, c=3)
> 1 2 3
# As expected
jf3 = jit(f, static_argnums=(0, 1, 2))
jf3(1, 2, c=3)
> 1 2 Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
# Expected: 1 2 3
jf3(1, b=2, c=3)
> 1 2 Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
# Expected: 1 2 3
The fact that we have one instance where we are able to get the expected result gives hope that a solution should be possible by inspecting the function and arguments and modifying static_argnums
and static_argnames
accordingly – or perhaps a better solution exists? Ideally we would want to avoid inspecting the arguments at call-time.
I have started toying with validation of static_argnums
and static_argnames
in #10603
Goals
My suggestion would be that a solution that fixes the inconsistencies above (or in the worst case documents them thoroughly) is found for jax.jit
.
Once that is done, it would be great to see *_argnames
and keyword-arg support added to other functions:
jax.experiment.pjit
jax.pmap
jax.value_and_grad
jax.custom_vjp
jax.custom_jvp
jax.hessian
jax.jacrev
jax.jacfwd
jax.grad
Additionally #10476 can be explored (could live in jax.experimental.annotations
, if there is any interest for this feature at all)
Progress
- Get feedback and decide on: (this issue)
- Interface (potential changes in function signatures for argument annotations)
- Behaviour
- Document interface and behaviour (initial PR: #10677)
- Make tests and ensure consistency for functions
-
jax.jit
(PR: #10619) -
jax.experiment.pjit
-
jax.pmap
-
jax.value_and_grad
-
jax.custom_vjp
-
jax.custom_jvp
-
jax.hessian
-
jax.jacrev
-
jax.jacfwd
-
jax.grad
-
Issue Analytics
- State:
- Created a year ago
- Comments:18 (16 by maintainers)
I second the proposal of https://github.com/google/jax/issues/10614#issuecomment-1132238808 and https://github.com/google/jax/issues/10614#issuecomment-1145151133 to use
*_args : Sequence[int | str]
. Treat each element as an argument number if it’s anint
and argument name if it’s astr
.Idea - Interface discussion:
Don’t use
*_argnums
and*_argnames
at all, just a*_args
parameter of typeSequence[str | int]
where integers are taken as positions and strings are taken as argument names.This is not only more succinct, but also allows us to maintain full backwards compatibility:
argnums
andargnames
would continue to work as they currently do, but would give rise to a deprecation warning for a few versions before being removed (potentially until JAX 1.0, but preferably sooner).Using the container class approach as proposed (proposal not finished) in #10746 would enable relatively painless support of
argnames
+argnums
andargs
for the period of deprecation.