Support for Gamma sampler in JAX
See original GitHub issueGamma sampler is one of the most popular samplers. Many popular distributions can be derived from Gamma sampler (e.g. Beta, Dirichlet, StudentT). This thread is created to discuss about details for an implementation in JAX.
Implement Gamma sampler
I made an initial attempt in https://github.com/google/jax/pull/551. Currently, the implementation is a bit slow but I believe that with some help from JAX devs, it will be fast.
TODO:
- Explore warp-based parallelism as suggested by @slinderman
Implement JVP rule for shape parameter a
(or alpha).
There are two recent papers which addresses this:
- Ref 1. Implicit Reparameterization Gradients, which is implemented in tensorflow.
- Ref 2. Pathwise Derivatives Beyond the Reparameterization Trick, which is implemented in pytorch.
Two references come up with the same formula for gradient of z ~ Gamma(a) w.r.t. a:
dz/da = - dCDF(z; a)/da / pdf(z; a)
Further derivation leads to:
# dz = - dCDF(z; a) / pdf(z; a)
# pdf = z^(a-1) * e^(-z) / Gamma(a)
# CDF(z; a) = IncompleteGamma(a, z) / Gamma(a)
# dCDF(z; a) = (dIncompleteGamma - IncompleteGamma * Digamma(a)) / Gamma(a)
# dz = (IncompleteGamma(a, z) * Digamma(a) - dIncompleteGamma) * e^z / z^(a-1)
So to compute unnormalized_dCDF
, a derivative for IncompleteGamma function is required. [1] and [2] provide different ways to compute this derivative.
Two questions come to my mind:
- If we have implemented reparameterized version of Gamma, do we need to compute pathwise derivative for all distributions Beta, Dirichlet, StudentT too? In [2], the authors computed pathwise derivative for Gamma and Beta/Dirichlet separately, and did not deal with StudentT. In [1], the authors only provide methods to compute pathwise derivative for Gamma.
- Is it true (as reported in [1]) that the algorithm in [1] is faster than the algorithm in [2]? Or the performance difference comes from using different frameworks? It seems to me that the authors in [2] only uses first 5 terms in the IncompleteGamma series (and correct the accuracy on extreme cases using polynomials) while the authors of [1] use up to 200 terms. So I feel that the method in [2] would be faster than in [1], but might be less accuracy.
cc @neerajprad
Issue Analytics
- State:
- Created 4 years ago
- Reactions:3
- Comments:13 (9 by maintainers)
Currently these distributions are sampled via rejection reparameterized via the reparameterization trick: https://arxiv.org/abs/1805.08498 (which in essence means overriding the gradient of the samples with gradients of the quantile).
In principle one can use inverse CDF sampling for these truncated distributions and this will work well in some parameter regimes. The issue tends to be that the CDFs / Quantiles for these distributions tend to become ill-behaved in some parameter regimes. For instance, in the TruncatedNormal case if the bounds are really far away from the mean, you’ll can get bad samples due to the ‘non-invertibility’ of these functions far enough out.
Generally, I think one would need to come up with hand tailored rejection samplers (or other methods) for the distributions in question in order to get performance for sampling (in terms of speed and sample quality).
Sorry, my comment here was just availability of the non-truncated distributions.
Brian’s suggestion contains a way to do generic truncation, but it does the naive thing of sampling and checking if it is in the truncation bounds (which can be sample inefficient).