question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Use a Bayesian CNN on the MNIST dataset

See original GitHub issue

Blackjax already has an example where we use SGLD to sample from a 3 layer MLP with a very decent accuracy when using the uncertainties to discard ambiguous predictions. We can use the CNN architecture in the Flax documentation:

from flax import linen as nn  

class CNN(nn.Module):
  """A simple CNN model."""

  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)

    return x

And the logprob function as (not tested):

from jax.tree_utils import flatten_pytree
import distrax

def logpdf(params, images, categories, model):
    logits = model.apply(params, images).ravel()
    flat_params, _ = ravel_pytree(params)
    log_prior = distrax.Normal(0.0, 1.0).log_prob(flat_params).sum()
    log_likelihood = distrax.Bernoulli(logits=logits).log_prob(categories).sum()
 
    return log_prior + log_likelihood

We should look at:

  • Comparison between SgLD and SgHMC (#211)
  • Raw accuracy compared to a solution that uses SGD (with Optax)
  • Show the distribution of “confidence” in predictions
  • Accuracy once we’ve removed examples where model is not sure
  • Examples where the model is not sure / proportion of examples where it is not sure

Issue Analytics

  • State:open
  • Created a year ago
  • Reactions:2
  • Comments:5 (4 by maintainers)

github_iconTop GitHub Comments

1reaction
gerdmcommented, Aug 30, 2022

Hey @rlouf. Yes, still planning to work on it. Expect updates in September.

1reaction
gerdmcommented, Jun 30, 2022

Hi @rlouf, I’ll work on this issue!

Read more comments on GitHub >

github_iconTop Results From Across the Web

Bayesian CNN model on MNIST data using Tensorflow ...
This blog will use TensorFlow Probability to implement Bayesian CNN and compare it to regular CNN, using the famous MNIST data. The human ......
Read more >
Bayesian Convolutional Neural Network - Chan`s Jupyter
In this post, we will create a Bayesian convolutional neural network to classify the famous MNIST handwritten digits.
Read more >
Bayesian neural network using Pyro and PyTorch on MNIST ...
Bayesian neural network using Pyro and PyTorch on MNIST dataset. Jupyter notebook corresponding to tutorial: Getting your Neural Network to ...
Read more >
Uncertainty In Deep Learning — Bayesian CNN | TensorFlow ...
Since this post is covering the basics, we will be using the mnist dataset. In this model, we use default values for the...
Read more >
2021-08-26-01-Bayesian-Convolutional-Neural-Network.ipynb
Now, you can train the probabilistic model on the MNIST data using the code below. Note that the target data now uses the...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found