vmap into dynamic_slice -- TypeError: unhashable type: 'BatchTracer'
See original GitHub issueIs it possible to vmap over inputs to a dynamic indexer? E.g., trying to implement a likelihood function for a non-absorbing random walk by splitting into partitions of individual absorbing random walks (see p.4 here, for example).
The most straight-forward way I could think to do this in jax was to use dynamic_slice_in_dim
on the markov chain, which can be used to index/partition the transition matrix as needed. Obviously this is vectorizeable, but the input “batch” will be which state I’m currently treating like the absorbing state.
Issue starts when it seems I can’t use the iterable used for vmap
as an input to the dynamic_slice_in_dim
function, as it throws an unhashable type error. Here’s some example code:
def P_i(T, a, idx):
"""we need to partition the transition matrix
$$ T =
\begin{pmatrix}
Q & R \\
0 & I
\end{pmatrix}
$$
where:
$Q$: the non-absorbing transitions,
$R$: non-absorbing to absorbing transitions
Then, probability of being absorbed is given as
$$P = (I-Q)^{-1} R$$
However, this inversion may not exist, and may be numerically unstable.
Instead, this can be rearranged into a linear system of equations,
$Ax = b$, and solving for $P$ as $x$ gives
$$(I-Q)P = R$$
In this case, we only want the probability of transitioning from the
most recent state to the current absorbing state.
"""
a_trans = dynamic_slice_in_dim(np.array(a), 0, idx) # visited
a_absrb = dynamic_slice_in_dim(np.array(a), idx, len(a)-idx) # not visited
Q = T[a_trans, :][:, a_trans]
R = T[a_trans, :][:, a_absrb]
I = np.identity(Q.shape[0])
P = linalg.solve(I - Q, R) # Probability of absorption...
return P[-1, 0] # ...from previous state (P[-1,:] by construction) into next
def invite_likelihood(T, a):
"""Calculate the log-likelihood of a transition matrix $T$, given censored
observed INVITE sequence $a$.
"""
# need to include all possible states in T at the end of a
frontload_a = list(a) + list(set(range(T.shape[0]))-set(a))
transition_prob = vmap(partial(P_i, T, frontload_a))(np.arange(1, len(a)))
return -np.log(np.sum(np.array(transition_prob)))
The function P_i
is working well, such that trying
for i in list(map(partial(P_i, T, frontload_a), range(1,len(a)))):
print(i)
gives the desired values. Using vmap appears to break the functionality, however.
Issue Analytics
- State:
- Created 5 years ago
- Comments:7 (5 by maintainers)
Aha, great! This sounds like exactly the same issue that people run into when, for example, trying to batch RNN computations on ragged sequence lengths. If we want to make use of SIMD hardware (which is what vmap is about), we might need to do some kind of padding+masking to express things in a vmap-friendly way (or alternatively, for more memory efficiency, have some other sparse representation).
I say “great” because we’re working on some new function transformations that should help this use case (see the “automasking” branch, which needs updating). But they’re not ready yet, and might not be for a couple months.
In the meantime, this computation is just hard to batch in any system, and vmap on its own doesn’t help. The best we have right now is “check back in a month and JAX might be able to help with this too!”
I’m going to close this issue if it’s alright with you, but reopen if you think there’s a JAX bug here (rather than an open challenge).
Sure, that’d be great! Actually it might be best to open a new issue in that case, to ensure we notice it.