Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions chromadb/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ def create_collection(
metadata: Optional[CollectionMetadata] = None,
embedding_function: Optional[
EmbeddingFunction[Embeddable]
] = ef.DefaultEmbeddingFunction(), # type: ignore
] = ef.DefaultEmbeddingFunction(),
data_loader: Optional[DataLoader[Loadable]] = None,
get_or_create: bool = False,
) -> Collection:
Expand Down Expand Up @@ -407,7 +407,7 @@ def get_collection(
name: str,
embedding_function: Optional[
EmbeddingFunction[Embeddable]
] = ef.DefaultEmbeddingFunction(), # type: ignore
] = ef.DefaultEmbeddingFunction(),
data_loader: Optional[DataLoader[Loadable]] = None,
) -> Collection:
"""Get a collection with the given name.
Expand Down Expand Up @@ -439,7 +439,7 @@ def get_or_create_collection(
metadata: Optional[CollectionMetadata] = None,
embedding_function: Optional[
EmbeddingFunction[Embeddable]
] = ef.DefaultEmbeddingFunction(), # type: ignore
] = ef.DefaultEmbeddingFunction(),
data_loader: Optional[DataLoader[Loadable]] = None,
) -> Collection:
"""Get or create a collection with the given name and metadata.
Expand Down
6 changes: 3 additions & 3 deletions chromadb/api/async_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ async def create_collection(
metadata: Optional[CollectionMetadata] = None,
embedding_function: Optional[
EmbeddingFunction[Embeddable]
] = ef.DefaultEmbeddingFunction(), # type: ignore
] = ef.DefaultEmbeddingFunction(),
data_loader: Optional[DataLoader[Loadable]] = None,
get_or_create: bool = False,
) -> AsyncCollection:
Expand Down Expand Up @@ -400,7 +400,7 @@ async def get_collection(
name: str,
embedding_function: Optional[
EmbeddingFunction[Embeddable]
] = ef.DefaultEmbeddingFunction(), # type: ignore
] = ef.DefaultEmbeddingFunction(),
data_loader: Optional[DataLoader[Loadable]] = None,
) -> AsyncCollection:
"""Get a collection with the given name.
Expand Down Expand Up @@ -432,7 +432,7 @@ async def get_or_create_collection(
metadata: Optional[CollectionMetadata] = None,
embedding_function: Optional[
EmbeddingFunction[Embeddable]
] = ef.DefaultEmbeddingFunction(), # type: ignore
] = ef.DefaultEmbeddingFunction(),
data_loader: Optional[DataLoader[Loadable]] = None,
) -> AsyncCollection:
"""Get or create a collection with the given name and metadata.
Expand Down
6 changes: 3 additions & 3 deletions chromadb/api/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ async def create_collection(
metadata: Optional[CollectionMetadata] = None,
embedding_function: Optional[
EmbeddingFunction[Embeddable]
] = ef.DefaultEmbeddingFunction(), # type: ignore
] = ef.DefaultEmbeddingFunction(),
data_loader: Optional[DataLoader[Loadable]] = None,
get_or_create: bool = False,
) -> AsyncCollection:
Expand Down Expand Up @@ -219,7 +219,7 @@ async def get_collection(
name: str,
embedding_function: Optional[
EmbeddingFunction[Embeddable]
] = ef.DefaultEmbeddingFunction(), # type: ignore
] = ef.DefaultEmbeddingFunction(),
data_loader: Optional[DataLoader[Loadable]] = None,
) -> AsyncCollection:
model = await self._server.get_collection(
Expand Down Expand Up @@ -248,7 +248,7 @@ async def get_or_create_collection(
metadata: Optional[CollectionMetadata] = None,
embedding_function: Optional[
EmbeddingFunction[Embeddable]
] = ef.DefaultEmbeddingFunction(), # type: ignore
] = ef.DefaultEmbeddingFunction(),
data_loader: Optional[DataLoader[Loadable]] = None,
) -> AsyncCollection:
if configuration is None:
Expand Down
6 changes: 3 additions & 3 deletions chromadb/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def create_collection(
metadata: Optional[CollectionMetadata] = None,
embedding_function: Optional[
EmbeddingFunction[Embeddable]
] = ef.DefaultEmbeddingFunction(), # type: ignore
] = ef.DefaultEmbeddingFunction(),
data_loader: Optional[DataLoader[Loadable]] = None,
get_or_create: bool = False,
) -> Collection:
Expand Down Expand Up @@ -195,7 +195,7 @@ def get_collection(
name: str,
embedding_function: Optional[
EmbeddingFunction[Embeddable]
] = ef.DefaultEmbeddingFunction(), # type: ignore
] = ef.DefaultEmbeddingFunction(),
data_loader: Optional[DataLoader[Loadable]] = None,
) -> Collection:
model = self._server.get_collection(
Expand Down Expand Up @@ -224,7 +224,7 @@ def get_or_create_collection(
metadata: Optional[CollectionMetadata] = None,
embedding_function: Optional[
EmbeddingFunction[Embeddable]
] = ef.DefaultEmbeddingFunction(), # type: ignore
] = ef.DefaultEmbeddingFunction(),
data_loader: Optional[DataLoader[Loadable]] = None,
) -> Collection:
if configuration is None:
Expand Down
31 changes: 16 additions & 15 deletions chromadb/api/collection_configuration.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import TypedDict, Dict, Any, Optional, cast, get_args
from typing import Type, TypedDict, Dict, Any, Optional, cast, get_args
import json
from chromadb.api.types import (
Embeddable,
Space,
CollectionMetadata,
UpdateMetadata,
Expand Down Expand Up @@ -40,7 +41,7 @@ class SpannConfiguration(TypedDict, total=False):
class CollectionConfiguration(TypedDict, total=True):
hnsw: Optional[HNSWConfiguration]
spann: Optional[SpannConfiguration]
embedding_function: Optional[EmbeddingFunction] # type: ignore
embedding_function: Optional[EmbeddingFunction[Embeddable]]


def load_collection_configuration_from_json_str(
Expand Down Expand Up @@ -88,13 +89,13 @@ def load_collection_configuration_from_json(
f"Embedding function name not found in config: {ef_config}"
)
try:
ef = known_embedding_functions[ef_name]
ef_class = known_embedding_functions[ef_name]
except KeyError:
raise ValueError(
f"Embedding function {ef_name} not found. Add @register_embedding_function decorator to the class definition."
)
try:
ef = ef.build_from_config(ef_config["config"]) # type: ignore
ef = ef_class.build_from_config(ef_config["config"])
except Exception as e:
raise ValueError(
f"Could not build embedding function {ef_config['name']} from config {ef_config['config']}: {e}"
Expand All @@ -106,7 +107,7 @@ def load_collection_configuration_from_json(
return CollectionConfiguration(
hnsw=hnsw_config,
spann=spann_config,
embedding_function=ef, # type: ignore
embedding_function=ef,
)


Expand Down Expand Up @@ -257,7 +258,7 @@ def json_to_create_spann_configuration(
class CreateCollectionConfiguration(TypedDict, total=False):
hnsw: Optional[CreateHNSWConfiguration]
spann: Optional[CreateSpannConfiguration]
embedding_function: Optional[EmbeddingFunction] # type: ignore
embedding_function: Optional[EmbeddingFunction[Embeddable]]


def load_collection_configuration_from_create_collection_configuration(
Expand Down Expand Up @@ -381,7 +382,7 @@ def create_collection_configuration_to_json(
}

try:
ef = cast(EmbeddingFunction, config.get("embedding_function")) # type: ignore
ef = cast(EmbeddingFunction[Embeddable], config.get("embedding_function"))
if ef.is_legacy():
ef_config = {"type": "legacy"}
else:
Expand Down Expand Up @@ -456,7 +457,7 @@ def create_collection_configuration_to_json(


def populate_create_hnsw_defaults(
config: CreateHNSWConfiguration, ef: Optional[EmbeddingFunction] = None # type: ignore
config: CreateHNSWConfiguration, ef: Optional[EmbeddingFunction[Embeddable]] = None
) -> CreateHNSWConfiguration:
"""Populate a CreateHNSW configuration with default values"""
if config.get("space") is None:
Expand Down Expand Up @@ -522,7 +523,7 @@ def json_to_update_spann_configuration(
class UpdateCollectionConfiguration(TypedDict, total=False):
hnsw: Optional[UpdateHNSWConfiguration]
spann: Optional[UpdateSpannConfiguration]
embedding_function: Optional[EmbeddingFunction] # type: ignore
embedding_function: Optional[EmbeddingFunction[Embeddable]]


def update_collection_configuration_from_legacy_collection_metadata(
Expand Down Expand Up @@ -697,9 +698,9 @@ def overwrite_spann_configuration(

# TODO: make warnings prettier and add link to migration docs
def overwrite_embedding_function(
existing_embedding_function: EmbeddingFunction, # type: ignore
update_embedding_function: EmbeddingFunction, # type: ignore
) -> EmbeddingFunction: # type: ignore
existing_embedding_function: EmbeddingFunction[Embeddable],
update_embedding_function: EmbeddingFunction[Embeddable],
) -> EmbeddingFunction[Embeddable]:
"""Overwrite an EmbeddingFunction with a new configuration"""
# Check for legacy embedding functions
if existing_embedding_function.is_legacy() or update_embedding_function.is_legacy():
Expand Down Expand Up @@ -768,8 +769,8 @@ def overwrite_collection_configuration(


def validate_embedding_function_conflict_on_create(
embedding_function: Optional[EmbeddingFunction], # type: ignore
configuration_ef: Optional[EmbeddingFunction], # type: ignore
embedding_function: Optional[EmbeddingFunction[Embeddable]],
configuration_ef: Optional[EmbeddingFunction[Embeddable]],
) -> None:
"""
Validates that there are no conflicting embedding functions between function parameter
Expand Down Expand Up @@ -800,7 +801,7 @@ def validate_embedding_function_conflict_on_create(
# if there is an issue with deserializing the config, an error shouldn't be raised
# at get time. CollectionCommon.py will raise an error at _embed time if there is an issue deserializing.
def validate_embedding_function_conflict_on_get(
embedding_function: Optional[EmbeddingFunction], # type: ignore
embedding_function: Optional[EmbeddingFunction[Embeddable]],
persisted_ef_config: Optional[Dict[str, Any]],
) -> None:
"""
Expand Down
2 changes: 1 addition & 1 deletion chromadb/api/models/CollectionCommon.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def __init__(
model: CollectionModel,
embedding_function: Optional[
EmbeddingFunction[Embeddable]
] = ef.DefaultEmbeddingFunction(), # type: ignore
] = ef.DefaultEmbeddingFunction(),
data_loader: Optional[DataLoader[Loadable]] = None,
):
"""Initializes a new instance of the Collection class."""
Expand Down
2 changes: 1 addition & 1 deletion chromadb/test/property/test_persist.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ def test_delete_add_after_persist(settings: Settings) -> None:
"hnsw:batch_size": 3,
"hnsw:sync_threshold": 3,
},
embedding_function=DefaultEmbeddingFunction(), # type: ignore[arg-type]
embedding_function=DefaultEmbeddingFunction(),
id=UUID("0851f751-2f11-4424-ab23-4ae97074887a"),
dimension=2,
dtype=None,
Expand Down
15 changes: 15 additions & 0 deletions chromadb/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import importlib
from typing import Type, TypeVar, cast

from chromadb.api.types import Document, Documents, Embeddable

C = TypeVar("C")


Expand All @@ -10,3 +12,16 @@ def get_class(fqn: str, type: Type[C]) -> Type[C]:
module = importlib.import_module(module_name)
cls = getattr(module, class_name)
return cast(Type[C], cls)


def text_only_embeddable_check(input: Embeddable, embedding_function_name: str) -> Documents:
"""
Helper function to determine if a given Embeddable is text-only.

Once the minimum supported python version is bumped up to 3.10, this should
be replaced with TypeGuard:
https://docs.python.org/3.10/library/typing.html#typing.TypeGuard
"""
if not all(isinstance(item, Document) for item in input):
raise ValueError(f"{embedding_function_name} only supports text documents, not images")
return cast(Documents, input)
15 changes: 9 additions & 6 deletions chromadb/utils/embedding_functions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from typing import Dict, Any, Type, Set
from typing import Dict, Any, Type, Set, cast
from chromadb.api.types import (
Document,
Embeddable,
EmbeddingFunction,
Embeddings,
Documents,
)

# Import all embedding functions
from chromadb.utils import text_only_embeddable_check
from chromadb.utils.embedding_functions.cohere_embedding_function import (
CohereEmbeddingFunction,
)
Expand Down Expand Up @@ -106,13 +109,13 @@ def get_builtins() -> Set[str]:
return _all_classes


class DefaultEmbeddingFunction(EmbeddingFunction[Documents]):
class DefaultEmbeddingFunction(EmbeddingFunction[Embeddable]):
def __init__(self) -> None:
if is_thin_client:
return

def __call__(self, input: Documents) -> Embeddings:
# Delegate to ONNXMiniLM_L6_V2
def __call__(self, input: Embeddable) -> Embeddings:
# Delegate to ONNXMiniLM_L6_V2
return ONNXMiniLM_L6_V2()(input)

@staticmethod
Expand All @@ -136,7 +139,7 @@ def validate_config(config: Dict[str, Any]) -> None:


# Dictionary of supported embedding functions
known_embedding_functions: Dict[str, Type[EmbeddingFunction]] = { # type: ignore
known_embedding_functions: Dict[str, Type[EmbeddingFunction[Embeddable]]] = {
"cohere": CohereEmbeddingFunction,
"openai": OpenAIEmbeddingFunction,
"huggingface": HuggingFaceEmbeddingFunction,
Expand Down Expand Up @@ -197,7 +200,7 @@ def _register(cls): # type: ignore


# Function to convert config to embedding function
def config_to_embedding_function(config: Dict[str, Any]) -> EmbeddingFunction: # type: ignore
def config_to_embedding_function(config: Dict[str, Any]) -> EmbeddingFunction[Embeddable]:
"""Convert a config dictionary to an embedding function.

Args:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from chromadb.utils import text_only_embeddable_check
from chromadb.utils.embedding_functions.schemas import validate_config_schema
from chromadb.api.types import Embeddings, Documents, EmbeddingFunction
from chromadb.api.types import Embeddable, Embeddings, EmbeddingFunction
from typing import Dict, Any, cast
import json
import numpy as np


class AmazonBedrockEmbeddingFunction(EmbeddingFunction[Documents]):
class AmazonBedrockEmbeddingFunction(EmbeddingFunction[Embeddable]):
"""
This class is used to generate embeddings for a list of texts using Amazon Bedrock.
"""
Expand Down Expand Up @@ -51,7 +52,7 @@ def __init__(
**kwargs,
)

def __call__(self, input: Documents) -> Embeddings:
def __call__(self, input: Embeddable) -> Embeddings:
"""
Generate embeddings for the given documents.

Expand All @@ -65,6 +66,7 @@ def __call__(self, input: Documents) -> Embeddings:
content_type = "application/json"
embeddings = []

input = text_only_embeddable_check(input, "Amazon Bedrock")
for text in input:
input_body = {"inputText": text}
body = json.dumps(input_body)
Expand All @@ -86,7 +88,7 @@ def name() -> str:
return "amazon_bedrock"

@staticmethod
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Embeddable]":
try:
import boto3
except ImportError:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from chromadb.api.types import (
Documents,
Embeddings,
Images,
Embeddable,
EmbeddingFunction,
)
from chromadb.utils.embedding_functions.schemas import validate_config_schema
from typing import List, Dict, Any, Union, cast, Sequence
from typing import List, Dict, Any, cast, Sequence
import numpy as np


Expand Down Expand Up @@ -100,7 +98,7 @@ def embed_image(self, uris: List[str]) -> List[List[float]]:
"The provided embedding function does not support image embeddings."
)

def __call__(self, input: Union[Documents, Images]) -> Embeddings:
def __call__(self, input: Embeddable) -> Embeddings:
"""
Get the embeddings for a list of texts or images.

Expand Down Expand Up @@ -134,7 +132,7 @@ def name() -> str:
@staticmethod
def build_from_config(
config: Dict[str, Any]
) -> "EmbeddingFunction[Union[Documents, Images]]":
) -> "EmbeddingFunction[Embeddable]":
# This is a placeholder implementation since we can't easily serialize and deserialize
# langchain embedding functions. Users will need to recreate the langchain embedding function
# and pass it to create_langchain_embedding.
Expand Down
Loading