Need a way to specify constants
See original GitHub issueI have a use case to jit a function that does slicing depending on a dict
arguments
@jit
def foo(x, d: dict):
return x[:d['num']] + d['val']
foo(jnp.arange(10), {'num': 5, 'value': np.arange(5)})
no surprise I got IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax.
static_argnums is yet not working as expected since I have an array in d
@partial(jit, static_argnums=1)
def foo(x, d: dict):
return x[:d['num']] + d['val']
foo(jnp.arange(10), {'num': 5, 'value': np.arange(5)})
I got ValueError: Non-hashable static arguments are not supported.
Here is a working dirty trick
@jit
def bar(x, d: dict):
return x[:d['num'].shape[0]] + d['val']
bar(jnp.arange(10), {'num': jnp.zeros((5, 0)), 'val': np.arange(5)})
I got DeviceArray([0, 2, 4, 6, 8], dtype=int32)
Is there a way to specify argument d['num']
as constant?
The above weird use case is more reasonable in Flax whose variables are collected as a dict input.
Issue Analytics
- State:
- Created 2 years ago
- Comments:8 (3 by maintainers)
Top Results From Across the Web
How to define constants in C# | Microsoft Learn
Constants are fields whose values are set at compile time and can never be changed. Use constants to provide meaningful names instead of...
Read more >Constants in C Explained – How to Use #define and the const ...
From the previous example, you have the constants STUDENT_ID and COURSE_CODE . Now you'll define them as constants using the const qualifier.
Read more >Python Constants: Improve Your Code's Maintainability
In this tutorial, you'll learn how to properly define constants in Python. By coding a bunch of practical example, you'll also learn how...
Read more >What is the best way to implement constants in Java? [closed]
If you have access to Java 5+, use enums to define your specific constants for an application area. All parts of the application...
Read more >Constants and Variables – Programming Fundamentals
A constant is a value that cannot be altered by the program during normal execution, i.e., the value is constant. When associated with...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
If you put non-hashable data into static fields, the JIT lookup will fail. Jax doesn’t give you an error for putting unhashable data in static fields yet.
Another related issue #7826
For any custom class Python defaults to creating a hash function based on the object ID, without reference to any of the classes members or contents:
This means that if you create a PyTree class and use it as a static argument, then mutate it, and pass it to the function again, you will get unexpected results. There’s some more information & discussion of this foot-gun in this issue: https://github.com/google/jax/issues/9024
If you create a new PyTree class each time, you will have a different object (and thus a different hash) and your function will be recompiled each time you call it.
There are ways you can hack around using
jnp.ndarray
as a static argument (essentially by wrapping it in some class that defines__hash__
), but I wouldn’t recommend it because it’s easy to inadvertently do the wrong thing, and this kind of problem is usually better solved by re-writing your code in a manner that conforms to typical JAX approaches.