From 11ff042dce3ef5bca1580369c1dc327c52884343 Mon Sep 17 00:00:00 2001 From: Zhe Yu Date: Wed, 25 Jun 2025 19:28:44 +0800 Subject: [PATCH 01/17] refactor(cli): implement filelock to protect `db_path`. --- pyproject.toml | 1 + src/vectorcode/cli_utils.py | 16 +++++++++++ src/vectorcode/common.py | 28 ++++++++++++++++++++ src/vectorcode/main.py | 14 +++------- src/vectorcode/subcommands/query/__init__.py | 4 +-- 5 files changed, 51 insertions(+), 12 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index dc05e2c6..45ec151f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ dependencies = [ "charset-normalizer>=3.4.1", "json5", "posthog<6.0.0", + "filelock>=3.15.0", ] requires-python = ">=3.11,<3.14" readme = "README.md" diff --git a/src/vectorcode/cli_utils.py b/src/vectorcode/cli_utils.py index 38fd74ef..f41fa7c6 100644 --- a/src/vectorcode/cli_utils.py +++ b/src/vectorcode/cli_utils.py @@ -12,6 +12,7 @@ import json5 import shtab +from filelock import AsyncFileLock from vectorcode import __version__ @@ -610,3 +611,18 @@ def config_logging( handlers=handlers, level=level, ) + + +class LockManager: + """ + A class that manages file locks that protects the database files in daemon processes (LSP, MCP). + """ + + def __init__(self) -> None: + self.__locks: dict[str, AsyncFileLock] = {} + + def get(self, path: str | os.PathLike) -> AsyncFileLock: + path = str(expand_path(str(path), True)) + if self.__locks.get(path) is None: + self.__locks[path] = AsyncFileLock(path) # pyright: ignore[reportArgumentType] + return self.__locks[path] diff --git a/src/vectorcode/common.py b/src/vectorcode/common.py index 65ce495e..622ed829 100644 --- a/src/vectorcode/common.py +++ b/src/vectorcode/common.py @@ -5,6 +5,7 @@ import socket import subprocess import sys +from asyncio.subprocess import Process from typing import Any, AsyncGenerator from urllib.parse import urlparse @@ -261,3 +262,30 @@ async def list_collection_files(collection: AsyncCollection) -> list[str]: or [] ) ) + + +class ClientManager: + __singleton = None + + # keys: project roots + # values: clients + __clients: dict[str, AsyncClientAPI] = {} + __server_processes = [] + + @classmethod + def get_instance(cls) -> "ClientManager": + if cls.__singleton is None: + cls.__singleton = ClientManager() + return cls.__singleton + + async def get(self, configs: Config) -> AsyncClientAPI: + project_root = str(expand_path(str(configs.project_root), True)) + if self.__clients.get(project_root) is None: + if not await try_server(configs.db_url): + self.__server_processes.append(await start_server(configs)) + + self.__clients[project_root] = await get_client(configs) + return self.__clients[project_root] + + def get_processes(self) -> list[Process]: + return self.__server_processes diff --git a/src/vectorcode/main.py b/src/vectorcode/main.py index 3ea8eefa..689c5834 100644 --- a/src/vectorcode/main.py +++ b/src/vectorcode/main.py @@ -12,6 +12,7 @@ get_project_config, parse_cli_args, ) +from vectorcode.common import ClientManager logger = logging.getLogger(name=__name__) @@ -63,12 +64,6 @@ async def async_main(): return await chunks(final_configs) - from vectorcode.common import start_server, try_server - - server_process = None - if not await try_server(final_configs.db_url): - server_process = await start_server(final_configs) - if final_configs.pipe: # pragma: nocover # NOTE: NNCF (intel GPU acceleration for sentence transformer) keeps showing logs. # This disables logs below ERROR so that it doesn't hurt the `pipe` output. @@ -105,10 +100,9 @@ async def async_main(): return_val = 1 logger.error(traceback.format_exc()) finally: - if server_process is not None: - logger.info("Shutting down the bundled Chromadb instance.") - server_process.terminate() - await server_process.wait() + for p in ClientManager.get_instance().get_processes(): + p.terminate() + await p.wait() return return_val diff --git a/src/vectorcode/subcommands/query/__init__.py b/src/vectorcode/subcommands/query/__init__.py index 51c3a550..b1a3c04f 100644 --- a/src/vectorcode/subcommands/query/__init__.py +++ b/src/vectorcode/subcommands/query/__init__.py @@ -17,7 +17,7 @@ expand_path, ) from vectorcode.common import ( - get_client, + ClientManager, get_collection, verify_ef, ) @@ -160,7 +160,7 @@ async def query(configs: Config) -> int: "Having both chunk and document in the output is not supported!", ) return 1 - client = await get_client(configs) + client = await ClientManager.get_instance().get(configs) try: collection = await get_collection(client, configs, False) if not verify_ef(collection, configs): From 28c44874ce2f43709cf6ec684267df3327e6a6d3 Mon Sep 17 00:00:00 2001 From: Zhe Yu Date: Wed, 25 Jun 2025 20:06:30 +0800 Subject: [PATCH 02/17] fix(cli): fixed singleton implementation. --- src/vectorcode/cli_utils.py | 10 ++++-- src/vectorcode/common.py | 38 ++++++++++++-------- src/vectorcode/main.py | 2 +- src/vectorcode/mcp_main.py | 21 ++++++----- src/vectorcode/subcommands/query/__init__.py | 2 +- 5 files changed, 46 insertions(+), 27 deletions(-) diff --git a/src/vectorcode/cli_utils.py b/src/vectorcode/cli_utils.py index f41fa7c6..ba7ba2c9 100644 --- a/src/vectorcode/cli_utils.py +++ b/src/vectorcode/cli_utils.py @@ -618,8 +618,14 @@ class LockManager: A class that manages file locks that protects the database files in daemon processes (LSP, MCP). """ - def __init__(self) -> None: - self.__locks: dict[str, AsyncFileLock] = {} + __locks: dict[str, AsyncFileLock] + __singleton: "LockManager" + + def __new__(cls) -> "LockManager": + if cls.__singleton is None: + cls.__singleton = super().__new__(cls) + cls.__singleton.__locks = {} + return cls.__singleton def get(self, path: str | os.PathLike) -> AsyncFileLock: path = str(expand_path(str(path), True)) diff --git a/src/vectorcode/common.py b/src/vectorcode/common.py index 622ed829..5c4bb5b2 100644 --- a/src/vectorcode/common.py +++ b/src/vectorcode/common.py @@ -6,7 +6,8 @@ import subprocess import sys from asyncio.subprocess import Process -from typing import Any, AsyncGenerator +from dataclasses import dataclass +from typing import Any, AsyncGenerator, Optional from urllib.parse import urlparse import chromadb @@ -264,28 +265,37 @@ async def list_collection_files(collection: AsyncCollection) -> list[str]: ) -class ClientManager: - __singleton = None +@dataclass +class _ClientModel: + client: AsyncClientAPI + is_bundled: bool = False + process: Optional[Process] = None + - # keys: project roots - # values: clients - __clients: dict[str, AsyncClientAPI] = {} - __server_processes = [] +class ClientManager: + __singleton: Optional["ClientManager"] = None + __clients: dict[str, _ClientModel] - @classmethod - def get_instance(cls) -> "ClientManager": + def __new__(cls) -> "ClientManager": if cls.__singleton is None: - cls.__singleton = ClientManager() + cls.__singleton = super().__new__(cls) + cls.__singleton.__clients = {} return cls.__singleton - async def get(self, configs: Config) -> AsyncClientAPI: + async def get_client(self, configs: Config) -> _ClientModel: project_root = str(expand_path(str(configs.project_root), True)) if self.__clients.get(project_root) is None: + is_bundled = False + process = None if not await try_server(configs.db_url): - self.__server_processes.append(await start_server(configs)) + logger.info(f"Starting a new server at {configs.db_url}") + process = await start_server(configs) + is_bundled = True - self.__clients[project_root] = await get_client(configs) + self.__clients[project_root] = _ClientModel( + client=await get_client(configs), is_bundled=is_bundled, process=process + ) return self.__clients[project_root] def get_processes(self) -> list[Process]: - return self.__server_processes + return [i.process for i in self.__clients.values() if i.process is not None] diff --git a/src/vectorcode/main.py b/src/vectorcode/main.py index 689c5834..6efb8ea8 100644 --- a/src/vectorcode/main.py +++ b/src/vectorcode/main.py @@ -100,7 +100,7 @@ async def async_main(): return_val = 1 logger.error(traceback.format_exc()) finally: - for p in ClientManager.get_instance().get_processes(): + for p in ClientManager().get_processes(): p.terminate() await p.wait() return return_val diff --git a/src/vectorcode/mcp_main.py b/src/vectorcode/mcp_main.py index 86d9989c..4402241e 100644 --- a/src/vectorcode/mcp_main.py +++ b/src/vectorcode/mcp_main.py @@ -32,6 +32,7 @@ from vectorcode.cli_utils import ( Config, + LockManager, cleanup_path, config_logging, expand_globs, @@ -39,11 +40,12 @@ get_project_config, load_config_file, ) -from vectorcode.common import get_client, get_collection, get_collections +from vectorcode.common import ClientManager, get_client, get_collection, get_collections from vectorcode.subcommands.prompt import prompt_by_categories from vectorcode.subcommands.query import get_query_result_files logger = logging.getLogger(name=__name__) +locks = LockManager() @dataclass @@ -79,18 +81,17 @@ def get_arg_parser(): return parser +default_project_root: Optional[str] = None default_config: Optional[Config] = None default_client: Optional[AsyncClientAPI] = None default_collection: Optional[AsyncCollection] = None async def list_collections() -> list[str]: - global default_config, default_client, default_collection names: list[str] = [] - client = default_client - if client is None: - # load from global config when failed to detect a project-local config. - client = await get_client(await load_config_file()) + client = ( + await ClientManager().get_client(await load_config_file(default_project_root)) + ).client async for col in get_collections(client): if col.metadata is not None: names.append(cleanup_path(str(col.metadata.get("path")))) @@ -110,7 +111,7 @@ async def vectorise_files(paths: list[str], project_root: str) -> dict[str, int] ) config = await get_project_config(project_root) try: - client = await get_client(config) + client = (await ClientManager().get_client(config)).client collection = await get_collection(client, config, True) except Exception as e: logger.error("Failed to access collection at %s", project_root) @@ -136,6 +137,7 @@ async def vectorise_files(paths: list[str], project_root: str) -> dict[str, int] if os.path.isfile(ignore_spec): logger.info(f"Loading ignore specs from {ignore_spec}.") paths = exclude_paths_by_spec((str(i) for i in paths), ignore_spec) + stats = VectoriseStats() collection_lock = asyncio.Lock() stats_lock = asyncio.Lock() @@ -187,7 +189,7 @@ async def query_tool( else: config = await get_project_config(project_root) try: - client = await get_client(config) + client = (await ClientManager().get_client(config)).client collection = await get_collection(client, config, False) except Exception as e: logger.error("Failed to access collection at %s", project_root) @@ -225,7 +227,7 @@ async def query_tool( async def mcp_server(): - global default_config, default_client, default_collection + global default_config, default_client, default_collection, default_project_root local_config_dir = await find_project_config_dir(".") @@ -233,6 +235,7 @@ async def mcp_server(): logger.info("Found project config: %s", local_config_dir) project_root = str(Path(local_config_dir).parent.resolve()) + default_project_root = project_root default_config = await get_project_config(project_root) default_config.project_root = project_root default_client = await get_client(default_config) diff --git a/src/vectorcode/subcommands/query/__init__.py b/src/vectorcode/subcommands/query/__init__.py index b1a3c04f..5fae9699 100644 --- a/src/vectorcode/subcommands/query/__init__.py +++ b/src/vectorcode/subcommands/query/__init__.py @@ -160,7 +160,7 @@ async def query(configs: Config) -> int: "Having both chunk and document in the output is not supported!", ) return 1 - client = await ClientManager.get_instance().get(configs) + client = (await ClientManager().get_client(configs)).client try: collection = await get_collection(client, configs, False) if not verify_ef(collection, configs): From 042b254e6dd9d2eb909767eb7cf6a1838ccc6f26 Mon Sep 17 00:00:00 2001 From: Zhe Yu Date: Wed, 25 Jun 2025 20:24:29 +0800 Subject: [PATCH 03/17] feat(cli): Remove client cache and fix termination issues --- src/vectorcode/cli_utils.py | 2 +- src/vectorcode/common.py | 42 ++++++++++++++++--------------------- src/vectorcode/mcp_main.py | 17 ++++++++++----- 3 files changed, 31 insertions(+), 30 deletions(-) diff --git a/src/vectorcode/cli_utils.py b/src/vectorcode/cli_utils.py index ba7ba2c9..23d81dd0 100644 --- a/src/vectorcode/cli_utils.py +++ b/src/vectorcode/cli_utils.py @@ -622,7 +622,7 @@ class LockManager: __singleton: "LockManager" def __new__(cls) -> "LockManager": - if cls.__singleton is None: + if not hasattr(cls, "__singleton") or cls.__singleton is None: cls.__singleton = super().__new__(cls) cls.__singleton.__locks = {} return cls.__singleton diff --git a/src/vectorcode/common.py b/src/vectorcode/common.py index 5c4bb5b2..9e45dd47 100644 --- a/src/vectorcode/common.py +++ b/src/vectorcode/common.py @@ -114,30 +114,24 @@ async def start_server(configs: Config): return process -__CLIENT_CACHE: dict[str, AsyncClientAPI] = {} - - async def get_client(configs: Config) -> AsyncClientAPI: - client_entry = configs.db_url - if __CLIENT_CACHE.get(client_entry) is None: - settings: dict[str, Any] = {"anonymized_telemetry": False} - if isinstance(configs.db_settings, dict): - valid_settings = { - k: v for k, v in configs.db_settings.items() if k in Settings.__fields__ - } - settings.update(valid_settings) - parsed_url = urlparse(configs.db_url) - settings["chroma_server_host"] = parsed_url.hostname or "127.0.0.1" - settings["chroma_server_http_port"] = parsed_url.port or 8000 - settings["chroma_server_ssl_enabled"] = parsed_url.scheme == "https" - settings["chroma_server_api_default_path"] = parsed_url.path or APIVersion.V2 - settings_obj = Settings(**settings) - __CLIENT_CACHE[client_entry] = await chromadb.AsyncHttpClient( - settings=settings_obj, - host=str(settings_obj.chroma_server_host), - port=int(settings_obj.chroma_server_http_port or 8000), - ) - return __CLIENT_CACHE[client_entry] + settings: dict[str, Any] = {"anonymized_telemetry": False} + if isinstance(configs.db_settings, dict): + valid_settings = { + k: v for k, v in configs.db_settings.items() if k in Settings.__fields__ + } + settings.update(valid_settings) + parsed_url = urlparse(configs.db_url) + settings["chroma_server_host"] = parsed_url.hostname or "127.0.0.1" + settings["chroma_server_http_port"] = parsed_url.port or 8000 + settings["chroma_server_ssl_enabled"] = parsed_url.scheme == "https" + settings["chroma_server_api_default_path"] = parsed_url.path or APIVersion.V2 + settings_obj = Settings(**settings) + return await chromadb.AsyncHttpClient( + settings=settings_obj, + host=str(settings_obj.chroma_server_host), + port=int(settings_obj.chroma_server_http_port or 8000), + ) def get_collection_name(full_path: str) -> str: @@ -277,7 +271,7 @@ class ClientManager: __clients: dict[str, _ClientModel] def __new__(cls) -> "ClientManager": - if cls.__singleton is None: + if not hasattr(cls, "__singleton") or cls.__singleton is None: cls.__singleton = super().__new__(cls) cls.__singleton.__clients = {} return cls.__singleton diff --git a/src/vectorcode/mcp_main.py b/src/vectorcode/mcp_main.py index 4402241e..383d4c96 100644 --- a/src/vectorcode/mcp_main.py +++ b/src/vectorcode/mcp_main.py @@ -40,7 +40,7 @@ get_project_config, load_config_file, ) -from vectorcode.common import ClientManager, get_client, get_collection, get_collections +from vectorcode.common import ClientManager, get_collection, get_collections from vectorcode.subcommands.prompt import prompt_by_categories from vectorcode.subcommands.query import get_query_result_files @@ -238,7 +238,7 @@ async def mcp_server(): default_project_root = project_root default_config = await get_project_config(project_root) default_config.project_root = project_root - default_client = await get_client(default_config) + default_client = (await ClientManager().get_client(default_config)).client try: default_collection = await get_collection(default_client, default_config) logger.info("Collection initialised for %s.", project_root) @@ -295,9 +295,16 @@ def parse_cli_args(args: Optional[list[str]] = None) -> MCPConfig: async def run_server(): # pragma: nocover - mcp = await mcp_server() - await mcp.run_stdio_async() - return 0 + try: + mcp = await mcp_server() + await mcp.run_stdio_async() + finally: + termination_tasks: list[asyncio.Task] = [] + for p in ClientManager().get_processes(): + p.terminate() + termination_tasks.append(asyncio.create_task(p.wait())) + await asyncio.gather(*termination_tasks) + return 0 def main(): # pragma: nocover From 568a609bcde65af187620c23a649a4700c1c4219 Mon Sep 17 00:00:00 2001 From: Zhe Yu Date: Thu, 26 Jun 2025 12:36:43 +0800 Subject: [PATCH 04/17] refactor(cli): Refactor client termination to ClientManager --- src/vectorcode/common.py | 7 +++++++ src/vectorcode/lsp_main.py | 14 ++++++++------ src/vectorcode/main.py | 4 +--- src/vectorcode/mcp_main.py | 6 +----- 4 files changed, 17 insertions(+), 14 deletions(-) diff --git a/src/vectorcode/common.py b/src/vectorcode/common.py index 9e45dd47..ec2bd8a5 100644 --- a/src/vectorcode/common.py +++ b/src/vectorcode/common.py @@ -293,3 +293,10 @@ async def get_client(self, configs: Config) -> _ClientModel: def get_processes(self) -> list[Process]: return [i.process for i in self.__clients.values() if i.process is not None] + + async def kill_servers(self): + termination_tasks: list[asyncio.Task] = [] + for p in ClientManager().get_processes(): + p.terminate() + termination_tasks.append(asyncio.create_task(p.wait())) + await asyncio.gather(*termination_tasks) diff --git a/src/vectorcode/lsp_main.py b/src/vectorcode/lsp_main.py index f9b883f9..24da40b1 100644 --- a/src/vectorcode/lsp_main.py +++ b/src/vectorcode/lsp_main.py @@ -43,7 +43,7 @@ get_project_config, parse_cli_args, ) -from vectorcode.common import get_client, get_collection, try_server +from vectorcode.common import ClientManager, get_collection, try_server from vectorcode.subcommands.ls import get_collection_list from vectorcode.subcommands.query import build_query_results @@ -114,7 +114,7 @@ async def execute_command(ls: LanguageServer, args: list[str]): parsed_args.project_root ].merge_from(parsed_args) final_configs.pipe = True - client = await get_client(final_configs) + client = (await ClientManager().get_client(final_configs)).client if final_configs.action in {CliAction.vectorise, CliAction.query}: collection = await get_collection( client=client, @@ -123,7 +123,7 @@ async def execute_command(ls: LanguageServer, args: list[str]): ) else: final_configs = parsed_args - client = await get_client(parsed_args) + client = (await ClientManager().get_client(final_configs)).client collection = None logger.info("Merged final configs: %s", final_configs) progress_token = str(uuid.uuid4()) @@ -266,9 +266,11 @@ async def lsp_start() -> int: logger.info(f"{DEFAULT_PROJECT_ROOT=}") logger.info("Parsed LSP server CLI arguments: %s", args) - await asyncio.to_thread(server.start_io) - - return 0 + try: + await asyncio.to_thread(server.start_io) + finally: + await ClientManager().kill_servers() + return 0 def main(): # pragma: nocover diff --git a/src/vectorcode/main.py b/src/vectorcode/main.py index 6efb8ea8..70cc1aba 100644 --- a/src/vectorcode/main.py +++ b/src/vectorcode/main.py @@ -100,9 +100,7 @@ async def async_main(): return_val = 1 logger.error(traceback.format_exc()) finally: - for p in ClientManager().get_processes(): - p.terminate() - await p.wait() + await ClientManager().kill_servers() return return_val diff --git a/src/vectorcode/mcp_main.py b/src/vectorcode/mcp_main.py index 383d4c96..bdaa8112 100644 --- a/src/vectorcode/mcp_main.py +++ b/src/vectorcode/mcp_main.py @@ -299,11 +299,7 @@ async def run_server(): # pragma: nocover mcp = await mcp_server() await mcp.run_stdio_async() finally: - termination_tasks: list[asyncio.Task] = [] - for p in ClientManager().get_processes(): - p.terminate() - termination_tasks.append(asyncio.create_task(p.wait())) - await asyncio.gather(*termination_tasks) + await ClientManager().kill_servers() return 0 From 25b2cbd683e03e0a90614291d298acab522a22e6 Mon Sep 17 00:00:00 2001 From: Zhe Yu Date: Thu, 26 Jun 2025 15:44:40 +0800 Subject: [PATCH 05/17] refactor(cli): Use a context manager for client with filelock when necessary --- src/vectorcode/cli_utils.py | 8 +- src/vectorcode/common.py | 65 +++-- src/vectorcode/lsp_main.py | 241 +++++++++---------- src/vectorcode/mcp_main.py | 214 ++++++++-------- src/vectorcode/subcommands/clean.py | 7 +- src/vectorcode/subcommands/drop.py | 28 +-- src/vectorcode/subcommands/ls.py | 60 ++--- src/vectorcode/subcommands/query/__init__.py | 88 +++---- src/vectorcode/subcommands/update.py | 126 +++++----- src/vectorcode/subcommands/vectorise.py | 116 ++++----- 10 files changed, 482 insertions(+), 471 deletions(-) diff --git a/src/vectorcode/cli_utils.py b/src/vectorcode/cli_utils.py index 23d81dd0..4400743d 100644 --- a/src/vectorcode/cli_utils.py +++ b/src/vectorcode/cli_utils.py @@ -627,8 +627,14 @@ def __new__(cls) -> "LockManager": cls.__singleton.__locks = {} return cls.__singleton - def get(self, path: str | os.PathLike) -> AsyncFileLock: + def get_lock(self, path: str | os.PathLike) -> AsyncFileLock: path = str(expand_path(str(path), True)) + if os.path.isdir(path): + lock_file = os.path.join(path, "vectorcode.lock") + logger.info(f"Creating {lock_file} for locking.") + with open(lock_file, mode="w") as fin: + fin.write("") + path = lock_file if self.__locks.get(path) is None: self.__locks[path] = AsyncFileLock(path) # pyright: ignore[reportArgumentType] return self.__locks[path] diff --git a/src/vectorcode/common.py b/src/vectorcode/common.py index ec2bd8a5..b4ae94cf 100644 --- a/src/vectorcode/common.py +++ b/src/vectorcode/common.py @@ -1,4 +1,5 @@ import asyncio +import contextlib import hashlib import logging import os @@ -18,7 +19,7 @@ from chromadb.config import APIVersion, Settings from chromadb.utils import embedding_functions -from vectorcode.cli_utils import Config, expand_path +from vectorcode.cli_utils import Config, LockManager, expand_path logger = logging.getLogger(name=__name__) @@ -114,26 +115,6 @@ async def start_server(configs: Config): return process -async def get_client(configs: Config) -> AsyncClientAPI: - settings: dict[str, Any] = {"anonymized_telemetry": False} - if isinstance(configs.db_settings, dict): - valid_settings = { - k: v for k, v in configs.db_settings.items() if k in Settings.__fields__ - } - settings.update(valid_settings) - parsed_url = urlparse(configs.db_url) - settings["chroma_server_host"] = parsed_url.hostname or "127.0.0.1" - settings["chroma_server_http_port"] = parsed_url.port or 8000 - settings["chroma_server_ssl_enabled"] = parsed_url.scheme == "https" - settings["chroma_server_api_default_path"] = parsed_url.path or APIVersion.V2 - settings_obj = Settings(**settings) - return await chromadb.AsyncHttpClient( - settings=settings_obj, - host=str(settings_obj.chroma_server_host), - port=int(settings_obj.chroma_server_http_port or 8000), - ) - - def get_collection_name(full_path: str) -> str: full_path = str(expand_path(full_path, absolute=True)) hasher = hashlib.sha256() @@ -276,10 +257,11 @@ def __new__(cls) -> "ClientManager": cls.__singleton.__clients = {} return cls.__singleton - async def get_client(self, configs: Config) -> _ClientModel: + @contextlib.asynccontextmanager + async def get_client(self, configs: Config, need_lock: bool = True): project_root = str(expand_path(str(configs.project_root), True)) + is_bundled = False if self.__clients.get(project_root) is None: - is_bundled = False process = None if not await try_server(configs.db_url): logger.info(f"Starting a new server at {configs.db_url}") @@ -287,9 +269,19 @@ async def get_client(self, configs: Config) -> _ClientModel: is_bundled = True self.__clients[project_root] = _ClientModel( - client=await get_client(configs), is_bundled=is_bundled, process=process + client=await self._create_client(configs), + is_bundled=is_bundled, + process=process, ) - return self.__clients[project_root] + lock = None + if self.__clients[project_root].is_bundled and need_lock: + lock = LockManager().get_lock(str(configs.db_path)) + logger.debug(f"Locking {configs.db_path}") + await lock.acquire() + yield self.__clients[project_root].client + if lock is not None: + logger.debug(f"Unlocking {configs.db_path}") + await lock.release() def get_processes(self) -> list[Process]: return [i.process for i in self.__clients.values() if i.process is not None] @@ -297,6 +289,29 @@ def get_processes(self) -> list[Process]: async def kill_servers(self): termination_tasks: list[asyncio.Task] = [] for p in ClientManager().get_processes(): + logger.info(f"Killing bundled chroma server with PID: {p.pid}") p.terminate() termination_tasks.append(asyncio.create_task(p.wait())) await asyncio.gather(*termination_tasks) + + async def _create_client(self, configs: Config) -> AsyncClientAPI: + settings: dict[str, Any] = {"anonymized_telemetry": False} + if isinstance(configs.db_settings, dict): + valid_settings = { + k: v for k, v in configs.db_settings.items() if k in Settings.__fields__ + } + settings.update(valid_settings) + parsed_url = urlparse(configs.db_url) + settings["chroma_server_host"] = parsed_url.hostname or "127.0.0.1" + settings["chroma_server_http_port"] = parsed_url.port or 8000 + settings["chroma_server_ssl_enabled"] = parsed_url.scheme == "https" + settings["chroma_server_api_default_path"] = parsed_url.path or APIVersion.V2 + settings_obj = Settings(**settings) + return await chromadb.AsyncHttpClient( + settings=settings_obj, + host=str(settings_obj.chroma_server_host), + port=int(settings_obj.chroma_server_http_port or 8000), + ) + + def clear(self): + self.__clients.clear() diff --git a/src/vectorcode/lsp_main.py b/src/vectorcode/lsp_main.py index 24da40b1..bd78854a 100644 --- a/src/vectorcode/lsp_main.py +++ b/src/vectorcode/lsp_main.py @@ -35,7 +35,6 @@ from vectorcode import __version__ from vectorcode.cli_utils import ( CliAction, - Config, cleanup_path, config_logging, expand_globs, @@ -43,28 +42,14 @@ get_project_config, parse_cli_args, ) -from vectorcode.common import ClientManager, get_collection, try_server +from vectorcode.common import ClientManager, get_collection from vectorcode.subcommands.ls import get_collection_list from vectorcode.subcommands.query import build_query_results -cached_project_configs: dict[str, Config] = {} DEFAULT_PROJECT_ROOT: str | None = None logger = logging.getLogger(__name__) -async def make_caches(project_root: str): - assert os.path.isabs(project_root) - if cached_project_configs.get(project_root) is None: - cached_project_configs[project_root] = await get_project_config(project_root) - config = cached_project_configs[project_root] - config.project_root = project_root - db_url = config.db_url - if not await try_server(db_url): # pragma: nocover - raise ConnectionError( - "Failed to find an existing ChromaDB server, which is a hard requirement for LSP mode!" - ) - - def get_arg_parser(): parser = argparse.ArgumentParser( "vectorcode-server", description="VectorCode LSP daemon." @@ -109,134 +94,140 @@ async def execute_command(ls: LanguageServer, args: list[str]): collection = None if parsed_args.project_root is not None: parsed_args.project_root = os.path.abspath(str(parsed_args.project_root)) - await make_caches(parsed_args.project_root) - final_configs = await cached_project_configs[ - parsed_args.project_root - ].merge_from(parsed_args) + + final_configs = await ( + await get_project_config(parsed_args.project_root) + ).merge_from(parsed_args) final_configs.pipe = True - client = (await ClientManager().get_client(final_configs)).client + else: + final_configs = parsed_args + logger.info("Merged final configs: %s", final_configs) + async with ClientManager().get_client(final_configs) as client: + progress_token = str(uuid.uuid4()) + if final_configs.action in {CliAction.vectorise, CliAction.query}: collection = await get_collection( client=client, configs=final_configs, make_if_missing=final_configs.action in {CliAction.vectorise}, ) - else: - final_configs = parsed_args - client = (await ClientManager().get_client(final_configs)).client - collection = None - logger.info("Merged final configs: %s", final_configs) - progress_token = str(uuid.uuid4()) - - await ls.progress.create_async(progress_token) - match final_configs.action: - case CliAction.query: - ls.progress.begin( - progress_token, - types.WorkDoneProgressBegin( - "VectorCode", - message=f"Querying {cleanup_path(str(final_configs.project_root))}", - ), - ) - final_results = [] - try: - assert collection is not None, ( - "Failed to find the correct collection." - ) - final_results.extend( - await build_query_results(collection, final_configs) - ) - finally: - log_message = f"Retrieved {len(final_results)} result{'s' if len(final_results) > 1 else ''} in {round(time.time() - start_time, 2)}s." - ls.progress.end( + await ls.progress.create_async(progress_token) + match final_configs.action: + case CliAction.query: + ls.progress.begin( progress_token, - types.WorkDoneProgressEnd(message=log_message), + types.WorkDoneProgressBegin( + "VectorCode", + message=f"Querying {cleanup_path(str(final_configs.project_root))}", + ), ) - logger.info(log_message) - return final_results - case CliAction.ls: - ls.progress.begin( - progress_token, - types.WorkDoneProgressBegin( - "VectorCode", - message="Looking for available projects indexed by VectorCode", - ), - ) - projects: list[dict] = [] - try: - projects.extend(await get_collection_list(client)) - finally: - ls.progress.end( + final_results = [] + try: + assert collection is not None, ( + "Failed to find the correct collection." + ) + final_results.extend( + await build_query_results(collection, final_configs) + ) + finally: + log_message = f"Retrieved {len(final_results)} result{'s' if len(final_results) > 1 else ''} in {round(time.time() - start_time, 2)}s." + ls.progress.end( + progress_token, + types.WorkDoneProgressEnd(message=log_message), + ) + logger.info(log_message) + return final_results + case CliAction.ls: + ls.progress.begin( progress_token, - types.WorkDoneProgressEnd(message="List retrieved."), + types.WorkDoneProgressBegin( + "VectorCode", + message="Looking for available projects indexed by VectorCode", + ), ) - logger.info(f"Retrieved {len(projects)} project(s).") - return projects - case CliAction.vectorise: - assert collection is not None, "Failed to find the correct collection." - ls.progress.begin( - progress_token, - types.WorkDoneProgressBegin( - title="VectorCode", message="Vectorising files...", percentage=0 - ), - ) - files = await expand_globs( - final_configs.files - or load_files_from_include(str(final_configs.project_root)), - recursive=final_configs.recursive, - include_hidden=final_configs.include_hidden, - ) - if not final_configs.force: # pragma: nocover - # tested in 'vectorise.py' - for spec in find_exclude_specs(final_configs): - if os.path.isfile(spec): - logger.info(f"Loading ignore specs from {spec}.") - files = exclude_paths_by_spec((str(i) for i in files), spec) - stats = VectoriseStats() - collection_lock = asyncio.Lock() - stats_lock = asyncio.Lock() - max_batch_size = await client.get_max_batch_size() - semaphore = asyncio.Semaphore(os.cpu_count() or 1) - tasks = [ - asyncio.create_task( - chunked_add( - str(file), - collection, - collection_lock, - stats, - stats_lock, - final_configs, - max_batch_size, - semaphore, + projects: list[dict] = [] + try: + projects.extend(await get_collection_list(client)) + finally: + ls.progress.end( + progress_token, + types.WorkDoneProgressEnd(message="List retrieved."), ) + logger.info(f"Retrieved {len(projects)} project(s).") + return projects + case CliAction.vectorise: + assert collection is not None, ( + "Failed to find the correct collection." ) - for file in files - ] - for i, task in enumerate(asyncio.as_completed(tasks), start=1): - await task - ls.progress.report( + ls.progress.begin( progress_token, - types.WorkDoneProgressReport( + types.WorkDoneProgressBegin( + title="VectorCode", message="Vectorising files...", - percentage=int(100 * i / len(tasks)), + percentage=0, ), ) + files = await expand_globs( + final_configs.files + or load_files_from_include(str(final_configs.project_root)), + recursive=final_configs.recursive, + include_hidden=final_configs.include_hidden, + ) + if not final_configs.force: # pragma: nocover + # tested in 'vectorise.py' + for spec in find_exclude_specs(final_configs): + if os.path.isfile(spec): + logger.info(f"Loading ignore specs from {spec}.") + files = exclude_paths_by_spec( + (str(i) for i in files), spec + ) + stats = VectoriseStats() + collection_lock = asyncio.Lock() + stats_lock = asyncio.Lock() + max_batch_size = await client.get_max_batch_size() + semaphore = asyncio.Semaphore(os.cpu_count() or 1) + tasks = [ + asyncio.create_task( + chunked_add( + str(file), + collection, + collection_lock, + stats, + stats_lock, + final_configs, + max_batch_size, + semaphore, + ) + ) + for file in files + ] + for i, task in enumerate(asyncio.as_completed(tasks), start=1): + await task + ls.progress.report( + progress_token, + types.WorkDoneProgressReport( + message="Vectorising files...", + percentage=int(100 * i / len(tasks)), + ), + ) - await remove_orphanes(collection, collection_lock, stats, stats_lock) + await remove_orphanes( + collection, collection_lock, stats, stats_lock + ) - ls.progress.end( - progress_token, - types.WorkDoneProgressEnd( - message=f"Vectorised {stats.add + stats.update} files." - ), - ) - return stats.to_dict() - case _ as c: # pragma: nocover - error_message = f"Unsupported vectorcode subcommand: {str(c)}" - logger.error( - error_message, - ) - raise JsonRpcInvalidRequest(error_message) + ls.progress.end( + progress_token, + types.WorkDoneProgressEnd( + message=f"Vectorised {stats.add + stats.update} files." + ), + ) + return stats.to_dict() + case _ as c: # pragma: nocover + error_message = f"Unsupported vectorcode subcommand: {str(c)}" + logger.error( + error_message, + ) + raise JsonRpcInvalidRequest(error_message) except Exception as e: # pragma: nocover if isinstance(e, JsonRpcException): # pygls exception. raise it as is. diff --git a/src/vectorcode/mcp_main.py b/src/vectorcode/mcp_main.py index bdaa8112..8d5fdd44 100644 --- a/src/vectorcode/mcp_main.py +++ b/src/vectorcode/mcp_main.py @@ -8,9 +8,6 @@ from typing import Optional import shtab -from chromadb.api import AsyncClientAPI -from chromadb.api.models.AsyncCollection import AsyncCollection -from chromadb.errors import InvalidCollectionException from vectorcode.subcommands.vectorise import ( VectoriseStats, @@ -83,20 +80,18 @@ def get_arg_parser(): default_project_root: Optional[str] = None default_config: Optional[Config] = None -default_client: Optional[AsyncClientAPI] = None -default_collection: Optional[AsyncCollection] = None async def list_collections() -> list[str]: names: list[str] = [] - client = ( - await ClientManager().get_client(await load_config_file(default_project_root)) - ).client - async for col in get_collections(client): - if col.metadata is not None: - names.append(cleanup_path(str(col.metadata.get("path")))) - logger.info("Retrieved the following collections: %s", names) - return names + async with ClientManager().get_client( + await load_config_file(default_project_root) + ) as client: + async for col in get_collections(client): + if col.metadata is not None: + names.append(cleanup_path(str(col.metadata.get("path")))) + logger.info("Retrieved the following collections: %s", names) + return names async def vectorise_files(paths: list[str], project_root: str) -> dict[str, int]: @@ -111,8 +106,53 @@ async def vectorise_files(paths: list[str], project_root: str) -> dict[str, int] ) config = await get_project_config(project_root) try: - client = (await ClientManager().get_client(config)).client - collection = await get_collection(client, config, True) + async with ClientManager().get_client(config) as client: + collection = await get_collection(client, config, True) + if collection is None: # pragma: nocover + raise McpError( + ErrorData( + code=1, + message=f"Failed to access the collection at {project_root}. Use `list_collections` tool to get a list of valid paths for this field.", + ) + ) + paths = [os.path.expanduser(i) for i in await expand_globs(paths)] + final_config = await config.merge_from( + Config( + files=[i for i in paths if os.path.isfile(i)], + project_root=project_root, + ) + ) + for ignore_spec in find_exclude_specs(final_config): + if os.path.isfile(ignore_spec): + logger.info(f"Loading ignore specs from {ignore_spec}.") + paths = exclude_paths_by_spec((str(i) for i in paths), ignore_spec) + + stats = VectoriseStats() + collection_lock = asyncio.Lock() + stats_lock = asyncio.Lock() + max_batch_size = await client.get_max_batch_size() + semaphore = asyncio.Semaphore(os.cpu_count() or 1) + tasks = [ + asyncio.create_task( + chunked_add( + str(file), + collection, + collection_lock, + stats, + stats_lock, + final_config, + max_batch_size, + semaphore, + ) + ) + for file in paths + ] + for i, task in enumerate(asyncio.as_completed(tasks), start=1): + await task + + await remove_orphanes(collection, collection_lock, stats, stats_lock) + + return stats.to_dict() except Exception as e: logger.error("Failed to access collection at %s", project_root) raise McpError( @@ -121,49 +161,6 @@ async def vectorise_files(paths: list[str], project_root: str) -> dict[str, int] message=f"{e.__class__.__name__}: Failed to create the collection at {project_root}.", ) ) - if collection is None: # pragma: nocover - raise McpError( - ErrorData( - code=1, - message=f"Failed to access the collection at {project_root}. Use `list_collections` tool to get a list of valid paths for this field.", - ) - ) - - paths = [os.path.expanduser(i) for i in await expand_globs(paths)] - final_config = await config.merge_from( - Config(files=[i for i in paths if os.path.isfile(i)], project_root=project_root) - ) - for ignore_spec in find_exclude_specs(final_config): - if os.path.isfile(ignore_spec): - logger.info(f"Loading ignore specs from {ignore_spec}.") - paths = exclude_paths_by_spec((str(i) for i in paths), ignore_spec) - - stats = VectoriseStats() - collection_lock = asyncio.Lock() - stats_lock = asyncio.Lock() - max_batch_size = await client.get_max_batch_size() - semaphore = asyncio.Semaphore(os.cpu_count() or 1) - tasks = [ - asyncio.create_task( - chunked_add( - str(file), - collection, - collection_lock, - stats, - stats_lock, - final_config, - max_batch_size, - semaphore, - ) - ) - for file in paths - ] - for i, task in enumerate(asyncio.as_completed(tasks), start=1): - await task - - await remove_orphanes(collection, collection_lock, stats, stats_lock) - - return stats.to_dict() async def query_tool( @@ -186,51 +183,55 @@ async def query_tool( message="Use `list_collections` tool to get a list of valid paths for this field.", ) ) - else: - config = await get_project_config(project_root) - try: - client = (await ClientManager().get_client(config)).client + config = await get_project_config(project_root) + try: + async with ClientManager().get_client(config) as client: collection = await get_collection(client, config, False) - except Exception as e: - logger.error("Failed to access collection at %s", project_root) - raise McpError( - ErrorData( - code=1, - message=f"{e.__class__.__name__}: Failed to access the collection at {project_root}. Use `list_collections` tool to get a list of valid paths for this field.", + + if collection is None: + raise McpError( + ErrorData( + code=1, + message=f"Failed to access the collection at {project_root}. Use `list_collections` tool to get a list of valid paths for this field.", + ) ) + query_config = await config.merge_from( + Config(n_result=n_query, query=query_messages) ) - if collection is None: + logger.info("Built the final config: %s", query_config) + result_paths = await get_query_result_files( + collection=collection, + configs=query_config, + ) + results: list[str] = [] + for path in result_paths: + if os.path.isfile(path): + with open(path) as fin: + rel_path = os.path.relpath(path, config.project_root) + results.append( + f"{rel_path}\n{fin.read()}", + ) + logger.info("Retrieved the following files: %s", result_paths) + return results + + except Exception as e: + logger.error("Failed to access collection at %s", project_root) raise McpError( ErrorData( code=1, - message=f"Failed to access the collection at {project_root}. Use `list_collections` tool to get a list of valid paths for this field.", + message=f"{e.__class__.__name__}: Failed to access the collection at {project_root}. Use `list_collections` tool to get a list of valid paths for this field.", ) ) - query_config = await config.merge_from( - Config(n_result=n_query, query=query_messages) - ) - logger.info("Built the final config: %s", query_config) - result_paths = await get_query_result_files( - collection=collection, - configs=query_config, - ) - results: list[str] = [] - for path in result_paths: - if os.path.isfile(path): - with open(path) as fin: - rel_path = os.path.relpath(path, config.project_root) - results.append( - f"{rel_path}\n{fin.read()}", - ) - logger.info("Retrieved the following files: %s", result_paths) - return results async def mcp_server(): - global default_config, default_client, default_collection, default_project_root + global default_config, default_project_root local_config_dir = await find_project_config_dir(".") + default_instructions = "\n".join( + "\n".join(i) for i in prompt_by_categories.values() + ) if local_config_dir is not None: logger.info("Found project config: %s", local_config_dir) project_root = str(Path(local_config_dir).parent.resolve()) @@ -238,27 +239,24 @@ async def mcp_server(): default_project_root = project_root default_config = await get_project_config(project_root) default_config.project_root = project_root - default_client = (await ClientManager().get_client(default_config)).client - try: - default_collection = await get_collection(default_client, default_config) + async with ClientManager().get_client(default_config) as client: logger.info("Collection initialised for %s.", project_root) - except InvalidCollectionException: # pragma: nocover - default_collection = None - default_instructions = "\n".join( - "\n".join(i) for i in prompt_by_categories.values() - ) - if default_client is None: - if mcp_config.ls_on_start: # pragma: nocover - logger.warning( - "Failed to initialise a chromadb client. Ignoring --ls-on-start flag." - ) - else: - if mcp_config.ls_on_start: - logger.info("Adding available collections to the server instructions.") - default_instructions += "\nYou have access to the following collections:\n" - for name in await list_collections(): - default_instructions += f"{name}" + if client is None: + if mcp_config.ls_on_start: # pragma: nocover + logger.warning( + "Failed to initialise a chromadb client. Ignoring --ls-on-start flag." + ) + else: + if mcp_config.ls_on_start: + logger.info( + "Adding available collections to the server instructions." + ) + default_instructions += ( + "\nYou have access to the following collections:\n" + ) + for name in await list_collections(): + default_instructions += f"{name}" mcp = FastMCP("VectorCode", instructions=default_instructions) mcp.add_tool( diff --git a/src/vectorcode/subcommands/clean.py b/src/vectorcode/subcommands/clean.py index 4a58aeb9..bae7ed48 100644 --- a/src/vectorcode/subcommands/clean.py +++ b/src/vectorcode/subcommands/clean.py @@ -4,7 +4,7 @@ from chromadb.api import AsyncClientAPI from vectorcode.cli_utils import Config -from vectorcode.common import get_client, get_collections +from vectorcode.common import ClientManager, get_collections logger = logging.getLogger(name=__name__) @@ -21,5 +21,6 @@ async def run_clean_on_client(client: AsyncClientAPI, pipe_mode: bool): async def clean(configs: Config) -> int: - await run_clean_on_client(await get_client(configs), configs.pipe) - return 0 + async with ClientManager().get_client(configs) as client: + await run_clean_on_client(client, configs.pipe) + return 0 diff --git a/src/vectorcode/subcommands/drop.py b/src/vectorcode/subcommands/drop.py index 08fbbbae..155c303f 100644 --- a/src/vectorcode/subcommands/drop.py +++ b/src/vectorcode/subcommands/drop.py @@ -3,22 +3,22 @@ from chromadb.errors import InvalidCollectionException from vectorcode.cli_utils import Config -from vectorcode.common import get_client, get_collection +from vectorcode.common import ClientManager, get_collection logger = logging.getLogger(name=__name__) async def drop(config: Config) -> int: - client = await get_client(config) - try: - collection = await get_collection(client, config) - collection_path = collection.metadata["path"] - await client.delete_collection(collection.name) - print(f"Collection for {collection_path} has been deleted.") - logger.info(f"Deteted collection at {collection_path}.") - return 0 - except (ValueError, InvalidCollectionException) as e: - logger.error( - f"{e.__class__.__name__}: There's no existing collection for {config.project_root}" - ) - return 1 + async with ClientManager().get_client(config) as client: + try: + collection = await get_collection(client, config) + collection_path = collection.metadata["path"] + await client.delete_collection(collection.name) + print(f"Collection for {collection_path} has been deleted.") + logger.info(f"Deteted collection at {collection_path}.") + return 0 + except (ValueError, InvalidCollectionException) as e: + logger.error( + f"{e.__class__.__name__}: There's no existing collection for {config.project_root}" + ) + return 1 diff --git a/src/vectorcode/subcommands/ls.py b/src/vectorcode/subcommands/ls.py index 246eb85b..c78d82ac 100644 --- a/src/vectorcode/subcommands/ls.py +++ b/src/vectorcode/subcommands/ls.py @@ -8,7 +8,7 @@ from chromadb.api.types import IncludeEnum from vectorcode.cli_utils import Config, cleanup_path -from vectorcode.common import get_client, get_collections +from vectorcode.common import ClientManager, get_collections logger = logging.getLogger(name=__name__) @@ -36,34 +36,34 @@ async def get_collection_list(client: AsyncClientAPI) -> list[dict]: async def ls(configs: Config) -> int: - client = await get_client(configs) - result: list[dict] = await get_collection_list(client) - logger.info(f"Found the following collections: {result}") + async with ClientManager().get_client(configs) as client: + result: list[dict] = await get_collection_list(client) + logger.info(f"Found the following collections: {result}") - if configs.pipe: - print(json.dumps(result)) - else: - table = [] - for meta in result: - project_root = meta["project-root"] - if os.environ.get("HOME"): - project_root = project_root.replace(os.environ["HOME"], "~") - row = [ - project_root, - meta["size"], - meta["num_files"], - meta["embedding_function"], - ] - table.append(row) - print( - tabulate.tabulate( - table, - headers=[ - "Project Root", - "Collection Size", - "Number of Files", - "Embedding Function", - ], + if configs.pipe: + print(json.dumps(result)) + else: + table = [] + for meta in result: + project_root = meta["project-root"] + if os.environ.get("HOME"): + project_root = project_root.replace(os.environ["HOME"], "~") + row = [ + project_root, + meta["size"], + meta["num_files"], + meta["embedding_function"], + ] + table.append(row) + print( + tabulate.tabulate( + table, + headers=[ + "Project Root", + "Collection Size", + "Number of Files", + "Embedding Function", + ], + ) ) - ) - return 0 + return 0 diff --git a/src/vectorcode/subcommands/query/__init__.py b/src/vectorcode/subcommands/query/__init__.py index 5fae9699..e1c0fc2f 100644 --- a/src/vectorcode/subcommands/query/__init__.py +++ b/src/vectorcode/subcommands/query/__init__.py @@ -160,52 +160,52 @@ async def query(configs: Config) -> int: "Having both chunk and document in the output is not supported!", ) return 1 - client = (await ClientManager().get_client(configs)).client - try: - collection = await get_collection(client, configs, False) - if not verify_ef(collection, configs): + async with ClientManager().get_client(configs) as client: + try: + collection = await get_collection(client, configs, False) + if not verify_ef(collection, configs): + return 1 + except (ValueError, InvalidCollectionException) as e: + logger.error( + f"{e.__class__.__name__}: There's no existing collection for {configs.project_root}", + ) + return 1 + except InvalidDimensionException as e: + logger.error( + f"{e.__class__.__name__}: The collection was embedded with a different embedding model.", + ) + return 1 + except IndexError as e: # pragma: nocover + logger.error( + f"{e.__class__.__name__}: Failed to get the collection. Please check your config." + ) return 1 - except (ValueError, InvalidCollectionException) as e: - logger.error( - f"{e.__class__.__name__}: There's no existing collection for {configs.project_root}", - ) - return 1 - except InvalidDimensionException as e: - logger.error( - f"{e.__class__.__name__}: The collection was embedded with a different embedding model.", - ) - return 1 - except IndexError as e: # pragma: nocover - logger.error( - f"{e.__class__.__name__}: Failed to get the collection. Please check your config." - ) - return 1 - if not configs.pipe: - print("Starting querying...") + if not configs.pipe: + print("Starting querying...") - if QueryInclude.chunk in configs.include: - if len((await collection.get(where={"start": {"$gte": 0}}))["ids"]) == 0: - logger.warning( - """ -This collection doesn't contain line range metadata. Falling back to `--include path document`. -Please re-vectorise it to use `--include chunk`.""", - ) - configs.include = [QueryInclude.path, QueryInclude.document] + if QueryInclude.chunk in configs.include: + if len((await collection.get(where={"start": {"$gte": 0}}))["ids"]) == 0: + logger.warning( + """ + This collection doesn't contain line range metadata. Falling back to `--include path document`. + Please re-vectorise it to use `--include chunk`.""", + ) + configs.include = [QueryInclude.path, QueryInclude.document] - try: - structured_result = await build_query_results(collection, configs) - except RerankerError as e: # pragma: nocover - # error logs should be handled where they're raised - logger.error(f"{e.__class__.__name__}") - return 1 + try: + structured_result = await build_query_results(collection, configs) + except RerankerError as e: # pragma: nocover + # error logs should be handled where they're raised + logger.error(f"{e.__class__.__name__}") + return 1 - if configs.pipe: - print(json.dumps(structured_result)) - else: - for idx, result in enumerate(structured_result): - for include_item in configs.include: - print(f"{include_item.to_header()}{result.get(include_item.value)}") - if idx != len(structured_result) - 1: - print() - return 0 + if configs.pipe: + print(json.dumps(structured_result)) + else: + for idx, result in enumerate(structured_result): + for include_item in configs.include: + print(f"{include_item.to_header()}{result.get(include_item.value)}") + if idx != len(structured_result) - 1: + print() + return 0 diff --git a/src/vectorcode/subcommands/update.py b/src/vectorcode/subcommands/update.py index 2d7d4322..ff4efa1e 100644 --- a/src/vectorcode/subcommands/update.py +++ b/src/vectorcode/subcommands/update.py @@ -9,78 +9,78 @@ from chromadb.errors import InvalidCollectionException from vectorcode.cli_utils import Config -from vectorcode.common import get_client, get_collection, verify_ef +from vectorcode.common import ClientManager, get_collection, verify_ef from vectorcode.subcommands.vectorise import VectoriseStats, chunked_add, show_stats logger = logging.getLogger(name=__name__) async def update(configs: Config) -> int: - client = await get_client(configs) - try: - collection = await get_collection(client, configs, False) - except IndexError as e: - print( - f"{e.__class__.__name__}: Failed to get/create the collection. Please check your config." - ) - return 1 - except (ValueError, InvalidCollectionException) as e: - print( - f"{e.__class__.__name__}: There's no existing collection for {configs.project_root}", - file=sys.stderr, - ) - return 1 - if collection is None or not verify_ef(collection, configs): - return 1 + async with ClientManager().get_client(configs) as client: + try: + collection = await get_collection(client, configs, False) + except IndexError as e: + print( + f"{e.__class__.__name__}: Failed to get/create the collection. Please check your config." + ) + return 1 + except (ValueError, InvalidCollectionException) as e: + print( + f"{e.__class__.__name__}: There's no existing collection for {configs.project_root}", + file=sys.stderr, + ) + return 1 + if collection is None or not verify_ef(collection, configs): + return 1 - metas = (await collection.get(include=[IncludeEnum.metadatas]))["metadatas"] - if metas is None: - return 0 - files_gen = (str(meta.get("path", "")) for meta in metas) - files = set() - orphanes = set() - for file in files_gen: - if os.path.isfile(file): - files.add(file) - else: - orphanes.add(file) + metas = (await collection.get(include=[IncludeEnum.metadatas]))["metadatas"] + if metas is None: + return 0 + files_gen = (str(meta.get("path", "")) for meta in metas) + files = set() + orphanes = set() + for file in files_gen: + if os.path.isfile(file): + files.add(file) + else: + orphanes.add(file) - stats = VectoriseStats(removed=len(orphanes)) - collection_lock = Lock() - stats_lock = Lock() - max_batch_size = await client.get_max_batch_size() - semaphore = asyncio.Semaphore(os.cpu_count() or 1) + stats = VectoriseStats(removed=len(orphanes)) + collection_lock = Lock() + stats_lock = Lock() + max_batch_size = await client.get_max_batch_size() + semaphore = asyncio.Semaphore(os.cpu_count() or 1) - with tqdm.tqdm( - total=len(files), desc="Vectorising files...", disable=configs.pipe - ) as bar: - logger.info(f"Updating embeddings for {len(files)} file(s).") - try: - tasks = [ - asyncio.create_task( - chunked_add( - str(file), - collection, - collection_lock, - stats, - stats_lock, - configs, - max_batch_size, - semaphore, + with tqdm.tqdm( + total=len(files), desc="Vectorising files...", disable=configs.pipe + ) as bar: + logger.info(f"Updating embeddings for {len(files)} file(s).") + try: + tasks = [ + asyncio.create_task( + chunked_add( + str(file), + collection, + collection_lock, + stats, + stats_lock, + configs, + max_batch_size, + semaphore, + ) ) - ) - for file in files - ] - for task in asyncio.as_completed(tasks): - await task - bar.update(1) - except asyncio.CancelledError: # pragma: nocover - print("Abort.", file=sys.stderr) - return 1 + for file in files + ] + for task in asyncio.as_completed(tasks): + await task + bar.update(1) + except asyncio.CancelledError: # pragma: nocover + print("Abort.", file=sys.stderr) + return 1 - if len(orphanes): - logger.info(f"Removing {len(orphanes)} orphaned files from database.") - await collection.delete(where={"path": {"$in": list(orphanes)}}) + if len(orphanes): + logger.info(f"Removing {len(orphanes)} orphaned files from database.") + await collection.delete(where={"path": {"$in": list(orphanes)}}) - show_stats(configs, stats) - return 0 + show_stats(configs, stats) + return 0 diff --git a/src/vectorcode/subcommands/vectorise.py b/src/vectorcode/subcommands/vectorise.py index 40ef1619..a0bea88f 100644 --- a/src/vectorcode/subcommands/vectorise.py +++ b/src/vectorcode/subcommands/vectorise.py @@ -24,7 +24,7 @@ expand_path, ) from vectorcode.common import ( - get_client, + ClientManager, get_collection, list_collection_files, verify_ef, @@ -251,64 +251,64 @@ def find_exclude_specs(configs: Config) -> list[str]: async def vectorise(configs: Config) -> int: assert configs.project_root is not None - client = await get_client(configs) - try: - collection = await get_collection(client, configs, True) - except IndexError as e: - print( - f"{e.__class__.__name__}: Failed to get/create the collection. Please check your config." - ) - return 1 - if not verify_ef(collection, configs): - return 1 - - files = await expand_globs( - configs.files or load_files_from_include(str(configs.project_root)), - recursive=configs.recursive, - include_hidden=configs.include_hidden, - ) - - if not configs.force: - for spec_path in find_exclude_specs(configs): - if os.path.isfile(spec_path): - logger.info(f"Loading ignore specs from {spec_path}.") - files = exclude_paths_by_spec((str(i) for i in files), spec_path) - else: # pragma: nocover - logger.info("Ignoring exclude specs.") - - stats = VectoriseStats() - collection_lock = Lock() - stats_lock = Lock() - max_batch_size = await client.get_max_batch_size() - semaphore = asyncio.Semaphore(os.cpu_count() or 1) - - with tqdm.tqdm( - total=len(files), desc="Vectorising files...", disable=configs.pipe - ) as bar: + async with ClientManager().get_client(configs) as client: try: - tasks = [ - asyncio.create_task( - chunked_add( - str(file), - collection, - collection_lock, - stats, - stats_lock, - configs, - max_batch_size, - semaphore, - ) - ) - for file in files - ] - for task in asyncio.as_completed(tasks): - await task - bar.update(1) - except asyncio.CancelledError: - print("Abort.", file=sys.stderr) + collection = await get_collection(client, configs, True) + except IndexError as e: + print( + f"{e.__class__.__name__}: Failed to get/create the collection. Please check your config." + ) + return 1 + if not verify_ef(collection, configs): return 1 - await remove_orphanes(collection, collection_lock, stats, stats_lock) + files = await expand_globs( + configs.files or load_files_from_include(str(configs.project_root)), + recursive=configs.recursive, + include_hidden=configs.include_hidden, + ) - show_stats(configs=configs, stats=stats) - return 0 + if not configs.force: + for spec_path in find_exclude_specs(configs): + if os.path.isfile(spec_path): + logger.info(f"Loading ignore specs from {spec_path}.") + files = exclude_paths_by_spec((str(i) for i in files), spec_path) + else: # pragma: nocover + logger.info("Ignoring exclude specs.") + + stats = VectoriseStats() + collection_lock = Lock() + stats_lock = Lock() + max_batch_size = await client.get_max_batch_size() + semaphore = asyncio.Semaphore(os.cpu_count() or 1) + + with tqdm.tqdm( + total=len(files), desc="Vectorising files...", disable=configs.pipe + ) as bar: + try: + tasks = [ + asyncio.create_task( + chunked_add( + str(file), + collection, + collection_lock, + stats, + stats_lock, + configs, + max_batch_size, + semaphore, + ) + ) + for file in files + ] + for task in asyncio.as_completed(tasks): + await task + bar.update(1) + except asyncio.CancelledError: + print("Abort.", file=sys.stderr) + return 1 + + await remove_orphanes(collection, collection_lock, stats, stats_lock) + + show_stats(configs=configs, stats=stats) + return 0 From 1acf1a5232b7fe1bf7047d6ef9e569fc3a88d657 Mon Sep 17 00:00:00 2001 From: Zhe Yu Date: Thu, 26 Jun 2025 18:19:52 +0800 Subject: [PATCH 06/17] tests(cli): fix failed tests due to ClientManager refactor. --- tests/subcommands/query/test_query.py | 24 ++-- tests/subcommands/test_clean.py | 4 +- tests/subcommands/test_drop.py | 27 +++-- tests/subcommands/test_ls.py | 19 ++- tests/subcommands/test_update.py | 13 ++- tests/subcommands/test_vectorise.py | 35 +++--- tests/test_common.py | 162 +++++++++++++++----------- tests/test_lsp.py | 103 ++-------------- tests/test_main.py | 55 --------- tests/test_mcp.py | 109 ++++++++++++----- 10 files changed, 269 insertions(+), 282 deletions(-) diff --git a/tests/subcommands/query/test_query.py b/tests/subcommands/query/test_query.py index 4a54de9d..9f8e4078 100644 --- a/tests/subcommands/query/test_query.py +++ b/tests/subcommands/query/test_query.py @@ -355,7 +355,7 @@ async def test_query_success(mock_config): mock_collection = AsyncMock() with ( - patch("vectorcode.subcommands.query.get_client", return_value=mock_client), + patch("vectorcode.subcommands.query.ClientManager") as MockClientManager, patch( "vectorcode.subcommands.query.get_collection", return_value=mock_collection ), @@ -367,6 +367,7 @@ async def test_query_success(mock_config): patch("os.path.relpath", return_value="rel/path.py"), patch("os.path.abspath", return_value="/abs/path.py"), ): + MockClientManager.return_value._create_client.return_value = mock_client # Set up the mock file paths and contents mock_get_files.return_value = ["file1.py", "file2.py"] mock_file_handle = MagicMock() @@ -396,7 +397,7 @@ async def test_query_pipe_mode(mock_config): mock_collection = AsyncMock() with ( - patch("vectorcode.subcommands.query.get_client", return_value=mock_client), + patch("vectorcode.subcommands.query.ClientManager") as MockClientManager, patch( "vectorcode.subcommands.query.get_collection", return_value=mock_collection ), @@ -408,6 +409,7 @@ async def test_query_pipe_mode(mock_config): patch("os.path.relpath", return_value="rel/path.py"), patch("os.path.abspath", return_value="/abs/path.py"), ): + MockClientManager.return_value._create_client.return_value = mock_client # Set up the mock file paths and contents mock_get_files.return_value = ["file1.py", "file2.py"] mock_file_handle = MagicMock() @@ -434,7 +436,7 @@ async def test_query_absolute_path(mock_config): mock_collection = AsyncMock() with ( - patch("vectorcode.subcommands.query.get_client", return_value=mock_client), + patch("vectorcode.subcommands.query.ClientManager") as MockClientManager, patch( "vectorcode.subcommands.query.get_collection", return_value=mock_collection ), @@ -445,6 +447,7 @@ async def test_query_absolute_path(mock_config): patch("os.path.relpath", return_value="rel/path.py"), patch("os.path.abspath", return_value="/abs/path.py"), ): + MockClientManager.return_value._create_client.return_value = mock_client # Set up the mock file paths and contents mock_get_files.return_value = ["file1.py"] mock_file_handle = MagicMock() @@ -463,7 +466,7 @@ async def test_query_collection_not_found(): config = Config(project_root="/test/project") with ( - patch("vectorcode.subcommands.query.get_client"), + patch("vectorcode.subcommands.query.ClientManager"), patch("vectorcode.subcommands.query.get_collection") as mock_get_collection, patch("sys.stderr"), ): @@ -482,7 +485,7 @@ async def test_query_invalid_collection(): config = Config(project_root="/test/project") with ( - patch("vectorcode.subcommands.query.get_client"), + patch("vectorcode.subcommands.query.ClientManager"), patch("vectorcode.subcommands.query.get_collection") as mock_get_collection, patch("sys.stderr"), ): @@ -503,7 +506,7 @@ async def test_query_invalid_dimension(): config = Config(project_root="/test/project") with ( - patch("vectorcode.subcommands.query.get_client"), + patch("vectorcode.subcommands.query.ClientManager"), patch("vectorcode.subcommands.query.get_collection") as mock_get_collection, patch("sys.stderr"), ): @@ -524,7 +527,7 @@ async def test_query_invalid_file(mock_config): mock_collection = AsyncMock() with ( - patch("vectorcode.subcommands.query.get_client", return_value=mock_client), + patch("vectorcode.subcommands.query.ClientManager") as MockClientManager, patch( "vectorcode.subcommands.query.get_collection", return_value=mock_collection ), @@ -532,6 +535,7 @@ async def test_query_invalid_file(mock_config): patch("vectorcode.subcommands.query.get_query_result_files") as mock_get_files, patch("os.path.isfile", return_value=False), ): + MockClientManager.return_value._create_client.return_value = mock_client # Set up the mock file paths mock_get_files.return_value = ["invalid_file.py"] @@ -549,12 +553,13 @@ async def test_query_invalid_ef(mock_config): mock_collection = AsyncMock() with ( - patch("vectorcode.subcommands.query.get_client", return_value=mock_client), + patch("vectorcode.subcommands.query.ClientManager") as MockClientManager, patch( "vectorcode.subcommands.query.get_collection", return_value=mock_collection ), patch("vectorcode.subcommands.query.verify_ef", return_value=False), ): + MockClientManager.return_value._create_client.return_value = mock_client # Call the function result = await query(mock_config) @@ -580,13 +585,14 @@ async def test_query_chunk_mode_no_metadata_fallback(mock_config): mock_collection.get.return_value = {"ids": []} with ( - patch("vectorcode.subcommands.query.get_client", return_value=mock_client), + patch("vectorcode.subcommands.query.ClientManager") as MockClientManager, patch( "vectorcode.subcommands.query.get_collection", return_value=mock_collection ), patch("vectorcode.subcommands.query.verify_ef", return_value=True), patch("vectorcode.subcommands.query.build_query_results") as mock_build_results, ): + MockClientManager.return_value._create_client.return_value = mock_client mock_build_results.return_value = [] # Return empty results for simplicity result = await query(mock_config) diff --git a/tests/subcommands/test_clean.py b/tests/subcommands/test_clean.py index 1fc345fd..8c79fd7f 100644 --- a/tests/subcommands/test_clean.py +++ b/tests/subcommands/test_clean.py @@ -73,10 +73,10 @@ async def mock_get_collections(client): @pytest.mark.asyncio async def test_clean(): - mock_client = AsyncMock(spec=AsyncClientAPI) + AsyncMock(spec=AsyncClientAPI) mock_config = Config(pipe=False) - with patch("vectorcode.subcommands.clean.get_client", return_value=mock_client): + with patch("vectorcode.subcommands.clean.ClientManager"): result = await clean(mock_config) assert result == 0 diff --git a/tests/subcommands/test_drop.py b/tests/subcommands/test_drop.py index dcf4b1f0..15b990d8 100644 --- a/tests/subcommands/test_drop.py +++ b/tests/subcommands/test_drop.py @@ -1,3 +1,4 @@ +from contextlib import asynccontextmanager from unittest.mock import AsyncMock, patch import pytest @@ -31,19 +32,31 @@ def mock_collection(): async def test_drop_success(mock_config, mock_client, mock_collection): mock_client.get_collection.return_value = mock_collection mock_client.delete_collection = AsyncMock() - with patch("vectorcode.subcommands.drop.get_client", return_value=mock_client): - with patch( + with ( + patch("vectorcode.subcommands.drop.ClientManager") as MockClientManager, + patch( "vectorcode.subcommands.drop.get_collection", return_value=mock_collection - ): - result = await drop(mock_config) - assert result == 0 - mock_client.delete_collection.assert_called_once_with(mock_collection.name) + ), + ): + mock_client = AsyncMock() + + @asynccontextmanager + async def _get_client(self, config=None, need_lock=True): + yield mock_client + + mock_client_manager = MockClientManager.return_value + mock_client_manager._create_client = AsyncMock(return_value=mock_client) + mock_client_manager.get_client = _get_client + + result = await drop(mock_config) + assert result == 0 + mock_client.delete_collection.assert_called_once_with(mock_collection.name) @pytest.mark.asyncio async def test_drop_collection_not_found(mock_config, mock_client): mock_client.get_collection.side_effect = ValueError("Collection not found") - with patch("vectorcode.subcommands.drop.get_client", return_value=mock_client): + with patch("vectorcode.subcommands.drop.ClientManager"): with patch( "vectorcode.subcommands.drop.get_collection", side_effect=ValueError("Collection not found"), diff --git a/tests/subcommands/test_ls.py b/tests/subcommands/test_ls.py index 36f67469..bbc674eb 100644 --- a/tests/subcommands/test_ls.py +++ b/tests/subcommands/test_ls.py @@ -1,6 +1,6 @@ import json import socket -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest import tabulate @@ -77,7 +77,7 @@ async def mock_get_collections(client): yield mock_collection with ( - patch("vectorcode.subcommands.ls.get_client", return_value=mock_client), + patch("vectorcode.subcommands.ls.ClientManager") as MockClientManager, patch( "vectorcode.subcommands.ls.get_collection_list", return_value=[ @@ -90,6 +90,10 @@ async def mock_get_collections(client): ], ), ): + mock_client = MagicMock() + mock_client_manager = MockClientManager.return_value + mock_client_manager._create_client = AsyncMock(return_value=mock_client) + config = Config(pipe=True) await ls(config) captured = capsys.readouterr() @@ -126,7 +130,7 @@ async def mock_get_collections(client): yield mock_collection with ( - patch("vectorcode.subcommands.ls.get_client", return_value=mock_client), + patch("vectorcode.subcommands.ls.ClientManager") as MockClientManager, patch( "vectorcode.subcommands.ls.get_collection_list", return_value=[ @@ -139,6 +143,10 @@ async def mock_get_collections(client): ], ), ): + mock_client = MagicMock() + mock_client_manager = MockClientManager.return_value + mock_client_manager._create_client = AsyncMock(return_value=mock_client) + config = Config(pipe=False) await ls(config) captured = capsys.readouterr() @@ -159,7 +167,7 @@ async def mock_get_collections(client): # Test with HOME environment variable set monkeypatch.setenv("HOME", "/test") with ( - patch("vectorcode.subcommands.ls.get_client", return_value=mock_client), + patch("vectorcode.subcommands.ls.ClientManager") as MockClientManager, patch( "vectorcode.subcommands.ls.get_collection_list", return_value=[ @@ -172,6 +180,9 @@ async def mock_get_collections(client): ], ), ): + mock_client = MagicMock() + mock_client_manager = MockClientManager.return_value + mock_client_manager._create_client = AsyncMock(return_value=mock_client) config = Config(pipe=False) await ls(config) captured = capsys.readouterr() diff --git a/tests/subcommands/test_update.py b/tests/subcommands/test_update.py index febfa405..314f7c2a 100644 --- a/tests/subcommands/test_update.py +++ b/tests/subcommands/test_update.py @@ -19,7 +19,7 @@ async def test_update_success(): mock_client.get_max_batch_size.return_value = 100 with ( - patch("vectorcode.subcommands.update.get_client", return_value=mock_client), + patch("vectorcode.subcommands.update.ClientManager"), patch( "vectorcode.subcommands.update.get_collection", return_value=mock_collection ), @@ -50,7 +50,7 @@ async def test_update_with_orphans(): mock_client.get_max_batch_size.return_value = 100 with ( - patch("vectorcode.subcommands.update.get_client", return_value=mock_client), + patch("vectorcode.subcommands.update.ClientManager"), patch( "vectorcode.subcommands.update.get_collection", return_value=mock_collection ), @@ -78,10 +78,11 @@ async def test_update_index_error(): # mock_collection = AsyncMock() with ( - patch("vectorcode.subcommands.update.get_client", return_value=mock_client), + patch("vectorcode.subcommands.update.ClientManager") as MockClientManager, patch("vectorcode.subcommands.update.get_collection", side_effect=IndexError), patch("sys.stderr"), ): + MockClientManager.return_value._create_client.return_value = mock_client config = Config(project_root="/test/project", pipe=False) result = await update(config) @@ -94,10 +95,11 @@ async def test_update_value_error(): # mock_collection = AsyncMock() with ( - patch("vectorcode.subcommands.update.get_client", return_value=mock_client), + patch("vectorcode.subcommands.update.ClientManager") as MockClientManager, patch("vectorcode.subcommands.update.get_collection", side_effect=ValueError), patch("sys.stderr"), ): + MockClientManager.return_value._create_client.return_value = mock_client config = Config(project_root="/test/project", pipe=False) result = await update(config) @@ -110,13 +112,14 @@ async def test_update_invalid_collection_exception(): # mock_collection = AsyncMock() with ( - patch("vectorcode.subcommands.update.get_client", return_value=mock_client), + patch("vectorcode.subcommands.update.ClientManager") as MockClientManager, patch( "vectorcode.subcommands.update.get_collection", side_effect=InvalidCollectionException, ), patch("sys.stderr"), ): + MockClientManager.return_value._create_client.return_value = mock_client config = Config(project_root="/test/project", pipe=False) result = await update(config) diff --git a/tests/subcommands/test_vectorise.py b/tests/subcommands/test_vectorise.py index 2f363a8b..6b2287bf 100644 --- a/tests/subcommands/test_vectorise.py +++ b/tests/subcommands/test_vectorise.py @@ -370,9 +370,7 @@ async def test_vectorise(capsys): with ExitStack() as stack: stack.enter_context( - patch( - "vectorcode.subcommands.vectorise.get_client", return_value=mock_client - ) + patch("vectorcode.subcommands.vectorise.ClientManager"), ) stack.enter_context(patch("os.path.isfile", return_value=False)) stack.enter_context( @@ -427,7 +425,7 @@ async def mock_chunked_add(*args, **kwargs): "vectorcode.subcommands.vectorise.chunked_add", side_effect=mock_chunked_add ) as mock_add, patch("sys.stderr") as mock_stderr, - patch("vectorcode.subcommands.vectorise.get_client", return_value=mock_client), + patch("vectorcode.subcommands.vectorise.ClientManager") as MockClientManager, patch( "vectorcode.subcommands.vectorise.get_collection", return_value=mock_collection, @@ -438,6 +436,7 @@ async def mock_chunked_add(*args, **kwargs): lambda x: not (x.endswith("gitignore") or x.endswith("vectorcode.exclude")), ), ): + MockClientManager.return_value._create_client.return_value = mock_client result = await vectorise(configs) assert result == 1 mock_add.assert_called_once() @@ -458,7 +457,7 @@ async def test_vectorise_orphaned_files(): pipe=False, ) - mock_client = AsyncMock() + AsyncMock() mock_collection = AsyncMock() # Define a mock response for collection.get in vectorise @@ -494,7 +493,7 @@ def is_file_side_effect(path): "vectorcode.subcommands.vectorise.TreeSitterChunker", return_value=mock_chunker, ), - patch("vectorcode.subcommands.vectorise.get_client", return_value=mock_client), + patch("vectorcode.subcommands.vectorise.ClientManager"), patch( "vectorcode.subcommands.vectorise.get_collection", return_value=mock_collection, @@ -532,10 +531,11 @@ async def test_vectorise_collection_index_error(): mock_client = AsyncMock() with ( - patch("vectorcode.subcommands.vectorise.get_client", return_value=mock_client), + patch("vectorcode.subcommands.vectorise.ClientManager") as MockClientManager, patch("vectorcode.subcommands.vectorise.get_collection") as mock_get_collection, patch("os.path.isfile", return_value=False), ): + MockClientManager.return_value._create_client.return_value = mock_client mock_get_collection.side_effect = IndexError("Collection not found") result = await vectorise(configs) assert result == 1 @@ -558,7 +558,7 @@ async def test_vectorise_verify_ef_false(): mock_collection = AsyncMock() with ( - patch("vectorcode.subcommands.vectorise.get_client", return_value=mock_client), + patch("vectorcode.subcommands.vectorise.ClientManager") as MockClientManager, patch( "vectorcode.subcommands.vectorise.get_collection", return_value=mock_collection, @@ -566,6 +566,7 @@ async def test_vectorise_verify_ef_false(): patch("vectorcode.subcommands.vectorise.verify_ef", return_value=False), patch("os.path.isfile", return_value=False), ): + MockClientManager.return_value._create_client.return_value = mock_client result = await vectorise(configs) assert result == 1 @@ -588,7 +589,7 @@ async def test_vectorise_gitignore(): mock_collection.get.return_value = {"metadatas": []} with ( - patch("vectorcode.subcommands.vectorise.get_client", return_value=mock_client), + patch("vectorcode.subcommands.vectorise.ClientManager") as MockClientManager, patch( "vectorcode.subcommands.vectorise.get_collection", return_value=mock_collection, @@ -608,6 +609,7 @@ async def test_vectorise_gitignore(): "vectorcode.subcommands.vectorise.exclude_paths_by_spec" ) as mock_exclude_paths, ): + MockClientManager.return_value._create_client.return_value = mock_client await vectorise(configs) mock_exclude_paths.assert_called_once() @@ -635,7 +637,7 @@ async def test_vectorise_exclude_file(tmpdir): mock_collection.get.return_value = {"ids": []} with ( - patch("vectorcode.subcommands.vectorise.get_client", return_value=mock_client), + patch("vectorcode.subcommands.vectorise.ClientManager") as MockClientManager, patch( "vectorcode.subcommands.vectorise.get_collection", return_value=mock_collection, @@ -652,6 +654,7 @@ async def test_vectorise_exclude_file(tmpdir): ), patch("vectorcode.subcommands.vectorise.chunked_add") as mock_chunked_add, ): + MockClientManager.return_value._create_client.return_value = mock_client await vectorise(configs) # Assert that chunked_add is only called for test_file.py, not excluded_file.py call_args = [call[0][0] for call in mock_chunked_add.call_args_list] @@ -664,7 +667,6 @@ async def test_vectorise_exclude_file(tmpdir): @pytest.mark.asyncio -@patch("vectorcode.subcommands.vectorise.get_client", new_callable=AsyncMock) @patch("vectorcode.subcommands.vectorise.get_collection", new_callable=AsyncMock) @patch("vectorcode.subcommands.vectorise.expand_globs", new_callable=AsyncMock) @patch("vectorcode.subcommands.vectorise.chunked_add", new_callable=AsyncMock) @@ -681,7 +683,6 @@ async def test_vectorise_uses_global_exclude_when_local_missing( mock_chunked_add, mock_expand_globs, mock_get_collection, - mock_get_client, tmp_path, ): """ @@ -712,14 +713,20 @@ def isfile_side_effect(p): global_exclude_content = "*.bin" m_open = mock_open(read_data=global_exclude_content) - with patch("builtins.open", m_open): + with ( + patch("builtins.open", m_open), + patch("vectorcode.subcommands.vectorise.ClientManager") as MockClientManager, + ): mock_spec_instance = MagicMock() mock_spec_instance.match_file = lambda path: str(path).endswith(".bin") mock_gitignore_spec.from_lines.return_value = mock_spec_instance mock_client_instance = AsyncMock() mock_client_instance.get_max_batch_size = AsyncMock(return_value=100) - mock_get_client.return_value = mock_client_instance + + MockClientManager.return_value._create_client.return_value = ( + mock_client_instance + ) mock_collection_instance = AsyncMock() mock_collection_instance.get = AsyncMock( diff --git a/tests/test_common.py b/tests/test_common.py index 98f1370b..e3345ec0 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -13,7 +13,7 @@ from vectorcode.cli_utils import Config from vectorcode.common import ( - get_client, + ClientManager, get_collection, get_collection_name, get_collections, @@ -152,78 +152,75 @@ async def test_try_server_versions(): @pytest.mark.asyncio async def test_get_client(): + config = Config(db_url="https://test_host:1234", db_path="test_db") + config1 = Config( + db_url="http://test_host1:1234", + db_path="test_db", + db_settings={"anonymized_telemetry": True}, + ) + config1_alt = Config( + db_url="http://test_host1:1234", + db_path="test_db", + db_settings={"anonymized_telemetry": True, "other_setting": "value"}, + ) # Patch chromadb.AsyncHttpClient to avoid actual network calls - with patch("chromadb.AsyncHttpClient") as MockAsyncHttpClient: + with ( + patch("chromadb.AsyncHttpClient") as MockAsyncHttpClient, + patch("vectorcode.common.try_server", return_value=True), + ): mock_client = MagicMock(spec=AsyncClientAPI) MockAsyncHttpClient.return_value = mock_client - config = Config(db_url="https://test_host:1234", db_path="test_db") - client = await get_client(config) - - assert isinstance(client, AsyncClientAPI) - MockAsyncHttpClient.assert_called_once() - assert ( - MockAsyncHttpClient.call_args.kwargs["settings"].chroma_server_host - == "test_host" - ) - assert ( - MockAsyncHttpClient.call_args.kwargs["settings"].chroma_server_http_port - == 1234 - ) - assert ( - MockAsyncHttpClient.call_args.kwargs["settings"].anonymized_telemetry - is False - ) - assert ( - MockAsyncHttpClient.call_args.kwargs["settings"].chroma_server_ssl_enabled - is True - ) - - # Test with valid db_settings (only anonymized_telemetry) - config = Config( - db_url="http://test_host1:1234", - db_path="test_db", - db_settings={"anonymized_telemetry": True}, - ) - client = await get_client(config) - - assert isinstance(client, AsyncClientAPI) - MockAsyncHttpClient.assert_called() - assert ( - MockAsyncHttpClient.call_args.kwargs["settings"].chroma_server_host - == "test_host1" - ) - assert ( - MockAsyncHttpClient.call_args.kwargs["settings"].chroma_server_http_port - == 1234 - ) - assert ( - MockAsyncHttpClient.call_args.kwargs["settings"].anonymized_telemetry - is True - ) + async with ( + ClientManager().get_client(config) as client, + ): + assert isinstance(client, AsyncClientAPI) + MockAsyncHttpClient.assert_called() + assert ( + MockAsyncHttpClient.call_args.kwargs["settings"].chroma_server_host + == "test_host" + ) + assert ( + MockAsyncHttpClient.call_args.kwargs["settings"].chroma_server_http_port + == 1234 + ) + assert ( + MockAsyncHttpClient.call_args.kwargs["settings"].anonymized_telemetry + is False + ) + assert ( + MockAsyncHttpClient.call_args.kwargs[ + "settings" + ].chroma_server_ssl_enabled + is True + ) - # Test with multiple db_settings, including an invalid one. The invalid one - # should be filtered out inside get_client. - config = Config( - db_url="http://test_host2:1234", - db_path="test_db", - db_settings={"anonymized_telemetry": True, "other_setting": "value"}, - ) - client = await get_client(config) - assert isinstance(client, AsyncClientAPI) - MockAsyncHttpClient.assert_called() - assert ( - MockAsyncHttpClient.call_args.kwargs["settings"].chroma_server_host - == "test_host2" - ) - assert ( - MockAsyncHttpClient.call_args.kwargs["settings"].chroma_server_http_port - == 1234 - ) - assert ( - MockAsyncHttpClient.call_args.kwargs["settings"].anonymized_telemetry - is True - ) + async with ( + ClientManager().get_client(config1) as client1, + ClientManager().get_client(config1_alt) as client1_alt, + ): + assert isinstance(client1, AsyncClientAPI) + MockAsyncHttpClient.assert_called() + assert ( + MockAsyncHttpClient.call_args.kwargs["settings"].chroma_server_host + == "test_host1" + ) + assert ( + MockAsyncHttpClient.call_args.kwargs[ + "settings" + ].chroma_server_http_port + == 1234 + ) + assert ( + MockAsyncHttpClient.call_args.kwargs[ + "settings" + ].anonymized_telemetry + is True + ) + + # Test with multiple db_settings, including an invalid one. The invalid one + # should be filtered out inside get_client. + assert id(client1_alt) == id(client1) def test_verify_ef(): @@ -581,3 +578,32 @@ async def test_wait_for_server_timeout(): # Verify try_server was called multiple times (due to retries) assert mock_try_server.call_count > 1 + + +@pytest.mark.asyncio +async def test_client_manager_get_client(): + ClientManager().clear() + with ( + patch("vectorcode.common.try_server", return_value=False) as mock_try_server, + patch( + "vectorcode.common.start_server", return_value=MagicMock() + ) as mock_start_server, + ): + # need to start a new server + manager = ClientManager() + async with manager.get_client(Config()): + mock_try_server.assert_called_once() + mock_start_server.assert_called_once() + + ClientManager().clear() + with ( + patch("vectorcode.common.try_server", return_value=True) as mock_try_server, + patch( + "vectorcode.common.start_server", return_value=MagicMock() + ) as mock_start_server, + ): + # need to start a new server + manager = ClientManager() + async with manager.get_client(Config()): + mock_try_server.assert_called_once() + mock_start_server.assert_not_called() diff --git a/tests/test_lsp.py b/tests/test_lsp.py index 46bcf7eb..18f999ff 100644 --- a/tests/test_lsp.py +++ b/tests/test_lsp.py @@ -1,3 +1,4 @@ +from contextlib import asynccontextmanager from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -9,7 +10,6 @@ from vectorcode.lsp_main import ( execute_command, lsp_start, - make_caches, ) @@ -39,66 +39,20 @@ def mock_config(): return config -@pytest.mark.asyncio -async def test_make_caches(tmp_path): - project_root = str(tmp_path) - config_file = tmp_path / ".vectorcode" / "config.json" - config_file.parent.mkdir(exist_ok=True) - config_file.write_text('{"host": "test_host", "port": 9999}') - from vectorcode.lsp_main import cached_project_configs - - with ( - patch( - "vectorcode.lsp_main.get_project_config", new_callable=AsyncMock - ) as mock_get_project_config, - patch( - "vectorcode.lsp_main.try_server", new_callable=AsyncMock - ) as mock_try_server, - ): - mock_try_server.return_value = True - await make_caches(project_root) - - mock_get_project_config.assert_called_once_with(project_root) - assert project_root in cached_project_configs - - -@pytest.mark.asyncio -async def test_make_caches_server_unavailable(tmp_path): - project_root = str(tmp_path) - config_file = tmp_path / ".vectorcode" / "config.json" - config_file.parent.mkdir(exist_ok=True) - config_file.write_text('{"host": "test_host", "port": 9999}') - - with ( - patch("vectorcode.lsp_main.get_project_config", new_callable=AsyncMock), - patch( - "vectorcode.lsp_main.try_server", new_callable=AsyncMock - ) as mock_try_server, - ): - mock_try_server.return_value = False - with pytest.raises(ConnectionError): - await make_caches(project_root) - - @pytest.mark.asyncio async def test_execute_command_query(mock_language_server, mock_config): with ( patch( "vectorcode.lsp_main.parse_cli_args", new_callable=AsyncMock ) as mock_parse_cli_args, - patch("vectorcode.lsp_main.get_client", new_callable=AsyncMock), + patch("vectorcode.lsp_main.ClientManager"), patch("vectorcode.lsp_main.get_collection", new_callable=AsyncMock), patch( "vectorcode.lsp_main.build_query_results", new_callable=AsyncMock ) as mock_get_query_result_files, patch("os.path.isfile", return_value=True), - patch("vectorcode.lsp_main.try_server", return_value=True), patch("builtins.open", MagicMock()) as mock_open, - patch("vectorcode.lsp_main.cached_project_configs", {}), ): - from vectorcode.lsp_main import cached_project_configs - - cached_project_configs.clear() mock_parse_cli_args.return_value = mock_config mock_get_query_result_files.return_value = ["/test/file.txt"] @@ -110,9 +64,6 @@ async def test_execute_command_query(mock_language_server, mock_config): # Ensure parsed_args.project_root is not None mock_config.project_root = "/test/project" - # Add a mock config to cached_project_configs - cached_project_configs["/test/project"] = mock_config - # Mock the merge_from method mock_config.merge_from = AsyncMock(return_value=mock_config) @@ -131,22 +82,17 @@ async def test_execute_command_query_default_proj_root( patch( "vectorcode.lsp_main.parse_cli_args", new_callable=AsyncMock ) as mock_parse_cli_args, - patch("vectorcode.lsp_main.get_client", new_callable=AsyncMock), + patch("vectorcode.lsp_main.ClientManager"), patch("vectorcode.lsp_main.get_collection", new_callable=AsyncMock), patch( "vectorcode.lsp_main.build_query_results", new_callable=AsyncMock ) as mock_get_query_result_files, patch("os.path.isfile", return_value=True), - patch("vectorcode.lsp_main.try_server", return_value=True), patch("builtins.open", MagicMock()) as mock_open, - patch("vectorcode.lsp_main.cached_project_configs", {}), ): - from vectorcode.lsp_main import cached_project_configs - global DEFAULT_PROJECT_ROOT mock_config.project_root = None - cached_project_configs.clear() mock_parse_cli_args.return_value = mock_config mock_get_query_result_files.return_value = ["/test/file.txt"] @@ -158,9 +104,6 @@ async def test_execute_command_query_default_proj_root( # Ensure parsed_args.project_root is not None DEFAULT_PROJECT_ROOT = "/test/project" - # Add a mock config to cached_project_configs - cached_project_configs["/test/project"] = mock_config - # Mock the merge_from method mock_config.merge_from = AsyncMock(return_value=mock_config) @@ -183,26 +126,18 @@ async def test_execute_command_ls(mock_language_server, mock_config): patch( "vectorcode.lsp_main.parse_cli_args", new_callable=AsyncMock ) as mock_parse_cli_args, - patch("vectorcode.lsp_main.get_client", new_callable=AsyncMock), + patch("vectorcode.lsp_main.ClientManager"), patch( "vectorcode.lsp_main.get_collection_list", new_callable=AsyncMock ) as mock_get_collection_list, - patch("vectorcode.lsp_main.cached_project_configs", {}), patch("vectorcode.common.get_embedding_function") as mock_embedding_function, patch("vectorcode.common.get_collection") as mock_get_collection, - patch("vectorcode.lsp_main.try_server", return_value=True), ): - from vectorcode.lsp_main import cached_project_configs - - cached_project_configs.clear() mock_parse_cli_args.return_value = mock_config # Ensure parsed_args.project_root is not None mock_config.project_root = "/test/project" - # Add a mock config to cached_project_configs - cached_project_configs["/test/project"] = mock_config - # Mock the merge_from method mock_config.merge_from = AsyncMock(return_value=mock_config) @@ -236,9 +171,7 @@ async def test_execute_command_vectorise(mock_language_server, mock_config: Conf patch( "vectorcode.lsp_main.parse_cli_args", new_callable=AsyncMock ) as mock_parse_cli_args, - patch( - "vectorcode.lsp_main.get_client", new_callable=AsyncMock - ) as mock_get_client, + patch("vectorcode.lsp_main.ClientManager") as MockClientManager, patch( "vectorcode.lsp_main.get_collection", new_callable=AsyncMock ) as mock_get_collection, @@ -255,16 +188,11 @@ async def test_execute_command_vectorise(mock_language_server, mock_config: Conf patch( "vectorcode.lsp_main.chunked_add", new_callable=AsyncMock ) as mock_chunked_add, - patch("vectorcode.lsp_main.try_server", return_value=True), - patch("vectorcode.lsp_main.cached_project_configs", {}), patch( "vectorcode.lsp_main.load_files_from_include", return_value=dummy_initial_files, ) as mock_load_files_from_include, patch("os.cpu_count", return_value=1), # For asyncio.Semaphore - patch( - "vectorcode.lsp_main.make_caches", new_callable=AsyncMock - ), # Mock make_caches to avoid actual file system ops patch( "vectorcode.lsp_main.remove_orphanes", new_callable=AsyncMock ) as mock_remove_orphanes, @@ -273,15 +201,14 @@ async def test_execute_command_vectorise(mock_language_server, mock_config: Conf from lsprotocol import types - from vectorcode.lsp_main import cached_project_configs - - cached_project_configs.clear() - cached_project_configs["/test/project"] = mock_config # Add config to cache + @asynccontextmanager + async def _get_client(*args): + yield mock_client # Set return values for mocks mock_parse_cli_args.return_value = mock_config mock_client = AsyncMock() - mock_get_client.return_value = mock_client + MockClientManager.return_value.get_client.side_effect = _get_client mock_collection = AsyncMock() mock_get_collection.return_value = mock_collection mock_client.get_max_batch_size.return_value = 100 # Mock batch size @@ -319,7 +246,7 @@ async def test_execute_command_vectorise(mock_language_server, mock_config: Conf recursive=mock_config.recursive, include_hidden=mock_config.include_hidden, ) - mock_find_exclude_specs.assert_called_once_with(mock_config) + mock_find_exclude_specs.assert_called_once() mock_exclude_paths_by_spec.assert_not_called() # Because mock_find_exclude_specs returns empty list (no specs to exclude by) mock_client.get_max_batch_size.assert_called_once() @@ -332,7 +259,7 @@ async def test_execute_command_vectorise(mock_language_server, mock_config: Conf ANY, # asyncio.Lock object ANY, # stats dict ANY, # stats_lock - mock_config, + ANY, 100, # max_batch_size ANY, # semaphore ) @@ -362,16 +289,9 @@ async def test_execute_command_unsupported_action( patch( "vectorcode.lsp_main.get_collection", new_callable=AsyncMock ) as mock_get_collection, - patch("vectorcode.lsp_main.cached_project_configs", {}), - patch("vectorcode.lsp_main.try_server", return_value=True), ): - from vectorcode.lsp_main import cached_project_configs - - cached_project_configs.clear() mock_parse_cli_args.return_value = mock_config - # Add a mock config to cached_project_configs - cached_project_configs["/test/project"] = mock_config mock_collection = MagicMock() mock_get_collection.return_value = mock_collection @@ -449,7 +369,6 @@ async def test_execute_command_no_default_project_root( patch( "vectorcode.lsp_main.parse_cli_args", new_callable=AsyncMock ) as mock_parse_cli_args, - patch("vectorcode.lsp_main.get_client", new_callable=AsyncMock), ): mock_parse_cli_args.return_value = mock_config with pytest.raises((AssertionError, JsonRpcInternalError)): diff --git a/tests/test_main.py b/tests/test_main.py index d6eadd0b..34ce181f 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -140,32 +140,6 @@ async def test_async_main_cli_action_prompts(monkeypatch): mock_prompts.assert_called_once() -@pytest.mark.asyncio -async def test_async_main_try_server_unavailable(monkeypatch): - mock_cli_args = MagicMock(no_stderr=False, project_root=".", action=CliAction.query) - monkeypatch.setattr( - "vectorcode.main.parse_cli_args", AsyncMock(return_value=mock_cli_args) - ) - mock_final_configs = MagicMock(host="test_host", port=1234, action=CliAction.query) - monkeypatch.setattr( - "vectorcode.main.get_project_config", - AsyncMock( - return_value=MagicMock( - merge_from=AsyncMock(return_value=mock_final_configs) - ) - ), - ) - monkeypatch.setattr("vectorcode.common.try_server", AsyncMock(return_value=False)) - mock_start_server = AsyncMock() - monkeypatch.setattr("vectorcode.common.start_server", mock_start_server) - monkeypatch.setattr("vectorcode.subcommands.query", AsyncMock(return_value=0)) - mock_start_server.return_value.wait = AsyncMock() - mock_start_server.return_value.terminate = MagicMock() - - await async_main() - mock_start_server.assert_called_once_with(mock_final_configs) - - @pytest.mark.asyncio async def test_async_main_cli_action_query(monkeypatch): mock_cli_args = MagicMock(no_stderr=False, project_root=".", action=CliAction.query) @@ -345,35 +319,6 @@ async def test_async_main_exception_handling(monkeypatch): mock_logger.error.assert_called_once() -@pytest.mark.asyncio -async def test_async_main_server_process_termination(monkeypatch): - mock_cli_args = MagicMock(no_stderr=False, project_root=".", action=CliAction.query) - monkeypatch.setattr( - "vectorcode.main.parse_cli_args", AsyncMock(return_value=mock_cli_args) - ) - mock_final_configs = MagicMock(host="test_host", port=1234, action=CliAction.query) - monkeypatch.setattr( - "vectorcode.main.get_project_config", - AsyncMock( - return_value=MagicMock( - merge_from=AsyncMock(return_value=mock_final_configs) - ) - ), - ) - monkeypatch.setattr("vectorcode.common.try_server", AsyncMock(return_value=False)) - mock_server_process = AsyncMock() - mock_start_server = AsyncMock(return_value=mock_server_process) - monkeypatch.setattr("vectorcode.common.start_server", mock_start_server) - monkeypatch.setattr("vectorcode.subcommands.query", AsyncMock(return_value=0)) - mock_server_process.terminate = MagicMock() - mock_server_process.wait = AsyncMock() - - await async_main() - - mock_server_process.terminate.assert_called_once() - await mock_server_process.wait() - - def test_main(monkeypatch): mock_async_main = AsyncMock(return_value=0) monkeypatch.setattr("vectorcode.main.async_main", mock_async_main) diff --git a/tests/test_mcp.py b/tests/test_mcp.py index a9be2f11..5be133af 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -1,6 +1,7 @@ import os import tempfile from argparse import ArgumentParser +from contextlib import asynccontextmanager from unittest.mock import AsyncMock, MagicMock, mock_open, patch import pytest @@ -20,11 +21,14 @@ @pytest.mark.asyncio async def test_list_collections_success(): with ( - patch("vectorcode.mcp_main.get_client") as mock_get_client, patch("vectorcode.mcp_main.get_collections") as mock_get_collections, + patch("vectorcode.common.ClientManager") as MockClientManager, ): mock_client = AsyncMock() - mock_get_client.return_value = mock_client + mock_client_manager_instance = MockClientManager.return_value + mock_client_manager_instance.get_client = asynccontextmanager( + AsyncMock(return_value=mock_client) + ) mock_collection1 = AsyncMock() mock_collection1.metadata = {"path": "path1"} @@ -44,11 +48,14 @@ async def async_generator(): @pytest.mark.asyncio async def test_list_collections_no_metadata(): with ( - patch("vectorcode.mcp_main.get_client") as mock_get_client, patch("vectorcode.mcp_main.get_collections") as mock_get_collections, + patch("vectorcode.common.ClientManager") as MockClientManager, ): mock_client = AsyncMock() - mock_get_client.return_value = mock_client + mock_client_manager_instance = MockClientManager.return_value + mock_client_manager_instance.get_client = asynccontextmanager( + AsyncMock(return_value=mock_client) + ) mock_collection1 = AsyncMock() mock_collection1.metadata = {"path": "path1"} mock_collection2 = AsyncMock() @@ -84,7 +91,6 @@ async def test_query_tool_success(): with ( patch("os.path.isdir", return_value=True), patch("vectorcode.mcp_main.get_project_config") as mock_get_project_config, - patch("vectorcode.mcp_main.get_client") as mock_get_client, patch("vectorcode.mcp_main.get_collection") as mock_get_collection, patch( "vectorcode.subcommands.query.get_query_result_files" @@ -93,12 +99,16 @@ async def test_query_tool_success(): patch("os.path.isfile", return_value=True), patch("os.path.relpath", return_value="rel/path.py"), patch("vectorcode.cli_utils.load_config_file") as mock_load_config_file, + patch("vectorcode.common.ClientManager") as MockClientManager, ): mock_config = Config(chunk_size=100, overlap_ratio=0.1, reranker=None) mock_load_config_file.return_value = mock_config mock_get_project_config.return_value = mock_config mock_client = AsyncMock() - mock_get_client.return_value = mock_client + mock_client_manager_instance = MockClientManager.return_value + mock_client_manager_instance.get_client = asynccontextmanager( + AsyncMock(return_value=mock_client) + ) # Mock the collection's query method to return a valid QueryResult mock_collection = AsyncMock() @@ -131,11 +141,17 @@ async def test_query_tool_collection_access_failure(): with ( patch("os.path.isdir", return_value=True), patch("vectorcode.mcp_main.get_project_config"), - patch("vectorcode.mcp_main.get_client") as mock_get_client, - patch("vectorcode.mcp_main.get_collection") as mock_get_collection, + patch("vectorcode.mcp_main.get_collection"), # Still mock get_collection + patch("vectorcode.common.ClientManager") as MockClientManager, ): - mock_get_client.side_effect = Exception("Failed to connect") - mock_get_collection.side_effect = Exception("Failed to connect") + mock_client_manager_instance = MockClientManager.return_value + + async def failing_get_client(*args, **kwargs): + raise Exception("Failed to connect") + + mock_client_manager_instance.get_client = asynccontextmanager( + failing_get_client + ) with pytest.raises(McpError) as exc_info: await query_tool( @@ -154,9 +170,15 @@ async def test_query_tool_no_collection(): with ( patch("os.path.isdir", return_value=True), patch("vectorcode.mcp_main.get_project_config"), - patch("vectorcode.mcp_main.get_client"), - patch("vectorcode.mcp_main.get_collection") as mock_get_collection, + patch( + "vectorcode.mcp_main.get_collection" + ) as mock_get_collection, # Still mock get_collection + patch("vectorcode.common.ClientManager") as MockClientManager, ): + mock_client_manager_instance = MockClientManager.return_value + mock_client_manager_instance.get_client = asynccontextmanager( + AsyncMock() + ) # Provide a working get_client mock_get_collection.return_value = None with pytest.raises(McpError) as exc_info: @@ -166,8 +188,8 @@ async def test_query_tool_no_collection(): assert exc_info.value.error.code == 1 assert ( - exc_info.value.error.message - == "Failed to access the collection at /valid/path. Use `list_collections` tool to get a list of valid paths for this field." + "Failed to access the collection at /valid/path. Use `list_collections` tool to get a list of valid paths for this field." + in exc_info.value.error.message ) @@ -188,9 +210,9 @@ async def test_vectorise_files_success(): f.write("def func(): pass") with ( + patch("vectorcode.common.ClientManager") as MockClientManager, patch("os.path.isdir", return_value=True), patch("vectorcode.mcp_main.get_project_config") as mock_get_project_config, - patch("vectorcode.mcp_main.get_client") as mock_get_client, patch("vectorcode.mcp_main.get_collection") as mock_get_collection, patch("vectorcode.subcommands.vectorise.chunked_add"), patch( @@ -200,7 +222,17 @@ async def test_vectorise_files_success(): mock_config = Config(project_root=temp_dir) mock_get_project_config.return_value = mock_config mock_client = AsyncMock() - mock_get_client.return_value = mock_client + + mock_client_manager_instance = MockClientManager.return_value + # Ensure ClientManager's internal client creation method returns our mock. + mock_client_manager_instance._create_client = AsyncMock( + return_value=mock_client + ) + # Ensure ClientManager's get_client context manager yields our mock. + mock_client_manager_instance.get_client = asynccontextmanager( + AsyncMock(return_value=mock_client) + ) + mock_collection = AsyncMock() mock_collection.get.return_value = {"ids": [], "metadatas": []} mock_get_collection.return_value = mock_collection @@ -210,18 +242,29 @@ async def test_vectorise_files_success(): assert result["add"] == 1 mock_get_project_config.assert_called_once_with(temp_dir) - mock_get_client.assert_called_once_with(mock_config) - mock_get_collection.assert_called_once_with(mock_client, mock_config, True) + # Assert that the mocked get_collection was called with our mock_client. + mock_get_collection.assert_called_once() @pytest.mark.asyncio -async def test_vectorise_files_collection_access_failure(): +async def test_vectorise_files_collection_access_failure(): # Removed client_manager fixture with ( patch("os.path.isdir", return_value=True), patch("vectorcode.mcp_main.get_project_config"), - patch("vectorcode.mcp_main.get_client", side_effect=Exception("Client error")), + # patch("vectorcode.mcp_main.get_client", side_effect=Exception("Client error")), # Removed explicit patch + patch( + "vectorcode.common.ClientManager" + ) as MockClientManager, # Patch ClientManager class patch("vectorcode.mcp_main.get_collection"), ): + mock_client_manager_instance = MockClientManager.return_value + + async def failing_get_client(*args, **kwargs): + raise Exception("Client error") + + mock_client_manager_instance.get_client = asynccontextmanager( + failing_get_client + ) with pytest.raises(McpError) as exc_info: await vectorise_files(paths=["file.py"], project_root="/valid/path") @@ -257,7 +300,6 @@ def mock_open_side_effect(filename, *args, **kwargs): with ( patch("os.path.isdir", return_value=True), patch("vectorcode.mcp_main.get_project_config") as mock_get_project_config, - patch("vectorcode.mcp_main.get_client") as mock_get_client, patch("vectorcode.mcp_main.get_collection") as mock_get_collection, patch("vectorcode.subcommands.vectorise.chunked_add") as mock_chunked_add, patch( @@ -270,11 +312,18 @@ def mock_open_side_effect(filename, *args, **kwargs): "os.path.isfile", side_effect=lambda x: x in [file1, excluded_file, exclude_spec_file], ), + patch("vectorcode.common.ClientManager") as MockClientManager, ): mock_config = Config(project_root=temp_dir) mock_get_project_config.return_value = mock_config mock_client = AsyncMock() - mock_get_client.return_value = mock_client + mock_client_manager_instance = MockClientManager.return_value + mock_client_manager_instance._create_client = AsyncMock( + return_value=mock_client + ) + mock_client_manager_instance.get_client = asynccontextmanager( + AsyncMock(return_value=mock_client) + ) mock_collection = AsyncMock() mock_collection.get.return_value = {"ids": [], "metadatas": []} mock_get_collection.return_value = mock_collection @@ -297,14 +346,18 @@ async def test_mcp_server(): "vectorcode.mcp_main.find_project_config_dir" ) as mock_find_project_config_dir, patch("vectorcode.mcp_main.load_config_file") as mock_load_config_file, - patch("vectorcode.mcp_main.get_client") as mock_get_client, + # patch("vectorcode.mcp_main.get_client") as mock_get_client, # Removed patch("vectorcode.mcp_main.get_collection") as mock_get_collection, patch("mcp.server.fastmcp.FastMCP.add_tool") as mock_add_tool, + patch("vectorcode.common.ClientManager") as MockClientManager, # Added ): mock_find_project_config_dir.return_value = "/path/to/config" mock_load_config_file.return_value = Config(project_root="/path/to/project") mock_client = AsyncMock() - mock_get_client.return_value = mock_client + mock_client_manager_instance = MockClientManager.return_value + mock_client_manager_instance.get_client = asynccontextmanager( + AsyncMock(return_value=mock_client) + ) mock_collection = AsyncMock() mock_get_collection.return_value = mock_collection @@ -320,12 +373,13 @@ async def test_mcp_server_ls_on_start(): "vectorcode.mcp_main.find_project_config_dir" ) as mock_find_project_config_dir, patch("vectorcode.mcp_main.load_config_file") as mock_load_config_file, - patch("vectorcode.mcp_main.get_client") as mock_get_client, + # patch("vectorcode.mcp_main.get_client") as mock_get_client, # Removed patch("vectorcode.mcp_main.get_collection") as mock_get_collection, patch( "vectorcode.mcp_main.get_collections", spec=AsyncMock ) as mock_get_collections, patch("mcp.server.fastmcp.FastMCP.add_tool") as mock_add_tool, + patch("vectorcode.common.ClientManager") as MockClientManager, # Added ): from vectorcode.mcp_main import mcp_config @@ -333,7 +387,10 @@ async def test_mcp_server_ls_on_start(): mock_find_project_config_dir.return_value = "/path/to/config" mock_load_config_file.return_value = Config(project_root="/path/to/project") mock_client = AsyncMock() - mock_get_client.return_value = mock_client + mock_client_manager_instance = MockClientManager.return_value + mock_client_manager_instance.get_client = asynccontextmanager( + AsyncMock(return_value=mock_client) + ) mock_collection = AsyncMock() mock_collection.metadata = {"path": "/path/to/project"} mock_get_collection.return_value = mock_collection From 117bc2c5908c2c6d4973c06e0e4f284a9dc4c1f0 Mon Sep 17 00:00:00 2001 From: Zhe Yu Date: Fri, 27 Jun 2025 17:06:37 +0800 Subject: [PATCH 07/17] fix(cli): Fix termination and add test case for `kill_servers` --- src/vectorcode/common.py | 2 +- tests/test_common.py | 22 +++++++++++++++++++++- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/src/vectorcode/common.py b/src/vectorcode/common.py index b4ae94cf..232a7e94 100644 --- a/src/vectorcode/common.py +++ b/src/vectorcode/common.py @@ -288,7 +288,7 @@ def get_processes(self) -> list[Process]: async def kill_servers(self): termination_tasks: list[asyncio.Task] = [] - for p in ClientManager().get_processes(): + for p in self.get_processes(): logger.info(f"Killing bundled chroma server with PID: {p.pid}") p.terminate() termination_tasks.append(asyncio.create_task(p.wait())) diff --git a/tests/test_common.py b/tests/test_common.py index e3345ec0..946514e0 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -3,7 +3,7 @@ import subprocess import sys import tempfile -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import httpx import pytest @@ -607,3 +607,23 @@ async def test_client_manager_get_client(): async with manager.get_client(Config()): mock_try_server.assert_called_once() mock_start_server.assert_not_called() + + +@pytest.mark.asyncio +async def test_client_manager_kill_servers(): + manager = ClientManager() + manager.clear() + + async def _try_server(url): + return "127.0.0.1" in url or "localhost" in url + + mock_process = AsyncMock() + with ( + patch("vectorcode.common.start_server", return_value=mock_process), + patch("vectorcode.common.try_server", side_effect=_try_server), + ): + manager._create_client = AsyncMock(return_value=AsyncMock()) + async with manager.get_client(Config(db_url="http://test_host:1081")): + pass + await manager.kill_servers() + mock_process.terminate.assert_called_once() From 115ac76922f184ea7eace7a4493d4b183031dc64 Mon Sep 17 00:00:00 2001 From: Zhe Yu Date: Fri, 27 Jun 2025 17:17:44 +0800 Subject: [PATCH 08/17] test(cli): fix mocking in mcp tests --- tests/test_mcp.py | 55 +++++++++++++++-------------------------------- 1 file changed, 17 insertions(+), 38 deletions(-) diff --git a/tests/test_mcp.py b/tests/test_mcp.py index 5be133af..6aa7fd32 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -25,9 +25,8 @@ async def test_list_collections_success(): patch("vectorcode.common.ClientManager") as MockClientManager, ): mock_client = AsyncMock() - mock_client_manager_instance = MockClientManager.return_value - mock_client_manager_instance.get_client = asynccontextmanager( - AsyncMock(return_value=mock_client) + MockClientManager.return_value._create_client = AsyncMock( + return_value=mock_client ) mock_collection1 = AsyncMock() @@ -52,8 +51,7 @@ async def test_list_collections_no_metadata(): patch("vectorcode.common.ClientManager") as MockClientManager, ): mock_client = AsyncMock() - mock_client_manager_instance = MockClientManager.return_value - mock_client_manager_instance.get_client = asynccontextmanager( + MockClientManager.return_value._create_client = asynccontextmanager( AsyncMock(return_value=mock_client) ) mock_collection1 = AsyncMock() @@ -105,9 +103,8 @@ async def test_query_tool_success(): mock_load_config_file.return_value = mock_config mock_get_project_config.return_value = mock_config mock_client = AsyncMock() - mock_client_manager_instance = MockClientManager.return_value - mock_client_manager_instance.get_client = asynccontextmanager( - AsyncMock(return_value=mock_client) + MockClientManager.return_value._create_client = AsyncMock( + return_value=mock_client ) # Mock the collection's query method to return a valid QueryResult @@ -144,14 +141,11 @@ async def test_query_tool_collection_access_failure(): patch("vectorcode.mcp_main.get_collection"), # Still mock get_collection patch("vectorcode.common.ClientManager") as MockClientManager, ): - mock_client_manager_instance = MockClientManager.return_value async def failing_get_client(*args, **kwargs): raise Exception("Failed to connect") - mock_client_manager_instance.get_client = asynccontextmanager( - failing_get_client - ) + MockClientManager.return_value._create_client.side_effect = failing_get_client with pytest.raises(McpError) as exc_info: await query_tool( @@ -175,10 +169,7 @@ async def test_query_tool_no_collection(): ) as mock_get_collection, # Still mock get_collection patch("vectorcode.common.ClientManager") as MockClientManager, ): - mock_client_manager_instance = MockClientManager.return_value - mock_client_manager_instance.get_client = asynccontextmanager( - AsyncMock() - ) # Provide a working get_client + MockClientManager.return_value._create_client.return_value = AsyncMock() mock_get_collection.return_value = None with pytest.raises(McpError) as exc_info: @@ -223,15 +214,10 @@ async def test_vectorise_files_success(): mock_get_project_config.return_value = mock_config mock_client = AsyncMock() - mock_client_manager_instance = MockClientManager.return_value # Ensure ClientManager's internal client creation method returns our mock. - mock_client_manager_instance._create_client = AsyncMock( + MockClientManager.return_value._create_client = AsyncMock( return_value=mock_client ) - # Ensure ClientManager's get_client context manager yields our mock. - mock_client_manager_instance.get_client = asynccontextmanager( - AsyncMock(return_value=mock_client) - ) mock_collection = AsyncMock() mock_collection.get.return_value = {"ids": [], "metadatas": []} @@ -257,14 +243,12 @@ async def test_vectorise_files_collection_access_failure(): # Removed client_ma ) as MockClientManager, # Patch ClientManager class patch("vectorcode.mcp_main.get_collection"), ): - mock_client_manager_instance = MockClientManager.return_value async def failing_get_client(*args, **kwargs): raise Exception("Client error") - mock_client_manager_instance.get_client = asynccontextmanager( - failing_get_client - ) + MockClientManager.return_value._create_client = failing_get_client + with pytest.raises(McpError) as exc_info: await vectorise_files(paths=["file.py"], project_root="/valid/path") @@ -317,13 +301,10 @@ def mock_open_side_effect(filename, *args, **kwargs): mock_config = Config(project_root=temp_dir) mock_get_project_config.return_value = mock_config mock_client = AsyncMock() - mock_client_manager_instance = MockClientManager.return_value - mock_client_manager_instance._create_client = AsyncMock( + MockClientManager.return_value._create_client = AsyncMock( return_value=mock_client ) - mock_client_manager_instance.get_client = asynccontextmanager( - AsyncMock(return_value=mock_client) - ) + mock_collection = AsyncMock() mock_collection.get.return_value = {"ids": [], "metadatas": []} mock_get_collection.return_value = mock_collection @@ -354,10 +335,8 @@ async def test_mcp_server(): mock_find_project_config_dir.return_value = "/path/to/config" mock_load_config_file.return_value = Config(project_root="/path/to/project") mock_client = AsyncMock() - mock_client_manager_instance = MockClientManager.return_value - mock_client_manager_instance.get_client = asynccontextmanager( - AsyncMock(return_value=mock_client) - ) + + MockClientManager.return_value.get_client = AsyncMock(return_value=mock_client) mock_collection = AsyncMock() mock_get_collection.return_value = mock_collection @@ -387,9 +366,9 @@ async def test_mcp_server_ls_on_start(): mock_find_project_config_dir.return_value = "/path/to/config" mock_load_config_file.return_value = Config(project_root="/path/to/project") mock_client = AsyncMock() - mock_client_manager_instance = MockClientManager.return_value - mock_client_manager_instance.get_client = asynccontextmanager( - AsyncMock(return_value=mock_client) + + MockClientManager.return_value._create_client = AsyncMock( + return_value=mock_client ) mock_collection = AsyncMock() mock_collection.metadata = {"path": "/path/to/project"} From 831dd506b491ed1b9ea1a4c5ee3d950dd9fe7aa4 Mon Sep 17 00:00:00 2001 From: Zhe Yu Date: Fri, 27 Jun 2025 17:56:57 +0800 Subject: [PATCH 09/17] tests: Move client manager tests to test_client_manager --- tests/test_common.py | 161 ++++++++++++++++++------------------------- 1 file changed, 66 insertions(+), 95 deletions(-) diff --git a/tests/test_common.py b/tests/test_common.py index 946514e0..8b970923 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -150,79 +150,6 @@ async def test_try_server_versions(): assert await try_server("http://localhost:8300") is False -@pytest.mark.asyncio -async def test_get_client(): - config = Config(db_url="https://test_host:1234", db_path="test_db") - config1 = Config( - db_url="http://test_host1:1234", - db_path="test_db", - db_settings={"anonymized_telemetry": True}, - ) - config1_alt = Config( - db_url="http://test_host1:1234", - db_path="test_db", - db_settings={"anonymized_telemetry": True, "other_setting": "value"}, - ) - # Patch chromadb.AsyncHttpClient to avoid actual network calls - with ( - patch("chromadb.AsyncHttpClient") as MockAsyncHttpClient, - patch("vectorcode.common.try_server", return_value=True), - ): - mock_client = MagicMock(spec=AsyncClientAPI) - MockAsyncHttpClient.return_value = mock_client - - async with ( - ClientManager().get_client(config) as client, - ): - assert isinstance(client, AsyncClientAPI) - MockAsyncHttpClient.assert_called() - assert ( - MockAsyncHttpClient.call_args.kwargs["settings"].chroma_server_host - == "test_host" - ) - assert ( - MockAsyncHttpClient.call_args.kwargs["settings"].chroma_server_http_port - == 1234 - ) - assert ( - MockAsyncHttpClient.call_args.kwargs["settings"].anonymized_telemetry - is False - ) - assert ( - MockAsyncHttpClient.call_args.kwargs[ - "settings" - ].chroma_server_ssl_enabled - is True - ) - - async with ( - ClientManager().get_client(config1) as client1, - ClientManager().get_client(config1_alt) as client1_alt, - ): - assert isinstance(client1, AsyncClientAPI) - MockAsyncHttpClient.assert_called() - assert ( - MockAsyncHttpClient.call_args.kwargs["settings"].chroma_server_host - == "test_host1" - ) - assert ( - MockAsyncHttpClient.call_args.kwargs[ - "settings" - ].chroma_server_http_port - == 1234 - ) - assert ( - MockAsyncHttpClient.call_args.kwargs[ - "settings" - ].anonymized_telemetry - is True - ) - - # Test with multiple db_settings, including an invalid one. The invalid one - # should be filtered out inside get_client. - assert id(client1_alt) == id(client1) - - def test_verify_ef(): # Mocking AsyncCollection and Config mock_collection = MagicMock() @@ -582,31 +509,75 @@ async def test_wait_for_server_timeout(): @pytest.mark.asyncio async def test_client_manager_get_client(): - ClientManager().clear() + config = Config(db_url="https://test_host:1234", db_path="test_db") + config1 = Config( + db_url="http://test_host1:1234", + db_path="test_db", + db_settings={"anonymized_telemetry": True}, + ) + config1_alt = Config( + db_url="http://test_host1:1234", + db_path="test_db", + db_settings={"anonymized_telemetry": True, "other_setting": "value"}, + ) + # Patch chromadb.AsyncHttpClient to avoid actual network calls with ( - patch("vectorcode.common.try_server", return_value=False) as mock_try_server, - patch( - "vectorcode.common.start_server", return_value=MagicMock() - ) as mock_start_server, + patch("chromadb.AsyncHttpClient") as MockAsyncHttpClient, + patch("vectorcode.common.try_server", return_value=True), ): - # need to start a new server - manager = ClientManager() - async with manager.get_client(Config()): - mock_try_server.assert_called_once() - mock_start_server.assert_called_once() + mock_client = MagicMock(spec=AsyncClientAPI) + MockAsyncHttpClient.return_value = mock_client - ClientManager().clear() - with ( - patch("vectorcode.common.try_server", return_value=True) as mock_try_server, - patch( - "vectorcode.common.start_server", return_value=MagicMock() - ) as mock_start_server, - ): - # need to start a new server - manager = ClientManager() - async with manager.get_client(Config()): - mock_try_server.assert_called_once() - mock_start_server.assert_not_called() + async with ( + ClientManager().get_client(config) as client, + ): + assert isinstance(client, AsyncClientAPI) + MockAsyncHttpClient.assert_called() + assert ( + MockAsyncHttpClient.call_args.kwargs["settings"].chroma_server_host + == "test_host" + ) + assert ( + MockAsyncHttpClient.call_args.kwargs["settings"].chroma_server_http_port + == 1234 + ) + assert ( + MockAsyncHttpClient.call_args.kwargs["settings"].anonymized_telemetry + is False + ) + assert ( + MockAsyncHttpClient.call_args.kwargs[ + "settings" + ].chroma_server_ssl_enabled + is True + ) + + async with ( + ClientManager().get_client(config1) as client1, + ClientManager().get_client(config1_alt) as client1_alt, + ): + assert isinstance(client1, AsyncClientAPI) + MockAsyncHttpClient.assert_called() + assert ( + MockAsyncHttpClient.call_args.kwargs["settings"].chroma_server_host + == "test_host1" + ) + assert ( + MockAsyncHttpClient.call_args.kwargs[ + "settings" + ].chroma_server_http_port + == 1234 + ) + assert ( + MockAsyncHttpClient.call_args.kwargs[ + "settings" + ].anonymized_telemetry + is True + ) + + # Test with multiple db_settings, including an invalid one. The invalid one + # should be filtered out inside get_client. + assert id(client1_alt) == id(client1) @pytest.mark.asyncio From db0b19a3374eabe45ec5d5b897f92556228533a2 Mon Sep 17 00:00:00 2001 From: Zhe Yu Date: Fri, 27 Jun 2025 18:32:14 +0800 Subject: [PATCH 10/17] tests(cli): fix broken tests --- tests/test_mcp.py | 81 +++++++++++++++++++++++------------------------ 1 file changed, 39 insertions(+), 42 deletions(-) diff --git a/tests/test_mcp.py b/tests/test_mcp.py index 6aa7fd32..2b809426 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -1,7 +1,6 @@ import os import tempfile from argparse import ArgumentParser -from contextlib import asynccontextmanager from unittest.mock import AsyncMock, MagicMock, mock_open, patch import pytest @@ -22,12 +21,12 @@ async def test_list_collections_success(): with ( patch("vectorcode.mcp_main.get_collections") as mock_get_collections, - patch("vectorcode.common.ClientManager") as MockClientManager, + patch("vectorcode.common.try_server", return_value=True), ): + from vectorcode.mcp_main import ClientManager + mock_client = AsyncMock() - MockClientManager.return_value._create_client = AsyncMock( - return_value=mock_client - ) + ClientManager._create_client = AsyncMock(return_value=mock_client) mock_collection1 = AsyncMock() mock_collection1.metadata = {"path": "path1"} @@ -48,12 +47,13 @@ async def async_generator(): async def test_list_collections_no_metadata(): with ( patch("vectorcode.mcp_main.get_collections") as mock_get_collections, - patch("vectorcode.common.ClientManager") as MockClientManager, + patch("vectorcode.common.try_server", return_value=True), ): + from vectorcode.mcp_main import ClientManager + mock_client = AsyncMock() - MockClientManager.return_value._create_client = asynccontextmanager( - AsyncMock(return_value=mock_client) - ) + ClientManager._create_client = AsyncMock(return_value=mock_client) + mock_collection1 = AsyncMock() mock_collection1.metadata = {"path": "path1"} mock_collection2 = AsyncMock() @@ -93,19 +93,19 @@ async def test_query_tool_success(): patch( "vectorcode.subcommands.query.get_query_result_files" ) as mock_get_query_result_files, + patch("vectorcode.common.try_server", return_value=True), patch("builtins.open", create=True) as mock_open, patch("os.path.isfile", return_value=True), patch("os.path.relpath", return_value="rel/path.py"), patch("vectorcode.cli_utils.load_config_file") as mock_load_config_file, - patch("vectorcode.common.ClientManager") as MockClientManager, ): + from vectorcode.mcp_main import ClientManager + mock_config = Config(chunk_size=100, overlap_ratio=0.1, reranker=None) mock_load_config_file.return_value = mock_config mock_get_project_config.return_value = mock_config mock_client = AsyncMock() - MockClientManager.return_value._create_client = AsyncMock( - return_value=mock_client - ) + ClientManager._create_client = AsyncMock(return_value=mock_client) # Mock the collection's query method to return a valid QueryResult mock_collection = AsyncMock() @@ -139,13 +139,13 @@ async def test_query_tool_collection_access_failure(): patch("os.path.isdir", return_value=True), patch("vectorcode.mcp_main.get_project_config"), patch("vectorcode.mcp_main.get_collection"), # Still mock get_collection - patch("vectorcode.common.ClientManager") as MockClientManager, ): + from vectorcode.mcp_main import ClientManager async def failing_get_client(*args, **kwargs): raise Exception("Failed to connect") - MockClientManager.return_value._create_client.side_effect = failing_get_client + ClientManager._create_client = AsyncMock(side_effect=failing_get_client) with pytest.raises(McpError) as exc_info: await query_tool( @@ -169,7 +169,8 @@ async def test_query_tool_no_collection(): ) as mock_get_collection, # Still mock get_collection patch("vectorcode.common.ClientManager") as MockClientManager, ): - MockClientManager.return_value._create_client.return_value = AsyncMock() + mock_client = AsyncMock() + MockClientManager.return_value._create_client.return_value = mock_client mock_get_collection.return_value = None with pytest.raises(McpError) as exc_info: @@ -201,7 +202,6 @@ async def test_vectorise_files_success(): f.write("def func(): pass") with ( - patch("vectorcode.common.ClientManager") as MockClientManager, patch("os.path.isdir", return_value=True), patch("vectorcode.mcp_main.get_project_config") as mock_get_project_config, patch("vectorcode.mcp_main.get_collection") as mock_get_collection, @@ -209,15 +209,16 @@ async def test_vectorise_files_success(): patch( "vectorcode.subcommands.vectorise.hash_file", return_value="test_hash" ), + patch("vectorcode.common.try_server", return_value=True), ): + from vectorcode.mcp_main import ClientManager + mock_config = Config(project_root=temp_dir) mock_get_project_config.return_value = mock_config mock_client = AsyncMock() # Ensure ClientManager's internal client creation method returns our mock. - MockClientManager.return_value._create_client = AsyncMock( - return_value=mock_client - ) + ClientManager._create_client = AsyncMock(return_value=mock_client) mock_collection = AsyncMock() mock_collection.get.return_value = {"ids": [], "metadatas": []} @@ -233,21 +234,16 @@ async def test_vectorise_files_success(): @pytest.mark.asyncio -async def test_vectorise_files_collection_access_failure(): # Removed client_manager fixture +async def test_vectorise_files_collection_access_failure(): with ( patch("os.path.isdir", return_value=True), patch("vectorcode.mcp_main.get_project_config"), - # patch("vectorcode.mcp_main.get_client", side_effect=Exception("Client error")), # Removed explicit patch - patch( - "vectorcode.common.ClientManager" - ) as MockClientManager, # Patch ClientManager class + patch("vectorcode.common.ClientManager"), # Patch ClientManager class patch("vectorcode.mcp_main.get_collection"), ): + from vectorcode.mcp_main import ClientManager - async def failing_get_client(*args, **kwargs): - raise Exception("Client error") - - MockClientManager.return_value._create_client = failing_get_client + ClientManager._create_client = AsyncMock(side_effect=Exception("Client error")) with pytest.raises(McpError) as exc_info: await vectorise_files(paths=["file.py"], project_root="/valid/path") @@ -296,14 +292,14 @@ def mock_open_side_effect(filename, *args, **kwargs): "os.path.isfile", side_effect=lambda x: x in [file1, excluded_file, exclude_spec_file], ), - patch("vectorcode.common.ClientManager") as MockClientManager, + patch("vectorcode.common.try_server", return_value=True), ): + from vectorcode.mcp_main import ClientManager + mock_config = Config(project_root=temp_dir) mock_get_project_config.return_value = mock_config mock_client = AsyncMock() - MockClientManager.return_value._create_client = AsyncMock( - return_value=mock_client - ) + ClientManager._create_client = AsyncMock(return_value=mock_client) mock_collection = AsyncMock() mock_collection.get.return_value = {"ids": [], "metadatas": []} @@ -330,13 +326,15 @@ async def test_mcp_server(): # patch("vectorcode.mcp_main.get_client") as mock_get_client, # Removed patch("vectorcode.mcp_main.get_collection") as mock_get_collection, patch("mcp.server.fastmcp.FastMCP.add_tool") as mock_add_tool, - patch("vectorcode.common.ClientManager") as MockClientManager, # Added + patch("vectorcode.common.try_server", return_value=True), ): + from vectorcode.mcp_main import ClientManager + mock_find_project_config_dir.return_value = "/path/to/config" mock_load_config_file.return_value = Config(project_root="/path/to/project") mock_client = AsyncMock() - MockClientManager.return_value.get_client = AsyncMock(return_value=mock_client) + ClientManager._create_client = AsyncMock(return_value=mock_client) mock_collection = AsyncMock() mock_get_collection.return_value = mock_collection @@ -347,30 +345,29 @@ async def test_mcp_server(): @pytest.mark.asyncio async def test_mcp_server_ls_on_start(): + mock_collection = AsyncMock() + with ( patch( "vectorcode.mcp_main.find_project_config_dir" ) as mock_find_project_config_dir, patch("vectorcode.mcp_main.load_config_file") as mock_load_config_file, - # patch("vectorcode.mcp_main.get_client") as mock_get_client, # Removed patch("vectorcode.mcp_main.get_collection") as mock_get_collection, patch( "vectorcode.mcp_main.get_collections", spec=AsyncMock ) as mock_get_collections, patch("mcp.server.fastmcp.FastMCP.add_tool") as mock_add_tool, - patch("vectorcode.common.ClientManager") as MockClientManager, # Added + patch("vectorcode.common.try_server", return_value=True), ): - from vectorcode.mcp_main import mcp_config + from vectorcode.mcp_main import ClientManager, mcp_config mcp_config.ls_on_start = True mock_find_project_config_dir.return_value = "/path/to/config" mock_load_config_file.return_value = Config(project_root="/path/to/project") mock_client = AsyncMock() - MockClientManager.return_value._create_client = AsyncMock( - return_value=mock_client - ) - mock_collection = AsyncMock() + ClientManager._create_client = AsyncMock(return_value=mock_client) + mock_collection.metadata = {"path": "/path/to/project"} mock_get_collection.return_value = mock_collection From fc6d77bf1e7fc8c2bce51fc5def891c77c61dbdc Mon Sep 17 00:00:00 2001 From: Zhe Yu Date: Fri, 27 Jun 2025 21:39:47 +0800 Subject: [PATCH 11/17] cov --- tests/test_cli_utils.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/test_cli_utils.py b/tests/test_cli_utils.py index 3252683f..655ef19f 100644 --- a/tests/test_cli_utils.py +++ b/tests/test_cli_utils.py @@ -11,6 +11,7 @@ from vectorcode.cli_utils import ( CliAction, Config, + LockManager, PromptCategory, QueryInclude, cleanup_path, @@ -553,3 +554,11 @@ def test_shtab(): .stderr.read() .decode() ) == "" + + +@pytest.mark.asyncio +async def test_filelock(): + manager = LockManager() + with tempfile.TemporaryDirectory() as tmp_dir: + manager.get_lock(tmp_dir) + assert os.path.isfile(os.path.join(tmp_dir, "vectorcode.lock")) From e4ef8f82c67fdc1cabdfc163981dd5b43863645f Mon Sep 17 00:00:00 2001 From: Zhe Yu Date: Sat, 28 Jun 2025 11:07:06 +0800 Subject: [PATCH 12/17] feat: Fix singleton implementation for LockManager and ClientManager --- src/vectorcode/cli_utils.py | 10 ++++----- src/vectorcode/common.py | 10 ++++----- src/vectorcode/mcp_main.py | 2 +- tests/test_common.py | 43 ++++++++++++++++++++++++++++++++++++- tests/test_mcp.py | 7 ++++-- 5 files changed, 58 insertions(+), 14 deletions(-) diff --git a/src/vectorcode/cli_utils.py b/src/vectorcode/cli_utils.py index 4400743d..282bb5e4 100644 --- a/src/vectorcode/cli_utils.py +++ b/src/vectorcode/cli_utils.py @@ -619,13 +619,13 @@ class LockManager: """ __locks: dict[str, AsyncFileLock] - __singleton: "LockManager" + singleton: Optional["LockManager"] = None def __new__(cls) -> "LockManager": - if not hasattr(cls, "__singleton") or cls.__singleton is None: - cls.__singleton = super().__new__(cls) - cls.__singleton.__locks = {} - return cls.__singleton + if cls.singleton is None: + cls.singleton = super().__new__(cls) + cls.singleton.__locks = {} + return cls.singleton def get_lock(self, path: str | os.PathLike) -> AsyncFileLock: path = str(expand_path(str(path), True)) diff --git a/src/vectorcode/common.py b/src/vectorcode/common.py index 232a7e94..a297c25f 100644 --- a/src/vectorcode/common.py +++ b/src/vectorcode/common.py @@ -248,14 +248,14 @@ class _ClientModel: class ClientManager: - __singleton: Optional["ClientManager"] = None + singleton: Optional["ClientManager"] = None __clients: dict[str, _ClientModel] def __new__(cls) -> "ClientManager": - if not hasattr(cls, "__singleton") or cls.__singleton is None: - cls.__singleton = super().__new__(cls) - cls.__singleton.__clients = {} - return cls.__singleton + if cls.singleton is None: + cls.singleton = super().__new__(cls) + cls.singleton.__clients = {} + return cls.singleton @contextlib.asynccontextmanager async def get_client(self, configs: Config, need_lock: bool = True): diff --git a/src/vectorcode/mcp_main.py b/src/vectorcode/mcp_main.py index 8d5fdd44..72f57d92 100644 --- a/src/vectorcode/mcp_main.py +++ b/src/vectorcode/mcp_main.py @@ -188,7 +188,7 @@ async def query_tool( async with ClientManager().get_client(config) as client: collection = await get_collection(client, config, False) - if collection is None: + if collection is None: # pragma: nocover raise McpError( ErrorData( code=1, diff --git a/tests/test_common.py b/tests/test_common.py index 8b970923..f9b360b1 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -509,15 +509,19 @@ async def test_wait_for_server_timeout(): @pytest.mark.asyncio async def test_client_manager_get_client(): - config = Config(db_url="https://test_host:1234", db_path="test_db") + config = Config( + db_url="https://test_host:1234", db_path="test_db", project_root="test_proj" + ) config1 = Config( db_url="http://test_host1:1234", db_path="test_db", + project_root="test_proj1", db_settings={"anonymized_telemetry": True}, ) config1_alt = Config( db_url="http://test_host1:1234", db_path="test_db", + project_root="test_proj1", db_settings={"anonymized_telemetry": True, "other_setting": "value"}, ) # Patch chromadb.AsyncHttpClient to avoid actual network calls @@ -580,6 +584,42 @@ async def test_client_manager_get_client(): assert id(client1_alt) == id(client1) +@pytest.mark.asyncio +async def test_client_manager_list_server_processes(): + async def _try_server(url): + return "127.0.0.1" in url or "localhost" in url + + async def _start_server(cfg): + return AsyncMock() + + with ( + tempfile.TemporaryDirectory() as temp_dir, + patch("vectorcode.common.start_server", side_effect=_start_server), + patch("vectorcode.common.try_server", side_effect=_try_server), + ): + db_path = os.path.join(temp_dir, "db") + os.makedirs(db_path, exist_ok=True) + + ClientManager._create_client = AsyncMock() + async with ClientManager().get_client( + Config( + db_url="http://test_host:8001", + project_root="proj1", + db_path=db_path, + ) + ): + print(ClientManager().get_processes()) + async with ClientManager().get_client( + Config( + db_url="http://test_host:8002", + project_root="proj2", + db_path=db_path, + ) + ): + pass + assert len(ClientManager().get_processes()) == 2 + + @pytest.mark.asyncio async def test_client_manager_kill_servers(): manager = ClientManager() @@ -596,5 +636,6 @@ async def _try_server(url): manager._create_client = AsyncMock(return_value=AsyncMock()) async with manager.get_client(Config(db_url="http://test_host:1081")): pass + assert len(manager.get_processes()) == 1 await manager.kill_servers() mock_process.terminate.assert_called_once() diff --git a/tests/test_mcp.py b/tests/test_mcp.py index 2b809426..43b9eac9 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -87,6 +87,7 @@ async def test_query_tool_invalid_project_root(): @pytest.mark.asyncio async def test_query_tool_success(): with ( + tempfile.TemporaryDirectory() as temp_dir, patch("os.path.isdir", return_value=True), patch("vectorcode.mcp_main.get_project_config") as mock_get_project_config, patch("vectorcode.mcp_main.get_collection") as mock_get_collection, @@ -101,7 +102,9 @@ async def test_query_tool_success(): ): from vectorcode.mcp_main import ClientManager - mock_config = Config(chunk_size=100, overlap_ratio=0.1, reranker=None) + mock_config = Config( + chunk_size=100, overlap_ratio=0.1, reranker=None, project_root=temp_dir + ) mock_load_config_file.return_value = mock_config mock_get_project_config.return_value = mock_config mock_client = AsyncMock() @@ -126,7 +129,7 @@ async def test_query_tool_success(): mock_open.return_value = mock_file_handle result = await query_tool( - n_query=2, query_messages=["keyword1"], project_root="/valid/path" + n_query=2, query_messages=["keyword1"], project_root=temp_dir ) assert len(result) == 2 From 7d375105c9cab1b7fd52aafbdff398edf8a7be84 Mon Sep 17 00:00:00 2001 From: Zhe Yu Date: Sat, 28 Jun 2025 13:05:03 +0800 Subject: [PATCH 13/17] tests(cli): fixed some test warnings --- tests/test_common.py | 2 ++ tests/test_main.py | 11 +---------- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/tests/test_common.py b/tests/test_common.py index f9b360b1..40b51fd6 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -629,6 +629,7 @@ async def _try_server(url): return "127.0.0.1" in url or "localhost" in url mock_process = AsyncMock() + mock_process.terminate = MagicMock() with ( patch("vectorcode.common.start_server", return_value=mock_process), patch("vectorcode.common.try_server", side_effect=_try_server), @@ -639,3 +640,4 @@ async def _try_server(url): assert len(manager.get_processes()) == 1 await manager.kill_servers() mock_process.terminate.assert_called_once() + mock_process.wait.assert_awaited() diff --git a/tests/test_main.py b/tests/test_main.py index 34ce181f..c9f9b718 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -4,7 +4,7 @@ from vectorcode import __version__ from vectorcode.cli_utils import CliAction -from vectorcode.main import async_main, main +from vectorcode.main import async_main @pytest.mark.asyncio @@ -317,12 +317,3 @@ async def test_async_main_exception_handling(monkeypatch): with patch("vectorcode.main.logger") as mock_logger: assert await async_main() == 1 mock_logger.error.assert_called_once() - - -def test_main(monkeypatch): - mock_async_main = AsyncMock(return_value=0) - monkeypatch.setattr("vectorcode.main.async_main", mock_async_main) - monkeypatch.setattr("asyncio.run", MagicMock(return_value=0)) - - result = main() - assert result == 0 From a803eba399cf0c47a17175b129803655799a46f9 Mon Sep 17 00:00:00 2001 From: Zhe Yu Date: Sat, 28 Jun 2025 13:16:20 +0800 Subject: [PATCH 14/17] cov --- src/vectorcode/subcommands/update.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/vectorcode/subcommands/update.py b/src/vectorcode/subcommands/update.py index ff4efa1e..1416a7b8 100644 --- a/src/vectorcode/subcommands/update.py +++ b/src/vectorcode/subcommands/update.py @@ -30,12 +30,19 @@ async def update(configs: Config) -> int: file=sys.stderr, ) return 1 - if collection is None or not verify_ef(collection, configs): + if collection is None: # pragma: nocover + logger.error( + f"Failed to find a collection at {configs.project_root} from {configs.db_url}" + ) + return 1 + if not verify_ef(collection, configs): # pragma: nocover return 1 metas = (await collection.get(include=[IncludeEnum.metadatas]))["metadatas"] - if metas is None: + if metas is None or len(metas) == 0: # pragma: nocover + logger.debug("Empty collection.") return 0 + files_gen = (str(meta.get("path", "")) for meta in metas) files = set() orphanes = set() From cb5160aa71d992bfa75a6dc1f59d54622e828c6f Mon Sep 17 00:00:00 2001 From: Zhe Yu Date: Sat, 28 Jun 2025 15:04:30 +0800 Subject: [PATCH 15/17] fix(cli): only create lock file when doesn't exist --- src/vectorcode/cli_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/vectorcode/cli_utils.py b/src/vectorcode/cli_utils.py index 282bb5e4..e49b4b24 100644 --- a/src/vectorcode/cli_utils.py +++ b/src/vectorcode/cli_utils.py @@ -632,8 +632,9 @@ def get_lock(self, path: str | os.PathLike) -> AsyncFileLock: if os.path.isdir(path): lock_file = os.path.join(path, "vectorcode.lock") logger.info(f"Creating {lock_file} for locking.") - with open(lock_file, mode="w") as fin: - fin.write("") + if not os.path.isfile(lock_file): + with open(lock_file, mode="w") as fin: + fin.write("") path = lock_file if self.__locks.get(path) is None: self.__locks[path] = AsyncFileLock(path) # pyright: ignore[reportArgumentType] From f71d4873733db6fe5267b52358fa0d16055ca7de Mon Sep 17 00:00:00 2001 From: Zhe Yu Date: Sat, 28 Jun 2025 15:32:34 +0800 Subject: [PATCH 16/17] docs(cli, nvim): remove standalone server requirement for LSP and MCP --- docs/cli.md | 9 ++++----- docs/neovim.md | 2 +- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/docs/cli.md b/docs/cli.md index 3b386cdb..f6dc17c2 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -696,8 +696,9 @@ will: Note that: 1. For easier parsing, `--pipe` is assumed to be enabled in LSP mode; -2. At the time this only work with vectorcode setup that uses a **standalone - ChromaDB server**, which is not difficult to setup using docker; +2. A `vectorcode.lock` file will be created in your `db_path` directory __if + you're using the bundled chromadb server__. Please do not delete it while a + vectorcode process is running; 3. The LSP server supports `vectorise`, `query` and `ls` subcommands. The other subcommands may be added in the future. @@ -714,9 +715,7 @@ features: - `vectorise`: vectorise files into a given project. To try it out, install the `vectorcode[mcp]` dependency group and the MCP server -is available in the shell as `vectorcode-mcp-server`, and make sure you're using -a [standalone chromadb server](#chromadb) configured in the [JSON](#configuring-vectorcode) -via the `host` and `port` options. +is available in the shell as `vectorcode-mcp-server`. The MCP server entry point (`vectorcode-mcp-server`) provides some CLI options that you can use to customise the default behaviour of the server. To view the diff --git a/docs/neovim.md b/docs/neovim.md index 0f61566c..a2a3981c 100644 --- a/docs/neovim.md +++ b/docs/neovim.md @@ -332,7 +332,7 @@ interface: | Features | `default` | `lsp` | |----------|-----------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------| | **Pros** | Fully backward compatible with minimal extra config required | Less IO overhead for loading/unloading embedding models; Progress reports. | -| **Cons** | Heavy IO overhead because the embedding model and database client need to be initialised for every query. | Requires `vectorcode-server`; Only works if you're using a standalone ChromaDB server. | +| **Cons** | Heavy IO overhead because the embedding model and database client need to be initialised for every query. | Requires `vectorcode-server` | You may choose which backend to use by setting the [`setup`](#setupopts) option `async_backend`, and acquire the corresponding backend by the following API: From 07af955c64e7ee667e7238ea84f5621e719a8f99 Mon Sep 17 00:00:00 2001 From: Davidyz Date: Sat, 28 Jun 2025 07:33:24 +0000 Subject: [PATCH 17/17] Auto generate docs --- doc/VectorCode-cli.txt | 7 +++---- doc/VectorCode.txt | 6 +++--- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/doc/VectorCode-cli.txt b/doc/VectorCode-cli.txt index 510510fa..92aa5117 100644 --- a/doc/VectorCode-cli.txt +++ b/doc/VectorCode-cli.txt @@ -771,7 +771,8 @@ will use that as the default project root for this process; Note that: 1. For easier parsing, `--pipe` is assumed to be enabled in LSP mode; -2. At the time this only work with vectorcode setup that uses a **standalone ChromaDB server**, which is not difficult to setup using docker; +2. A `vectorcode.lock` file will be created in your `db_path` directory **if you’re using the bundled chromadb server**. Please do not delete it while a +vectorcode process is running; 3. The LSP server supports `vectorise`, `query` and `ls` subcommands. The other subcommands may be added in the future. @@ -789,9 +790,7 @@ features: - `vectorise`vectorise files into a given project. To try it out, install the `vectorcode[mcp]` dependency group and the MCP -server is available in the shell as `vectorcode-mcp-server`, and make sure -you’re using a |VectorCode-cli-standalone-chromadb-server| configured in the -|VectorCode-cli-json| via the `host` and `port` options. +server is available in the shell as `vectorcode-mcp-server`. The MCP server entry point (`vectorcode-mcp-server`) provides some CLI options that you can use to customise the default behaviour of the server. To view the diff --git a/doc/VectorCode.txt b/doc/VectorCode.txt index 22aeeec1..78410ecb 100644 --- a/doc/VectorCode.txt +++ b/doc/VectorCode.txt @@ -372,9 +372,9 @@ path to the executable) by calling `vim.lsp.config('vectorcode_server', opts)`. minimal extra config required loading/unloading embedding models; Progress reports. - Cons Heavy IO overhead because the Requires vectorcode-server; Only - embedding model and database works if you’re using a standalone - client need to be initialised ChromaDB server. + Cons Heavy IO overhead because the Requires vectorcode-server + embedding model and database + client need to be initialised for every query. ------------------------------------------------------------------------------- You may choose which backend to use by setting the |VectorCode-`setup`| option