question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Support for Gamma sampler in JAX

See original GitHub issue

Gamma sampler is one of the most popular samplers. Many popular distributions can be derived from Gamma sampler (e.g. Beta, Dirichlet, StudentT). This thread is created to discuss about details for an implementation in JAX.

Implement Gamma sampler

I made an initial attempt in https://github.com/google/jax/pull/551. Currently, the implementation is a bit slow but I believe that with some help from JAX devs, it will be fast.

TODO:

  • Explore warp-based parallelism as suggested by @slinderman

Implement JVP rule for shape parameter a (or alpha).

There are two recent papers which addresses this:

Two references come up with the same formula for gradient of z ~ Gamma(a) w.r.t. a:

dz/da = - dCDF(z; a)/da / pdf(z; a)

Further derivation leads to:

# dz = - dCDF(z; a) / pdf(z; a)
# pdf = z^(a-1) * e^(-z) / Gamma(a)
# CDF(z; a) = IncompleteGamma(a, z) / Gamma(a)
# dCDF(z; a) = (dIncompleteGamma - IncompleteGamma * Digamma(a)) / Gamma(a)
# dz = (IncompleteGamma(a, z) * Digamma(a) - dIncompleteGamma) * e^z / z^(a-1)

So to compute unnormalized_dCDF, a derivative for IncompleteGamma function is required. [1] and [2] provide different ways to compute this derivative.

Two questions come to my mind:

  • If we have implemented reparameterized version of Gamma, do we need to compute pathwise derivative for all distributions Beta, Dirichlet, StudentT too? In [2], the authors computed pathwise derivative for Gamma and Beta/Dirichlet separately, and did not deal with StudentT. In [1], the authors only provide methods to compute pathwise derivative for Gamma.
  • Is it true (as reported in [1]) that the algorithm in [1] is faster than the algorithm in [2]? Or the performance difference comes from using different frameworks? It seems to me that the authors in [2] only uses first 5 terms in the IncompleteGamma series (and correct the accuracy on extreme cases using polynomials) while the authors of [1] use up to 200 terms. So I feel that the method in [2] would be faster than in [1], but might be less accuracy.

cc @neerajprad

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Reactions:3
  • Comments:13 (9 by maintainers)

github_iconTop GitHub Comments

1reaction
srvasudecommented, Feb 19, 2021

Currently these distributions are sampled via rejection reparameterized via the reparameterization trick: https://arxiv.org/abs/1805.08498 (which in essence means overriding the gradient of the samples with gradients of the quantile).

In principle one can use inverse CDF sampling for these truncated distributions and this will work well in some parameter regimes. The issue tends to be that the CDFs / Quantiles for these distributions tend to become ill-behaved in some parameter regimes. For instance, in the TruncatedNormal case if the bounds are really far away from the mean, you’ll can get bad samples due to the ‘non-invertibility’ of these functions far enough out.

Generally, I think one would need to come up with hand tailored rejection samplers (or other methods) for the distributions in question in order to get performance for sampling (in terms of speed and sample quality).

1reaction
srvasudecommented, Feb 19, 2021

Sorry, my comment here was just availability of the non-truncated distributions.

Brian’s suggestion contains a way to do generic truncation, but it does the naive thing of sampling and checking if it is in the truncation bounds (which can be sample inefficient).

Read more comments on GitHub >

github_iconTop Results From Across the Web

jax.random.gamma - JAX documentation - Read the Docs
Sample Gamma random values with given shape and float dtype. Parameters. key ( Union [ Array , PRNGKeyArray ]) – a PRNG ...
Read more >
tfp.substrates.jax.distributions.Gamma - Probability - TensorFlow
tfp.substrates.jax.distributions.Gamma ; log_rate, Floating point tensor, natural logarithm of the inverse scale params of the distribution(s). Mutually ...
Read more >
pymc.sampling.jax.sample_blackjax_nuts
Number of iterations to tune. Samplers adjust the step sizes, scalings or similar during tuning. Tuning samples will be drawn in addition to...
Read more >
032474 - NSG-SGM3-Abcd1[Δ897] Strain Details
... immunodeficient NOD scid gamma (NSG) mouse with cytokines that support the ... primary AML samples than other models and are useful for...
Read more >
Distributions - NumPyro documentation
The tensor shape of samples from this distribution. ... from jax import random >>> import jax.numpy as jnp >>> import numpyro >>> import ......
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found