question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

FR MarkovianCategorical distribution

See original GitHub issue

FR to have a more flexible version of DiscreteHMM.

I’m working with a model where several variables depend on a discrete markovian state:

p(z_k | z_{k-1}) and Categorical(m_k | z_k), Gamma(h_k | z_k), Beta(x_k | z_k)

I propose to introduce a new distribution MarkovianCategorical for z and integrate out z with MarkovProduct at the very end (this would be similar how enumeration works in TraceEnum_ELBO). Below is the snippet of MarkovianCategorical that I’ve written but I don’t know how to make it funsor compatible. The logic for log_prob here depends on the whether the value is enumerated or not. If it is then calculate the log_prob of the shape of (duration,states,states,...) which later can be used with MarkovProduct. Is this approach feasible?

Code Snippet

class MarkovianCategorical(TorchDistribution):
    def __init__(self, initial_logits, transition_logits, duration, validate_args=None):
        if initial_logits.dim() < 1:
             raise ValueError("expected initial_logits to have at least one dim, "
                              "actual shape = {}".format(initial_logits.shape))
        if transition_logits.dim() < 2:
            raise ValueError("expected transition_logits to have at least two dims, "
                             "actual shape = {}".format(transition_logits.shape))
        batch_shape = broadcast_shape(initial_logits.shape[:-1],
                                transition_logits.shape[:-3])
        event_shape = torch.Size((duration,))
        trans_shape = broadcast_shape(initial_logits.shape[-1:],
                                transition_logits.shape[-2:])
        self.initial_logits = initial_logits - initial_logits.logsumexp(-1, True)
        self.transition_logits = transition_logits - transition_logits.logsumexp(-1, True)
        self.logits = torch.zeros(batch_shape + event_shape + trans_shape)
        self.logits[..., 0, 0, :] = self.initial_logits
        self.logits[..., 1:, :, :] = self.transition_logits

        self._duration = duration
        self._num_events = trans_shape[-1]
        super().__init__(batch_shape, event_shape, validate_args=validate_args)

    def log_prob(self, value):
        if self._validate_args:
            self._validate_sample(value)
        value = value.long()
        # normal interpretation
        if value.shape[-1] == self._duration:
            value_prev = torch.cat(
                (torch.zeros(value.shape[:-1] + (1,), dtype=torch.long), value[..., :-1]),
                dim=-1)
            result = Vindex(self.logits)[..., value_prev, value].sum(-1)
        # enumarated interpretation
        elif value.shape[-1] == 1:
            time = torch.arange(self._duration).view((-1,) + (1,) * (len(value)+2))
            result = Vindex(self.logits)[..., time, value.unsqueeze(-1), value]
        return result

    def enumerate_support(self, expand=True):
        num_events = self._num_events
        values = torch.arange(num_events, dtype=torch.long, device=self.logits.device)
        values = values.view((-1,) + (1,) * len(self._batch_shape))
        if expand:
            values = values.expand((-1,) + self._batch_shape)
        return values

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:19 (19 by maintainers)

github_iconTop GitHub Comments

1reaction
eb8680commented, Nov 13, 2020

do you think it is possible to generalize partial_sum_product algorithm to eliminate markov variables (i.e. use sequential_sum_product to eliminate time dim and pairs of (prev,curr) for markov sites )?

Definitely! I think that’s the most general solution on the inference side to your original problem - the ELBO implementation I had sketched above effectively does this manually. For reference, you can find details on partial_sum_product in our paper Tensor Variable Elimination for Plated Factor Graphs - partial_sum_product is basically a line-for-line translation of Algorithm 1 in the paper.

You could even replace sequential_sum_product with funsor.sum_product.sarkka_bilmes_product, a version of sequential_sum_product extended to multiple longer time dependencies, which would enable parallel-scan elimination for a very general class of higher-order dynamic graphical models. If you get that working, we’d be able to write a nice paper about it entirely independent of the original motivating application if that’s something that interests you.

One thing that jumps out to me in looking at the code you have there: I think you might be missing one step in your algorithm - you should also eliminate plates that are entirely local to a single time step before eliminating the time variable. It’s not clear to me whether this is being done correctly at the moment.

Tracking down issues like that is surprisingly difficult, so regardless of how general you decide to get (and as a Funsor developer I’m all for extreme generality 😉 ) I suggest writing some direct tests of your modified partial_sum_product that do not involve Pyro or enumeration. @fritzo and my strategy in developing these rather complex variable elimination algorithms has been to test the fancy parallel versions against “unrolled” sequential versions whose correctness we can trust.

There are a bunch of tests of this flavor in Funsor’s test_sum_product.py, e.g. this one for sequential_sum_product. In my experience this will save you a tremendous amount of sweat and tears, even if it seems like more work up front; in fact, @fritzo and I often start by writing these tests and an interface or algorithm sketch, then implement the algorithms to make them pass.

In your case, a good strategy might involve writing some tests that construct a bunch of funsor.Tensors representing factor graphs with the appropriate markov structure (these correspond to the log_prob tensors produced by Pyro) and comparing your new algorithm against the old partial_sum_product applied to an “unrolled” set of factors, where “unrolling” means slicing time-parallel factors along the time axis

for factor in log_factors:
    if "time" in factor.inputs:
        slice_factors = [factor(time=t) for t in range(factor.inputs["time"].size)]
        ...  # rename other variables, remove factor from log_factors and add slice_factors to log_factors

and renaming any {var}_curr/{var}_prev dimensions in each new slice to {var}_t_{t}/{var}_t_{t-1} or something similar. The unrolled factors should no longer contain a “time” input. You can see a more concrete example of this unrolling procedure applied to a single factor in the test I linked above, test_sequential_sum_product.

To have it running I had to change to eager interpretation in TraceEnum_ELBO, otherwise I was getting an error:

What was the problem? The lazy expression generated by partial_sum_product should be handled correctly by funsor.optimize. What happens if you replace sequential_sum_product with funsor.sum_product.MarkovProduct (a first-class Funsor term wrapping sequential_sum_product)?

1reaction
ordabayevycommented, Nov 13, 2020

@eb8680 do you think it is possible to generalize partial_sum_product algorithm to eliminate markov variables (i.e. use sequential_sum_product to eliminate time dim and pairs of (prev,curr) for markov sites )? Something like this:

     while ordinal_to_factors:
         leaf = max(ordinal_to_factors, key=len)
         leaf_factors = ordinal_to_factors.pop(leaf)
         leaf_reduce_vars = ordinal_to_vars[leaf]
         for (group_factors, group_vars) in _partition(leaf_factors, leaf_reduce_vars):
-            f = reduce(prod_op, group_factors).reduce(sum_op, group_vars)
+            nonmarkov_vars = frozenset(v for v in group_vars if not v.endswith("curr"))
+            markov_vars = {v.replace("curr", "prev"): v for v in group_vars if v.endswith("curr")}
+            f = reduce(prod_op, group_factors).reduce(sum_op, nonmarkov_vars)
+            if markov_vars:
+                time = Variable("time", f.inputs["time"])
+                f = sequential_sum_product(sum_op, prod_op, f, time, markov_vars)
+                f = f.reduce(sum_op, frozenset(markov_vars.values()))
+                f = f.reduce(sum_op, frozenset(markov_vars.keys()))
             remaining_sum_vars = sum_vars.intersection(f.inputs)
             if not remaining_sum_vars:
                 results.append(f.reduce(prod_op, leaf & eliminate))

Vectorized examples of model_1, model_2, model_3, and model_4 from HMM examples.

Markov site names end with _curr and corresponding auxiliary markov sites end with _prev and are not recorded in the trace in the TraceMessenger (not using enumerate: markov anymore):

     def _pyro_post_sample(self, msg):
-        if msg["name"] in self.trace:
+        if msg["name"] in self.trace or \
+                msg["name"].endswith("prev"):

To have it running I had to change to eager interpretation in TraceEnum_ELBO, otherwise I was getting an error:

-        with funsor.interpreter.interpretation(funsor.terms.lazy):
+        with funsor.interpreter.interpretation(funsor.terms.eager):
Read more comments on GitHub >

github_iconTop Results From Across the Web

Markov Chain Random Fields for Estimation of Categorical ...
This paper introduces a Markov chain random field (MCRF) theory for building one to multi-dimensional Markov chain models for conditional ...
Read more >
Markov data - Simulx
We have so far assumed that the categorical observations (yj,j=1,2,…,n) are independent. It is however possible to introduce dependence between observations ...
Read more >
Categorical data modeling using Monolix - Lixoft
Objectives: learn how to implement a model for categorical data, assuming either independence or a Markovian dependence between observations.
Read more >
Hidden Markov Models [50 pts]
j=1 γij = 1 for the state observation parameters, for i ∈ S. Note that for each state i ∈ S, we have...
Read more >
How to learn a Hidden Markov Model with categorical ...
You are on the right track. A multinomial distribution makes sense here. By default, the submodel for a multinomial response variable will ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found