Target loss
See original GitHub issueHi, I’m working with a custom dataset, which contains images with small defects, which are annotated with bounding boxes. The goal in my case is to detect these bugs, so I have set the number of classes to 2, one for defect, and the other one non-defect. The targets dicts contain the key “labels” containing a tensor of dim [nb_target_boxes]
, which I have created as torch.ones((num_objs,), dtype=torch.int64)
.
def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
"""Classification loss (NLL)
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
"""
assert 'pred_logits' in outputs
src_logits = outputs['pred_logits']
idx = self._get_src_permutation_idx(indices)
target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
target_classes = torch.full(src_logits.shape[:2], self.num_classes,
dtype=torch.int64, device=src_logits.device)
target_classes[idx] = target_classes_o
loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
losses = {'loss_ce': loss_ce}
if log:
# TODO this should probably be a separate loss, not hacked in this one here
losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0]
return losses
I’m trying to understand how the loss of labels is computed but I don’t understand certain things about the loss_labels function. I have several doubts:
-
First of all, I would like to know if the way I pass the labels is the right way for a two-class problem and what I should put in num_classes, 1 or 2. There are only bounding boxes around a defect, so I pass only these labels, as I explained above.
-
In the code above, src_logits why does it give me a dimension output [batch_size,num_queries,num_classes+1]? why +1?
-
target_classes_o: I don’t understand exactly what this is, is it the ground truth?
-
target_classes: I understand that here first we have all as a non-defect class and in the indexes that correspond to defect I change the target, but it is not clear to me.
-
What is the function of empty_weight when calculating the loss?
-
If I am working with only two classes, should I put binary_cross_entropy instead of cross_entropy?
Thank you in advance.
Issue Analytics
- State:
- Created 3 years ago
- Comments:6 (3 by maintainers)
Top GitHub Comments
So if you decide to use label = 1 for the things you want to detect you need to:
Hope this helps.
Hi @alcinos, I didn’t realize about the sigmoid, thank you. Oh, okay, in my case the objects are smaller than 16x16 pixels, so I understand that DETR won’t work well…