|
| 1 | +import os |
| 2 | +import json |
| 3 | +import time |
| 4 | +from typing import Optional, Tuple, List |
| 5 | +from .base import ASRProviderBase |
| 6 | +from config.logger import setup_logging |
| 7 | +from core.providers.asr.dto.dto import InterfaceType |
| 8 | +import vosk |
| 9 | + |
| 10 | +TAG = __name__ |
| 11 | +logger = setup_logging() |
| 12 | + |
| 13 | +class ASRProvider(ASRProviderBase): |
| 14 | + def __init__(self, config: dict, delete_audio_file: bool = True): |
| 15 | + super().__init__() |
| 16 | + self.interface_type = InterfaceType.LOCAL |
| 17 | + self.model_path = config.get("model_path") |
| 18 | + self.output_dir = config.get("output_dir", "tmp/") |
| 19 | + self.delete_audio_file = delete_audio_file |
| 20 | + |
| 21 | + # 初始化VOSK模型 |
| 22 | + self.model = None |
| 23 | + self.recognizer = None |
| 24 | + self._load_model() |
| 25 | + |
| 26 | + # 确保输出目录存在 |
| 27 | + os.makedirs(self.output_dir, exist_ok=True) |
| 28 | + |
| 29 | + def _load_model(self): |
| 30 | + """加载VOSK模型""" |
| 31 | + try: |
| 32 | + if not os.path.exists(self.model_path): |
| 33 | + raise FileNotFoundError(f"VOSK模型路径不存在: {self.model_path}") |
| 34 | + |
| 35 | + logger.bind(tag=TAG).info(f"正在加载VOSK模型: {self.model_path}") |
| 36 | + self.model = vosk.Model(self.model_path) |
| 37 | + |
| 38 | + # 初始化VOSK识别器(采样率必须为16kHz) |
| 39 | + self.recognizer = vosk.KaldiRecognizer(self.model, 16000) |
| 40 | + |
| 41 | + logger.bind(tag=TAG).info("VOSK模型加载成功") |
| 42 | + except Exception as e: |
| 43 | + logger.bind(tag=TAG).error(f"加载VOSK模型失败: {e}") |
| 44 | + raise |
| 45 | + |
| 46 | + async def speech_to_text( |
| 47 | + self, audio_data: List[bytes], session_id: str, audio_format: str = "opus" |
| 48 | + ) -> Tuple[Optional[str], Optional[str]]: |
| 49 | + """将语音数据转换为文本""" |
| 50 | + file_path = None |
| 51 | + try: |
| 52 | + # 检查模型是否加载成功 |
| 53 | + if not self.model: |
| 54 | + logger.bind(tag=TAG).error("VOSK模型未加载,无法进行识别") |
| 55 | + return "", None |
| 56 | + |
| 57 | + # 解码音频(如果原始格式是Opus) |
| 58 | + if audio_format == "pcm": |
| 59 | + pcm_data = audio_data |
| 60 | + else: |
| 61 | + pcm_data = self.decode_opus(audio_data) |
| 62 | + |
| 63 | + if not pcm_data: |
| 64 | + logger.bind(tag=TAG).warning("解码后的PCM数据为空,无法进行识别") |
| 65 | + return "", None |
| 66 | + |
| 67 | + # 合并PCM数据 |
| 68 | + combined_pcm_data = b"".join(pcm_data) |
| 69 | + if len(combined_pcm_data) == 0: |
| 70 | + logger.bind(tag=TAG).warning("合并后的PCM数据为空") |
| 71 | + return "", None |
| 72 | + |
| 73 | + # 判断是否保存为WAV文件 |
| 74 | + if not self.delete_audio_file: |
| 75 | + file_path = self.save_audio_to_file(pcm_data, session_id) |
| 76 | + |
| 77 | + start_time = time.time() |
| 78 | + |
| 79 | + |
| 80 | + # 进行识别(VOSK推荐每次送入2000字节的数据) |
| 81 | + chunk_size = 2000 |
| 82 | + text_result = "" |
| 83 | + |
| 84 | + for i in range(0, len(combined_pcm_data), chunk_size): |
| 85 | + chunk = combined_pcm_data[i:i+chunk_size] |
| 86 | + if self.recognizer.AcceptWaveform(chunk): |
| 87 | + result = json.loads(self.recognizer.Result()) |
| 88 | + text = result.get('text', '') |
| 89 | + if text: |
| 90 | + text_result += text + " " |
| 91 | + |
| 92 | + # 获取最终结果 |
| 93 | + final_result = json.loads(self.recognizer.FinalResult()) |
| 94 | + final_text = final_result.get('text', '') |
| 95 | + if final_text: |
| 96 | + text_result += final_text |
| 97 | + |
| 98 | + logger.bind(tag=TAG).debug( |
| 99 | + f"VOSK语音识别耗时: {time.time() - start_time:.3f}s | 结果: {text_result.strip()}" |
| 100 | + ) |
| 101 | + |
| 102 | + return text_result.strip(), file_path |
| 103 | + |
| 104 | + except Exception as e: |
| 105 | + logger.bind(tag=TAG).error(f"VOSK语音识别失败: {e}") |
| 106 | + return "", None |
| 107 | + finally: |
| 108 | + # 文件清理逻辑 |
| 109 | + if self.delete_audio_file and file_path and os.path.exists(file_path): |
| 110 | + try: |
| 111 | + os.remove(file_path) |
| 112 | + logger.bind(tag=TAG).debug(f"已删除临时音频文件: {file_path}") |
| 113 | + except Exception as e: |
| 114 | + logger.bind(tag=TAG).error(f"文件删除失败: {file_path} | 错误: {e}") |
0 commit comments