question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Trouble using JIT when the function is in a class?

See original GitHub issue

Hi,

I noticed some significant slowdowns in my code from using jax.numpy instead of numpy and from the other issues it seems the solution is to use jit. However, when I try to use jit in a single script file for testing purposes it seems to work, but when I separate the function that I want to jit into another class I have problems.

import jax.numpy as np
import numpy as onp
from jax import jit, jacfwd, grad

from jax.numpy import sin, cos, exp

class odes:
    def __init__(self):
        print("odes file initialized")
    @jit
    def simpleODE(self, t,q):
        return np.array([[q[1]], [cos(q[0])]])

from odes import *
from jax import jit, jacfwd, grad

ODE = odes()

Jac = jacfwd(ODE.simpleODE, argnums = (1,))

q = np.ones(2)

A = Jac(0,q)
print(A)

gives the following error, TypeError: Argument ‘<odes.odes object at 0x7fe440250810>’ of type <class ‘odes.odes’> is not a valid JAX type

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Comments:5 (2 by maintainers)

github_iconTop GitHub Comments

41reactions
mattjjcommented, Aug 26, 2019

You might be able to work with this pattern:

from functools import partial

class odes:
  def __init__(self):
    print("odes file initialized")

  @partial(jit, static_argnums=(0,))
  def simpleODE(self, t, q):
    return np.array([[q[1]], [cos(q[0])]])

In words, we’re marking the first argument (index 0) as a static argument.

What do you think?

3reactions
mattjjcommented, Aug 27, 2019

Glad to hear that helped!

Yes, the issue is that jit only knows how to compile numerical computations on arrays (i.e. what XLA can do), not arbitrary Python computations. In particular that means it only knows how to work with array data types, not arbitrary classes, and in this case the self argument is an instance of ode. By using static_argnums we’re telling jit to compile only the computation that gets applied to the other arguments, and just to re-trace and re-compile every time the first argument changes its Python object id. That re-tracing basically means jit lets Python handle everything to do with the self argument.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Problems with JIT compiling a function - ROOT Forum - CERN
I'm trying to see if I could generalise an algorithm in our analysis software a bit. Instead of that piece of the code...
Read more >
python - Whenever I try to use @jit on my class method, I get ...
Two options: either the @jit line mixes tabs and spaces inconsistent with the other lines, or you are using an ancient Python version...
Read more >
Compiling Python classes with @jitclass - Numba
We call the resulting class object a jitclass. All methods of a jitclass are compiled into nopython functions. The data of a jitclass...
Read more >
Just In Time Compilation with JAX
We will discuss the jax.jit() transform, which will perform Just In Time (JIT) compilation of a JAX Python function so it can be...
Read more >
13. Numba - Python Programming for Economics and Finance
If a class is successfully compiled, then its methods act as JIT-compiled functions. To give one example, let's consider the class for analyzing...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found