[FR] Add ordered categorical distribution
See original GitHub issueIssue Description
I’m going through the Statistical Rethinking textbook by McElreath and in chapter 12 the author discusses how to handle ordered categorical variables in a sensible way. I tried to look through the lists of distributions available in both torch
and pyro
, but could not find anything similar to the dordlogit
(aka ordered-logit) distribution that he describes in the text. I realize that it’s just a re-parametrization of the categorical distribution, but I think it might be a useful utility to include.
I tried to recreate the distribution myself; does this seem like a sensible implementation, and would it be something you would want to include in pyro
?
Code Snippet
I’m no expert in implementing torch/pyro distributions, but this seems to do the job, and I have used it successfully to do the trolley problem example in the textbook using NUTS. There are probably other methods/attributes that are missing for a full implementation.
import torch
from torch.distributions import constraints
from pyro.distributions import TorchDistribution, Categorical
class OrderedCategorical(TorchDistribution):
"""
Alternative parametrization of the distribution over a categorical variable.
Instead of the typical parametrization of a categorical variable in terms
of the probability mass of the individual categories ``p``, this provides an
alternative that is useful in specifying ordered categorical models. This
accepts a list of ``cutpoints`` which are a (potentially initially unordered)
vector of real numbers denoting baseline cumulative log-odds of the individual
categories, and a model vector ``phi`` which modifies the baselines for each
sample individually.
These cumulative log-odds are then transformed into a discrete cumulative
probability distribution, that is finally differenced to return the probability
mass function ``p`` that specifies the categorical distribution.
"""
support = constraints.nonnegative_integer
arg_constraints = {"phi": constraints.real, "cutpoints": constraints.real}
has_rsample = False
def __init__(self, phi, cutpoints):
assert len(cutpoints.shape) == 1 # cutpoints must be 1d vector
assert len(phi.shape) == 1 # model terms must be 1d vector of samples
N, K = phi.shape[0], cutpoints.shape[0]+1
cutpoints = torch.sort(cutpoints).values.reshape(1, -1) # sort and reshape for broadcasting
q = torch.sigmoid(cutpoints - phi.reshape(-1, 1)) # cumulative probabilities
# turn cumulative probabilities into probability mass of categories
p = torch.zeros((N, K)) # (batch/sample dim, categories)
p[:,0] = q[:,0]
p[:,1:-1] = (q - torch.roll(q, 1, dims=1))[:,1:]
p[:,-1] = 1 - q[:,-1]
# self.cum_prob = q
self.dist = Categorical(p)
def sample(self, *args, **kwargs):
return self.dist.sample(*args, **kwargs)
def log_prob(self, *args, **kwargs):
return self.dist.log_prob(*args, **kwargs)
@property
def _event_shape(self):
return self.dist._event_shape
Issue Analytics
- State:
- Created 3 years ago
- Reactions:3
- Comments:13 (12 by maintainers)
Wanted to say I am using this in a personal project and it’s very helpful—thank you all (also going through Statistical Relearning)
Thanks for all the advice implementing the
OrderedLogistic
distribution @fehiepsi and @fritzo, I learned a lot from that. Was also thinking we might want to complement it with anOrderedTransform
for cutpoints like in numpyro. Does that sound like a good idea?