question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Quantization-aware training with the API

See original GitHub issue

Instructions To Reproduce the 🐛 Bug:

I tried to add quantization-aware training to the d2go_beginner.ipynb notebook, but I couldn’t get it to work.

Code:

from d2go.runner import Detectron2GoRunner


def prepare_for_launch():
    runner = Detectron2GoRunner()
    cfg = runner.get_default_cfg()
    cfg.merge_from_file(model_zoo.get_config_file("faster_rcnn_fbnetv3a_C4.yaml"))
    cfg.MODEL_EMA.ENABLED = False
    cfg.DATASETS.TRAIN = ("balloon_train",)
    cfg.DATASETS.TEST = ("balloon_val",)
    cfg.DATALOADER.NUM_WORKERS = 2
    cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("faster_rcnn_fbnetv3a_C4.yaml")  # Let training initialize from model zoo
    cfg.SOLVER.IMS_PER_BATCH = 2
    cfg.SOLVER.BASE_LR = 0.00025  # pick a good LR
    cfg.SOLVER.MAX_ITER = 600    # 300 iterations seems good enough for this toy dataset; you will need to train longer for a practical dataset
    cfg.SOLVER.STEPS = []        # do not decay learning rate
    cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128   # faster, and good enough for this toy dataset (default: 512)
    cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1  # only has one class (ballon). (see https://detectron2.readthedocs.io/tutorials/datasets.html#update-the-config-for-new-datasets)
    # NOTE: this config means the number of classes, but a few popular unofficial tutorials incorrect uses num_classes+1 here.

    # quantization-aware training
    cfg.QUANTIZATION.BACKEND = "qnnpack"
    cfg.QUANTIZATION.QAT.ENABLED = True
    cfg.QUANTIZATION.QAT.START_ITER = 0
    cfg.QUANTIZATION.QAT.ENABLE_OBSERVER_ITER = 0
    cfg.QUANTIZATION.QAT.DISABLE_OBSERVER_ITER = 5
    cfg.QUANTIZATION.QAT.FREEZE_BN_ITER = 7

    os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
    return cfg, runner

cfg, runner = prepare_for_launch()
print(cfg)
model = runner.build_model(cfg)
runner.do_train(cfg, model, resume=False)

Error message:

AssertionError                            Traceback (most recent call last)
<ipython-input-11-327fe2b2a9ce> in <module>()
     32 cfg, runner = prepare_for_launch()
     33 print(cfg)
---> 34 model = runner.build_model(cfg)
     35 runner.do_train(cfg, model, resume=False)

15 frames
/usr/local/lib/python3.7/dist-packages/torch/quantization/fuser_method_mappings.py in get_fuser_method(op_list, additional_fuser_method_mapping)
    129                                      additional_fuser_method_mapping)
    130     fuser_method = all_mappings.get(op_list, None)
--> 131     assert fuser_method is not None, "did not find fuser method for: {} ".format(op_list)
    132     return fuser_method

AssertionError: did not find fuser method for: (<class 'torch.nn.modules.conv.Conv2d'>, <class 'mobile_cv.arch.layers.batch_norm.NaiveSyncBatchNorm'>, <class 'torch.nn.modules.activation.ReLU'>) 

Expected behavior:

Quantization-aware training should work with the API.

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:10 (7 by maintainers)

github_iconTop GitHub Comments

4reactions
chiehpowercommented, Mar 31, 2021

Thanks for you guy’s contribution!!

I am wondering is there any further update?

According to my experiments, I was using the balloon dataset to do QAT experiments and the results were quite bad. To be precise, the results were strange that all score values were the same (all are 50%) indeed.

5603212091_2dfe16ea72_b jpg_predict_result

When I switched the model from QAT-model to normal model which was trained by the code from the demo jupyter notebook, the predictions were more correct.

In addition, I compared the inference time between QAT-mdoel and the normal model, the inference time of both models were the same. I anticipate that the speed will be accelerated even the predicting ability is not good.

In addition, I also expect that the QAT-model size could be smaller than the normal model but it was not actually.

Thank you.

I was using the script of train_net.py to train the QAT model directly.

4reactions
TannerGilbertcommented, Mar 22, 2021

I now tried multiple sets of hyperparameters, and I achieved a decent total_loss with the following hyperparameters:

from d2go.runner import Detectron2GoRunner


def prepare_for_launch():
    runner = Detectron2GoRunner()
    cfg = runner.get_default_cfg()
    cfg.merge_from_file(model_zoo.get_config_file("faster_rcnn_fbnetv3a_C4.yaml"))
    cfg.MODEL_EMA.ENABLED = False
    cfg.DATASETS.TRAIN = ("balloon_train",)
    cfg.DATASETS.TEST = ("balloon_val",)
    cfg.DATALOADER.NUM_WORKERS = 2
    cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("faster_rcnn_fbnetv3a_C4.yaml")
    cfg.SOLVER.IMS_PER_BATCH = 2
    cfg.SOLVER.BASE_LR = 0.001 
    cfg.SOLVER.MAX_ITER = 1000    
    cfg.SOLVER.STEPS = []
    cfg.SOLVER.WARMUP_ITERS = 1500   
    cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 512 
    cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1 

    # quantization-aware training
    cfg.QUANTIZATION.BACKEND = "qnnpack"
    cfg.QUANTIZATION.QAT.ENABLED = True
    cfg.QUANTIZATION.QAT.START_ITER = 500
    cfg.QUANTIZATION.QAT.ENABLE_OBSERVER_ITER = 500
    cfg.QUANTIZATION.QAT.DISABLE_OBSERVER_ITER = 1000
    cfg.QUANTIZATION.QAT.FREEZE_BN_ITER = 1000
    cfg.MODEL.FBNET_V2.NORM = "bn"
    cfg.MODEL.ROI_BOX_HEAD.NORM = "bn"

    os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
    return cfg, runner

cfg, runner = prepare_for_launch()
model = runner.build_model(cfg)
runner.do_train(cfg, model, resume=False)

Output:

[03/22 08:01:00 d2.utils.events]:  eta: 0:00:57  iter: 839  total_loss: 0.1025  loss_cls: 0.04144  loss_box_reg: 0.04826  loss_rpn_cls: 0.01209  loss_rpn_loc: 0.003019  time: 0.4089  data_time: 0.0099  lr: 0.00028011  max_mem: 1981M
[03/22 08:01:11 d2.utils.events]:  eta: 0:00:51  iter: 859  total_loss: 0.09382  loss_cls: 0.04182  loss_box_reg: 0.04343  loss_rpn_cls: 0.009706  loss_rpn_loc: 0.003442  time: 0.4116  data_time: 0.0093  lr: 0.00028676  max_mem: 1981M
[03/22 08:01:21 d2.utils.events]:  eta: 0:00:44  iter: 879  total_loss: 0.0969  loss_cls: 0.0401  loss_box_reg: 0.04823  loss_rpn_cls: 0.01033  loss_rpn_loc: 0.003244  time: 0.4142  data_time: 0.0085  lr: 0.00029341  max_mem: 1981M
[03/22 08:01:31 d2.utils.events]:  eta: 0:00:38  iter: 899  total_loss: 0.1273  loss_cls: 0.04177  loss_box_reg: 0.07245  loss_rpn_cls: 0.00876  loss_rpn_loc: 0.003226  time: 0.4165  data_time: 0.0088  lr: 0.00030007  max_mem: 1981M
[03/22 08:01:42 d2.utils.events]:  eta: 0:00:31  iter: 919  total_loss: 0.08434  loss_cls: 0.0404  loss_box_reg: 0.04349  loss_rpn_cls: 0.008288  loss_rpn_loc: 0.002168  time: 0.4190  data_time: 0.0085  lr: 0.00030672  max_mem: 1981M
[03/22 08:01:52 d2.utils.events]:  eta: 0:00:24  iter: 939  total_loss: 0.09694  loss_cls: 0.04067  loss_box_reg: 0.04277  loss_rpn_cls: 0.01131  loss_rpn_loc: 0.003144  time: 0.4210  data_time: 0.0087  lr: 0.00031337  max_mem: 1981M
[03/22 08:02:03 d2.utils.events]:  eta: 0:00:16  iter: 959  total_loss: 0.07904  loss_cls: 0.03692  loss_box_reg: 0.03721  loss_rpn_cls: 0.01094  loss_rpn_loc: 0.002774  time: 0.4234  data_time: 0.0102  lr: 0.00032003  max_mem: 1981M
[03/22 08:02:14 d2.utils.events]:  eta: 0:00:08  iter: 979  total_loss: 0.1544  loss_cls: 0.04261  loss_box_reg: 0.08179  loss_rpn_cls: 0.009008  loss_rpn_loc: 0.005097  time: 0.4255  data_time: 0.0121  lr: 0.00032668  max_mem: 1981M
[03/22 08:02:24 d2.utils.events]:  eta: 0:00:00  iter: 999  total_loss: 0.1233  loss_cls: 0.03635  loss_box_reg: 0.06346  loss_rpn_cls: 0.009787  loss_rpn_loc: 0.003979  time: 0.4274  data_time: 0.0093  lr: 0.00033333  max_mem: 1981M

But if I try to make predictions on the validation set, nothing is getting detected. Do I need to pre-process the images differently now?

Code:

from detectron2.engine import DefaultPredictor
from detectron2.utils.visualizer import ColorMode
import random

cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth")
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5   # set the testing threshold for this model
predictor = DefaultPredictor(cfg)

dataset_dicts = DatasetCatalog.get('balloon_val')
for d in random.sample(dataset_dicts, 3):    
    im = cv2.imread(d["file_name"])
    outputs = predictor(im)
    v = Visualizer(im[:, :, ::-1], metadata=balloon_metadata, scale=0.8)
    v = v.draw_instance_predictions(outputs["instances"].to("cpu"))
    plt.figure(figsize = (14, 10))
    plt.imshow(cv2.cvtColor(v.get_image()[:, :, ::-1], cv2.COLOR_BGR2RGB))
    plt.show()

Full notebook: https://colab.research.google.com/drive/1BSa319b6QCfX4yEJ-Zedi5KGnT8ldZ_w?usp=sharing

Read more comments on GitHub >

github_iconTop Results From Across the Web

Quantization aware training - Model optimization - TensorFlow
Experiment with quantization algorithms that span Keras layers or require the training step. Stabilize APIs. Results. Image classification with ...
Read more >
TF2 Object Detect API Quantization Aware Training #8935
Great to see the Tensorflow 2 Object Detect API has been released. One feature I'm very interested in is quantization aware training (as...
Read more >
vai_q_tensorflow Quantization Aware Training - 2.5 English
Quantization aware training (QAT) is similar to float model training/finetuning, but in QAT, the vai_q_tensorflow APIs are used to rewrite the float graph ......
Read more >
Transfer Learning with Quantization Aware Training using ...
I'd like to understand what is the correct way to perform quantization-aware-training in this case? python · tensorflow · machine-learning ...
Read more >
Quantization of TensorFlow Object Detection API Models
Quantization -aware training gives less accuracy drop compared to post-training quantization and allows us to recover most of the accuracy loss ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found