question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

How to load exact pretrained model for the ready-to-use recipe?

See original GitHub issue

Hi!

I’ve been trying for a while to load the pretrained CRDNN model trained on the librispeech (https://huggingface.co/speechbrain/asr-crdnn-transformerlm-librispeech) to the exact training recipe (https://github.com/speechbrain/speechbrain/tree/develop/recipes/LibriSpeech/ASR/seq2seq) with which it was trained upon.

The reason I want to load it to the recipe is to have a full control over the forward loop as I need to change it a bit for my research.

So I made some tweaks to train.py to make it work, mainly:

  1. Inherit from pretrained interface, i.e
class  ASR(Pretrained):

Instead of

class  ASR(sb.Brain):
  1. Load the pretrained model:
 brain = ASR.from_hparams(source="speechbrain/asr-crdnn-rnnlm-librispeech",
                                        savedir=hparams["model_dir"])

However, the module names seem to not match. When I print with brain.mods.keys() I get:

odict_keys(['normalizer', 'encoder', 'decoder', 'lm_model'])

If I try to run the model, I start getting AttributeErrors, because in forward() function of the model I cannot access directly layer names. You can see that layer names don’t match those above odic_keys at all.

    def compute_forward(self, batch, stage):
        """Forward computations from the waveform batches to the output probabilities."""
        batch = batch.to(self.device)
        wavs, wav_lens = batch.signal
        tokens_bos, _ = batch.tokens_bos
        wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)

        # Add augmentation if specified
        if stage == sb.Stage.TRAIN:
            if hasattr(self.modules, "env_corrupt"):
                wavs_noise = self.modules.env_corrupt(wavs, wav_lens)
                wavs = torch.cat([wavs, wavs_noise], dim=0)
                wav_lens = torch.cat([wav_lens, wav_lens])
                tokens_bos = torch.cat([tokens_bos, tokens_bos], dim=0)

            if hasattr(self.hparams, "augmentation"):
                wavs = self.hparams.augmentation(wavs, wav_lens)

        # Forward pass
        feats = self.hparams.compute_features(wavs)
        feats = self.modules.normalize(feats, wav_lens)
        x = self.modules.enc(feats.detach())
        e_in = self.modules.emb(tokens_bos)  # y_in bos + tokens
        h, _ = self.modules.dec(e_in, x, wav_lens)

        # Output layer for seq2seq log-probabilities
        logits = self.modules.seq_lin(h)
        p_seq = self.hparams.log_softmax(logits)

        # Compute outputs
        if stage == sb.Stage.TRAIN:
            current_epoch = self.hparams.epoch_counter.current
            if current_epoch <= self.hparams.number_of_ctc_epochs:
                # Output layer for ctc log-probabilities
                logits = self.modules.ctc_lin(x)
                p_ctc = self.hparams.log_softmax(logits)
                return p_ctc, p_seq, wav_lens
            else:
                return p_seq, wav_lens
        else:
            if stage == sb.Stage.VALID:
                p_tokens, scores = self.hparams.valid_search(x, wav_lens)
            else:
                p_tokens, scores = self.hparams.test_search(x, wav_lens)
            return p_seq, wav_lens, p_tokens

For example:

AttributeError: 'function' object has no attribute 'normalize'

Is https://huggingface.co/speechbrain/asr-crdnn-rnnlm-librispeech/blob/main/hyperparams.yaml equivalent to https://github.com/speechbrain/speechbrain/blob/develop/recipes/LibriSpeech/ASR/seq2seq/hparams/train_BPE_1000.yaml?They seem to have a lot of differences thus I am not sure if the training file in the Librispeech folder matches the Huggingface model.

It seems however that Hugginface and Librispeech recipes don’t match??? Was the model on Hugginface trained with different architecture or hparams file? How can I make the above work?

Issue Analytics

  • State:closed
  • Created a year ago
  • Comments:15

github_iconTop GitHub Comments

1reaction
TParcolletcommented, Apr 26, 2022

Say no more: https://github.com/speechbrain/speechbrain/blob/cec3bc7f21361ad1123ee08dbca0129f60940dc6/recipes/VoxCeleb/SpeakerRec/hparams/verification_plda_xvector.yaml#L77

We have a Pretrainer that can load any arbitrary .ckpt as long as the YAML definition corresponding to this model is given :_)

1reaction
KacperKubaracommented, Apr 26, 2022

Alright, it works now. I was running it from the wrong folder…

I will recreate this checkpoint loading in my experiments and will let you know if everything works! Thanks.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Sparsifying YOLOv5 Using Recipes - sparseml - GitHub
Before applying one of the recipes, you must first create the pre-trained model to sparsify further. The pre-trained model enables pruning and other ......
Read more >
Training Pipelines & Models · spaCy Usage Documentation
Load in data resources defined in the [initialize] config, including word vectors and pretrained tok2vec weights. Call the initialize methods of the tokenizer...
Read more >
Model Zoo - Deep learning code and pretrained models for ...
ModelZoo curates and provides a platform for deep learning researchers to easily find code and pre-trained models for a variety of platforms and...
Read more >
Training with Custom Pretrained Models Using the NVIDIA ...
Gathering and preparing a large dataset and labeling all the images is expensive, time-consuming, and often requires domain expertise. To enable ...
Read more >
Pretrained models — asteroid 0.6.1dev documentation
Loading a pretrained model is super simple! from asteroid.models import ConvTasNet model = ConvTasNet.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found