From ffc5e9135e1e5cc471849a8cf4fc7607d287a4ff Mon Sep 17 00:00:00 2001 From: David Reguera Date: Thu, 11 Apr 2024 08:13:28 +0200 Subject: [PATCH 01/23] 1965 - Init the module --- .../utils/{embedding_functions.py => embedding_functions.old.py} | 0 chromadb/utils/embedding_functions/__init__.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename chromadb/utils/{embedding_functions.py => embedding_functions.old.py} (100%) create mode 100644 chromadb/utils/embedding_functions/__init__.py diff --git a/chromadb/utils/embedding_functions.py b/chromadb/utils/embedding_functions.old.py similarity index 100% rename from chromadb/utils/embedding_functions.py rename to chromadb/utils/embedding_functions.old.py diff --git a/chromadb/utils/embedding_functions/__init__.py b/chromadb/utils/embedding_functions/__init__.py new file mode 100644 index 00000000000..e69de29bb2d From 1c28da4f6469a964b003ec45cee1d1dcf5e63004 Mon Sep 17 00:00:00 2001 From: David Reguera Date: Sat, 13 Apr 2024 11:41:05 +0200 Subject: [PATCH 02/23] 1965 - Move over `AmazonBedrockEmbeddingFunction` --- chromadb/utils/embedding_functions.old.py | 47 --------------- .../utils/embedding_functions/__init__.py | 2 + .../amazon_bedrock_embedding_function.py | 60 +++++++++++++++++++ 3 files changed, 62 insertions(+), 47 deletions(-) create mode 100644 chromadb/utils/embedding_functions/amazon_bedrock_embedding_function.py diff --git a/chromadb/utils/embedding_functions.old.py b/chromadb/utils/embedding_functions.old.py index 3f0a1ce043b..08350b493c6 100644 --- a/chromadb/utils/embedding_functions.old.py +++ b/chromadb/utils/embedding_functions.old.py @@ -809,53 +809,6 @@ def __call__(self, input: Union[Documents, Images]) -> Embeddings: return embeddings - -class AmazonBedrockEmbeddingFunction(EmbeddingFunction[Documents]): - def __init__( - self, - session: "boto3.Session", # noqa: F821 # Quote for forward reference - model_name: str = "amazon.titan-embed-text-v1", - **kwargs: Any, - ): - """Initialize AmazonBedrockEmbeddingFunction. - - Args: - session (boto3.Session): The boto3 session to use. - model_name (str, optional): Identifier of the model, defaults to "amazon.titan-embed-text-v1" - **kwargs: Additional arguments to pass to the boto3 client. - - Example: - >>> import boto3 - >>> session = boto3.Session(profile_name="profile", region_name="us-east-1") - >>> bedrock = AmazonBedrockEmbeddingFunction(session=session) - >>> texts = ["Hello, world!", "How are you?"] - >>> embeddings = bedrock(texts) - """ - - self._model_name = model_name - - self._client = session.client( - service_name="bedrock-runtime", - **kwargs, - ) - - def __call__(self, input: Documents) -> Embeddings: - accept = "application/json" - content_type = "application/json" - embeddings = [] - for text in input: - input_body = {"inputText": text} - body = json.dumps(input_body) - response = self._client.invoke_model( - body=body, - modelId=self._model_name, - accept=accept, - contentType=content_type, - ) - embedding = json.load(response.get("body")).get("embedding") - embeddings.append(embedding) - return embeddings - class HuggingFaceEmbeddingServer(EmbeddingFunction[Documents]): """ diff --git a/chromadb/utils/embedding_functions/__init__.py b/chromadb/utils/embedding_functions/__init__.py index e69de29bb2d..fa12acab253 100644 --- a/chromadb/utils/embedding_functions/__init__.py +++ b/chromadb/utils/embedding_functions/__init__.py @@ -0,0 +1,2 @@ + +from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import AmazonBedrockEmbeddingFunction \ No newline at end of file diff --git a/chromadb/utils/embedding_functions/amazon_bedrock_embedding_function.py b/chromadb/utils/embedding_functions/amazon_bedrock_embedding_function.py new file mode 100644 index 00000000000..11a3fba875c --- /dev/null +++ b/chromadb/utils/embedding_functions/amazon_bedrock_embedding_function.py @@ -0,0 +1,60 @@ +import logging + +from chromadb.api.types import ( + Documents, + EmbeddingFunction, + Embeddings, +) + +from typing import Any +import json + +logger = logging.getLogger(__name__) + + + +class AmazonBedrockEmbeddingFunction(EmbeddingFunction[Documents]): + def __init__( + self, + session: "boto3.Session", # noqa: F821 # Quote for forward reference + model_name: str = "amazon.titan-embed-text-v1", + **kwargs: Any, + ): + """Initialize AmazonBedrockEmbeddingFunction. + + Args: + session (boto3.Session): The boto3 session to use. + model_name (str, optional): Identifier of the model, defaults to "amazon.titan-embed-text-v1" + **kwargs: Additional arguments to pass to the boto3 client. + + Example: + >>> import boto3 + >>> session = boto3.Session(profile_name="profile", region_name="us-east-1") + >>> bedrock = AmazonBedrockEmbeddingFunction(session=session) + >>> texts = ["Hello, world!", "How are you?"] + >>> embeddings = bedrock(texts) + """ + + self._model_name = model_name + + self._client = session.client( + service_name="bedrock-runtime", + **kwargs, + ) + + def __call__(self, input: Documents) -> Embeddings: + accept = "application/json" + content_type = "application/json" + embeddings = [] + for text in input: + input_body = {"inputText": text} + body = json.dumps(input_body) + response = self._client.invoke_model( + body=body, + modelId=self._model_name, + accept=accept, + contentType=content_type, + ) + embedding = json.load(response.get("body")).get("embedding") + embeddings.append(embedding) + return embeddings \ No newline at end of file From 6e4f190eed54d9bf337072c11e0d0afac8b9b9e5 Mon Sep 17 00:00:00 2001 From: David Reguera Date: Sat, 13 Apr 2024 11:43:53 +0200 Subject: [PATCH 03/23] 1965 - Move over `create_langchain_embedding` --- chromadb/utils/embedding_functions.old.py | 63 ------------ .../utils/embedding_functions/__init__.py | 3 +- .../chroma_langchain_embedding_function.py | 96 +++++++++++++++++++ 3 files changed, 98 insertions(+), 64 deletions(-) create mode 100644 chromadb/utils/embedding_functions/chroma_langchain_embedding_function.py diff --git a/chromadb/utils/embedding_functions.old.py b/chromadb/utils/embedding_functions.old.py index 08350b493c6..5c54431345a 100644 --- a/chromadb/utils/embedding_functions.old.py +++ b/chromadb/utils/embedding_functions.old.py @@ -852,69 +852,6 @@ def __call__(self, input: Documents) -> Embeddings: Embeddings, self._session.post(self._api_url, json={"inputs": input}).json() ) - -def create_langchain_embedding(langchain_embdding_fn: Any): # type: ignore - try: - from langchain_core.embeddings import Embeddings as LangchainEmbeddings - except ImportError: - raise ValueError( - "The langchain_core python package is not installed. Please install it with `pip install langchain-core`" - ) - - class ChromaLangchainEmbeddingFunction( - LangchainEmbeddings, EmbeddingFunction[Union[Documents, Images]] # type: ignore - ): - """ - This class is used as bridge between langchain embedding functions and custom chroma embedding functions. - """ - - def __init__(self, embedding_function: LangchainEmbeddings) -> None: - """ - Initialize the ChromaLangchainEmbeddingFunction - - Args: - embedding_function : The embedding function implementing Embeddings from langchain_core. - """ - self.embedding_function = embedding_function - - def embed_documents(self, documents: Documents) -> List[List[float]]: - return self.embedding_function.embed_documents(documents) # type: ignore - - def embed_query(self, query: str) -> List[float]: - return self.embedding_function.embed_query(query) # type: ignore - - def embed_image(self, uris: List[str]) -> List[List[float]]: - if hasattr(self.embedding_function, "embed_image"): - return self.embedding_function.embed_image(uris) # type: ignore - else: - raise ValueError( - "The provided embedding function does not support image embeddings." - ) - - def __call__(self, input: Documents) -> Embeddings: # type: ignore - """ - Get the embeddings for a list of texts or images. - - Args: - input (Documents | Images): A list of texts or images to get embeddings for. - Images should be provided as a list of URIs passed through the langchain data loader - - Returns: - Embeddings: The embeddings for the texts or images. - - Example: - >>> langchain_embedding = ChromaLangchainEmbeddingFunction(embedding_function=OpenAIEmbeddings(model="text-embedding-3-large")) - >>> texts = ["Hello, world!", "How are you?"] - >>> embeddings = langchain_embedding(texts) - """ - # Due to langchain quirks, the dataloader returns a tuple if the input is uris of images - if input[0] == "images": - return self.embed_image(list(input[1])) # type: ignore - - return self.embed_documents(list(input)) # type: ignore - - return ChromaLangchainEmbeddingFunction(embedding_function=langchain_embdding_fn) - class OllamaEmbeddingFunction(EmbeddingFunction[Documents]): """ diff --git a/chromadb/utils/embedding_functions/__init__.py b/chromadb/utils/embedding_functions/__init__.py index fa12acab253..314abf4441f 100644 --- a/chromadb/utils/embedding_functions/__init__.py +++ b/chromadb/utils/embedding_functions/__init__.py @@ -1,2 +1,3 @@ -from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import AmazonBedrockEmbeddingFunction \ No newline at end of file +from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import AmazonBedrockEmbeddingFunction +from chromadb.utils.embedding_functions.chroma_langchain_embedding_function import create_langchain_embedding \ No newline at end of file diff --git a/chromadb/utils/embedding_functions/chroma_langchain_embedding_function.py b/chromadb/utils/embedding_functions/chroma_langchain_embedding_function.py new file mode 100644 index 00000000000..4908efd2883 --- /dev/null +++ b/chromadb/utils/embedding_functions/chroma_langchain_embedding_function.py @@ -0,0 +1,96 @@ +import hashlib +import logging +from functools import cached_property + +from tenacity import stop_after_attempt, wait_random, retry, retry_if_exception + +from chromadb.api.types import ( + Document, + Documents, + Embedding, + Image, + Images, + EmbeddingFunction, + Embeddings, + is_image, + is_document, +) + +from io import BytesIO +from pathlib import Path +import os +import tarfile +import requests +from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Union, cast +import numpy as np +import numpy.typing as npt +import importlib +import inspect +import json +import sys +import base64 + +logger = logging.getLogger(__name__) + + +def create_langchain_embedding(langchain_embdding_fn: Any): # type: ignore + try: + from langchain_core.embeddings import Embeddings as LangchainEmbeddings + except ImportError: + raise ValueError( + "The langchain_core python package is not installed. Please install it with `pip install langchain-core`" + ) + + class ChromaLangchainEmbeddingFunction( + LangchainEmbeddings, EmbeddingFunction[Union[Documents, Images]] # type: ignore + ): + """ + This class is used as bridge between langchain embedding functions and custom chroma embedding functions. + """ + + def __init__(self, embedding_function: LangchainEmbeddings) -> None: + """ + Initialize the ChromaLangchainEmbeddingFunction + + Args: + embedding_function : The embedding function implementing Embeddings from langchain_core. + """ + self.embedding_function = embedding_function + + def embed_documents(self, documents: Documents) -> List[List[float]]: + return self.embedding_function.embed_documents(documents) # type: ignore + + def embed_query(self, query: str) -> List[float]: + return self.embedding_function.embed_query(query) # type: ignore + + def embed_image(self, uris: List[str]) -> List[List[float]]: + if hasattr(self.embedding_function, "embed_image"): + return self.embedding_function.embed_image(uris) # type: ignore + else: + raise ValueError( + "The provided embedding function does not support image embeddings." + ) + + def __call__(self, input: Documents) -> Embeddings: # type: ignore + """ + Get the embeddings for a list of texts or images. + + Args: + input (Documents | Images): A list of texts or images to get embeddings for. + Images should be provided as a list of URIs passed through the langchain data loader + + Returns: + Embeddings: The embeddings for the texts or images. + + Example: + >>> langchain_embedding = ChromaLangchainEmbeddingFunction(embedding_function=OpenAIEmbeddings(model="text-embedding-3-large")) + >>> texts = ["Hello, world!", "How are you?"] + >>> embeddings = langchain_embedding(texts) + """ + # Due to langchain quirks, the dataloader returns a tuple if the input is uris of images + if input[0] == "images": + return self.embed_image(list(input[1])) # type: ignore + + return self.embed_documents(list(input)) # type: ignore + + return ChromaLangchainEmbeddingFunction(embedding_function=langchain_embdding_fn) \ No newline at end of file From 40dee43392efa4fe14a2d1afd8099646bb563828 Mon Sep 17 00:00:00 2001 From: David Reguera Date: Sat, 13 Apr 2024 11:48:11 +0200 Subject: [PATCH 04/23] 1965 - Move over `CohereEmbeddingFunction` --- chromadb/utils/embedding_functions.old.py | 22 ------------- .../utils/embedding_functions/__init__.py | 3 +- .../cohere_embedding_function.py | 31 +++++++++++++++++++ 3 files changed, 33 insertions(+), 23 deletions(-) create mode 100644 chromadb/utils/embedding_functions/cohere_embedding_function.py diff --git a/chromadb/utils/embedding_functions.old.py b/chromadb/utils/embedding_functions.old.py index 5c54431345a..b61e37ce152 100644 --- a/chromadb/utils/embedding_functions.old.py +++ b/chromadb/utils/embedding_functions.old.py @@ -226,28 +226,6 @@ def __call__(self, input: Documents) -> Embeddings: ) -class CohereEmbeddingFunction(EmbeddingFunction[Documents]): - def __init__(self, api_key: str, model_name: str = "large"): - try: - import cohere - except ImportError: - raise ValueError( - "The cohere python package is not installed. Please install it with `pip install cohere`" - ) - - self._client = cohere.Client(api_key) - self._model_name = model_name - - def __call__(self, input: Documents) -> Embeddings: - # Call Cohere Embedding API for each document. - return [ - embeddings - for embeddings in self._client.embed( - texts=input, model=self._model_name, input_type="search_document" - ) - ] - - class HuggingFaceEmbeddingFunction(EmbeddingFunction[Documents]): """ This class is used to get embeddings for a list of texts using the HuggingFace API. diff --git a/chromadb/utils/embedding_functions/__init__.py b/chromadb/utils/embedding_functions/__init__.py index 314abf4441f..9a2ccc85971 100644 --- a/chromadb/utils/embedding_functions/__init__.py +++ b/chromadb/utils/embedding_functions/__init__.py @@ -1,3 +1,4 @@ from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import AmazonBedrockEmbeddingFunction -from chromadb.utils.embedding_functions.chroma_langchain_embedding_function import create_langchain_embedding \ No newline at end of file +from chromadb.utils.embedding_functions.chroma_langchain_embedding_function import create_langchain_embedding +from chromadb.utils.embedding_functions.cohere_embedding_function import CohereEmbeddingFunction \ No newline at end of file diff --git a/chromadb/utils/embedding_functions/cohere_embedding_function.py b/chromadb/utils/embedding_functions/cohere_embedding_function.py new file mode 100644 index 00000000000..d7b703be231 --- /dev/null +++ b/chromadb/utils/embedding_functions/cohere_embedding_function.py @@ -0,0 +1,31 @@ + +import logging +from chromadb.api.types import ( + Documents, + EmbeddingFunction, + Embeddings, +) + +logger = logging.getLogger(__name__) + + +class CohereEmbeddingFunction(EmbeddingFunction[Documents]): + def __init__(self, api_key: str, model_name: str = "large"): + try: + import cohere + except ImportError: + raise ValueError( + "The cohere python package is not installed. Please install it with `pip install cohere`" + ) + + self._client = cohere.Client(api_key) + self._model_name = model_name + + def __call__(self, input: Documents) -> Embeddings: + # Call Cohere Embedding API for each document. + return [ + embeddings + for embeddings in self._client.embed( + texts=input, model=self._model_name, input_type="search_document" + ) + ] \ No newline at end of file From 385dcc0c1843eea74c5f657a1bc619b207b97e32 Mon Sep 17 00:00:00 2001 From: David Reguera Date: Sat, 13 Apr 2024 11:52:17 +0200 Subject: [PATCH 05/23] 1965 - Move over `google_embedding_function` --- chromadb/utils/embedding_functions.old.py | 103 ---------------- .../utils/embedding_functions/__init__.py | 3 +- .../google_embedding_function.py | 115 ++++++++++++++++++ 3 files changed, 117 insertions(+), 104 deletions(-) create mode 100644 chromadb/utils/embedding_functions/google_embedding_function.py diff --git a/chromadb/utils/embedding_functions.old.py b/chromadb/utils/embedding_functions.old.py index b61e37ce152..be4e18e5183 100644 --- a/chromadb/utils/embedding_functions.old.py +++ b/chromadb/utils/embedding_functions.old.py @@ -557,109 +557,6 @@ def DefaultEmbeddingFunction() -> Optional[EmbeddingFunction[Documents]]: return ONNXMiniLM_L6_V2() -class GooglePalmEmbeddingFunction(EmbeddingFunction[Documents]): - """To use this EmbeddingFunction, you must have the google.generativeai Python package installed and have a PaLM API key.""" - - def __init__(self, api_key: str, model_name: str = "models/embedding-gecko-001"): - if not api_key: - raise ValueError("Please provide a PaLM API key.") - - if not model_name: - raise ValueError("Please provide the model name.") - - try: - import google.generativeai as palm - except ImportError: - raise ValueError( - "The Google Generative AI python package is not installed. Please install it with `pip install google-generativeai`" - ) - - palm.configure(api_key=api_key) - self._palm = palm - self._model_name = model_name - - def __call__(self, input: Documents) -> Embeddings: - return [ - self._palm.generate_embeddings(model=self._model_name, text=text)[ - "embedding" - ] - for text in input - ] - - -class GoogleGenerativeAiEmbeddingFunction(EmbeddingFunction[Documents]): - """To use this EmbeddingFunction, you must have the google.generativeai Python package installed and have a Google API key.""" - - """Use RETRIEVAL_DOCUMENT for the task_type for embedding, and RETRIEVAL_QUERY for the task_type for retrieval.""" - - def __init__( - self, - api_key: str, - model_name: str = "models/embedding-001", - task_type: str = "RETRIEVAL_DOCUMENT", - ): - if not api_key: - raise ValueError("Please provide a Google API key.") - - if not model_name: - raise ValueError("Please provide the model name.") - - try: - import google.generativeai as genai - except ImportError: - raise ValueError( - "The Google Generative AI python package is not installed. Please install it with `pip install google-generativeai`" - ) - - genai.configure(api_key=api_key) - self._genai = genai - self._model_name = model_name - self._task_type = task_type - self._task_title = None - if self._task_type == "RETRIEVAL_DOCUMENT": - self._task_title = "Embedding of single string" - - def __call__(self, input: Documents) -> Embeddings: - return [ - self._genai.embed_content( - model=self._model_name, - content=text, - task_type=self._task_type, - title=self._task_title, - )["embedding"] - for text in input - ] - - -class GoogleVertexEmbeddingFunction(EmbeddingFunction[Documents]): - # Follow API Quickstart for Google Vertex AI - # https://cloud.google.com/vertex-ai/docs/generative-ai/start/quickstarts/api-quickstart - # Information about the text embedding modules in Google Vertex AI - # https://cloud.google.com/vertex-ai/docs/generative-ai/embeddings/get-text-embeddings - def __init__( - self, - api_key: str, - model_name: str = "textembedding-gecko", - project_id: str = "cloud-large-language-models", - region: str = "us-central1", - ): - self._api_url = f"https://{region}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{region}/publishers/goole/models/{model_name}:predict" - self._session = requests.Session() - self._session.headers.update({"Authorization": f"Bearer {api_key}"}) - - def __call__(self, input: Documents) -> Embeddings: - embeddings = [] - for text in input: - response = self._session.post( - self._api_url, json={"instances": [{"content": text}]} - ).json() - - if "predictions" in response: - embeddings.append(response["predictions"]["embeddings"]["values"]) - - return embeddings - - class OpenCLIPEmbeddingFunction(EmbeddingFunction[Union[Documents, Images]]): def __init__( self, diff --git a/chromadb/utils/embedding_functions/__init__.py b/chromadb/utils/embedding_functions/__init__.py index 9a2ccc85971..9cc8e5ad7aa 100644 --- a/chromadb/utils/embedding_functions/__init__.py +++ b/chromadb/utils/embedding_functions/__init__.py @@ -1,4 +1,5 @@ from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import AmazonBedrockEmbeddingFunction from chromadb.utils.embedding_functions.chroma_langchain_embedding_function import create_langchain_embedding -from chromadb.utils.embedding_functions.cohere_embedding_function import CohereEmbeddingFunction \ No newline at end of file +from chromadb.utils.embedding_functions.cohere_embedding_function import CohereEmbeddingFunction +from chromadb.utils.embedding_functions.google_embedding_function import (GoogleGenerativeAiEmbeddingFunction, GooglePalmEmbeddingFunction, GoogleVertexEmbeddingFunction) \ No newline at end of file diff --git a/chromadb/utils/embedding_functions/google_embedding_function.py b/chromadb/utils/embedding_functions/google_embedding_function.py new file mode 100644 index 00000000000..2aea0bd7171 --- /dev/null +++ b/chromadb/utils/embedding_functions/google_embedding_function.py @@ -0,0 +1,115 @@ +import logging + +from chromadb.api.types import ( + Documents, + EmbeddingFunction, + Embeddings, +) + +import requests + +logger = logging.getLogger(__name__) + + + +class GooglePalmEmbeddingFunction(EmbeddingFunction[Documents]): + """To use this EmbeddingFunction, you must have the google.generativeai Python package installed and have a PaLM API key.""" + + def __init__(self, api_key: str, model_name: str = "models/embedding-gecko-001"): + if not api_key: + raise ValueError("Please provide a PaLM API key.") + + if not model_name: + raise ValueError("Please provide the model name.") + + try: + import google.generativeai as palm + except ImportError: + raise ValueError( + "The Google Generative AI python package is not installed. Please install it with `pip install google-generativeai`" + ) + + palm.configure(api_key=api_key) + self._palm = palm + self._model_name = model_name + + def __call__(self, input: Documents) -> Embeddings: + return [ + self._palm.generate_embeddings(model=self._model_name, text=text)[ + "embedding" + ] + for text in input + ] + + +class GoogleGenerativeAiEmbeddingFunction(EmbeddingFunction[Documents]): + """To use this EmbeddingFunction, you must have the google.generativeai Python package installed and have a Google API key.""" + + """Use RETRIEVAL_DOCUMENT for the task_type for embedding, and RETRIEVAL_QUERY for the task_type for retrieval.""" + + def __init__( + self, + api_key: str, + model_name: str = "models/embedding-001", + task_type: str = "RETRIEVAL_DOCUMENT", + ): + if not api_key: + raise ValueError("Please provide a Google API key.") + + if not model_name: + raise ValueError("Please provide the model name.") + + try: + import google.generativeai as genai + except ImportError: + raise ValueError( + "The Google Generative AI python package is not installed. Please install it with `pip install google-generativeai`" + ) + + genai.configure(api_key=api_key) + self._genai = genai + self._model_name = model_name + self._task_type = task_type + self._task_title = None + if self._task_type == "RETRIEVAL_DOCUMENT": + self._task_title = "Embedding of single string" + + def __call__(self, input: Documents) -> Embeddings: + return [ + self._genai.embed_content( + model=self._model_name, + content=text, + task_type=self._task_type, + title=self._task_title, + )["embedding"] + for text in input + ] + + +class GoogleVertexEmbeddingFunction(EmbeddingFunction[Documents]): + # Follow API Quickstart for Google Vertex AI + # https://cloud.google.com/vertex-ai/docs/generative-ai/start/quickstarts/api-quickstart + # Information about the text embedding modules in Google Vertex AI + # https://cloud.google.com/vertex-ai/docs/generative-ai/embeddings/get-text-embeddings + def __init__( + self, + api_key: str, + model_name: str = "textembedding-gecko", + project_id: str = "cloud-large-language-models", + region: str = "us-central1", + ): + self._api_url = f"https://{region}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{region}/publishers/goole/models/{model_name}:predict" + self._session = requests.Session() + self._session.headers.update({"Authorization": f"Bearer {api_key}"}) + + def __call__(self, input: Documents) -> Embeddings: + embeddings = [] + for text in input: + response = self._session.post( + self._api_url, json={"instances": [{"content": text}]} + ).json() + + if "predictions" in response: + embeddings.append(response["predictions"]["embeddings"]["values"]) + + return embeddings \ No newline at end of file From ed206d665912a83f1151fd1d95e023ef66172bd9 Mon Sep 17 00:00:00 2001 From: David Reguera Date: Sat, 13 Apr 2024 11:53:41 +0200 Subject: [PATCH 06/23] 1965 - Move over `huggingface_embedding_function` --- chromadb/utils/embedding_functions.old.py | 88 --------------- .../utils/embedding_functions/__init__.py | 3 +- .../huggingface_embedding_function.py | 102 ++++++++++++++++++ 3 files changed, 104 insertions(+), 89 deletions(-) create mode 100644 chromadb/utils/embedding_functions/huggingface_embedding_function.py diff --git a/chromadb/utils/embedding_functions.old.py b/chromadb/utils/embedding_functions.old.py index be4e18e5183..f5bb9812d93 100644 --- a/chromadb/utils/embedding_functions.old.py +++ b/chromadb/utils/embedding_functions.old.py @@ -226,51 +226,6 @@ def __call__(self, input: Documents) -> Embeddings: ) -class HuggingFaceEmbeddingFunction(EmbeddingFunction[Documents]): - """ - This class is used to get embeddings for a list of texts using the HuggingFace API. - It requires an API key and a model name. The default model name is "sentence-transformers/all-MiniLM-L6-v2". - """ - - def __init__( - self, api_key: str, model_name: str = "sentence-transformers/all-MiniLM-L6-v2" - ): - """ - Initialize the HuggingFaceEmbeddingFunction. - - Args: - api_key (str): Your API key for the HuggingFace API. - model_name (str, optional): The name of the model to use for text embeddings. Defaults to "sentence-transformers/all-MiniLM-L6-v2". - """ - self._api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{model_name}" - self._session = requests.Session() - self._session.headers.update({"Authorization": f"Bearer {api_key}"}) - - def __call__(self, input: Documents) -> Embeddings: - """ - Get the embeddings for a list of texts. - - Args: - texts (Documents): A list of texts to get embeddings for. - - Returns: - Embeddings: The embeddings for the texts. - - Example: - >>> hugging_face = HuggingFaceEmbeddingFunction(api_key="your_api_key") - >>> texts = ["Hello, world!", "How are you?"] - >>> embeddings = hugging_face(texts) - """ - # Call HuggingFace Embedding API for each document - return cast( - Embeddings, - self._session.post( - self._api_url, - json={"inputs": input, "options": {"wait_for_model": True}}, - ).json(), - ) - - class JinaEmbeddingFunction(EmbeddingFunction[Documents]): """ This class is used to get embeddings for a list of texts using the Jina AI API. @@ -684,49 +639,6 @@ def __call__(self, input: Union[Documents, Images]) -> Embeddings: return embeddings - -class HuggingFaceEmbeddingServer(EmbeddingFunction[Documents]): - """ - This class is used to get embeddings for a list of texts using the HuggingFace Embedding server (https://github.com/huggingface/text-embeddings-inference). - The embedding model is configured in the server. - """ - - def __init__(self, url: str): - """ - Initialize the HuggingFaceEmbeddingServer. - - Args: - url (str): The URL of the HuggingFace Embedding Server. - """ - try: - import requests - except ImportError: - raise ValueError( - "The requests python package is not installed. Please install it with `pip install requests`" - ) - self._api_url = f"{url}" - self._session = requests.Session() - - def __call__(self, input: Documents) -> Embeddings: - """ - Get the embeddings for a list of texts. - - Args: - texts (Documents): A list of texts to get embeddings for. - - Returns: - Embeddings: The embeddings for the texts. - - Example: - >>> hugging_face = HuggingFaceEmbeddingServer(url="http://localhost:8080/embed") - >>> texts = ["Hello, world!", "How are you?"] - >>> embeddings = hugging_face(texts) - """ - # Call HuggingFace Embedding Server API for each document - return cast( - Embeddings, self._session.post(self._api_url, json={"inputs": input}).json() - ) - class OllamaEmbeddingFunction(EmbeddingFunction[Documents]): """ diff --git a/chromadb/utils/embedding_functions/__init__.py b/chromadb/utils/embedding_functions/__init__.py index 9cc8e5ad7aa..50860f9c174 100644 --- a/chromadb/utils/embedding_functions/__init__.py +++ b/chromadb/utils/embedding_functions/__init__.py @@ -2,4 +2,5 @@ from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import AmazonBedrockEmbeddingFunction from chromadb.utils.embedding_functions.chroma_langchain_embedding_function import create_langchain_embedding from chromadb.utils.embedding_functions.cohere_embedding_function import CohereEmbeddingFunction -from chromadb.utils.embedding_functions.google_embedding_function import (GoogleGenerativeAiEmbeddingFunction, GooglePalmEmbeddingFunction, GoogleVertexEmbeddingFunction) \ No newline at end of file +from chromadb.utils.embedding_functions.google_embedding_function import (GoogleGenerativeAiEmbeddingFunction, GooglePalmEmbeddingFunction, GoogleVertexEmbeddingFunction) +from chromadb.utils.embedding_functions.huggingface_embedding_function import (HuggingFaceEmbeddingFunction, HuggingFaceEmbeddingServer) \ No newline at end of file diff --git a/chromadb/utils/embedding_functions/huggingface_embedding_function.py b/chromadb/utils/embedding_functions/huggingface_embedding_function.py new file mode 100644 index 00000000000..95230ea696d --- /dev/null +++ b/chromadb/utils/embedding_functions/huggingface_embedding_function.py @@ -0,0 +1,102 @@ +import logging + +from chromadb.api.types import ( + Documents, + EmbeddingFunction, + Embeddings, +) + +from typing import cast +import requests + +logger = logging.getLogger(__name__) + + + +class HuggingFaceEmbeddingFunction(EmbeddingFunction[Documents]): + """ + This class is used to get embeddings for a list of texts using the HuggingFace API. + It requires an API key and a model name. The default model name is "sentence-transformers/all-MiniLM-L6-v2". + """ + + def __init__( + self, api_key: str, model_name: str = "sentence-transformers/all-MiniLM-L6-v2" + ): + """ + Initialize the HuggingFaceEmbeddingFunction. + + Args: + api_key (str): Your API key for the HuggingFace API. + model_name (str, optional): The name of the model to use for text embeddings. Defaults to "sentence-transformers/all-MiniLM-L6-v2". + """ + self._api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{model_name}" + self._session = requests.Session() + self._session.headers.update({"Authorization": f"Bearer {api_key}"}) + + def __call__(self, input: Documents) -> Embeddings: + """ + Get the embeddings for a list of texts. + + Args: + texts (Documents): A list of texts to get embeddings for. + + Returns: + Embeddings: The embeddings for the texts. + + Example: + >>> hugging_face = HuggingFaceEmbeddingFunction(api_key="your_api_key") + >>> texts = ["Hello, world!", "How are you?"] + >>> embeddings = hugging_face(texts) + """ + # Call HuggingFace Embedding API for each document + return cast( + Embeddings, + self._session.post( + self._api_url, + json={"inputs": input, "options": {"wait_for_model": True}}, + ).json(), + ) + + + +class HuggingFaceEmbeddingServer(EmbeddingFunction[Documents]): + """ + This class is used to get embeddings for a list of texts using the HuggingFace Embedding server (https://github.com/huggingface/text-embeddings-inference). + The embedding model is configured in the server. + """ + + def __init__(self, url: str): + """ + Initialize the HuggingFaceEmbeddingServer. + + Args: + url (str): The URL of the HuggingFace Embedding Server. + """ + try: + import requests + except ImportError: + raise ValueError( + "The requests python package is not installed. Please install it with `pip install requests`" + ) + self._api_url = f"{url}" + self._session = requests.Session() + + def __call__(self, input: Documents) -> Embeddings: + """ + Get the embeddings for a list of texts. + + Args: + texts (Documents): A list of texts to get embeddings for. + + Returns: + Embeddings: The embeddings for the texts. + + Example: + >>> hugging_face = HuggingFaceEmbeddingServer(url="http://localhost:8080/embed") + >>> texts = ["Hello, world!", "How are you?"] + >>> embeddings = hugging_face(texts) + """ + # Call HuggingFace Embedding Server API for each document + return cast( + Embeddings, self._session.post(self._api_url, json={"inputs": input}).json() + ) \ No newline at end of file From 50076e4f1018ff4e9ed455d7b58d425a471ed46c Mon Sep 17 00:00:00 2001 From: David Reguera Date: Sat, 13 Apr 2024 11:55:53 +0200 Subject: [PATCH 07/23] 1965 - Move over `InstructorEmbeddingFunction` --- chromadb/utils/embedding_functions.old.py | 27 ------------- .../utils/embedding_functions/__init__.py | 3 +- .../instructor_embedding_function.py | 39 +++++++++++++++++++ 3 files changed, 41 insertions(+), 28 deletions(-) create mode 100644 chromadb/utils/embedding_functions/instructor_embedding_function.py diff --git a/chromadb/utils/embedding_functions.old.py b/chromadb/utils/embedding_functions.old.py index f5bb9812d93..1c83c940e8b 100644 --- a/chromadb/utils/embedding_functions.old.py +++ b/chromadb/utils/embedding_functions.old.py @@ -278,33 +278,6 @@ def __call__(self, input: Documents) -> Embeddings: return cast(Embeddings, [result["embedding"] for result in sorted_embeddings]) -class InstructorEmbeddingFunction(EmbeddingFunction[Documents]): - # If you have a GPU with at least 6GB try model_name = "hkunlp/instructor-xl" and device = "cuda" - # for a full list of options: https://github.com/HKUNLP/instructor-embedding#model-list - def __init__( - self, - model_name: str = "hkunlp/instructor-base", - device: str = "cpu", - instruction: Optional[str] = None, - ): - try: - from InstructorEmbedding import INSTRUCTOR - except ImportError: - raise ValueError( - "The InstructorEmbedding python package is not installed. Please install it with `pip install InstructorEmbedding`" - ) - self._model = INSTRUCTOR(model_name, device=device) - self._instruction = instruction - - def __call__(self, input: Documents) -> Embeddings: - if self._instruction is None: - return cast(Embeddings, self._model.encode(input).tolist()) - - texts_with_instructions = [[self._instruction, text] for text in input] - - return cast(Embeddings, self._model.encode(texts_with_instructions).tolist()) - - # In order to remove dependencies on sentence-transformers, which in turn depends on # pytorch and sentence-piece we have created a default ONNX embedding function that # implements the same functionality as "all-MiniLM-L6-v2" from sentence-transformers. diff --git a/chromadb/utils/embedding_functions/__init__.py b/chromadb/utils/embedding_functions/__init__.py index 50860f9c174..c0f5e553e1c 100644 --- a/chromadb/utils/embedding_functions/__init__.py +++ b/chromadb/utils/embedding_functions/__init__.py @@ -3,4 +3,5 @@ from chromadb.utils.embedding_functions.chroma_langchain_embedding_function import create_langchain_embedding from chromadb.utils.embedding_functions.cohere_embedding_function import CohereEmbeddingFunction from chromadb.utils.embedding_functions.google_embedding_function import (GoogleGenerativeAiEmbeddingFunction, GooglePalmEmbeddingFunction, GoogleVertexEmbeddingFunction) -from chromadb.utils.embedding_functions.huggingface_embedding_function import (HuggingFaceEmbeddingFunction, HuggingFaceEmbeddingServer) \ No newline at end of file +from chromadb.utils.embedding_functions.huggingface_embedding_function import (HuggingFaceEmbeddingFunction, HuggingFaceEmbeddingServer) +from chromadb.utils.embedding_functions.instructor_embedding_function import InstructorEmbeddingFunction \ No newline at end of file diff --git a/chromadb/utils/embedding_functions/instructor_embedding_function.py b/chromadb/utils/embedding_functions/instructor_embedding_function.py new file mode 100644 index 00000000000..13d8c715aeb --- /dev/null +++ b/chromadb/utils/embedding_functions/instructor_embedding_function.py @@ -0,0 +1,39 @@ +import logging + +from chromadb.api.types import ( + Documents, + EmbeddingFunction, + Embeddings, +) + +from typing import Optional, cast + +logger = logging.getLogger(__name__) + + + +class InstructorEmbeddingFunction(EmbeddingFunction[Documents]): + # If you have a GPU with at least 6GB try model_name = "hkunlp/instructor-xl" and device = "cuda" + # for a full list of options: https://github.com/HKUNLP/instructor-embedding#model-list + def __init__( + self, + model_name: str = "hkunlp/instructor-base", + device: str = "cpu", + instruction: Optional[str] = None, + ): + try: + from InstructorEmbedding import INSTRUCTOR + except ImportError: + raise ValueError( + "The InstructorEmbedding python package is not installed. Please install it with `pip install InstructorEmbedding`" + ) + self._model = INSTRUCTOR(model_name, device=device) + self._instruction = instruction + + def __call__(self, input: Documents) -> Embeddings: + if self._instruction is None: + return cast(Embeddings, self._model.encode(input).tolist()) + + texts_with_instructions = [[self._instruction, text] for text in input] + + return cast(Embeddings, self._model.encode(texts_with_instructions).tolist()) \ No newline at end of file From 0aeb92ecbe793edbfb41726acd476e9f6e365e55 Mon Sep 17 00:00:00 2001 From: David Reguera Date: Sat, 13 Apr 2024 11:57:27 +0200 Subject: [PATCH 08/23] 1965 - Move over `JinaEmbeddingFunction` --- chromadb/utils/embedding_functions.old.py | 52 --------------- .../utils/embedding_functions/__init__.py | 3 +- .../jina_embedding_function.py | 65 +++++++++++++++++++ 3 files changed, 67 insertions(+), 53 deletions(-) create mode 100644 chromadb/utils/embedding_functions/jina_embedding_function.py diff --git a/chromadb/utils/embedding_functions.old.py b/chromadb/utils/embedding_functions.old.py index 1c83c940e8b..db3c73d2360 100644 --- a/chromadb/utils/embedding_functions.old.py +++ b/chromadb/utils/embedding_functions.old.py @@ -226,58 +226,6 @@ def __call__(self, input: Documents) -> Embeddings: ) -class JinaEmbeddingFunction(EmbeddingFunction[Documents]): - """ - This class is used to get embeddings for a list of texts using the Jina AI API. - It requires an API key and a model name. The default model name is "jina-embeddings-v2-base-en". - """ - - def __init__(self, api_key: str, model_name: str = "jina-embeddings-v2-base-en"): - """ - Initialize the JinaEmbeddingFunction. - - Args: - api_key (str): Your API key for the Jina AI API. - model_name (str, optional): The name of the model to use for text embeddings. Defaults to "jina-embeddings-v2-base-en". - """ - self._model_name = model_name - self._api_url = "https://api.jina.ai/v1/embeddings" - self._session = requests.Session() - self._session.headers.update( - {"Authorization": f"Bearer {api_key}", "Accept-Encoding": "identity"} - ) - - def __call__(self, input: Documents) -> Embeddings: - """ - Get the embeddings for a list of texts. - - Args: - texts (Documents): A list of texts to get embeddings for. - - Returns: - Embeddings: The embeddings for the texts. - - Example: - >>> jina_ai_fn = JinaEmbeddingFunction(api_key="your_api_key") - >>> input = ["Hello, world!", "How are you?"] - >>> embeddings = jina_ai_fn(input) - """ - # Call Jina AI Embedding API - resp = self._session.post( - self._api_url, json={"input": input, "model": self._model_name} - ).json() - if "data" not in resp: - raise RuntimeError(resp["detail"]) - - embeddings = resp["data"] - - # Sort resulting embeddings by index - sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) - - # Return just the embeddings - return cast(Embeddings, [result["embedding"] for result in sorted_embeddings]) - - # In order to remove dependencies on sentence-transformers, which in turn depends on # pytorch and sentence-piece we have created a default ONNX embedding function that # implements the same functionality as "all-MiniLM-L6-v2" from sentence-transformers. diff --git a/chromadb/utils/embedding_functions/__init__.py b/chromadb/utils/embedding_functions/__init__.py index c0f5e553e1c..4266924de10 100644 --- a/chromadb/utils/embedding_functions/__init__.py +++ b/chromadb/utils/embedding_functions/__init__.py @@ -4,4 +4,5 @@ from chromadb.utils.embedding_functions.cohere_embedding_function import CohereEmbeddingFunction from chromadb.utils.embedding_functions.google_embedding_function import (GoogleGenerativeAiEmbeddingFunction, GooglePalmEmbeddingFunction, GoogleVertexEmbeddingFunction) from chromadb.utils.embedding_functions.huggingface_embedding_function import (HuggingFaceEmbeddingFunction, HuggingFaceEmbeddingServer) -from chromadb.utils.embedding_functions.instructor_embedding_function import InstructorEmbeddingFunction \ No newline at end of file +from chromadb.utils.embedding_functions.instructor_embedding_function import InstructorEmbeddingFunction +from chromadb.utils.embedding_functions.jina_embedding_function import JinaEmbeddingFunction \ No newline at end of file diff --git a/chromadb/utils/embedding_functions/jina_embedding_function.py b/chromadb/utils/embedding_functions/jina_embedding_function.py new file mode 100644 index 00000000000..26f71c7ee18 --- /dev/null +++ b/chromadb/utils/embedding_functions/jina_embedding_function.py @@ -0,0 +1,65 @@ +import logging + +from chromadb.api.types import ( + Documents, + EmbeddingFunction, + Embeddings, +) + +import requests +from typing import cast + + +logger = logging.getLogger(__name__) + + +class JinaEmbeddingFunction(EmbeddingFunction[Documents]): + """ + This class is used to get embeddings for a list of texts using the Jina AI API. + It requires an API key and a model name. The default model name is "jina-embeddings-v2-base-en". + """ + + def __init__(self, api_key: str, model_name: str = "jina-embeddings-v2-base-en"): + """ + Initialize the JinaEmbeddingFunction. + + Args: + api_key (str): Your API key for the Jina AI API. + model_name (str, optional): The name of the model to use for text embeddings. Defaults to "jina-embeddings-v2-base-en". + """ + self._model_name = model_name + self._api_url = "https://api.jina.ai/v1/embeddings" + self._session = requests.Session() + self._session.headers.update( + {"Authorization": f"Bearer {api_key}", "Accept-Encoding": "identity"} + ) + + def __call__(self, input: Documents) -> Embeddings: + """ + Get the embeddings for a list of texts. + + Args: + texts (Documents): A list of texts to get embeddings for. + + Returns: + Embeddings: The embeddings for the texts. + + Example: + >>> jina_ai_fn = JinaEmbeddingFunction(api_key="your_api_key") + >>> input = ["Hello, world!", "How are you?"] + >>> embeddings = jina_ai_fn(input) + """ + # Call Jina AI Embedding API + resp = self._session.post( + self._api_url, json={"input": input, "model": self._model_name} + ).json() + if "data" not in resp: + raise RuntimeError(resp["detail"]) + + embeddings = resp["data"] + + # Sort resulting embeddings by index + sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) + + # Return just the embeddings + return cast(Embeddings, [result["embedding"] for result in sorted_embeddings]) \ No newline at end of file From 18926e9f0e1afdc9f8d42b7ed170800729ef58d3 Mon Sep 17 00:00:00 2001 From: David Reguera Date: Sat, 13 Apr 2024 11:58:38 +0200 Subject: [PATCH 09/23] 1965 - Move over `OllamaEmbeddingFunction` --- chromadb/utils/embedding_functions.old.py | 56 ---------------- .../utils/embedding_functions/__init__.py | 3 +- .../ollama_embedding_function.py | 67 +++++++++++++++++++ 3 files changed, 69 insertions(+), 57 deletions(-) create mode 100644 chromadb/utils/embedding_functions/ollama_embedding_function.py diff --git a/chromadb/utils/embedding_functions.old.py b/chromadb/utils/embedding_functions.old.py index db3c73d2360..49a4a3ceb4b 100644 --- a/chromadb/utils/embedding_functions.old.py +++ b/chromadb/utils/embedding_functions.old.py @@ -560,62 +560,6 @@ def __call__(self, input: Union[Documents, Images]) -> Embeddings: return embeddings - -class OllamaEmbeddingFunction(EmbeddingFunction[Documents]): - """ - This class is used to generate embeddings for a list of texts using the Ollama Embedding API (https://github.com/ollama/ollama/blob/main/docs/api.md#generate-embeddings). - """ - - def __init__(self, url: str, model_name: str) -> None: - """ - Initialize the Ollama Embedding Function. - - Args: - url (str): The URL of the Ollama Server. - model_name (str): The name of the model to use for text embeddings. E.g. "nomic-embed-text" (see https://ollama.com/library for available models). - """ - try: - import requests - except ImportError: - raise ValueError( - "The requests python package is not installed. Please install it with `pip install requests`" - ) - self._api_url = f"{url}" - self._model_name = model_name - self._session = requests.Session() - - def __call__(self, input: Documents) -> Embeddings: - """ - Get the embeddings for a list of texts. - - Args: - input (Documents): A list of texts to get embeddings for. - - Returns: - Embeddings: The embeddings for the texts. - - Example: - >>> ollama_ef = OllamaEmbeddingFunction(url="http://localhost:11434/api/embeddings", model_name="nomic-embed-text") - >>> texts = ["Hello, world!", "How are you?"] - >>> embeddings = ollama_ef(texts) - """ - # Call Ollama Server API for each document - texts = input if isinstance(input, list) else [input] - embeddings = [ - self._session.post( - self._api_url, json={"model": self._model_name, "prompt": text} - ).json() - for text in texts - ] - return cast( - Embeddings, - [ - embedding["embedding"] - for embedding in embeddings - if "embedding" in embedding - ], - ) - # List of all classes in this module _classes = [ diff --git a/chromadb/utils/embedding_functions/__init__.py b/chromadb/utils/embedding_functions/__init__.py index 4266924de10..df5d0940b6c 100644 --- a/chromadb/utils/embedding_functions/__init__.py +++ b/chromadb/utils/embedding_functions/__init__.py @@ -5,4 +5,5 @@ from chromadb.utils.embedding_functions.google_embedding_function import (GoogleGenerativeAiEmbeddingFunction, GooglePalmEmbeddingFunction, GoogleVertexEmbeddingFunction) from chromadb.utils.embedding_functions.huggingface_embedding_function import (HuggingFaceEmbeddingFunction, HuggingFaceEmbeddingServer) from chromadb.utils.embedding_functions.instructor_embedding_function import InstructorEmbeddingFunction -from chromadb.utils.embedding_functions.jina_embedding_function import JinaEmbeddingFunction \ No newline at end of file +from chromadb.utils.embedding_functions.jina_embedding_function import JinaEmbeddingFunction +from chromadb.utils.embedding_functions.ollama_embedding_function import OllamaEmbeddingFunction \ No newline at end of file diff --git a/chromadb/utils/embedding_functions/ollama_embedding_function.py b/chromadb/utils/embedding_functions/ollama_embedding_function.py new file mode 100644 index 00000000000..b2b896d7cc9 --- /dev/null +++ b/chromadb/utils/embedding_functions/ollama_embedding_function.py @@ -0,0 +1,67 @@ +import logging + +from chromadb.api.types import ( + Documents, + EmbeddingFunction, + Embeddings, +) + +from typing import cast + +logger = logging.getLogger(__name__) + + +class OllamaEmbeddingFunction(EmbeddingFunction[Documents]): + """ + This class is used to generate embeddings for a list of texts using the Ollama Embedding API (https://github.com/ollama/ollama/blob/main/docs/api.md#generate-embeddings). + """ + + def __init__(self, url: str, model_name: str) -> None: + """ + Initialize the Ollama Embedding Function. + + Args: + url (str): The URL of the Ollama Server. + model_name (str): The name of the model to use for text embeddings. E.g. "nomic-embed-text" (see https://ollama.com/library for available models). + """ + try: + import requests + except ImportError: + raise ValueError( + "The requests python package is not installed. Please install it with `pip install requests`" + ) + self._api_url = f"{url}" + self._model_name = model_name + self._session = requests.Session() + + def __call__(self, input: Documents) -> Embeddings: + """ + Get the embeddings for a list of texts. + + Args: + input (Documents): A list of texts to get embeddings for. + + Returns: + Embeddings: The embeddings for the texts. + + Example: + >>> ollama_ef = OllamaEmbeddingFunction(url="http://localhost:11434/api/embeddings", model_name="nomic-embed-text") + >>> texts = ["Hello, world!", "How are you?"] + >>> embeddings = ollama_ef(texts) + """ + # Call Ollama Server API for each document + texts = input if isinstance(input, list) else [input] + embeddings = [ + self._session.post( + self._api_url, json={"model": self._model_name, "prompt": text} + ).json() + for text in texts + ] + return cast( + Embeddings, + [ + embedding["embedding"] + for embedding in embeddings + if "embedding" in embedding + ], + ) \ No newline at end of file From 1ec3d2a449b49c13880a9694ff4d8c68ba489537 Mon Sep 17 00:00:00 2001 From: David Reguera Date: Sat, 13 Apr 2024 12:02:07 +0200 Subject: [PATCH 10/23] 1965 - Move over `ONNXMiniLM_L6_V2` --- chromadb/utils/embedding_functions.old.py | 213 ---------------- .../utils/embedding_functions/__init__.py | 3 +- .../embedding_functions/onnx_mini_lm_l6_v2.py | 236 ++++++++++++++++++ 3 files changed, 238 insertions(+), 214 deletions(-) create mode 100644 chromadb/utils/embedding_functions/onnx_mini_lm_l6_v2.py diff --git a/chromadb/utils/embedding_functions.old.py b/chromadb/utils/embedding_functions.old.py index 49a4a3ceb4b..7bc55c88a69 100644 --- a/chromadb/utils/embedding_functions.old.py +++ b/chromadb/utils/embedding_functions.old.py @@ -35,23 +35,10 @@ except ImportError: is_thin_client = False -if TYPE_CHECKING: - from onnxruntime import InferenceSession - from tokenizers import Tokenizer logger = logging.getLogger(__name__) -def _verify_sha256(fname: str, expected_sha256: str) -> bool: - sha256_hash = hashlib.sha256() - with open(fname, "rb") as f: - # Read and update hash in chunks to avoid using too much memory - for byte_block in iter(lambda: f.read(4096), b""): - sha256_hash.update(byte_block) - - return sha256_hash.hexdigest() == expected_sha256 - - class SentenceTransformerEmbeddingFunction(EmbeddingFunction[Documents]): # Since we do dynamic imports we have to type this as Any models: Dict[str, Any] = {} @@ -226,206 +213,6 @@ def __call__(self, input: Documents) -> Embeddings: ) -# In order to remove dependencies on sentence-transformers, which in turn depends on -# pytorch and sentence-piece we have created a default ONNX embedding function that -# implements the same functionality as "all-MiniLM-L6-v2" from sentence-transformers. -# visit https://github.com/chroma-core/onnx-embedding for the source code to generate -# and verify the ONNX model. -class ONNXMiniLM_L6_V2(EmbeddingFunction[Documents]): - MODEL_NAME = "all-MiniLM-L6-v2" - DOWNLOAD_PATH = Path.home() / ".cache" / "chroma" / "onnx_models" / MODEL_NAME - EXTRACTED_FOLDER_NAME = "onnx" - ARCHIVE_FILENAME = "onnx.tar.gz" - MODEL_DOWNLOAD_URL = ( - "https://chroma-onnx-models.s3.amazonaws.com/all-MiniLM-L6-v2/onnx.tar.gz" - ) - _MODEL_SHA256 = "913d7300ceae3b2dbc2c50d1de4baacab4be7b9380491c27fab7418616a16ec3" - - # https://github.com/python/mypy/issues/7291 mypy makes you type the constructor if - # no args - def __init__(self, preferred_providers: Optional[List[str]] = None) -> None: - # Import dependencies on demand to mirror other embedding functions. This - # breaks typechecking, thus the ignores. - # convert the list to set for unique values - if preferred_providers and not all( - [isinstance(i, str) for i in preferred_providers] - ): - raise ValueError("Preferred providers must be a list of strings") - # check for duplicate providers - if preferred_providers and len(preferred_providers) != len( - set(preferred_providers) - ): - raise ValueError("Preferred providers must be unique") - self._preferred_providers = preferred_providers - try: - # Equivalent to import onnxruntime - self.ort = importlib.import_module("onnxruntime") - except ImportError: - raise ValueError( - "The onnxruntime python package is not installed. Please install it with `pip install onnxruntime`" - ) - try: - # Equivalent to from tokenizers import Tokenizer - self.Tokenizer = importlib.import_module("tokenizers").Tokenizer - except ImportError: - raise ValueError( - "The tokenizers python package is not installed. Please install it with `pip install tokenizers`" - ) - try: - # Equivalent to from tqdm import tqdm - self.tqdm = importlib.import_module("tqdm").tqdm - except ImportError: - raise ValueError( - "The tqdm python package is not installed. Please install it with `pip install tqdm`" - ) - - # Borrowed from https://gist.github.com/yanqd0/c13ed29e29432e3cf3e7c38467f42f51 - # Download with tqdm to preserve the sentence-transformers experience - @retry( - reraise=True, - stop=stop_after_attempt(3), - wait=wait_random(min=1, max=3), - retry=retry_if_exception(lambda e: "does not match expected SHA256" in str(e)), - ) - def _download(self, url: str, fname: str, chunk_size: int = 1024) -> None: - resp = requests.get(url, stream=True) - total = int(resp.headers.get("content-length", 0)) - with open(fname, "wb") as file, self.tqdm( - desc=str(fname), - total=total, - unit="iB", - unit_scale=True, - unit_divisor=1024, - ) as bar: - for data in resp.iter_content(chunk_size=chunk_size): - size = file.write(data) - bar.update(size) - if not _verify_sha256(fname, self._MODEL_SHA256): - # if the integrity of the file is not verified, remove it - os.remove(fname) - raise ValueError( - f"Downloaded file {fname} does not match expected SHA256 hash. Corrupted download or malicious file." - ) - - # Use pytorches default epsilon for division by zero - # https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html - def _normalize(self, v: npt.NDArray) -> npt.NDArray: - norm = np.linalg.norm(v, axis=1) - norm[norm == 0] = 1e-12 - return cast(npt.NDArray, v / norm[:, np.newaxis]) - - def _forward(self, documents: List[str], batch_size: int = 32) -> npt.NDArray: - # We need to cast to the correct type because the type checker doesn't know that init_model_and_tokenizer will set the values - self.tokenizer = cast(self.Tokenizer, self.tokenizer) - self.model = cast(self.ort.InferenceSession, self.model) - all_embeddings = [] - for i in range(0, len(documents), batch_size): - batch = documents[i : i + batch_size] - encoded = [self.tokenizer.encode(d) for d in batch] - input_ids = np.array([e.ids for e in encoded]) - attention_mask = np.array([e.attention_mask for e in encoded]) - onnx_input = { - "input_ids": np.array(input_ids, dtype=np.int64), - "attention_mask": np.array(attention_mask, dtype=np.int64), - "token_type_ids": np.array( - [np.zeros(len(e), dtype=np.int64) for e in input_ids], - dtype=np.int64, - ), - } - model_output = self.model.run(None, onnx_input) - last_hidden_state = model_output[0] - # Perform mean pooling with attention weighting - input_mask_expanded = np.broadcast_to( - np.expand_dims(attention_mask, -1), last_hidden_state.shape - ) - embeddings = np.sum(last_hidden_state * input_mask_expanded, 1) / np.clip( - input_mask_expanded.sum(1), a_min=1e-9, a_max=None - ) - embeddings = self._normalize(embeddings).astype(np.float32) - all_embeddings.append(embeddings) - return np.concatenate(all_embeddings) - - @cached_property - def tokenizer(self) -> "Tokenizer": - tokenizer = self.Tokenizer.from_file( - os.path.join( - self.DOWNLOAD_PATH, self.EXTRACTED_FOLDER_NAME, "tokenizer.json" - ) - ) - # max_seq_length = 256, for some reason sentence-transformers uses 256 even though the HF config has a max length of 128 - # https://github.com/UKPLab/sentence-transformers/blob/3e1929fddef16df94f8bc6e3b10598a98f46e62d/docs/_static/html/models_en_sentence_embeddings.html#LL480 - tokenizer.enable_truncation(max_length=256) - tokenizer.enable_padding(pad_id=0, pad_token="[PAD]", length=256) - return tokenizer - - @cached_property - def model(self) -> "InferenceSession": - if self._preferred_providers is None or len(self._preferred_providers) == 0: - if len(self.ort.get_available_providers()) > 0: - logger.debug( - f"WARNING: No ONNX providers provided, defaulting to available providers: " - f"{self.ort.get_available_providers()}" - ) - self._preferred_providers = self.ort.get_available_providers() - elif not set(self._preferred_providers).issubset( - set(self.ort.get_available_providers()) - ): - raise ValueError( - f"Preferred providers must be subset of available providers: {self.ort.get_available_providers()}" - ) - - # Suppress onnxruntime warnings. This produces logspew, mainly when onnx tries to use CoreML, which doesn't fit this model. - so = self.ort.SessionOptions() - so.log_severity_level = 3 - - return self.ort.InferenceSession( - os.path.join(self.DOWNLOAD_PATH, self.EXTRACTED_FOLDER_NAME, "model.onnx"), - # Since 1.9 onnyx runtime requires providers to be specified when there are multiple available - https://onnxruntime.ai/docs/api/python/api_summary.html - # This is probably not ideal but will improve DX as no exceptions will be raised in multi-provider envs - providers=self._preferred_providers, - sess_options=so, - ) - - def __call__(self, input: Documents) -> Embeddings: - # Only download the model when it is actually used - self._download_model_if_not_exists() - return cast(Embeddings, self._forward(input).tolist()) - - def _download_model_if_not_exists(self) -> None: - onnx_files = [ - "config.json", - "model.onnx", - "special_tokens_map.json", - "tokenizer_config.json", - "tokenizer.json", - "vocab.txt", - ] - extracted_folder = os.path.join(self.DOWNLOAD_PATH, self.EXTRACTED_FOLDER_NAME) - onnx_files_exist = True - for f in onnx_files: - if not os.path.exists(os.path.join(extracted_folder, f)): - onnx_files_exist = False - break - # Model is not downloaded yet - if not onnx_files_exist: - os.makedirs(self.DOWNLOAD_PATH, exist_ok=True) - if not os.path.exists( - os.path.join(self.DOWNLOAD_PATH, self.ARCHIVE_FILENAME) - ) or not _verify_sha256( - os.path.join(self.DOWNLOAD_PATH, self.ARCHIVE_FILENAME), - self._MODEL_SHA256, - ): - self._download( - url=self.MODEL_DOWNLOAD_URL, - fname=os.path.join(self.DOWNLOAD_PATH, self.ARCHIVE_FILENAME), - ) - with tarfile.open( - name=os.path.join(self.DOWNLOAD_PATH, self.ARCHIVE_FILENAME), - mode="r:gz", - ) as tar: - tar.extractall(path=self.DOWNLOAD_PATH) - - def DefaultEmbeddingFunction() -> Optional[EmbeddingFunction[Documents]]: if is_thin_client: return None diff --git a/chromadb/utils/embedding_functions/__init__.py b/chromadb/utils/embedding_functions/__init__.py index df5d0940b6c..d2d0728f7f3 100644 --- a/chromadb/utils/embedding_functions/__init__.py +++ b/chromadb/utils/embedding_functions/__init__.py @@ -6,4 +6,5 @@ from chromadb.utils.embedding_functions.huggingface_embedding_function import (HuggingFaceEmbeddingFunction, HuggingFaceEmbeddingServer) from chromadb.utils.embedding_functions.instructor_embedding_function import InstructorEmbeddingFunction from chromadb.utils.embedding_functions.jina_embedding_function import JinaEmbeddingFunction -from chromadb.utils.embedding_functions.ollama_embedding_function import OllamaEmbeddingFunction \ No newline at end of file +from chromadb.utils.embedding_functions.ollama_embedding_function import OllamaEmbeddingFunction +from chromadb.utils.embedding_functions.onnx_mini_lm_l6_v2 import ONNXMiniLM_L6_V2, _verify_sha256 \ No newline at end of file diff --git a/chromadb/utils/embedding_functions/onnx_mini_lm_l6_v2.py b/chromadb/utils/embedding_functions/onnx_mini_lm_l6_v2.py new file mode 100644 index 00000000000..3aff9a37a0e --- /dev/null +++ b/chromadb/utils/embedding_functions/onnx_mini_lm_l6_v2.py @@ -0,0 +1,236 @@ +import hashlib +import logging +from functools import cached_property + +from tenacity import stop_after_attempt, wait_random, retry, retry_if_exception + +from chromadb.api.types import ( + Documents, + EmbeddingFunction, + Embeddings, +) + +from pathlib import Path +import os +import tarfile +import requests +from typing import TYPE_CHECKING, List, Optional, cast +import numpy as np +import numpy.typing as npt +import importlib + +logger = logging.getLogger(__name__) + + +if TYPE_CHECKING: + from onnxruntime import InferenceSession + from tokenizers import Tokenizer + + +def _verify_sha256(fname: str, expected_sha256: str) -> bool: + sha256_hash = hashlib.sha256() + with open(fname, "rb") as f: + # Read and update hash in chunks to avoid using too much memory + for byte_block in iter(lambda: f.read(4096), b""): + sha256_hash.update(byte_block) + + return sha256_hash.hexdigest() == expected_sha256 + +# In order to remove dependencies on sentence-transformers, which in turn depends on +# pytorch and sentence-piece we have created a default ONNX embedding function that +# implements the same functionality as "all-MiniLM-L6-v2" from sentence-transformers. +# visit https://github.com/chroma-core/onnx-embedding for the source code to generate +# and verify the ONNX model. +class ONNXMiniLM_L6_V2(EmbeddingFunction[Documents]): + MODEL_NAME = "all-MiniLM-L6-v2" + DOWNLOAD_PATH = Path.home() / ".cache" / "chroma" / "onnx_models" / MODEL_NAME + EXTRACTED_FOLDER_NAME = "onnx" + ARCHIVE_FILENAME = "onnx.tar.gz" + MODEL_DOWNLOAD_URL = ( + "https://chroma-onnx-models.s3.amazonaws.com/all-MiniLM-L6-v2/onnx.tar.gz" + ) + _MODEL_SHA256 = "913d7300ceae3b2dbc2c50d1de4baacab4be7b9380491c27fab7418616a16ec3" + + # https://github.com/python/mypy/issues/7291 mypy makes you type the constructor if + # no args + def __init__(self, preferred_providers: Optional[List[str]] = None) -> None: + # Import dependencies on demand to mirror other embedding functions. This + # breaks typechecking, thus the ignores. + # convert the list to set for unique values + if preferred_providers and not all( + [isinstance(i, str) for i in preferred_providers] + ): + raise ValueError("Preferred providers must be a list of strings") + # check for duplicate providers + if preferred_providers and len(preferred_providers) != len( + set(preferred_providers) + ): + raise ValueError("Preferred providers must be unique") + self._preferred_providers = preferred_providers + try: + # Equivalent to import onnxruntime + self.ort = importlib.import_module("onnxruntime") + except ImportError: + raise ValueError( + "The onnxruntime python package is not installed. Please install it with `pip install onnxruntime`" + ) + try: + # Equivalent to from tokenizers import Tokenizer + self.Tokenizer = importlib.import_module("tokenizers").Tokenizer + except ImportError: + raise ValueError( + "The tokenizers python package is not installed. Please install it with `pip install tokenizers`" + ) + try: + # Equivalent to from tqdm import tqdm + self.tqdm = importlib.import_module("tqdm").tqdm + except ImportError: + raise ValueError( + "The tqdm python package is not installed. Please install it with `pip install tqdm`" + ) + + # Borrowed from https://gist.github.com/yanqd0/c13ed29e29432e3cf3e7c38467f42f51 + # Download with tqdm to preserve the sentence-transformers experience + @retry( + reraise=True, + stop=stop_after_attempt(3), + wait=wait_random(min=1, max=3), + retry=retry_if_exception(lambda e: "does not match expected SHA256" in str(e)), + ) + def _download(self, url: str, fname: str, chunk_size: int = 1024) -> None: + resp = requests.get(url, stream=True) + total = int(resp.headers.get("content-length", 0)) + with open(fname, "wb") as file, self.tqdm( + desc=str(fname), + total=total, + unit="iB", + unit_scale=True, + unit_divisor=1024, + ) as bar: + for data in resp.iter_content(chunk_size=chunk_size): + size = file.write(data) + bar.update(size) + if not _verify_sha256(fname, self._MODEL_SHA256): + # if the integrity of the file is not verified, remove it + os.remove(fname) + raise ValueError( + f"Downloaded file {fname} does not match expected SHA256 hash. Corrupted download or malicious file." + ) + + # Use pytorches default epsilon for division by zero + # https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html + def _normalize(self, v: npt.NDArray) -> npt.NDArray: + norm = np.linalg.norm(v, axis=1) + norm[norm == 0] = 1e-12 + return cast(npt.NDArray, v / norm[:, np.newaxis]) + + def _forward(self, documents: List[str], batch_size: int = 32) -> npt.NDArray: + # We need to cast to the correct type because the type checker doesn't know that init_model_and_tokenizer will set the values + self.tokenizer = cast(self.Tokenizer, self.tokenizer) + self.model = cast(self.ort.InferenceSession, self.model) + all_embeddings = [] + for i in range(0, len(documents), batch_size): + batch = documents[i : i + batch_size] + encoded = [self.tokenizer.encode(d) for d in batch] + input_ids = np.array([e.ids for e in encoded]) + attention_mask = np.array([e.attention_mask for e in encoded]) + onnx_input = { + "input_ids": np.array(input_ids, dtype=np.int64), + "attention_mask": np.array(attention_mask, dtype=np.int64), + "token_type_ids": np.array( + [np.zeros(len(e), dtype=np.int64) for e in input_ids], + dtype=np.int64, + ), + } + model_output = self.model.run(None, onnx_input) + last_hidden_state = model_output[0] + # Perform mean pooling with attention weighting + input_mask_expanded = np.broadcast_to( + np.expand_dims(attention_mask, -1), last_hidden_state.shape + ) + embeddings = np.sum(last_hidden_state * input_mask_expanded, 1) / np.clip( + input_mask_expanded.sum(1), a_min=1e-9, a_max=None + ) + embeddings = self._normalize(embeddings).astype(np.float32) + all_embeddings.append(embeddings) + return np.concatenate(all_embeddings) + + @cached_property + def tokenizer(self) -> "Tokenizer": + tokenizer = self.Tokenizer.from_file( + os.path.join( + self.DOWNLOAD_PATH, self.EXTRACTED_FOLDER_NAME, "tokenizer.json" + ) + ) + # max_seq_length = 256, for some reason sentence-transformers uses 256 even though the HF config has a max length of 128 + # https://github.com/UKPLab/sentence-transformers/blob/3e1929fddef16df94f8bc6e3b10598a98f46e62d/docs/_static/html/models_en_sentence_embeddings.html#LL480 + tokenizer.enable_truncation(max_length=256) + tokenizer.enable_padding(pad_id=0, pad_token="[PAD]", length=256) + return tokenizer + + @cached_property + def model(self) -> "InferenceSession": + if self._preferred_providers is None or len(self._preferred_providers) == 0: + if len(self.ort.get_available_providers()) > 0: + logger.debug( + f"WARNING: No ONNX providers provided, defaulting to available providers: " + f"{self.ort.get_available_providers()}" + ) + self._preferred_providers = self.ort.get_available_providers() + elif not set(self._preferred_providers).issubset( + set(self.ort.get_available_providers()) + ): + raise ValueError( + f"Preferred providers must be subset of available providers: {self.ort.get_available_providers()}" + ) + + # Suppress onnxruntime warnings. This produces logspew, mainly when onnx tries to use CoreML, which doesn't fit this model. + so = self.ort.SessionOptions() + so.log_severity_level = 3 + + return self.ort.InferenceSession( + os.path.join(self.DOWNLOAD_PATH, self.EXTRACTED_FOLDER_NAME, "model.onnx"), + # Since 1.9 onnyx runtime requires providers to be specified when there are multiple available - https://onnxruntime.ai/docs/api/python/api_summary.html + # This is probably not ideal but will improve DX as no exceptions will be raised in multi-provider envs + providers=self._preferred_providers, + sess_options=so, + ) + + def __call__(self, input: Documents) -> Embeddings: + # Only download the model when it is actually used + self._download_model_if_not_exists() + return cast(Embeddings, self._forward(input).tolist()) + + def _download_model_if_not_exists(self) -> None: + onnx_files = [ + "config.json", + "model.onnx", + "special_tokens_map.json", + "tokenizer_config.json", + "tokenizer.json", + "vocab.txt", + ] + extracted_folder = os.path.join(self.DOWNLOAD_PATH, self.EXTRACTED_FOLDER_NAME) + onnx_files_exist = True + for f in onnx_files: + if not os.path.exists(os.path.join(extracted_folder, f)): + onnx_files_exist = False + break + # Model is not downloaded yet + if not onnx_files_exist: + os.makedirs(self.DOWNLOAD_PATH, exist_ok=True) + if not os.path.exists( + os.path.join(self.DOWNLOAD_PATH, self.ARCHIVE_FILENAME) + ) or not _verify_sha256( + os.path.join(self.DOWNLOAD_PATH, self.ARCHIVE_FILENAME), + self._MODEL_SHA256, + ): + self._download( + url=self.MODEL_DOWNLOAD_URL, + fname=os.path.join(self.DOWNLOAD_PATH, self.ARCHIVE_FILENAME), + ) + with tarfile.open( + name=os.path.join(self.DOWNLOAD_PATH, self.ARCHIVE_FILENAME), + mode="r:gz", + ) as tar: + tar.extractall(path=self.DOWNLOAD_PATH) \ No newline at end of file From 36420587787cc76681952ea63a6db8430efbe75e Mon Sep 17 00:00:00 2001 From: David Reguera Date: Sat, 13 Apr 2024 12:03:14 +0200 Subject: [PATCH 11/23] 1965 - Move over `OpenCLIPEmbeddingFunction` --- chromadb/utils/embedding_functions.old.py | 60 -------------- .../utils/embedding_functions/__init__.py | 3 +- .../open_clip_embedding_function.py | 78 +++++++++++++++++++ 3 files changed, 80 insertions(+), 61 deletions(-) create mode 100644 chromadb/utils/embedding_functions/open_clip_embedding_function.py diff --git a/chromadb/utils/embedding_functions.old.py b/chromadb/utils/embedding_functions.old.py index 7bc55c88a69..b86e660ed95 100644 --- a/chromadb/utils/embedding_functions.old.py +++ b/chromadb/utils/embedding_functions.old.py @@ -220,66 +220,6 @@ def DefaultEmbeddingFunction() -> Optional[EmbeddingFunction[Documents]]: return ONNXMiniLM_L6_V2() -class OpenCLIPEmbeddingFunction(EmbeddingFunction[Union[Documents, Images]]): - def __init__( - self, - model_name: str = "ViT-B-32", - checkpoint: str = "laion2b_s34b_b79k", - device: Optional[str] = "cpu", - ) -> None: - try: - import open_clip - except ImportError: - raise ValueError( - "The open_clip python package is not installed. Please install it with `pip install open-clip-torch`. https://github.com/mlfoundations/open_clip" - ) - try: - self._torch = importlib.import_module("torch") - except ImportError: - raise ValueError( - "The torch python package is not installed. Please install it with `pip install torch`" - ) - - try: - self._PILImage = importlib.import_module("PIL.Image") - except ImportError: - raise ValueError( - "The PIL python package is not installed. Please install it with `pip install pillow`" - ) - - model, _, preprocess = open_clip.create_model_and_transforms( - model_name=model_name, pretrained=checkpoint - ) - self._model = model - self._model.to(device) - self._preprocess = preprocess - self._tokenizer = open_clip.get_tokenizer(model_name=model_name) - - def _encode_image(self, image: Image) -> Embedding: - pil_image = self._PILImage.fromarray(image) - with self._torch.no_grad(): - image_features = self._model.encode_image( - self._preprocess(pil_image).unsqueeze(0) - ) - image_features /= image_features.norm(dim=-1, keepdim=True) - return cast(Embedding, image_features.squeeze().tolist()) - - def _encode_text(self, text: Document) -> Embedding: - with self._torch.no_grad(): - text_features = self._model.encode_text(self._tokenizer(text)) - text_features /= text_features.norm(dim=-1, keepdim=True) - return cast(Embedding, text_features.squeeze().tolist()) - - def __call__(self, input: Union[Documents, Images]) -> Embeddings: - embeddings: Embeddings = [] - for item in input: - if is_image(item): - embeddings.append(self._encode_image(cast(Image, item))) - elif is_document(item): - embeddings.append(self._encode_text(cast(Document, item))) - return embeddings - - class RoboflowEmbeddingFunction(EmbeddingFunction[Union[Documents, Images]]): def __init__( self, api_key: str = "", api_url = "https://infer.roboflow.com" diff --git a/chromadb/utils/embedding_functions/__init__.py b/chromadb/utils/embedding_functions/__init__.py index d2d0728f7f3..0cac11542fe 100644 --- a/chromadb/utils/embedding_functions/__init__.py +++ b/chromadb/utils/embedding_functions/__init__.py @@ -7,4 +7,5 @@ from chromadb.utils.embedding_functions.instructor_embedding_function import InstructorEmbeddingFunction from chromadb.utils.embedding_functions.jina_embedding_function import JinaEmbeddingFunction from chromadb.utils.embedding_functions.ollama_embedding_function import OllamaEmbeddingFunction -from chromadb.utils.embedding_functions.onnx_mini_lm_l6_v2 import ONNXMiniLM_L6_V2, _verify_sha256 \ No newline at end of file +from chromadb.utils.embedding_functions.onnx_mini_lm_l6_v2 import ONNXMiniLM_L6_V2, _verify_sha256 +from chromadb.utils.embedding_functions.open_clip_embedding_function import OpenCLIPEmbeddingFunction \ No newline at end of file diff --git a/chromadb/utils/embedding_functions/open_clip_embedding_function.py b/chromadb/utils/embedding_functions/open_clip_embedding_function.py new file mode 100644 index 00000000000..8dac31983d4 --- /dev/null +++ b/chromadb/utils/embedding_functions/open_clip_embedding_function.py @@ -0,0 +1,78 @@ +import logging + +from chromadb.api.types import ( + Document, + Documents, + Embedding, + Image, + Images, + EmbeddingFunction, + Embeddings, + is_image, + is_document, +) + +from typing import Optional, Union, cast +import importlib + +logger = logging.getLogger(__name__) + + +class OpenCLIPEmbeddingFunction(EmbeddingFunction[Union[Documents, Images]]): + def __init__( + self, + model_name: str = "ViT-B-32", + checkpoint: str = "laion2b_s34b_b79k", + device: Optional[str] = "cpu", + ) -> None: + try: + import open_clip + except ImportError: + raise ValueError( + "The open_clip python package is not installed. Please install it with `pip install open-clip-torch`. https://github.com/mlfoundations/open_clip" + ) + try: + self._torch = importlib.import_module("torch") + except ImportError: + raise ValueError( + "The torch python package is not installed. Please install it with `pip install torch`" + ) + + try: + self._PILImage = importlib.import_module("PIL.Image") + except ImportError: + raise ValueError( + "The PIL python package is not installed. Please install it with `pip install pillow`" + ) + + model, _, preprocess = open_clip.create_model_and_transforms( + model_name=model_name, pretrained=checkpoint + ) + self._model = model + self._model.to(device) + self._preprocess = preprocess + self._tokenizer = open_clip.get_tokenizer(model_name=model_name) + + def _encode_image(self, image: Image) -> Embedding: + pil_image = self._PILImage.fromarray(image) + with self._torch.no_grad(): + image_features = self._model.encode_image( + self._preprocess(pil_image).unsqueeze(0) + ) + image_features /= image_features.norm(dim=-1, keepdim=True) + return cast(Embedding, image_features.squeeze().tolist()) + + def _encode_text(self, text: Document) -> Embedding: + with self._torch.no_grad(): + text_features = self._model.encode_text(self._tokenizer(text)) + text_features /= text_features.norm(dim=-1, keepdim=True) + return cast(Embedding, text_features.squeeze().tolist()) + + def __call__(self, input: Union[Documents, Images]) -> Embeddings: + embeddings: Embeddings = [] + for item in input: + if is_image(item): + embeddings.append(self._encode_image(cast(Image, item))) + elif is_document(item): + embeddings.append(self._encode_text(cast(Document, item))) + return embeddings \ No newline at end of file From 0196264d24866e3fbe4458688e9224f9f61715dd Mon Sep 17 00:00:00 2001 From: David Reguera Date: Sat, 13 Apr 2024 12:04:30 +0200 Subject: [PATCH 12/23] 1965 - Move over `OpenAIEmbeddingFunction` --- chromadb/utils/embedding_functions.old.py | 113 ---------------- .../utils/embedding_functions/__init__.py | 3 +- .../openai_embedding_function.py | 125 ++++++++++++++++++ 3 files changed, 127 insertions(+), 114 deletions(-) create mode 100644 chromadb/utils/embedding_functions/openai_embedding_function.py diff --git a/chromadb/utils/embedding_functions.old.py b/chromadb/utils/embedding_functions.old.py index b86e660ed95..cfae7c32e68 100644 --- a/chromadb/utils/embedding_functions.old.py +++ b/chromadb/utils/embedding_functions.old.py @@ -100,119 +100,6 @@ def __call__(self, input: Documents) -> Embeddings: ) # noqa E501 -class OpenAIEmbeddingFunction(EmbeddingFunction[Documents]): - def __init__( - self, - api_key: Optional[str] = None, - model_name: str = "text-embedding-ada-002", - organization_id: Optional[str] = None, - api_base: Optional[str] = None, - api_type: Optional[str] = None, - api_version: Optional[str] = None, - deployment_id: Optional[str] = None, - default_headers: Optional[Mapping[str, str]] = None, - ): - """ - Initialize the OpenAIEmbeddingFunction. - Args: - api_key (str, optional): Your API key for the OpenAI API. If not - provided, it will raise an error to provide an OpenAI API key. - organization_id(str, optional): The OpenAI organization ID if applicable - model_name (str, optional): The name of the model to use for text - embeddings. Defaults to "text-embedding-ada-002". - api_base (str, optional): The base path for the API. If not provided, - it will use the base path for the OpenAI API. This can be used to - point to a different deployment, such as an Azure deployment. - api_type (str, optional): The type of the API deployment. This can be - used to specify a different deployment, such as 'azure'. If not - provided, it will use the default OpenAI deployment. - api_version (str, optional): The api version for the API. If not provided, - it will use the api version for the OpenAI API. This can be used to - point to a different deployment, such as an Azure deployment. - deployment_id (str, optional): Deployment ID for Azure OpenAI. - default_headers (Mapping, optional): A mapping of default headers to be sent with each API request. - - """ - try: - import openai - except ImportError: - raise ValueError( - "The openai python package is not installed. Please install it with `pip install openai`" - ) - - if api_key is not None: - openai.api_key = api_key - # If the api key is still not set, raise an error - elif openai.api_key is None: - raise ValueError( - "Please provide an OpenAI API key. You can get one at https://platform.openai.com/account/api-keys" - ) - - if api_base is not None: - openai.api_base = api_base - - if api_version is not None: - openai.api_version = api_version - - self._api_type = api_type - if api_type is not None: - openai.api_type = api_type - - if organization_id is not None: - openai.organization = organization_id - - self._v1 = openai.__version__.startswith("1.") - if self._v1: - if api_type == "azure": - self._client = openai.AzureOpenAI( - api_key=api_key, - api_version=api_version, - azure_endpoint=api_base, - default_headers=default_headers, - ).embeddings - else: - self._client = openai.OpenAI( - api_key=api_key, base_url=api_base, default_headers=default_headers - ).embeddings - else: - self._client = openai.Embedding - self._model_name = model_name - self._deployment_id = deployment_id - - def __call__(self, input: Documents) -> Embeddings: - # replace newlines, which can negatively affect performance. - input = [t.replace("\n", " ") for t in input] - - # Call the OpenAI Embedding API - if self._v1: - embeddings = self._client.create( - input=input, model=self._deployment_id or self._model_name - ).data - - # Sort resulting embeddings by index - sorted_embeddings = sorted(embeddings, key=lambda e: e.index) - - # Return just the embeddings - return cast(Embeddings, [result.embedding for result in sorted_embeddings]) - else: - if self._api_type == "azure": - embeddings = self._client.create( - input=input, engine=self._deployment_id or self._model_name - )["data"] - else: - embeddings = self._client.create(input=input, model=self._model_name)[ - "data" - ] - - # Sort resulting embeddings by index - sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) - - # Return just the embeddings - return cast( - Embeddings, [result["embedding"] for result in sorted_embeddings] - ) - - def DefaultEmbeddingFunction() -> Optional[EmbeddingFunction[Documents]]: if is_thin_client: return None diff --git a/chromadb/utils/embedding_functions/__init__.py b/chromadb/utils/embedding_functions/__init__.py index 0cac11542fe..4ec9096da0a 100644 --- a/chromadb/utils/embedding_functions/__init__.py +++ b/chromadb/utils/embedding_functions/__init__.py @@ -8,4 +8,5 @@ from chromadb.utils.embedding_functions.jina_embedding_function import JinaEmbeddingFunction from chromadb.utils.embedding_functions.ollama_embedding_function import OllamaEmbeddingFunction from chromadb.utils.embedding_functions.onnx_mini_lm_l6_v2 import ONNXMiniLM_L6_V2, _verify_sha256 -from chromadb.utils.embedding_functions.open_clip_embedding_function import OpenCLIPEmbeddingFunction \ No newline at end of file +from chromadb.utils.embedding_functions.open_clip_embedding_function import OpenCLIPEmbeddingFunction +from chromadb.utils.embedding_functions.openai_embedding_function import OpenAIEmbeddingFunction \ No newline at end of file diff --git a/chromadb/utils/embedding_functions/openai_embedding_function.py b/chromadb/utils/embedding_functions/openai_embedding_function.py new file mode 100644 index 00000000000..0ae2978d30f --- /dev/null +++ b/chromadb/utils/embedding_functions/openai_embedding_function.py @@ -0,0 +1,125 @@ + +import logging +from chromadb.api.types import ( + Documents, + EmbeddingFunction, + Embeddings, +) + +from typing import Mapping, Optional, cast + + +logger = logging.getLogger(__name__) + + +class OpenAIEmbeddingFunction(EmbeddingFunction[Documents]): + def __init__( + self, + api_key: Optional[str] = None, + model_name: str = "text-embedding-ada-002", + organization_id: Optional[str] = None, + api_base: Optional[str] = None, + api_type: Optional[str] = None, + api_version: Optional[str] = None, + deployment_id: Optional[str] = None, + default_headers: Optional[Mapping[str, str]] = None, + ): + """ + Initialize the OpenAIEmbeddingFunction. + Args: + api_key (str, optional): Your API key for the OpenAI API. If not + provided, it will raise an error to provide an OpenAI API key. + organization_id(str, optional): The OpenAI organization ID if applicable + model_name (str, optional): The name of the model to use for text + embeddings. Defaults to "text-embedding-ada-002". + api_base (str, optional): The base path for the API. If not provided, + it will use the base path for the OpenAI API. This can be used to + point to a different deployment, such as an Azure deployment. + api_type (str, optional): The type of the API deployment. This can be + used to specify a different deployment, such as 'azure'. If not + provided, it will use the default OpenAI deployment. + api_version (str, optional): The api version for the API. If not provided, + it will use the api version for the OpenAI API. This can be used to + point to a different deployment, such as an Azure deployment. + deployment_id (str, optional): Deployment ID for Azure OpenAI. + default_headers (Mapping, optional): A mapping of default headers to be sent with each API request. + + """ + try: + import openai + except ImportError: + raise ValueError( + "The openai python package is not installed. Please install it with `pip install openai`" + ) + + if api_key is not None: + openai.api_key = api_key + # If the api key is still not set, raise an error + elif openai.api_key is None: + raise ValueError( + "Please provide an OpenAI API key. You can get one at https://platform.openai.com/account/api-keys" + ) + + if api_base is not None: + openai.api_base = api_base + + if api_version is not None: + openai.api_version = api_version + + self._api_type = api_type + if api_type is not None: + openai.api_type = api_type + + if organization_id is not None: + openai.organization = organization_id + + self._v1 = openai.__version__.startswith("1.") + if self._v1: + if api_type == "azure": + self._client = openai.AzureOpenAI( + api_key=api_key, + api_version=api_version, + azure_endpoint=api_base, + default_headers=default_headers, + ).embeddings + else: + self._client = openai.OpenAI( + api_key=api_key, base_url=api_base, default_headers=default_headers + ).embeddings + else: + self._client = openai.Embedding + self._model_name = model_name + self._deployment_id = deployment_id + + def __call__(self, input: Documents) -> Embeddings: + # replace newlines, which can negatively affect performance. + input = [t.replace("\n", " ") for t in input] + + # Call the OpenAI Embedding API + if self._v1: + embeddings = self._client.create( + input=input, model=self._deployment_id or self._model_name + ).data + + # Sort resulting embeddings by index + sorted_embeddings = sorted(embeddings, key=lambda e: e.index) + + # Return just the embeddings + return cast(Embeddings, [result.embedding for result in sorted_embeddings]) + else: + if self._api_type == "azure": + embeddings = self._client.create( + input=input, engine=self._deployment_id or self._model_name + )["data"] + else: + embeddings = self._client.create(input=input, model=self._model_name)[ + "data" + ] + + # Sort resulting embeddings by index + sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) + + # Return just the embeddings + return cast( + Embeddings, [result["embedding"] for result in sorted_embeddings] + ) \ No newline at end of file From 929a8d4a4c8df641c79b9cfcf9c6c626e3b27cab Mon Sep 17 00:00:00 2001 From: David Reguera Date: Sat, 13 Apr 2024 12:05:50 +0200 Subject: [PATCH 13/23] 1965 - Move over `RoboflowEmbeddingFunction` --- chromadb/utils/embedding_functions.old.py | 68 --------------- .../utils/embedding_functions/__init__.py | 3 +- .../roboflow_embedding_function.py | 85 +++++++++++++++++++ 3 files changed, 87 insertions(+), 69 deletions(-) create mode 100644 chromadb/utils/embedding_functions/roboflow_embedding_function.py diff --git a/chromadb/utils/embedding_functions.old.py b/chromadb/utils/embedding_functions.old.py index cfae7c32e68..5b10b21f5bb 100644 --- a/chromadb/utils/embedding_functions.old.py +++ b/chromadb/utils/embedding_functions.old.py @@ -106,74 +106,6 @@ def DefaultEmbeddingFunction() -> Optional[EmbeddingFunction[Documents]]: else: return ONNXMiniLM_L6_V2() - -class RoboflowEmbeddingFunction(EmbeddingFunction[Union[Documents, Images]]): - def __init__( - self, api_key: str = "", api_url = "https://infer.roboflow.com" - ) -> None: - """ - Create a RoboflowEmbeddingFunction. - - Args: - api_key (str): Your API key for the Roboflow API. - api_url (str, optional): The URL of the Roboflow API. Defaults to "https://infer.roboflow.com". - """ - if not api_key: - api_key = os.environ.get("ROBOFLOW_API_KEY") - - self._api_url = api_url - self._api_key = api_key - - try: - self._PILImage = importlib.import_module("PIL.Image") - except ImportError: - raise ValueError( - "The PIL python package is not installed. Please install it with `pip install pillow`" - ) - - def __call__(self, input: Union[Documents, Images]) -> Embeddings: - embeddings = [] - - for item in input: - if is_image(item): - image = self._PILImage.fromarray(item) - - buffer = BytesIO() - image.save(buffer, format="JPEG") - base64_image = base64.b64encode(buffer.getvalue()).decode("utf-8") - - infer_clip_payload = { - "image": { - "type": "base64", - "value": base64_image, - }, - } - - res = requests.post( - f"{self._api_url}/clip/embed_image?api_key={self._api_key}", - json=infer_clip_payload, - ) - - result = res.json()['embeddings'] - - embeddings.append(result[0]) - - elif is_document(item): - infer_clip_payload = { - "text": input, - } - - res = requests.post( - f"{self._api_url}/clip/embed_text?api_key={self._api_key}", - json=infer_clip_payload, - ) - - result = res.json()['embeddings'] - - embeddings.append(result[0]) - - return embeddings - # List of all classes in this module _classes = [ diff --git a/chromadb/utils/embedding_functions/__init__.py b/chromadb/utils/embedding_functions/__init__.py index 4ec9096da0a..60affc6b051 100644 --- a/chromadb/utils/embedding_functions/__init__.py +++ b/chromadb/utils/embedding_functions/__init__.py @@ -9,4 +9,5 @@ from chromadb.utils.embedding_functions.ollama_embedding_function import OllamaEmbeddingFunction from chromadb.utils.embedding_functions.onnx_mini_lm_l6_v2 import ONNXMiniLM_L6_V2, _verify_sha256 from chromadb.utils.embedding_functions.open_clip_embedding_function import OpenCLIPEmbeddingFunction -from chromadb.utils.embedding_functions.openai_embedding_function import OpenAIEmbeddingFunction \ No newline at end of file +from chromadb.utils.embedding_functions.openai_embedding_function import OpenAIEmbeddingFunction +from chromadb.utils.embedding_functions.roboflow_embedding_function import RoboflowEmbeddingFunction \ No newline at end of file diff --git a/chromadb/utils/embedding_functions/roboflow_embedding_function.py b/chromadb/utils/embedding_functions/roboflow_embedding_function.py new file mode 100644 index 00000000000..e3a16fa943a --- /dev/null +++ b/chromadb/utils/embedding_functions/roboflow_embedding_function.py @@ -0,0 +1,85 @@ +import logging + +from chromadb.api.types import ( + Documents, + Images, + EmbeddingFunction, + Embeddings, + is_image, + is_document, +) + +from io import BytesIO +import os +import requests +from typing import Union +import importlib +import base64 + +logger = logging.getLogger(__name__) + + +class RoboflowEmbeddingFunction(EmbeddingFunction[Union[Documents, Images]]): + def __init__(self, api_key: str = "", api_url="https://infer.roboflow.com") -> None: + """ + Create a RoboflowEmbeddingFunction. + + Args: + api_key (str): Your API key for the Roboflow API. + api_url (str, optional): The URL of the Roboflow API. Defaults to "https://infer.roboflow.com". + """ + if not api_key: + api_key = os.environ.get("ROBOFLOW_API_KEY") + + self._api_url = api_url + self._api_key = api_key + + try: + self._PILImage = importlib.import_module("PIL.Image") + except ImportError: + raise ValueError( + "The PIL python package is not installed. Please install it with `pip install pillow`" + ) + + def __call__(self, input: Union[Documents, Images]) -> Embeddings: + embeddings = [] + + for item in input: + if is_image(item): + image = self._PILImage.fromarray(item) + + buffer = BytesIO() + image.save(buffer, format="JPEG") + base64_image = base64.b64encode(buffer.getvalue()).decode("utf-8") + + infer_clip_payload = { + "image": { + "type": "base64", + "value": base64_image, + }, + } + + res = requests.post( + f"{self._api_url}/clip/embed_image?api_key={self._api_key}", + json=infer_clip_payload, + ) + + result = res.json()["embeddings"] + + embeddings.append(result[0]) + + elif is_document(item): + infer_clip_payload = { + "text": input, + } + + res = requests.post( + f"{self._api_url}/clip/embed_text?api_key={self._api_key}", + json=infer_clip_payload, + ) + + result = res.json()["embeddings"] + + embeddings.append(result[0]) + + return embeddings \ No newline at end of file From c2e2cc88364c3b9353b3001c9b27cc6f2b3a8a98 Mon Sep 17 00:00:00 2001 From: David Reguera Date: Sat, 13 Apr 2024 12:07:28 +0200 Subject: [PATCH 14/23] 1965 - Move over `SentenceTransformerEmbeddingFunction` --- chromadb/utils/embedding_functions.old.py | 45 ----------------- .../utils/embedding_functions/__init__.py | 3 +- ...sentence_transformer_embedding_function.py | 49 +++++++++++++++++++ 3 files changed, 51 insertions(+), 46 deletions(-) create mode 100644 chromadb/utils/embedding_functions/sentence_transformer_embedding_function.py diff --git a/chromadb/utils/embedding_functions.old.py b/chromadb/utils/embedding_functions.old.py index 5b10b21f5bb..4accc48a9f9 100644 --- a/chromadb/utils/embedding_functions.old.py +++ b/chromadb/utils/embedding_functions.old.py @@ -39,51 +39,6 @@ logger = logging.getLogger(__name__) -class SentenceTransformerEmbeddingFunction(EmbeddingFunction[Documents]): - # Since we do dynamic imports we have to type this as Any - models: Dict[str, Any] = {} - - # If you have a beefier machine, try "gtr-t5-large". - # for a full list of options: https://huggingface.co/sentence-transformers, https://www.sbert.net/docs/pretrained_models.html - def __init__( - self, - model_name: str = "all-MiniLM-L6-v2", - device: str = "cpu", - normalize_embeddings: bool = False, - **kwargs: Any, - ): - """Initialize SentenceTransformerEmbeddingFunction. - - Args: - model_name (str, optional): Identifier of the SentenceTransformer model, defaults to "all-MiniLM-L6-v2" - device (str, optional): Device used for computation, defaults to "cpu" - normalize_embeddings (bool, optional): Whether to normalize returned vectors, defaults to False - **kwargs: Additional arguments to pass to the SentenceTransformer model. - """ - if model_name not in self.models: - try: - from sentence_transformers import SentenceTransformer - except ImportError: - raise ValueError( - "The sentence_transformers python package is not installed. Please install it with `pip install sentence_transformers`" - ) - self.models[model_name] = SentenceTransformer( - model_name, device=device, **kwargs - ) - self._model = self.models[model_name] - self._normalize_embeddings = normalize_embeddings - - def __call__(self, input: Documents) -> Embeddings: - return cast( - Embeddings, - self._model.encode( - list(input), - convert_to_numpy=True, - normalize_embeddings=self._normalize_embeddings, - ).tolist(), - ) - - class Text2VecEmbeddingFunction(EmbeddingFunction[Documents]): def __init__(self, model_name: str = "shibing624/text2vec-base-chinese"): try: diff --git a/chromadb/utils/embedding_functions/__init__.py b/chromadb/utils/embedding_functions/__init__.py index 60affc6b051..be9b70db196 100644 --- a/chromadb/utils/embedding_functions/__init__.py +++ b/chromadb/utils/embedding_functions/__init__.py @@ -10,4 +10,5 @@ from chromadb.utils.embedding_functions.onnx_mini_lm_l6_v2 import ONNXMiniLM_L6_V2, _verify_sha256 from chromadb.utils.embedding_functions.open_clip_embedding_function import OpenCLIPEmbeddingFunction from chromadb.utils.embedding_functions.openai_embedding_function import OpenAIEmbeddingFunction -from chromadb.utils.embedding_functions.roboflow_embedding_function import RoboflowEmbeddingFunction \ No newline at end of file +from chromadb.utils.embedding_functions.roboflow_embedding_function import RoboflowEmbeddingFunction +from chromadb.utils.embedding_functions.sentence_transformer_embedding_function import SentenceTransformerEmbeddingFunction \ No newline at end of file diff --git a/chromadb/utils/embedding_functions/sentence_transformer_embedding_function.py b/chromadb/utils/embedding_functions/sentence_transformer_embedding_function.py new file mode 100644 index 00000000000..682dfc1f1e2 --- /dev/null +++ b/chromadb/utils/embedding_functions/sentence_transformer_embedding_function.py @@ -0,0 +1,49 @@ +import logging +from typing import Any, Dict, cast +from chromadb.api.types import Documents, EmbeddingFunction, Embeddings + +logger = logging.getLogger(__name__) + +class SentenceTransformerEmbeddingFunction(EmbeddingFunction[Documents]): + # Since we do dynamic imports we have to type this as Any + models: Dict[str, Any] = {} + + # If you have a beefier machine, try "gtr-t5-large". + # for a full list of options: https://huggingface.co/sentence-transformers, https://www.sbert.net/docs/pretrained_models.html + def __init__( + self, + model_name: str = "all-MiniLM-L6-v2", + device: str = "cpu", + normalize_embeddings: bool = False, + **kwargs: Any, + ): + """Initialize SentenceTransformerEmbeddingFunction. + + Args: + model_name (str, optional): Identifier of the SentenceTransformer model, defaults to "all-MiniLM-L6-v2" + device (str, optional): Device used for computation, defaults to "cpu" + normalize_embeddings (bool, optional): Whether to normalize returned vectors, defaults to False + **kwargs: Additional arguments to pass to the SentenceTransformer model. + """ + if model_name not in self.models: + try: + from sentence_transformers import SentenceTransformer + except ImportError: + raise ValueError( + "The sentence_transformers python package is not installed. Please install it with `pip install sentence_transformers`" + ) + self.models[model_name] = SentenceTransformer( + model_name, device=device, **kwargs + ) + self._model = self.models[model_name] + self._normalize_embeddings = normalize_embeddings + + def __call__(self, input: Documents) -> Embeddings: + return cast( + Embeddings, + self._model.encode( + list(input), + convert_to_numpy=True, + normalize_embeddings=self._normalize_embeddings, + ).tolist(), + ) From 2632601490144c761a0333a1821beac8e70767b6 Mon Sep 17 00:00:00 2001 From: David Reguera Date: Sat, 13 Apr 2024 12:08:33 +0200 Subject: [PATCH 15/23] 1965 - Move over `Text2VecEmbeddingFunction` --- chromadb/utils/embedding_functions.old.py | 16 -------------- .../utils/embedding_functions/__init__.py | 3 ++- .../text2vec_embedding_function.py | 21 +++++++++++++++++++ 3 files changed, 23 insertions(+), 17 deletions(-) create mode 100644 chromadb/utils/embedding_functions/text2vec_embedding_function.py diff --git a/chromadb/utils/embedding_functions.old.py b/chromadb/utils/embedding_functions.old.py index 4accc48a9f9..87ad6f81f0b 100644 --- a/chromadb/utils/embedding_functions.old.py +++ b/chromadb/utils/embedding_functions.old.py @@ -39,22 +39,6 @@ logger = logging.getLogger(__name__) -class Text2VecEmbeddingFunction(EmbeddingFunction[Documents]): - def __init__(self, model_name: str = "shibing624/text2vec-base-chinese"): - try: - from text2vec import SentenceModel - except ImportError: - raise ValueError( - "The text2vec python package is not installed. Please install it with `pip install text2vec`" - ) - self._model = SentenceModel(model_name_or_path=model_name) - - def __call__(self, input: Documents) -> Embeddings: - return cast( - Embeddings, self._model.encode(list(input), convert_to_numpy=True).tolist() - ) # noqa E501 - - def DefaultEmbeddingFunction() -> Optional[EmbeddingFunction[Documents]]: if is_thin_client: return None diff --git a/chromadb/utils/embedding_functions/__init__.py b/chromadb/utils/embedding_functions/__init__.py index be9b70db196..09e12ebc61a 100644 --- a/chromadb/utils/embedding_functions/__init__.py +++ b/chromadb/utils/embedding_functions/__init__.py @@ -11,4 +11,5 @@ from chromadb.utils.embedding_functions.open_clip_embedding_function import OpenCLIPEmbeddingFunction from chromadb.utils.embedding_functions.openai_embedding_function import OpenAIEmbeddingFunction from chromadb.utils.embedding_functions.roboflow_embedding_function import RoboflowEmbeddingFunction -from chromadb.utils.embedding_functions.sentence_transformer_embedding_function import SentenceTransformerEmbeddingFunction \ No newline at end of file +from chromadb.utils.embedding_functions.sentence_transformer_embedding_function import SentenceTransformerEmbeddingFunction +from chromadb.utils.embedding_functions.text2vec_embedding_function import Text2VecEmbeddingFunction \ No newline at end of file diff --git a/chromadb/utils/embedding_functions/text2vec_embedding_function.py b/chromadb/utils/embedding_functions/text2vec_embedding_function.py new file mode 100644 index 00000000000..7b73c6b889d --- /dev/null +++ b/chromadb/utils/embedding_functions/text2vec_embedding_function.py @@ -0,0 +1,21 @@ +from typing import cast +from chromadb.api.types import Documents, EmbeddingFunction, Embeddings +import logging + +logger = logging.getLogger(__name__) + + +class Text2VecEmbeddingFunction(EmbeddingFunction[Documents]): + def __init__(self, model_name: str = "shibing624/text2vec-base-chinese"): + try: + from text2vec import SentenceModel + except ImportError: + raise ValueError( + "The text2vec python package is not installed. Please install it with `pip install text2vec`" + ) + self._model = SentenceModel(model_name_or_path=model_name) + + def __call__(self, input: Documents) -> Embeddings: + return cast( + Embeddings, self._model.encode(list(input), convert_to_numpy=True).tolist() + ) # noqa E501 From 6770d21ec547cf744bbc5c6218de3e0e305c9c72 Mon Sep 17 00:00:00 2001 From: David Reguera Date: Sat, 13 Apr 2024 12:10:03 +0200 Subject: [PATCH 16/23] 1965 - Move remaining functions --- chromadb/utils/embedding_functions.old.py | 58 ------------------- .../utils/embedding_functions/__init__.py | 30 +++++++++- 2 files changed, 29 insertions(+), 59 deletions(-) delete mode 100644 chromadb/utils/embedding_functions.old.py diff --git a/chromadb/utils/embedding_functions.old.py b/chromadb/utils/embedding_functions.old.py deleted file mode 100644 index 87ad6f81f0b..00000000000 --- a/chromadb/utils/embedding_functions.old.py +++ /dev/null @@ -1,58 +0,0 @@ -import hashlib -import logging -from functools import cached_property - -from tenacity import stop_after_attempt, wait_random, retry, retry_if_exception - -from chromadb.api.types import ( - Document, - Documents, - Embedding, - Image, - Images, - EmbeddingFunction, - Embeddings, - is_image, - is_document, -) - -from io import BytesIO -from pathlib import Path -import os -import tarfile -import requests -from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Union, cast -import numpy as np -import numpy.typing as npt -import importlib -import inspect -import json -import sys -import base64 - -try: - from chromadb.is_thin_client import is_thin_client -except ImportError: - is_thin_client = False - - -logger = logging.getLogger(__name__) - - -def DefaultEmbeddingFunction() -> Optional[EmbeddingFunction[Documents]]: - if is_thin_client: - return None - else: - return ONNXMiniLM_L6_V2() - - -# List of all classes in this module -_classes = [ - name - for name, obj in inspect.getmembers(sys.modules[__name__], inspect.isclass) - if obj.__module__ == __name__ -] - - -def get_builtins() -> List[str]: - return _classes diff --git a/chromadb/utils/embedding_functions/__init__.py b/chromadb/utils/embedding_functions/__init__.py index 09e12ebc61a..2fa9bfb8138 100644 --- a/chromadb/utils/embedding_functions/__init__.py +++ b/chromadb/utils/embedding_functions/__init__.py @@ -1,4 +1,8 @@ +import inspect +import sys +from typing import List, Optional +from chromadb.api.types import Documents, EmbeddingFunction from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import AmazonBedrockEmbeddingFunction from chromadb.utils.embedding_functions.chroma_langchain_embedding_function import create_langchain_embedding from chromadb.utils.embedding_functions.cohere_embedding_function import CohereEmbeddingFunction @@ -12,4 +16,28 @@ from chromadb.utils.embedding_functions.openai_embedding_function import OpenAIEmbeddingFunction from chromadb.utils.embedding_functions.roboflow_embedding_function import RoboflowEmbeddingFunction from chromadb.utils.embedding_functions.sentence_transformer_embedding_function import SentenceTransformerEmbeddingFunction -from chromadb.utils.embedding_functions.text2vec_embedding_function import Text2VecEmbeddingFunction \ No newline at end of file +from chromadb.utils.embedding_functions.text2vec_embedding_function import Text2VecEmbeddingFunction + + +try: + from chromadb.is_thin_client import is_thin_client +except ImportError: + is_thin_client = False + + +def DefaultEmbeddingFunction() -> Optional[EmbeddingFunction[Documents]]: + if is_thin_client: + return None + else: + return ONNXMiniLM_L6_V2() + + +_classes = [ + name + for name, obj in inspect.getmembers(sys.modules[__name__], inspect.isclass) + if obj.__module__ == __name__ +] + + +def get_builtins() -> List[str]: + return _classes \ No newline at end of file From 6ad759829df125d17e6b26d20ebb987b70861705 Mon Sep 17 00:00:00 2001 From: David Reguera Date: Sat, 20 Apr 2024 11:58:43 +0200 Subject: [PATCH 17/23] 1965 - Lint Files --- .../utils/embedding_functions/__init__.py | 82 +++++++++++++++---- .../amazon_bedrock_embedding_function.py | 19 ++--- .../chroma_langchain_embedding_function.py | 35 +------- .../cohere_embedding_function.py | 10 +-- .../google_embedding_function.py | 13 +-- .../huggingface_embedding_function.py | 14 +--- .../instructor_embedding_function.py | 12 +-- .../jina_embedding_function.py | 13 +-- .../ollama_embedding_function.py | 13 +-- .../open_clip_embedding_function.py | 13 ++- .../openai_embedding_function.py | 33 +++++--- .../roboflow_embedding_function.py | 34 ++++---- ...sentence_transformer_embedding_function.py | 2 + .../text2vec_embedding_function.py | 3 +- 14 files changed, 151 insertions(+), 145 deletions(-) diff --git a/chromadb/utils/embedding_functions/__init__.py b/chromadb/utils/embedding_functions/__init__.py index 2fa9bfb8138..1226a0b034b 100644 --- a/chromadb/utils/embedding_functions/__init__.py +++ b/chromadb/utils/embedding_functions/__init__.py @@ -3,21 +3,73 @@ from typing import List, Optional from chromadb.api.types import Documents, EmbeddingFunction -from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import AmazonBedrockEmbeddingFunction -from chromadb.utils.embedding_functions.chroma_langchain_embedding_function import create_langchain_embedding -from chromadb.utils.embedding_functions.cohere_embedding_function import CohereEmbeddingFunction -from chromadb.utils.embedding_functions.google_embedding_function import (GoogleGenerativeAiEmbeddingFunction, GooglePalmEmbeddingFunction, GoogleVertexEmbeddingFunction) -from chromadb.utils.embedding_functions.huggingface_embedding_function import (HuggingFaceEmbeddingFunction, HuggingFaceEmbeddingServer) -from chromadb.utils.embedding_functions.instructor_embedding_function import InstructorEmbeddingFunction -from chromadb.utils.embedding_functions.jina_embedding_function import JinaEmbeddingFunction -from chromadb.utils.embedding_functions.ollama_embedding_function import OllamaEmbeddingFunction -from chromadb.utils.embedding_functions.onnx_mini_lm_l6_v2 import ONNXMiniLM_L6_V2, _verify_sha256 -from chromadb.utils.embedding_functions.open_clip_embedding_function import OpenCLIPEmbeddingFunction -from chromadb.utils.embedding_functions.openai_embedding_function import OpenAIEmbeddingFunction -from chromadb.utils.embedding_functions.roboflow_embedding_function import RoboflowEmbeddingFunction -from chromadb.utils.embedding_functions.sentence_transformer_embedding_function import SentenceTransformerEmbeddingFunction -from chromadb.utils.embedding_functions.text2vec_embedding_function import Text2VecEmbeddingFunction +from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import ( + AmazonBedrockEmbeddingFunction, +) +from chromadb.utils.embedding_functions.chroma_langchain_embedding_function import ( + create_langchain_embedding, +) +from chromadb.utils.embedding_functions.cohere_embedding_function import ( + CohereEmbeddingFunction, +) +from chromadb.utils.embedding_functions.google_embedding_function import ( + GoogleGenerativeAiEmbeddingFunction, + GooglePalmEmbeddingFunction, + GoogleVertexEmbeddingFunction, +) +from chromadb.utils.embedding_functions.huggingface_embedding_function import ( + HuggingFaceEmbeddingFunction, + HuggingFaceEmbeddingServer, +) +from chromadb.utils.embedding_functions.instructor_embedding_function import ( + InstructorEmbeddingFunction, +) +from chromadb.utils.embedding_functions.jina_embedding_function import ( + JinaEmbeddingFunction, +) +from chromadb.utils.embedding_functions.ollama_embedding_function import ( + OllamaEmbeddingFunction, +) +from chromadb.utils.embedding_functions.onnx_mini_lm_l6_v2 import ( + ONNXMiniLM_L6_V2, + _verify_sha256, +) +from chromadb.utils.embedding_functions.open_clip_embedding_function import ( + OpenCLIPEmbeddingFunction, +) +from chromadb.utils.embedding_functions.openai_embedding_function import ( + OpenAIEmbeddingFunction, +) +from chromadb.utils.embedding_functions.roboflow_embedding_function import ( + RoboflowEmbeddingFunction, +) +from chromadb.utils.embedding_functions.sentence_transformer_embedding_function import ( + SentenceTransformerEmbeddingFunction, +) +from chromadb.utils.embedding_functions.text2vec_embedding_function import ( + Text2VecEmbeddingFunction, +) +__all__ = [ + "AmazonBedrockEmbeddingFunction", + "create_langchain_embedding", + "CohereEmbeddingFunction", + "GoogleGenerativeAiEmbeddingFunction", + "GooglePalmEmbeddingFunction", + "GoogleVertexEmbeddingFunction", + "HuggingFaceEmbeddingFunction", + "HuggingFaceEmbeddingServer", + "InstructorEmbeddingFunction", + "JinaEmbeddingFunction", + "OllamaEmbeddingFunction", + "OpenCLIPEmbeddingFunction", + "ONNXMiniLM_L6_V2", + "OpenAIEmbeddingFunction", + "RoboflowEmbeddingFunction", + "SentenceTransformerEmbeddingFunction", + "Text2VecEmbeddingFunction", + "_verify_sha256", +] try: from chromadb.is_thin_client import is_thin_client @@ -40,4 +92,4 @@ def DefaultEmbeddingFunction() -> Optional[EmbeddingFunction[Documents]]: def get_builtins() -> List[str]: - return _classes \ No newline at end of file + return _classes diff --git a/chromadb/utils/embedding_functions/amazon_bedrock_embedding_function.py b/chromadb/utils/embedding_functions/amazon_bedrock_embedding_function.py index 11a3fba875c..67103ab7ffd 100644 --- a/chromadb/utils/embedding_functions/amazon_bedrock_embedding_function.py +++ b/chromadb/utils/embedding_functions/amazon_bedrock_embedding_function.py @@ -1,29 +1,24 @@ +import json import logging - -from chromadb.api.types import ( - Documents, - EmbeddingFunction, - Embeddings, -) - from typing import Any -import json -logger = logging.getLogger(__name__) +from chromadb.api.types import Documents, EmbeddingFunction, Embeddings +logger = logging.getLogger(__name__) class AmazonBedrockEmbeddingFunction(EmbeddingFunction[Documents]): def __init__( self, - session: "boto3.Session", # noqa: F821 # Quote for forward reference + session: Any, model_name: str = "amazon.titan-embed-text-v1", **kwargs: Any, ): """Initialize AmazonBedrockEmbeddingFunction. Args: - session (boto3.Session): The boto3 session to use. + session (boto3.Session): The boto3 session to use. You need to have boto3 + installed, `pip install boto3`. model_name (str, optional): Identifier of the model, defaults to "amazon.titan-embed-text-v1" **kwargs: Additional arguments to pass to the boto3 client. @@ -57,4 +52,4 @@ def __call__(self, input: Documents) -> Embeddings: ) embedding = json.load(response.get("body")).get("embedding") embeddings.append(embedding) - return embeddings \ No newline at end of file + return embeddings diff --git a/chromadb/utils/embedding_functions/chroma_langchain_embedding_function.py b/chromadb/utils/embedding_functions/chroma_langchain_embedding_function.py index 4908efd2883..445cca5b128 100644 --- a/chromadb/utils/embedding_functions/chroma_langchain_embedding_function.py +++ b/chromadb/utils/embedding_functions/chroma_langchain_embedding_function.py @@ -1,34 +1,7 @@ -import hashlib import logging -from functools import cached_property - -from tenacity import stop_after_attempt, wait_random, retry, retry_if_exception - -from chromadb.api.types import ( - Document, - Documents, - Embedding, - Image, - Images, - EmbeddingFunction, - Embeddings, - is_image, - is_document, -) - -from io import BytesIO -from pathlib import Path -import os -import tarfile -import requests -from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Union, cast -import numpy as np -import numpy.typing as npt -import importlib -import inspect -import json -import sys -import base64 +from typing import Any, List, Union + +from chromadb.api.types import Documents, EmbeddingFunction, Embeddings, Images logger = logging.getLogger(__name__) @@ -93,4 +66,4 @@ def __call__(self, input: Documents) -> Embeddings: # type: ignore return self.embed_documents(list(input)) # type: ignore - return ChromaLangchainEmbeddingFunction(embedding_function=langchain_embdding_fn) \ No newline at end of file + return ChromaLangchainEmbeddingFunction(embedding_function=langchain_embdding_fn) diff --git a/chromadb/utils/embedding_functions/cohere_embedding_function.py b/chromadb/utils/embedding_functions/cohere_embedding_function.py index d7b703be231..ef9c33e24b9 100644 --- a/chromadb/utils/embedding_functions/cohere_embedding_function.py +++ b/chromadb/utils/embedding_functions/cohere_embedding_function.py @@ -1,10 +1,6 @@ - import logging -from chromadb.api.types import ( - Documents, - EmbeddingFunction, - Embeddings, -) + +from chromadb.api.types import Documents, EmbeddingFunction, Embeddings logger = logging.getLogger(__name__) @@ -28,4 +24,4 @@ def __call__(self, input: Documents) -> Embeddings: for embeddings in self._client.embed( texts=input, model=self._model_name, input_type="search_document" ) - ] \ No newline at end of file + ] diff --git a/chromadb/utils/embedding_functions/google_embedding_function.py b/chromadb/utils/embedding_functions/google_embedding_function.py index 2aea0bd7171..5db890e5a2f 100644 --- a/chromadb/utils/embedding_functions/google_embedding_function.py +++ b/chromadb/utils/embedding_functions/google_embedding_function.py @@ -1,15 +1,10 @@ import logging -from chromadb.api.types import ( - Documents, - EmbeddingFunction, - Embeddings, -) - import requests -logger = logging.getLogger(__name__) +from chromadb.api.types import Documents, EmbeddingFunction, Embeddings +logger = logging.getLogger(__name__) class GooglePalmEmbeddingFunction(EmbeddingFunction[Documents]): @@ -84,7 +79,7 @@ def __call__(self, input: Documents) -> Embeddings: )["embedding"] for text in input ] - + class GoogleVertexEmbeddingFunction(EmbeddingFunction[Documents]): # Follow API Quickstart for Google Vertex AI @@ -112,4 +107,4 @@ def __call__(self, input: Documents) -> Embeddings: if "predictions" in response: embeddings.append(response["predictions"]["embeddings"]["values"]) - return embeddings \ No newline at end of file + return embeddings diff --git a/chromadb/utils/embedding_functions/huggingface_embedding_function.py b/chromadb/utils/embedding_functions/huggingface_embedding_function.py index 95230ea696d..541c3a4ca4b 100644 --- a/chromadb/utils/embedding_functions/huggingface_embedding_function.py +++ b/chromadb/utils/embedding_functions/huggingface_embedding_function.py @@ -1,16 +1,11 @@ import logging - -from chromadb.api.types import ( - Documents, - EmbeddingFunction, - Embeddings, -) - from typing import cast + import requests -logger = logging.getLogger(__name__) +from chromadb.api.types import Documents, EmbeddingFunction, Embeddings +logger = logging.getLogger(__name__) class HuggingFaceEmbeddingFunction(EmbeddingFunction[Documents]): @@ -58,7 +53,6 @@ def __call__(self, input: Documents) -> Embeddings: ) - class HuggingFaceEmbeddingServer(EmbeddingFunction[Documents]): """ This class is used to get embeddings for a list of texts using the HuggingFace Embedding server (https://github.com/huggingface/text-embeddings-inference). @@ -99,4 +93,4 @@ def __call__(self, input: Documents) -> Embeddings: # Call HuggingFace Embedding Server API for each document return cast( Embeddings, self._session.post(self._api_url, json={"inputs": input}).json() - ) \ No newline at end of file + ) diff --git a/chromadb/utils/embedding_functions/instructor_embedding_function.py b/chromadb/utils/embedding_functions/instructor_embedding_function.py index 13d8c715aeb..a9ea6b26038 100644 --- a/chromadb/utils/embedding_functions/instructor_embedding_function.py +++ b/chromadb/utils/embedding_functions/instructor_embedding_function.py @@ -1,15 +1,9 @@ import logging - -from chromadb.api.types import ( - Documents, - EmbeddingFunction, - Embeddings, -) - from typing import Optional, cast -logger = logging.getLogger(__name__) +from chromadb.api.types import Documents, EmbeddingFunction, Embeddings +logger = logging.getLogger(__name__) class InstructorEmbeddingFunction(EmbeddingFunction[Documents]): @@ -36,4 +30,4 @@ def __call__(self, input: Documents) -> Embeddings: texts_with_instructions = [[self._instruction, text] for text in input] - return cast(Embeddings, self._model.encode(texts_with_instructions).tolist()) \ No newline at end of file + return cast(Embeddings, self._model.encode(texts_with_instructions).tolist()) diff --git a/chromadb/utils/embedding_functions/jina_embedding_function.py b/chromadb/utils/embedding_functions/jina_embedding_function.py index 26f71c7ee18..99baa4089a9 100644 --- a/chromadb/utils/embedding_functions/jina_embedding_function.py +++ b/chromadb/utils/embedding_functions/jina_embedding_function.py @@ -1,14 +1,9 @@ import logging - -from chromadb.api.types import ( - Documents, - EmbeddingFunction, - Embeddings, -) +from typing import List, cast, Union import requests -from typing import cast +from chromadb.api.types import Documents, EmbeddingFunction, Embeddings logger = logging.getLogger(__name__) @@ -56,10 +51,10 @@ def __call__(self, input: Documents) -> Embeddings: if "data" not in resp: raise RuntimeError(resp["detail"]) - embeddings = resp["data"] + embeddings: List[dict[str, Union[str, List[float]]]] = resp["data"] # Sort resulting embeddings by index sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # Return just the embeddings - return cast(Embeddings, [result["embedding"] for result in sorted_embeddings]) \ No newline at end of file + return cast(Embeddings, [result["embedding"] for result in sorted_embeddings]) diff --git a/chromadb/utils/embedding_functions/ollama_embedding_function.py b/chromadb/utils/embedding_functions/ollama_embedding_function.py index b2b896d7cc9..6cc1e0e4c7b 100644 --- a/chromadb/utils/embedding_functions/ollama_embedding_function.py +++ b/chromadb/utils/embedding_functions/ollama_embedding_function.py @@ -1,12 +1,7 @@ import logging +from typing import Union, cast -from chromadb.api.types import ( - Documents, - EmbeddingFunction, - Embeddings, -) - -from typing import cast +from chromadb.api.types import Documents, EmbeddingFunction, Embeddings logger = logging.getLogger(__name__) @@ -34,7 +29,7 @@ def __init__(self, url: str, model_name: str) -> None: self._model_name = model_name self._session = requests.Session() - def __call__(self, input: Documents) -> Embeddings: + def __call__(self, input: Union[Documents, str]) -> Embeddings: """ Get the embeddings for a list of texts. @@ -64,4 +59,4 @@ def __call__(self, input: Documents) -> Embeddings: for embedding in embeddings if "embedding" in embedding ], - ) \ No newline at end of file + ) diff --git a/chromadb/utils/embedding_functions/open_clip_embedding_function.py b/chromadb/utils/embedding_functions/open_clip_embedding_function.py index 8dac31983d4..712cd871905 100644 --- a/chromadb/utils/embedding_functions/open_clip_embedding_function.py +++ b/chromadb/utils/embedding_functions/open_clip_embedding_function.py @@ -1,20 +1,19 @@ +import importlib import logging +from typing import Optional, Union, cast from chromadb.api.types import ( Document, Documents, Embedding, - Image, - Images, EmbeddingFunction, Embeddings, - is_image, + Image, + Images, is_document, + is_image, ) -from typing import Optional, Union, cast -import importlib - logger = logging.getLogger(__name__) @@ -75,4 +74,4 @@ def __call__(self, input: Union[Documents, Images]) -> Embeddings: embeddings.append(self._encode_image(cast(Image, item))) elif is_document(item): embeddings.append(self._encode_text(cast(Document, item))) - return embeddings \ No newline at end of file + return embeddings diff --git a/chromadb/utils/embedding_functions/openai_embedding_function.py b/chromadb/utils/embedding_functions/openai_embedding_function.py index 0ae2978d30f..03eff5437b3 100644 --- a/chromadb/utils/embedding_functions/openai_embedding_function.py +++ b/chromadb/utils/embedding_functions/openai_embedding_function.py @@ -1,13 +1,7 @@ - import logging -from chromadb.api.types import ( - Documents, - EmbeddingFunction, - Embeddings, -) - from typing import Mapping, Optional, cast +from chromadb.api.types import Documents, EmbeddingFunction, Embeddings logger = logging.getLogger(__name__) @@ -92,6 +86,21 @@ def __init__( self._deployment_id = deployment_id def __call__(self, input: Documents) -> Embeddings: + """ + Generate the embeddings for the given `input`. + + # About ignoring types + We are not enforcing the openai library, therefore, `mypy` has hard times trying + to figure out what the types are for `self._client.create()` which throws an + error when trying to sort the list. If, eventually we include the `openai` lib + we can remove the type ignore tag. + + Args: + input (Documents): A list of texts to get embeddings for. + + Returns: + Embeddings: The embeddings for the given input sorted by index + """ # replace newlines, which can negatively affect performance. input = [t.replace("\n", " ") for t in input] @@ -102,7 +111,9 @@ def __call__(self, input: Documents) -> Embeddings: ).data # Sort resulting embeddings by index - sorted_embeddings = sorted(embeddings, key=lambda e: e.index) + sorted_embeddings = sorted( + embeddings, key=lambda e: e.index # type: ignore + ) # Return just the embeddings return cast(Embeddings, [result.embedding for result in sorted_embeddings]) @@ -117,9 +128,11 @@ def __call__(self, input: Documents) -> Embeddings: ] # Sort resulting embeddings by index - sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) + sorted_embeddings = sorted( + embeddings, key=lambda e: e["index"] # type: ignore + ) # Return just the embeddings return cast( Embeddings, [result["embedding"] for result in sorted_embeddings] - ) \ No newline at end of file + ) diff --git a/chromadb/utils/embedding_functions/roboflow_embedding_function.py b/chromadb/utils/embedding_functions/roboflow_embedding_function.py index e3a16fa943a..4fa3b0e43b2 100644 --- a/chromadb/utils/embedding_functions/roboflow_embedding_function.py +++ b/chromadb/utils/embedding_functions/roboflow_embedding_function.py @@ -1,26 +1,28 @@ +import base64 +import importlib import logging +import os +from io import BytesIO +from typing import Union + +import requests from chromadb.api.types import ( Documents, - Images, EmbeddingFunction, Embeddings, - is_image, + Images, is_document, + is_image, ) -from io import BytesIO -import os -import requests -from typing import Union -import importlib -import base64 - logger = logging.getLogger(__name__) class RoboflowEmbeddingFunction(EmbeddingFunction[Union[Documents, Images]]): - def __init__(self, api_key: str = "", api_url="https://infer.roboflow.com") -> None: + def __init__( + self, api_key: str = "", api_url: str = "https://infer.roboflow.com" + ) -> None: """ Create a RoboflowEmbeddingFunction. @@ -29,7 +31,7 @@ def __init__(self, api_key: str = "", api_url="https://infer.roboflow.com") -> N api_url (str, optional): The URL of the Roboflow API. Defaults to "https://infer.roboflow.com". """ if not api_key: - api_key = os.environ.get("ROBOFLOW_API_KEY") + api_key = os.environ.get("ROBOFLOW_API_KEY", "") self._api_url = api_url self._api_key = api_key @@ -52,7 +54,7 @@ def __call__(self, input: Union[Documents, Images]) -> Embeddings: image.save(buffer, format="JPEG") base64_image = base64.b64encode(buffer.getvalue()).decode("utf-8") - infer_clip_payload = { + infer_clip_payload_image = { "image": { "type": "base64", "value": base64_image, @@ -61,7 +63,7 @@ def __call__(self, input: Union[Documents, Images]) -> Embeddings: res = requests.post( f"{self._api_url}/clip/embed_image?api_key={self._api_key}", - json=infer_clip_payload, + json=infer_clip_payload_image, ) result = res.json()["embeddings"] @@ -69,17 +71,17 @@ def __call__(self, input: Union[Documents, Images]) -> Embeddings: embeddings.append(result[0]) elif is_document(item): - infer_clip_payload = { + infer_clip_payload_text = { "text": input, } res = requests.post( f"{self._api_url}/clip/embed_text?api_key={self._api_key}", - json=infer_clip_payload, + json=infer_clip_payload_text, ) result = res.json()["embeddings"] embeddings.append(result[0]) - return embeddings \ No newline at end of file + return embeddings diff --git a/chromadb/utils/embedding_functions/sentence_transformer_embedding_function.py b/chromadb/utils/embedding_functions/sentence_transformer_embedding_function.py index 682dfc1f1e2..2ca530b0a30 100644 --- a/chromadb/utils/embedding_functions/sentence_transformer_embedding_function.py +++ b/chromadb/utils/embedding_functions/sentence_transformer_embedding_function.py @@ -1,9 +1,11 @@ import logging from typing import Any, Dict, cast + from chromadb.api.types import Documents, EmbeddingFunction, Embeddings logger = logging.getLogger(__name__) + class SentenceTransformerEmbeddingFunction(EmbeddingFunction[Documents]): # Since we do dynamic imports we have to type this as Any models: Dict[str, Any] = {} diff --git a/chromadb/utils/embedding_functions/text2vec_embedding_function.py b/chromadb/utils/embedding_functions/text2vec_embedding_function.py index 7b73c6b889d..86a45deff24 100644 --- a/chromadb/utils/embedding_functions/text2vec_embedding_function.py +++ b/chromadb/utils/embedding_functions/text2vec_embedding_function.py @@ -1,6 +1,7 @@ +import logging from typing import cast + from chromadb.api.types import Documents, EmbeddingFunction, Embeddings -import logging logger = logging.getLogger(__name__) From 8f08d60ebdfbe8aa7f6ef37c03a6ca1ffceb6a0e Mon Sep 17 00:00:00 2001 From: David Reguera Date: Sat, 20 Apr 2024 12:12:52 +0200 Subject: [PATCH 18/23] 1965 - Lint onnx embedding function --- .../embedding_functions/onnx_mini_lm_l6_v2.py | 97 ++++++++----------- 1 file changed, 38 insertions(+), 59 deletions(-) diff --git a/chromadb/utils/embedding_functions/onnx_mini_lm_l6_v2.py b/chromadb/utils/embedding_functions/onnx_mini_lm_l6_v2.py index 3aff9a37a0e..d1c798c3745 100644 --- a/chromadb/utils/embedding_functions/onnx_mini_lm_l6_v2.py +++ b/chromadb/utils/embedding_functions/onnx_mini_lm_l6_v2.py @@ -1,30 +1,22 @@ import hashlib import logging -from functools import cached_property - -from tenacity import stop_after_attempt, wait_random, retry, retry_if_exception - -from chromadb.api.types import ( - Documents, - EmbeddingFunction, - Embeddings, -) - -from pathlib import Path import os import tarfile -import requests -from typing import TYPE_CHECKING, List, Optional, cast +from functools import cached_property +from pathlib import Path +from typing import List, Optional, cast + import numpy as np import numpy.typing as npt -import importlib - -logger = logging.getLogger(__name__) +import requests +from onnxruntime import InferenceSession, get_available_providers, SessionOptions +from tenacity import retry, retry_if_exception, stop_after_attempt, wait_random +from tokenizers import Tokenizer +from tqdm import tqdm +from chromadb.api.types import Documents, EmbeddingFunction, Embeddings -if TYPE_CHECKING: - from onnxruntime import InferenceSession - from tokenizers import Tokenizer +logger = logging.getLogger(__name__) def _verify_sha256(fname: str, expected_sha256: str) -> bool: @@ -36,6 +28,7 @@ def _verify_sha256(fname: str, expected_sha256: str) -> bool: return sha256_hash.hexdigest() == expected_sha256 + # In order to remove dependencies on sentence-transformers, which in turn depends on # pytorch and sentence-piece we have created a default ONNX embedding function that # implements the same functionality as "all-MiniLM-L6-v2" from sentence-transformers. @@ -67,40 +60,27 @@ def __init__(self, preferred_providers: Optional[List[str]] = None) -> None: ): raise ValueError("Preferred providers must be unique") self._preferred_providers = preferred_providers - try: - # Equivalent to import onnxruntime - self.ort = importlib.import_module("onnxruntime") - except ImportError: - raise ValueError( - "The onnxruntime python package is not installed. Please install it with `pip install onnxruntime`" - ) - try: - # Equivalent to from tokenizers import Tokenizer - self.Tokenizer = importlib.import_module("tokenizers").Tokenizer - except ImportError: - raise ValueError( - "The tokenizers python package is not installed. Please install it with `pip install tokenizers`" - ) - try: - # Equivalent to from tqdm import tqdm - self.tqdm = importlib.import_module("tqdm").tqdm - except ImportError: - raise ValueError( - "The tqdm python package is not installed. Please install it with `pip install tqdm`" - ) # Borrowed from https://gist.github.com/yanqd0/c13ed29e29432e3cf3e7c38467f42f51 # Download with tqdm to preserve the sentence-transformers experience - @retry( + @retry( # type: ignore reraise=True, stop=stop_after_attempt(3), wait=wait_random(min=1, max=3), retry=retry_if_exception(lambda e: "does not match expected SHA256" in str(e)), ) def _download(self, url: str, fname: str, chunk_size: int = 1024) -> None: + """ + Download the onnx model from the URL and save it to the file path. + + About ignored types: + tenacity.retry decorator is a bit convoluted when it comes to type annotations + which makes mypy unhappy. If some smart folk knows how to fix this in an + elegant way, please do so. + """ resp = requests.get(url, stream=True) total = int(resp.headers.get("content-length", 0)) - with open(fname, "wb") as file, self.tqdm( + with open(fname, "wb") as file, tqdm( desc=str(fname), total=total, unit="iB", @@ -119,15 +99,14 @@ def _download(self, url: str, fname: str, chunk_size: int = 1024) -> None: # Use pytorches default epsilon for division by zero # https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html - def _normalize(self, v: npt.NDArray) -> npt.NDArray: + def _normalize(self, v: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]: norm = np.linalg.norm(v, axis=1) norm[norm == 0] = 1e-12 - return cast(npt.NDArray, v / norm[:, np.newaxis]) + return cast(npt.NDArray[np.float32], v / norm[:, np.newaxis]) - def _forward(self, documents: List[str], batch_size: int = 32) -> npt.NDArray: - # We need to cast to the correct type because the type checker doesn't know that init_model_and_tokenizer will set the values - self.tokenizer = cast(self.Tokenizer, self.tokenizer) - self.model = cast(self.ort.InferenceSession, self.model) + def _forward( + self, documents: List[str], batch_size: int = 32 + ) -> npt.NDArray[np.float32]: all_embeddings = [] for i in range(0, len(documents), batch_size): batch = documents[i : i + batch_size] @@ -156,8 +135,8 @@ def _forward(self, documents: List[str], batch_size: int = 32) -> npt.NDArray: return np.concatenate(all_embeddings) @cached_property - def tokenizer(self) -> "Tokenizer": - tokenizer = self.Tokenizer.from_file( + def tokenizer(self) -> Tokenizer: + tokenizer = Tokenizer.from_file( os.path.join( self.DOWNLOAD_PATH, self.EXTRACTED_FOLDER_NAME, "tokenizer.json" ) @@ -169,26 +148,26 @@ def tokenizer(self) -> "Tokenizer": return tokenizer @cached_property - def model(self) -> "InferenceSession": + def model(self) -> InferenceSession: if self._preferred_providers is None or len(self._preferred_providers) == 0: - if len(self.ort.get_available_providers()) > 0: + if len(get_available_providers()) > 0: logger.debug( f"WARNING: No ONNX providers provided, defaulting to available providers: " - f"{self.ort.get_available_providers()}" + f"{get_available_providers()}" ) - self._preferred_providers = self.ort.get_available_providers() + self._preferred_providers = get_available_providers() elif not set(self._preferred_providers).issubset( - set(self.ort.get_available_providers()) + set(get_available_providers()) ): raise ValueError( - f"Preferred providers must be subset of available providers: {self.ort.get_available_providers()}" + f"Preferred providers must be subset of available providers: {get_available_providers()}" ) # Suppress onnxruntime warnings. This produces logspew, mainly when onnx tries to use CoreML, which doesn't fit this model. - so = self.ort.SessionOptions() + so = SessionOptions() so.log_severity_level = 3 - return self.ort.InferenceSession( + return InferenceSession( os.path.join(self.DOWNLOAD_PATH, self.EXTRACTED_FOLDER_NAME, "model.onnx"), # Since 1.9 onnyx runtime requires providers to be specified when there are multiple available - https://onnxruntime.ai/docs/api/python/api_summary.html # This is probably not ideal but will improve DX as no exceptions will be raised in multi-provider envs @@ -233,4 +212,4 @@ def _download_model_if_not_exists(self) -> None: name=os.path.join(self.DOWNLOAD_PATH, self.ARCHIVE_FILENAME), mode="r:gz", ) as tar: - tar.extractall(path=self.DOWNLOAD_PATH) \ No newline at end of file + tar.extractall(path=self.DOWNLOAD_PATH) From fc6b3c86d9a1c4b6b734f8afc636aad75c7f03cc Mon Sep 17 00:00:00 2001 From: David Reguera Date: Wed, 24 Apr 2024 11:07:53 +0200 Subject: [PATCH 19/23] 1965 - Ensure that `get_builtins()` holds after the migration. --- chromadb/test/ef/test_ef.py | 34 +++++++++++++++++++ .../utils/embedding_functions/__init__.py | 2 +- 2 files changed, 35 insertions(+), 1 deletion(-) create mode 100644 chromadb/test/ef/test_ef.py diff --git a/chromadb/test/ef/test_ef.py b/chromadb/test/ef/test_ef.py new file mode 100644 index 00000000000..70b4b6e65d2 --- /dev/null +++ b/chromadb/test/ef/test_ef.py @@ -0,0 +1,34 @@ +from chromadb.utils import embedding_functions + + +def test_get_builtins_holds() -> None: + """ + Ensure that `get_builtins` is consistent after the ef migration. + + This test is intended to be temporary until the ef migration is complete as + these expected builtins are likely to grow as long as users add new + embedding functions. + + The hardcoded list of builtins was generated by running `get_builtins()` + on this commit: df65e5a65628ef9231f67ccc748a7d6b114c9c02 + """ + expected_builtins = { + "AmazonBedrockEmbeddingFunction", + "CohereEmbeddingFunction", + "GoogleGenerativeAiEmbeddingFunction", + "GooglePalmEmbeddingFunction", + "GoogleVertexEmbeddingFunction", + "HuggingFaceEmbeddingFunction", + "HuggingFaceEmbeddingServer", + "InstructorEmbeddingFunction", + "JinaEmbeddingFunction", + "ONNXMiniLM_L6_V2", + "OllamaEmbeddingFunction", + "OpenAIEmbeddingFunction", + "OpenCLIPEmbeddingFunction", + "RoboflowEmbeddingFunction", + "SentenceTransformerEmbeddingFunction", + "Text2VecEmbeddingFunction", + } + + assert expected_builtins == set(embedding_functions.get_builtins()) diff --git a/chromadb/utils/embedding_functions/__init__.py b/chromadb/utils/embedding_functions/__init__.py index 1226a0b034b..3340e35b8ef 100644 --- a/chromadb/utils/embedding_functions/__init__.py +++ b/chromadb/utils/embedding_functions/__init__.py @@ -87,7 +87,7 @@ def DefaultEmbeddingFunction() -> Optional[EmbeddingFunction[Documents]]: _classes = [ name for name, obj in inspect.getmembers(sys.modules[__name__], inspect.isclass) - if obj.__module__ == __name__ + if __name__ in obj.__module__ ] From 5c563870ceb0b404da19eb70d2d9d66521bb267c Mon Sep 17 00:00:00 2001 From: atroyn Date: Thu, 20 Jun 2024 13:26:50 -0700 Subject: [PATCH 20/23] Automate imports of EF in module --- chromadb/api/types.py | 11 +- chromadb/test/ef/test_default_ef.py | 5 +- chromadb/test/ef/test_ef.py | 3 +- .../utils/embedding_functions/__init__.py | 109 +++++------------- 4 files changed, 45 insertions(+), 83 deletions(-) diff --git a/chromadb/api/types.py b/chromadb/api/types.py index 01448cdc1b1..8c7a7113888 100644 --- a/chromadb/api/types.py +++ b/chromadb/api/types.py @@ -1,7 +1,7 @@ from typing import Optional, Union, TypeVar, List, Dict, Any, Tuple, cast from numpy.typing import NDArray import numpy as np -from typing_extensions import Literal, TypedDict, Protocol +from typing_extensions import Literal, TypedDict, Protocol, runtime_checkable import chromadb.errors as errors from chromadb.types import ( Metadata, @@ -99,7 +99,7 @@ def maybe_cast_one_to_many_document(target: OneOrMany[Document]) -> Documents: # Images -ImageDType = Union[np.uint, np.int_, np.float_] +ImageDType = Union[np.uint, np.int_, np.float_] # type: ignore[name-defined] Image = NDArray[ImageDType] Images = List[Image] @@ -182,6 +182,7 @@ class IndexMetadata(TypedDict): time_created: float +@runtime_checkable class EmbeddingFunction(Protocol[D]): def __call__(self, input: D) -> Embeddings: ... @@ -197,8 +198,10 @@ def __call__(self: EmbeddingFunction[D], input: D) -> Embeddings: setattr(cls, "__call__", __call__) - def embed_with_retries(self, input: D, **retry_kwargs: Dict) -> Embeddings: - return retry(**retry_kwargs)(self.__call__)(input) + def embed_with_retries( + self, input: D, **retry_kwargs: Dict[str, Any] + ) -> Embeddings: + return cast(Embeddings, retry(**retry_kwargs)(self.__call__)(input)) def validate_embedding_function( diff --git a/chromadb/test/ef/test_default_ef.py b/chromadb/test/ef/test_default_ef.py index 6d8fb623698..a80ccd2813b 100644 --- a/chromadb/test/ef/test_default_ef.py +++ b/chromadb/test/ef/test_default_ef.py @@ -7,7 +7,10 @@ import pytest from hypothesis import given, settings -from chromadb.utils.embedding_functions import ONNXMiniLM_L6_V2, _verify_sha256 +from chromadb.utils.embedding_functions.onnx_mini_lm_l6_v2 import ( + ONNXMiniLM_L6_V2, + _verify_sha256, +) def unique_by(x: Hashable) -> Hashable: diff --git a/chromadb/test/ef/test_ef.py b/chromadb/test/ef/test_ef.py index 70b4b6e65d2..9f5793203c2 100644 --- a/chromadb/test/ef/test_ef.py +++ b/chromadb/test/ef/test_ef.py @@ -29,6 +29,7 @@ def test_get_builtins_holds() -> None: "RoboflowEmbeddingFunction", "SentenceTransformerEmbeddingFunction", "Text2VecEmbeddingFunction", + "ChromaLangchainEmbeddingFunction", } - assert expected_builtins == set(embedding_functions.get_builtins()) + assert expected_builtins == embedding_functions.get_builtins() diff --git a/chromadb/utils/embedding_functions/__init__.py b/chromadb/utils/embedding_functions/__init__.py index 3340e35b8ef..3df92b7654e 100644 --- a/chromadb/utils/embedding_functions/__init__.py +++ b/chromadb/utils/embedding_functions/__init__.py @@ -1,75 +1,18 @@ -import inspect -import sys -from typing import List, Optional +import os +import importlib +import pkgutil +from types import ModuleType +from typing import Optional, Set, cast from chromadb.api.types import Documents, EmbeddingFunction -from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import ( - AmazonBedrockEmbeddingFunction, -) -from chromadb.utils.embedding_functions.chroma_langchain_embedding_function import ( + +# Langchain embedding function is a special snowflake +from chromadb.utils.embedding_functions.chroma_langchain_embedding_function import ( # noqa: F401 create_langchain_embedding, ) -from chromadb.utils.embedding_functions.cohere_embedding_function import ( - CohereEmbeddingFunction, -) -from chromadb.utils.embedding_functions.google_embedding_function import ( - GoogleGenerativeAiEmbeddingFunction, - GooglePalmEmbeddingFunction, - GoogleVertexEmbeddingFunction, -) -from chromadb.utils.embedding_functions.huggingface_embedding_function import ( - HuggingFaceEmbeddingFunction, - HuggingFaceEmbeddingServer, -) -from chromadb.utils.embedding_functions.instructor_embedding_function import ( - InstructorEmbeddingFunction, -) -from chromadb.utils.embedding_functions.jina_embedding_function import ( - JinaEmbeddingFunction, -) -from chromadb.utils.embedding_functions.ollama_embedding_function import ( - OllamaEmbeddingFunction, -) -from chromadb.utils.embedding_functions.onnx_mini_lm_l6_v2 import ( - ONNXMiniLM_L6_V2, - _verify_sha256, -) -from chromadb.utils.embedding_functions.open_clip_embedding_function import ( - OpenCLIPEmbeddingFunction, -) -from chromadb.utils.embedding_functions.openai_embedding_function import ( - OpenAIEmbeddingFunction, -) -from chromadb.utils.embedding_functions.roboflow_embedding_function import ( - RoboflowEmbeddingFunction, -) -from chromadb.utils.embedding_functions.sentence_transformer_embedding_function import ( - SentenceTransformerEmbeddingFunction, -) -from chromadb.utils.embedding_functions.text2vec_embedding_function import ( - Text2VecEmbeddingFunction, -) -__all__ = [ - "AmazonBedrockEmbeddingFunction", - "create_langchain_embedding", - "CohereEmbeddingFunction", - "GoogleGenerativeAiEmbeddingFunction", - "GooglePalmEmbeddingFunction", - "GoogleVertexEmbeddingFunction", - "HuggingFaceEmbeddingFunction", - "HuggingFaceEmbeddingServer", - "InstructorEmbeddingFunction", - "JinaEmbeddingFunction", - "OllamaEmbeddingFunction", - "OpenCLIPEmbeddingFunction", - "ONNXMiniLM_L6_V2", - "OpenAIEmbeddingFunction", - "RoboflowEmbeddingFunction", - "SentenceTransformerEmbeddingFunction", - "Text2VecEmbeddingFunction", - "_verify_sha256", -] +_all_classes: Set[str] = set() +_all_classes.add("ChromaLangchainEmbeddingFunction") try: from chromadb.is_thin_client import is_thin_client @@ -77,19 +20,31 @@ is_thin_client = False +_module_dir = os.path.dirname(__file__) +for _, module_name, _ in pkgutil.iter_modules([_module_dir]): # type: ignore[assignment] + module: ModuleType = importlib.import_module(f"{__name__}.{module_name}") + + for attr_name in dir(module): + attr = getattr(module, attr_name) + if ( + isinstance(attr, type) + and issubclass(attr, EmbeddingFunction) + and attr is not EmbeddingFunction # Don't re-export the type + ): + globals()[attr.__name__] = attr + _all_classes.add(attr.__name__) + + +# Define and export the default embedding function def DefaultEmbeddingFunction() -> Optional[EmbeddingFunction[Documents]]: if is_thin_client: return None else: - return ONNXMiniLM_L6_V2() - - -_classes = [ - name - for name, obj in inspect.getmembers(sys.modules[__name__], inspect.isclass) - if __name__ in obj.__module__ -] + return cast( + EmbeddingFunction[Documents], + ONNXMiniLM_L6_V2(), # type: ignore[name-defined] # noqa: F821 + ) -def get_builtins() -> List[str]: - return _classes +def get_builtins() -> Set[str]: + return _all_classes From a548218ed74f5858b62e72185f4ccb66ee680f44 Mon Sep 17 00:00:00 2001 From: atroyn Date: Thu, 20 Jun 2024 13:57:44 -0700 Subject: [PATCH 21/23] Automate imports of EF in module --- .../utils/embedding_functions/__init__.py | 35 ++++++++++++------- 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/chromadb/utils/embedding_functions/__init__.py b/chromadb/utils/embedding_functions/__init__.py index 3df92b7654e..2f0bf0f5cf2 100644 --- a/chromadb/utils/embedding_functions/__init__.py +++ b/chromadb/utils/embedding_functions/__init__.py @@ -20,19 +20,29 @@ is_thin_client = False -_module_dir = os.path.dirname(__file__) -for _, module_name, _ in pkgutil.iter_modules([_module_dir]): # type: ignore[assignment] - module: ModuleType = importlib.import_module(f"{__name__}.{module_name}") +def _import_all_efs() -> Set[str]: + imported_classes = set() + _module_dir = os.path.dirname(__file__) + for _, module_name, _ in pkgutil.iter_modules([_module_dir]): + # Skip the current module + if module_name == __name__: + continue - for attr_name in dir(module): - attr = getattr(module, attr_name) - if ( - isinstance(attr, type) - and issubclass(attr, EmbeddingFunction) - and attr is not EmbeddingFunction # Don't re-export the type - ): - globals()[attr.__name__] = attr - _all_classes.add(attr.__name__) + module: ModuleType = importlib.import_module(f"{__name__}.{module_name}") + + for attr_name in dir(module): + attr = getattr(module, attr_name) + if ( + isinstance(attr, type) + and issubclass(attr, EmbeddingFunction) + and attr is not EmbeddingFunction # Don't re-export the type + ): + globals()[attr.__name__] = attr + imported_classes.add(attr.__name__) + return imported_classes + + +_all_classes.update(_import_all_efs()) # Define and export the default embedding function @@ -42,6 +52,7 @@ def DefaultEmbeddingFunction() -> Optional[EmbeddingFunction[Documents]]: else: return cast( EmbeddingFunction[Documents], + # This is implicitly imported above ONNXMiniLM_L6_V2(), # type: ignore[name-defined] # noqa: F821 ) From cbb0b0338c48165a8853b8470f60f7f0d325b2c6 Mon Sep 17 00:00:00 2001 From: atroyn Date: Thu, 20 Jun 2024 14:05:15 -0700 Subject: [PATCH 22/23] Additional tests --- chromadb/test/ef/test_ef.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/chromadb/test/ef/test_ef.py b/chromadb/test/ef/test_ef.py index 9f5793203c2..c93502e3fc8 100644 --- a/chromadb/test/ef/test_ef.py +++ b/chromadb/test/ef/test_ef.py @@ -1,4 +1,5 @@ from chromadb.utils import embedding_functions +from chromadb.api.types import EmbeddingFunction def test_get_builtins_holds() -> None: @@ -9,8 +10,7 @@ def test_get_builtins_holds() -> None: these expected builtins are likely to grow as long as users add new embedding functions. - The hardcoded list of builtins was generated by running `get_builtins()` - on this commit: df65e5a65628ef9231f67ccc748a7d6b114c9c02 + REMOVE ME ON THE NEXT EF ADDITION """ expected_builtins = { "AmazonBedrockEmbeddingFunction", @@ -33,3 +33,21 @@ def test_get_builtins_holds() -> None: } assert expected_builtins == embedding_functions.get_builtins() + + +def test_default_ef_exists() -> None: + assert hasattr(embedding_functions, "DefaultEmbeddingFunction") + default_ef = embedding_functions.DefaultEmbeddingFunction() + + assert default_ef is not None + assert isinstance(default_ef, EmbeddingFunction) + + +def test_ef_imports() -> None: + for ef in embedding_functions.get_builtins(): + # Langchain embedding function is a special snowflake + if ef == "ChromaLangchainEmbeddingFunction": + continue + assert hasattr(embedding_functions, ef) + assert isinstance(getattr(embedding_functions, ef), type) + assert issubclass(getattr(embedding_functions, ef), EmbeddingFunction) From 41a3e911d165af5f3ba41423de8db25e6690d0f7 Mon Sep 17 00:00:00 2001 From: atroyn Date: Thu, 20 Jun 2024 14:19:40 -0700 Subject: [PATCH 23/23] httpx everywhere --- .../google_embedding_function.py | 4 +- .../huggingface_embedding_function.py | 12 ++--- .../jina_embedding_function.py | 4 +- .../ollama_embedding_function.py | 10 ++-- .../embedding_functions/onnx_mini_lm_l6_v2.py | 49 +++++++++++++------ .../roboflow_embedding_function.py | 6 +-- 6 files changed, 48 insertions(+), 37 deletions(-) diff --git a/chromadb/utils/embedding_functions/google_embedding_function.py b/chromadb/utils/embedding_functions/google_embedding_function.py index 5db890e5a2f..0534d790674 100644 --- a/chromadb/utils/embedding_functions/google_embedding_function.py +++ b/chromadb/utils/embedding_functions/google_embedding_function.py @@ -1,6 +1,6 @@ import logging -import requests +import httpx from chromadb.api.types import Documents, EmbeddingFunction, Embeddings @@ -94,7 +94,7 @@ def __init__( region: str = "us-central1", ): self._api_url = f"https://{region}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{region}/publishers/goole/models/{model_name}:predict" - self._session = requests.Session() + self._session = httpx.Client() self._session.headers.update({"Authorization": f"Bearer {api_key}"}) def __call__(self, input: Documents) -> Embeddings: diff --git a/chromadb/utils/embedding_functions/huggingface_embedding_function.py b/chromadb/utils/embedding_functions/huggingface_embedding_function.py index 541c3a4ca4b..376a98fa4ae 100644 --- a/chromadb/utils/embedding_functions/huggingface_embedding_function.py +++ b/chromadb/utils/embedding_functions/huggingface_embedding_function.py @@ -1,7 +1,7 @@ import logging from typing import cast -import requests +import httpx from chromadb.api.types import Documents, EmbeddingFunction, Embeddings @@ -25,7 +25,7 @@ def __init__( model_name (str, optional): The name of the model to use for text embeddings. Defaults to "sentence-transformers/all-MiniLM-L6-v2". """ self._api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{model_name}" - self._session = requests.Session() + self._session = httpx.Client() self._session.headers.update({"Authorization": f"Bearer {api_key}"}) def __call__(self, input: Documents) -> Embeddings: @@ -66,14 +66,8 @@ def __init__(self, url: str): Args: url (str): The URL of the HuggingFace Embedding Server. """ - try: - import requests - except ImportError: - raise ValueError( - "The requests python package is not installed. Please install it with `pip install requests`" - ) self._api_url = f"{url}" - self._session = requests.Session() + self._session = httpx.Client() def __call__(self, input: Documents) -> Embeddings: """ diff --git a/chromadb/utils/embedding_functions/jina_embedding_function.py b/chromadb/utils/embedding_functions/jina_embedding_function.py index 99baa4089a9..f631bef4df8 100644 --- a/chromadb/utils/embedding_functions/jina_embedding_function.py +++ b/chromadb/utils/embedding_functions/jina_embedding_function.py @@ -1,7 +1,7 @@ import logging from typing import List, cast, Union -import requests +import httpx from chromadb.api.types import Documents, EmbeddingFunction, Embeddings @@ -24,7 +24,7 @@ def __init__(self, api_key: str, model_name: str = "jina-embeddings-v2-base-en") """ self._model_name = model_name self._api_url = "https://api.jina.ai/v1/embeddings" - self._session = requests.Session() + self._session = httpx.Client() self._session.headers.update( {"Authorization": f"Bearer {api_key}", "Accept-Encoding": "identity"} ) diff --git a/chromadb/utils/embedding_functions/ollama_embedding_function.py b/chromadb/utils/embedding_functions/ollama_embedding_function.py index 6cc1e0e4c7b..a6293e36075 100644 --- a/chromadb/utils/embedding_functions/ollama_embedding_function.py +++ b/chromadb/utils/embedding_functions/ollama_embedding_function.py @@ -1,6 +1,8 @@ import logging from typing import Union, cast +import httpx + from chromadb.api.types import Documents, EmbeddingFunction, Embeddings logger = logging.getLogger(__name__) @@ -19,15 +21,9 @@ def __init__(self, url: str, model_name: str) -> None: url (str): The URL of the Ollama Server. model_name (str): The name of the model to use for text embeddings. E.g. "nomic-embed-text" (see https://ollama.com/library for available models). """ - try: - import requests - except ImportError: - raise ValueError( - "The requests python package is not installed. Please install it with `pip install requests`" - ) self._api_url = f"{url}" self._model_name = model_name - self._session = requests.Session() + self._session = httpx.Client() def __call__(self, input: Union[Documents, str]) -> Embeddings: """ diff --git a/chromadb/utils/embedding_functions/onnx_mini_lm_l6_v2.py b/chromadb/utils/embedding_functions/onnx_mini_lm_l6_v2.py index d1c798c3745..3120f3ffad8 100644 --- a/chromadb/utils/embedding_functions/onnx_mini_lm_l6_v2.py +++ b/chromadb/utils/embedding_functions/onnx_mini_lm_l6_v2.py @@ -1,4 +1,5 @@ import hashlib +import importlib import logging import os import tarfile @@ -8,11 +9,10 @@ import numpy as np import numpy.typing as npt -import requests +import httpx from onnxruntime import InferenceSession, get_available_providers, SessionOptions from tenacity import retry, retry_if_exception, stop_after_attempt, wait_random from tokenizers import Tokenizer -from tqdm import tqdm from chromadb.api.types import Documents, EmbeddingFunction, Embeddings @@ -60,6 +60,27 @@ def __init__(self, preferred_providers: Optional[List[str]] = None) -> None: ): raise ValueError("Preferred providers must be unique") self._preferred_providers = preferred_providers + try: + # Equivalent to import onnxruntime + self.ort = importlib.import_module("onnxruntime") + except ImportError: + raise ValueError( + "The onnxruntime python package is not installed. Please install it with `pip install onnxruntime`" + ) + try: + # Equivalent to from tokenizers import Tokenizer + self.Tokenizer = importlib.import_module("tokenizers").Tokenizer + except ImportError: + raise ValueError( + "The tokenizers python package is not installed. Please install it with `pip install tokenizers`" + ) + try: + # Equivalent to from tqdm import tqdm + self.tqdm = importlib.import_module("tqdm").tqdm + except ImportError: + raise ValueError( + "The tqdm python package is not installed. Please install it with `pip install tqdm`" + ) # Borrowed from https://gist.github.com/yanqd0/c13ed29e29432e3cf3e7c38467f42f51 # Download with tqdm to preserve the sentence-transformers experience @@ -78,18 +99,18 @@ def _download(self, url: str, fname: str, chunk_size: int = 1024) -> None: which makes mypy unhappy. If some smart folk knows how to fix this in an elegant way, please do so. """ - resp = requests.get(url, stream=True) - total = int(resp.headers.get("content-length", 0)) - with open(fname, "wb") as file, tqdm( - desc=str(fname), - total=total, - unit="iB", - unit_scale=True, - unit_divisor=1024, - ) as bar: - for data in resp.iter_content(chunk_size=chunk_size): - size = file.write(data) - bar.update(size) + with httpx.stream("GET", url) as resp: + total = int(resp.headers.get("content-length", 0)) + with open(fname, "wb") as file, self.tqdm( + desc=str(fname), + total=total, + unit="iB", + unit_scale=True, + unit_divisor=1024, + ) as bar: + for data in resp.iter_bytes(chunk_size=chunk_size): + size = file.write(data) + bar.update(size) if not _verify_sha256(fname, self._MODEL_SHA256): # if the integrity of the file is not verified, remove it os.remove(fname) diff --git a/chromadb/utils/embedding_functions/roboflow_embedding_function.py b/chromadb/utils/embedding_functions/roboflow_embedding_function.py index 4fa3b0e43b2..b118aa01c64 100644 --- a/chromadb/utils/embedding_functions/roboflow_embedding_function.py +++ b/chromadb/utils/embedding_functions/roboflow_embedding_function.py @@ -5,7 +5,7 @@ from io import BytesIO from typing import Union -import requests +import httpx from chromadb.api.types import ( Documents, @@ -61,7 +61,7 @@ def __call__(self, input: Union[Documents, Images]) -> Embeddings: }, } - res = requests.post( + res = httpx.post( f"{self._api_url}/clip/embed_image?api_key={self._api_key}", json=infer_clip_payload_image, ) @@ -75,7 +75,7 @@ def __call__(self, input: Union[Documents, Images]) -> Embeddings: "text": input, } - res = requests.post( + res = httpx.post( f"{self._api_url}/clip/embed_text?api_key={self._api_key}", json=infer_clip_payload_text, )