Skip to content

[WIP]: Refactor db layer #165

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/vectorcode/cli_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import atexit
import glob
import logging
from optparse import Option

Check failure on line 5 in src/vectorcode/cli_utils.py

View workflow job for this annotation

GitHub Actions / style-check

Ruff (F401)

src/vectorcode/cli_utils.py:5:22: F401 `optparse.Option` imported but unused
import os
import sys
from dataclasses import dataclass, field, fields
Expand Down Expand Up @@ -62,6 +63,11 @@
hooks = "hooks"


class DbType(StrEnum):
local = "local" # Local ChromaDB instance
chromadb = "chromadb" # Remote ChromaDB instance


@dataclass
class Config:
no_stderr: bool = False
Expand All @@ -74,6 +80,7 @@
project_root: Optional[Union[str, Path]] = None
query: Optional[list[str]] = None
db_url: str = "http://127.0.0.1:8000"
db_type: DbType = DbType.local # falls back to a local instance
embedding_function: str = "SentenceTransformerEmbeddingFunction" # This should fallback to whatever the default is.
embedding_params: dict[str, Any] = field(default_factory=(lambda: {}))
n_result: int = 1
Expand Down Expand Up @@ -106,6 +113,8 @@
default_config = Config()
db_path = config_dict.get("db_path")
db_url = config_dict.get("db_url")
db_type = config_dict.get("db_type", default_config.db_type)

if db_url is None:
host = config_dict.get("host")
port = config_dict.get("port")
Expand Down Expand Up @@ -135,6 +144,7 @@
"embedding_params", default_config.embedding_params
),
"db_url": db_url,
"db_type": db_type,
"db_path": db_path,
"db_log_path": os.path.expanduser(
config_dict.get("db_log_path", default_config.db_log_path)
Expand Down Expand Up @@ -521,6 +531,9 @@
if config is None:
config = await load_config_file()
config.project_root = project_root

if config.db_type is None:
config.db_type = "local"
return config


Expand Down
34 changes: 29 additions & 5 deletions src/vectorcode/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,15 @@
import subprocess
import sys
from typing import Any, AsyncGenerator
from urllib.parse import urlparse

import chromadb
import httpx
from chromadb.api import AsyncClientAPI
from chromadb.api.models.AsyncCollection import AsyncCollection
from chromadb.config import APIVersion, Settings
from chromadb.utils import embedding_functions

from vectorcode.cli_utils import Config, expand_path
from vectorcode.db.base import VectorStore

logger = logging.getLogger(name=__name__)

Expand Down Expand Up @@ -120,15 +119,15 @@
settings: dict[str, Any] = {"anonymized_telemetry": False}
if isinstance(configs.db_settings, dict):
valid_settings = {
k: v for k, v in configs.db_settings.items() if k in Settings.__fields__

Check failure on line 122 in src/vectorcode/common.py

View workflow job for this annotation

GitHub Actions / style-check

Ruff (F821)

src/vectorcode/common.py:122:70: F821 Undefined name `Settings`
}
settings.update(valid_settings)
parsed_url = urlparse(configs.db_url)

Check failure on line 125 in src/vectorcode/common.py

View workflow job for this annotation

GitHub Actions / style-check

Ruff (F821)

src/vectorcode/common.py:125:22: F821 Undefined name `urlparse`
settings["chroma_server_host"] = parsed_url.hostname or "127.0.0.1"
settings["chroma_server_http_port"] = parsed_url.port or 8000
settings["chroma_server_ssl_enabled"] = parsed_url.scheme == "https"
settings["chroma_server_api_default_path"] = parsed_url.path or APIVersion.V2

Check failure on line 129 in src/vectorcode/common.py

View workflow job for this annotation

GitHub Actions / style-check

Ruff (F821)

src/vectorcode/common.py:129:73: F821 Undefined name `APIVersion`
settings_obj = Settings(**settings)

Check failure on line 130 in src/vectorcode/common.py

View workflow job for this annotation

GitHub Actions / style-check

Ruff (F821)

src/vectorcode/common.py:130:24: F821 Undefined name `Settings`
__CLIENT_CACHE[client_entry] = await chromadb.AsyncHttpClient(
settings=settings_obj,
host=str(settings_obj.chroma_server_host),
Expand Down Expand Up @@ -169,11 +168,36 @@
raise


def build_collection_metadata(configs: Config) -> dict[str, str | int]:
assert configs.project_root is not None
full_path = str(expand_path(str(configs.project_root), absolute=True))

collection_meta: dict[str, str | int] = {
"path": full_path,
"hostname": socket.gethostname(),
"created-by": "VectorCode",
"username": os.environ.get("USER", os.environ.get("USERNAME", "DEFAULT_USER")),
"embedding_function": configs.embedding_function,
}

if configs.hnsw:
for key in configs.hnsw.keys():
target_key = key
if not key.startswith("hnsw:"):
target_key = f"hnsw:{key}"
collection_meta[target_key] = configs.hnsw[key]
logger.debug(
f"Getting/Creating collection with the following metadata: {collection_meta}"
)

return collection_meta


__COLLECTION_CACHE: dict[str, AsyncCollection] = {}


async def get_collection(
client: AsyncClientAPI, configs: Config, make_if_missing: bool = False
db: VectorStore, configs: Config, make_if_missing: bool = False
):
"""
Raise ValueError when make_if_missing is False and no collection is found;
Expand Down Expand Up @@ -205,11 +229,11 @@
f"Getting/Creating collection with the following metadata: {collection_meta}"
)
if not make_if_missing:
__COLLECTION_CACHE[full_path] = await client.get_collection(
__COLLECTION_CACHE[full_path] = await db.get_collection(
collection_name, embedding_function
)
else:
collection = await client.get_or_create_collection(
collection = await db.get_or_create_collection(
collection_name,
metadata=collection_meta,
embedding_function=embedding_function,
Expand Down
127 changes: 127 additions & 0 deletions src/vectorcode/db/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from abc import ABC, abstractmethod
from typing import Any
from urllib.parse import urlparse

from vectorcode.cli_utils import Config


class VectorStoreConnectionError(Exception):
pass


class VectorStore(ABC):
"""Base class for vector database implementations.

This abstract class defines the interface that all vector database implementations
must follow. It provides methods for common vector database operations like
querying, adding, and deleting vectors.
"""

_configs: Config

def __init__(self, configs: Config):
self.__COLLECTION_CACHE: dict[str, Any] = {}
self._configs = configs

assert configs.project_root is not None

@abstractmethod
async def connect(self) -> None:
"""Establish connection to the vector database."""
pass

@abstractmethod
async def disconnect(self) -> None:
"""Close connection to the vector database."""
pass

# @abstractmethod
# async def check_health(self) -> None:
# """Check if the database is healthy and accessible. Raises a VectorStoreConnectionError if not."""
# pass

@abstractmethod
async def get_collection(
self,
collection_name: str,
collection_meta: dict[str, Any] | None = None,
make_if_missing: bool = False,
) -> Any:
"""Get an existing collection."""
pass

# @abstractmethod
# async def get_or_create_collection(
# self,
# collection_name: str,
# metadata: Optional[Dict[str, Any]] = None,
# embedding_function: Optional[Any] = None,
# ) -> Any:
# """Get an existing collection or create a new one if it doesn't exist."""
# pass

# @abstractmethod
# async def query(
# self,
# collection: Any,
# query_texts: List[str],
# n_results: int,
# where: Optional[Dict[str, Any]] = None,
# include: Optional[List[str]] = None,
# ) -> Dict[str, Any]:
# """Query the vector database for similar vectors."""
# pass
#
# @abstractmethod
# async def add(
# self,
# collection: Any,
# documents: List[str],
# metadatas: List[Dict[str, Any]],
# ids: Optional[List[str]] = None,
# ) -> None:
# """Add documents to the vector database."""
# pass
#
# @abstractmethod
# async def delete(
# self,
# collection: Any,
# where: Optional[Dict[str, Any]] = None,
# ) -> None:
# """Delete documents from the vector database."""
# pass
#
# @abstractmethod
# async def count(
# self,
# collection: Any,
# ) -> int:
# """Get the number of documents in the collection."""
# pass
#
# @abstractmethod
# async def get(
# self,
# collection: Any,
# ids: Union[str, List[str]],
# include: Optional[List[str]] = None,
# ) -> Dict[str, Any]:
# """Get documents by their IDs."""
# pass

def print_config(self) -> None:
"""Print the current database configuration."""
parsed_url = urlparse(self._configs.db_url)

print(f"{self._configs.db_type.title()} Configuration:")
print(f" URL: {self._configs.db_url}")
print(f" Host: {parsed_url.hostname or 'localhost'}")
print(
f" Port: {parsed_url.port or (8000 if self._configs.db_type == 'chroma' else 6333)}"
)
print(f" SSL: {parsed_url.scheme == 'https'}")
if self._configs.db_settings:
print(" Settings:")
for key, value in self._configs.db_settings.items():
print(f" {key}: {value}")
110 changes: 110 additions & 0 deletions src/vectorcode/db/chroma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import logging
from typing import Any, Dict, override
from urllib.parse import urlparse

import chromadb
from chromadb.api import AsyncClientAPI
from chromadb.api.models.AsyncCollection import AsyncCollection
from chromadb.config import Settings
from chromadb.utils import embedding_functions

from vectorcode.cli_utils import Config
from vectorcode.db.base import VectorStore, VectorStoreConnectionError

logger = logging.getLogger(__name__)


class ChromaVectorStore(VectorStore):
"""ChromaDB implementation of the vector store."""

_client: AsyncClientAPI | None = None
_chroma_settings: Settings
_embedding_function: chromadb.EmbeddingFunction | None

def __init__(self, configs: Config):
super().__init__(configs)
settings: Dict[str, Any] = {"anonymized_telemetry": False}
if isinstance(self._configs.db_settings, dict):
valid_settings = {
k: v
for k, v in self._configs.db_settings.items()
if k in Settings.__fields__
}
settings.update(valid_settings)

parsed_url = urlparse(self._configs.db_url)
settings["chroma_server_host"] = parsed_url.hostname or "127.0.0.1"
settings["chroma_server_http_port"] = parsed_url.port or 8000
settings["chroma_server_ssl_enabled"] = parsed_url.scheme == "https"
settings["chroma_server_api_default_path"] = "/api/v2"

self._chroma_settings = Settings(**settings)

try:
self._embedding_function = getattr(
embedding_functions, configs.embedding_function
)(**configs.embedding_params)
except AttributeError:
logger.warning(
f"Failed to use {configs.embedding_function}. Falling back to Sentence Transformer.",
)
self._embedding_function = (
embedding_functions.SentenceTransformerEmbeddingFunction() # type:ignore
)
except Exception as e:
e.add_note(
"\nFor errors caused by missing dependency, consult the documentation of pipx (or whatever package manager that you installed VectorCode with) for instructions to inject libraries into the virtual environment."
)
logger.error(
f"Failed to use {configs.embedding_function} with following error.",
)
raise

@override
async def connect(self) -> None:
"""Establish connection to ChromaDB."""
try:
if self._client is None:
logger.debug(
f"Connecting to ChromaDB at {self._chroma_settings.chroma_server_host}:{self._chroma_settings.chroma_server_http_port}."
)
self._client = await chromadb.AsyncHttpClient(
settings=self._chroma_settings,
host=str(self._chroma_settings.chroma_server_host),
port=int(self._chroma_settings.chroma_server_http_port or 8000),
)

await self._client.heartbeat()
except Exception as e:
logger.error(f"Could not connect to ChromaDB: {e}")
raise VectorStoreConnectionError(e)

@override
async def disconnect(self) -> None:
"""Not required for non local chromadb."""
pass

@override
async def get_collection(
self,
collection_name: str,
collection_meta: dict[str, Any] | None = None,
make_if_missing: bool = False,
) -> AsyncCollection:
"""
Raise ValueError when make_if_missing is False and no collection is found;
Raise IndexError on hash collision.
"""
if not self._client:
await self.connect()

assert self._client is not None, "Chroma client is not connected."

if not make_if_missing:
return await self._client.get_collection(collection_name)
else:
return await self._client.get_or_create_collection(
collection_name,
metadata=collection_meta,
embedding_function=self._embedding_function,
)
25 changes: 25 additions & 0 deletions src/vectorcode/db/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from typing import Dict, Type

from vectorcode.cli_utils import Config, DbType
from vectorcode.db.base import VectorStore
from vectorcode.db.chroma import ChromaVectorStore
from vectorcode.db.local import LocalChromaVectorStore


class VectorStoreFactory:
"""Factory for creating vector store instances."""

_stores: Dict[DbType, Type[VectorStore]] = {
DbType.chromadb: ChromaVectorStore,
DbType.local: LocalChromaVectorStore,
}

@classmethod
def create_store(cls, configs: Config) -> VectorStore:
"""Create a vector store instance based on configuration."""
store_type = configs.db_type
if store_type not in cls._stores:
raise ValueError(f"Unsupported vector store type: {store_type}")

store_class = cls._stores[store_type]
return store_class(configs)
Loading
Loading