question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Way to use weighted loss in config

See original GitHub issue

I use configs to train my nets and now I’ve faced imbalanced classification problem. It could be solved by training classifier with weighted loss and pytorch losses support weighting out of the box.

But I couldn’t use them because they accept only tensor input. If I provide a list of desired weights via config it get parsed to regular python list, so loss constructor fails.

To solve this I’ve written own wrapper class that constructs tensor and passes it to loss:

class WrappedCrossEntropyLoss(nn.CrossEntropyLoss):
    '''Implements standard CrossEntropyLoss, where weight param can be list, not tensor'''

    def __init__(self, weights: Optional[List] = None):
        weights_tensor = torch.tensor(weights) if weights is not None else None
        super().__init__(weights_tensor)

Then regular config works fine:

    criterion_params:
        criterion: WrappedCrossEntropyLoss
        weights: [0.2491, 1.0]

Is is intentional behaviour and could catalyst somehow cast such arguments automatically?

I guess it could be solved by enhancing config parsing algorithm. In case you are interested in solving this I could create some PR.

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:5 (2 by maintainers)

github_iconTop GitHub Comments

1reaction
v-goncharenkocommented, Feb 19, 2021

AFAIK, PyTorch always accepted only tensor as an input for loss classes (any other type just fails), so I believe they won’t change this (although I will search their issues for details). So PR to PyTorch doesn’t look like a probable solution.

So for now answer to my initial question is No - one couldn’t use weights param of nn.CrossEntropyLoss directly from config (without some wrapper class).

Regarding designing config syntax - I’ll contact you =)

0reactions
Scitatorcommented, Feb 19, 2021

Config-way of object creation is no the pure PyTorch way 😃 So, you need to write your wrappers.

btw, speaking of theweights example, I think it should work naturally thanks to python dynamic cast list -> tensor during the __init__ in the class. If no, it looks like we need to PR to the PyTorch 😃

another opinion, that I see - add eval option to the Catalyst Config API, for example

   criterion_params:
        criterion: WrappedCrossEntropyLoss
        weights: "torch.tensor([0.2491, 1.0])"

or

   criterion_params:
        criterion: WrappedCrossEntropyLoss
        weights: "[0.2491, 1.0])":torch.tensor

Nevertheless, we need to design it properly first. If you would like to make such contribution - feel free to write me in slack 😉

Read more comments on GitHub >

github_iconTop Results From Across the Web

Handling Class imbalanced data using a loss specifically ...
TL;DR — It proposes a class-wise re-weighting scheme for most frequently used losses (softmax-cross-entropy, focal loss, etc.) giving a quick boost of ...
Read more >
Weighted Loss in BertForTokenClassification #9625 - GitHub
Feature request BertForTokenClassification models can compute cross entropy loss currently is only weighted. The option to have different ...
Read more >
Tutorial 6: Customize Losses - MMDetection's documentation!
Weighting loss (step 3)​​ Weighting loss means we re-weight the loss element-wisely. To be more specific, we multiply the loss tensor with a...
Read more >
Custom weighted loss function in Keras for weighing each ...
UPDATE: It seems you want to give a different weight to each element ... into your model and then use this tensor within...
Read more >
How to set class weights for imbalanced classes in Keras?
Hence, the loss becomes a weighted average, where the weight of each sample is specified by class_weight and its corresponding class. From Keras...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found