question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Hidden Markov Models

See original GitHub issue

DRAFT

Please comment if you see issues with this design or have ideas, know use cases I did not think about!

Finding the right abstraction: the HMM distribution

We would like simplify the expression of hidden markov models (HMMs) in MCX. The underlying idea is that HMMs are made of units that are repeated.

Simple HMM

In their simplest form:

x[t-1] ---> x[t] ---> x[t+1]
  |          |           |
y[t-1]      y[t]      y[t+1]

And the elementary unit is:

x[t-1] ---> x[t]
             |
            y[t]

Let us see if this is possible to build a model from the expression of one unit as a generative model:

def hmm_unit(x_prev):
    x <~ Categorical(x_probs[x_prev])
    y <~ Bernoulli(y_probs[x])
    return y

Knowing the previous value of x and the observation y we can compute the posterior distribution of x_prev. How do we combine the units? We can create a new distribution!

class HMM(mcx.Distribution):
    def __init__(self, unit):
        pass

    def sample(self):
        # to be defined
	
    def logpdf(self):
        # to be defined

Let us assume this hmm distribution exists. A simple HMM would thus be loosely expressed in MCX as

@mcx.model
def mymodel(hidden_dims, num_units):
    x_probs <~ dist.Dirichlet(0.5 * np.eye(hidden_dims))
    y_probs <~ Beta(1,1, batch_size=(hidden_dims, num_units))

    @mcx.model
    def hmm_unit(x_prev):
        x <~ Categorical(x_probs[x_prev])
        y <~ Bernoulli(y_probs[x])
        return y

    x_init = np.zeros(num_units)

    obs <~ HMM(hmm_unit, x_init)

    return obs

HMM where observations depend on previous observations

To challenge this abstraction let us assume a more complex model:

x[t-1] ---> x[t] ---> x[t+1]
  |          |           |
y[t-1] ---> y[t] ---> y[t+1]

The elementary unit becomes:

x[t-1] ---> x[t]
             |           
y[t-1] ---> y[t]

So the model can be written:

@mcx.model
def mymodel(hidden_dims, num_units):
    x_probs <~ Dirichlet(0.5 * np.eye(hidden_dims))
    y_probs <~ Beta(1,1, batch_size=(hidden_dims, 2, num_units))

    @mcx.model
    def hmm_unit(x_prev, y_prev):
        x <~ Categorical(x_probs[x_prev])
        y <~ Bernoulli(y_probs[x, y_prev])
        return y

    x_init = np.zeros(num_units)

    obs <~ HMM(hmm_unit, x_init)

    return obs

Factorial HMM

What if we have a Factorial HMM instead:

x[t-1] ---> x[t] ---> x[t+1]
  |          |           |
  v          v           v
y[t-1]     y[t]       y[t+1]
  ^          ^           ^
  |          |           |
v[t-1] ---> v[t] ---> v[t+1]

Elementary unit is:

x[t-1] ---> x[t]
             | 
             v 
            y[t]
             ^  
             |  
v[t-1] ---> v[t]

And in code:

@mcx.model
def mymodel(hidden_dims, num_units):
    x_probs <~ Dirichlet(0.5 * np.eye(hidden_dims))
    v_probs <~ Dirichlet(0.3 * np.eye(hidden_dims))
    y_probs <~ Beta(1,1, batch_size=(hidden_dims, 2, num_units))

    @mcx.model
    def hmm_unit(x_prev, v_prev):
        x <~ Categorical(x_probs[x_prev])
        v <~ Categorical(v_probs[v_prev])
        y <~ Bernoulli(y_probs[x, v])
	    return y

    x_init = np.zeros(num_units)

    obs <~ HMM(hmm_unit, x=x_init, v=v_init)

    return obs

The abstraction seems to be robust.

Implementing the HMM distribution

We need to provide an implementation for the sample and logpdf methods of the HMM distribution.

Sample

When parsing the model to compile it into a sampling function, MCX will transform hmm_unit into the sample_hmm_unit function below:

def sample_hmm_unit(rng_key, x):
    x_new = Categorical(x_probs[x]).sample(rng_key)
    y_new = Bernoulli(y_probs[x_new]).sample(rng_key)
    return x_new, y_new

HMM.sample(rng_key) should return samples for y and x’s prior distribution. We can achieve it with:

def scan_update(x, rng_key):
    x_new, y = sample_hmm_unit(rng_key, x)
    return x_new, (x_new, y)

rng_key = jax.random.PRNGKey(0)
keys = jax.random.split(rng_key, num_units)
_, (x_samples, y_samples) = jax.lax.scan(scan_update, x_init, keys)

likelihood

When parsing the model to compile it into a loglikelihood, MCX with transform hmm_unit into the logpdf_hmm_unit function below:

def logpdf_hmm_unit(x_prev, x, y):
	loglikelihood = 0
	loglikelihood += Categorical(x_probs[x_prev]).logpdf(x)
	loglikelihood += Bernoulli(y_probs[x]).logpdf(y)
	return loglikelihood

HMM.logpdf(x, y) should return the loglikelihood of the model given the values of x_probs, y_probs (in the higher-level context) x and y=obs. We could use:

def scan_update(x_prev, (x, y)):
	loglikelihood = logpdf_hmm_unit(x_prev, x, y)
	return x, loglikelihood

_, loglikelihoods = jax.lax.scan(scan_update, x_init, (x, obs))
loglikelihood = np.sum(loglikelihoods)

Note: How do you implement time-dependent transitions?

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:7 (7 by maintainers)

github_iconTop GitHub Comments

1reaction
rloufcommented, Oct 24, 2020

Look at the batch size, y_probs is a 3D array with independent rvs that are Beta-distributed. I first came across that on your blog actually 😃

I believe Pyro has an expand method that does this, maybe it’s more readable?

1reaction
ericmjlcommented, Oct 21, 2020

Yes. We did similar things with the unirep model!

Read more comments on GitHub >

github_iconTop Results From Across the Web

Hidden Markov model - Wikipedia
Hidden Markov models are generative models, in which the joint distribution of observations and hidden states, or equivalently both the prior distribution of...
Read more >
Hidden Markov Model. Elaborated with examples
Markov and Hidden Markov models are engineered to handle data which can be represented as 'sequence' of observations over time .
Read more >
What is a hidden Markov model? | Nature Biotechnology
Hidden Markov models (HMMs) are a formal foundation for making probabilistic models of linear sequence 'labeling' problems. They provide a ...
Read more >
Hidden Markov Models
Chapter 8 introduced the Hidden Markov Model and applied it to part of speech tagging. Part of speech tagging is a fully-supervised learning...
Read more >
Hidden Markov Models Simplified. Sanjay Dorairaj - Medium
Hidden Markov Models (HMMs) are a class of probabilistic graphical model that allow us to predict a sequence of unknown (hidden) variables ......
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found