question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

How do I compute validation loss during training?

See original GitHub issue

How do I compute validation loss during training?

I’m trying to compute the loss on a validation dataset for each iteration during training. To do so, I’ve created my own hook:

class ValidationLoss(detectron2.engine.HookBase):
    def __init__(self, config, dataset_name):
        super(ValidationLoss, self).__init__()
        self._loader = detectron2.data.build_detection_test_loader(config, dataset_name)
        
    def after_step(self):
        for batch in self._loader:
            loss = self.trainer.model(batch)
            log.debug(f"validation loss: {loss}")

… which I register with a DefaultTrainer. The hook code is called during training, but fails with the following:

INFO:detectron2.engine.train_loop:Starting training from iteration 0
ERROR:detectron2.engine.train_loop:Exception during training:
Traceback (most recent call last):
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/detectron2/engine/train_loop.py", line 133, in train
    self.after_step()
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/detectron2/engine/train_loop.py", line 153, in after_step
    h.after_step()
  File "<ipython-input-6-63b308743b7d>", line 8, in after_step
    loss = self.trainer.model(batch)
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/detectron2/modeling/meta_arch/rcnn.py", line 123, in forward
    proposals, proposal_losses = self.proposal_generator(images, features, gt_instances)
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/detectron2/modeling/proposal_generator/rpn.py", line 164, in forward
    losses = {k: v * self.loss_weight for k, v in outputs.losses().items()}
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/detectron2/modeling/proposal_generator/rpn_outputs.py", line 322, in losses
    gt_objectness_logits, gt_anchor_deltas = self._get_ground_truth()
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/detectron2/modeling/proposal_generator/rpn_outputs.py", line 262, in _get_ground_truth
    for image_size_i, anchors_i, gt_boxes_i in zip(self.image_sizes, anchors, self.gt_boxes):
TypeError: zip argument #3 must support iteration
INFO:detectron2.engine.hooks:Total training time: 0:00:00 (0:00:00 on hooks)

The traceback seems to imply that ground truth data is missing, which made me think that the data loader was the problem. However, switching to a training loader produces a different error:

class ValidationLoss(detectron2.engine.HookBase):
    def __init__(self, config, dataset_name):
        super(ValidationLoss, self).__init__()
        self._loader = detectron2.data.build_detection_train_loader(config, dataset_name)
        
    def after_step(self):
        for batch in self._loader:
            loss = self.trainer.model(batch)
            log.debug(f"validation loss: {loss}")
INFO:detectron2.engine.train_loop:Starting training from iteration 0
ERROR:detectron2.engine.train_loop:Exception during training:
Traceback (most recent call last):
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/detectron2/engine/train_loop.py", line 133, in train
    self.after_step()
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/detectron2/engine/train_loop.py", line 153, in after_step
    h.after_step()
  File "<ipython-input-6-e0d2c509cc72>", line 7, in after_step
    for batch in self._loader:
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/detectron2/data/common.py", line 109, in __iter__
    for d in self.dataset:
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 345, in __next__
    data = self._next_data()
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 856, in _next_data
    return self._process_data(data)
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 881, in _process_data
    data.reraise()
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/torch/_utils.py", line 394, in reraise
    raise self.exc_type(msg)
TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
    data = fetcher.fetch(index)
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/detectron2/data/common.py", line 39, in __getitem__
    data = self._map_func(self._dataset[cur_idx])
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/detectron2/utils/serialize.py", line 23, in __call__
    return self._obj(*args, **kwargs)
TypeError: 'str' object is not callable

INFO:detectron2.engine.hooks:Total training time: 0:00:00 (0:00:00 on hooks)

As a sanity check, inference works just fine:

class ValidationLoss(detectron2.engine.HookBase):
    def __init__(self, config, dataset_name):
        super(ValidationLoss, self).__init__()
        self._loader = detectron2.data.build_detection_test_loader(config, dataset_name)
        
    def after_step(self):
        for batch in self._loader:
            with detectron2.evaluation.inference_context(self.trainer.model):
                loss = self.trainer.model(batch)
                log.debug(f"validation loss: {loss}")
INFO:detectron2.engine.train_loop:Starting training from iteration 0
DEBUG:root:validation loss: [{'instances': Instances(num_instances=100, image_height=720, image_width=720, fields=[pred_boxes = Boxes(tensor([[4.4867e+02, 1.9488e+02, 5.1496e+02, 3.9878e+02],
        [4.2163e+02, 1.1204e+02, 6.1118e+02, 5.5378e+02],
        [8.7323e-01, 3.0374e+02, 9.2917e+01, 3.8698e+02],
        [4.3202e+02, 2.0296e+02, 5.7938e+02, 3.6817e+02],
        ...

… but that isn’t what I want, of course. Any thoughts?

Thanks in advance, Tim

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Comments:35

github_iconTop GitHub Comments

48reactions
mnslarchercommented, Mar 8, 2020

Hi,

I have an hacky solution for this, I’ll leave it here in case anyone needs it or someone has suggestions on how to improve it.

from detectron2.engine import HookBase
from detectron2.data import build_detection_train_loader
import detectron2.utils.comm as comm

cfg.DATASETS.VAL = ("voc_2007_val",)


class ValidationLoss(HookBase):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg.clone()
        self.cfg.DATASETS.TRAIN = cfg.DATASETS.VAL
        self._loader = iter(build_detection_train_loader(self.cfg))
        
    def after_step(self):
        data = next(self._loader)
        with torch.no_grad():
            loss_dict = self.trainer.model(data)
            
            losses = sum(loss_dict.values())
            assert torch.isfinite(losses).all(), loss_dict

            loss_dict_reduced = {"val_" + k: v.item() for k, v in 
                                 comm.reduce_dict(loss_dict).items()}
            losses_reduced = sum(loss for loss in loss_dict_reduced.values())
            if comm.is_main_process():
                self.trainer.storage.put_scalars(total_val_loss=losses_reduced, 
                                                 **loss_dict_reduced)

And then

os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
trainer = Trainer(cfg) 
val_loss = ValidationLoss(cfg)  
trainer.register_hooks([val_loss])
# swap the order of PeriodicWriter and ValidationLoss
trainer._hooks = trainer._hooks[:-2] + trainer._hooks[-2:][::-1]
trainer.resume_or_load(resume=False)
trainer.train()
Read more comments on GitHub >

github_iconTop Results From Across the Web

Validation loss - neural network - Data Science Stack Exchange
It is calculated in the same way - by running the network forward over inputs xi and comparing the network outputs ˆyi with...
Read more >
How to compute the validation loss? (Simple linear regression)
The code that you have written is first training the model for the ... then once the model is trained, it is calculating...
Read more >
Your validation loss is lower than your training loss? This is why!
The regularization terms are only applied while training the model on the training set, inflating the training loss . During validation and testing,...
Read more >
Plotting the Training and Validation Loss Curves for the ...
Training the Transformer Model ; # Compute the training accuracy ; # Retrieve gradients of the trainable variables with respect to the training...
Read more >
How to calculate the validation loss during each epoch of ...
I thought out was the output of the last layer of my GNN (which does binary classification). Therefore, to get the training loss...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found