|
1 | 1 | import asyncio |
| 2 | +import contextlib |
2 | 3 | import hashlib |
3 | 4 | import logging |
4 | 5 | import os |
|
18 | 19 | from chromadb.config import APIVersion, Settings |
19 | 20 | from chromadb.utils import embedding_functions |
20 | 21 |
|
21 | | -from vectorcode.cli_utils import Config, expand_path |
| 22 | +from vectorcode.cli_utils import Config, LockManager, expand_path |
22 | 23 |
|
23 | 24 | logger = logging.getLogger(name=__name__) |
24 | 25 |
|
@@ -114,26 +115,6 @@ async def start_server(configs: Config): |
114 | 115 | return process |
115 | 116 |
|
116 | 117 |
|
117 | | -async def get_client(configs: Config) -> AsyncClientAPI: |
118 | | - settings: dict[str, Any] = {"anonymized_telemetry": False} |
119 | | - if isinstance(configs.db_settings, dict): |
120 | | - valid_settings = { |
121 | | - k: v for k, v in configs.db_settings.items() if k in Settings.__fields__ |
122 | | - } |
123 | | - settings.update(valid_settings) |
124 | | - parsed_url = urlparse(configs.db_url) |
125 | | - settings["chroma_server_host"] = parsed_url.hostname or "127.0.0.1" |
126 | | - settings["chroma_server_http_port"] = parsed_url.port or 8000 |
127 | | - settings["chroma_server_ssl_enabled"] = parsed_url.scheme == "https" |
128 | | - settings["chroma_server_api_default_path"] = parsed_url.path or APIVersion.V2 |
129 | | - settings_obj = Settings(**settings) |
130 | | - return await chromadb.AsyncHttpClient( |
131 | | - settings=settings_obj, |
132 | | - host=str(settings_obj.chroma_server_host), |
133 | | - port=int(settings_obj.chroma_server_http_port or 8000), |
134 | | - ) |
135 | | - |
136 | | - |
137 | 118 | def get_collection_name(full_path: str) -> str: |
138 | 119 | full_path = str(expand_path(full_path, absolute=True)) |
139 | 120 | hasher = hashlib.sha256() |
@@ -276,27 +257,61 @@ def __new__(cls) -> "ClientManager": |
276 | 257 | cls.__singleton.__clients = {} |
277 | 258 | return cls.__singleton |
278 | 259 |
|
279 | | - async def get_client(self, configs: Config) -> _ClientModel: |
| 260 | + @contextlib.asynccontextmanager |
| 261 | + async def get_client(self, configs: Config, need_lock: bool = True): |
280 | 262 | project_root = str(expand_path(str(configs.project_root), True)) |
| 263 | + is_bundled = False |
281 | 264 | if self.__clients.get(project_root) is None: |
282 | | - is_bundled = False |
283 | 265 | process = None |
284 | 266 | if not await try_server(configs.db_url): |
285 | 267 | logger.info(f"Starting a new server at {configs.db_url}") |
286 | 268 | process = await start_server(configs) |
287 | 269 | is_bundled = True |
288 | 270 |
|
289 | 271 | self.__clients[project_root] = _ClientModel( |
290 | | - client=await get_client(configs), is_bundled=is_bundled, process=process |
| 272 | + client=await self._create_client(configs), |
| 273 | + is_bundled=is_bundled, |
| 274 | + process=process, |
291 | 275 | ) |
292 | | - return self.__clients[project_root] |
| 276 | + lock = None |
| 277 | + if self.__clients[project_root].is_bundled and need_lock: |
| 278 | + lock = LockManager().get_lock(str(configs.db_path)) |
| 279 | + logger.debug(f"Locking {configs.db_path}") |
| 280 | + await lock.acquire() |
| 281 | + yield self.__clients[project_root].client |
| 282 | + if lock is not None: |
| 283 | + logger.debug(f"Unlocking {configs.db_path}") |
| 284 | + await lock.release() |
293 | 285 |
|
294 | 286 | def get_processes(self) -> list[Process]: |
295 | 287 | return [i.process for i in self.__clients.values() if i.process is not None] |
296 | 288 |
|
297 | 289 | async def kill_servers(self): |
298 | 290 | termination_tasks: list[asyncio.Task] = [] |
299 | 291 | for p in ClientManager().get_processes(): |
| 292 | + logger.info(f"Killing bundled chroma server with PID: {p.pid}") |
300 | 293 | p.terminate() |
301 | 294 | termination_tasks.append(asyncio.create_task(p.wait())) |
302 | 295 | await asyncio.gather(*termination_tasks) |
| 296 | + |
| 297 | + async def _create_client(self, configs: Config) -> AsyncClientAPI: |
| 298 | + settings: dict[str, Any] = {"anonymized_telemetry": False} |
| 299 | + if isinstance(configs.db_settings, dict): |
| 300 | + valid_settings = { |
| 301 | + k: v for k, v in configs.db_settings.items() if k in Settings.__fields__ |
| 302 | + } |
| 303 | + settings.update(valid_settings) |
| 304 | + parsed_url = urlparse(configs.db_url) |
| 305 | + settings["chroma_server_host"] = parsed_url.hostname or "127.0.0.1" |
| 306 | + settings["chroma_server_http_port"] = parsed_url.port or 8000 |
| 307 | + settings["chroma_server_ssl_enabled"] = parsed_url.scheme == "https" |
| 308 | + settings["chroma_server_api_default_path"] = parsed_url.path or APIVersion.V2 |
| 309 | + settings_obj = Settings(**settings) |
| 310 | + return await chromadb.AsyncHttpClient( |
| 311 | + settings=settings_obj, |
| 312 | + host=str(settings_obj.chroma_server_host), |
| 313 | + port=int(settings_obj.chroma_server_http_port or 8000), |
| 314 | + ) |
| 315 | + |
| 316 | + def clear(self): |
| 317 | + self.__clients.clear() |
0 commit comments