question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Adding `class_weights` argument for the loss function of transformers model

See original GitHub issue

🚀 Feature request

To provide a parameter called class_weights while initializing a sequence classification model. The attribute will be used to calculate the weighted loss which is useful for classification with imbalanced datasets.

from transformers import DistilBertForSequenceClassification

# Note the additional class_weights attribute
model = DistilBertForSequenceClassification.from_pretrained(
    "distilbert-base-uncased", 
    num_labels=5, 
    class_weights=[5, 3, 2, 1, 1])

class_weights will provide the same functionality as the weight parameter of Pytorch losses like torch.nn.CrossEntropyLoss.

Motivation

There have been similar issues raised before on “How to provide class weights for imbalanced classification dataset”. See #297, #1755,

And I ended up modifying the transformers code to get the class weights (shown below), and it looks like an easy addition which can benefit many.

Your contribution

This should be possible because currently the loss for Sequence classification in the forward method is initialized like below:

    loss_fct = nn.CrossEntropyLoss() # <- Defined without the weight parameter
    loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

And we can add the weight attribute of Pytorch and pass the class_weights recorded during model initialization.

    loss_fct = nn.CrossEntropyLoss(weight=self.class_weights)
    loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

I am happy to implement this and provide a PR. Although I am new to the transformers package and may require some iterative code reviews from the senior contributors/members.

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Reactions:6
  • Comments:18 (11 by maintainers)

github_iconTop GitHub Comments

8reactions
sguggercommented, Sep 12, 2020

It’s a bit hard to know how to guide you when you don’t explain to use how you train your model. Are you using Trainer? Then you should subclass it and override the brand new compute_loss method that I just added to make this use case super easy. There is an example in the docs (note that you will need an install from source for this).

3reactions
PhilipMaycommented, Sep 12, 2020

Ok. Super easy. Thanks @sgugger ! Thats it! 😃)

Read more comments on GitHub >

github_iconTop Results From Across the Web

How can I use class_weights when training? - Transformers
I have an unbalanced dataset. When training I want to pass class_weights so the update for rare classes is highen than for large...
Read more >
Using weights with transformers huggingface - Stack Overflow
I came across this two links - one and two which talk about using class weights when the data is unbalanced. # instantiate...
Read more >
How to set class weights for imbalanced classes in Keras?
class_weight : Optional dictionary mapping class indices (integers) to a weight (float) value, used for weighting the loss function (during training only). Share....
Read more >
Simple Ways to Tackle Class Imbalance – Weights & Biases
Class weights regularize the loss function. By misclassifying the minority class, a higher loss is incurred by the model since the minority class...
Read more >
This thing called Weight Decay - Towards Data Science
One way to penalize complexity, would be to add all our parameters (weights) to our loss function. Well, that won't quite work because...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found