FR MarkovianCategorical distribution
See original GitHub issueFR to have a more flexible version of DiscreteHMM
.
I’m working with a model where several variables depend on a discrete markovian state:
p(z_k | z_{k-1})
and Categorical(m_k | z_k)
, Gamma(h_k | z_k)
, Beta(x_k | z_k)
I propose to introduce a new distribution MarkovianCategorical
for z
and integrate out z
with MarkovProduct
at the very end (this would be similar how enumeration works in TraceEnum_ELBO
). Below is the snippet of MarkovianCategorical
that I’ve written but I don’t know how to make it funsor compatible. The logic for log_prob
here depends on the whether the value
is enumerated or not. If it is then calculate the log_prob
of the shape of (duration,states,states,...)
which later can be used with MarkovProduct
. Is this approach feasible?
Code Snippet
class MarkovianCategorical(TorchDistribution):
def __init__(self, initial_logits, transition_logits, duration, validate_args=None):
if initial_logits.dim() < 1:
raise ValueError("expected initial_logits to have at least one dim, "
"actual shape = {}".format(initial_logits.shape))
if transition_logits.dim() < 2:
raise ValueError("expected transition_logits to have at least two dims, "
"actual shape = {}".format(transition_logits.shape))
batch_shape = broadcast_shape(initial_logits.shape[:-1],
transition_logits.shape[:-3])
event_shape = torch.Size((duration,))
trans_shape = broadcast_shape(initial_logits.shape[-1:],
transition_logits.shape[-2:])
self.initial_logits = initial_logits - initial_logits.logsumexp(-1, True)
self.transition_logits = transition_logits - transition_logits.logsumexp(-1, True)
self.logits = torch.zeros(batch_shape + event_shape + trans_shape)
self.logits[..., 0, 0, :] = self.initial_logits
self.logits[..., 1:, :, :] = self.transition_logits
self._duration = duration
self._num_events = trans_shape[-1]
super().__init__(batch_shape, event_shape, validate_args=validate_args)
def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
value = value.long()
# normal interpretation
if value.shape[-1] == self._duration:
value_prev = torch.cat(
(torch.zeros(value.shape[:-1] + (1,), dtype=torch.long), value[..., :-1]),
dim=-1)
result = Vindex(self.logits)[..., value_prev, value].sum(-1)
# enumarated interpretation
elif value.shape[-1] == 1:
time = torch.arange(self._duration).view((-1,) + (1,) * (len(value)+2))
result = Vindex(self.logits)[..., time, value.unsqueeze(-1), value]
return result
def enumerate_support(self, expand=True):
num_events = self._num_events
values = torch.arange(num_events, dtype=torch.long, device=self.logits.device)
values = values.view((-1,) + (1,) * len(self._batch_shape))
if expand:
values = values.expand((-1,) + self._batch_shape)
return values
Issue Analytics
- State:
- Created 3 years ago
- Comments:19 (19 by maintainers)
Top Results From Across the Web
Markov Chain Random Fields for Estimation of Categorical ...
This paper introduces a Markov chain random field (MCRF) theory for building one to multi-dimensional Markov chain models for conditional ...
Read more >Markov data - Simulx
We have so far assumed that the categorical observations (yj,j=1,2,…,n) are independent. It is however possible to introduce dependence between observations ...
Read more >Categorical data modeling using Monolix - Lixoft
Objectives: learn how to implement a model for categorical data, assuming either independence or a Markovian dependence between observations.
Read more >Hidden Markov Models [50 pts]
j=1 γij = 1 for the state observation parameters, for i ∈ S. Note that for each state i ∈ S, we have...
Read more >How to learn a Hidden Markov Model with categorical ...
You are on the right track. A multinomial distribution makes sense here. By default, the submodel for a multinomial response variable will ...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Definitely! I think that’s the most general solution on the inference side to your original problem - the ELBO implementation I had sketched above effectively does this manually. For reference, you can find details on
partial_sum_product
in our paper Tensor Variable Elimination for Plated Factor Graphs -partial_sum_product
is basically a line-for-line translation of Algorithm 1 in the paper.You could even replace
sequential_sum_product
withfunsor.sum_product.sarkka_bilmes_product
, a version ofsequential_sum_product
extended to multiple longer time dependencies, which would enable parallel-scan elimination for a very general class of higher-order dynamic graphical models. If you get that working, we’d be able to write a nice paper about it entirely independent of the original motivating application if that’s something that interests you.One thing that jumps out to me in looking at the code you have there: I think you might be missing one step in your algorithm - you should also eliminate plates that are entirely local to a single time step before eliminating the
time
variable. It’s not clear to me whether this is being done correctly at the moment.Tracking down issues like that is surprisingly difficult, so regardless of how general you decide to get (and as a Funsor developer I’m all for extreme generality 😉 ) I suggest writing some direct tests of your modified
partial_sum_product
that do not involve Pyro or enumeration. @fritzo and my strategy in developing these rather complex variable elimination algorithms has been to test the fancy parallel versions against “unrolled” sequential versions whose correctness we can trust.There are a bunch of tests of this flavor in Funsor’s
test_sum_product.py
, e.g. this one forsequential_sum_product
. In my experience this will save you a tremendous amount of sweat and tears, even if it seems like more work up front; in fact, @fritzo and I often start by writing these tests and an interface or algorithm sketch, then implement the algorithms to make them pass.In your case, a good strategy might involve writing some tests that construct a bunch of
funsor.Tensor
s representing factor graphs with the appropriate markov structure (these correspond to thelog_prob
tensors produced by Pyro) and comparing your new algorithm against the oldpartial_sum_product
applied to an “unrolled” set of factors, where “unrolling” means slicing time-parallel factors along the time axisand renaming any
{var}_curr
/{var}_prev
dimensions in each new slice to{var}_t_{t}
/{var}_t_{t-1}
or something similar. The unrolled factors should no longer contain a “time” input. You can see a more concrete example of this unrolling procedure applied to a single factor in the test I linked above,test_sequential_sum_product
.What was the problem? The lazy expression generated by
partial_sum_product
should be handled correctly byfunsor.optimize
. What happens if you replacesequential_sum_product
withfunsor.sum_product.MarkovProduct
(a first-class Funsor term wrappingsequential_sum_product
)?@eb8680 do you think it is possible to generalize
partial_sum_product
algorithm to eliminate markov variables (i.e. usesequential_sum_product
to eliminatetime
dim and pairs of(prev,curr)
for markov sites )? Something like this:Vectorized examples of
model_1
,model_2
,model_3
, andmodel_4
from HMM examples.Markov site names end with
_curr
and corresponding auxiliary markov sites end with_prev
and are not recorded in the trace in theTraceMessenger
(not usingenumerate: markov
anymore):To have it running I had to change to eager interpretation in
TraceEnum_ELBO
, otherwise I was getting an error: