Skip to content

Commit 5b39353

Browse files
committed
接入阿里SenseVoice
1 parent afcf9a5 commit 5b39353

6 files changed

Lines changed: 195 additions & 137 deletions

File tree

GPU诊断.bat

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
venv\python.exe gpu_diagnostics.py
2+
pause

gpu_diagnostics.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import torch
2+
3+
def gpu_diagnostics():
4+
if torch.cuda.is_available():
5+
print("GPU 诊断报告:")
6+
print("="*40)
7+
for i in range(torch.cuda.device_count()):
8+
props = torch.cuda.get_device_properties(i)
9+
total_memory = props.total_memory / (1024 * 1024)
10+
reserved_memory = torch.cuda.memory_reserved(i) / (1024 * 1024)
11+
allocated_memory = torch.cuda.memory_allocated(i) / (1024 * 1024)
12+
free_memory = total_memory - allocated_memory
13+
14+
print(f"GPU {i}: {props.name}")
15+
print(f" 总显存 : {round(total_memory, 2)} MB")
16+
print(f" 已保留显存 : {round(reserved_memory, 2)} MB")
17+
print(f" 已分配显存 : {round(allocated_memory, 2)} MB")
18+
print(f" 空闲显存 : {round(free_memory, 2)} MB")
19+
print("="*40)
20+
else:
21+
print("未找到 GPU,使用 CPU")
22+
23+
if __name__ == "__main__":
24+
gpu_diagnostics()

requirements.txt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ filelock==3.13.1
3535
fonttools==4.47.2
3636
frozenlist==1.4.1
3737
fsspec==2023.10.0
38-
funasr==0.8.8
38+
funasr
3939
gast==0.5.4
4040
google-auth==2.26.2
4141
google-auth-oauthlib==1.2.0
@@ -137,10 +137,10 @@ tokenizers==0.15.0
137137
tomli==2.0.1
138138
tomlkit==0.12.0
139139
toolz==0.12.0
140-
torch
141-
torch-complex
142-
torchaudio
143-
torchvision
140+
torch==2.1.2+cu118
141+
torch-complex==0.4.3
142+
torchaudio==2.1.2
143+
torchvision==0.16.2+cu118
144144
tqdm==4.66.1
145145
transformers==4.36.2
146146
typer==0.9.0

short_audio_transcribe_ali.py

Lines changed: 146 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -6,110 +6,181 @@
66
from tqdm import tqdm
77
import sys
88
import os
9-
from modelscope.pipelines import pipeline
10-
from modelscope.utils.constant import Tasks
119

1210

13-
from modelscope.hub.snapshot_download import snapshot_download
14-
1511

1612
from common.constants import Languages
1713
from common.log import logger
1814
from common.stdout_wrapper import SAFE_STDOUT
1915

2016
import re
2117

22-
# 指定本地目录
23-
local_dir_root = "./models_from_modelscope"
24-
model_dir = snapshot_download('damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch', cache_dir=local_dir_root)
25-
model_dir_punc_ct = snapshot_download('damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch', cache_dir=local_dir_root)
26-
model_dir_vad = snapshot_download('damo/speech_fsmn_vad_zh-cn-16k-common-pytorch', cache_dir=local_dir_root)
27-
28-
model_dir_ja = snapshot_download('damo/speech_UniASR_asr_2pass-ja-16k-common-vocab93-tensorflow1-offline', cache_dir=local_dir_root)
29-
30-
31-
model_dir_en = snapshot_download('damo/speech_UniASR_asr_2pass-en-16k-common-vocab1080-tensorflow1-offline', cache_dir=local_dir_root)
3218

3319
device = "cuda:0" if torch.cuda.is_available() else "cpu"
3420

3521

36-
37-
inference_pipeline = pipeline(
38-
task=Tasks.auto_speech_recognition,
39-
model=model_dir,
40-
vad_model=model_dir_vad,
41-
punc_model=model_dir_punc_ct,
42-
#lm_model='damo/speech_transformer_lm_zh-cn-common-vocab8404-pytorch',
43-
#lm_weight=0.15,
44-
#beam_size=10,
45-
)
46-
param_dict = {}
47-
param_dict['use_timestamp'] = False
48-
# folderpath = sys.argv[1]
49-
extensions = ['wav']
50-
51-
52-
53-
inference_pipeline_ja = pipeline(
54-
task=Tasks.auto_speech_recognition,
55-
model=model_dir_ja,
56-
# vad_model=model_dir_vad,
57-
# punc_model=model_dir_punc_ct,
58-
#lm_model='damo/speech_transformer_lm_zh-cn-common-vocab8404-pytorch',
59-
#lm_weight=0.15,
60-
#beam_size=10,
61-
)
62-
63-
64-
inference_pipeline_en = pipeline(
65-
task=Tasks.auto_speech_recognition,
66-
model=model_dir_en,
67-
# vad_model=model_dir_vad,
68-
# punc_model=model_dir_punc_ct,
69-
#lm_model='damo/speech_transformer_lm_zh-cn-common-vocab8404-pytorch',
70-
#lm_weight=0.15,
71-
#beam_size=10,
72-
)
73-
74-
75-
model = whisper.load_model("medium",download_root="./whisper_model/")
76-
77-
22+
from funasr import AutoModel
23+
24+
model_dir = "iic/SenseVoiceSmall"
25+
26+
27+
emo_dict = {
28+
"<|HAPPY|>": "😊",
29+
"<|SAD|>": "😔",
30+
"<|ANGRY|>": "😡",
31+
"<|NEUTRAL|>": "",
32+
"<|FEARFUL|>": "😰",
33+
"<|DISGUSTED|>": "🤢",
34+
"<|SURPRISED|>": "😮",
35+
}
36+
37+
event_dict = {
38+
"<|BGM|>": "🎼",
39+
"<|Speech|>": "",
40+
"<|Applause|>": "👏",
41+
"<|Laughter|>": "😀",
42+
"<|Cry|>": "😭",
43+
"<|Sneeze|>": "🤧",
44+
"<|Breath|>": "",
45+
"<|Cough|>": "🤧",
46+
}
47+
48+
emoji_dict = {
49+
"<|nospeech|><|Event_UNK|>": "❓",
50+
"<|zh|>": "",
51+
"<|en|>": "",
52+
"<|yue|>": "",
53+
"<|ja|>": "",
54+
"<|ko|>": "",
55+
"<|nospeech|>": "",
56+
"<|HAPPY|>": "😊",
57+
"<|SAD|>": "😔",
58+
"<|ANGRY|>": "😡",
59+
"<|NEUTRAL|>": "",
60+
"<|BGM|>": "🎼",
61+
"<|Speech|>": "",
62+
"<|Applause|>": "👏",
63+
"<|Laughter|>": "😀",
64+
"<|FEARFUL|>": "😰",
65+
"<|DISGUSTED|>": "🤢",
66+
"<|SURPRISED|>": "😮",
67+
"<|Cry|>": "😭",
68+
"<|EMO_UNKNOWN|>": "",
69+
"<|Sneeze|>": "🤧",
70+
"<|Breath|>": "",
71+
"<|Cough|>": "😷",
72+
"<|Sing|>": "",
73+
"<|Speech_Noise|>": "",
74+
"<|withitn|>": "",
75+
"<|woitn|>": "",
76+
"<|GBG|>": "",
77+
"<|Event_UNK|>": "",
78+
}
79+
80+
lang_dict = {
81+
"<|zh|>": "<|lang|>",
82+
"<|en|>": "<|lang|>",
83+
"<|yue|>": "<|lang|>",
84+
"<|ja|>": "<|lang|>",
85+
"<|ko|>": "<|lang|>",
86+
"<|nospeech|>": "<|lang|>",
87+
}
88+
89+
emo_set = {"😊", "😔", "😡", "😰", "🤢", "😮"}
90+
event_set = {"🎼", "👏", "😀", "😭", "🤧", "😷",}
7891

7992
lang2token = {
8093
'zh': "ZH|",
8194
'ja': "JP|",
8295
"en": "EN|",
96+
"ko": "KO|",
97+
"yue": "YUE|",
8398
}
8499

100+
def format_str(s):
101+
for sptk in emoji_dict:
102+
s = s.replace(sptk, emoji_dict[sptk])
103+
return s
104+
105+
106+
def format_str_v2(s):
107+
sptk_dict = {}
108+
for sptk in emoji_dict:
109+
sptk_dict[sptk] = s.count(sptk)
110+
s = s.replace(sptk, "")
111+
emo = "<|NEUTRAL|>"
112+
for e in emo_dict:
113+
if sptk_dict[e] > sptk_dict[emo]:
114+
emo = e
115+
for e in event_dict:
116+
if sptk_dict[e] > 0:
117+
s = event_dict[e] + s
118+
s = s + emo_dict[emo]
119+
120+
for emoji in emo_set.union(event_set):
121+
s = s.replace(" " + emoji, emoji)
122+
s = s.replace(emoji + " ", emoji)
123+
return s.strip()
124+
125+
def format_str_v3(s):
126+
def get_emo(s):
127+
return s[-1] if s[-1] in emo_set else None
128+
def get_event(s):
129+
return s[0] if s[0] in event_set else None
130+
131+
s = s.replace("<|nospeech|><|Event_UNK|>", "❓")
132+
for lang in lang_dict:
133+
s = s.replace(lang, "<|lang|>")
134+
s_list = [format_str_v2(s_i).strip(" ") for s_i in s.split("<|lang|>")]
135+
new_s = " " + s_list[0]
136+
cur_ent_event = get_event(new_s)
137+
for i in range(1, len(s_list)):
138+
if len(s_list[i]) == 0:
139+
continue
140+
if get_event(s_list[i]) == cur_ent_event and get_event(s_list[i]) != None:
141+
s_list[i] = s_list[i][1:]
142+
#else:
143+
cur_ent_event = get_event(s_list[i])
144+
if get_emo(s_list[i]) != None and get_emo(s_list[i]) == get_emo(new_s):
145+
new_s = new_s[:-1]
146+
new_s += s_list[i].strip().lstrip()
147+
new_s = new_s.replace("The.", " ")
148+
return new_s.strip()
85149

86150
def transcribe_one(audio_path,language):
87151

88-
audio = whisper.load_audio(audio_path)
89-
audio = whisper.pad_or_trim(audio)
90-
mel = whisper.log_mel_spectrogram(audio).to(model.device)
91-
_, probs = model.detect_language(mel)
92-
language = max(probs, key=probs.get)
152+
model = AutoModel(model=model_dir,
153+
vad_model="fsmn-vad",
154+
vad_kwargs={"max_single_segment_time": 30000},
155+
trust_remote_code=True, device="cuda:0")
156+
157+
res = model.generate(
158+
input=audio_path,
159+
cache={},
160+
language=language, # "zn", "en", "yue", "ja", "ko", "nospeech"
161+
use_itn=False,
162+
batch_size_s=0,
163+
)
164+
165+
try:
93166

94-
if language == "zh":
95-
96-
rec_result = inference_pipeline(audio_in=audio_path, param_dict=param_dict)
97-
elif language == "ja":
98-
rec_result = inference_pipeline_ja(audio_in=audio_path, param_dict=param_dict)
99-
else:
100-
rec_result = inference_pipeline_en(audio_in=audio_path, param_dict=param_dict)
167+
text = res[0]["text"]
168+
text = format_str_v3(text)
169+
print(text)
170+
except Exception as e:
171+
print(e)
172+
text = ""
101173

102-
print(rec_result["text"])
103174

104-
return rec_result["text"],language
175+
return text,language
105176

106177

107178
if __name__ == "__main__":
108179

109180
parser = argparse.ArgumentParser()
110181

111182
parser.add_argument(
112-
"--language", type=str, default="ja", choices=["ja", "en", "zh"]
183+
"--language", type=str, default="ja", choices=["ja", "en", "zh","yue","ko"]
113184
)
114185
parser.add_argument("--model_name", type=str, required=True)
115186

@@ -159,6 +230,10 @@ def transcribe_one(audio_path,language):
159230
language_id = "EN"
160231
elif lang == "zh":
161232
language_id = "ZH"
233+
elif lang == "yue":
234+
language_id = "YUE"
235+
elif lang == "ko":
236+
language_id = "KO"
162237

163238
f.write(file_pos+f"{file_name}|{extracted_name.replace('.wav','')}|{language_id}|{text}\n")
164239

0 commit comments

Comments
 (0)