-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathshort_audio_transcribe_whisper.py
180 lines (112 loc) · 4.08 KB
/
short_audio_transcribe_whisper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import os
import argparse
import whisper
import torch
from tqdm import tqdm
import sys
import os
from common.constants import Languages
from common.log import logger
from common.stdout_wrapper import SAFE_STDOUT
import re
from transformers import pipeline
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = None
lang2token = {
'zh': "ZH|",
'ja': "JP|",
"en": "EN|",
}
def transcribe_bela(audio_path):
transcriber = pipeline(
"automatic-speech-recognition",
model="BELLE-2/Belle-whisper-large-v2-zh",
device=device
)
transcriber.model.config.forced_decoder_ids = (
transcriber.tokenizer.get_decoder_prompt_ids(
language="zh",
task="transcribe",
)
)
transcription = transcriber(audio_path)
print(transcription["text"])
return transcription["text"]
def transcribe_one(audio_path,mytype):
# load audio and pad/trim it to fit 30 seconds
audio = whisper.load_audio(audio_path)
audio = whisper.pad_or_trim(audio)
# make log-Mel spectrogram and move to the same device as the model
if mytype == "large-v3":
mel = whisper.log_mel_spectrogram(audio,n_mels=128).to(model.device)
else:
mel = whisper.log_mel_spectrogram(audio).to(model.device)
# detect the spoken language
_, probs = model.detect_language(mel)
print(f"Detected language: {max(probs, key=probs.get)}")
lang = max(probs, key=probs.get)
# decode the audio
if lang == "zh":
if torch.cuda.is_available():
options = whisper.DecodingOptions(beam_size=5,prompt="生于忧患,死于欢乐。不亦快哉!")
else:
options = whisper.DecodingOptions(beam_size=5,fp16 = False,prompt="生于忧患,死于欢乐。不亦快哉!")
else:
if torch.cuda.is_available():
options = whisper.DecodingOptions(beam_size=5)
else:
options = whisper.DecodingOptions(beam_size=5,fp16 = False)
result = whisper.decode(model, mel, options)
# print the recognized text
print(result.text)
return result.text,max(probs, key=probs.get)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--language", type=str, default="ja", choices=["ja", "en", "zh"]
)
parser.add_argument(
"--mytype", type=str, default="medium"
)
parser.add_argument("--model_name", type=str, required=True)
parser.add_argument("--input_file", type=str, default="./wavs/")
parser.add_argument("--file_pos", type=str, default="")
args = parser.parse_args()
speaker_name = args.model_name
language = args.language
mytype = args.mytype
input_file = args.input_file
if input_file == "":
input_file = "./wavs/"
file_pos = args.file_pos
try:
model = whisper.load_model(mytype,download_root="./whisper_model/")
except Exception as e:
print(str(e))
print("中文特化逻辑")
wav_files = [
f for f in os.listdir(f"{input_file}") if f.endswith(".wav")
]
with open("./esd.list", "w", encoding="utf-8") as f:
for wav_file in tqdm(wav_files, file=SAFE_STDOUT):
file_name = os.path.basename(wav_file)
if model:
text,lang = transcribe_one(f"{input_file}"+wav_file,mytype)
else:
text,lang = transcribe_bela(f"{input_file}"+wav_file)
# 使用正则表达式提取'deedee'
match = re.search(r'(^.*?)_.*?(\..*?$)', wav_file)
if match:
extracted_name = match.group(1) + match.group(2)
else:
print("No match found")
extracted_name = "sample"
if lang == "ja":
language_id = "JA"
elif lang == "en":
language_id = "EN"
elif lang == "zh":
language_id = "ZH"
f.write(file_pos+f"{file_name}|{extracted_name.replace('.wav','')}|{language_id}|{text}\n")
f.flush()
sys.exit(0)