Skip to content

Commit 2f11d8d

Browse files
committed
添加功能,批量转写,根据字幕手动转写
1 parent 9db84a1 commit 2f11d8d

10 files changed

Lines changed: 1037 additions & 49 deletions

argparse_tools.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import argparse
2+
from pathlib import Path
3+
4+
import yaml
5+
import sys
6+
7+
8+
class ArgumentParser(argparse.ArgumentParser):
9+
"""Simple implementation of ArgumentParser supporting config file
10+
11+
This class is originated from https://github.com/bw2/ConfigArgParse,
12+
but this class is lack of some features that it has.
13+
14+
- Not supporting multiple config files
15+
- Automatically adding "--config" as an option.
16+
- Not supporting any formats other than yaml
17+
- Not checking argument type
18+
19+
"""
20+
21+
def __init__(self, *args, **kwargs):
22+
super().__init__(*args, **kwargs)
23+
self.add_argument("--config", help="Give config file in yaml format")
24+
25+
def parse_known_args(self, args=None, namespace=None):
26+
# Once parsing for setting from "--config"
27+
_args, _ = super().parse_known_args(args, namespace)
28+
if _args.config is not None:
29+
if not Path(_args.config).exists():
30+
self.error(f"No such file: {_args.config}")
31+
32+
with open(_args.config, "r", encoding="utf-8") as f:
33+
d = yaml.safe_load(f)
34+
if not isinstance(d, dict):
35+
self.error("Config file has non dict value: {_args.config}")
36+
37+
for key in d:
38+
for action in self._actions:
39+
if key == action.dest:
40+
break
41+
else:
42+
self.error(f"unrecognized arguments: {key} (from {_args.config})")
43+
44+
# NOTE(kamo): Ignore "--config" from a config file
45+
# NOTE(kamo): Unlike "configargparse", this module doesn't check type.
46+
# i.e. We can set any type value regardless of argument type.
47+
self.set_defaults(**d)
48+
return super().parse_known_args(args, namespace)
49+
50+
51+
def get_commandline_args():
52+
extra_chars = [
53+
" ",
54+
";",
55+
"&",
56+
"(",
57+
")",
58+
"|",
59+
"^",
60+
"<",
61+
">",
62+
"?",
63+
"*",
64+
"[",
65+
"]",
66+
"$",
67+
"`",
68+
'"',
69+
"\\",
70+
"!",
71+
"{",
72+
"}",
73+
]
74+
75+
# Escape the extra characters for shell
76+
argv = [
77+
arg.replace("'", "'\\''")
78+
if all(char not in arg for char in extra_chars)
79+
else "'" + arg.replace("'", "'\\''") + "'"
80+
for arg in sys.argv
81+
]
82+
83+
return sys.executable + " " + " ".join(argv)

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,3 +156,5 @@ xxhash==3.4.1
156156
yapf==0.40.2
157157
yarl==1.9.4
158158
zipp==3.17.0
159+
faster-whisper
160+
moviepy

short_audio_transcribe_ali.py

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import argparse
3-
3+
import whisper
4+
import torch
45

56
from tqdm import tqdm
67
import sys
@@ -16,6 +17,8 @@
1617
from common.log import logger
1718
from common.stdout_wrapper import SAFE_STDOUT
1819

20+
import re
21+
1922
# 指定本地目录
2023
local_dir_root = "./models_from_modelscope"
2124
model_dir = snapshot_download('damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch', cache_dir=local_dir_root)
@@ -27,7 +30,7 @@
2730

2831
model_dir_en = snapshot_download('damo/speech_UniASR_asr_2pass-en-16k-common-vocab1080-tensorflow1-offline', cache_dir=local_dir_root)
2932

30-
33+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
3134

3235

3336

@@ -69,6 +72,9 @@
6972
)
7073

7174

75+
model = whisper.load_model("medium",download_root="./whisper_model/")
76+
77+
7278

7379
lang2token = {
7480
'zh': "ZH|",
@@ -79,6 +85,12 @@
7985

8086
def transcribe_one(audio_path,language):
8187

88+
audio = whisper.load_audio(audio_path)
89+
audio = whisper.pad_or_trim(audio)
90+
mel = whisper.log_mel_spectrogram(audio).to(model.device)
91+
_, probs = model.detect_language(mel)
92+
language = max(probs, key=probs.get)
93+
8294
if language == "zh":
8395

8496
rec_result = inference_pipeline(audio_in=audio_path, param_dict=param_dict)
@@ -89,7 +101,7 @@ def transcribe_one(audio_path,language):
89101

90102
print(rec_result["text"])
91103

92-
return rec_result["text"]
104+
return rec_result["text"],language
93105

94106

95107
if __name__ == "__main__":
@@ -127,22 +139,28 @@ def transcribe_one(audio_path,language):
127139
]
128140

129141

130-
if language == "ja":
131-
language_id = Languages.JP
132-
elif language == "en":
133-
language_id = Languages.EN
134-
elif language == "zh":
135-
language_id = Languages.ZH
136-
else:
137-
raise ValueError(f"{language} is not supported.")
138-
139142
with open("./esd.list", "w", encoding="utf-8") as f:
140143
for wav_file in tqdm(wav_files, file=SAFE_STDOUT):
141144
file_name = os.path.basename(wav_file)
142145

143-
text = transcribe_one(f"{input_file}"+wav_file,language)
144-
145-
f.write(file_pos+f"{file_name}|{speaker_name}|{language_id}|{text}\n")
146+
text,lang = transcribe_one(f"{input_file}"+wav_file,language)
147+
148+
# 使用正则表达式提取'deedee'
149+
match = re.search(r'(^.*?)_.*?(\..*?$)', wav_file)
150+
if match:
151+
extracted_name = match.group(1) + match.group(2)
152+
else:
153+
print("No match found")
154+
extracted_name = "sample"
155+
156+
if lang == "ja":
157+
language_id = "JA"
158+
elif lang == "en":
159+
language_id = "EN"
160+
elif lang == "zh":
161+
language_id = "ZH"
162+
163+
f.write(file_pos+f"{file_name}|{extracted_name.replace('.wav','')}|{language_id}|{text}\n")
146164

147165
f.flush()
148166
sys.exit(0)

short_audio_transcribe_bcut.py

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,17 @@
1313
from bcut_asr import BcutASR
1414
from bcut_asr.orm import ResultStateEnum
1515

16+
import whisper
17+
import torch
18+
19+
import re
20+
21+
22+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
23+
24+
25+
model = whisper.load_model("medium",download_root="./whisper_model/")
26+
1627

1728

1829
lang2token = {
@@ -24,6 +35,12 @@
2435

2536
def transcribe_one(audio_path):
2637

38+
audio = whisper.load_audio(audio_path)
39+
audio = whisper.pad_or_trim(audio)
40+
mel = whisper.log_mel_spectrogram(audio).to(model.device)
41+
_, probs = model.detect_language(mel)
42+
language = max(probs, key=probs.get)
43+
2744
asr = BcutASR(audio_path)
2845
asr.upload() # 上传文件
2946
asr.create_task() # 创建任务
@@ -52,9 +69,9 @@ def transcribe_one(audio_path):
5269
print(text)
5370

5471
# 输出srt格式
55-
return text
72+
return text,language
5673
else:
57-
return "必剪无法识别"
74+
return "必剪无法识别",language
5875

5976

6077

@@ -91,22 +108,28 @@ def transcribe_one(audio_path):
91108
]
92109

93110

94-
if language == "ja":
95-
language_id = Languages.JP
96-
elif language == "en":
97-
language_id = Languages.EN
98-
elif language == "zh":
99-
language_id = Languages.ZH
100-
else:
101-
raise ValueError(f"{language} is not supported.")
102-
103111
with open("./esd.list", "w", encoding="utf-8") as f:
104112
for wav_file in tqdm(wav_files, file=SAFE_STDOUT):
105113
file_name = os.path.basename(wav_file)
114+
115+
# 使用正则表达式提取'deedee'
116+
match = re.search(r'(^.*?)_.*?(\..*?$)', wav_file)
117+
if match:
118+
extracted_name = match.group(1) + match.group(2)
119+
else:
120+
print("No match found")
121+
extracted_name = "sample"
106122

107-
text = transcribe_one(f"{input_file}"+wav_file)
123+
text,lang = transcribe_one(f"{input_file}"+wav_file)
124+
125+
if lang == "ja":
126+
language_id = "JA"
127+
elif lang == "en":
128+
language_id = "EN"
129+
elif lang == "zh":
130+
language_id = "ZH"
108131

109-
f.write(file_pos+f"{file_name}|{speaker_name}|{language_id}|{text}\n")
132+
f.write(file_pos+f"{file_name}|{extracted_name.replace('.wav','')}|{language_id}|{text}\n")
110133

111134
f.flush()
112135
sys.exit(0)

0 commit comments

Comments
 (0)