O(n)-sized jaxpr takes O(exp(n)) work to obtain
See original GitHub issueThe following code has exponential compilation time in the depth (here passed as 10
):
import jax
import jax.lax as lax
import jax.numpy as jnp
def fn(x, depth):
if depth == 1:
return x
else:
_fn = lambda x: fn(x, depth - 1)
return lax.cond(jnp.array(True), _fn, lambda x: x, x)
jax.make_jaxpr(jax.vmap(lambda x: fn(x, 10)))(jnp.array([1.]))
The reason is that the condition jnp.array(True)
is not vmap’d, which means that two calls to batch_jaxpr
are made:
Here:
https://github.com/google/jax/blob/d8b7bd54be7252673a211b8f4d06a0b4d72256c1/jax/_src/lax/control_flow.py#L857
and here:
https://github.com/google/jax/blob/d8b7bd54be7252673a211b8f4d06a0b4d72256c1/jax/_src/lax/control_flow.py#L861
So that overall 2^depth
calls are made to batch_jaxpr
.
I’m currently writing something involving a binary tree, but unfortunately this is pretty much a blocker.
Issue Analytics
- State:
- Created 2 years ago
- Comments:28 (26 by maintainers)
Top Results From Across the Web
Understanding Jaxprs - JAX documentation - Read the Docs
A jaxpr instance represents a function with one or more typed parameters (input variables) and one or more typed results. The results depend ......
Read more >Alpa: Automating Inter- and Intra-Operator Parallelism for ...
(e) shows our approach that creates a hierarchical space to com- bine intra- and inter-operator parallelisms. • We evaluate Alpa on training large...
Read more >Fast Finite Width Neural Tangent Kernel
storing the NTK of size N2O2), since intermediate pre-activations and ... do have to instead process N small subarrays of primitive Jacobians ∂yki....
Read more >A Particle Filter Method of Inference for Stochastic Differential ...
opportunity to work with him on exciting research and for mentoring me through ... n−1,θ) is the PDF of the normal distribution in...
Read more >Swift for TensorFlow: A portable, flexible platform for deep ...
In practice, neural networks are trained on clus- ... 2018)) have been explicitly designed with pre-training in.
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Yeah, adding
jax._src.util.cache
aroundbatching.batch_jaxpr
and adding some tuple casts for its arguments drops the execution time from 80s to milliseconds.Hrm really doesn’t seem exponential anymore: