Unable to generate predictions for wav2vec model fine-tuned with custom data
See original GitHub issueDiscussed in https://github.com/PyTorchLightning/pytorch-lightning/discussions/11432
<div type='discussions-op-text'>Originally posted by nayak24 January 11, 2022 Hi, Iβm trying to fine-tune the baseline wav2vec model with my own audio training/test data using Lightning Flash, essentially exactly following the tutorial in this doc: https://lightning-flash.readthedocs.io/en/latest/reference/speech_recognition.html
However, I am running into an issue when generating the prediction for an audio file, and Iβm getting a null output:
94.4 M Trainable params 0 Non-trainable params 94.4 M Total params 377.585 Total estimated model params size (MB) Epoch 0: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 88/88 [02:43<00:00, 1.86s/it, loss=633, v_num=57, train_loss_step=750.0] Predicting: 88it [00:00, ?it/s] [[ββ]]
Iβm not sure what the issue is, as Iβve only replaced the Timit dataset with my own input data for fine-tuning, and the rest of the script follows exactly from the doc above. All of the input data are wav files with the following format:
format | 1 (uncompressed PCM) number of channel | 1 (mono) sampleRate | 16000 byteRate | 32000 blockAlign | 2 bitsPerSample (bit depth) | 16
Iβm new to PyTorch Lightning and training with wav2vec as a whole, so Iβm guessing that Iβm missing something obvious. Any help would be greatly appreciated!
Here is the full script Iβm running:
import torch
import flash
from flash.audio import SpeechRecognition, SpeechRecognitionData
from flash.core.data.utils import download_data
#download_data("https://pl-flash-data.s3.amazonaws.com/timit_data.zip", "./data")
datamodule = SpeechRecognitionData.from_csv(
input_fields="file",
target_fields="text",
#train_file="data/timit/train.json",
#test_file="data/timit/test.json",
train_file="FLT034/FLT034-TRAIN.csv",
test_file="FLT034/FLT034-TEST.csv",
batch_size=4,
)
#can use any wav2vec model in HuggingFace as backbone for finetuning
model = SpeechRecognition(backbone="facebook/wav2vec2-base-960h")
#create trainer and finetune model
trainer = flash.Trainer(max_epochs=1)
trainer.finetune(model, datamodule=datamodule, strategy='no_freeze')
# predict on audio files
#datamodule = SpeechRecognitionData.from_files(predict_files=["data/timit/example.wav"], batch_size=4)
datamodule = SpeechRecognitionData.from_files(predict_files=["FLT034/FLT034-14.wav"], batch_size=4)
predictions = trainer.predict(model, datamodule=datamodule)
print(predictions)
# Save Checkpoint
trainer.save_checkpoint("FL034_trained_model.pt")
And here is a sample of the train.csv file with the annotations:
file,text
"./FLT034-12.wav","Weather at one seven five eight zulu."
"./FLT034-13.wav","Wind one niner zero at eight."
"./FLT034-14.wav","Visibility eight ceiling eight hundred overcast."
"./FLT034-15.wav","Temperature one five"
"./FLT034-16.wav","Dewpoint one four"
"./FLT034-17.wav","Altimeter three zero"
"./FLT034-18.wav","Get both sides on a mic"
Issue Analytics
- State:
- Created 2 years ago
- Comments:13 (6 by maintainers)

Top Related StackOverflow Question
@HiPracheta @krshrimali Done some debugging here. Turns out we werenβt freezig the backbone properly. I have a PR that fixes it (#1275). With the changes from that PR I am getting the right result without needing to change the learning rate (just setting
strategy="freeze"πHi, @HiPracheta - Looks like @ethanwharris is spot on with his suggestion. Setting the
learning_rate=1e-5does the job.I used the script below:
And the output prediction is: (which looks correct after listening to the audio file)
(checked on my local system). π
@ethanwharris - Probably something we can add to our documentation? Iβll create a quick PR if it sounds good to you!