Dynamically sized arrays with maximum length
See original GitHub issueJax doesn’t seem to like dynamically sized arrays when jit
ing, which is reasonable, because I suppose static shapes make it easier for the compiler? e.g. the following yields an error:
import jax.numpy as jnp
from jax import random
from jax import jit
rng = random.PRNGKey(0)
K = random.randint(rng, (1,), 1, 100)
def dynamic_array(K):
return jnp.arange(K)
# Raises an error
jit(dynamic_array)(K)
Similar considerations exists if using lax.iota
or linspace
. Is there a way to specify some maximum array size and getting away with creating an array of dynamic length below this maximum size? I thought about allocating an array of maximum length and using lax.dynamic_slice
, but this throws a similar error, and seems pretty wasteful, moreover.
Right now I get around this by specifying the size K
as a static argnum to jit
:
@partial(jit, static_argnums=(0,))
def dynamic_array(K):
return jnp.arange(K)
but then this recompiles for each different value of K
, if I understand correctly, and seems pretty wasteful if there can be many different sizes K
.
For some context, I’m doing this in the context of Russian Roulette estimators - (https://projecteuclid.org/euclid.ss/1449670853), where one can obtain an unbiased estimate of some infinite series by randomly truncating the series and weighting terms appropriately. So here K
would be how many terms to evaluate in the series.
Issue Analytics
- State:
- Created 3 years ago
- Reactions:10
- Comments:6 (1 by maintainers)
Sorry for not providing any useful updates. Actually I didn’t notice until now that this issue was getting pinged! My GitHub notifications are not organized, to say the least.
We’ve worked on a couple prototypes in this direction, but no solutions so far. Momentum is picking back up on this front so I hope to provide a substantial update in the next several months. (I’d love to detail and discuss all our work-in-progress ideas here, but we’re pretty bandwidth-maxed right now!)
So I don’t have a useful update right now, but we haven’t forgotten!
Any updates? I was trying to re-implement IRM with CMNIST in Jax but kind of impossible to reproduce the original version because of this issue.