Adding `class_weights` argument for the loss function of transformers model
See original GitHub issue🚀 Feature request
To provide a parameter called class_weights
while initializing a sequence classification model. The attribute will be used to calculate the weighted loss which is useful for classification with imbalanced datasets.
from transformers import DistilBertForSequenceClassification
# Note the additional class_weights attribute
model = DistilBertForSequenceClassification.from_pretrained(
"distilbert-base-uncased",
num_labels=5,
class_weights=[5, 3, 2, 1, 1])
class_weights will provide the same functionality as the weight
parameter of Pytorch losses like torch.nn.CrossEntropyLoss.
Motivation
There have been similar issues raised before on “How to provide class weights for imbalanced classification dataset”. See #297, #1755,
And I ended up modifying the transformers code to get the class weights (shown below), and it looks like an easy addition which can benefit many.
Your contribution
This should be possible because currently the loss for Sequence classification in the forward method is initialized like below:
loss_fct = nn.CrossEntropyLoss() # <- Defined without the weight parameter
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
And we can add the weight attribute of Pytorch and pass the class_weights recorded during model initialization.
loss_fct = nn.CrossEntropyLoss(weight=self.class_weights)
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
I am happy to implement this and provide a PR. Although I am new to the transformers package and may require some iterative code reviews from the senior contributors/members.
Issue Analytics
- State:
- Created 3 years ago
- Reactions:6
- Comments:18 (11 by maintainers)
It’s a bit hard to know how to guide you when you don’t explain to use how you train your model. Are you using
Trainer
? Then you should subclass it and override the brand newcompute_loss
method that I just added to make this use case super easy. There is an example in the docs (note that you will need an install from source for this).Ok. Super easy. Thanks @sgugger ! Thats it! 😃)