question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[Bug] ”weight=class_weight“ should be modified for ”pos_weight=class_weight“

See original GitHub issue

loss = F.binary_cross_entropy_with_logits( pred, label, weight=class_weight, reduction=‘none’)

it should be modified for “loss = F.binary_cross_entropy_with_logits( pred, label, pos_weight=class_weight, reduction=‘none’)

def binary_cross_entropy(pred,
                         label,
                         weight=None,
                         reduction='mean',
                         avg_factor=None,
                         class_weight=None):
    r"""Calculate the binary CrossEntropy loss with logits.

    Args:
        pred (torch.Tensor): The prediction with shape (N, \*).
        label (torch.Tensor): The gt label with shape (N, \*).
        weight (torch.Tensor, optional): Element-wise weight of loss with shape
            (N, ). Defaults to None.
        reduction (str): The method used to reduce the loss.
            Options are "none", "mean" and "sum". If reduction is 'none' , loss
            is same shape as pred and label. Defaults to 'mean'.
        avg_factor (int, optional): Average factor that is used to average
            the loss. Defaults to None.
        class_weight (torch.Tensor, optional): The weight for each class with
            shape (C), C is the number of classes. Default None.

    Returns:
        torch.Tensor: The calculated loss
    """
    assert pred.dim() == label.dim()
    # Ensure that the size of class_weight is consistent with pred and label to
    # avoid automatic boracast,
    if class_weight is not None:
        N = pred.size()[0]
        class_weight = class_weight.repeat(N, 1)
    loss = F.binary_cross_entropy_with_logits(
        pred, label, weight=class_weight, reduction='none')

    # apply weights and do the reduction
    if weight is not None:
        assert weight.dim() == 1
        weight = weight.float()
        if pred.dim() > 1:
            weight = weight.reshape(-1, 1)
    loss = weight_reduce_loss(
        loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
    return loss

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:6 (6 by maintainers)

github_iconTop GitHub Comments

1reaction
mzr1996commented, Sep 7, 2021

Hello, in binary cross-entropy, pos_weight means “a weight of positive examples”, instead of “a weight given to each class”. In regular cross-entropy, PyTorch provides weight, and describes it as “a manual rescaling weight given to each class.”. Therefore, to keep the same behavior, we define class_weight as “class-wise weight”, differ from pos_weight.

class-wise
hat glasses
positive 0 1
negative 1 0

As for your need, we will add the pos_weight parameter in our BCE loss recently.

0reactions
Ezra-Yucommented, Dec 7, 2021

we have added pos_weight in BCE refer to #515

Read more comments on GitHub >

github_iconTop Results From Across the Web

No results found

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found