question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

How can i obtain the epoch parameters to model function

See original GitHub issue

I find that the mmcv. Runner class is called for training. However I want to pass the training current epoch parameter to model. Not sure how I should change it in apis\train.py.

def train_detector(model,
                   dataset,
                   cfg,
                   distributed=False,
                   validate=False,
                   timestamp=None,
                   meta=None):
    logger = get_root_logger(cfg.log_level)
    # prepare data loaders
    dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
    # step 1: give default values and override (if exist) from cfg.data
    loader_cfg = {
        **dict(
            seed=cfg.get('seed'),
            drop_last=False,
            dist=distributed,
            num_gpus=len(cfg.gpu_ids)),
        **({} if torch.__version__ != 'parrots' else dict(
               prefetch_num=2,
               pin_memory=False,
           )),
    }

    # step 2: cfg.data.train_dataloader has highest priority
    train_loader_cfg = dict(loader_cfg, **cfg.data.get('train_dataloader', {}))

    data_loaders = [build_dataloader(ds, **train_loader_cfg) for ds in dataset]

    # put model on gpus
    model = MMDataParallel(
            model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)

    # build runner
    optimizer = build_optimizer(model, cfg.optimizer)

    if 'runner' not in cfg:
        cfg.runner = {
            'type': 'EpochBasedRunner',
            'max_epochs': cfg.total_epochs
        }
    else:
        if 'total_epochs' in cfg:
            assert cfg.total_epochs == cfg.runner.max_epochs

    runner = build_runner(
        cfg.runner,
        default_args=dict(
            model=model,
            optimizer=optimizer,
            work_dir=cfg.work_dir,
            logger=logger,
            meta=meta))

    # an ugly workaround to make .log and .log.json filenames the same
    runner.timestamp = timestamp

    # register hooks
    runner.register_training_hooks(
        cfg.lr_config,
        optimizer_config,
        cfg.checkpoint_config,
        cfg.log_config,
        cfg.get('momentum_config', None),
        custom_hooks_config=cfg.get('custom_hooks', None))
    if distributed:
        if isinstance(runner, EpochBasedRunner):
            runner.register_hook(DistSamplerSeedHook())

        eval_cfg = cfg.get('evaluation', {})
        eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
        eval_hook = DistEvalHook if distributed else EvalHook
        runner.register_hook(eval_hook(val_dataloader, **eval_cfg))

``

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:6 (3 by maintainers)

github_iconTop GitHub Comments

1reaction
gaotongxiaocommented, Mar 13, 2022

@simplify23 Sorry for the late reply. Here is an example of injecting the epoch number into DBNet’s head via the minimum changes in MMOCR.

In mmocr/models/textdet/dense_heads/db_head.py:

from mmcv.runner import RUNNERS, EpochBasedRunner

@RUNNERS.register_module()
class RunnerWrapper(EpochBasedRunner):
    def train(self, data_loader, **kwargs):
        self.model.train()
        self.model.module.bbox_head.epoch = self._epoch
        self.mode = 'train'
        self.data_loader = data_loader
        self._max_iters = self._max_epochs * len(self.data_loader)
        self.call_hook('before_train_epoch')
        time.sleep(2)  # Prevent possible deadlock during epoch transition
        for i, data_batch in enumerate(self.data_loader):
            self._inner_iter = i
            self.call_hook('before_train_iter')
            self.run_iter(data_batch, train_mode=True, **kwargs)
            self.call_hook('after_train_iter')
            self._iter += 1

        self.call_hook('after_train_epoch')
        self._epoch += 1

@HEADS.register_module()
class DBHead(HeadMixin, BaseModule):
...
    def forward(self, inputs):
        print(self.epoch)

Append the following line to the end of configs/textdet/dbnet/dbnet_r18_fpnc_1200e_icdar2015.py:

runner = dict(type='RunnerWrapper', max_epochs=1200)

Then you’ll be able to see the epoch number as you start training with this config.

0reactions
simplify23commented, Mar 14, 2022

I have implemented this function, thanks to both mmcv and mmocr’s authors~~

Read more comments on GitHub >

github_iconTop Results From Across the Web

Training & evaluation with the built-in methods - Keras
This guide covers training, evaluation, and prediction (inference) models when using built-in APIs for training & validation (such as ...
Read more >
Choose optimal number of epochs to train a neural network in ...
This model gives high accuracy on the training set (sample data) but fails to achieve good accuracy on the test set.
Read more >
How to access the weights (parameters) of Neural Network ...
I want to access the parameters after every 100 epochs during training so that I may use these parameters to print the validation...
Read more >
Epoch in Neural Networks | Baeldung on Computer Science
When building a neural network model, we set the number of epochs parameter before the training starts. However, initially, we can't know ...
Read more >
Difference Between a Batch and an Epoch in a Neural Network
The batch size is a hyperparameter that defines the number of samples to work through before updating the internal model parameters. Think of...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found