11import os
22import argparse
3-
3+ import whisper
4+ import torch
45
56from tqdm import tqdm
67import sys
1617from common .log import logger
1718from common .stdout_wrapper import SAFE_STDOUT
1819
20+ import re
21+
1922# 指定本地目录
2023local_dir_root = "./models_from_modelscope"
2124model_dir = snapshot_download ('damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' , cache_dir = local_dir_root )
2730
2831model_dir_en = snapshot_download ('damo/speech_UniASR_asr_2pass-en-16k-common-vocab1080-tensorflow1-offline' , cache_dir = local_dir_root )
2932
30-
33+ device = "cuda:0" if torch . cuda . is_available () else "cpu"
3134
3235
3336
6972)
7073
7174
75+ model = whisper .load_model ("medium" ,download_root = "./whisper_model/" )
76+
77+
7278
7379lang2token = {
7480 'zh' : "ZH|" ,
7985
8086def transcribe_one (audio_path ,language ):
8187
88+ audio = whisper .load_audio (audio_path )
89+ audio = whisper .pad_or_trim (audio )
90+ mel = whisper .log_mel_spectrogram (audio ).to (model .device )
91+ _ , probs = model .detect_language (mel )
92+ language = max (probs , key = probs .get )
93+
8294 if language == "zh" :
8395
8496 rec_result = inference_pipeline (audio_in = audio_path , param_dict = param_dict )
@@ -89,7 +101,7 @@ def transcribe_one(audio_path,language):
89101
90102 print (rec_result ["text" ])
91103
92- return rec_result ["text" ]
104+ return rec_result ["text" ], language
93105
94106
95107if __name__ == "__main__" :
@@ -127,22 +139,28 @@ def transcribe_one(audio_path,language):
127139 ]
128140
129141
130- if language == "ja" :
131- language_id = Languages .JP
132- elif language == "en" :
133- language_id = Languages .EN
134- elif language == "zh" :
135- language_id = Languages .ZH
136- else :
137- raise ValueError (f"{ language } is not supported." )
138-
139142 with open ("./esd.list" , "w" , encoding = "utf-8" ) as f :
140143 for wav_file in tqdm (wav_files , file = SAFE_STDOUT ):
141144 file_name = os .path .basename (wav_file )
142145
143- text = transcribe_one (f"{ input_file } " + wav_file ,language )
144-
145- f .write (file_pos + f"{ file_name } |{ speaker_name } |{ language_id } |{ text } \n " )
146+ text ,lang = transcribe_one (f"{ input_file } " + wav_file ,language )
147+
148+ # 使用正则表达式提取'deedee'
149+ match = re .search (r'(^.*?)_.*?(\..*?$)' , wav_file )
150+ if match :
151+ extracted_name = match .group (1 ) + match .group (2 )
152+ else :
153+ print ("No match found" )
154+ extracted_name = "sample"
155+
156+ if lang == "ja" :
157+ language_id = "JA"
158+ elif lang == "en" :
159+ language_id = "EN"
160+ elif lang == "zh" :
161+ language_id = "ZH"
162+
163+ f .write (file_pos + f"{ file_name } |{ extracted_name .replace ('.wav' ,'' )} |{ language_id } |{ text } \n " )
146164
147165 f .flush ()
148166 sys .exit (0 )
0 commit comments