[Wav2Vec2] Wav2Vec2Conformer Fine-Tuned seems to give Gibberish on Librispeech example
See original GitHub issueš Bug
Wav2Vec2ās newly released fine-tuned conformer checkpoints (see here) donāt produce reasonable results on an example of Librispeech.
Iām not sure if the model requires a different
To Reproduce
-
Download 960h fine-tuned checkpoint:
wget https://dl.fbaipublicfiles.com/fairseq/conformer/wav2vec2/librilight/LL_relpos_PT_960h_FT.pt
-
Download Librispeech Dict:
wget https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt
-
Load a sample of the Librispeech clean dataset for inference. You can load a dummy sample via the Hugging Face Hub
pip install datasets
from datasets import load_dataset
libri_dummy = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
# check out the dataset
print(libri_dummy)
- Run a forward pass
import torch
import fairseq
input_sample = torch.tensor(libri_dummy[0]["audio"]["array"])[None, :]
# normalize
input_sample = torch.nn.functional.layer_norm(input_sample, input_sample.shape)
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(['LL_relpos_PT_960h_FT.pt'], arg_overrides={"data": "/path/to/folder/of/dict"})
model = model[0]
model.eval()
logits = model(source=input_sample, padding_mask=None)["encoder_out"]
- Decode the prediction
The output is a tensor of shape [seq_len, 1, vocab_size]. We are interested in the most likely token for each time step. So we can take the argmax:
predicted_ids = torch.argmax(logits[:, 0], dim=-1)
- Now weāll create our own decoder based on the dict we downloaded previously to decode the result (itās just the decoder put into json format)
json_dict = {"<s>": 0, "<pad>": 1, "</s>": 2, "<unk>": 3, "|": 4, "E": 5, "T": 6, "A": 7, "O": 8, "N": 9, "I": 10, "H": 11, "S": 12, "R": 13, "D": 14, "L": 15, "U": 16, "M": 17, "W": 18, "C": 19, "F": 20, "G": 21, "Y": 22, "P": 23, "B": 24, "V": 25, "K": 26, "'": 27, "X": 28, "J": 29, "Q": 30, "Z": 31}
and create a decoder
import numpy as np
from itertools import groupby
class Decoder:
def __init__(self, json_dict):
self.dict = json_dict
self.look_up = np.asarray(list(self.dict.keys()))
def decode(self, ids):
converted_tokens = self.look_up[ids]
fused_tokens = [tok[0] for tok in groupby(converted_tokens)]
output = ' '.join(''.join(''.join(fused_tokens).split("<s>")).split("|"))
return output
Now we can decode the output and compare it to the correct output:
decoder = Decoder(json_dict=json_dict)
print("Prediction: ", decoder.decode(predicted_ids))
As we can see the prediction is wrong:
Prediction: AY N N N VN V'V'V'IRSIMG KMNJB TPEPEMEDYR RGQTQ'OB 'HNJ<unk>TQURIEMJ' 'B'F' TM'VS'NEMDJH DSB'CNSTITE RYKYRSITPSV'DYNY' M'SOEPUGSYDYH'BYTITIPKV UFMQ'W'YJRDH' MVY'SGM'GNE F'YZH'U IFB'N' ' A V YA IN
The correct transcription is:
print(libri_dummy[0]["text"]
'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'
Also from looking at the predicted ids of the model (the argmax logits):
tensor([ 7, 22, 4, 4, 9, 4, 4, 9, 9, 9, 4, 4, 9, 4, 4, 4, 25, 9,
4, 4, 4, 4, 4, 25, 27, 25, 25, 25, 27, 25, 27, 10, 13, 0, 12, 10,
17, 21, 21, 21, 0, 4, 26, 17, 17, 17, 9, 9, 9, 9, 29, 0, 24, 4,
6, 23, 5, 23, 5, 17, 17, 5, 5, 5, 14, 22, 22, 13, 13, 4, 4, 13,
21, 21, 21, 21, 30, 6, 30, 27, 8, 24, 0, 4, 27, 11, 9, 29, 29, 3,
6, 30, 16, 13, 10, 10, 5, 17, 29, 29, 29, 27, 4, 4, 27, 24, 24, 27,
20, 27, 27, 27, 4, 4, 6, 0, 17, 27, 25, 12, 12, 27, 9, 5, 17, 14,
29, 11, 4, 14, 12, 12, 24, 27, 19, 19, 9, 12, 6, 10, 10, 10, 6, 6,
5, 4, 4, 0, 4, 13, 22, 22, 22, 22, 26, 22, 13, 13, 12, 10, 0, 6,
23, 12, 0, 25, 27, 14, 22, 9, 22, 27, 27, 4, 17, 27, 12, 8, 8, 5,
5, 23, 16, 21, 12, 22, 14, 22, 11, 27, 24, 22, 6, 10, 6, 10, 10, 10,
23, 26, 25, 25, 25, 0, 4, 16, 20, 17, 30, 27, 27, 18, 27, 22, 29, 13,
14, 11, 11, 0, 27, 4, 4, 17, 17, 25, 22, 22, 27, 12, 21, 17, 27, 27,
27, 21, 9, 5, 5, 5, 4, 0, 0, 20, 27, 22, 22, 31, 11, 11, 11, 27,
16, 16, 4, 10, 0, 0, 0, 20, 24, 27, 27, 27, 27, 27, 27, 27, 27, 27,
27, 9, 9, 9, 27, 4, 4, 4, 27, 4, 4, 4, 7, 4, 25, 4, 22, 7,
4, 4, 10, 9])
It does seems like there is something wrong with the model and not just the dictionary. There is no overwhelmingly present id which could represent silence.
Expected behavior
The model should work correctly here.
Environment
- fairseq Version (e.g., 1.0 or main): main
- PyTorch Version (e.g., 1.0): '1.10.2+cu102
- OS (e.g., Linux): Linux
- How you installed fairseq (
pip
, source): pip (as shown in Readme) - Build command you used (if compiling from source):
- Python version: 3.9.7.
Additional context
Issue Analytics
- State:
- Created a year ago
- Comments:10 (9 by maintainers)
Top GitHub Comments
Thanks for the ping @patrickvonplaten. I will look into this and get back to you.
The very first command actually worked correctly @rahulshivajipawar
There is also a HF implementation now: https://github.com/huggingface/transformers/pull/16812