question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Transfer-learning to improve accuracy for a specific font and background

See original GitHub issue

In my specific use-case, I only need to recognize texts of which the font is always the same and the background is always darker than the foreground. Additionally, the character set is smaller than the one that the default latin.pth checkpoint is trained on. The default model does not meet my accuracy requirements, even after fine-tuning each parameter. For this reason, a way to improve accuracy would be to fine-tune the model using generated image-text pairs in the target style of my application.

My current solution loads the default latin1.pth model, but instead uses my own prediction layer that reflects my reduced character-set. Transfer-learning is then performed by only training the weights of this final layer. The training process is set-up as follows:

import tempfile

from os import path
from typing import List

from easyocr.model.model import Model
from easyocr.utils import CTCLabelConverter
from easyocr.recognition import AlignCollate
import easyocr

import torch
import torch.nn.functional
import torch.utils.data

import numpy as np

from .. import textgen, translations

# Omitted: methods that wrap all code below

# Lexicon, a dictionary of words that appear in this specific use-case
lexicon = translations.Lexicon(dict_path)
lexicon.augment_capitalization(len(lexicon))
lexicon.augment_jibberish(round(len(lexicon)/5))
lexicon.augment_numbers(round(len(lexicon)/10))

# Initialize the CTC label converter (see CRNN paper)
character = lexicon.characters()
character.add(' ')

converter = CTCLabelConverter(character)
num_class = len(converter.character)

# Initialize the feature mapping and sequence labeling model
model = Model(input_channel, output_channel, hidden_size, num_class=num_class)
model = torch.nn.DataParallel(model).to(dev)

for name, param in model.named_parameters():
    if 'localization_fc2' in name:
        print(f'Skip {name} as it is already initialized')
        continue
    try:
        if 'bias' in name:
            torch.nn.init.constant_(param, 0.)
        elif 'weight' in name:
            torch.nn.init.kaiming_normal_(param)
    except Exception:  # for batchnorm.
        if 'weight' in name:
            param.data.fill_(1.)
        continue

# Load weights from EasyOCR default checkpoint for latin characters
checkpoint = torch.load("models/latin.pth")
checkpoint["module.Prediction.weight"] = torch.randn((96, 512)) * 0.01
checkpoint["module.Prediction.bias"] = torch.zeros(96)

model.load_state_dict(checkpoint)
model.train()

# Disable training on all layers except the final prediction layer
for param in model.parameters():
        param.requires_grad = False

for param in model.module.Prediction.parameters():
        param.requires_grad = True

# CTC Loss criterion
criterion = torch.nn.CTCLoss(zero_infinity=False).to(dev)

# Optimizer (CRNN paper recommends Adagrad)
filtered_parameters = []
params_num = 0
for p in filter(lambda p: p.requires_grad, model.parameters()):
    filtered_parameters.append(p)
    params_num += np.prod(p.size())

print('Trainable params num : ', params_num)

optimizer = torch.optim.Adagrad(filtered_parameters)

# Wrap AlignCollate in our own collate function
def collate(batch):
    ratios = []
    imgs = []
    lbls = []
    for img, lbl in batch:
        lbls.append(lbl)
        ratios.append(float(img.size[0])/float(img.size[1]))
        imgs.append(img)

    imgs = AlignCollate(imgH=64, imgW=int(max(ratios) * 64), keep_ratio_with_pad=True)(imgs)

    return imgs, lbls

# TRDG library requires a directory to look for background images
with tempfile.TemporaryDirectory() as bg_dir:
    # Dataset that yields random (image,text) pairs, sentences of max 3 random word, using TRDG library
    ds = textgen.TextDataset(lexicon, bg_dir)
    train_dl = torch.utils.data.DataLoader(ds, batch_size=batch_size, pin_memory=True, drop_last=True, collate_fn=collate)

    # Required for CTCLoss
    torch.backends.cudnn.deterministic = True

    # Training loop
    for (i, (img, lbl)) in enumerate(train_dl):
        img = img.to(dev)

        # Encode the text label
        lbl_encoded, length = converter.encode(lbl)

        # Run the model
        model.zero_grad()
        preds = model(img, None)
        preds_size = torch.IntTensor([preds.size(1)] * img.size(0))
        preds_log_softmax = preds.log_softmax(2)

        # Calculate loss
        cost = criterion(preds_log_softmax.permute(1, 0, 2), lbl_encoded, preds_size, length)

        print(f"Iteration\t{i}:\tcost {cost.item()}")

        # Optimizer step
        cost.backward()
        torch.nn.utils.clip_grad_norm_(filtered_parameters, 5)
        optimizer.step()

        # Show the result every 5 steps
        if i % 5 == 0:
            _, preds_index = preds_log_softmax.detach().max(2)
            preds_index = preds_index.view(-1)
            preds_str = converter.decode_greedy(preds_index.data, preds_size.data)

            for idx, (true_lbl, pred_lbl) in enumerate(zip(lbl, preds_str)):
                print(f"\t- true {idx}\t: {true_lbl}")
                print(f"\t- pred {idx}\t: {pred_lbl}")

While this already significantly improves accuracy, I would like to go further and also train the remaining layers. I notice though that when I try to train all layers simultaneously, the model quickly diverges. It is not clear to me whether this is due to a mistake in the training script or something else that I am not accounting for.

How can I (re-)train the model either from scratch or using latin.pth as a starting point?

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Reactions:12
  • Comments:9

github_iconTop GitHub Comments

3reactions
piotrostrcommented, Jan 30, 2021

Hey @kurt-stolle. There is a guide on training for your own dataset as well as failure cases under that repository. In the paper they also outline some possible solutions to overcome the failure cases like say low-res images. I faced a similar problem as you and found it helpful so thought it might be worth sharing 😃

1reaction
kurt-stollecommented, May 21, 2021

@LanzaMercado Sadly, this issue was never resolved. I ended up writing my own OCR library with off-the-shelf networks, using a similar three-stage approach as EasyOCR does.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Improving Model Accuracy with Transfer Learning, Data ...
Learn how to improve your models with transfer learning, data augmentation, LR Finder, and much more using this hands on guide with image...
Read more >
Improve your model accuracy by Transfer Learning. - Medium
In this blog post, I'll help you in understanding how we can work with real-life implementations of images or object recognition with one...
Read more >
How to Improve Performance With Transfer Learning for Deep ...
In this tutorial, you will discover how to use transfer learning to improve the performance deep learning neural networks in Python with ...
Read more >
A Practical Tutorial With Examples for Images and Text in Keras
In this article, you explored transfer learning, with examples of how to use it to develop models faster. You used pre-trained models in...
Read more >
Transfer learning from pre-trained models | by Pedro Marcelino
Validation accuracy is around 0.85, which is encouraging given the size of the dataset. The model strongly overfits. There's a big gap between ......
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found