[Bug] ”weight=class_weight“ should be modified for ”pos_weight=class_weight“
See original GitHub issueloss = F.binary_cross_entropy_with_logits( pred, label, weight=class_weight, reduction=‘none’)
it should be modified for “loss = F.binary_cross_entropy_with_logits( pred, label, pos_weight=class_weight, reduction=‘none’)
def binary_cross_entropy(pred,
label,
weight=None,
reduction='mean',
avg_factor=None,
class_weight=None):
r"""Calculate the binary CrossEntropy loss with logits.
Args:
pred (torch.Tensor): The prediction with shape (N, \*).
label (torch.Tensor): The gt label with shape (N, \*).
weight (torch.Tensor, optional): Element-wise weight of loss with shape
(N, ). Defaults to None.
reduction (str): The method used to reduce the loss.
Options are "none", "mean" and "sum". If reduction is 'none' , loss
is same shape as pred and label. Defaults to 'mean'.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
class_weight (torch.Tensor, optional): The weight for each class with
shape (C), C is the number of classes. Default None.
Returns:
torch.Tensor: The calculated loss
"""
assert pred.dim() == label.dim()
# Ensure that the size of class_weight is consistent with pred and label to
# avoid automatic boracast,
if class_weight is not None:
N = pred.size()[0]
class_weight = class_weight.repeat(N, 1)
loss = F.binary_cross_entropy_with_logits(
pred, label, weight=class_weight, reduction='none')
# apply weights and do the reduction
if weight is not None:
assert weight.dim() == 1
weight = weight.float()
if pred.dim() > 1:
weight = weight.reshape(-1, 1)
loss = weight_reduce_loss(
loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
return loss
Issue Analytics
- State:
- Created 2 years ago
- Comments:6 (6 by maintainers)
Top Results From Across the Web
No results found
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Hello, in binary cross-entropy,
pos_weight
means “a weight of positive examples”, instead of “a weight given to each class”. In regular cross-entropy, PyTorch providesweight
, and describes it as “a manual rescaling weight given to each class.”. Therefore, to keep the same behavior, we defineclass_weight
as “class-wise weight”, differ frompos_weight
.As for your need, we will add the
pos_weight
parameter in our BCE loss recently.we have added pos_weight in BCE refer to #515