Print return of jit method in class
See original GitHub issueHi all,
I’m trying to use classes with jax and one of the problems I have is that I can’t print a value that is manipulated in a JIT compiled class method. Example code:
import jax.numpy as np
from jax import jit
from functools import partial
class World:
def __init__(self, p, v):
self.p = p
self.v = v
@partial(jit, static_argnums=(0,))
def step(self, dt):
a = - 9.8
self.v += a * dt
self.p += self.v *dt
world = World(np.array([0, 0]), np.array([1, 1]))
for i in range(1000):
world.step(0.01)
print(world.p)
This prints Traced<ShapedArray(float32[2]):JaxprTrace(level=-1/1)>
I’m aware this is expected behavior when you print something inside the function, but this is not inside the function right?
More generally, I’m wondering if object oriented programming is well suited for jax? Should I avoid this kind of stuff? Is JIT capable of working optimally this way?
Thanks for your time!
Issue Analytics
- State:
- Created 4 years ago
- Reactions:1
- Comments:7 (3 by maintainers)
Top Results From Across the Web
How to use numba.jit with methods - Stack Overflow
1 Answer 1 ; from numba import jit class ; import numba as nb spec = [ ('number', nb.int64), ] @nb.jitclass(spec) ; from...
Read more >Compiling Python classes with @jitclass - Numba
All methods of a jitclass are compiled into nopython functions. The data of a jitclass instance is allocated on the heap as a...
Read more >Compiling Python classes with @jitclass
returns a compiled version. If used as a function, returns the compiled class (an instance of: JitClassType ). Examples.
Read more >Optimal way to speed up class method - Numba Discussion
@jit def aux(x, a, b): *do some complicated computations which involve a and b* class MyClass: def __init__(self, a, b): self.a = a...
Read more >Just-in-Time Inventory (JIT) Explained: A Guide - NetSuite
Book manuscripts are printed and assembled only when sold. JIT reduces wasteful destruction of books and returns of unsold inventory. Publishing ...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Here are a couple styles that do work well with JAX:
That’s just a functional version of your code. The key is that
step
returns a new World, rather than modifying the existing one.We can organize the same thing into Python classes if we want:
The key difference there is that
step
returns a newWorld
instance.Here’s one last pattern that works, using your original
World
class, though it’s a bit more subtle:(That last one takes much longer to compile, because we’re unrolling 1000 steps into a single XLA computation and compiling that; in practice we’d use something like
lax.fori_loop
orlax.scan
to avoid those long compile times.)The reason your original class works in that last example is that we’re only using it under a
jit
, so thejit
function itself doesn’t have any side-effects.Of those styles, I personally have grown to like the first. I wrote all my code in grad school in an OOP-heavy style, and I regret it: it was hard to compose with other code, even other code that I wrote, and that really limited its reach. Functional code, by forcing explicit state management, solves the composition problem. It’s also a great fit for numerical computing in general, since numerical computing is much closer to math than, say, writing a web server.
Hope that’s helpful 😃
Just wanted to let you know that this answer allowed me to convert from an OOP to a functional mindset for the first time! Was v helpful and illustrative, and I can now
@jax.jit
everything! 😉