question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

vmap into dynamic_slice -- TypeError: unhashable type: 'BatchTracer'

See original GitHub issue

Is it possible to vmap over inputs to a dynamic indexer? E.g., trying to implement a likelihood function for a non-absorbing random walk by splitting into partitions of individual absorbing random walks (see p.4 here, for example).

The most straight-forward way I could think to do this in jax was to use dynamic_slice_in_dim on the markov chain, which can be used to index/partition the transition matrix as needed. Obviously this is vectorizeable, but the input “batch” will be which state I’m currently treating like the absorbing state.

Issue starts when it seems I can’t use the iterable used for vmap as an input to the dynamic_slice_in_dim function, as it throws an unhashable type error. Here’s some example code:

def P_i(T, a, idx):
    """we need to partition the transition matrix
        $$ T =
        \begin{pmatrix}
         Q & R \\
         0 & I
        \end{pmatrix}
        $$
    where:
        $Q$: the non-absorbing transitions,
        $R$: non-absorbing to absorbing transitions
    Then, probability of being absorbed is given as
        $$P = (I-Q)^{-1} R$$
    However, this inversion may not exist, and may be numerically unstable.
    Instead, this can be rearranged into a linear system of equations,
    $Ax = b$, and solving for $P$ as $x$ gives
        $$(I-Q)P = R$$

    In this case, we only want the probability of transitioning from the
    most recent state to the current absorbing state.
    """

    a_trans = dynamic_slice_in_dim(np.array(a), 0, idx)  # visited
    a_absrb = dynamic_slice_in_dim(np.array(a), idx, len(a)-idx)  # not visited

    Q = T[a_trans, :][:, a_trans]
    R = T[a_trans, :][:, a_absrb]
    I = np.identity(Q.shape[0])

    P = linalg.solve(I - Q, R)  # Probability of absorption...
    return P[-1, 0]  # ...from previous state (P[-1,:] by construction) into next


def invite_likelihood(T, a):
    """Calculate the log-likelihood of a transition matrix $T$, given censored
    observed INVITE sequence $a$.
    """
    # need to include all possible states in T at the end of a
    frontload_a = list(a) + list(set(range(T.shape[0]))-set(a))

    transition_prob = vmap(partial(P_i, T, frontload_a))(np.arange(1, len(a)))

    return -np.log(np.sum(np.array(transition_prob)))

The function P_i is working well, such that trying

for i in list(map(partial(P_i, T, frontload_a), range(1,len(a)))):
    print(i)

gives the desired values. Using vmap appears to break the functionality, however.

Issue Analytics

  • State:closed
  • Created 5 years ago
  • Comments:7 (5 by maintainers)

github_iconTop GitHub Comments

1reaction
mattjjcommented, Feb 11, 2019

Aha, great! This sounds like exactly the same issue that people run into when, for example, trying to batch RNN computations on ragged sequence lengths. If we want to make use of SIMD hardware (which is what vmap is about), we might need to do some kind of padding+masking to express things in a vmap-friendly way (or alternatively, for more memory efficiency, have some other sparse representation).

I say “great” because we’re working on some new function transformations that should help this use case (see the “automasking” branch, which needs updating). But they’re not ready yet, and might not be for a couple months.

In the meantime, this computation is just hard to batch in any system, and vmap on its own doesn’t help. The best we have right now is “check back in a month and JAX might be able to help with this too!”

I’m going to close this issue if it’s alright with you, but reopen if you think there’s a JAX bug here (rather than an open challenge).

0reactions
mattjjcommented, Feb 13, 2019

Sure, that’d be great! Actually it might be best to open a new issue in that case, to ensure we notice it.

Read more comments on GitHub >

github_iconTop Results From Across the Web

python - TypeError : Unhashable type - Stack Overflow
You are creating a set via set(...) call, and set needs hashable items. You can't have set of lists. Because list's arent hashable....
Read more >
Python: TypeError: unhashable type: 'list' - Net-Informations.Com
This error shows that the my_dict key [1,2,3] is List and List is not a hashable type in Python . Dictionary keys must...
Read more >
jax_intro.ipynb - Colaboratory - Google Colab
So vmap cannot be used to do any kind of embarassingly parallel task. ... The error was: TypeError: unhashable type: 'BatchTracer' The stack ......
Read more >
How to Handle Unhashable Type List Exceptions in Python
The Python TypeError: Unhashable Type: 'list' happens when a mutable list, instead of an immutable tuple, is used as a hash argument.
Read more >
TypeError: unhashable type: 'list' - STechies
This tutorial explains how to resolve TypeError: unhashable type: 'list' in Python. This error occurs when you try to use a list as...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found