Use a Bayesian CNN on the MNIST dataset
See original GitHub issueBlackjax already has an example where we use SGLD to sample from a 3 layer MLP with a very decent accuracy when using the uncertainties to discard ambiguous predictions. We can use the CNN architecture in the Flax documentation:
from flax import linen as nn
class CNN(nn.Module):
"""A simple CNN model."""
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
return x
And the logprob function as (not tested):
from jax.tree_utils import flatten_pytree
import distrax
def logpdf(params, images, categories, model):
logits = model.apply(params, images).ravel()
flat_params, _ = ravel_pytree(params)
log_prior = distrax.Normal(0.0, 1.0).log_prob(flat_params).sum()
log_likelihood = distrax.Bernoulli(logits=logits).log_prob(categories).sum()
return log_prior + log_likelihood
We should look at:
- Comparison between SgLD and SgHMC (#211)
- Raw accuracy compared to a solution that uses SGD (with Optax)
- Show the distribution of “confidence” in predictions
- Accuracy once we’ve removed examples where model is not sure
- Examples where the model is not sure / proportion of examples where it is not sure
Issue Analytics
- State:
- Created a year ago
- Reactions:2
- Comments:5 (4 by maintainers)
Top Results From Across the Web
Bayesian CNN model on MNIST data using Tensorflow ...
This blog will use TensorFlow Probability to implement Bayesian CNN and compare it to regular CNN, using the famous MNIST data. The human ......
Read more >Bayesian Convolutional Neural Network - Chan`s Jupyter
In this post, we will create a Bayesian convolutional neural network to classify the famous MNIST handwritten digits.
Read more >Bayesian neural network using Pyro and PyTorch on MNIST ...
Bayesian neural network using Pyro and PyTorch on MNIST dataset. Jupyter notebook corresponding to tutorial: Getting your Neural Network to ...
Read more >Uncertainty In Deep Learning — Bayesian CNN | TensorFlow ...
Since this post is covering the basics, we will be using the mnist dataset. In this model, we use default values for the...
Read more >2021-08-26-01-Bayesian-Convolutional-Neural-Network.ipynb
Now, you can train the probabilistic model on the MNIST data using the code below. Note that the target data now uses the...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Hey @rlouf. Yes, still planning to work on it. Expect updates in September.
Hi @rlouf, I’ll work on this issue!