Skip to content

Commit

Permalink
让 config.py 中的设置项都生效
Browse files Browse the repository at this point in the history
  • Loading branch information
HaujetZhao committed Jan 8, 2024
1 parent f4e5d9a commit 4623fb9
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 25 deletions.
2 changes: 0 additions & 2 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@ class ModelPaths:
paraformer_path = Path() / 'models' / 'paraformer-offline-zh' / 'model.int8.onnx'
tokens_path = Path() / 'models' / 'paraformer-offline-zh' / 'tokens.txt'
punc_model_dir = Path() / 'models' / 'punc_ct-transformer_cn-en'




class ParaformerArgs:
Expand Down
10 changes: 6 additions & 4 deletions util/client_recv_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import keyboard
import websockets
from config import ClientConfig as Config
from util.client_cosmic import Cosmic, console
from util.client_check_websocket import check_websocket
from util.client_hot_sub import hot_sub
Expand Down Expand Up @@ -37,11 +38,12 @@ async def recv_result():
# 打字
await type_result(text)

# 重命名录音文件
file_audio = rename_audio(message['task_id'], text, message['time_start'])
if Config.save_audio:
# 重命名录音文件
file_audio = rename_audio(message['task_id'], text, message['time_start'])

# 记录写入 md 文件
write_md(text, message['time_start'], file_audio)
# 记录写入 md 文件
write_md(text, message['time_start'], file_audio)

# 控制台输出
console.print(f' 转录时延:{delay:.2f}s')
Expand Down
5 changes: 3 additions & 2 deletions util/client_rename_audio.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from util.client_cosmic import Cosmic, console
from pathlib import Path
from typing import Union
import time
from util.client_cosmic import Cosmic, console
from config import ClientConfig as Config
from os import makedirs


Expand All @@ -19,7 +20,7 @@ def rename_audio(task_id, text, time_start) -> Union[Path, None]:
time_year = time.strftime('%Y', time.localtime(time_start))
time_month = time.strftime('%m', time.localtime(time_start))
time_ymdhms = time.strftime("%Y%m%d-%H%M%S", time.localtime(time_start))
file_stem = f'({time_ymdhms}){text[:20]}'
file_stem = f'({time_ymdhms}){text[:Config.audio_name_len]}'

# 重命名
file_path_new = file_path.with_name(file_stem + file_path.suffix)
Expand Down
12 changes: 7 additions & 5 deletions util/client_send_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,14 @@ async def send_audio():
Cosmic.queue_in.task_done()
if task['type'] == 'begin':
time_start = task['time']
elif task['type'] == 'data':
elif task['type'] == 'data':
# 在阈值之前积攒音频数据
if task['time'] - time_start < Config.threshold:
cache.append(task['data'])
continue

# 创建音频文件
if not file_path:
if Config.save_audio and not file_path:
file_path, file = create_file(task['data'].shape[1], time_start)
Cosmic.audio_files[task_id] = file_path

Expand All @@ -72,7 +72,8 @@ async def send_audio():

# 保存音频至本地文件
duration += len(data) / 48000
write_file(file, data)
if Config.save_audio:
write_file(file, data)

# 发送音频数据用于识别
message = {
Expand All @@ -90,7 +91,8 @@ async def send_audio():
task = asyncio.create_task(send_message(message))
elif task['type'] == 'finish':
# 完成写入本地文件
finish_file(file)
if Config.save_audio:
finish_file(file)

console.print(f'任务标识:{task_id}')
console.print(f' 录音时长:{duration:.2f}s')
Expand All @@ -109,4 +111,4 @@ async def send_audio():
task = asyncio.create_task(send_message(message))
break
except Exception as e:
print(e)
print(e)
7 changes: 6 additions & 1 deletion util/hot_kwds.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@

from config import ClientConfig as Config

kwd_list = []

Expand All @@ -9,6 +9,11 @@ def do_updata_kwd(kwd_text: str):
kwd_list.clear()
kwd_list.append('')

# 如果不启用关键词功能,直接返回
if not Config.hot_kwd:
return len(kwd_list)

# 更新关键词
for kwd in kwd_text.splitlines():
kwd = kwd.strip()
if not kwd or kwd.startswith('#'):
Expand Down
27 changes: 16 additions & 11 deletions util/server_recognize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np

from util.server_cosmic import console
from config import ServerConfig as Config
from util.server_classes import Task, Result
from util.chinese_itn import chinese_to_num
from util.format_tools import adjust_space
Expand All @@ -13,6 +14,18 @@
results = {}


def format_text(text, punc_model):
if Config.format_spell:
text = adjust_space(text) # 调空格
if Config.format_punc and punc_model and text:
text = punc_model(text)[0] # 加标点
if Config.format_num:
text = chinese_to_num(text) # 转数字
if Config.format_spell:
text = adjust_space(text) # 调空格
return text


def recognize(recognizer, punc_model, task: Task):

# inspect({key:value for key, value in task.__dict__.items() if not key.startswith('_') and key != 'data'})
Expand Down Expand Up @@ -42,7 +55,7 @@ def recognize(recognizer, punc_model, task: Task):
result.time_submit = task.time_submit
result.time_complete = time.time()

# 先粗去重
# 先粗去重,依据:字级时间戳
m = n = len(stream.result.timestamps)
for i, timestamp in enumerate(stream.result.timestamps, start=0):
if timestamp > task.overlap / 2:
Expand All @@ -57,7 +70,7 @@ def recognize(recognizer, punc_model, task: Task):
if task.is_final:
n = len(stream.result.timestamps)

# 再细去重
# 再细去重,依据:在端点是否有重复的字
if result.tokens and result.tokens[-2:] == stream.result.tokens[m:n][:2]:
m += 2
elif result.tokens and result.tokens[-1:] == stream.result.tokens[m:n][:1]:
Expand All @@ -67,7 +80,6 @@ def recognize(recognizer, punc_model, task: Task):
result.timestamps += [t + task.offset for t in stream.result.timestamps[m:n]]
result.tokens += [token for token in stream.result.tokens[m:n]]


# token 合并为文本
text = ' '.join(result.tokens).replace('@@ ', '')
text = re.sub('([^a-zA-Z0-9]) (?![a-zA-Z0-9])', r'\1', text)
Expand All @@ -78,14 +90,7 @@ def recognize(recognizer, punc_model, task: Task):
return result

# 调整文本格式
text = adjust_space(text) # 调空格
if punc_model and text:
text = punc_model(text)[0] # 加标点
text = chinese_to_num(text) # 转数字
text = adjust_space(text) # 调空格

result.text = text

result.text = format_text(text, punc_model)

# 若最后一个片段完成识别,从字典摘取任务
result = results.pop(task.task_id)
Expand Down

0 comments on commit 4623fb9

Please sign in to comment.