|
1 | 1 | import asyncio
|
| 2 | +import contextlib |
2 | 3 | import hashlib
|
3 | 4 | import logging
|
4 | 5 | import os
|
5 | 6 | import socket
|
6 | 7 | import subprocess
|
7 | 8 | import sys
|
8 |
| -from typing import Any, AsyncGenerator |
| 9 | +from asyncio.subprocess import Process |
| 10 | +from dataclasses import dataclass |
| 11 | +from typing import Any, AsyncGenerator, Optional |
9 | 12 | from urllib.parse import urlparse
|
10 | 13 |
|
11 | 14 | import chromadb
|
|
16 | 19 | from chromadb.config import APIVersion, Settings
|
17 | 20 | from chromadb.utils import embedding_functions
|
18 | 21 |
|
19 |
| -from vectorcode.cli_utils import Config, expand_path |
| 22 | +from vectorcode.cli_utils import Config, LockManager, expand_path |
20 | 23 |
|
21 | 24 | logger = logging.getLogger(name=__name__)
|
22 | 25 |
|
@@ -112,32 +115,6 @@ async def start_server(configs: Config):
|
112 | 115 | return process
|
113 | 116 |
|
114 | 117 |
|
115 |
| -__CLIENT_CACHE: dict[str, AsyncClientAPI] = {} |
116 |
| - |
117 |
| - |
118 |
| -async def get_client(configs: Config) -> AsyncClientAPI: |
119 |
| - client_entry = configs.db_url |
120 |
| - if __CLIENT_CACHE.get(client_entry) is None: |
121 |
| - settings: dict[str, Any] = {"anonymized_telemetry": False} |
122 |
| - if isinstance(configs.db_settings, dict): |
123 |
| - valid_settings = { |
124 |
| - k: v for k, v in configs.db_settings.items() if k in Settings.__fields__ |
125 |
| - } |
126 |
| - settings.update(valid_settings) |
127 |
| - parsed_url = urlparse(configs.db_url) |
128 |
| - settings["chroma_server_host"] = parsed_url.hostname or "127.0.0.1" |
129 |
| - settings["chroma_server_http_port"] = parsed_url.port or 8000 |
130 |
| - settings["chroma_server_ssl_enabled"] = parsed_url.scheme == "https" |
131 |
| - settings["chroma_server_api_default_path"] = parsed_url.path or APIVersion.V2 |
132 |
| - settings_obj = Settings(**settings) |
133 |
| - __CLIENT_CACHE[client_entry] = await chromadb.AsyncHttpClient( |
134 |
| - settings=settings_obj, |
135 |
| - host=str(settings_obj.chroma_server_host), |
136 |
| - port=int(settings_obj.chroma_server_http_port or 8000), |
137 |
| - ) |
138 |
| - return __CLIENT_CACHE[client_entry] |
139 |
| - |
140 |
| - |
141 | 118 | def get_collection_name(full_path: str) -> str:
|
142 | 119 | full_path = str(expand_path(full_path, absolute=True))
|
143 | 120 | hasher = hashlib.sha256()
|
@@ -261,3 +238,80 @@ async def list_collection_files(collection: AsyncCollection) -> list[str]:
|
261 | 238 | or []
|
262 | 239 | )
|
263 | 240 | )
|
| 241 | + |
| 242 | + |
| 243 | +@dataclass |
| 244 | +class _ClientModel: |
| 245 | + client: AsyncClientAPI |
| 246 | + is_bundled: bool = False |
| 247 | + process: Optional[Process] = None |
| 248 | + |
| 249 | + |
| 250 | +class ClientManager: |
| 251 | + singleton: Optional["ClientManager"] = None |
| 252 | + __clients: dict[str, _ClientModel] |
| 253 | + |
| 254 | + def __new__(cls) -> "ClientManager": |
| 255 | + if cls.singleton is None: |
| 256 | + cls.singleton = super().__new__(cls) |
| 257 | + cls.singleton.__clients = {} |
| 258 | + return cls.singleton |
| 259 | + |
| 260 | + @contextlib.asynccontextmanager |
| 261 | + async def get_client(self, configs: Config, need_lock: bool = True): |
| 262 | + project_root = str(expand_path(str(configs.project_root), True)) |
| 263 | + is_bundled = False |
| 264 | + if self.__clients.get(project_root) is None: |
| 265 | + process = None |
| 266 | + if not await try_server(configs.db_url): |
| 267 | + logger.info(f"Starting a new server at {configs.db_url}") |
| 268 | + process = await start_server(configs) |
| 269 | + is_bundled = True |
| 270 | + |
| 271 | + self.__clients[project_root] = _ClientModel( |
| 272 | + client=await self._create_client(configs), |
| 273 | + is_bundled=is_bundled, |
| 274 | + process=process, |
| 275 | + ) |
| 276 | + lock = None |
| 277 | + if self.__clients[project_root].is_bundled and need_lock: |
| 278 | + lock = LockManager().get_lock(str(configs.db_path)) |
| 279 | + logger.debug(f"Locking {configs.db_path}") |
| 280 | + await lock.acquire() |
| 281 | + yield self.__clients[project_root].client |
| 282 | + if lock is not None: |
| 283 | + logger.debug(f"Unlocking {configs.db_path}") |
| 284 | + await lock.release() |
| 285 | + |
| 286 | + def get_processes(self) -> list[Process]: |
| 287 | + return [i.process for i in self.__clients.values() if i.process is not None] |
| 288 | + |
| 289 | + async def kill_servers(self): |
| 290 | + termination_tasks: list[asyncio.Task] = [] |
| 291 | + for p in self.get_processes(): |
| 292 | + logger.info(f"Killing bundled chroma server with PID: {p.pid}") |
| 293 | + p.terminate() |
| 294 | + termination_tasks.append(asyncio.create_task(p.wait())) |
| 295 | + await asyncio.gather(*termination_tasks) |
| 296 | + |
| 297 | + async def _create_client(self, configs: Config) -> AsyncClientAPI: |
| 298 | + settings: dict[str, Any] = {"anonymized_telemetry": False} |
| 299 | + if isinstance(configs.db_settings, dict): |
| 300 | + valid_settings = { |
| 301 | + k: v for k, v in configs.db_settings.items() if k in Settings.__fields__ |
| 302 | + } |
| 303 | + settings.update(valid_settings) |
| 304 | + parsed_url = urlparse(configs.db_url) |
| 305 | + settings["chroma_server_host"] = parsed_url.hostname or "127.0.0.1" |
| 306 | + settings["chroma_server_http_port"] = parsed_url.port or 8000 |
| 307 | + settings["chroma_server_ssl_enabled"] = parsed_url.scheme == "https" |
| 308 | + settings["chroma_server_api_default_path"] = parsed_url.path or APIVersion.V2 |
| 309 | + settings_obj = Settings(**settings) |
| 310 | + return await chromadb.AsyncHttpClient( |
| 311 | + settings=settings_obj, |
| 312 | + host=str(settings_obj.chroma_server_host), |
| 313 | + port=int(settings_obj.chroma_server_http_port or 8000), |
| 314 | + ) |
| 315 | + |
| 316 | + def clear(self): |
| 317 | + self.__clients.clear() |
0 commit comments