FR Streaming MCMC interface for big models
See original GitHub issueThis issue proposes a streaming architecture for MCMC on models with large memory footprint.
The problem this addresses is that, in models with high-dimensional latents (say >1M latent variables), it becomes difficult to save a list of samples, especially on GPUs with limited memory. The proposed solution is to eagerly compute statistics on those samples, and discard them during inference.
@fehiepsi suggested creating a new MCMC class (say StreamingMCMC
) with similar interface to MCMC
and still independent of kernel (using either HMC
or NUTS
) but that follows an internal streaming architecture. Since large models like these usually run on GPU or are otherwise memory constrained, it is reasonable to avoid multiprocessing support in StreamingMCMC
.
Along with the new StreamingMCMC
class I think there should be a set of helpers to streamingly compute statistics from sample streams, e.g. mean, variance, covariance, r_hat statistics.
Tasks (to be split into multiple PRs)
- #2857 Create a
StreamingMCMC
class with interface identical to MCMC (except disallowing parallel chains). - #2857 Generalize unit tests of
MCMC
to parametrize over bothMCMC
andStreamingMCMC
- Add some tests ensuring
StreamingMCMC
andMCMC
perform identical computations, up to numerical precision - Create a tutorial using
StreamingMCMC
on a big model
- #2856 Create streaming helpers for mean, variance, etc.
- Add
r_hat
to pyro.ops.streaming - Add
n_eff = ess
to pyro.ops.streaming
Issue Analytics
- State:
- Created 2 years ago
- Comments:9 (9 by maintainers)
@mtsokol answering your latest questions:
StreamingMCMC
will typically be used with large memory-bound models with huge tensors, where the python overhead is negligible. For this same reason I think we should avoid batching since that increases memory overhead. (In fact I suspect the bottleneck will be in pyro.ops.streaming where we may need to refactor to perform tensor operations in-place)..merge()
operation in #2856 to make this easy for you. The main motivation is to compute cross-chain statistics like r_hat.https://github.com/pyro-ppl/pyro/blob/4a61ef2f9050ef81d1b0aa148d14ecc876f24a51/pyro/infer/mcmc/api.py#L389-L392
@fritzo thanks for guidance! Right now I’m looking at the current implementation and starting working on this. This abstraction with
StreamingStatistic
is sound to me.StreamingMCMC
will only iterate and call method on passed objects implementing that interface.Sure! I can start working on
StreamingMCMC
and already followStreamingStatistic
notion. When your RP is ready I will adjust my implementation.Should I introduce some
AbstractMCMC
interface that existingMCMC
andStreamingMCMC
will implement?