|
| 1 | +from typing import Any, Optional, cast |
| 2 | + |
| 3 | +from chromadb.api.types import Documents, EmbeddingFunction, Embeddings |
| 4 | + |
| 5 | + |
| 6 | +class FastEmbedEmbeddingFunction(EmbeddingFunction[Documents]): |
| 7 | + """ |
| 8 | + This class is used to generate embeddings for a list of texts using FastEmbed - https://qdrant.github.io/fastembed/. |
| 9 | + Find the list of supported models at https://qdrant.github.io/fastembed/examples/Supported_Models/. |
| 10 | + """ |
| 11 | + |
| 12 | + def __init__( |
| 13 | + self, |
| 14 | + model_name: str = "BAAI/bge-small-en-v1.5", |
| 15 | + cache_dir: Optional[str] = None, |
| 16 | + threads: Optional[int] = None, |
| 17 | + **kwargs: Any, |
| 18 | + ) -> None: |
| 19 | + """ |
| 20 | + Initialize fastembed.TextEmbedding |
| 21 | +
|
| 22 | + Args: |
| 23 | + model_name (str): The name of the model to use. |
| 24 | + cache_dir (str, optional): The path to the model cache directory. |
| 25 | + Can also be set using the `FASTEMBED_CACHE_PATH` env variable. |
| 26 | + threads (int, optional): The number of threads single onnxruntime session can use.. |
| 27 | +
|
| 28 | + Raises: |
| 29 | + ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en. |
| 30 | + """ |
| 31 | + try: |
| 32 | + from fastembed import TextEmbedding |
| 33 | + except ImportError: |
| 34 | + raise ValueError( |
| 35 | + "The 'fastembed' package is not installed. Please install it with `pip install fastembed`" |
| 36 | + ) |
| 37 | + self._model = TextEmbedding( |
| 38 | + model_name=model_name, cache_dir=cache_dir, threads=threads, **kwargs |
| 39 | + ) |
| 40 | + |
| 41 | + def __call__(self, input: Documents) -> Embeddings: |
| 42 | + """ |
| 43 | + Get the embeddings for a list of texts. |
| 44 | +
|
| 45 | + Args: |
| 46 | + input (Documents): A list of texts to get embeddings for. |
| 47 | +
|
| 48 | + Returns: |
| 49 | + Embeddings: The embeddings for the texts. |
| 50 | +
|
| 51 | + Example: |
| 52 | + >>> fastembed_ef = FastEmbedEmbeddingFunction(model_name="sentence-transformers/all-MiniLM-L6-v2") |
| 53 | + >>> texts = ["Hello, world!", "How are you?"] |
| 54 | + >>> embeddings = fastembed_ef(texts) |
| 55 | + """ |
| 56 | + embeddings = self._model.embed(input) |
| 57 | + return cast( |
| 58 | + Embeddings, |
| 59 | + [embedding.tolist() for embedding in embeddings], |
| 60 | + ) |
0 commit comments