Model and guide event_dims disagree at site 'module$$$out.weight': 0 vs 1
See original GitHub issueIssue Description
I want to implement the Making Your Neural Network Say “I Don’t Know” — Bayesian NNs using Pyro and PyTorch
but when i run it , i got error.
ValueError: Model and guide event_dims disagree at site 'module$$$out.weight': 0 vs 1
i copy its module and guide here
def model(x_data, y_data):
fc1w_prior = Normal(loc=torch.zeros_like(net.fc1.weight), scale=torch.ones_like(net.fc1.weight))
fc1b_prior = Normal(loc=torch.zeros_like(net.fc1.bias), scale=torch.ones_like(net.fc1.bias))
outw_prior = Normal(loc=torch.zeros_like(net.out.weight), scale=torch.ones_like(net.out.weight))
outb_prior = Normal(loc=torch.zeros_like(net.out.bias), scale=torch.ones_like(net.out.bias))
priors = {'fc1.weight': fc1w_prior, 'fc1.bias': fc1b_prior, 'out.weight': outw_prior, 'out.bias': outb_prior}
print(priors['out.weight'])
# lift module parameters to random variables sampled from the priors
lifted_module = pyro.random_module("module", net, priors)
# sample a regressor (which also samples w and b)
lifted_reg_model = lifted_module()
lhat = log_softmax(lifted_reg_model(x_data))
pyro.sample("obs", Categorical(logits=lhat), obs=y_data)
def guide(x_data, y_data):
# First layer weight distribution priors
fc1w_mu = torch.randn_like(net.fc1.weight)
fc1w_sigma = torch.randn_like(net.fc1.weight)
fc1w_mu_param = pyro.param("fc1w_mu", fc1w_mu)
fc1w_sigma_param = softplus(pyro.param("fc1w_sigma", fc1w_sigma))
fc1w_prior = Normal(loc=fc1w_mu_param, scale=fc1w_sigma_param)
# First layer bias distribution priors
fc1b_mu = torch.randn_like(net.fc1.bias)
fc1b_sigma = torch.randn_like(net.fc1.bias)
fc1b_mu_param = pyro.param("fc1b_mu", fc1b_mu)
fc1b_sigma_param = softplus(pyro.param("fc1b_sigma", fc1b_sigma))
fc1b_prior = Normal(loc=fc1b_mu_param, scale=fc1b_sigma_param)
# Output layer weight distribution priors
outw_mu = torch.randn_like(net.out.weight)
outw_sigma = torch.randn_like(net.out.weight)
outw_mu_param = pyro.param("outw_mu", outw_mu)
outw_sigma_param = softplus(pyro.param("outw_sigma", outw_sigma))
outw_prior = Normal(loc=outw_mu_param, scale=outw_sigma_param).independent(1)
# outw_prior = Normal(loc=outw_mu_param, scale=outw_sigma_param)
# Output layer bias distribution priors
outb_mu = torch.randn_like(net.out.bias)
outb_sigma = torch.randn_like(net.out.bias)
outb_mu_param = pyro.param("outb_mu", outb_mu)
outb_sigma_param = softplus(pyro.param("outb_sigma", outb_sigma))
outb_prior = Normal(loc=outb_mu_param, scale=outb_sigma_param)
priors = {'fc1.weight': fc1w_prior, 'fc1.bias': fc1b_prior, 'out.weight': outw_prior, 'out.bias': outb_prior}
print(priors['out.weight'])
lifted_module = pyro.random_module("module", net, priors)
return lifted_module()
how can i resolve this problem? i tried to add .to_event()
after the variables (such as fc1w_prior) but it cause more errors
Environment
- ubuntu python 3.8.8
- PyTorch 1.8.1+cu102
- Pyro version: 1.6.0
Hope for some help . thanks a lot !!
Issue Analytics
- State:
- Created 2 years ago
- Comments:5 (2 by maintainers)
Top Results From Across the Web
Model and guide shapes disagree at site *** - Misc.
Anyone has the clue, why the shapes disagree at some point? Here is the z_t sample site in the model: z_loc here is...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Hi I eventually fixed the code by installing the previous versions of all of the necessary packages. Here is my code: https://github.com/arm-on/interpretable-image-classification/blob/main/bayesian-nn/MNIST_Bayesian_Training.ipynb
I had the same issue and was not able to fix it by manually setting the event dimensions. When trying to run the code, I recommend installing an earlier version of pyro (ex: !pip3 install pyro-ppl==1.4.0) in the notebook instead as a temporary fix. I am still trying to figure out how to fix it.