diff --git a/chromadb/api/__init__.py b/chromadb/api/__init__.py index 60865835cc4..138049aaa36 100644 --- a/chromadb/api/__init__.py +++ b/chromadb/api/__init__.py @@ -370,7 +370,7 @@ def create_collection( metadata: Optional[CollectionMetadata] = None, embedding_function: Optional[ EmbeddingFunction[Embeddable] - ] = ef.DefaultEmbeddingFunction(), # type: ignore + ] = ef.DefaultEmbeddingFunction(), data_loader: Optional[DataLoader[Loadable]] = None, get_or_create: bool = False, ) -> Collection: @@ -407,7 +407,7 @@ def get_collection( name: str, embedding_function: Optional[ EmbeddingFunction[Embeddable] - ] = ef.DefaultEmbeddingFunction(), # type: ignore + ] = ef.DefaultEmbeddingFunction(), data_loader: Optional[DataLoader[Loadable]] = None, ) -> Collection: """Get a collection with the given name. @@ -439,7 +439,7 @@ def get_or_create_collection( metadata: Optional[CollectionMetadata] = None, embedding_function: Optional[ EmbeddingFunction[Embeddable] - ] = ef.DefaultEmbeddingFunction(), # type: ignore + ] = ef.DefaultEmbeddingFunction(), data_loader: Optional[DataLoader[Loadable]] = None, ) -> Collection: """Get or create a collection with the given name and metadata. diff --git a/chromadb/api/async_api.py b/chromadb/api/async_api.py index f3eb365cc7c..81a1db1ea41 100644 --- a/chromadb/api/async_api.py +++ b/chromadb/api/async_api.py @@ -363,7 +363,7 @@ async def create_collection( metadata: Optional[CollectionMetadata] = None, embedding_function: Optional[ EmbeddingFunction[Embeddable] - ] = ef.DefaultEmbeddingFunction(), # type: ignore + ] = ef.DefaultEmbeddingFunction(), data_loader: Optional[DataLoader[Loadable]] = None, get_or_create: bool = False, ) -> AsyncCollection: @@ -400,7 +400,7 @@ async def get_collection( name: str, embedding_function: Optional[ EmbeddingFunction[Embeddable] - ] = ef.DefaultEmbeddingFunction(), # type: ignore + ] = ef.DefaultEmbeddingFunction(), data_loader: Optional[DataLoader[Loadable]] = None, ) -> AsyncCollection: """Get a collection with the given name. @@ -432,7 +432,7 @@ async def get_or_create_collection( metadata: Optional[CollectionMetadata] = None, embedding_function: Optional[ EmbeddingFunction[Embeddable] - ] = ef.DefaultEmbeddingFunction(), # type: ignore + ] = ef.DefaultEmbeddingFunction(), data_loader: Optional[DataLoader[Loadable]] = None, ) -> AsyncCollection: """Get or create a collection with the given name and metadata. diff --git a/chromadb/api/async_client.py b/chromadb/api/async_client.py index 9bdf59af330..0758720e40c 100644 --- a/chromadb/api/async_client.py +++ b/chromadb/api/async_client.py @@ -180,7 +180,7 @@ async def create_collection( metadata: Optional[CollectionMetadata] = None, embedding_function: Optional[ EmbeddingFunction[Embeddable] - ] = ef.DefaultEmbeddingFunction(), # type: ignore + ] = ef.DefaultEmbeddingFunction(), data_loader: Optional[DataLoader[Loadable]] = None, get_or_create: bool = False, ) -> AsyncCollection: @@ -219,7 +219,7 @@ async def get_collection( name: str, embedding_function: Optional[ EmbeddingFunction[Embeddable] - ] = ef.DefaultEmbeddingFunction(), # type: ignore + ] = ef.DefaultEmbeddingFunction(), data_loader: Optional[DataLoader[Loadable]] = None, ) -> AsyncCollection: model = await self._server.get_collection( @@ -248,7 +248,7 @@ async def get_or_create_collection( metadata: Optional[CollectionMetadata] = None, embedding_function: Optional[ EmbeddingFunction[Embeddable] - ] = ef.DefaultEmbeddingFunction(), # type: ignore + ] = ef.DefaultEmbeddingFunction(), data_loader: Optional[DataLoader[Loadable]] = None, ) -> AsyncCollection: if configuration is None: diff --git a/chromadb/api/client.py b/chromadb/api/client.py index 4bc0e7d755a..003a1ad17fd 100644 --- a/chromadb/api/client.py +++ b/chromadb/api/client.py @@ -156,7 +156,7 @@ def create_collection( metadata: Optional[CollectionMetadata] = None, embedding_function: Optional[ EmbeddingFunction[Embeddable] - ] = ef.DefaultEmbeddingFunction(), # type: ignore + ] = ef.DefaultEmbeddingFunction(), data_loader: Optional[DataLoader[Loadable]] = None, get_or_create: bool = False, ) -> Collection: @@ -195,7 +195,7 @@ def get_collection( name: str, embedding_function: Optional[ EmbeddingFunction[Embeddable] - ] = ef.DefaultEmbeddingFunction(), # type: ignore + ] = ef.DefaultEmbeddingFunction(), data_loader: Optional[DataLoader[Loadable]] = None, ) -> Collection: model = self._server.get_collection( @@ -224,7 +224,7 @@ def get_or_create_collection( metadata: Optional[CollectionMetadata] = None, embedding_function: Optional[ EmbeddingFunction[Embeddable] - ] = ef.DefaultEmbeddingFunction(), # type: ignore + ] = ef.DefaultEmbeddingFunction(), data_loader: Optional[DataLoader[Loadable]] = None, ) -> Collection: if configuration is None: diff --git a/chromadb/api/collection_configuration.py b/chromadb/api/collection_configuration.py index 943415d2e67..24783aeb5e9 100644 --- a/chromadb/api/collection_configuration.py +++ b/chromadb/api/collection_configuration.py @@ -1,6 +1,7 @@ -from typing import TypedDict, Dict, Any, Optional, cast, get_args +from typing import Type, TypedDict, Dict, Any, Optional, cast, get_args import json from chromadb.api.types import ( + Embeddable, Space, CollectionMetadata, UpdateMetadata, @@ -40,7 +41,7 @@ class SpannConfiguration(TypedDict, total=False): class CollectionConfiguration(TypedDict, total=True): hnsw: Optional[HNSWConfiguration] spann: Optional[SpannConfiguration] - embedding_function: Optional[EmbeddingFunction] # type: ignore + embedding_function: Optional[EmbeddingFunction[Embeddable]] def load_collection_configuration_from_json_str( @@ -88,13 +89,13 @@ def load_collection_configuration_from_json( f"Embedding function name not found in config: {ef_config}" ) try: - ef = known_embedding_functions[ef_name] + ef_class = known_embedding_functions[ef_name] except KeyError: raise ValueError( f"Embedding function {ef_name} not found. Add @register_embedding_function decorator to the class definition." ) try: - ef = ef.build_from_config(ef_config["config"]) # type: ignore + ef = ef_class.build_from_config(ef_config["config"]) except Exception as e: raise ValueError( f"Could not build embedding function {ef_config['name']} from config {ef_config['config']}: {e}" @@ -106,7 +107,7 @@ def load_collection_configuration_from_json( return CollectionConfiguration( hnsw=hnsw_config, spann=spann_config, - embedding_function=ef, # type: ignore + embedding_function=ef, ) @@ -257,7 +258,7 @@ def json_to_create_spann_configuration( class CreateCollectionConfiguration(TypedDict, total=False): hnsw: Optional[CreateHNSWConfiguration] spann: Optional[CreateSpannConfiguration] - embedding_function: Optional[EmbeddingFunction] # type: ignore + embedding_function: Optional[EmbeddingFunction[Embeddable]] def load_collection_configuration_from_create_collection_configuration( @@ -381,7 +382,7 @@ def create_collection_configuration_to_json( } try: - ef = cast(EmbeddingFunction, config.get("embedding_function")) # type: ignore + ef = cast(EmbeddingFunction[Embeddable], config.get("embedding_function")) if ef.is_legacy(): ef_config = {"type": "legacy"} else: @@ -456,7 +457,7 @@ def create_collection_configuration_to_json( def populate_create_hnsw_defaults( - config: CreateHNSWConfiguration, ef: Optional[EmbeddingFunction] = None # type: ignore + config: CreateHNSWConfiguration, ef: Optional[EmbeddingFunction[Embeddable]] = None ) -> CreateHNSWConfiguration: """Populate a CreateHNSW configuration with default values""" if config.get("space") is None: @@ -522,7 +523,7 @@ def json_to_update_spann_configuration( class UpdateCollectionConfiguration(TypedDict, total=False): hnsw: Optional[UpdateHNSWConfiguration] spann: Optional[UpdateSpannConfiguration] - embedding_function: Optional[EmbeddingFunction] # type: ignore + embedding_function: Optional[EmbeddingFunction[Embeddable]] def update_collection_configuration_from_legacy_collection_metadata( @@ -697,9 +698,9 @@ def overwrite_spann_configuration( # TODO: make warnings prettier and add link to migration docs def overwrite_embedding_function( - existing_embedding_function: EmbeddingFunction, # type: ignore - update_embedding_function: EmbeddingFunction, # type: ignore -) -> EmbeddingFunction: # type: ignore + existing_embedding_function: EmbeddingFunction[Embeddable], + update_embedding_function: EmbeddingFunction[Embeddable], +) -> EmbeddingFunction[Embeddable]: """Overwrite an EmbeddingFunction with a new configuration""" # Check for legacy embedding functions if existing_embedding_function.is_legacy() or update_embedding_function.is_legacy(): @@ -768,8 +769,8 @@ def overwrite_collection_configuration( def validate_embedding_function_conflict_on_create( - embedding_function: Optional[EmbeddingFunction], # type: ignore - configuration_ef: Optional[EmbeddingFunction], # type: ignore + embedding_function: Optional[EmbeddingFunction[Embeddable]], + configuration_ef: Optional[EmbeddingFunction[Embeddable]], ) -> None: """ Validates that there are no conflicting embedding functions between function parameter @@ -800,7 +801,7 @@ def validate_embedding_function_conflict_on_create( # if there is an issue with deserializing the config, an error shouldn't be raised # at get time. CollectionCommon.py will raise an error at _embed time if there is an issue deserializing. def validate_embedding_function_conflict_on_get( - embedding_function: Optional[EmbeddingFunction], # type: ignore + embedding_function: Optional[EmbeddingFunction[Embeddable]], persisted_ef_config: Optional[Dict[str, Any]], ) -> None: """ diff --git a/chromadb/api/models/CollectionCommon.py b/chromadb/api/models/CollectionCommon.py index 70d3f6cf7a3..44497302775 100644 --- a/chromadb/api/models/CollectionCommon.py +++ b/chromadb/api/models/CollectionCommon.py @@ -117,7 +117,7 @@ def __init__( model: CollectionModel, embedding_function: Optional[ EmbeddingFunction[Embeddable] - ] = ef.DefaultEmbeddingFunction(), # type: ignore + ] = ef.DefaultEmbeddingFunction(), data_loader: Optional[DataLoader[Loadable]] = None, ): """Initializes a new instance of the Collection class.""" diff --git a/chromadb/test/property/test_persist.py b/chromadb/test/property/test_persist.py index ccc72ba45c2..2db94ec4bef 100644 --- a/chromadb/test/property/test_persist.py +++ b/chromadb/test/property/test_persist.py @@ -413,7 +413,7 @@ def test_delete_add_after_persist(settings: Settings) -> None: "hnsw:batch_size": 3, "hnsw:sync_threshold": 3, }, - embedding_function=DefaultEmbeddingFunction(), # type: ignore[arg-type] + embedding_function=DefaultEmbeddingFunction(), id=UUID("0851f751-2f11-4424-ab23-4ae97074887a"), dimension=2, dtype=None, diff --git a/chromadb/utils/__init__.py b/chromadb/utils/__init__.py index fe6bb81853b..281d34538d4 100644 --- a/chromadb/utils/__init__.py +++ b/chromadb/utils/__init__.py @@ -1,6 +1,8 @@ import importlib from typing import Type, TypeVar, cast +from chromadb.api.types import Document, Documents, Embeddable + C = TypeVar("C") @@ -10,3 +12,16 @@ def get_class(fqn: str, type: Type[C]) -> Type[C]: module = importlib.import_module(module_name) cls = getattr(module, class_name) return cast(Type[C], cls) + + +def text_only_embeddable_check(input: Embeddable, embedding_function_name: str) -> Documents: + """ + Helper function to determine if a given Embeddable is text-only. + + Once the minimum supported python version is bumped up to 3.10, this should + be replaced with TypeGuard: + https://docs.python.org/3.10/library/typing.html#typing.TypeGuard + """ + if not all(isinstance(item, Document) for item in input): + raise ValueError(f"{embedding_function_name} only supports text documents, not images") + return cast(Documents, input) diff --git a/chromadb/utils/embedding_functions/__init__.py b/chromadb/utils/embedding_functions/__init__.py index 28984f139cf..a80114177cc 100644 --- a/chromadb/utils/embedding_functions/__init__.py +++ b/chromadb/utils/embedding_functions/__init__.py @@ -1,11 +1,14 @@ -from typing import Dict, Any, Type, Set +from typing import Dict, Any, Type, Set, cast from chromadb.api.types import ( + Document, + Embeddable, EmbeddingFunction, Embeddings, Documents, ) # Import all embedding functions +from chromadb.utils import text_only_embeddable_check from chromadb.utils.embedding_functions.cohere_embedding_function import ( CohereEmbeddingFunction, ) @@ -106,13 +109,13 @@ def get_builtins() -> Set[str]: return _all_classes -class DefaultEmbeddingFunction(EmbeddingFunction[Documents]): +class DefaultEmbeddingFunction(EmbeddingFunction[Embeddable]): def __init__(self) -> None: if is_thin_client: return - def __call__(self, input: Documents) -> Embeddings: - # Delegate to ONNXMiniLM_L6_V2 + def __call__(self, input: Embeddable) -> Embeddings: + # Delegate to ONNXMiniLM_L6_V2 return ONNXMiniLM_L6_V2()(input) @staticmethod @@ -136,7 +139,7 @@ def validate_config(config: Dict[str, Any]) -> None: # Dictionary of supported embedding functions -known_embedding_functions: Dict[str, Type[EmbeddingFunction]] = { # type: ignore +known_embedding_functions: Dict[str, Type[EmbeddingFunction[Embeddable]]] = { "cohere": CohereEmbeddingFunction, "openai": OpenAIEmbeddingFunction, "huggingface": HuggingFaceEmbeddingFunction, @@ -197,7 +200,7 @@ def _register(cls): # type: ignore # Function to convert config to embedding function -def config_to_embedding_function(config: Dict[str, Any]) -> EmbeddingFunction: # type: ignore +def config_to_embedding_function(config: Dict[str, Any]) -> EmbeddingFunction[Embeddable]: """Convert a config dictionary to an embedding function. Args: diff --git a/chromadb/utils/embedding_functions/amazon_bedrock_embedding_function.py b/chromadb/utils/embedding_functions/amazon_bedrock_embedding_function.py index 886701b0dca..97d211a7576 100644 --- a/chromadb/utils/embedding_functions/amazon_bedrock_embedding_function.py +++ b/chromadb/utils/embedding_functions/amazon_bedrock_embedding_function.py @@ -1,11 +1,12 @@ +from chromadb.utils import text_only_embeddable_check from chromadb.utils.embedding_functions.schemas import validate_config_schema -from chromadb.api.types import Embeddings, Documents, EmbeddingFunction +from chromadb.api.types import Embeddable, Embeddings, EmbeddingFunction from typing import Dict, Any, cast import json import numpy as np -class AmazonBedrockEmbeddingFunction(EmbeddingFunction[Documents]): +class AmazonBedrockEmbeddingFunction(EmbeddingFunction[Embeddable]): """ This class is used to generate embeddings for a list of texts using Amazon Bedrock. """ @@ -51,7 +52,7 @@ def __init__( **kwargs, ) - def __call__(self, input: Documents) -> Embeddings: + def __call__(self, input: Embeddable) -> Embeddings: """ Generate embeddings for the given documents. @@ -65,6 +66,7 @@ def __call__(self, input: Documents) -> Embeddings: content_type = "application/json" embeddings = [] + input = text_only_embeddable_check(input, "Amazon Bedrock") for text in input: input_body = {"inputText": text} body = json.dumps(input_body) @@ -86,7 +88,7 @@ def name() -> str: return "amazon_bedrock" @staticmethod - def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]": + def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Embeddable]": try: import boto3 except ImportError: diff --git a/chromadb/utils/embedding_functions/chroma_langchain_embedding_function.py b/chromadb/utils/embedding_functions/chroma_langchain_embedding_function.py index 5f3f8029d89..915e74e7f0c 100644 --- a/chromadb/utils/embedding_functions/chroma_langchain_embedding_function.py +++ b/chromadb/utils/embedding_functions/chroma_langchain_embedding_function.py @@ -1,12 +1,10 @@ from chromadb.api.types import ( - Documents, Embeddings, - Images, Embeddable, EmbeddingFunction, ) from chromadb.utils.embedding_functions.schemas import validate_config_schema -from typing import List, Dict, Any, Union, cast, Sequence +from typing import List, Dict, Any, cast, Sequence import numpy as np @@ -100,7 +98,7 @@ def embed_image(self, uris: List[str]) -> List[List[float]]: "The provided embedding function does not support image embeddings." ) - def __call__(self, input: Union[Documents, Images]) -> Embeddings: + def __call__(self, input: Embeddable) -> Embeddings: """ Get the embeddings for a list of texts or images. @@ -134,7 +132,7 @@ def name() -> str: @staticmethod def build_from_config( config: Dict[str, Any] - ) -> "EmbeddingFunction[Union[Documents, Images]]": + ) -> "EmbeddingFunction[Embeddable]": # This is a placeholder implementation since we can't easily serialize and deserialize # langchain embedding functions. Users will need to recreate the langchain embedding function # and pass it to create_langchain_embedding. diff --git a/chromadb/utils/embedding_functions/cloudflare_workers_ai_embedding_function.py b/chromadb/utils/embedding_functions/cloudflare_workers_ai_embedding_function.py index b48cf22ebf8..71368048d63 100644 --- a/chromadb/utils/embedding_functions/cloudflare_workers_ai_embedding_function.py +++ b/chromadb/utils/embedding_functions/cloudflare_workers_ai_embedding_function.py @@ -1,4 +1,5 @@ from chromadb.api.types import ( + Embeddable, Embeddings, Documents, EmbeddingFunction, @@ -6,6 +7,7 @@ ) from typing import List, Dict, Any, Optional import os +from chromadb.utils import text_only_embeddable_check from chromadb.utils.embedding_functions.schemas import validate_config_schema from typing import cast import warnings @@ -14,7 +16,7 @@ GATEWAY_BASE_URL = "https://gateway.ai.cloudflare.com/v1" -class CloudflareWorkersAIEmbeddingFunction(EmbeddingFunction[Documents]): +class CloudflareWorkersAIEmbeddingFunction(EmbeddingFunction[Embeddable]): """ This class is used to get embeddings for a list of texts using the Cloudflare Workers AI API. It requires an API key and a model name. @@ -70,7 +72,7 @@ def __init__( {"Authorization": f"Bearer {self.api_key}", "Accept-Encoding": "identity"} ) - def __call__(self, input: Documents) -> Embeddings: + def __call__(self, input: Embeddable) -> Embeddings: """ Generate embeddings for the given documents. @@ -80,10 +82,7 @@ def __call__(self, input: Documents) -> Embeddings: Returns: Embeddings for the documents. """ - if not all(isinstance(item, str) for item in input): - raise ValueError( - "Cloudflare Workers AI only supports text documents, not images" - ) + input = text_only_embeddable_check(input, "Cloudflare Workers AI") payload: Dict[str, Any] = { "text": input, @@ -107,7 +106,7 @@ def supported_spaces(self) -> List[Space]: return ["cosine", "l2", "ip"] @staticmethod - def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]": + def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Embeddable]": api_key_env_var = config.get("api_key_env_var") model_name = config.get("model_name") account_id = config.get("account_id") diff --git a/chromadb/utils/embedding_functions/google_embedding_function.py b/chromadb/utils/embedding_functions/google_embedding_function.py index b08e1f5dd6f..bb96b5fc62a 100644 --- a/chromadb/utils/embedding_functions/google_embedding_function.py +++ b/chromadb/utils/embedding_functions/google_embedding_function.py @@ -1,13 +1,14 @@ -from chromadb.api.types import Embeddings, Documents, EmbeddingFunction, Space +from chromadb.api.types import Embeddable, Embeddings, EmbeddingFunction, Space from typing import List, Dict, Any, cast, Optional import os import numpy as np import numpy.typing as npt +from chromadb.utils import text_only_embeddable_check from chromadb.utils.embedding_functions.schemas import validate_config_schema import warnings -class GooglePalmEmbeddingFunction(EmbeddingFunction[Documents]): +class GooglePalmEmbeddingFunction(EmbeddingFunction[Embeddable]): """To use this EmbeddingFunction, you must have the google.generativeai Python package installed and have a PaLM API key.""" def __init__( @@ -48,19 +49,18 @@ def __init__( palm.configure(api_key=self.api_key) self._palm = palm - def __call__(self, input: Documents) -> Embeddings: + def __call__(self, input: Embeddable) -> Embeddings: """ Generate embeddings for the given documents. Args: - input: Documents or images to generate embeddings for. + input: Documents to generate embeddings for. Returns: Embeddings for the documents. """ # Google PaLM only works with text documents - if not all(isinstance(item, str) for item in input): - raise ValueError("Google PaLM only supports text documents, not images") + input = text_only_embeddable_check(input, "Google PaLM") return [ np.array( @@ -83,7 +83,7 @@ def supported_spaces(self) -> List[Space]: return ["cosine", "l2", "ip"] @staticmethod - def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]": + def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Embeddable]": api_key_env_var = config.get("api_key_env_var") model_name = config.get("model_name") @@ -119,7 +119,7 @@ def validate_config(config: Dict[str, Any]) -> None: validate_config_schema(config, "google_palm") -class GoogleGenerativeAiEmbeddingFunction(EmbeddingFunction[Documents]): +class GoogleGenerativeAiEmbeddingFunction(EmbeddingFunction[Embeddable]): """To use this EmbeddingFunction, you must have the google.generativeai Python package installed and have a Google API key.""" def __init__( @@ -165,21 +165,18 @@ def __init__( genai.configure(api_key=self.api_key) self._genai = genai - def __call__(self, input: Documents) -> Embeddings: + def __call__(self, input: Embeddable) -> Embeddings: """ Generate embeddings for the given documents. Args: - input: Documents or images to generate embeddings for. + input: Documents to generate embeddings for. Returns: Embeddings for the documents. """ # Google Generative AI only works with text documents - if not all(isinstance(item, str) for item in input): - raise ValueError( - "Google Generative AI only supports text documents, not images" - ) + input = text_only_embeddable_check(input, "Google Generative AI") embeddings_list: List[npt.NDArray[np.float32]] = [] for text in input: @@ -206,7 +203,7 @@ def supported_spaces(self) -> List[Space]: return ["cosine", "l2", "ip"] @staticmethod - def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]": + def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Embeddable]": api_key_env_var = config.get("api_key_env_var") model_name = config.get("model_name") task_type = config.get("task_type") @@ -251,7 +248,7 @@ def validate_config(config: Dict[str, Any]) -> None: validate_config_schema(config, "google_generative_ai") -class GoogleVertexEmbeddingFunction(EmbeddingFunction[Documents]): +class GoogleVertexEmbeddingFunction(EmbeddingFunction[Embeddable]): """To use this EmbeddingFunction, you must have the vertexai Python package installed and have Google Cloud credentials configured.""" def __init__( @@ -301,19 +298,18 @@ def __init__( vertexai.init(project=project_id, location=region) self._model = TextEmbeddingModel.from_pretrained(model_name) - def __call__(self, input: Documents) -> Embeddings: + def __call__(self, input: Embeddable) -> Embeddings: """ Generate embeddings for the given documents. Args: - input: Documents or images to generate embeddings for. + input: Documents to generate embeddings for. Returns: Embeddings for the documents. """ # Google Vertex only works with text documents - if not all(isinstance(item, str) for item in input): - raise ValueError("Google Vertex only supports text documents, not images") + input = text_only_embeddable_check(input, "Google Vertex") embeddings_list: List[npt.NDArray[np.float32]] = [] for text in input: @@ -336,7 +332,7 @@ def supported_spaces(self) -> List[Space]: return ["cosine", "l2", "ip"] @staticmethod - def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]": + def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Embeddable]": api_key_env_var = config.get("api_key_env_var") model_name = config.get("model_name") project_id = config.get("project_id") diff --git a/chromadb/utils/embedding_functions/huggingface_embedding_function.py b/chromadb/utils/embedding_functions/huggingface_embedding_function.py index 193a74b0eda..3d6c5171856 100644 --- a/chromadb/utils/embedding_functions/huggingface_embedding_function.py +++ b/chromadb/utils/embedding_functions/huggingface_embedding_function.py @@ -1,12 +1,13 @@ -from chromadb.api.types import Embeddings, Documents, EmbeddingFunction, Space +from chromadb.api.types import Embeddable, Embeddings, EmbeddingFunction, Space from typing import List, Dict, Any, Optional import os import numpy as np +from chromadb.utils import text_only_embeddable_check from chromadb.utils.embedding_functions.schemas import validate_config_schema import warnings -class HuggingFaceEmbeddingFunction(EmbeddingFunction[Documents]): +class HuggingFaceEmbeddingFunction(EmbeddingFunction[Embeddable]): """ This class is used to get embeddings for a list of texts using the HuggingFace API. It requires an API key and a model name. The default model name is "sentence-transformers/all-MiniLM-L6-v2". @@ -51,12 +52,12 @@ def __init__( self._session = httpx.Client() self._session.headers.update({"Authorization": f"Bearer {self.api_key}"}) - def __call__(self, input: Documents) -> Embeddings: + def __call__(self, input: Embeddable) -> Embeddings: """ Get the embeddings for a list of texts. Args: - input (Documents): A list of texts to get embeddings for. + input (Embeddable): A list of texts to get embeddings for. Returns: Embeddings: The embeddings for the texts. @@ -66,6 +67,8 @@ def __call__(self, input: Documents) -> Embeddings: >>> texts = ["Hello, world!", "How are you?"] >>> embeddings = hugging_face(texts) """ + input = text_only_embeddable_check(input, "HuggingFace") + # Call HuggingFace Embedding API for each document response = self._session.post( self._api_url, @@ -86,7 +89,7 @@ def supported_spaces(self) -> List[Space]: return ["cosine", "l2", "ip"] @staticmethod - def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]": + def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Embeddable]": api_key_env_var = config.get("api_key_env_var") model_name = config.get("model_name") @@ -122,7 +125,7 @@ def validate_config(config: Dict[str, Any]) -> None: validate_config_schema(config, "huggingface") -class HuggingFaceEmbeddingServer(EmbeddingFunction[Documents]): +class HuggingFaceEmbeddingServer(EmbeddingFunction[Embeddable]): """ This class is used to get embeddings for a list of texts using the HuggingFace Embedding server (https://github.com/huggingface/text-embeddings-inference). @@ -171,12 +174,12 @@ def __init__( if self.api_key is not None: self._session.headers.update({"Authorization": f"Bearer {self.api_key}"}) - def __call__(self, input: Documents) -> Embeddings: + def __call__(self, input: Embeddable) -> Embeddings: """ Get the embeddings for a list of texts. Args: - input (Documents): A list of texts to get embeddings for. + input (Embeddable): A list of texts to get embeddings for. Returns: Embeddings: The embeddings for the texts. @@ -186,6 +189,9 @@ def __call__(self, input: Documents) -> Embeddings: >>> texts = ["Hello, world!", "How are you?"] >>> embeddings = hugging_face(texts) """ + + input = text_only_embeddable_check(input, "HuggingFace Server") + # Call HuggingFace Embedding Server API for each document response = self._session.post(self._api_url, json={"inputs": input}).json() @@ -203,7 +209,7 @@ def supported_spaces(self) -> List[Space]: return ["cosine", "l2", "ip"] @staticmethod - def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]": + def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Embeddable]": url = config.get("url") api_key_env_var = config.get("api_key_env_var") if url is None: diff --git a/chromadb/utils/embedding_functions/instructor_embedding_function.py b/chromadb/utils/embedding_functions/instructor_embedding_function.py index 0d6b053c640..4b23a8f6f8d 100644 --- a/chromadb/utils/embedding_functions/instructor_embedding_function.py +++ b/chromadb/utils/embedding_functions/instructor_embedding_function.py @@ -1,10 +1,11 @@ -from chromadb.api.types import Embeddings, Documents, EmbeddingFunction, Space +from chromadb.api.types import Embeddable, Embeddings, EmbeddingFunction, Space +from chromadb.utils import text_only_embeddable_check from chromadb.utils.embedding_functions.schemas import validate_config_schema from typing import List, Dict, Any, Optional import numpy as np -class InstructorEmbeddingFunction(EmbeddingFunction[Documents]): +class InstructorEmbeddingFunction(EmbeddingFunction[Embeddable]): """ This class is used to generate embeddings for a list of texts using the Instructor embedding model. """ @@ -41,19 +42,18 @@ def __init__( self._model = INSTRUCTOR(model_name_or_path=model_name, device=device) - def __call__(self, input: Documents) -> Embeddings: + def __call__(self, input: Embeddable) -> Embeddings: """ Generate embeddings for the given documents. Args: - input: Documents or images to generate embeddings for. + input: Documents to generate embeddings for. Returns: Embeddings for the documents. """ # Instructor only works with text documents - if not all(isinstance(item, str) for item in input): - raise ValueError("Instructor only supports text documents, not images") + input = text_only_embeddable_check(input, "Instructor") if self.instruction is None: embeddings = self._model.encode(input, convert_to_numpy=True) @@ -77,7 +77,7 @@ def supported_spaces(self) -> List[Space]: return ["cosine", "l2", "ip"] @staticmethod - def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]": + def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Embeddable]": model_name = config.get("model_name") device = config.get("device") instruction = config.get("instruction") diff --git a/chromadb/utils/embedding_functions/jina_embedding_function.py b/chromadb/utils/embedding_functions/jina_embedding_function.py index b627295259b..3fffbe73cca 100644 --- a/chromadb/utils/embedding_functions/jina_embedding_function.py +++ b/chromadb/utils/embedding_functions/jina_embedding_function.py @@ -1,4 +1,5 @@ -from chromadb.api.types import Embeddings, Documents, EmbeddingFunction, Space +from chromadb.api.types import Embeddable, Embeddings, EmbeddingFunction, Space +from chromadb.utils import text_only_embeddable_check from chromadb.utils.embedding_functions.schemas import validate_config_schema from typing import List, Dict, Any, Union, Optional import os @@ -6,7 +7,7 @@ import warnings -class JinaEmbeddingFunction(EmbeddingFunction[Documents]): +class JinaEmbeddingFunction(EmbeddingFunction[Embeddable]): """ This class is used to get embeddings for a list of texts using the Jina AI API. It requires an API key and a model name. The default model name is "jina-embeddings-v2-base-en". @@ -81,12 +82,12 @@ def __init__( {"Authorization": f"Bearer {self.api_key}", "Accept-Encoding": "identity"} ) - def __call__(self, input: Documents) -> Embeddings: + def __call__(self, input: Embeddable) -> Embeddings: """ Get the embeddings for a list of texts. Args: - input (Documents): A list of texts to get embeddings for. + input (Embeddable): A list of texts to get embeddings for. Returns: Embeddings: The embeddings for the texts. @@ -96,8 +97,7 @@ def __call__(self, input: Documents) -> Embeddings: >>> input = ["Hello, world!", "How are you?"] """ # Jina AI only works with text documents - if not all(isinstance(item, str) for item in input): - raise ValueError("Jina AI only supports text documents, not images") + input = text_only_embeddable_check(input, "Jina AI") payload: Dict[str, Any] = { "input": input, @@ -150,7 +150,7 @@ def supported_spaces(self) -> List[Space]: return ["cosine", "l2", "ip"] @staticmethod - def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]": + def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Embeddable]": api_key_env_var = config.get("api_key_env_var") model_name = config.get("model_name") task = config.get("task") diff --git a/chromadb/utils/embedding_functions/mistral_embedding_function.py b/chromadb/utils/embedding_functions/mistral_embedding_function.py index b4d67712354..177930400f4 100644 --- a/chromadb/utils/embedding_functions/mistral_embedding_function.py +++ b/chromadb/utils/embedding_functions/mistral_embedding_function.py @@ -1,11 +1,12 @@ -from chromadb.api.types import Embeddings, Documents, EmbeddingFunction, Space +from chromadb.api.types import Embeddable, Embeddings, EmbeddingFunction, Space +from chromadb.utils import text_only_embeddable_check from chromadb.utils.embedding_functions.schemas import validate_config_schema from typing import List, Dict, Any import os import numpy as np -class MistralEmbeddingFunction(EmbeddingFunction[Documents]): +class MistralEmbeddingFunction(EmbeddingFunction[Embeddable]): def __init__( self, model: str, @@ -31,15 +32,14 @@ def __init__( raise ValueError(f"The {api_key_env_var} environment variable is not set.") self.client = Mistral(api_key=self.api_key) - def __call__(self, input: Documents) -> Embeddings: + def __call__(self, input: Embeddable) -> Embeddings: """ Get the embeddings for a list of texts. Args: - input (Documents): A list of texts to get embeddings for. + input (Embeddable): A list of texts to get embeddings for. """ - if not all(isinstance(item, str) for item in input): - raise ValueError("Mistral only supports text documents, not images") + input = text_only_embeddable_check(input, "Mistral") output = self.client.embeddings.create( model=self.model, inputs=input, @@ -59,7 +59,7 @@ def supported_spaces(self) -> List[Space]: return ["cosine", "l2", "ip"] @staticmethod - def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]": + def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Embeddable]": model = config.get("model") api_key_env_var = config.get("api_key_env_var") diff --git a/chromadb/utils/embedding_functions/morph_embedding_function.py b/chromadb/utils/embedding_functions/morph_embedding_function.py index ac8e569714d..0ce04e7a0a6 100644 --- a/chromadb/utils/embedding_functions/morph_embedding_function.py +++ b/chromadb/utils/embedding_functions/morph_embedding_function.py @@ -1,12 +1,13 @@ -from chromadb.api.types import Embeddings, Documents, EmbeddingFunction, Space +from chromadb.api.types import Embeddable, Embeddings, Documents, EmbeddingFunction, Space from typing import List, Dict, Any, Optional import os import numpy as np +from chromadb.utils import text_only_embeddable_check from chromadb.utils.embedding_functions.schemas import validate_config_schema import warnings -class MorphEmbeddingFunction(EmbeddingFunction[Documents]): +class MorphEmbeddingFunction(EmbeddingFunction[Embeddable]): def __init__( self, api_key: Optional[str] = None, @@ -60,7 +61,7 @@ def __init__( base_url=self.api_base, ) - def __call__(self, input: Documents) -> Embeddings: + def __call__(self, input: Embeddable) -> Embeddings: """ Generate embeddings for the given documents. @@ -70,6 +71,8 @@ def __call__(self, input: Documents) -> Embeddings: Returns: Embeddings for the documents. """ + input = text_only_embeddable_check(input, type(self).__name__) + # Handle empty input if not input: return [] @@ -99,7 +102,7 @@ def supported_spaces(self) -> List[Space]: return ["cosine", "l2", "ip"] @staticmethod - def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]": + def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Embeddable]": # Extract parameters from config api_key_env_var = config.get("api_key_env_var") model_name = config.get("model_name") @@ -144,4 +147,4 @@ def validate_config(config: Dict[str, Any]) -> None: Raises: ValidationError: If the configuration does not match the schema """ - validate_config_schema(config, "morph") \ No newline at end of file + validate_config_schema(config, "morph") diff --git a/chromadb/utils/embedding_functions/ollama_embedding_function.py b/chromadb/utils/embedding_functions/ollama_embedding_function.py index 4e0ebe3b005..7164c168e76 100644 --- a/chromadb/utils/embedding_functions/ollama_embedding_function.py +++ b/chromadb/utils/embedding_functions/ollama_embedding_function.py @@ -1,4 +1,5 @@ -from chromadb.api.types import Embeddings, Documents, EmbeddingFunction, Space +from chromadb.api.types import Embeddable, Embeddings, EmbeddingFunction, Space +from chromadb.utils import text_only_embeddable_check from chromadb.utils.embedding_functions.schemas import validate_config_schema from typing import List, Dict, Any import numpy as np @@ -7,7 +8,7 @@ DEFAULT_MODEL_NAME = "chroma/all-minilm-l6-v2-f32" -class OllamaEmbeddingFunction(EmbeddingFunction[Documents]): +class OllamaEmbeddingFunction(EmbeddingFunction[Embeddable]): """ This class is used to generate embeddings for a list of texts using the Ollama Embedding API (https://github.com/ollama/ollama/blob/main/docs/api.md#generate-embeddings). @@ -47,12 +48,12 @@ def __init__( self._client = Client(host=self._base_url, timeout=timeout) - def __call__(self, input: Documents) -> Embeddings: + def __call__(self, input: Embeddable) -> Embeddings: """ Get the embeddings for a list of texts. Args: - input (Documents): A list of texts to get embeddings for. + input (Embeddable): A list of texts to get embeddings for. Returns: Embeddings: The embeddings for the texts. @@ -62,6 +63,8 @@ def __call__(self, input: Documents) -> Embeddings: >>> texts = ["Hello, world!", "How are you?"] >>> embeddings = ollama_ef(texts) """ + input = text_only_embeddable_check(input, "Ollama") + # Call Ollama client response = self._client.embed(model=self.model_name, input=input) @@ -82,7 +85,7 @@ def supported_spaces(self) -> List[Space]: return ["cosine", "l2", "ip"] @staticmethod - def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]": + def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Embeddable]": url = config.get("url") model_name = config.get("model_name") timeout = config.get("timeout") diff --git a/chromadb/utils/embedding_functions/onnx_mini_lm_l6_v2.py b/chromadb/utils/embedding_functions/onnx_mini_lm_l6_v2.py index 4084121cd6e..79a1994cebb 100644 --- a/chromadb/utils/embedding_functions/onnx_mini_lm_l6_v2.py +++ b/chromadb/utils/embedding_functions/onnx_mini_lm_l6_v2.py @@ -13,7 +13,8 @@ import httpx from tenacity import retry, retry_if_exception, stop_after_attempt, wait_random -from chromadb.api.types import Documents, Embeddings, EmbeddingFunction, Space +from chromadb.api.types import Embeddable, Embeddings, EmbeddingFunction, Space +from chromadb.utils import text_only_embeddable_check from chromadb.utils.embedding_functions.schemas import validate_config_schema logger = logging.getLogger(__name__) @@ -34,7 +35,7 @@ def _verify_sha256(fname: str, expected_sha256: str) -> bool: # implements the same functionality as "all-MiniLM-L6-v2" from sentence-transformers. # visit https://github.com/chroma-core/onnx-embedding for the source code to generate # and verify the ONNX model. -class ONNXMiniLM_L6_V2(EmbeddingFunction[Documents]): +class ONNXMiniLM_L6_V2(EmbeddingFunction[Embeddable]): MODEL_NAME = "all-MiniLM-L6-v2" DOWNLOAD_PATH = Path.home() / ".cache" / "chroma" / "onnx_models" / MODEL_NAME EXTRACTED_FOLDER_NAME = "onnx" @@ -255,7 +256,7 @@ def model(self) -> Any: sess_options=so, ) - def __call__(self, input: Documents) -> Embeddings: + def __call__(self, input: Embeddable) -> Embeddings: """ Generate embeddings for the given documents. @@ -266,6 +267,8 @@ def __call__(self, input: Documents) -> Embeddings: Embeddings for the documents. """ + input = text_only_embeddable_check(input, type(self).__name__) + # Only download the model when it is actually used self._download_model_if_not_exists() @@ -336,7 +339,7 @@ def max_tokens(self) -> int: return 256 @staticmethod - def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]": + def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Embeddable]": preferred_providers = config.get("preferred_providers") return ONNXMiniLM_L6_V2(preferred_providers=preferred_providers) diff --git a/chromadb/utils/embedding_functions/open_clip_embedding_function.py b/chromadb/utils/embedding_functions/open_clip_embedding_function.py index 6c6ef9cd454..ad087a76e6a 100644 --- a/chromadb/utils/embedding_functions/open_clip_embedding_function.py +++ b/chromadb/utils/embedding_functions/open_clip_embedding_function.py @@ -142,7 +142,7 @@ def supported_spaces(self) -> List[Space]: @staticmethod def build_from_config( config: Dict[str, Any] - ) -> "EmbeddingFunction[Union[Documents, Images]]": + ) -> "EmbeddingFunction[Embeddable]": model_name = config.get("model_name") checkpoint = config.get("checkpoint") device = config.get("device") diff --git a/chromadb/utils/embedding_functions/openai_embedding_function.py b/chromadb/utils/embedding_functions/openai_embedding_function.py index b34babd57ff..74df0ddba41 100644 --- a/chromadb/utils/embedding_functions/openai_embedding_function.py +++ b/chromadb/utils/embedding_functions/openai_embedding_function.py @@ -1,12 +1,13 @@ -from chromadb.api.types import Embeddings, Documents, EmbeddingFunction, Space +from chromadb.api.types import Embeddable, Embeddings, EmbeddingFunction, Space from typing import List, Dict, Any, Optional import os import numpy as np +from chromadb.utils import text_only_embeddable_check from chromadb.utils.embedding_functions.schemas import validate_config_schema import warnings -class OpenAIEmbeddingFunction(EmbeddingFunction[Documents]): +class OpenAIEmbeddingFunction(EmbeddingFunction[Embeddable]): def __init__( self, api_key: Optional[str] = None, @@ -102,7 +103,7 @@ def __init__( default_headers=self.default_headers, ) - def __call__(self, input: Documents) -> Embeddings: + def __call__(self, input: Embeddable) -> Embeddings: """ Generate embeddings for the given documents. Args: @@ -110,6 +111,8 @@ def __call__(self, input: Documents) -> Embeddings: Returns: Embeddings for the documents. """ + input = text_only_embeddable_check(input, type(self).__name__) + # Handle batching if not input: return [] @@ -141,7 +144,7 @@ def supported_spaces(self) -> List[Space]: return ["cosine", "l2", "ip"] @staticmethod - def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]": + def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Embeddable]": # Extract parameters from config api_key_env_var = config.get("api_key_env_var") model_name = config.get("model_name") diff --git a/chromadb/utils/embedding_functions/sentence_transformer_embedding_function.py b/chromadb/utils/embedding_functions/sentence_transformer_embedding_function.py index c5b2cfb47b1..9f683a6cafb 100644 --- a/chromadb/utils/embedding_functions/sentence_transformer_embedding_function.py +++ b/chromadb/utils/embedding_functions/sentence_transformer_embedding_function.py @@ -1,10 +1,11 @@ -from chromadb.api.types import EmbeddingFunction, Space, Embeddings, Documents +from chromadb.api.types import Embeddable, EmbeddingFunction, Space, Embeddings, Documents from typing import List, Dict, Any import numpy as np +from chromadb.utils import text_only_embeddable_check from chromadb.utils.embedding_functions.schemas import validate_config_schema -class SentenceTransformerEmbeddingFunction(EmbeddingFunction[Documents]): +class SentenceTransformerEmbeddingFunction(EmbeddingFunction[Embeddable]): # Since we do dynamic imports we have to type this as Any models: Dict[str, Any] = {} @@ -46,7 +47,7 @@ def __init__( ) self._model = self.models[model_name] - def __call__(self, input: Documents) -> Embeddings: + def __call__(self, input: Embeddable) -> Embeddings: """Generate embeddings for the given documents. Args: @@ -55,6 +56,7 @@ def __call__(self, input: Documents) -> Embeddings: Returns: Embeddings for the documents. """ + input = text_only_embeddable_check(input, "Sentence Transformers") embeddings = self._model.encode( list(input), convert_to_numpy=True, @@ -75,7 +77,7 @@ def supported_spaces(self) -> List[Space]: return ["cosine", "l2", "ip"] @staticmethod - def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]": + def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Embeddable]": model_name = config.get("model_name") device = config.get("device") normalize_embeddings = config.get("normalize_embeddings") diff --git a/chromadb/utils/embedding_functions/text2vec_embedding_function.py b/chromadb/utils/embedding_functions/text2vec_embedding_function.py index 4c9838f55ac..594374165e9 100644 --- a/chromadb/utils/embedding_functions/text2vec_embedding_function.py +++ b/chromadb/utils/embedding_functions/text2vec_embedding_function.py @@ -1,10 +1,11 @@ -from chromadb.api.types import EmbeddingFunction, Space, Embeddings, Documents +from chromadb.api.types import Embeddable, EmbeddingFunction, Space, Embeddings +from chromadb.utils import text_only_embeddable_check from chromadb.utils.embedding_functions.schemas import validate_config_schema from typing import List, Dict, Any import numpy as np -class Text2VecEmbeddingFunction(EmbeddingFunction[Documents]): +class Text2VecEmbeddingFunction(EmbeddingFunction[Embeddable]): """ This class is used to generate embeddings for a list of texts using the Text2Vec model. """ @@ -27,19 +28,18 @@ def __init__(self, model_name: str = "shibing624/text2vec-base-chinese"): self.model_name = model_name self._model = SentenceModel(model_name_or_path=model_name) - def __call__(self, input: Documents) -> Embeddings: + def __call__(self, input: Embeddable) -> Embeddings: """ Generate embeddings for the given documents. Args: - input: Documents or images to generate embeddings for. + input: Documents to generate embeddings for. Returns: Embeddings for the documents. """ # Text2Vec only works with text documents - if not all(isinstance(item, str) for item in input): - raise ValueError("Text2Vec only supports text documents, not images") + input = text_only_embeddable_check(input, "Text2Vec") embeddings = self._model.encode(list(input), convert_to_numpy=True) @@ -57,7 +57,7 @@ def supported_spaces(self) -> List[Space]: return ["cosine", "l2", "ip"] @staticmethod - def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]": + def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Embeddable]": model_name = config.get("model_name") if model_name is None: diff --git a/chromadb/utils/embedding_functions/together_ai_embedding_function.py b/chromadb/utils/embedding_functions/together_ai_embedding_function.py index d258ca22a64..d69c08074d1 100644 --- a/chromadb/utils/embedding_functions/together_ai_embedding_function.py +++ b/chromadb/utils/embedding_functions/together_ai_embedding_function.py @@ -1,11 +1,12 @@ from chromadb.api.types import ( + Embeddable, Embeddings, - Documents, EmbeddingFunction, Space, ) from typing import List, Dict, Any, Optional import os +from chromadb.utils import text_only_embeddable_check from chromadb.utils.embedding_functions.schemas import validate_config_schema from typing import cast import warnings @@ -13,7 +14,7 @@ ENDPOINT = "https://api.together.xyz/v1/embeddings" -class TogetherAIEmbeddingFunction(EmbeddingFunction[Documents]): +class TogetherAIEmbeddingFunction(EmbeddingFunction[Embeddable]): """ This class is used to get embeddings for a list of texts using the Together AI API. """ @@ -68,7 +69,7 @@ def __init__( } ) - def __call__(self, input: Documents) -> Embeddings: + def __call__(self, input: Embeddable) -> Embeddings: """ Embed a list of texts using the Together AI API. @@ -82,8 +83,7 @@ def __call__(self, input: Documents) -> Embeddings: if not isinstance(input, list): raise ValueError("Input must be a list") - if not all(isinstance(item, str) for item in input): - raise ValueError("All items in input must be strings") + input = text_only_embeddable_check(input, "Together AI") response = self._session.post( ENDPOINT, @@ -109,7 +109,7 @@ def supported_spaces(self) -> List[Space]: return ["cosine", "l2", "ip"] @staticmethod - def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]": + def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Embeddable]": api_key_env_var = config.get("api_key_env_var") model_name = config.get("model_name") diff --git a/chromadb/utils/embedding_functions/voyageai_embedding_function.py b/chromadb/utils/embedding_functions/voyageai_embedding_function.py index 94a7051e46e..80f789fd3f4 100644 --- a/chromadb/utils/embedding_functions/voyageai_embedding_function.py +++ b/chromadb/utils/embedding_functions/voyageai_embedding_function.py @@ -1,4 +1,5 @@ -from chromadb.api.types import EmbeddingFunction, Space, Embeddings, Documents +from chromadb.api.types import Embeddable, EmbeddingFunction, Space, Embeddings +from chromadb.utils import text_only_embeddable_check from chromadb.utils.embedding_functions.schemas import validate_config_schema from typing import List, Dict, Any, Optional import os @@ -6,7 +7,7 @@ import warnings -class VoyageAIEmbeddingFunction(EmbeddingFunction[Documents]): +class VoyageAIEmbeddingFunction(EmbeddingFunction[Embeddable]): """ This class is used to generate embeddings for a list of texts using the VoyageAI API. """ @@ -57,7 +58,7 @@ def __init__( self.truncation = truncation self._client = voyageai.Client(api_key=self.api_key) - def __call__(self, input: Documents) -> Embeddings: + def __call__(self, input: Embeddable) -> Embeddings: """ Generate embeddings for the given documents. @@ -67,6 +68,7 @@ def __call__(self, input: Documents) -> Embeddings: Returns: Embeddings for the documents. """ + input = text_only_embeddable_check(input, "VoyageAI") embeddings = self._client.embed( texts=input, model=self.model_name, @@ -90,7 +92,7 @@ def supported_spaces(self) -> List[Space]: return ["cosine", "l2", "ip"] @staticmethod - def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]": + def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Embeddable]": api_key_env_var = config.get("api_key_env_var") model_name = config.get("model_name") input_type = config.get("input_type")