Trouble using JIT when the function is in a class?
See original GitHub issueHi,
I noticed some significant slowdowns in my code from using jax.numpy instead of numpy and from the other issues it seems the solution is to use jit. However, when I try to use jit in a single script file for testing purposes it seems to work, but when I separate the function that I want to jit into another class I have problems.
import jax.numpy as np
import numpy as onp
from jax import jit, jacfwd, grad
from jax.numpy import sin, cos, exp
class odes:
def __init__(self):
print("odes file initialized")
@jit
def simpleODE(self, t,q):
return np.array([[q[1]], [cos(q[0])]])
from odes import *
from jax import jit, jacfwd, grad
ODE = odes()
Jac = jacfwd(ODE.simpleODE, argnums = (1,))
q = np.ones(2)
A = Jac(0,q)
print(A)
gives the following error, TypeError: Argument ‘<odes.odes object at 0x7fe440250810>’ of type <class ‘odes.odes’> is not a valid JAX type
Issue Analytics
- State:
- Created 4 years ago
- Comments:5 (2 by maintainers)
Top Results From Across the Web
Problems with JIT compiling a function - ROOT Forum - CERN
I'm trying to see if I could generalise an algorithm in our analysis software a bit. Instead of that piece of the code...
Read more >python - Whenever I try to use @jit on my class method, I get ...
Two options: either the @jit line mixes tabs and spaces inconsistent with the other lines, or you are using an ancient Python version...
Read more >Compiling Python classes with @jitclass - Numba
We call the resulting class object a jitclass. All methods of a jitclass are compiled into nopython functions. The data of a jitclass...
Read more >Just In Time Compilation with JAX
We will discuss the jax.jit() transform, which will perform Just In Time (JIT) compilation of a JAX Python function so it can be...
Read more >13. Numba - Python Programming for Economics and Finance
If a class is successfully compiled, then its methods act as JIT-compiled functions. To give one example, let's consider the class for analyzing...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
You might be able to work with this pattern:
In words, we’re marking the first argument (index 0) as a static argument.
What do you think?
Glad to hear that helped!
Yes, the issue is that
jit
only knows how to compile numerical computations on arrays (i.e. what XLA can do), not arbitrary Python computations. In particular that means it only knows how to work with array data types, not arbitrary classes, and in this case theself
argument is an instance ofode
. By usingstatic_argnums
we’re tellingjit
to compile only the computation that gets applied to the other arguments, and just to re-trace and re-compile every time the first argument changes its Python object id. That re-tracing basically meansjit
lets Python handle everything to do with theself
argument.