diff --git a/docker-setup.sh b/docker-setup.sh index 04ea547a3..2067ef39d 100755 --- a/docker-setup.sh +++ b/docker-setup.sh @@ -74,6 +74,17 @@ else $DOWNLOAD_CMD "data/.config.yaml" "https://raw.githubusercontent.com/xinnan-tech/xiaozhi-esp32-server/main/main/xiaozhi-server/config.yaml" fi +# 下载量化模型 +echo "下载量化模型..." +mkdir -p models/all-MiniLM-L6-v2 +if [ "$DOWNLOAD_CMD" = "powershell -Command Invoke-WebRequest -Uri" ]; then + $DOWNLOAD_CMD "https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/v0.2/all-MiniLM-L6-v2.zip" $DOWNLOAD_CMD_SUFFIX "models/all-MiniLM-L6-v2/all-MiniLM-L6-v2.zip" + +else + $DOWNLOAD_CMD "models/all-MiniLM-L6-v2.zip" "https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/v0.2/all-MiniLM-L6-v2.zip" +fi +unzip models/all-MiniLM-L6-v2.zip -d models/all-MiniLM-L6-v2 + # 检查文件是否存在 echo "检查文件完整性..." FILES_TO_CHECK="docker-compose.yml data/.config.yaml models/SenseVoiceSmall/model.pt" diff --git a/docs/Deployment.md b/docs/Deployment.md index dfac9dcde..14241f678 100644 --- a/docs/Deployment.md +++ b/docs/Deployment.md @@ -306,6 +306,17 @@ LLM: - 线路二:百度网盘下载[SenseVoiceSmall](https://pan.baidu.com/share/init?surl=QlgM58FHhYv1tFnUT_A8Sg&pwd=qvna) 提取码: `qvna` +如果需要使用 oceanbase 的长期记忆向量化能力需要额外下载量化模型,用于把历史对话向量化到成向量化数据。目前支持 `all-MiniLM-L6-v2` 模型。因为模型较大,需要独立下载, + +``` +wget https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/v0.2/all-MiniLM-L6-v2.zip +mkdir models/all-MiniLM-L6-v2 +``` +解压后放在`models/all-MiniLM-L6-v2`目录下。 + + +``` + ## 运行状态确认 如果你能看到,类似以下日志,则是本项目服务启动成功的标志。 diff --git a/main/xiaozhi-server/config.yaml b/main/xiaozhi-server/config.yaml index f1ac85c06..bebad1b67 100644 --- a/main/xiaozhi-server/config.yaml +++ b/main/xiaozhi-server/config.yaml @@ -194,6 +194,13 @@ Memory: mem_local_short: # 本地记忆功能,通过selected_module的llm总结,数据保存在本地,不会上传到服务器 type: mem_local_short + oceanbase: + uri: "*****:**" + user: "***" + password: "***" + database: xiaozhi + table: xiaozhi + model_path: "models/all-MiniLM-L6-v2" ASR: FunASR: diff --git a/main/xiaozhi-server/core/providers/memory/oceanbase/oceanbase.py b/main/xiaozhi-server/core/providers/memory/oceanbase/oceanbase.py new file mode 100644 index 000000000..7e82a5dc3 --- /dev/null +++ b/main/xiaozhi-server/core/providers/memory/oceanbase/oceanbase.py @@ -0,0 +1,92 @@ +import os.path +import traceback +from pyobvector import MilvusLikeClient +from sentence_transformers import SentenceTransformer +from core.providers.memory.base import MemoryProviderBase ,logger + +TAG = __name__ + +create_table_sql = """ +CREATE TABLE xiaozhi( + id INT AUTO_INCREMENT PRIMARY KEY, + role VARCHAR(200), + content text, + embedding VECTOR(384), + VECTOR INDEX idx1(embedding) WITH (distance=L2, type=hnsw) + ); +""" + + +class MemoryProvider(MemoryProviderBase): + def __init__(self, config): + super().__init__(config) + self.uri = config.get("uri", "localhost") + self.user = config.get("user", "root@test") + self.password = config.get("password", "") + self.db_name = config.get("database", "xiaozhi") + self.table_name = config.get("table_name", "xiaozhi") + self.model_path = config.get("model_path", "models/all-MiniLM-L6-v2") + self.model_path = os.path.abspath(self.model_path) + if not os.path.exists(self.model_path): + raise Exception(f"模型路径不存在,请下载量化模型 https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/v0.2/all-MiniLM-L6-v2.zip 并解压到: {self.model_path}") + logger.bind(tag=TAG).info(f"连接到 oceanbase 服务: {self.uri}") + self.client = self.connect_to_client() + + def connect_to_client(self): + try: + return MilvusLikeClient(uri=self.uri, user=self.user,password=self.password, db_name=self.db_name) + except Exception as e: + logger.bind(tag=TAG).error(f"连接到 oceanbase 服务时发生错误: {str(e)}") + logger.bind(tag=TAG).error(f"请检查配置并确认表是否存在: 初始化sql: {create_table_sql}") + logger.bind(tag=TAG).error(f"详细错误: {traceback.format_exc()}") + return None + + def _string_to_embeddings(self, sentences): + # 加载预训练的 'all-MiniLM-L6-v2' 模型 + model = SentenceTransformer(self.model_path ) + embeddings = model.encode(sentences) + return embeddings + + def init_memory(self, role_id, llm): + super().init_memory(role_id, llm) + pass + + async def save_memory(self, msgs): + if not self.client or len(msgs) < 2: + return None + + try: + + messages =[] + for message in msgs: + if message.role != "system": + if message.content: + messages.append({"role": message.role, "content": message.content, + "embedding": self._string_to_embeddings(message.content)}) + + for i in range(0, len(messages), 1): + self.client.insert(collection_name=self.table_name, data=messages[i:i+1]) + logger.bind(tag=TAG).info(f"Save memory") + except Exception as e: + logger.bind(tag=TAG).error(f"保存记忆失败: {traceback.format_exc()}") + return None + + async def query_memory(self, query: str) -> str: + if not self.client: + return "" + + # 把 query 向量化 + query = self._string_to_embeddings(query) + + try: + results = self.client.search( + collection_name=self.table_name, + data=query, + anns_field="embedding", + limit=5, + output_fields=["role", "content"] + ) + return results + except Exception as e: + logger.bind(tag=TAG).error(f"查询记忆失败: {traceback.format_exc()}") + return "" diff --git a/main/xiaozhi-server/requirements.txt b/main/xiaozhi-server/requirements.txt index 66128cff9..f85d6c842 100755 --- a/main/xiaozhi-server/requirements.txt +++ b/main/xiaozhi-server/requirements.txt @@ -23,5 +23,8 @@ bs4==0.0.2 modelscope==1.23.2 sherpa_onnx==1.11.0 mcp==1.4.1 +pyobvector==0.2.4 cnlunar==0.2.0 PySocks==1.7.1 +sentence-transformers==4.0.1 +transformers==4.50.3