From d3ebe527a16048c613c98dc56f139a2734651265 Mon Sep 17 00:00:00 2001 From: Jared Rieger Date: Sat, 31 May 2025 12:36:57 +1000 Subject: [PATCH 1/3] add db type to config --- src/vectorcode/cli_utils.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/vectorcode/cli_utils.py b/src/vectorcode/cli_utils.py index c6b3cb5a..83b63552 100644 --- a/src/vectorcode/cli_utils.py +++ b/src/vectorcode/cli_utils.py @@ -2,6 +2,7 @@ import atexit import glob import logging +from optparse import Option import os import sys from dataclasses import dataclass, field, fields @@ -62,6 +63,11 @@ class CliAction(Enum): hooks = "hooks" +class DbType(StrEnum): + local = "local" # Local ChromaDB instance + chromadb = "chromadb" # Remote ChromaDB instance + + @dataclass class Config: no_stderr: bool = False @@ -74,6 +80,7 @@ class Config: project_root: Optional[Union[str, Path]] = None query: Optional[list[str]] = None db_url: str = "http://127.0.0.1:8000" + db_type: DbType = DbType.local # falls back to a local instance embedding_function: str = "SentenceTransformerEmbeddingFunction" # This should fallback to whatever the default is. embedding_params: dict[str, Any] = field(default_factory=(lambda: {})) n_result: int = 1 @@ -106,6 +113,8 @@ async def import_from(cls, config_dict: dict[str, Any]) -> "Config": default_config = Config() db_path = config_dict.get("db_path") db_url = config_dict.get("db_url") + db_type = config_dict.get("db_type", default_config.db_type) + if db_url is None: host = config_dict.get("host") port = config_dict.get("port") @@ -135,6 +144,7 @@ async def import_from(cls, config_dict: dict[str, Any]) -> "Config": "embedding_params", default_config.embedding_params ), "db_url": db_url, + "db_type": db_type, "db_path": db_path, "db_log_path": os.path.expanduser( config_dict.get("db_log_path", default_config.db_log_path) @@ -521,6 +531,9 @@ async def get_project_config(project_root: Union[str, Path]) -> Config: if config is None: config = await load_config_file() config.project_root = project_root + + if config.db_type is None: + config.db_type = "local" return config From 74a67267a0f9c84e577ef5fb5ed93fa0b667205a Mon Sep 17 00:00:00 2001 From: Jared Rieger Date: Sat, 31 May 2025 12:38:26 +1000 Subject: [PATCH 2/3] WIP: base implementation for chroma --- src/vectorcode/common.py | 34 ++++- src/vectorcode/db/base.py | 132 ++++++++++++++++++ src/vectorcode/db/chroma.py | 117 ++++++++++++++++ src/vectorcode/db/factory.py | 25 ++++ src/vectorcode/db/local.py | 171 ++++++++++++++++++++++++ src/vectorcode/main.py | 14 +- src/vectorcode/subcommands/vectorise.py | 10 +- 7 files changed, 484 insertions(+), 19 deletions(-) create mode 100644 src/vectorcode/db/base.py create mode 100644 src/vectorcode/db/chroma.py create mode 100644 src/vectorcode/db/factory.py create mode 100644 src/vectorcode/db/local.py diff --git a/src/vectorcode/common.py b/src/vectorcode/common.py index f4fff1a6..62f4b104 100644 --- a/src/vectorcode/common.py +++ b/src/vectorcode/common.py @@ -6,16 +6,15 @@ import subprocess import sys from typing import Any, AsyncGenerator -from urllib.parse import urlparse import chromadb import httpx from chromadb.api import AsyncClientAPI from chromadb.api.models.AsyncCollection import AsyncCollection -from chromadb.config import APIVersion, Settings from chromadb.utils import embedding_functions from vectorcode.cli_utils import Config, expand_path +from vectorcode.db.base import VectorStore logger = logging.getLogger(name=__name__) @@ -169,11 +168,36 @@ def get_embedding_function(configs: Config) -> chromadb.EmbeddingFunction | None raise +def build_collection_metadata(configs: Config) -> dict[str, str | int]: + assert configs.project_root is not None + full_path = str(expand_path(str(configs.project_root), absolute=True)) + + collection_meta: dict[str, str | int] = { + "path": full_path, + "hostname": socket.gethostname(), + "created-by": "VectorCode", + "username": os.environ.get("USER", os.environ.get("USERNAME", "DEFAULT_USER")), + "embedding_function": configs.embedding_function, + } + + if configs.hnsw: + for key in configs.hnsw.keys(): + target_key = key + if not key.startswith("hnsw:"): + target_key = f"hnsw:{key}" + collection_meta[target_key] = configs.hnsw[key] + logger.debug( + f"Getting/Creating collection with the following metadata: {collection_meta}" + ) + + return collection_meta + + __COLLECTION_CACHE: dict[str, AsyncCollection] = {} async def get_collection( - client: AsyncClientAPI, configs: Config, make_if_missing: bool = False + db: VectorStore, configs: Config, make_if_missing: bool = False ): """ Raise ValueError when make_if_missing is False and no collection is found; @@ -205,11 +229,11 @@ async def get_collection( f"Getting/Creating collection with the following metadata: {collection_meta}" ) if not make_if_missing: - __COLLECTION_CACHE[full_path] = await client.get_collection( + __COLLECTION_CACHE[full_path] = await db.get_collection( collection_name, embedding_function ) else: - collection = await client.get_or_create_collection( + collection = await db.get_or_create_collection( collection_name, metadata=collection_meta, embedding_function=embedding_function, diff --git a/src/vectorcode/db/base.py b/src/vectorcode/db/base.py new file mode 100644 index 00000000..7eeff340 --- /dev/null +++ b/src/vectorcode/db/base.py @@ -0,0 +1,132 @@ +from abc import ABC, abstractmethod +from typing import Any +from urllib.parse import urlparse + +from vectorcode.cli_utils import Config +from vectorcode.common import ( + build_collection_metadata, + expand_path, + get_collection_name, + get_embedding_function, +) + + +class VectorStore(ABC): + """Base class for vector database implementations. + + This abstract class defines the interface that all vector database implementations + must follow. It provides methods for common vector database operations like + querying, adding, and deleting vectors. + """ + + configs: Config + + def __init__(self, configs: Config): + self.__COLLECTION_CACHE: dict[str, Any] = {} + self.configs = configs + + assert configs.project_root is not None + self.full_path = str(expand_path(str(configs.project_root), absolute=True)) + + self.collection_metadata = build_collection_metadata(configs) + self.collection_name = get_collection_name(self.full_path) + self.embedding_function = get_embedding_function(configs) + + @abstractmethod + async def connect(self) -> None: + """Establish connection to the vector database.""" + pass + + @abstractmethod + async def disconnect(self) -> None: + """Close connection to the vector database.""" + pass + + # @abstractmethod + # async def get_or_create_collection( + # self, + # collection_name: str, + # metadata: Optional[Dict[str, Any]] = None, + # embedding_function: Optional[Any] = None, + # ) -> Any: + # """Get an existing collection or create a new one if it doesn't exist.""" + # pass + + @abstractmethod + async def get_collection( + self, + make_if_missing: bool = False, + ) -> Any: + """Get an existing collection.""" + pass + + # @abstractmethod + # async def query( + # self, + # collection: Any, + # query_texts: List[str], + # n_results: int, + # where: Optional[Dict[str, Any]] = None, + # include: Optional[List[str]] = None, + # ) -> Dict[str, Any]: + # """Query the vector database for similar vectors.""" + # pass + # + # @abstractmethod + # async def add( + # self, + # collection: Any, + # documents: List[str], + # metadatas: List[Dict[str, Any]], + # ids: Optional[List[str]] = None, + # ) -> None: + # """Add documents to the vector database.""" + # pass + # + # @abstractmethod + # async def delete( + # self, + # collection: Any, + # where: Optional[Dict[str, Any]] = None, + # ) -> None: + # """Delete documents from the vector database.""" + # pass + # + # @abstractmethod + # async def count( + # self, + # collection: Any, + # ) -> int: + # """Get the number of documents in the collection.""" + # pass + # + # @abstractmethod + # async def get( + # self, + # collection: Any, + # ids: Union[str, List[str]], + # include: Optional[List[str]] = None, + # ) -> Dict[str, Any]: + # """Get documents by their IDs.""" + # pass + + @abstractmethod + async def check_health(self) -> bool: + """Check if the database is healthy and accessible.""" + pass + + def print_config(self) -> None: + """Print the current database configuration.""" + parsed_url = urlparse(self.configs.db_url) + + print(f"{self.configs.db_type.title()} Configuration:") + print(f" URL: {self.configs.db_url}") + print(f" Host: {parsed_url.hostname or 'localhost'}") + print( + f" Port: {parsed_url.port or (8000 if self.configs.db_type == 'chroma' else 6333)}" + ) + print(f" SSL: {parsed_url.scheme == 'https'}") + if self.configs.db_settings: + print(" Settings:") + for key, value in self.configs.db_settings.items(): + print(f" {key}: {value}") diff --git a/src/vectorcode/db/chroma.py b/src/vectorcode/db/chroma.py new file mode 100644 index 00000000..4bb32902 --- /dev/null +++ b/src/vectorcode/db/chroma.py @@ -0,0 +1,117 @@ +import logging +import os +import socket +from typing import Any, Dict, override + +import chromadb +from chromadb.api import AsyncClientAPI +from chromadb.api.models.AsyncCollection import AsyncCollection +from chromadb.config import Settings + +from vectorcode.cli_utils import Config +from vectorcode.db.base import VectorStore + +logger = logging.getLogger(__name__) + + +class ChromaVectorStore(VectorStore): + """ChromaDB implementation of the vector store.""" + + _client: AsyncClientAPI | None = None + _chroma_settings: Settings + + def __init__(self, configs: Config): + super().__init__(configs) + settings: Dict[str, Any] = {"anonymized_telemetry": False} + if isinstance(self.configs.db_settings, dict): + valid_settings = { + k: v + for k, v in self.configs.db_settings.items() + if k in Settings.__fields__ + } + settings.update(valid_settings) + + from urllib.parse import urlparse + + parsed_url = urlparse(self.configs.db_url) + settings["chroma_server_host"] = parsed_url.hostname or "127.0.0.1" + settings["chroma_server_http_port"] = parsed_url.port or 8000 + settings["chroma_server_ssl_enabled"] = parsed_url.scheme == "https" + settings["chroma_server_api_default_path"] = "/api/v2" + + self._chroma_settings = Settings(**settings) + + async def connect(self) -> None: + """Establish connection to ChromaDB.""" + try: + self._client = await chromadb.AsyncHttpClient( + settings=self._chroma_settings, + host=str(self._chroma_settings.chroma_server_host), + port=int(self._chroma_settings.chroma_server_http_port or 8000), + ) + await self.check_health() + except Exception as e: + logger.error(f"Could not connect to ChromaDB: {e}") + + @override + async def check_health(self) -> bool: + try: + if self._client is None: + await self.connect() + + assert self._client is not None, "Chroma client is not connected." + await self._client.heartbeat() + + return True + except Exception as e: + logger.error(f"ChromaDB is not healthy: {e}") + return False + + async def disconnect(self) -> None: + """Close connection to ChromaDB.""" + return None + + async def get_collection( + self, + make_if_missing: bool = False, + ) -> AsyncCollection: + """ + Raise ValueError when make_if_missing is False and no collection is found; + Raise IndexError on hash collision. + """ + if not self._client: + await self.connect() + + assert self._client is not None, "Chroma client is not connected." + + if self.__COLLECTION_CACHE.get(self.full_path) is None: + if not make_if_missing: + self.__COLLECTION_CACHE[ + self.full_path + ] = await self._client.get_collection( + self.collection_name, self.embedding_function + ) + else: + collection = await self._client.get_or_create_collection( + self.collection_name, + metadata=self.collection_metadata, + embedding_function=self.embedding_function, + ) + if ( + not collection.metadata.get("hostname") == socket.gethostname() + or collection.metadata.get("username") + not in ( + os.environ.get("USER"), + os.environ.get("USERNAME"), + "DEFAULT_USER", + ) + or not collection.metadata.get("created-by") == "VectorCode" + ): + logger.error( + f"Failed to use existing collection due to metadata mismatch: {self.collection_metadata}" + ) + raise IndexError( + "Failed to create the collection due to hash collision. Please file a bug report." + ) + self.__COLLECTION_CACHE[self.full_path] = collection + return self.__COLLECTION_CACHE[self.full_path] diff --git a/src/vectorcode/db/factory.py b/src/vectorcode/db/factory.py new file mode 100644 index 00000000..5cd5577f --- /dev/null +++ b/src/vectorcode/db/factory.py @@ -0,0 +1,25 @@ +from typing import Dict, Type + +from vectorcode.cli_utils import Config, DbType +from vectorcode.db.base import VectorStore +from vectorcode.db.chroma import ChromaVectorStore +from vectorcode.db.local import LocalChromaVectorStore + + +class VectorStoreFactory: + """Factory for creating vector store instances.""" + + _stores: Dict[DbType, Type[VectorStore]] = { + DbType.chromadb: ChromaVectorStore, + DbType.local: LocalChromaVectorStore, + } + + @classmethod + def create_store(cls, configs: Config) -> VectorStore: + """Create a vector store instance based on configuration.""" + store_type = configs.db_type + if store_type not in cls._stores: + raise ValueError(f"Unsupported vector store type: {store_type}") + + store_class = cls._stores[store_type] + return store_class(configs) diff --git a/src/vectorcode/db/local.py b/src/vectorcode/db/local.py new file mode 100644 index 00000000..3cb781b3 --- /dev/null +++ b/src/vectorcode/db/local.py @@ -0,0 +1,171 @@ +import asyncio +from asyncio.subprocess import Process +import logging +import subprocess +import os +import socket +import sys +from typing import Any, Dict, override + +import chromadb +from chromadb.api import AsyncClientAPI +from chromadb.api.models.AsyncCollection import AsyncCollection +from chromadb.config import Settings + +from vectorcode.cli_utils import Config +from vectorcode.db.chroma import ChromaVectorStore + +logger = logging.getLogger(__name__) + + +class LocalChromaVectorStore(ChromaVectorStore): + """ChromaDB implementation of the vector store.""" + + _client: AsyncClientAPI | None = None + _process: Process | None = None + _chroma_settings: Settings + + def __init__(self, configs: Config): + super().__init__(configs) + settings: Dict[str, Any] = {"anonymized_telemetry": False} + if isinstance(self.configs.db_settings, dict): + valid_settings = { + k: v + for k, v in self.configs.db_settings.items() + if k in Settings.__fields__ + } + settings.update(valid_settings) + + from urllib.parse import urlparse + + parsed_url = urlparse(self.configs.db_url) + settings["chroma_server_host"] = parsed_url.hostname or "127.0.0.1" + settings["chroma_server_http_port"] = parsed_url.port or 8000 + settings["chroma_server_ssl_enabled"] = parsed_url.scheme == "https" + settings["chroma_server_api_default_path"] = "/api/v2" + + self._chroma_settings = Settings(**settings) + + async def _start_chroma_process(self) -> None: + if self._process is not None: + return + + assert self.configs.db_path is not None, "ChromaDB db_path must be set." + db_path = os.path.expanduser(self.configs.db_path) + self.configs.db_log_path = os.path.expanduser(self.configs.db_log_path) + if not os.path.isdir(self.configs.db_log_path): + os.makedirs(self.configs.db_log_path) + if not os.path.isdir(db_path): + logger.warning( + f"Using local database at {os.path.expanduser('~/.local/share/vectorcode/chromadb/')}.", + ) + db_path = os.path.expanduser("~/.local/share/vectorcode/chromadb/") + env = os.environ.copy() + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) # OS selects a free ephemeral port + port = int(s.getsockname()[1]) + + server_url = f"http://127.0.0.1:{port}" + logger.warning(f"Starting bundled ChromaDB server at {server_url}.") + env.update({"ANONYMIZED_TELEMETRY": "False"}) + + self._process = await asyncio.create_subprocess_exec( + sys.executable, + "-m", + "chromadb.cli.cli", + "run", + "--host", + "localhost", + "--port", + str(port), + "--path", + db_path, + "--log-path", + os.path.join(str(self.configs.db_log_path), "chroma.log"), + stdout=subprocess.DEVNULL, + stderr=sys.stderr, + env=env, + ) + + async def connect(self) -> None: + """Establish connection to ChromaDB.""" + if self._process is None: + await self._start_chroma_process() + + try: + self._client = await chromadb.AsyncHttpClient( + settings=self._chroma_settings, + host=str(self._chroma_settings.chroma_server_host), + port=int(self._chroma_settings.chroma_server_http_port or 8000), + ) + await self.check_health() + except Exception as e: + logger.error(f"Could not connect to ChromaDB: {e}") + + # @override + # async def check_health(self) -> bool: + # try: + # if self._client is None: + # await self.connect() + # + # assert self._client is not None, "Chroma client is not connected." + # await self._client.heartbeat() + # + # return True + # except Exception as e: + # logger.error(f"ChromaDB is not healthy: {e}") + # return False + + async def disconnect(self) -> None: + """Close connection to ChromaDB.""" + if self._process is None: + return + + logger.info("Shutting down the bundled Chromadb instance.") + self._process.terminate() + await self._process.wait() + + # async def get_collection( + # self, + # make_if_missing: bool = False, + # ) -> AsyncCollection: + # """ + # Raise ValueError when make_if_missing is False and no collection is found; + # Raise IndexError on hash collision. + # """ + # if not self._client: + # await self.connect() + # + # assert self._client is not None, "Chroma client is not connected." + # + # if self.__COLLECTION_CACHE.get(self.full_path) is None: + # if not make_if_missing: + # self.__COLLECTION_CACHE[ + # self.full_path + # ] = await self._client.get_collection( + # self.collection_name, self.embedding_function + # ) + # else: + # collection = await self._client.get_or_create_collection( + # self.collection_name, + # metadata=self.collection_metadata, + # embedding_function=self.embedding_function, + # ) + # if ( + # not collection.metadata.get("hostname") == socket.gethostname() + # or collection.metadata.get("username") + # not in ( + # os.environ.get("USER"), + # os.environ.get("USERNAME"), + # "DEFAULT_USER", + # ) + # or not collection.metadata.get("created-by") == "VectorCode" + # ): + # logger.error( + # f"Failed to use existing collection due to metadata mismatch: {self.collection_metadata}" + # ) + # raise IndexError( + # "Failed to create the collection due to hash collision. Please file a bug report." + # ) + # self.__COLLECTION_CACHE[self.full_path] = collection + # return self.__COLLECTION_CACHE[self.full_path] diff --git a/src/vectorcode/main.py b/src/vectorcode/main.py index 3b64cff4..e9ad4f26 100644 --- a/src/vectorcode/main.py +++ b/src/vectorcode/main.py @@ -71,11 +71,10 @@ async def async_main(): return await hooks(cli_args) - from vectorcode.common import start_server, try_server + from vectorcode.db.factory import VectorStoreFactory - server_process = None - if not await try_server(final_configs.db_url): - server_process = await start_server(final_configs) + db = VectorStoreFactory.create_store(final_configs) + await db.connect() if final_configs.pipe: # NOTE: NNCF (intel GPU acceleration for sentence transformer) keeps showing logs. @@ -92,7 +91,7 @@ async def async_main(): case CliAction.vectorise: from vectorcode.subcommands import vectorise - return_val = await vectorise(final_configs) + return_val = await vectorise(db, final_configs) case CliAction.drop: from vectorcode.subcommands import drop @@ -113,10 +112,7 @@ async def async_main(): return_val = 1 logger.error(traceback.format_exc()) finally: - if server_process is not None: - logger.info("Shutting down the bundled Chromadb instance.") - server_process.terminate() - await server_process.wait() + await db.disconnect() return return_val diff --git a/src/vectorcode/subcommands/vectorise.py b/src/vectorcode/subcommands/vectorise.py index a838124f..33df3dd3 100644 --- a/src/vectorcode/subcommands/vectorise.py +++ b/src/vectorcode/subcommands/vectorise.py @@ -14,6 +14,7 @@ from chromadb.api.models.AsyncCollection import AsyncCollection from chromadb.api.types import IncludeEnum +from vectorcode.db.base import VectorStore from vectorcode.chunking import Chunk, TreeSitterChunker from vectorcode.cli_utils import ( GLOBAL_EXCLUDE_SPEC, @@ -22,7 +23,7 @@ expand_globs, expand_path, ) -from vectorcode.common import get_client, get_collection, verify_ef +from vectorcode.common import verify_ef logger = logging.getLogger(name=__name__) @@ -158,11 +159,9 @@ def load_files_from_include(project_root: str) -> list[str]: return [] -async def vectorise(configs: Config) -> int: - assert configs.project_root is not None - client = await get_client(configs) +async def vectorise(db: VectorStore, configs: Config) -> int: try: - collection = await get_collection(client, configs, True) + collection = await db.get_collection(True) except IndexError: print("Failed to get/create the collection. Please check your config.") return 1 @@ -180,6 +179,7 @@ async def vectorise(configs: Config) -> int: specs = [ gitignore_path, ] + assert configs.project_root is not None exclude_spec_path = os.path.join( configs.project_root, ".vectorcode", "vectorcode.exclude" ) From 88226f4f1bfc8537f8f63de0238a3ef3bcc5e28a Mon Sep 17 00:00:00 2001 From: Jared Rieger Date: Sat, 31 May 2025 17:03:25 +1000 Subject: [PATCH 3/3] fix connection to local --- src/vectorcode/db/base.py | 57 +++++++------- src/vectorcode/db/chroma.py | 105 ++++++++++++-------------- src/vectorcode/db/local.py | 143 ++++++++++-------------------------- 3 files changed, 112 insertions(+), 193 deletions(-) diff --git a/src/vectorcode/db/base.py b/src/vectorcode/db/base.py index 7eeff340..6be6739d 100644 --- a/src/vectorcode/db/base.py +++ b/src/vectorcode/db/base.py @@ -3,12 +3,10 @@ from urllib.parse import urlparse from vectorcode.cli_utils import Config -from vectorcode.common import ( - build_collection_metadata, - expand_path, - get_collection_name, - get_embedding_function, -) + + +class VectorStoreConnectionError(Exception): + pass class VectorStore(ABC): @@ -19,18 +17,13 @@ class VectorStore(ABC): querying, adding, and deleting vectors. """ - configs: Config + _configs: Config def __init__(self, configs: Config): self.__COLLECTION_CACHE: dict[str, Any] = {} - self.configs = configs + self._configs = configs assert configs.project_root is not None - self.full_path = str(expand_path(str(configs.project_root), absolute=True)) - - self.collection_metadata = build_collection_metadata(configs) - self.collection_name = get_collection_name(self.full_path) - self.embedding_function = get_embedding_function(configs) @abstractmethod async def connect(self) -> None: @@ -43,23 +36,30 @@ async def disconnect(self) -> None: pass # @abstractmethod - # async def get_or_create_collection( - # self, - # collection_name: str, - # metadata: Optional[Dict[str, Any]] = None, - # embedding_function: Optional[Any] = None, - # ) -> Any: - # """Get an existing collection or create a new one if it doesn't exist.""" + # async def check_health(self) -> None: + # """Check if the database is healthy and accessible. Raises a VectorStoreConnectionError if not.""" # pass @abstractmethod async def get_collection( self, + collection_name: str, + collection_meta: dict[str, Any] | None = None, make_if_missing: bool = False, ) -> Any: """Get an existing collection.""" pass + # @abstractmethod + # async def get_or_create_collection( + # self, + # collection_name: str, + # metadata: Optional[Dict[str, Any]] = None, + # embedding_function: Optional[Any] = None, + # ) -> Any: + # """Get an existing collection or create a new one if it doesn't exist.""" + # pass + # @abstractmethod # async def query( # self, @@ -110,23 +110,18 @@ async def get_collection( # """Get documents by their IDs.""" # pass - @abstractmethod - async def check_health(self) -> bool: - """Check if the database is healthy and accessible.""" - pass - def print_config(self) -> None: """Print the current database configuration.""" - parsed_url = urlparse(self.configs.db_url) + parsed_url = urlparse(self._configs.db_url) - print(f"{self.configs.db_type.title()} Configuration:") - print(f" URL: {self.configs.db_url}") + print(f"{self._configs.db_type.title()} Configuration:") + print(f" URL: {self._configs.db_url}") print(f" Host: {parsed_url.hostname or 'localhost'}") print( - f" Port: {parsed_url.port or (8000 if self.configs.db_type == 'chroma' else 6333)}" + f" Port: {parsed_url.port or (8000 if self._configs.db_type == 'chroma' else 6333)}" ) print(f" SSL: {parsed_url.scheme == 'https'}") - if self.configs.db_settings: + if self._configs.db_settings: print(" Settings:") - for key, value in self.configs.db_settings.items(): + for key, value in self._configs.db_settings.items(): print(f" {key}: {value}") diff --git a/src/vectorcode/db/chroma.py b/src/vectorcode/db/chroma.py index 4bb32902..07bc6bae 100644 --- a/src/vectorcode/db/chroma.py +++ b/src/vectorcode/db/chroma.py @@ -1,15 +1,15 @@ import logging -import os -import socket from typing import Any, Dict, override +from urllib.parse import urlparse import chromadb from chromadb.api import AsyncClientAPI from chromadb.api.models.AsyncCollection import AsyncCollection from chromadb.config import Settings +from chromadb.utils import embedding_functions from vectorcode.cli_utils import Config -from vectorcode.db.base import VectorStore +from vectorcode.db.base import VectorStore, VectorStoreConnectionError logger = logging.getLogger(__name__) @@ -19,21 +19,20 @@ class ChromaVectorStore(VectorStore): _client: AsyncClientAPI | None = None _chroma_settings: Settings + _embedding_function: chromadb.EmbeddingFunction | None def __init__(self, configs: Config): super().__init__(configs) settings: Dict[str, Any] = {"anonymized_telemetry": False} - if isinstance(self.configs.db_settings, dict): + if isinstance(self._configs.db_settings, dict): valid_settings = { k: v - for k, v in self.configs.db_settings.items() + for k, v in self._configs.db_settings.items() if k in Settings.__fields__ } settings.update(valid_settings) - from urllib.parse import urlparse - - parsed_url = urlparse(self.configs.db_url) + parsed_url = urlparse(self._configs.db_url) settings["chroma_server_host"] = parsed_url.hostname or "127.0.0.1" settings["chroma_server_http_port"] = parsed_url.port or 8000 settings["chroma_server_ssl_enabled"] = parsed_url.scheme == "https" @@ -41,38 +40,55 @@ def __init__(self, configs: Config): self._chroma_settings = Settings(**settings) - async def connect(self) -> None: - """Establish connection to ChromaDB.""" try: - self._client = await chromadb.AsyncHttpClient( - settings=self._chroma_settings, - host=str(self._chroma_settings.chroma_server_host), - port=int(self._chroma_settings.chroma_server_http_port or 8000), + self._embedding_function = getattr( + embedding_functions, configs.embedding_function + )(**configs.embedding_params) + except AttributeError: + logger.warning( + f"Failed to use {configs.embedding_function}. Falling back to Sentence Transformer.", + ) + self._embedding_function = ( + embedding_functions.SentenceTransformerEmbeddingFunction() # type:ignore ) - await self.check_health() except Exception as e: - logger.error(f"Could not connect to ChromaDB: {e}") + e.add_note( + "\nFor errors caused by missing dependency, consult the documentation of pipx (or whatever package manager that you installed VectorCode with) for instructions to inject libraries into the virtual environment." + ) + logger.error( + f"Failed to use {configs.embedding_function} with following error.", + ) + raise @override - async def check_health(self) -> bool: + async def connect(self) -> None: + """Establish connection to ChromaDB.""" try: if self._client is None: - await self.connect() + logger.debug( + f"Connecting to ChromaDB at {self._chroma_settings.chroma_server_host}:{self._chroma_settings.chroma_server_http_port}." + ) + self._client = await chromadb.AsyncHttpClient( + settings=self._chroma_settings, + host=str(self._chroma_settings.chroma_server_host), + port=int(self._chroma_settings.chroma_server_http_port or 8000), + ) - assert self._client is not None, "Chroma client is not connected." await self._client.heartbeat() - - return True except Exception as e: - logger.error(f"ChromaDB is not healthy: {e}") - return False + logger.error(f"Could not connect to ChromaDB: {e}") + raise VectorStoreConnectionError(e) + @override async def disconnect(self) -> None: - """Close connection to ChromaDB.""" - return None + """Not required for non local chromadb.""" + pass + @override async def get_collection( self, + collection_name: str, + collection_meta: dict[str, Any] | None = None, make_if_missing: bool = False, ) -> AsyncCollection: """ @@ -84,34 +100,11 @@ async def get_collection( assert self._client is not None, "Chroma client is not connected." - if self.__COLLECTION_CACHE.get(self.full_path) is None: - if not make_if_missing: - self.__COLLECTION_CACHE[ - self.full_path - ] = await self._client.get_collection( - self.collection_name, self.embedding_function - ) - else: - collection = await self._client.get_or_create_collection( - self.collection_name, - metadata=self.collection_metadata, - embedding_function=self.embedding_function, - ) - if ( - not collection.metadata.get("hostname") == socket.gethostname() - or collection.metadata.get("username") - not in ( - os.environ.get("USER"), - os.environ.get("USERNAME"), - "DEFAULT_USER", - ) - or not collection.metadata.get("created-by") == "VectorCode" - ): - logger.error( - f"Failed to use existing collection due to metadata mismatch: {self.collection_metadata}" - ) - raise IndexError( - "Failed to create the collection due to hash collision. Please file a bug report." - ) - self.__COLLECTION_CACHE[self.full_path] = collection - return self.__COLLECTION_CACHE[self.full_path] + if not make_if_missing: + return await self._client.get_collection(collection_name) + else: + return await self._client.get_or_create_collection( + collection_name, + metadata=collection_meta, + embedding_function=self._embedding_function, + ) diff --git a/src/vectorcode/db/local.py b/src/vectorcode/db/local.py index 3cb781b3..88339106 100644 --- a/src/vectorcode/db/local.py +++ b/src/vectorcode/db/local.py @@ -5,14 +5,10 @@ import os import socket import sys -from typing import Any, Dict, override +from typing import override -import chromadb -from chromadb.api import AsyncClientAPI -from chromadb.api.models.AsyncCollection import AsyncCollection -from chromadb.config import Settings -from vectorcode.cli_utils import Config +from vectorcode.cli_utils import Config, expand_path from vectorcode.db.chroma import ChromaVectorStore logger = logging.getLogger(__name__) @@ -21,40 +17,22 @@ class LocalChromaVectorStore(ChromaVectorStore): """ChromaDB implementation of the vector store.""" - _client: AsyncClientAPI | None = None _process: Process | None = None - _chroma_settings: Settings + _full_path: str def __init__(self, configs: Config): super().__init__(configs) - settings: Dict[str, Any] = {"anonymized_telemetry": False} - if isinstance(self.configs.db_settings, dict): - valid_settings = { - k: v - for k, v in self.configs.db_settings.items() - if k in Settings.__fields__ - } - settings.update(valid_settings) - - from urllib.parse import urlparse - - parsed_url = urlparse(self.configs.db_url) - settings["chroma_server_host"] = parsed_url.hostname or "127.0.0.1" - settings["chroma_server_http_port"] = parsed_url.port or 8000 - settings["chroma_server_ssl_enabled"] = parsed_url.scheme == "https" - settings["chroma_server_api_default_path"] = "/api/v2" - - self._chroma_settings = Settings(**settings) + self._full_path = str(expand_path(str(configs.project_root), absolute=True)) async def _start_chroma_process(self) -> None: if self._process is not None: return - assert self.configs.db_path is not None, "ChromaDB db_path must be set." - db_path = os.path.expanduser(self.configs.db_path) - self.configs.db_log_path = os.path.expanduser(self.configs.db_log_path) - if not os.path.isdir(self.configs.db_log_path): - os.makedirs(self.configs.db_log_path) + assert self._configs.db_path is not None, "ChromaDB db_path must be set." + db_path = os.path.expanduser(self._configs.db_path) + self._configs.db_log_path = os.path.expanduser(self._configs.db_log_path) + if not os.path.isdir(self._configs.db_log_path): + os.makedirs(self._configs.db_log_path) if not os.path.isdir(db_path): logger.warning( f"Using local database at {os.path.expanduser('~/.local/share/vectorcode/chromadb/')}.", @@ -63,10 +41,10 @@ async def _start_chroma_process(self) -> None: env = os.environ.copy() with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(("", 0)) # OS selects a free ephemeral port - port = int(s.getsockname()[1]) + self._chroma_settings.chroma_server_http_port = int(s.getsockname()[1]) - server_url = f"http://127.0.0.1:{port}" - logger.warning(f"Starting bundled ChromaDB server at {server_url}.") + server_url = f"http://127.0.0.1:{self._chroma_settings.chroma_server_http_port}" + logger.info(f"Starting bundled ChromaDB server at {server_url}.") env.update({"ANONYMIZED_TELEMETRY": "False"}) self._process = await asyncio.create_subprocess_exec( @@ -77,45 +55,43 @@ async def _start_chroma_process(self) -> None: "--host", "localhost", "--port", - str(port), + str(self._chroma_settings.chroma_server_http_port), "--path", db_path, "--log-path", - os.path.join(str(self.configs.db_log_path), "chroma.log"), + os.path.join(str(self._configs.db_log_path), "chroma.log"), stdout=subprocess.DEVNULL, stderr=sys.stderr, env=env, ) + @override async def connect(self) -> None: """Establish connection to ChromaDB.""" if self._process is None: await self._start_chroma_process() - - try: - self._client = await chromadb.AsyncHttpClient( - settings=self._chroma_settings, - host=str(self._chroma_settings.chroma_server_host), - port=int(self._chroma_settings.chroma_server_http_port or 8000), - ) - await self.check_health() - except Exception as e: - logger.error(f"Could not connect to ChromaDB: {e}") - - # @override - # async def check_health(self) -> bool: - # try: - # if self._client is None: - # await self.connect() - # - # assert self._client is not None, "Chroma client is not connected." - # await self._client.heartbeat() - # - # return True - # except Exception as e: - # logger.error(f"ChromaDB is not healthy: {e}") - # return False - + # Wait for server to start up + await asyncio.sleep(2) + + # we have to wait until the local chroma server is ready + # Retry connection with exponential backoff + max_retries = 5 + retry_delay = 0.5 + + for attempt in range(max_retries): + try: + await super().connect() + return + except Exception as e: + if attempt == max_retries - 1: + raise + logger.debug( + f"Connection attempt {attempt + 1} failed, retrying in {retry_delay}s: {e}" + ) + await asyncio.sleep(retry_delay) + retry_delay *= 2 + + @override async def disconnect(self) -> None: """Close connection to ChromaDB.""" if self._process is None: @@ -124,48 +100,3 @@ async def disconnect(self) -> None: logger.info("Shutting down the bundled Chromadb instance.") self._process.terminate() await self._process.wait() - - # async def get_collection( - # self, - # make_if_missing: bool = False, - # ) -> AsyncCollection: - # """ - # Raise ValueError when make_if_missing is False and no collection is found; - # Raise IndexError on hash collision. - # """ - # if not self._client: - # await self.connect() - # - # assert self._client is not None, "Chroma client is not connected." - # - # if self.__COLLECTION_CACHE.get(self.full_path) is None: - # if not make_if_missing: - # self.__COLLECTION_CACHE[ - # self.full_path - # ] = await self._client.get_collection( - # self.collection_name, self.embedding_function - # ) - # else: - # collection = await self._client.get_or_create_collection( - # self.collection_name, - # metadata=self.collection_metadata, - # embedding_function=self.embedding_function, - # ) - # if ( - # not collection.metadata.get("hostname") == socket.gethostname() - # or collection.metadata.get("username") - # not in ( - # os.environ.get("USER"), - # os.environ.get("USERNAME"), - # "DEFAULT_USER", - # ) - # or not collection.metadata.get("created-by") == "VectorCode" - # ): - # logger.error( - # f"Failed to use existing collection due to metadata mismatch: {self.collection_metadata}" - # ) - # raise IndexError( - # "Failed to create the collection due to hash collision. Please file a bug report." - # ) - # self.__COLLECTION_CACHE[self.full_path] = collection - # return self.__COLLECTION_CACHE[self.full_path]