jnp.piecewise compiles on every call in some cases
See original GitHub issueArose when submitting PR #7692.
def g(x):
one = jnp.ones_like(x)
return jnp.where(x>0, one, -one)
def f1(x):
one = jnp.ones_like(x)
return jnp.piecewise(x, [x < 0, x > 0], [g, g, 0.])
def f2(x):
return jnp.piecewise(x, [x < 0, x > 0], [-1., 1., 0.])
x = jnp.arange(-2, 3)
for i in range(5):
print(i, 'f1', f1(x))
for i in range(5):
print(i, 'f2', f2(x))
Checking the logs, I see:
DEBUG:absl:Compiling _cumulative_reduction (5315785472) for args (ShapedArray(bool[3,5]),).
DEBUG:absl:Compiling <unnamed wrapped function> (5316489152) for args (ShapedArray(bool[5]), ShapedArray(int32[]), ShapedArray(int32[])).
DEBUG:absl:Compiling <unnamed wrapped function> (5316499392) for args (ShapedArray(bool[5]), ShapedArray(int32[]), ShapedArray(int32[])).
DEBUG:absl:Compiling <lambda> (5316499904) for args (ShapedArray(int32[5]), ShapedArray(int32[])).
0 f1 [-1 -1 0 1 1]
DEBUG:absl:Compiling <unnamed wrapped function> (5316615552) for args (ShapedArray(bool[5]), ShapedArray(int32[]), ShapedArray(int32[])).
DEBUG:absl:Compiling <unnamed wrapped function> (5316614720) for args (ShapedArray(bool[5]), ShapedArray(int32[]), ShapedArray(int32[])).
1 f1 [-1 -1 0 1 1]
DEBUG:absl:Compiling <unnamed wrapped function> (5316651968) for args (ShapedArray(bool[5]), ShapedArray(int32[]), ShapedArray(int32[])).
DEBUG:absl:Compiling <unnamed wrapped function> (5316606656) for args (ShapedArray(bool[5]), ShapedArray(int32[]), ShapedArray(int32[])).
2 f1 [-1 -1 0 1 1]
DEBUG:absl:Compiling <unnamed wrapped function> (5316694400) for args (ShapedArray(bool[5]), ShapedArray(int32[]), ShapedArray(int32[])).
DEBUG:absl:Compiling <unnamed wrapped function> (5316706816) for args (ShapedArray(bool[5]), ShapedArray(int32[]), ShapedArray(int32[])).
3 f1 [-1 -1 0 1 1]
DEBUG:absl:Compiling <unnamed wrapped function> (5316745024) for args (ShapedArray(bool[5]), ShapedArray(int32[]), ShapedArray(int32[])).
DEBUG:absl:Compiling <unnamed wrapped function> (5316666944) for args (ShapedArray(bool[5]), ShapedArray(int32[]), ShapedArray(int32[])).
4 f1 [-1 -1 0 1 1]
0 f2 [-1 -1 0 1 1]
1 f2 [-1 -1 0 1 1]
2 f2 [-1 -1 0 1 1]
3 f2 [-1 -1 0 1 1]
4 f2 [-1 -1 0 1 1]
Every call to f1
causes a compilation, whereas the equivalent f2
does not.
Issue Analytics
- State:
- Created 2 years ago
- Comments:10 (2 by maintainers)
Top Results From Across the Web
numpy.piecewise() - JAX documentation - Read the Docs
Given a set of conditions and corresponding functions, evaluate each function on the input data wherever its condition is true. Parameters. x (ndarray...
Read more >Just-In-Time Compilation in JAX | Kaggle
Hence, the first call to jit-compiled impure function executes side effects, but since jaxpr is created in side-effect free form, the subsequent calls...
Read more >Piecewise - Wolfram Language Documentation
Set up a piecewise function with different pieces below and above zero: ... Find the derivative of a piecewise function: ... for each...
Read more >Notes from Russel & Norvig - Stanford AI Lab
In many cases an useful way to come up with heuristic function is to use a relaxed ... recursive calls, but rather perform...
Read more >Newest 'jit' Questions - Page 2 - Stack Overflow
My use case is a code that is executed a large number of times, and after some ... messages for all the methods...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Revisiting this. With better understanding of JAX now, I think the root cause of the issue is that
jnp.piecewise
is creating temporary functions… https://github.com/google/jax/blob/10af170a85a20b86b44d6d026e9ef1669e9af0ce/jax/_src/numpy/lax_numpy.py#L6489-L6494 …and passing them tolax.switch
, which traces the functions it receives… https://github.com/google/jax/blob/7fa6b1b5fafaf22503bc194ba66a00a6859bdadd/jax/_src/lax/control_flow.py#L656-L657 … and since all trace cacheing is keyed on the function ID, temporary functions are essentially never cached. So thejnp.piecewise
implementation is similar to the second block here:I think the best solution here would be to make
jnp.piecewise
a wrapper for an underlying JIT-compiled function that separates out the static and non-static arguments it receives, so that the entirelax.switch
call would be JIT compiled and appropriately cached.It’s not tracing the single branch at all, it’s simply calling it in line. That’s a correct optimization, and I think a desirable one.
In what sense is this a root cause?