question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[Bug] Fully Bayesian Inference with Pyro Does Not Work

See original GitHub issue

🐛 Bug

Right now, the way our Bayesian inference with pyro works is that we loop over all Parameters with priors, call pyro.sample and load the sample into the parameter value. Here’s an extremely simplified model that is equivalent to what we do that does NOT work:

def pyro_model2(x, y):
    raw_lengthscale = pyro.sample('raw_lengthscale', Normal(0, 1))
    raw_outputscale = pyro.sample('raw_outputscale', Normal(0, 1))

    # This is not an okay way to load the results of pyro.sample statements into a Module
    covar_module.initialize(**{'raw_outputscale': raw_outputscale, 'base_kernel.raw_lengthscale': raw_lengthscale})
    
    covar_x = covar_module(x).evaluate()
    pyro.sample("obs", pyro.distributions.MultivariateNormal(torch.zeros(y.size(-1)), covar_x + torch.eye(y.size(-1))), obs=y)

See the attached python notebook that demonstrates that this leads to completely wrong gradients for the potential energy, so HMC is entirely broken.

The best way I’ve found to fix this is to use the (possibly deprecated?) pyro.random_module primitive. Here’s a pyro model using a full GPyTorch GP that gets correct potential derivatives:

def pyro_model5(x, y):
    priors= {
        'covar_module.base_kernel.raw_lengthscale': Normal(0, 1).expand([1, 1]),
        'covar_module.raw_outputscale': Normal(0, 1),
        'likelihood.noise_covar.raw_noise': Normal(0, 1).expand([1]),
        'mean_module.constant': Normal(0, 1),
   }
    fn = pyro.random_module("model", model, prior=priors)
    sampled_model = fn()
    
    output = sampled_model.likelihood(sampled_model(x))
    pyro.sample("obs", output, obs=y)

Now, we could definitely just replace our existing GPyTorch interface to basically do the above instead for any parameter that has a prior registered. The only problem is that I don’t know if this will let us place priors over functions of Parameters (i.e., if we want to place a prior over the lengthscale rather than the raw_lengthscale).

Thoughts? I feel like the above is almost there, but it’d be great to still support placing priors over derived values of the parameters.

cc/ @rmgarnett

also cc/ @fritzo @eb8680 @martinjankowiak in case any of the pyro devs hopefully have a thought about how to do this properly that I’m just missing?

Sampling Bug.ipynb.txt

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Reactions:1
  • Comments:7 (2 by maintainers)

github_iconTop GitHub Comments

2reactions
sdaultoncommented, Apr 17, 2021
1reaction
jacobrgardnercommented, Apr 19, 2021

@neerajprad Thanks! This is actually super helpful. At a minimum, the reset_params_pyro solution is much better for us than directly using pyro.random_module because it’s obvious how it allows us to continue placing priors over transformed versions of the Parameters, and is only equally hacky.

Based on what you’ve said, I’m assuming the following workflow would work:

  1. deepcopy the module.
  2. Call pyro.sample for each prior and apply the inverse transform to get the appropriate value for the underlying raw parameter, as we do now.
  3. Use reset_params_pyro above to replace the parameters with the samples on the copied module.

If so, I will probably implement things this way in the short term. It keeps most or all of the interface we already have for fully Bayesian inference but, you know, actually works 🙃 .

Regarding better support for nn.Module within Pyro’s HMC: this would definitely be terrific to have. I imagine I’m not the only one that wants to take the parameters in a complex nested pre-existing nn.Module and do HMC over them.

Read more comments on GitHub >

github_iconTop Results From Across the Web

I don't understand why NUTS code is not working. bayesian ...
I wrote code of Probabilistic-Programming-and-Bayesian-Methods-for-Hackersbayesian-chapter1 in pyro, however that code is not working well.
Read more >
NaNs and constraints [bug, maybe?] [discussion, possibly?]
pyro is a tool not a solution that claims to robustly do inference for all possible models. for certain problem domains, say certain...
Read more >
Bayesian Generalized Linear Models with Pyro
To address these problems, we can employ Pyro and PyTorch to construct our own linear model which will address all the pain points...
Read more >
A Prelude to Pyro - Chad Scherrer
The goal of Bayesian inference is to "understand" this distribution. ... Markov chain Monte Carlo ("MCMC") methods work well for problems ...
Read more >
Bayesian inference; How we are able to chase the Posterior
This post I will take a formal definition of the problem (As I've skipped ... We will do a full Bayesian analysis in...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found