Can't add new language to pre-trained spoken language recognition model: Model forgets other languages
See original GitHub issueHello,
I am trying to fine-tune existing spoken language recognition model. I chose common voice language and trying to add new language. I did things exactly as they are described in fine-tuning tutorial (and ensured unknown label in label encoder as well).
I also tried to freeze more layers, for example, I froze every modules except classifier. However, when I fine-tune the model, the performance gets worse. For example, during the first several epochs model gives different incorrect outputs. However, around 5th epochs it starts assigning every language the label I want to add.
I also tried to fine-tune model on 19 different languages (including previously unknown), however, the results still same. Is there any way to fine-tune model to predict new languages or this model is not supposed to be fine-tuned? Why model can’t learn new languages and forgets old during fine-tuning?
Here is the class I used in fine-tuning
class LanguageBrain(speechbrain.core.Brain):
def on_stage_start(self, stage, epoch):
# enable grad for all modules we want to fine-tune
if stage == speechbrain.Stage.TRAIN:
for module in [self.modules.compute_features, self.modules.mean_var_norm,
self.modules.embedding_model, self.modules.classifier]:
for p in module.parameters():
p.requires_grad = True
def compute_forward(self, batch, stage):
"""Computation pipeline based on a encoder + speaker classifier.
Data augmentation and environmental corruption are applied to the
input speech.
"""
batch = batch.to(self.device)
wavs, lens = batch.sig
#wavs, lens = wavs.to(self.device), lens.to(self.device)
if stage == speechbrain.Stage.TRAIN:
# Applying the augmentation pipeline
wavs_aug_tot = []
wavs_aug_tot.append(wavs)
# Apply augment
wavs_aug = self.hparams.augment_speed(wavs, lens)
wavs_aug = self.hparams.add_rev_noise(wavs, lens)
# Managing speed change
if wavs_aug.shape[1] > wavs.shape[1]:
wavs_aug = wavs_aug[:, 0 : wavs.shape[1]]
else:
zero_sig = torch.zeros_like(wavs)
zero_sig[:, 0 : wavs_aug.shape[1]] = wavs_aug
wavs_aug = zero_sig
wavs = wavs_aug
wavs_aug_tot[0] = wavs
wavs = torch.cat(wavs_aug_tot, dim=0)
self.n_augment = len(wavs_aug_tot)
lens = torch.cat([lens] * self.n_augment)
feats = self.modules.compute_features(wavs)
feats = self.modules.mean_var_norm(feats, lens)
# Embeddings + speaker classifier
embeddings = self.modules.embedding_model(feats, lens)
outputs = self.modules.classifier(embeddings)
return outputs, lens
def compute_objectives(self, predictions, batch, stage):
"""Computes the loss using speaker-id as label.
"""
predictions, lens = predictions
lens = lens
uttid = batch.id
langid = batch.lang_id_encoded
langid = torch.cat([langid] * self.n_augment, dim=0)
loss = self.hparams.compute_cost(predictions, langid.unsqueeze(1), lens)
return loss
def on_stage_end(self, stage, stage_loss, epoch=None):
"""Gets called at the end of an epoch."""
stage_stats = {"loss": stage_loss}
self.hparams.checkpointer.save_and_keep_only(
meta={"loss": stage_stats["loss"]},
min_keys=["loss"])
Issue Analytics
- State:
- Created a year ago
- Comments:5
Top GitHub Comments
That could play an important role. I think it is important to make sure there are data from different languages in each batch.
Looks like this is solved; closing this one—please feel free to reopen 😃