question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

jnp.piecewise compiles on every call in some cases

See original GitHub issue

Arose when submitting PR #7692.

def g(x):
  one = jnp.ones_like(x)
  return jnp.where(x>0, one, -one)

def f1(x):
  one = jnp.ones_like(x)
  return jnp.piecewise(x, [x < 0, x > 0], [g, g, 0.])

def f2(x):
  return jnp.piecewise(x, [x < 0, x > 0], [-1., 1., 0.])

x = jnp.arange(-2, 3)
for i in range(5):
  print(i, 'f1', f1(x))

for i in range(5):
  print(i, 'f2', f2(x))

Checking the logs, I see:

DEBUG:absl:Compiling _cumulative_reduction (5315785472) for args (ShapedArray(bool[3,5]),).
DEBUG:absl:Compiling <unnamed wrapped function> (5316489152) for args (ShapedArray(bool[5]), ShapedArray(int32[]), ShapedArray(int32[])).
DEBUG:absl:Compiling <unnamed wrapped function> (5316499392) for args (ShapedArray(bool[5]), ShapedArray(int32[]), ShapedArray(int32[])).
DEBUG:absl:Compiling <lambda> (5316499904) for args (ShapedArray(int32[5]), ShapedArray(int32[])).
0 f1 [-1 -1  0  1  1]
DEBUG:absl:Compiling <unnamed wrapped function> (5316615552) for args (ShapedArray(bool[5]), ShapedArray(int32[]), ShapedArray(int32[])).
DEBUG:absl:Compiling <unnamed wrapped function> (5316614720) for args (ShapedArray(bool[5]), ShapedArray(int32[]), ShapedArray(int32[])).
1 f1 [-1 -1  0  1  1]
DEBUG:absl:Compiling <unnamed wrapped function> (5316651968) for args (ShapedArray(bool[5]), ShapedArray(int32[]), ShapedArray(int32[])).
DEBUG:absl:Compiling <unnamed wrapped function> (5316606656) for args (ShapedArray(bool[5]), ShapedArray(int32[]), ShapedArray(int32[])).
2 f1 [-1 -1  0  1  1]
DEBUG:absl:Compiling <unnamed wrapped function> (5316694400) for args (ShapedArray(bool[5]), ShapedArray(int32[]), ShapedArray(int32[])).
DEBUG:absl:Compiling <unnamed wrapped function> (5316706816) for args (ShapedArray(bool[5]), ShapedArray(int32[]), ShapedArray(int32[])).
3 f1 [-1 -1  0  1  1]
DEBUG:absl:Compiling <unnamed wrapped function> (5316745024) for args (ShapedArray(bool[5]), ShapedArray(int32[]), ShapedArray(int32[])).
DEBUG:absl:Compiling <unnamed wrapped function> (5316666944) for args (ShapedArray(bool[5]), ShapedArray(int32[]), ShapedArray(int32[])).
4 f1 [-1 -1  0  1  1]
0 f2 [-1 -1  0  1  1]
1 f2 [-1 -1  0  1  1]
2 f2 [-1 -1  0  1  1]
3 f2 [-1 -1  0  1  1]
4 f2 [-1 -1  0  1  1]

Every call to f1 causes a compilation, whereas the equivalent f2 does not.

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:10 (2 by maintainers)

github_iconTop GitHub Comments

1reaction
jakevdpcommented, Oct 13, 2021

Revisiting this. With better understanding of JAX now, I think the root cause of the issue is that jnp.piecewise is creating temporary functions… https://github.com/google/jax/blob/10af170a85a20b86b44d6d026e9ef1669e9af0ce/jax/_src/numpy/lax_numpy.py#L6489-L6494 …and passing them to lax.switch, which traces the functions it receives… https://github.com/google/jax/blob/7fa6b1b5fafaf22503bc194ba66a00a6859bdadd/jax/_src/lax/control_flow.py#L656-L657 … and since all trace cacheing is keyed on the function ID, temporary functions are essentially never cached. So the jnp.piecewise implementation is similar to the second block here:

import jax.numpy as jnp
from jax import lax

def g(x):
  print(f'tracing g({x})')
  return x

def anon(func):
    return lambda *args: func(*args)

# g is traced once here:
lax.switch(0, [g, g], 1)
# tracing g(Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)

# g is traced multiple times here:
lax.switch(0, [anon(g), anon(g)], 1)
# tracing g(Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
# tracing g(Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)

I think the best solution here would be to make jnp.piecewise a wrapper for an underlying JIT-compiled function that separates out the static and non-static arguments it receives, so that the entire lax.switch call would be JIT compiled and appropriately cached.

0reactions
froystigcommented, Aug 30, 2021

It’s not tracing the single branch at all, it’s simply calling it in line. That’s a correct optimization, and I think a desirable one.

In what sense is this a root cause?

Read more comments on GitHub >

github_iconTop Results From Across the Web

numpy.piecewise() - JAX documentation - Read the Docs
Given a set of conditions and corresponding functions, evaluate each function on the input data wherever its condition is true. Parameters. x (ndarray...
Read more >
Just-In-Time Compilation in JAX | Kaggle
Hence, the first call to jit-compiled impure function executes side effects, but since jaxpr is created in side-effect free form, the subsequent calls...
Read more >
Piecewise - Wolfram Language Documentation
Set up a piecewise function with different pieces below and above zero: ... Find the derivative of a piecewise function: ... for each...
Read more >
Notes from Russel & Norvig - Stanford AI Lab
In many cases an useful way to come up with heuristic function is to use a relaxed ... recursive calls, but rather perform...
Read more >
Newest 'jit' Questions - Page 2 - Stack Overflow
My use case is a code that is executed a large number of times, and after some ... messages for all the methods...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found