question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

The average should be taken over log probability rather than logits

See original GitHub issue

https://github.com/IntelLabs/bayesian-torch/blob/7abcfe7ff3811c6a5be6326ab91a8d5cb1e8619f/bayesian_torch/examples/main_bayesian_cifar.py#L363-L367 I think the average across the MC runs should be taken over the log probability. However, the output here is the logits before the softmax operation. I think we may first run output = F.log_softmax(output, dim=1) and then take the average.

There are two equivalent ways to take the average, which I think is more reasonable. The first way is

for mc_run in range(args.num_mc):
    output, kl = model(input_var)
    output = F.log_softmax(output, dim=1)
    output_.append(output)
    kl_.append(kl)
output = torch.mean(torch.stack(output_), dim=0)
loss= F.nll_loss(output, target_var) # this is to replace the original cross_entropy_loss

Or equivalently, we can first take the cross-entropy loss for each MC run, and average the losses at the end:

loss = 0
for mc_run in range(args.num_mc):
    output, kl = model(input_var)
    loss = loss + F.cross_entropy(output, target_var, dim=1)
    kl_.append(kl)
loss = loss / args.num_mc  # this is to replace the original cross_entropy_loss

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:5 (3 by maintainers)

github_iconTop GitHub Comments

1reaction
piEspositocommented, Nov 26, 2021

There is not much of a rule for that on training but I think it would be worth to try it and see if it improves training.

For inference 100% sure average should be over probabilities.

0reactions
ranganathkrishnancommented, Dec 1, 2021

@Nebularaid2000 If multiple MC samples are used during training, I think it should be better to calculate the cross entropy loss for each MC run and then average them if that helps with better training convergence. The run script for training can be modified (snippet below). There was no difference in current example runscript as num_mc=1.

        #another way of computing gradients with multiple MC samples
        cross_entropy_ = [] 
        kl_ = []
        output_ = []
        for mc_run in range(args.num_mc):
            output, kl = model(input_var)
            cross_entropy_.append(criterion(output, target_var))
            kl_.append(kl)
            output_.append(output)
        output = torch.mean(torch.stack(output_), dim=0)
        loss = torch.mean(torch.stack(cross_entropy_), dim=0) + torch.mean(torch.stack(kl_), dim=0)/args.batch_size
Read more comments on GitHub >

github_iconTop Results From Across the Web

What does the logit value actually mean? - Cross Validated
Negative logit values indicate probabilities smaller than 0.5, positive logits indicate probabilities greater than 0.5.
Read more >
The Logit-Normal: A ubiquitous but strange distribution!
In this post we explore the very curious logit-normal distribution, which appears frequently in statistics and yet has no analytical ...
Read more >
Linear vs. Logistic Probability Models: Which is Better, and ...
For the logistic model to fit better than the linear model, it must be the case that the log odds are a linear...
Read more >
Probability, log-odds, and odds
Values of x ranging from -1 to +1 create probabilities that range from about 0.25 to 0.75. The material below will let you...
Read more >
'Logit' of Logistic Regression; Understanding the Fundamentals
The base of the logarithm is not important but taking logarithm of odds is. We can retrieve the probability of success from eq....
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found