question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. ItΒ collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Unable to generate predictions for wav2vec model fine-tuned with custom data

See original GitHub issue

Discussed in https://github.com/PyTorchLightning/pytorch-lightning/discussions/11432

<div type='discussions-op-text'>

Originally posted by nayak24 January 11, 2022 Hi, I’m trying to fine-tune the baseline wav2vec model with my own audio training/test data using Lightning Flash, essentially exactly following the tutorial in this doc: https://lightning-flash.readthedocs.io/en/latest/reference/speech_recognition.html

However, I am running into an issue when generating the prediction for an audio file, and I’m getting a null output:

94.4 M Trainable params 0 Non-trainable params 94.4 M Total params 377.585 Total estimated model params size (MB) Epoch 0: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 88/88 [02:43<00:00, 1.86s/it, loss=633, v_num=57, train_loss_step=750.0] Predicting: 88it [00:00, ?it/s] [[β€˜β€™]]

I’m not sure what the issue is, as I’ve only replaced the Timit dataset with my own input data for fine-tuning, and the rest of the script follows exactly from the doc above. All of the input data are wav files with the following format:

format | 1 (uncompressed PCM) number of channel | 1 (mono) sampleRate | 16000 byteRate | 32000 blockAlign | 2 bitsPerSample (bit depth) | 16

I’m new to PyTorch Lightning and training with wav2vec as a whole, so I’m guessing that I’m missing something obvious. Any help would be greatly appreciated!

Here is the full script I’m running:

import torch
import flash
from flash.audio import SpeechRecognition, SpeechRecognitionData
from flash.core.data.utils import download_data

#download_data("https://pl-flash-data.s3.amazonaws.com/timit_data.zip", "./data")

datamodule = SpeechRecognitionData.from_csv(
    input_fields="file",
    target_fields="text",
    #train_file="data/timit/train.json",
    #test_file="data/timit/test.json",
    train_file="FLT034/FLT034-TRAIN.csv",
    test_file="FLT034/FLT034-TEST.csv",
    batch_size=4,
)

#can use any wav2vec model in HuggingFace as backbone for finetuning
model = SpeechRecognition(backbone="facebook/wav2vec2-base-960h")

#create trainer and finetune model
trainer = flash.Trainer(max_epochs=1)
trainer.finetune(model, datamodule=datamodule, strategy='no_freeze')

# predict on audio files
#datamodule = SpeechRecognitionData.from_files(predict_files=["data/timit/example.wav"], batch_size=4)
datamodule = SpeechRecognitionData.from_files(predict_files=["FLT034/FLT034-14.wav"], batch_size=4)
predictions = trainer.predict(model, datamodule=datamodule)
print(predictions)

# Save Checkpoint 
trainer.save_checkpoint("FL034_trained_model.pt") 

And here is a sample of the train.csv file with the annotations:

file,text
"./FLT034-12.wav","Weather at one seven five eight zulu."
"./FLT034-13.wav","Wind one niner zero at eight."
"./FLT034-14.wav","Visibility eight ceiling eight hundred overcast."
"./FLT034-15.wav","Temperature one five"
"./FLT034-16.wav","Dewpoint one four"
"./FLT034-17.wav","Altimeter three zero"
"./FLT034-18.wav","Get both sides on a mic"

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:13 (6 by maintainers)

github_iconTop GitHub Comments

2reactions
ethanwharriscommented, Apr 6, 2022

@HiPracheta @krshrimali Done some debugging here. Turns out we weren’t freezig the backbone properly. I have a PR that fixes it (#1275). With the changes from that PR I am getting the right result without needing to change the learning rate (just setting strategy="freeze" πŸ˜ƒ

0reactions
krshrimalicommented, Apr 4, 2022

Hi, @HiPracheta - Looks like @ethanwharris is spot on with his suggestion. Setting the learning_rate=1e-5 does the job.

I used the script below:

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch

import flash
from flash.audio import SpeechRecognition, SpeechRecognitionData
from flash.core.data.utils import download_data

# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/timit_data.zip", "./data")

datamodule = SpeechRecognitionData.from_json(
    "file",
    "text",
    train_file="data/timit/train.json",
    test_file="data/timit/test.json",
    batch_size=4,
)

# 2. Build the task
model = SpeechRecognition(backbone="facebook/wav2vec2-base-960h", learning_rate=1e-5)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="no_freeze")

# 4. Predict on audio files!
predict_datamodule = SpeechRecognitionData.from_files(predict_files=["data/timit/example.wav"], batch_size=4)
predictions = trainer.predict(model, datamodule=predict_datamodule)

print("Predictions: ", predictions)

# 5. Save the model!
trainer.save_checkpoint("speech_recognition_model.pt")

And the output prediction is: (which looks correct after listening to the audio file)

Predictions:  [['SHE HAD YER DARK SUIT IN GREASY WASHWATER ALL YEAR']]

(checked on my local system). πŸ˜ƒ

@ethanwharris - Probably something we can add to our documentation? I’ll create a quick PR if it sounds good to you!

Read more comments on GitHub >

github_iconTop Results From Across the Web

Unable to generate predictions for wav2vec model fine-tuned ...
Hi, I'm trying to fine-tune the baseline wav2vec model with my own audio training/test data using Lightning Flash, essentially exactly following theΒ ...
Read more >
Fine-Tune Wav2Vec2 for English ASR with Transformers
Wav2Vec2 is a pretrained model for Automatic Speech Recognition (ASR) ... the tokenizer responsible for decoding the model's predictions.
Read more >
Fine-tune and deploy a Wav2Vec2 model for speech ...
This post shows how to use SageMaker to easily fine-tune the latest Wav2Vec2 model from Hugging Face, and then deploy the model with...
Read more >
Fine-Tuning Hugging Face Model with Custom Dataset
End-to-end example to explain how to fine-tune the Hugging Face model with a custom dataset using TensorFlow and Keras.
Read more >
Return predictions wav2vec fairseq - python - Stack Overflow
After trying various things I was able to figure this out and trained a wav2vec model from scratch. Some background: wav2vec usesΒ ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found