`jit` input validation can lead to silently dynamic variables
See original GitHub issuejax.jit
(and presumable other similar transformations) currently does not validate static_argnames
and static_argnums
, which leads to silent failures in some cases. These can be particularly frustrating to debug and the additional input validation can be done with relatively little overhead.
A particularly headache inducing example of this could be when used in combination with positional-only (or keyword-only) arguments.
Examples:
def f(dyn, stat):
print(dyn)
print(stat)
...
def g(dyn, stat, /):
print(dyn)
print(stat)
...
jf = jit(f, static_argnames=("stat",)
jf(1, 2)
> Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
> 2
jg = jit(g, static_argnames=("stat",)
jg(1, 2)
> Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
> Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
# This is somewhat expected, but should at least raise an error
jit(f, static_argnames=("some_arg_that_does_not_exist",)
# No error!
jit(f, static_argnums=(9,)
# No error!
Solution
Input validation on jit
and similar functions may be made stricter. This may require using inspect
where possible.
I would be happy to try undertaking this.
Issue Analytics
- State:
- Created a year ago
- Comments:6 (6 by maintainers)
Top Results From Across the Web
CWE-20: Improper Input Validation (4.9) - MITRE
Input validation is a frequently-used technique for checking potentially dangerous inputs in order to ensure that the inputs are safe for ...
Read more >Using Dynamic SQL | InterSystems IRIS Data Platform 2022.2
Dynamic SQL can accept a literal value input to a query in two ways: input parameters specified using the “?” character, and input...
Read more >Genesys Cloud CX - Genesys
Hi everyone, in scripts we can challenge an input field over a regex (validation). However this validation only is visual and has no...
Read more >codegen.ts file - GraphQL Code Generator
In addition, you can also define a path to your config file with the ... to dynamically create the list of output files...
Read more >Input validation errors: The root of all evil in web application ...
Input validation is the first step in sanitizing the type and content of data supplied by a user or application. Missing or improper...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Would this not also apply to
argnums
?I am working on argnum validation right now, and I think it already spotted an error in JAX source: https://github.com/google/jax/blob/ab7a60b3abdc1e4aec210d79926824b85082d2d7/jax/_src/lax/qdwh.py#L68-L69
I assume (and believe to recall)
static_argnums
is zero-indexed, so makingargnum=3
will never make sense.Could you confirm this, @hawkinsp?
Would the correct
arg_nums
be(0, 1, 2)
or(1, 2)
in this case?Those two cases nicely motivate stricter input validation, I think!
I will make a PR with the additional validation.