question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Examples for speech recognition trainings from scratch

See original GitHub issue

🚀 Feature request

Fine-tuning is rather straight forward but it looks to me as if running a training from scratch isn’t. I am rather new to 🤗 but from what I’ve learned to far is that it’s rather tricky to get by and find out how to start a new Speech2Text training (for example).

We got run_wav2vec2_pretraining_no_trainer.py in order to train a new Wav2Vec2 model from scratch but I wonder why this is (explicitly) not using the Trainer API? Is there any particular reason?

Motivation

After running into out-of-memory issues during Wav2Vec2 trainings I figured it would be better to use a smaller model for this purpose. Since training an end-to-end model using Wav2Vec2 requires multiple stages I thought it would be better to start with a simple Speech2Text transformer model and continue from there. However, up until now I am unable to properly run a training. For some reason the word-error-rate is basically 0% from the start only to get worse over time to a point where the model is not predicting anything anymore. I have no explanation for this but you can take a look at the code that (in a sense) brought me here.

Code (click to expand)
import json
import os
from dataclasses import dataclass
from functools import partial
from typing import List, Dict, Union

import torch
import tqdm
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import IterableDataset
from transformers import TrainingArguments, Trainer, trainer_utils, Speech2TextTokenizer, Speech2TextFeatureExtractor, \
    Speech2TextProcessor, Speech2TextConfig, Speech2TextModel, Speech2TextForConditionalGeneration, Seq2SeqTrainer, \
    IntervalStrategy, EarlyStoppingCallback, Seq2SeqTrainingArguments
import sentencepiece as spm
import tensorflow as tf

from .speech.bin.hf_train import get_dataset, get_preprocessor
from .speech.data.speech_dataset import SpeechRecognitionDatasets
from .speech import bin as binaries
from .speech.lab.training.metrics import error_rate
import numpy as np


class Speech2TextTFDataset(IterableDataset):

    def __init__(self, processor: Speech2TextProcessor, text_preprocessor, dataset: tf.data.Dataset, num_samples: int = None):
        self.processor = processor
        self.text_preprocessor = text_preprocessor
        self.dataset = dataset
        self.num_samples = num_samples

    def __len__(self):
        if self.num_samples is None:
            raise RuntimeError("Number of samples is unknown.")
        return self.num_samples

    def __getitem__(self, item):
        raise NotImplementedError

    def __iter__(self):
        for example in self.dataset:
            inputs = example["inputs"]
            targets = example["targets"].numpy()[0].decode()
            targets = self.text_preprocessor.preprocess(targets)

            sampling_rate = self.processor.feature_extractor.sampling_rate
            # Extract features & target labels
            audio_features = self.processor.feature_extractor(inputs, sampling_rate=sampling_rate)["input_features"][0]
            labels = self.processor.tokenizer.encode(targets)

            size, _ = audio_features.shape
            attention_mask = torch.ones(size)

            yield dict(inputs=audio_features, targets=labels, attention_mask=attention_mask)

    @classmethod
    def get_split(cls, processor, text_preprocessor, datasets: SpeechRecognitionDatasets, split: str, max_samples=None):
        dataset = datasets.get(split, load_noise=False)
        if split == "train":
            dataset = dataset.repeat()

        if max_samples is not None:
            dataset = dataset.take(max_samples)

        num_samples = datasets.get_num_speech_samples(split)
        return cls(processor, text_preprocessor, dataset, num_samples=num_samples)


@dataclass
class Speech2TextCollator:

    def __init__(self, processor: Speech2TextProcessor):
        self.processor = processor

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        inputs = [torch.Tensor(f["inputs"]) for f in features]
        targets = [torch.Tensor(f["targets"]) for f in features]
        # Create batches
        inputs_batch = pad_sequence(inputs, batch_first=True)
        targets_batch = pad_sequence(targets, batch_first=True).long()
        attention_mask = pad_sequence([f["attention_mask"] for f in features], batch_first=True).long()
        return dict(
            input_features=inputs_batch,
            # decoder_input_ids=targets_batch,
            attention_mask=attention_mask,
            labels=targets_batch
        )


def compute_metrics(processor: Speech2TextProcessor, pred):
    # pred_logits = pred.predictions
    pred_ids = np.argmax(pred.predictions[0], axis=-1)
    pred_str = processor.batch_decode(pred_ids)

    # we do not want to group tokens when computing the metrics
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer = error_rate(targets=label_str, predictions=pred_str, tokens="words")
    cer = error_rate(targets=label_str, predictions=pred_str, tokens="characters")

    return {"wer": wer, "cer": cer}


def get_sentence_piece_model(sentence_generator, text_preprocessor, overwrite=False):

    model_prefix = "/tmp/en"
    vocab_file = model_prefix + ".json"
    spm_file = model_prefix + ".model"

    if os.path.exists(vocab_file) and os.path.exists(spm_file) and not overwrite:
        return vocab_file, spm_file

    text_fp = "/tmp/spm.txt"
    with open(text_fp, "w") as f:
        for sentence in sentence_generator():
            text = sentence.strip()
            text = text_preprocessor.preprocess(text)
            f.write(text)
            f.write("\n")

    spm.SentencePieceTrainer.Train(
        input=text_fp,
        vocab_size=1000,
        model_prefix=model_prefix,
        user_defined_symbols=["<mask>"],
        # hard_vocab_limit=False,
    )

    processor = spm.SentencePieceProcessor()
    processor.Load(model_file=model_prefix + ".model")

    vocab_file = model_prefix + ".json"
    spm_file = model_prefix + ".model"

    # noinspection PyUnresolvedReferences
    vocab = {processor.id_to_piece(piece_id): piece_id for piece_id in range(processor.get_piece_size())}
    with open(vocab_file, "w") as f:
        json.dump(vocab, f, indent=2)

    return vocab_file, spm_file


def main():

    # TODO Paths!
    local_raw_root = "/data/-asr/corpora/raw"
    local_shards_root = "/data/-asr/corpora/sharded"
    # TODO Paths!
    remote_raw_root = "/mariana/asr/raw"
    remote_shards_root = "/mariana/asr/corpora/sharded"
    remote_converted_root = "/mariana/asr/corpora/converted"
    remote_vocabs_root = "/mariana/asr/vocabularies/masking"

    tf.config.set_visible_devices([], "GPU")

    out_dir = "/data/-asr/models/huggingface/dev"
    log_dir = os.path.join(out_dir, "logs")
    config_fp = os.path.join(os.path.dirname(binaries.__file__), "configs/data/en/timit.yml")
    early_stopping_patience = 5

    print(f"data config: {config_fp}")

    with tf.device("cpu"):
        asr_datasets = get_dataset(
            config_fp=config_fp,
            local_raw_root=local_raw_root,
            local_shards_root=local_shards_root,
            remote_raw_root=remote_raw_root,
            remote_shards_root=remote_shards_root,
            remote_converted_root=remote_converted_root,
        )

        vocab_config_fp = os.path.join(
            os.path.dirname(binaries.__file__), f"configs/vocabulary/{asr_datasets.language}.yml"
        )

        text_preprocessor = get_preprocessor(
            remote_vocabs_root=remote_vocabs_root, vocab_config_fp=vocab_config_fp, asr_datasets=asr_datasets
        )

    sampling_rate = 16_000
    max_vocab_samples = 100000

    def sentence_generator():
        for i, example in tqdm.tqdm(enumerate(asr_datasets.get("train", load_noise=False)), total=max_vocab_samples):
            if i >= max_vocab_samples:
                break
            targets = example["targets"].numpy()[0].decode()
            yield text_preprocessor.preprocess(targets)

    vocab_file, spm_file = get_sentence_piece_model(
        sentence_generator=sentence_generator, text_preprocessor=text_preprocessor, overwrite=False
    )
    tokenizer = Speech2TextTokenizer(vocab_file=vocab_file, spm_file=spm_file)
    feature_extractor = Speech2TextFeatureExtractor(sampling_rate=sampling_rate)
    processor = Speech2TextProcessor(
        feature_extractor=feature_extractor,
        tokenizer=tokenizer
    )

    save_and_eval_steps = 1

    training_args = Seq2SeqTrainingArguments(
        output_dir=out_dir,
        evaluation_strategy=IntervalStrategy("steps"),
        save_steps=save_and_eval_steps,
        eval_steps=save_and_eval_steps,
        num_train_epochs=3,
        per_device_train_batch_size=16,
        per_device_eval_batch_size=64,
        warmup_steps=500,
        weight_decay=0.01,
        logging_dir=log_dir,
        group_by_length=True,
        # label_smoothing_factor=1,
        load_best_model_at_end=True,
        save_total_limit=2,
    )

    # Create the model
    config = Speech2TextConfig(
        return_dict=True,
        sampling_rate=sampling_rate,
        vocab_size=tokenizer.vocab_size,
        pad_token_id=processor.tokenizer.pad_token_id,
        bos_token_id=processor.tokenizer.bos_token_id,
        eos_token_id=processor.tokenizer.eos_token_id,
        decoder_start_token_id=processor.tokenizer.bos_token_id,
    )

    model = Speech2TextForConditionalGeneration(config)
    # model = Speech2TextModel(config)
    model.train()

    train_dataset = Speech2TextTFDataset.get_split(
        processor=processor, text_preprocessor=text_preprocessor, datasets=asr_datasets, split="train"
    )
    eval_dataset = Speech2TextTFDataset.get_split(
        processor=processor, text_preprocessor=text_preprocessor, datasets=asr_datasets, split="dev", max_samples=3
    )

    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=Speech2TextCollator(processor=processor),
        compute_metrics=partial(compute_metrics, processor),
        callbacks=[EarlyStoppingCallback(early_stopping_patience=early_stopping_patience)],
    )

    last_checkpoint = trainer_utils.get_last_checkpoint(out_dir)
    trainer.train(resume_from_checkpoint=last_checkpoint)

    print("All done.")


if __name__ == '__main__':
    main()

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:14 (5 by maintainers)

github_iconTop GitHub Comments

1reaction
patrickvonplatencommented, Feb 15, 2022

There are also the XLS-R checkpoints which have been pretrained on over 128 languages 😃 https://huggingface.co/models?other=xls_r_pretrained

1reaction
stefan-falkcommented, Jan 19, 2022

@patrickvonplaten Okay, I see. The issue here is just that I am now reliant on the availability of pre-trained models in all the languages I want to support. For example, facebok/wav2vec2-base was only trained on English which probably does not help for languages like Chinese. Going for he cross-language model is also not an option due to its size.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Train Your Own Speech Recognition Model in 5 Simple Steps
A quick tutorial to get ready your own speech recognition model ... If you want to train a model from scratch there is...
Read more >
Building an end-to-end Speech Recognition model in PyTorch
The complete guide on how to build an end-to-end Speech Recognition model in PyTorch. Train your own CTC Deep Speech model using this ......
Read more >
Signal Processing | Building Speech to Text Model in Python
Here's a tutorial to signal processing and build speech-to-text model ... we implement our own speech-to-text model from scratch in Python.
Read more >
Audio Deep Learning Made Simple: Automatic Speech ...
These are the most well-known examples of Automatic Speech Recognition (ASR). This class of applications starts with a clip of spoken audio ...
Read more >
Speech Recognition from scratch using Dilated Convolutions ...
Speech recognition has been amongst one of the hardest tasks in Machine Learning. Traditional approaches involve meticulous crafting and ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found