Assign different alphas to every class in Focal Loss
See original GitHub issueHi
First of all, great work with kornia! Currently, focal loss (FL) is implemented for multiclass tasks. However, its alpha parameter is shared between all classes when the original paper defines FL for binary tasks as: FL(pt)=−αt(1−pt)γlog(pt). See https://arxiv.org/abs/1708.02002 and #493 .
I modified the current FL code to support different alphas for every class (see below). The only relevant change is:
# Instead of using a constant, we change to a tensor of size (N, C, *) and rename alpha to alphas
alphas = torch.tensor(alphas, dtype=input.dtype, device=input.device).view(1, -1, *[1]*(input.ndim-2))
Here is the whole code (I haven’t changed docs):
from typing import Optional, List # <-Change: Add List
import torch
import torch.nn as nn
import torch.nn.functional as F
from kornia.utils import one_hot
# based on:
# https://github.com/zhezh/focalloss/blob/master/focalloss.py
def focal_loss(
input: torch.Tensor,
target: torch.Tensor,
alphas: Optional[List[float]], # <-Change: rename to alphas and to a list of floats
gamma: float = 2.0,
reduction: str = 'none',
eps: float = 1e-8) -> torch.Tensor:
r"""Function that computes Focal loss.
See :class:`~kornia.losses.FocalLoss` for details.
"""
if not torch.is_tensor(input):
raise TypeError("Input type is not a torch.Tensor. Got {}"
.format(type(input)))
if not len(input.shape) >= 2:
raise ValueError("Invalid input shape, we expect BxCx*. Got: {}"
.format(input.shape))
if input.size(0) != target.size(0):
raise ValueError('Expected input batch_size ({}) to match target batch_size ({}).'
.format(input.size(0), target.size(0)))
n = input.size(0)
out_size = (n,) + input.size()[2:]
if target.size()[1:] != input.size()[2:]:
raise ValueError('Expected target size {}, got {}'.format(
out_size, target.size()))
if not input.device == target.device:
raise ValueError(
"input and target must be in the same device. Got: {} and {}" .format(
input.device, target.device))
### New addition ###
alphas = torch.tensor(alphas, dtype=input.dtype, device=input.device).view(1, -1, *[1]*(input.ndim-2))
if alphas.size(1) != input.size(1):
raise ValueError("Invalid alphas shape. we expect{} alpha values. Got: {}"
.format(input.size(1), alphas.size(1)))
# Normalize alphas to sum up 1
alphas.div_(alphas.sum())
# Original code:
# compute softmax over the classes axis
input_soft: torch.Tensor = F.softmax(input, dim=1) + eps
# create the labels one hot tensor
target_one_hot: torch.Tensor = one_hot(
target, num_classes=input.shape[1],
device=input.device, dtype=input.dtype)
# compute the actual focal loss
weight = torch.pow(-input_soft + 1., gamma)
focal = -alphas * weight * torch.log(input_soft) # <-Change: alpha -> alphas
loss_tmp = torch.sum(target_one_hot * focal, dim=1)
if reduction == 'none':
loss = loss_tmp
elif reduction == 'mean':
loss = torch.mean(loss_tmp)
elif reduction == 'sum':
loss = torch.sum(loss_tmp)
else:
raise NotImplementedError("Invalid reduction mode: {}"
.format(reduction))
return loss
class FocalLoss(nn.Module):
r"""Criterion that computes Focal loss.
According to [1], the Focal loss is computed as follows:
.. math::
\text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t)
where:
- :math:`p_t` is the model's estimated probability for each class.
Arguments:
alpha (float): Weighting factor :math:`\alpha \in [0, 1]`.
gamma (float): Focusing parameter :math:`\gamma >= 0`.
reduction (str, optional): Specifies the reduction to apply to the
output: ‘none’ | ‘mean’ | ‘sum’. ‘none’: no reduction will be applied,
‘mean’: the sum of the output will be divided by the number of elements
in the output, ‘sum’: the output will be summed. Default: ‘none’.
Shape:
- Input: :math:`(N, C, *)` where C = number of classes.
- Target: :math:`(N, *)` where each value is
:math:`0 ≤ targets[i] ≤ C−1`.
Examples:
>>> N = 5 # num_classes
>>> kwargs = {"alpha": 0.5, "gamma": 2.0, "reduction": 'mean'}
>>> loss = kornia.losses.FocalLoss(**kwargs)
>>> input = torch.randn(1, N, 3, 5, requires_grad=True)
>>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N)
>>> output = loss(input, target)
>>> output.backward()
References:
[1] https://arxiv.org/abs/1708.02002
"""
def __init__(self, alphas: Optional[List[float]], gamma: float = 2.0, # <- Change: alpha to alphas
reduction: str = 'none') -> None:
super(FocalLoss, self).__init__()
self.alphas: Optional[List[float]] = alphas # <- Change: alpha to alphas
self.gamma: float = gamma
self.reduction: str = reduction
self.eps: float = 1e-6
def forward( # type: ignore
self,
input: torch.Tensor,
target: torch.Tensor) -> torch.Tensor:
return focal_loss(input, target, self.alphas, self.gamma, self.reduction, self.eps) # <- Change: alpha to alphas
Issue Analytics
- State:
- Created 3 years ago
- Comments:11 (10 by maintainers)
Top Results From Across the Web
What is Focal Loss and when should you use it?
Give high weights to the rare class and small weights to the dominating or common class. These weights are referred to as α...
Read more >Understanding Focal Loss in 5 mins | Medium | VisionWizard
The focal loss gives less weight to easy examples and gives more weight to hard misclassified examples. This, in turn, helps to solve...
Read more >The alpha parameter of focal loss - Cross Validated
"A common method for addressing class imbalance is to introduce a weighting factor α ∈ [0, 1] for class 1 and 1 −...
Read more >Focal loss - Hasty.ai
This alpha is set inversely proportional to the number of examples for a particular class or is learned through cross-validation.
Read more >Multi-class classification with focal loss for imbalanced datasets
This tutorial will show you how to apply focal loss to train a multi-class classifier model given highly imbalanced datasets.
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
@hal-314 that sound good. Let continue this conversation in #722
My pleasure! I think so, too. Thanks for your great work.