Examples for speech recognition trainings from scratch
See original GitHub issue🚀 Feature request
Fine-tuning is rather straight forward but it looks to me as if running a training from scratch isn’t. I am rather new to 🤗 but from what I’ve learned to far is that it’s rather tricky to get by and find out how to start a new Speech2Text
training (for example).
We got run_wav2vec2_pretraining_no_trainer.py
in order to train a new Wav2Vec2
model from scratch but I wonder why this is (explicitly) not using the Trainer
API? Is there any particular reason?
Motivation
After running into out-of-memory issues during Wav2Vec2
trainings I figured it would be better to use a smaller model for this purpose. Since training an end-to-end model using Wav2Vec2
requires multiple stages I thought it would be better to start with a simple Speech2Text
transformer model and continue from there. However, up until now I am unable to properly run a training. For some reason the word-error-rate is basically 0% from the start only to get worse over time to a point where the model is not predicting anything anymore. I have no explanation for this but you can take a look at the code that (in a sense) brought me here.
Code (click to expand)
import json
import os
from dataclasses import dataclass
from functools import partial
from typing import List, Dict, Union
import torch
import tqdm
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import IterableDataset
from transformers import TrainingArguments, Trainer, trainer_utils, Speech2TextTokenizer, Speech2TextFeatureExtractor, \
Speech2TextProcessor, Speech2TextConfig, Speech2TextModel, Speech2TextForConditionalGeneration, Seq2SeqTrainer, \
IntervalStrategy, EarlyStoppingCallback, Seq2SeqTrainingArguments
import sentencepiece as spm
import tensorflow as tf
from .speech.bin.hf_train import get_dataset, get_preprocessor
from .speech.data.speech_dataset import SpeechRecognitionDatasets
from .speech import bin as binaries
from .speech.lab.training.metrics import error_rate
import numpy as np
class Speech2TextTFDataset(IterableDataset):
def __init__(self, processor: Speech2TextProcessor, text_preprocessor, dataset: tf.data.Dataset, num_samples: int = None):
self.processor = processor
self.text_preprocessor = text_preprocessor
self.dataset = dataset
self.num_samples = num_samples
def __len__(self):
if self.num_samples is None:
raise RuntimeError("Number of samples is unknown.")
return self.num_samples
def __getitem__(self, item):
raise NotImplementedError
def __iter__(self):
for example in self.dataset:
inputs = example["inputs"]
targets = example["targets"].numpy()[0].decode()
targets = self.text_preprocessor.preprocess(targets)
sampling_rate = self.processor.feature_extractor.sampling_rate
# Extract features & target labels
audio_features = self.processor.feature_extractor(inputs, sampling_rate=sampling_rate)["input_features"][0]
labels = self.processor.tokenizer.encode(targets)
size, _ = audio_features.shape
attention_mask = torch.ones(size)
yield dict(inputs=audio_features, targets=labels, attention_mask=attention_mask)
@classmethod
def get_split(cls, processor, text_preprocessor, datasets: SpeechRecognitionDatasets, split: str, max_samples=None):
dataset = datasets.get(split, load_noise=False)
if split == "train":
dataset = dataset.repeat()
if max_samples is not None:
dataset = dataset.take(max_samples)
num_samples = datasets.get_num_speech_samples(split)
return cls(processor, text_preprocessor, dataset, num_samples=num_samples)
@dataclass
class Speech2TextCollator:
def __init__(self, processor: Speech2TextProcessor):
self.processor = processor
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
inputs = [torch.Tensor(f["inputs"]) for f in features]
targets = [torch.Tensor(f["targets"]) for f in features]
# Create batches
inputs_batch = pad_sequence(inputs, batch_first=True)
targets_batch = pad_sequence(targets, batch_first=True).long()
attention_mask = pad_sequence([f["attention_mask"] for f in features], batch_first=True).long()
return dict(
input_features=inputs_batch,
# decoder_input_ids=targets_batch,
attention_mask=attention_mask,
labels=targets_batch
)
def compute_metrics(processor: Speech2TextProcessor, pred):
# pred_logits = pred.predictions
pred_ids = np.argmax(pred.predictions[0], axis=-1)
pred_str = processor.batch_decode(pred_ids)
# we do not want to group tokens when computing the metrics
label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
wer = error_rate(targets=label_str, predictions=pred_str, tokens="words")
cer = error_rate(targets=label_str, predictions=pred_str, tokens="characters")
return {"wer": wer, "cer": cer}
def get_sentence_piece_model(sentence_generator, text_preprocessor, overwrite=False):
model_prefix = "/tmp/en"
vocab_file = model_prefix + ".json"
spm_file = model_prefix + ".model"
if os.path.exists(vocab_file) and os.path.exists(spm_file) and not overwrite:
return vocab_file, spm_file
text_fp = "/tmp/spm.txt"
with open(text_fp, "w") as f:
for sentence in sentence_generator():
text = sentence.strip()
text = text_preprocessor.preprocess(text)
f.write(text)
f.write("\n")
spm.SentencePieceTrainer.Train(
input=text_fp,
vocab_size=1000,
model_prefix=model_prefix,
user_defined_symbols=["<mask>"],
# hard_vocab_limit=False,
)
processor = spm.SentencePieceProcessor()
processor.Load(model_file=model_prefix + ".model")
vocab_file = model_prefix + ".json"
spm_file = model_prefix + ".model"
# noinspection PyUnresolvedReferences
vocab = {processor.id_to_piece(piece_id): piece_id for piece_id in range(processor.get_piece_size())}
with open(vocab_file, "w") as f:
json.dump(vocab, f, indent=2)
return vocab_file, spm_file
def main():
# TODO Paths!
local_raw_root = "/data/-asr/corpora/raw"
local_shards_root = "/data/-asr/corpora/sharded"
# TODO Paths!
remote_raw_root = "/mariana/asr/raw"
remote_shards_root = "/mariana/asr/corpora/sharded"
remote_converted_root = "/mariana/asr/corpora/converted"
remote_vocabs_root = "/mariana/asr/vocabularies/masking"
tf.config.set_visible_devices([], "GPU")
out_dir = "/data/-asr/models/huggingface/dev"
log_dir = os.path.join(out_dir, "logs")
config_fp = os.path.join(os.path.dirname(binaries.__file__), "configs/data/en/timit.yml")
early_stopping_patience = 5
print(f"data config: {config_fp}")
with tf.device("cpu"):
asr_datasets = get_dataset(
config_fp=config_fp,
local_raw_root=local_raw_root,
local_shards_root=local_shards_root,
remote_raw_root=remote_raw_root,
remote_shards_root=remote_shards_root,
remote_converted_root=remote_converted_root,
)
vocab_config_fp = os.path.join(
os.path.dirname(binaries.__file__), f"configs/vocabulary/{asr_datasets.language}.yml"
)
text_preprocessor = get_preprocessor(
remote_vocabs_root=remote_vocabs_root, vocab_config_fp=vocab_config_fp, asr_datasets=asr_datasets
)
sampling_rate = 16_000
max_vocab_samples = 100000
def sentence_generator():
for i, example in tqdm.tqdm(enumerate(asr_datasets.get("train", load_noise=False)), total=max_vocab_samples):
if i >= max_vocab_samples:
break
targets = example["targets"].numpy()[0].decode()
yield text_preprocessor.preprocess(targets)
vocab_file, spm_file = get_sentence_piece_model(
sentence_generator=sentence_generator, text_preprocessor=text_preprocessor, overwrite=False
)
tokenizer = Speech2TextTokenizer(vocab_file=vocab_file, spm_file=spm_file)
feature_extractor = Speech2TextFeatureExtractor(sampling_rate=sampling_rate)
processor = Speech2TextProcessor(
feature_extractor=feature_extractor,
tokenizer=tokenizer
)
save_and_eval_steps = 1
training_args = Seq2SeqTrainingArguments(
output_dir=out_dir,
evaluation_strategy=IntervalStrategy("steps"),
save_steps=save_and_eval_steps,
eval_steps=save_and_eval_steps,
num_train_epochs=3,
per_device_train_batch_size=16,
per_device_eval_batch_size=64,
warmup_steps=500,
weight_decay=0.01,
logging_dir=log_dir,
group_by_length=True,
# label_smoothing_factor=1,
load_best_model_at_end=True,
save_total_limit=2,
)
# Create the model
config = Speech2TextConfig(
return_dict=True,
sampling_rate=sampling_rate,
vocab_size=tokenizer.vocab_size,
pad_token_id=processor.tokenizer.pad_token_id,
bos_token_id=processor.tokenizer.bos_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
decoder_start_token_id=processor.tokenizer.bos_token_id,
)
model = Speech2TextForConditionalGeneration(config)
# model = Speech2TextModel(config)
model.train()
train_dataset = Speech2TextTFDataset.get_split(
processor=processor, text_preprocessor=text_preprocessor, datasets=asr_datasets, split="train"
)
eval_dataset = Speech2TextTFDataset.get_split(
processor=processor, text_preprocessor=text_preprocessor, datasets=asr_datasets, split="dev", max_samples=3
)
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=Speech2TextCollator(processor=processor),
compute_metrics=partial(compute_metrics, processor),
callbacks=[EarlyStoppingCallback(early_stopping_patience=early_stopping_patience)],
)
last_checkpoint = trainer_utils.get_last_checkpoint(out_dir)
trainer.train(resume_from_checkpoint=last_checkpoint)
print("All done.")
if __name__ == '__main__':
main()
Issue Analytics
- State:
- Created 2 years ago
- Comments:14 (5 by maintainers)
There are also the XLS-R checkpoints which have been pretrained on over 128 languages 😃 https://huggingface.co/models?other=xls_r_pretrained
@patrickvonplaten Okay, I see. The issue here is just that I am now reliant on the availability of pre-trained models in all the languages I want to support. For example,
facebok/wav2vec2-base
was only trained on English which probably does not help for languages like Chinese. Going for he cross-language model is also not an option due to its size.