question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Assign different alphas to every class in Focal Loss

See original GitHub issue

Hi

First of all, great work with kornia! Currently, focal loss (FL) is implemented for multiclass tasks. However, its alpha parameter is shared between all classes when the original paper defines FL for binary tasks as: FL(pt)=−αt(1−pt)γlog(pt). See https://arxiv.org/abs/1708.02002 and #493 .

I modified the current FL code to support different alphas for every class (see below). The only relevant change is:

# Instead of using a constant, we change to a tensor of size (N, C, *) and rename alpha to alphas
alphas = torch.tensor(alphas, dtype=input.dtype, device=input.device).view(1, -1, *[1]*(input.ndim-2))      

Here is the whole code (I haven’t changed docs):

from typing import Optional, List  # <-Change:  Add List

import torch
import torch.nn as nn
import torch.nn.functional as F

from kornia.utils import one_hot


# based on:
# https://github.com/zhezh/focalloss/blob/master/focalloss.py

def focal_loss(
        input: torch.Tensor,
        target: torch.Tensor,
        alphas: Optional[List[float]], # <-Change:  rename to alphas and to a list of floats
        gamma: float = 2.0,
        reduction: str = 'none',
        eps: float = 1e-8) -> torch.Tensor:
    r"""Function that computes Focal loss.

    See :class:`~kornia.losses.FocalLoss` for details.
    """
    if not torch.is_tensor(input):
        raise TypeError("Input type is not a torch.Tensor. Got {}"
                        .format(type(input)))

    if not len(input.shape) >= 2:
        raise ValueError("Invalid input shape, we expect BxCx*. Got: {}"
                         .format(input.shape))
 
    if input.size(0) != target.size(0):
        raise ValueError('Expected input batch_size ({}) to match target batch_size ({}).'
                         .format(input.size(0), target.size(0)))

    n = input.size(0)
    out_size = (n,) + input.size()[2:]
    if target.size()[1:] != input.size()[2:]:
        raise ValueError('Expected target size {}, got {}'.format(
            out_size, target.size()))

    if not input.device == target.device:
        raise ValueError(
            "input and target must be in the same device. Got: {} and {}" .format(
                input.device, target.device))

    ### New addition ###
    alphas = torch.tensor(alphas, dtype=input.dtype, device=input.device).view(1, -1, *[1]*(input.ndim-2))      
    if alphas.size(1) != input.size(1):
        raise ValueError("Invalid alphas shape. we expect{} alpha values. Got: {}"
                         .format(input.size(1), alphas.size(1)))

    # Normalize alphas to sum up 1
    alphas.div_(alphas.sum())

    # Original code:

    # compute softmax over the classes axis
    input_soft: torch.Tensor = F.softmax(input, dim=1) + eps

    # create the labels one hot tensor
    target_one_hot: torch.Tensor = one_hot(
        target, num_classes=input.shape[1],
        device=input.device, dtype=input.dtype)

    # compute the actual focal loss
    weight = torch.pow(-input_soft + 1., gamma)

    focal = -alphas * weight * torch.log(input_soft) # <-Change:  alpha -> alphas
    loss_tmp = torch.sum(target_one_hot * focal, dim=1)

    if reduction == 'none':
        loss = loss_tmp
    elif reduction == 'mean':
        loss = torch.mean(loss_tmp)
    elif reduction == 'sum':
        loss = torch.sum(loss_tmp)
    else:
        raise NotImplementedError("Invalid reduction mode: {}"
                                  .format(reduction))
    return loss



class FocalLoss(nn.Module):
    r"""Criterion that computes Focal loss.

    According to [1], the Focal loss is computed as follows:

    .. math::

        \text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t)

    where:
       - :math:`p_t` is the model's estimated probability for each class.


    Arguments:
        alpha (float): Weighting factor :math:`\alpha \in [0, 1]`.
        gamma (float): Focusing parameter :math:`\gamma >= 0`.
        reduction (str, optional): Specifies the reduction to apply to the
         output: ‘none’ | ‘mean’ | ‘sum’. ‘none’: no reduction will be applied,
         ‘mean’: the sum of the output will be divided by the number of elements
         in the output, ‘sum’: the output will be summed. Default: ‘none’.

    Shape:
        - Input: :math:`(N, C, *)` where C = number of classes.
        - Target: :math:`(N, *)` where each value is
          :math:`0 ≤ targets[i] ≤ C−1`.

    Examples:
        >>> N = 5  # num_classes
        >>> kwargs = {"alpha": 0.5, "gamma": 2.0, "reduction": 'mean'}
        >>> loss = kornia.losses.FocalLoss(**kwargs)
        >>> input = torch.randn(1, N, 3, 5, requires_grad=True)
        >>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N)
        >>> output = loss(input, target)
        >>> output.backward()

    References:
        [1] https://arxiv.org/abs/1708.02002
    """

    def __init__(self, alphas: Optional[List[float]], gamma: float = 2.0, # <- Change:  alpha to alphas
                 reduction: str = 'none') -> None:
        super(FocalLoss, self).__init__()
        self.alphas: Optional[List[float]] = alphas # <- Change:  alpha to alphas
        self.gamma: float = gamma
        self.reduction: str = reduction
        self.eps: float = 1e-6

    def forward(  # type: ignore
            self,
            input: torch.Tensor,
            target: torch.Tensor) -> torch.Tensor:
        return focal_loss(input, target, self.alphas, self.gamma, self.reduction, self.eps) # <- Change:  alpha to alphas

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:11 (10 by maintainers)

github_iconTop GitHub Comments

1reaction
edgarribacommented, Oct 23, 2020

@hal-314 that sound good. Let continue this conversation in #722

1reaction
iimogcommented, Oct 19, 2020

My pleasure! I think so, too. Thanks for your great work.

Read more comments on GitHub >

github_iconTop Results From Across the Web

What is Focal Loss and when should you use it?
Give high weights to the rare class and small weights to the dominating or common class. These weights are referred to as α...
Read more >
Understanding Focal Loss in 5 mins | Medium | VisionWizard
The focal loss gives less weight to easy examples and gives more weight to hard misclassified examples. This, in turn, helps to solve...
Read more >
The alpha parameter of focal loss - Cross Validated
"A common method for addressing class imbalance is to introduce a weighting factor α ∈ [0, 1] for class 1 and 1 −...
Read more >
Focal loss - Hasty.ai
This alpha is set inversely proportional to the number of examples for a particular class or is learned through cross-validation.
Read more >
Multi-class classification with focal loss for imbalanced datasets
This tutorial will show you how to apply focal loss to train a multi-class classifier model given highly imbalanced datasets.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found