Integrate k2 into SpeechBrain
See original GitHub issueWe are planning to integrate k2 into SpeechBrain. At the first stage, we aim to apply k2 for decoding. The following shows how to use k2 for 1-best ctc decoding. Creating an issue here so that others know what we’re doing and this can also be used as a place for future discussions.
#!/usr/bin/env python3
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
from snowfall.training.ctc_graph import build_ctc_topo2
from speechbrain.pretrained import EncoderDecoderASR
import k2
import torch
def load_model():
model = EncoderDecoderASR.from_hparams(
source="speechbrain/asr-transformer-transformerlm-librispeech",
savedir="pretrained_models/asr-transformer-transformerlm-librispeech",
# run_opts={'device': 'cuda:0'},
)
return model
@torch.no_grad()
def main():
model = load_model()
device = model.device
# See https://huggingface.co/speechbrain/asr-transformer-transformerlm-librispeech/blob/main/example.wav
sound_file = './example.wav'
wav = model.load_audio(sound_file)
# wav is a 1-d tensor, e.g., [52173]
wavs = wav.unsqueeze(0).float().to(device)
# wavs is a 2-d tensor, e.g., [1, 52173]
wav_lens = torch.tensor([1.0])
wav_lens = wav_lens.to(device)
encoder_out = model.modules.encoder(wavs, wav_lens)
# encoder_out.shape [N, T, C], e.g., [1, 82, 768]
logits = model.hparams.ctc_lin(encoder_out)
# logits.shape [N, T, C], e.g., [1, 82, 5000]
log_probs = model.hparams.log_softmax(logits)
# log_probs.shape [N, T, C], e.g., [1, 82, 5000]
vocab_size = model.tokenizer.vocab_size()
ctc_topo = build_ctc_topo2(list(range(vocab_size)))
ctc_topo = k2.create_fsa_vec([ctc_topo]).to(device)
supervision_segments = torch.tensor([[0, 0, log_probs.size(1)]],
dtype=torch.int32)
dense_fsa_vec = k2.DenseFsaVec(log_probs, supervision_segments)
lattices = k2.intersect_dense_pruned(ctc_topo, dense_fsa_vec, 20.0, 8, 30,
10000)
best_path = k2.shortest_path(lattices, True)
aux_labels = best_path[0].aux_labels
aux_labels = aux_labels[aux_labels.nonzero().squeeze()]
# The last entry is -1, so remove it
aux_labels = aux_labels[:-1]
hyp = model.tokenizer.decode(aux_labels.tolist())
print(hyp)
if __name__ == '__main__':
main()
The above code is available at https://gist.github.com/csukuangfj/c68697cd144c8f063cc7ec4fd885fd6f
Issue Analytics
- State:
- Created 2 years ago
- Reactions:2
- Comments:7
Top Results From Across the Web
luomingshuang/k2-speechbrain: In this repository, I ... - GitHub
This repository aims to add WFST decoding based on k2 for speechbrain with python. You can know more details about how k2 implements...
Read more >Desh Raj on Twitter: "Of course, other amazing folks are also ...
Things may change soon, due to recent developments in k2 (https://github.com/k2-fsa/k2) and ... we can incorporate external LMs into the decoding graph.
Read more >Unifying Speech Technologies and Deep Learning With an ...
With this tutorial, we would like to present, for the first time, SpeechBrain to the INTERSPEECH attenders. First, the design and the ...
Read more >Speech Recognition with Next-Generation Kaldi (K2, Lhotse ...
Title: Speech Recognition with Next-Generation Kaldi ( K2, Lhotse, Icefall)Authors: Sanjeev Khudanpur, Daniel Povey, Piotr ŻelaskoCategory: ...
Read more >[R] SpeechBrain is out. A PyTorch Speech Toolkit. - Reddit
We still want to integrate HMM-based ASR on SpeechBrain, and we hope that K2 will be sufficiently documented and well-written to be nicely ......
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Yes there seems to be something wrong, we will try to debug.
On Sat, Jul 17, 2021 at 12:02 AM Mirco Ravanelli @.***> wrote:
I am also taking a look.