exporting keypoint_rcnn_fbnetv3a_dsmask_C4 pretrained
See original GitHub issueHello all
Could anyone ever export torchscript file for keypoint_rcnn_fbnetv3a_dsmask_C4 pre-trained model using create_d2go.py file? I altered the Wrapper to return “keypoints” beside others (“boxes”,“scores”,“labels”). “Keypoints” are in the out[3] in Wrapper. res[“scores”] = out[2]
Here is the code I use to export the model is: ` #!/usr/bin/env python3
import contextlib
import copy
import os
import unittest
from PIL import Image
import torch
from d2go.export.api import convert_and_export_predictor
from d2go.export.d2_meta_arch import patch_d2_meta_arch
from d2go.runner import create_runner, GeneralizedRCNNRunner
from d2go.model_zoo import model_zoo
from typing import List, Dict
from mobile_cv.common.misc.file_utils import make_temp_directory
from d2go.utils.testing.data_loader_helper import LocalImageGenerator, register_toy_dataset
from d2go.utils.testing.data_loader_helper import create_fake_detection_data_loader
patch_d2_meta_arch()
cfg_name = 'keypoint_rcnn_fbnetv3a_dsmask_C4.yaml'
pytorch_model = model_zoo.get(cfg_name, trained=True)
# pytorch_model.training=False
# pytorch_model.eval()
class Wrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
coco_idx_list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,
27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51,
52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77,
78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90, 91]
self.coco_idx = torch.tensor(coco_idx_list)
def forward(self, inputs: List[torch.Tensor]):
x = inputs[0].unsqueeze(0) * 255
scale = 320.0 / min(x.shape[-2], x.shape[-1])
x = torch.nn.functional.interpolate(x, scale_factor=scale, mode="bilinear", align_corners=True, recompute_scale_factor=True)
out = self.model(x[0])
res=( out[3],
out[0] / scale,
torch.index_select(self.coco_idx, 0, out[1]),
out[4],)
return res
size_divisibility = max(pytorch_model.backbone.size_divisibility, 10)
h, w = size_divisibility, size_divisibility * 2
with create_fake_detection_data_loader(h, w, is_train=False) as data_loader:
predictor_path = convert_and_export_predictor(
model_zoo.get_config(cfg_name),
copy.deepcopy(pytorch_model),
"torchscript_int8@tracing",
'./',
data_loader,
)
orig_model = torch.jit.load(os.path.join(predictor_path, "model.jit"))
wrapped_model = Wrapper(orig_model)
# optionally do a forward
import cv2
im = cv2.imread("inp8.jpg",)
im=torch.tensor(im)/255
im=torch.reshape(im,(3,im.shape[0],im.shape[1]))
wrapped_model([im])
scripted_model = torch.jit.script(wrapped_model)
scripted_model.save("d2go_tracker_temp.pt")
` when I use exported .pt file in android I get corrupted keypoints coordinates and it seems it’s due to the TracerWarnings meanwhile I export torchscript file ( Converting a tensor to other python types cause the value to be constant in the torchscript output file)
I am pretty sure the input format of the forwarding path in android is correct. The output of the model in android for keypoint_rcnn_fbnetv3a_dsmask_C4 model is “boxes”,“scores”,“labels”,“keypoints”. but “keypoints” are not correct. others are fine and I can draw boxes around “persons”.
my meaning of corrupted keypoints : for each keypoint the model in android returns the same (x,y,probability) ## Expected behavior: get the same output as the time I run the model in python3 with DemoPredicator.
Issue Analytics
- State:
- Created 2 years ago
- Comments:6 (2 by maintainers)
Top GitHub Comments
@sadegh16 Hi is it still an issue? Seems the latest error message is related to data loading (maybe the dataset is not installed correctly), to debug could you disable error handling by appending
D2GO_DATA.MAPPER.CATCH_EXCEPTION False
to the command?Hi there, thank you for your feedback! Can you provide the detailed tracer warning? Also, could you try exporting the int8 model.jit with https://github.com/facebookresearch/d2go/tree/master/demo and test the model.jit and wrapper again?