question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Print return of jit method in class

See original GitHub issue

Hi all,

I’m trying to use classes with jax and one of the problems I have is that I can’t print a value that is manipulated in a JIT compiled class method. Example code:

import jax.numpy as np
from jax import jit
from functools import partial


class World:
    def __init__(self, p, v):
        self.p = p
        self.v = v

    @partial(jit, static_argnums=(0,))
    def step(self, dt):
        a = - 9.8
        self.v += a * dt
        self.p += self.v *dt


world = World(np.array([0, 0]), np.array([1, 1]))

for i in range(1000):
    world.step(0.01)
print(world.p)

This prints Traced<ShapedArray(float32[2]):JaxprTrace(level=-1/1)> I’m aware this is expected behavior when you print something inside the function, but this is not inside the function right?

More generally, I’m wondering if object oriented programming is well suited for jax? Should I avoid this kind of stuff? Is JIT capable of working optimally this way?

Thanks for your time!

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Reactions:1
  • Comments:7 (3 by maintainers)

github_iconTop GitHub Comments

11reactions
mattjjcommented, Oct 25, 2019

Here are a couple styles that do work well with JAX:

import jax.numpy as np
from jax import jit
from collections import namedtuple

World = namedtuple("World", ["p", "v"])

@jit
def step(world, dt):
  a = -9.8
  new_v = world.v + a * dt
  new_p = world.p + new_v * dt
  return World(new_p, new_v)

world = World(np.array([0, 0]), np.array([1, 1]))

for i in range(1000):
  world = step(world, 0.01)
print(world.p)

That’s just a functional version of your code. The key is that step returns a new World, rather than modifying the existing one.

We can organize the same thing into Python classes if we want:

from jax.tree_util import register_pytree_node
from functools import partial

class World:
  def __init__(self, p, v):
    self.p = p
    self.v = v

  @jit
  def step(self, dt):
    a = -9.8
    new_v = self.v + a * dt
    new_p = self.p + new_v * dt
    return World(new_p, new_v)

# By registering 'World' as a pytree, it turns into a transparent container and
# can be used as an argument to any JAX-transformed functions.
register_pytree_node(World,
                     lambda x: ((x.p, x.v), None),
                     lambda _, tup: World(tup[0], tup[1]))


world = World(np.array([0, 0]), np.array([1, 1]))

for i in range(1000):
  world = world.step(0.01)
print(world.p)

The key difference there is that step returns a new World instance.

Here’s one last pattern that works, using your original World class, though it’s a bit more subtle:

class World:
  def __init__(self, p, v):
    self.p = p
    self.v = v

  def step(self, dt):
    a = - 9.8
    self.v += a * dt
    self.p += self.v *dt

@jit
def run(init_p, init_v):
  world = World(init_p, init_v)
  for i in range(1000):
    world.step(0.01)
  return world.p, world.v

out = run(np.array([0, 0]), np.array([1, 1]))
print(out)

(That last one takes much longer to compile, because we’re unrolling 1000 steps into a single XLA computation and compiling that; in practice we’d use something like lax.fori_loop or lax.scan to avoid those long compile times.)

The reason your original class works in that last example is that we’re only using it under a jit, so the jit function itself doesn’t have any side-effects.

Of those styles, I personally have grown to like the first. I wrote all my code in grad school in an OOP-heavy style, and I regret it: it was hard to compose with other code, even other code that I wrote, and that really limited its reach. Functional code, by forcing explicit state management, solves the composition problem. It’s also a great fit for numerical computing in general, since numerical computing is much closer to math than, say, writing a web server.

Hope that’s helpful 😃

2reactions
phinatecommented, Mar 3, 2020

Here are a couple styles that do work well with JAX:

import jax.numpy as np
from jax import jit
from collections import namedtuple

World = namedtuple("World", ["p", "v"])

@jit
def step(world, dt):
  a = -9.8
  new_v = world.v + a * dt
  new_p = world.p + new_v * dt
  return World(new_p, new_v)

world = World(np.array([0, 0]), np.array([1, 1]))

for i in range(1000):
  world = step(world, 0.01)
print(world.p)

That’s just a functional version of your code. The key is that step returns a new World, rather than modifying the existing one.

We can organize the same thing into Python classes if we want:

from jax.tree_util import register_pytree_node
from functools import partial

class World:
  def __init__(self, p, v):
    self.p = p
    self.v = v

  @jit
  def step(self, dt):
    a = -9.8
    new_v = self.v + a * dt
    new_p = self.p + new_v * dt
    return World(new_p, new_v)

# By registering 'World' as a pytree, it turns into a transparent container and
# can be used as an argument to any JAX-transformed functions.
register_pytree_node(World,
                     lambda x: ((x.p, x.v), None),
                     lambda _, tup: World(tup[0], tup[1]))


world = World(np.array([0, 0]), np.array([1, 1]))

for i in range(1000):
  world = world.step(0.01)
print(world.p)

The key difference there is that step returns a new World instance.

Here’s one last pattern that works, using your original World class, though it’s a bit more subtle:

class World:
  def __init__(self, p, v):
    self.p = p
    self.v = v

  def step(self, dt):
    a = - 9.8
    self.v += a * dt
    self.p += self.v *dt

@jit
def run(init_p, init_v):
  world = World(init_p, init_v)
  for i in range(1000):
    world.step(0.01)
  return world.p, world.v

out = run(np.array([0, 0]), np.array([1, 1]))
print(out)

(That last one takes much longer to compile, because we’re unrolling 1000 steps into a single XLA computation and compiling that; in practice we’d use something like lax.fori_loop or lax.scan to avoid those long compile times.)

The reason your original class works in that last example is that we’re only using it under a jit, so the jit function itself doesn’t have any side-effects.

Of those styles, I personally have grown to like the first. I wrote all my code in grad school in an OOP-heavy style, and I regret it: it was hard to compose with other code, even other code that I wrote, and that really limited its reach. Functional code, by forcing explicit state management, solves the composition problem. It’s also a great fit for numerical computing in general, since numerical computing is much closer to math than, say, writing a web server.

Hope that’s helpful 😃

Just wanted to let you know that this answer allowed me to convert from an OOP to a functional mindset for the first time! Was v helpful and illustrative, and I can now @jax.jit everything! 😉

Read more comments on GitHub >

github_iconTop Results From Across the Web

How to use numba.jit with methods - Stack Overflow
1 Answer 1 ; from numba import jit class ; import numba as nb spec = [ ('number', nb.int64), ] @nb.jitclass(spec) ; from...
Read more >
Compiling Python classes with @jitclass - Numba
All methods of a jitclass are compiled into nopython functions. The data of a jitclass instance is allocated on the heap as a...
Read more >
Compiling Python classes with @jitclass
returns a compiled version. If used as a function, returns the compiled class (an instance of: JitClassType ). Examples.
Read more >
Optimal way to speed up class method - Numba Discussion
@jit def aux(x, a, b): *do some complicated computations which involve a and b* class MyClass: def __init__(self, a, b): self.a = a...
Read more >
Just-in-Time Inventory (JIT) Explained: A Guide - NetSuite
Book manuscripts are printed and assembled only when sold. JIT reduces wasteful destruction of books and returns of unsold inventory. Publishing ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found