Skip to content

Commit 9e0b64e

Browse files
committed
refactor(cli): Use a context manager for client with filelock when necessary
1 parent 452602d commit 9e0b64e

File tree

10 files changed

+482
-471
lines changed

10 files changed

+482
-471
lines changed

src/vectorcode/cli_utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -627,8 +627,14 @@ def __new__(cls) -> "LockManager":
627627
cls.__singleton.__locks = {}
628628
return cls.__singleton
629629

630-
def get(self, path: str | os.PathLike) -> AsyncFileLock:
630+
def get_lock(self, path: str | os.PathLike) -> AsyncFileLock:
631631
path = str(expand_path(str(path), True))
632+
if os.path.isdir(path):
633+
lock_file = os.path.join(path, "vectorcode.lock")
634+
logger.info(f"Creating {lock_file} for locking.")
635+
with open(lock_file, mode="w") as fin:
636+
fin.write("")
637+
path = lock_file
632638
if self.__locks.get(path) is None:
633639
self.__locks[path] = AsyncFileLock(path) # pyright: ignore[reportArgumentType]
634640
return self.__locks[path]

src/vectorcode/common.py

Lines changed: 40 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import contextlib
23
import hashlib
34
import logging
45
import os
@@ -18,7 +19,7 @@
1819
from chromadb.config import APIVersion, Settings
1920
from chromadb.utils import embedding_functions
2021

21-
from vectorcode.cli_utils import Config, expand_path
22+
from vectorcode.cli_utils import Config, LockManager, expand_path
2223

2324
logger = logging.getLogger(name=__name__)
2425

@@ -114,26 +115,6 @@ async def start_server(configs: Config):
114115
return process
115116

116117

117-
async def get_client(configs: Config) -> AsyncClientAPI:
118-
settings: dict[str, Any] = {"anonymized_telemetry": False}
119-
if isinstance(configs.db_settings, dict):
120-
valid_settings = {
121-
k: v for k, v in configs.db_settings.items() if k in Settings.__fields__
122-
}
123-
settings.update(valid_settings)
124-
parsed_url = urlparse(configs.db_url)
125-
settings["chroma_server_host"] = parsed_url.hostname or "127.0.0.1"
126-
settings["chroma_server_http_port"] = parsed_url.port or 8000
127-
settings["chroma_server_ssl_enabled"] = parsed_url.scheme == "https"
128-
settings["chroma_server_api_default_path"] = parsed_url.path or APIVersion.V2
129-
settings_obj = Settings(**settings)
130-
return await chromadb.AsyncHttpClient(
131-
settings=settings_obj,
132-
host=str(settings_obj.chroma_server_host),
133-
port=int(settings_obj.chroma_server_http_port or 8000),
134-
)
135-
136-
137118
def get_collection_name(full_path: str) -> str:
138119
full_path = str(expand_path(full_path, absolute=True))
139120
hasher = hashlib.sha256()
@@ -276,27 +257,61 @@ def __new__(cls) -> "ClientManager":
276257
cls.__singleton.__clients = {}
277258
return cls.__singleton
278259

279-
async def get_client(self, configs: Config) -> _ClientModel:
260+
@contextlib.asynccontextmanager
261+
async def get_client(self, configs: Config, need_lock: bool = True):
280262
project_root = str(expand_path(str(configs.project_root), True))
263+
is_bundled = False
281264
if self.__clients.get(project_root) is None:
282-
is_bundled = False
283265
process = None
284266
if not await try_server(configs.db_url):
285267
logger.info(f"Starting a new server at {configs.db_url}")
286268
process = await start_server(configs)
287269
is_bundled = True
288270

289271
self.__clients[project_root] = _ClientModel(
290-
client=await get_client(configs), is_bundled=is_bundled, process=process
272+
client=await self._create_client(configs),
273+
is_bundled=is_bundled,
274+
process=process,
291275
)
292-
return self.__clients[project_root]
276+
lock = None
277+
if self.__clients[project_root].is_bundled and need_lock:
278+
lock = LockManager().get_lock(str(configs.db_path))
279+
logger.debug(f"Locking {configs.db_path}")
280+
await lock.acquire()
281+
yield self.__clients[project_root].client
282+
if lock is not None:
283+
logger.debug(f"Unlocking {configs.db_path}")
284+
await lock.release()
293285

294286
def get_processes(self) -> list[Process]:
295287
return [i.process for i in self.__clients.values() if i.process is not None]
296288

297289
async def kill_servers(self):
298290
termination_tasks: list[asyncio.Task] = []
299291
for p in ClientManager().get_processes():
292+
logger.info(f"Killing bundled chroma server with PID: {p.pid}")
300293
p.terminate()
301294
termination_tasks.append(asyncio.create_task(p.wait()))
302295
await asyncio.gather(*termination_tasks)
296+
297+
async def _create_client(self, configs: Config) -> AsyncClientAPI:
298+
settings: dict[str, Any] = {"anonymized_telemetry": False}
299+
if isinstance(configs.db_settings, dict):
300+
valid_settings = {
301+
k: v for k, v in configs.db_settings.items() if k in Settings.__fields__
302+
}
303+
settings.update(valid_settings)
304+
parsed_url = urlparse(configs.db_url)
305+
settings["chroma_server_host"] = parsed_url.hostname or "127.0.0.1"
306+
settings["chroma_server_http_port"] = parsed_url.port or 8000
307+
settings["chroma_server_ssl_enabled"] = parsed_url.scheme == "https"
308+
settings["chroma_server_api_default_path"] = parsed_url.path or APIVersion.V2
309+
settings_obj = Settings(**settings)
310+
return await chromadb.AsyncHttpClient(
311+
settings=settings_obj,
312+
host=str(settings_obj.chroma_server_host),
313+
port=int(settings_obj.chroma_server_http_port or 8000),
314+
)
315+
316+
def clear(self):
317+
self.__clients.clear()

0 commit comments

Comments
 (0)