Transfer-learning to improve accuracy for a specific font and background
See original GitHub issueIn my specific use-case, I only need to recognize texts of which the font is always the same and the background is always darker than the foreground. Additionally, the character set is smaller than the one that the default latin.pth
checkpoint is trained on. The default model does not meet my accuracy requirements, even after fine-tuning each parameter. For this reason, a way to improve accuracy would be to fine-tune the model using generated image-text pairs in the target style of my application.
My current solution loads the default latin1.pth
model, but instead uses my own prediction layer that reflects my reduced character-set. Transfer-learning is then performed by only training the weights of this final layer. The training process is set-up as follows:
import tempfile
from os import path
from typing import List
from easyocr.model.model import Model
from easyocr.utils import CTCLabelConverter
from easyocr.recognition import AlignCollate
import easyocr
import torch
import torch.nn.functional
import torch.utils.data
import numpy as np
from .. import textgen, translations
# Omitted: methods that wrap all code below
# Lexicon, a dictionary of words that appear in this specific use-case
lexicon = translations.Lexicon(dict_path)
lexicon.augment_capitalization(len(lexicon))
lexicon.augment_jibberish(round(len(lexicon)/5))
lexicon.augment_numbers(round(len(lexicon)/10))
# Initialize the CTC label converter (see CRNN paper)
character = lexicon.characters()
character.add(' ')
converter = CTCLabelConverter(character)
num_class = len(converter.character)
# Initialize the feature mapping and sequence labeling model
model = Model(input_channel, output_channel, hidden_size, num_class=num_class)
model = torch.nn.DataParallel(model).to(dev)
for name, param in model.named_parameters():
if 'localization_fc2' in name:
print(f'Skip {name} as it is already initialized')
continue
try:
if 'bias' in name:
torch.nn.init.constant_(param, 0.)
elif 'weight' in name:
torch.nn.init.kaiming_normal_(param)
except Exception: # for batchnorm.
if 'weight' in name:
param.data.fill_(1.)
continue
# Load weights from EasyOCR default checkpoint for latin characters
checkpoint = torch.load("models/latin.pth")
checkpoint["module.Prediction.weight"] = torch.randn((96, 512)) * 0.01
checkpoint["module.Prediction.bias"] = torch.zeros(96)
model.load_state_dict(checkpoint)
model.train()
# Disable training on all layers except the final prediction layer
for param in model.parameters():
param.requires_grad = False
for param in model.module.Prediction.parameters():
param.requires_grad = True
# CTC Loss criterion
criterion = torch.nn.CTCLoss(zero_infinity=False).to(dev)
# Optimizer (CRNN paper recommends Adagrad)
filtered_parameters = []
params_num = 0
for p in filter(lambda p: p.requires_grad, model.parameters()):
filtered_parameters.append(p)
params_num += np.prod(p.size())
print('Trainable params num : ', params_num)
optimizer = torch.optim.Adagrad(filtered_parameters)
# Wrap AlignCollate in our own collate function
def collate(batch):
ratios = []
imgs = []
lbls = []
for img, lbl in batch:
lbls.append(lbl)
ratios.append(float(img.size[0])/float(img.size[1]))
imgs.append(img)
imgs = AlignCollate(imgH=64, imgW=int(max(ratios) * 64), keep_ratio_with_pad=True)(imgs)
return imgs, lbls
# TRDG library requires a directory to look for background images
with tempfile.TemporaryDirectory() as bg_dir:
# Dataset that yields random (image,text) pairs, sentences of max 3 random word, using TRDG library
ds = textgen.TextDataset(lexicon, bg_dir)
train_dl = torch.utils.data.DataLoader(ds, batch_size=batch_size, pin_memory=True, drop_last=True, collate_fn=collate)
# Required for CTCLoss
torch.backends.cudnn.deterministic = True
# Training loop
for (i, (img, lbl)) in enumerate(train_dl):
img = img.to(dev)
# Encode the text label
lbl_encoded, length = converter.encode(lbl)
# Run the model
model.zero_grad()
preds = model(img, None)
preds_size = torch.IntTensor([preds.size(1)] * img.size(0))
preds_log_softmax = preds.log_softmax(2)
# Calculate loss
cost = criterion(preds_log_softmax.permute(1, 0, 2), lbl_encoded, preds_size, length)
print(f"Iteration\t{i}:\tcost {cost.item()}")
# Optimizer step
cost.backward()
torch.nn.utils.clip_grad_norm_(filtered_parameters, 5)
optimizer.step()
# Show the result every 5 steps
if i % 5 == 0:
_, preds_index = preds_log_softmax.detach().max(2)
preds_index = preds_index.view(-1)
preds_str = converter.decode_greedy(preds_index.data, preds_size.data)
for idx, (true_lbl, pred_lbl) in enumerate(zip(lbl, preds_str)):
print(f"\t- true {idx}\t: {true_lbl}")
print(f"\t- pred {idx}\t: {pred_lbl}")
While this already significantly improves accuracy, I would like to go further and also train the remaining layers. I notice though that when I try to train all layers simultaneously, the model quickly diverges. It is not clear to me whether this is due to a mistake in the training script or something else that I am not accounting for.
How can I (re-)train the model either from scratch or using latin.pth
as a starting point?
Issue Analytics
- State:
- Created 3 years ago
- Reactions:12
- Comments:9
Hey @kurt-stolle. There is a guide on training for your own dataset as well as failure cases under that repository. In the paper they also outline some possible solutions to overcome the failure cases like say low-res images. I faced a similar problem as you and found it helpful so thought it might be worth sharing 😃
@LanzaMercado Sadly, this issue was never resolved. I ended up writing my own OCR library with off-the-shelf networks, using a similar three-stage approach as EasyOCR does.