question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

I had to review a lot of documentation and issues to implement the training code. So here is the code you’ll be needing for training.

Initially, you’ll need to download the pretrained model files from https://drive.google.com/drive/folders/1umYmlCulvIFNaqPjwod1SayFmSRHziyR?usp=sharing and move it to MODNet/pretrained. In case you need to fine-tune the model to your own dataset, download - modnet_photographic_portrait_matting.ckpt. In case you need to use the backbone mobilenetv2 model, download that too.

For preparing the dataset, I prepared a pandas dataframe which had 2 columns - [“image”, “matte”] “image” had the absolute path to the images’ location and “matte” had that respective image’s matte image location.

After downloading, for preprocessing, the code is:

class ModNetDataLoader(Dataset):
    def __init__(self, annotations_file, resize_dim, transform=None):
        self.img_labels =annotations_file
        self.transform=transform
        self.resize_dim=resize_dim

    def __len__(self):
        #return the total number of images
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = self.img_labels.iloc[idx,0]
        mask_path = self.img_labels.iloc[idx,1]

        img = np.asarray(Image.open(img_path))

        in_image = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)
        mask = in_image[:,:,3]

        if len(img.shape)==2:
            img = img[:,:,None]
        if img.shape[2]==1:
            img = np.repeat(img, 3, axis=2)
        elif img.shape[2]==4:
            img = img[:,:,0:3]

        if len(mask.shape)==3:
            mask = mask[:,:, 0]

        #convert Image to pytorch tensor
        img = Image.fromarray(img)
        mask = Image.fromarray(mask)
        if self.transform:
            img = self.transform(img)
            trimap = self.get_trimap(mask)
            mask = self.transform(mask)

        img = self._resize(img)
        mask = self._resize(mask)
        trimap = self._resize(trimap, trimap=True)

        img = torch.squeeze(img, 0)
        mask = torch.squeeze(mask, 0)
        trimap = torch.squeeze(trimap, 1)

        return img, trimap, mask

    def get_trimap(self, alpha):
        # alpha \in [0, 1] should be taken into account
        # be careful when dealing with regions of alpha=0 and alpha=1
        fg = np.array(np.equal(alpha, 255).astype(np.float32))
        unknown = np.array(np.not_equal(alpha, 0).astype(np.float32)) # unknown = alpha > 0
        unknown = unknown - fg
        # image dilation implemented by Euclidean distance transform
        unknown = morphology.distance_transform_edt(unknown==0) <= np.random.randint(1, 20)
        trimap = fg
        trimap[unknown] = 0.5
        return torch.unsqueeze(torch.from_numpy(trimap), dim=0)#.astype(np.uint8)

    def _resize(self, img, trimap=False):
        im = img[None, :, :, :]
        ref_size = self.resize_dim

        # resize image for input
        im_b, im_c, im_h, im_w = im.shape
        if max(im_h, im_w) < ref_size or min(im_h, im_w) > ref_size:
            if im_w >= im_h:
                im_rh = ref_size
                im_rw = int(im_w / im_h * ref_size)
            elif im_w < im_h:
                im_rw = ref_size
                im_rh = int(im_h / im_w * ref_size)
        else:
            im_rh = im_h
            im_rw = im_w

        im_rw = im_rw - im_rw % 32
        im_rh = im_rh - im_rh % 32
        if trimap == True:
            im = F.interpolate(im, size=(im_rh, im_rw), mode='nearest')
        else:
            im = F.interpolate(im, size=(im_rh, im_rw), mode='area')
        return im

You might need to change the above code in methods, get_trimap and get_item according to your dataset You would need to verify if your data is proper in the next to next step

Finally, create your dataset using the code below:

transformer = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.5), (0.5)
            )
        ]
    )
data = ModNetDataLoader(data_csv, 512, transform=transformer)

After your dataset has been created, 1st verify it by printing the first row of data and verifying if the shapes of image, matte, trimap are equal (only the channels can be different). IMPORTANT: Try printing the first of the trimaps. The only values in the numpy array should be 0, 0.5 and 1. Use the dataloader function to prepare your data for training:

train_dataloader = DataLoader(data, batch_size=8, shuffle=True)

After this, the code for training is available in the trainer.py file:

import torch
from src.models.modnet import MODNet
from src.trainer import supervised_training_iter
bs = 16         # batch size
lr = 0.01       # learn rate
epochs = 40     # total epochs

modnet = torch.nn.DataParallel(MODNet()).cuda()
optimizer = torch.optim.SGD(modnet.parameters(), lr=lr, momentum=0.9)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(0.25 * epochs), gamma=0.1)

for epoch in range(0, epochs):
    for idx, (image, trimap, gt_matte) in enumerate(dataloader):
        semantic_loss, detail_loss, matte_loss = \
            supervised_training_iter(modnet, optimizer, image.cuda(), trimap.cuda(), gt_matte.cuda())
    lr_scheduler.step()

For using the backbone - Change modnet = torch.nn.DataParallel(MODNet()).cuda() to modnet = torch.nn.DataParallel(MODNet(backbone_pretrained=True)).cuda() in case you have the mobilenetv2 in your pretrained directory.

For fine-tuning the existing MODNet model, use this snippet before the optimizer line:

modnet = torch.nn.DataParallel(MODNet()).cuda()
state_dict = torch.load("path_to_torchscript_model.ckpt")
modnet.load_state_dict(state_dict)
modnet.train()

Issue Analytics

  • State:open
  • Created a year ago
  • Comments:5

github_iconTop GitHub Comments

1reaction
yashsandansingcommented, Nov 4, 2022

@SamStark-AtWork I was the one who wrote this training code and even I couldn’t get it to work later. The loss remains constant more or less throughout the training process. I read in one of the issues that if you set the model to model.eval() in the trainer function before training, you might get better results. But this script has gotten me some terrible results on multiple datasets. People have gotten it to work but there have been 0 training scripts on here. I think I saw a training script pending approval in the PR section once. You can maybe try that code?

0reactions
SamStark-AtWorkcommented, Nov 9, 2022

Some updates, so I’ve taken a look at the training script mentioned by @yashsandansing, to be honest it’s not much difference, the biggest different is how the original trimap was produced (which did work a bit better), and most of the importing functions and utilities, there are some slight adjustments here and there for the training code on different commits that person did, but overall they don’t produce much differences.

For the problem I described above where there are 0 changes by epoch, I realized the LR scheduler was stepping thru each batch (should be each epoch), which was my mistake, the problem is gone after that was dealt with.

Below section is more abt the quality and anyone who would like to try train/tune the model.

For context, I’m trying to produce a model that can produce matte of an object in given picture (not human in portrait). My dataset is self-produced, and very questionable quality (it’s most likely one of the reason why I couldn’t get great result), and they also only contains very similar objects differ in angles, which might explains the overfitting.

I tried both training the model from scratch and fine-tuning a pre-trained model. The result wasn’t good, feels like there’s a lot of overfitting (might work better if I try a better dataset). However, it does works better than a generic VGG16 autoencoder, and at a much lower VRAM cost too (which is incredible). Fine-tuning the pre-trained model does work better than expected, both method do seem to try and approach to a closer results than when they started.

I also tried some metaparams, the current optimizer and its params are pretty good already. The biggest changes I noticed is on the batch size, I tried 4 and 32, bigger batch size does produce better result (less overfitting).

The result accuracy jumped significantly between no-train and after first epoch, any changes after that is small to almost none, and my dataset contains a lot of similar pictures, which is why I deduced that overfitting is happening.

So far I haven’t seen anyone that did train/fine-tune a trained model that gives great result yet, so do be aware of that if you are interested in doing it your own.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Codecademy: Learn to Code - for Free
Learn the technical skills to get the job you want. Join over 50 million people choosing Codecademy to start a new career (or...
Read more >
CODE Training | Classes
We offer standard classes as well as custom and individual training that can be delivered on-site at your location, in our training center,...
Read more >
Codewars - Achieve mastery through coding practice and ...
Achieve mastery through challenge. Improve your development skills by training with your peers on code kata that continuously challenge and push your coding ......
Read more >
Learn to Code — For Free — Coding Courses for Busy People
Learn to code — for free. Build projects. Earn certifications. Since 2014, more than 40,000 freeCodeCamp.org graduates have gotten jobs at tech companies ......
Read more >
Prepare training code | Vertex AI | Google Cloud
Perform custom training on Vertex AI to run your own machine learning (ML) training code in the cloud, instead of using AutoML.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found