Hidden Markov Models
See original GitHub issueDRAFT
Please comment if you see issues with this design or have ideas, know use cases I did not think about!
Finding the right abstraction: the HMM distribution
We would like simplify the expression of hidden markov models (HMMs) in MCX. The underlying idea is that HMMs are made of units that are repeated.
Simple HMM
In their simplest form:
x[t-1] ---> x[t] ---> x[t+1]
| | |
y[t-1] y[t] y[t+1]
And the elementary unit is:
x[t-1] ---> x[t]
|
y[t]
Let us see if this is possible to build a model from the expression of one unit as a generative model:
def hmm_unit(x_prev):
x <~ Categorical(x_probs[x_prev])
y <~ Bernoulli(y_probs[x])
return y
Knowing the previous value of x
and the observation y
we can compute the posterior distribution of x_prev
. How do we combine the units? We can create a new distribution!
class HMM(mcx.Distribution):
def __init__(self, unit):
pass
def sample(self):
# to be defined
def logpdf(self):
# to be defined
Let us assume this hmm
distribution exists. A simple HMM would thus be loosely expressed in MCX as
@mcx.model
def mymodel(hidden_dims, num_units):
x_probs <~ dist.Dirichlet(0.5 * np.eye(hidden_dims))
y_probs <~ Beta(1,1, batch_size=(hidden_dims, num_units))
@mcx.model
def hmm_unit(x_prev):
x <~ Categorical(x_probs[x_prev])
y <~ Bernoulli(y_probs[x])
return y
x_init = np.zeros(num_units)
obs <~ HMM(hmm_unit, x_init)
return obs
HMM where observations depend on previous observations
To challenge this abstraction let us assume a more complex model:
x[t-1] ---> x[t] ---> x[t+1]
| | |
y[t-1] ---> y[t] ---> y[t+1]
The elementary unit becomes:
x[t-1] ---> x[t]
|
y[t-1] ---> y[t]
So the model can be written:
@mcx.model
def mymodel(hidden_dims, num_units):
x_probs <~ Dirichlet(0.5 * np.eye(hidden_dims))
y_probs <~ Beta(1,1, batch_size=(hidden_dims, 2, num_units))
@mcx.model
def hmm_unit(x_prev, y_prev):
x <~ Categorical(x_probs[x_prev])
y <~ Bernoulli(y_probs[x, y_prev])
return y
x_init = np.zeros(num_units)
obs <~ HMM(hmm_unit, x_init)
return obs
Factorial HMM
What if we have a Factorial HMM instead:
x[t-1] ---> x[t] ---> x[t+1]
| | |
v v v
y[t-1] y[t] y[t+1]
^ ^ ^
| | |
v[t-1] ---> v[t] ---> v[t+1]
Elementary unit is:
x[t-1] ---> x[t]
|
v
y[t]
^
|
v[t-1] ---> v[t]
And in code:
@mcx.model
def mymodel(hidden_dims, num_units):
x_probs <~ Dirichlet(0.5 * np.eye(hidden_dims))
v_probs <~ Dirichlet(0.3 * np.eye(hidden_dims))
y_probs <~ Beta(1,1, batch_size=(hidden_dims, 2, num_units))
@mcx.model
def hmm_unit(x_prev, v_prev):
x <~ Categorical(x_probs[x_prev])
v <~ Categorical(v_probs[v_prev])
y <~ Bernoulli(y_probs[x, v])
return y
x_init = np.zeros(num_units)
obs <~ HMM(hmm_unit, x=x_init, v=v_init)
return obs
The abstraction seems to be robust.
Implementing the HMM distribution
We need to provide an implementation for the sample
and logpdf
methods of the HMM distribution.
Sample
When parsing the model to compile it into a sampling function, MCX will transform hmm_unit
into the sample_hmm_unit
function below:
def sample_hmm_unit(rng_key, x):
x_new = Categorical(x_probs[x]).sample(rng_key)
y_new = Bernoulli(y_probs[x_new]).sample(rng_key)
return x_new, y_new
HMM.sample(rng_key)
should return samples for y
and x
’s prior distribution. We can achieve it with:
def scan_update(x, rng_key):
x_new, y = sample_hmm_unit(rng_key, x)
return x_new, (x_new, y)
rng_key = jax.random.PRNGKey(0)
keys = jax.random.split(rng_key, num_units)
_, (x_samples, y_samples) = jax.lax.scan(scan_update, x_init, keys)
likelihood
When parsing the model to compile it into a loglikelihood, MCX with transform hmm_unit
into the logpdf_hmm_unit
function below:
def logpdf_hmm_unit(x_prev, x, y):
loglikelihood = 0
loglikelihood += Categorical(x_probs[x_prev]).logpdf(x)
loglikelihood += Bernoulli(y_probs[x]).logpdf(y)
return loglikelihood
HMM.logpdf(x, y)
should return the loglikelihood of the model given the values of x_probs
, y_probs
(in the higher-level context) x
and y=obs
. We could use:
def scan_update(x_prev, (x, y)):
loglikelihood = logpdf_hmm_unit(x_prev, x, y)
return x, loglikelihood
_, loglikelihoods = jax.lax.scan(scan_update, x_init, (x, obs))
loglikelihood = np.sum(loglikelihoods)
Note: How do you implement time-dependent transitions?
Issue Analytics
- State:
- Created 3 years ago
- Comments:7 (7 by maintainers)
Top GitHub Comments
Look at the batch size,
y_probs
is a 3D array with independent rvs that are Beta-distributed. I first came across that on your blog actually 😃I believe Pyro has an
expand
method that does this, maybe it’s more readable?Yes. We did similar things with the unirep model!