[Bug] Fully Bayesian Inference with Pyro Does Not Work
See original GitHub issue🐛 Bug
Right now, the way our Bayesian inference with pyro works is that we loop over all Parameter
s with priors, call pyro.sample
and load the sample into the parameter value. Here’s an extremely simplified model that is equivalent to what we do that does NOT work:
def pyro_model2(x, y):
raw_lengthscale = pyro.sample('raw_lengthscale', Normal(0, 1))
raw_outputscale = pyro.sample('raw_outputscale', Normal(0, 1))
# This is not an okay way to load the results of pyro.sample statements into a Module
covar_module.initialize(**{'raw_outputscale': raw_outputscale, 'base_kernel.raw_lengthscale': raw_lengthscale})
covar_x = covar_module(x).evaluate()
pyro.sample("obs", pyro.distributions.MultivariateNormal(torch.zeros(y.size(-1)), covar_x + torch.eye(y.size(-1))), obs=y)
See the attached python notebook that demonstrates that this leads to completely wrong gradients for the potential energy, so HMC is entirely broken.
The best way I’ve found to fix this is to use the (possibly deprecated?) pyro.random_module
primitive. Here’s a pyro model using a full GPyTorch GP that gets correct potential derivatives:
def pyro_model5(x, y):
priors= {
'covar_module.base_kernel.raw_lengthscale': Normal(0, 1).expand([1, 1]),
'covar_module.raw_outputscale': Normal(0, 1),
'likelihood.noise_covar.raw_noise': Normal(0, 1).expand([1]),
'mean_module.constant': Normal(0, 1),
}
fn = pyro.random_module("model", model, prior=priors)
sampled_model = fn()
output = sampled_model.likelihood(sampled_model(x))
pyro.sample("obs", output, obs=y)
Now, we could definitely just replace our existing GPyTorch interface to basically do the above instead for any parameter that has a prior registered. The only problem is that I don’t know if this will let us place priors over functions of Parameters
(i.e., if we want to place a prior over the lengthscale
rather than the raw_lengthscale
).
Thoughts? I feel like the above is almost there, but it’d be great to still support placing priors over derived values of the parameters.
cc/ @rmgarnett
also cc/ @fritzo @eb8680 @martinjankowiak in case any of the pyro devs hopefully have a thought about how to do this properly that I’m just missing?
Issue Analytics
- State:
- Created 2 years ago
- Reactions:1
- Comments:7 (2 by maintainers)
Top GitHub Comments
Cc @jpchen @Balandat @dme65
@neerajprad Thanks! This is actually super helpful. At a minimum, the
reset_params_pyro
solution is much better for us than directly usingpyro.random_module
because it’s obvious how it allows us to continue placing priors over transformed versions of the Parameters, and is only equally hacky.Based on what you’ve said, I’m assuming the following workflow would work:
pyro.sample
for each prior and apply the inverse transform to get the appropriate value for the underlying raw parameter, as we do now.reset_params_pyro
above to replace the parameters with the samples on the copied module.If so, I will probably implement things this way in the short term. It keeps most or all of the interface we already have for fully Bayesian inference but, you know, actually works 🙃 .
Regarding better support for
nn.Module
within Pyro’s HMC: this would definitely be terrific to have. I imagine I’m not the only one that wants to take the parameters in a complex nested pre-existingnn.Module
and do HMC over them.