Skip to content
72 changes: 72 additions & 0 deletions chromadb/test/ef/test_cloudflare_ef.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import os

import pytest

from chromadb.utils.embedding_functions.cloudflare_workers_ai_embedding_function import (
CloudflareWorkersAIEmbeddingFunction,
)


@pytest.mark.skipif(
"CF_API_TOKEN" not in os.environ,
reason="CF_API_TOKEN and CF_ACCOUNT_ID not set, skipping test.",
)
def test_cf_ef_token_and_account() -> None:
ef = CloudflareWorkersAIEmbeddingFunction(
api_token=os.environ.get("CF_API_TOKEN", ""),
account_id=os.environ.get("CF_ACCOUNT_ID"),
)
embeddings = ef(["test doc"])
assert embeddings is not None
assert len(embeddings) == 1
assert len(embeddings[0]) > 0


@pytest.mark.skipif(
"CF_API_TOKEN" not in os.environ,
reason="CF_API_TOKEN and CF_ACCOUNT_ID not set, skipping test.",
)
def test_cf_ef_gateway() -> None:
ef = CloudflareWorkersAIEmbeddingFunction(
api_token=os.environ.get("CF_API_TOKEN", ""),
gateway_url=os.environ.get("CF_GATEWAY_ENDPOINT"),
)
embeddings = ef(["test doc"])
assert embeddings is not None
assert len(embeddings) == 1
assert len(embeddings[0]) > 0


@pytest.mark.skipif(
"CF_API_TOKEN" not in os.environ,
reason="CF_API_TOKEN and CF_ACCOUNT_ID not set, skipping test.",
)
def test_cf_ef_large_batch() -> None:
ef = CloudflareWorkersAIEmbeddingFunction(api_token="dummy", account_id="dummy")
with pytest.raises(ValueError, match="Batch too large"):
ef(["test doc"] * 101)


@pytest.mark.skipif(
"CF_API_TOKEN" not in os.environ,
reason="CF_API_TOKEN and CF_ACCOUNT_ID not set, skipping test.",
)
def test_cf_ef_missing_account_or_gateway() -> None:
with pytest.raises(
ValueError, match="Please provide either an account_id or a gateway_url"
):
CloudflareWorkersAIEmbeddingFunction(api_token="dummy")


@pytest.mark.skipif(
"CF_API_TOKEN" not in os.environ,
reason="CF_API_TOKEN and CF_ACCOUNT_ID not set, skipping test.",
)
def test_cf_ef_with_account_or_gateway() -> None:
with pytest.raises(
ValueError,
match="Please provide either an account_id or a gateway_url, not both",
):
CloudflareWorkersAIEmbeddingFunction(
api_token="dummy", account_id="dummy", gateway_url="dummy"
)
1 change: 1 addition & 0 deletions chromadb/test/ef/test_ef.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def test_get_builtins_holds() -> None:
"SentenceTransformerEmbeddingFunction",
"Text2VecEmbeddingFunction",
"ChromaLangchainEmbeddingFunction",
"CloudflareWorkersAIEmbeddingFunction",
}

assert expected_builtins == embedding_functions.get_builtins()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import logging
from typing import Optional, Dict, cast

import httpx

from chromadb.api.types import Documents, EmbeddingFunction, Embeddings

logger = logging.getLogger(__name__)


class CloudflareWorkersAIEmbeddingFunction(EmbeddingFunction[Documents]):
# Follow API Quickstart for Cloudflare Workers AI
# https://developers.cloudflare.com/workers-ai/
# Information about the text embedding modules in Google Vertex AI
# https://developers.cloudflare.com/workers-ai/models/embedding/
def __init__(
self,
api_token: str,
account_id: Optional[str] = None,
model_name: Optional[str] = "@cf/baai/bge-base-en-v1.5",
gateway_url: Optional[
str
] = None, # use Cloudflare AI Gateway instead of the usual endpoint
# right now endpoint schema supports up to 100 docs at a time
# https://developers.cloudflare.com/workers-ai/models/bge-small-en-v1.5/#api-schema (Input JSON Schema)
max_batch_size: Optional[int] = 100,
headers: Optional[Dict[str, str]] = None,
):
if not gateway_url and not account_id:
raise ValueError("Please provide either an account_id or a gateway_url.")
if gateway_url and account_id:
raise ValueError(
"Please provide either an account_id or a gateway_url, not both."
)
if gateway_url is not None and not gateway_url.endswith("/"):
gateway_url += "/"
self._api_url = (
f"{gateway_url}{model_name}"
if gateway_url is not None
else f"https://api.cloudflare.com/client/v4/accounts/{account_id}/ai/run/{model_name}"
)
self._session = httpx.Client()
self._session.headers.update(headers or {})
self._session.headers.update({"Authorization": f"Bearer {api_token}"})
self._max_batch_size = max_batch_size

def __call__(self, texts: Documents) -> Embeddings:
# Endpoint accepts up to 100 items at a time. We'll reject anything larger.
# It would be up to the user to split the input into smaller batches.
if self._max_batch_size and len(texts) > self._max_batch_size:
raise ValueError(
f"Batch too large {len(texts)} > {self._max_batch_size} (maximum batch size)."
)

print("URI", self._api_url)

response = self._session.post(f"{self._api_url}", json={"text": texts})
response.raise_for_status()
_json = response.json()
if "result" in _json and "data" in _json["result"]:
return cast(Embeddings, _json["result"]["data"])
else:
raise ValueError(f"Error calling Cloudflare Workers AI: {response.text}")
82 changes: 82 additions & 0 deletions clients/js/src/embeddings/CloudflareWorkersAIEmbeddingFunction.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import { IEmbeddingFunction } from "./IEmbeddingFunction";

export class CloudflareWorkersAIEmbeddingFunction
implements IEmbeddingFunction
{
private apiUrl: string;
private headers: { [key: string]: string };
private maxBatchSize: number;

constructor({
apiToken,
model,
accountId,
gatewayUrl,
maxBatchSize,
headers,
}: {
apiToken: string;
model?: string;
accountId?: string;
gatewayUrl?: string;
maxBatchSize?: number;
headers?: { [key: string]: string };
}) {
model = model || "@cf/baai/bge-base-en-v1.5";
this.maxBatchSize = maxBatchSize || 100;
if (accountId === undefined && gatewayUrl === undefined) {
throw new Error("Please provide either an accountId or a gatewayUrl.");
}
if (accountId !== undefined && gatewayUrl !== undefined) {
throw new Error(
"Please provide either an accountId or a gatewayUrl, not both.",
);
}
if (gatewayUrl !== undefined && !gatewayUrl.endsWith("/")) {
gatewayUrl += "/" + model;
}
this.apiUrl =
gatewayUrl ||
`https://api.cloudflare.com/client/v4/accounts/${accountId}/ai/run/${model}`;
this.headers = headers || {};
this.headers["Authorization"] = `Bearer ${apiToken}`;
}

public async generate(texts: string[]) {
if (texts.length === 0) {
return [];
}
if (texts.length > this.maxBatchSize) {
throw new Error(
`Batch too large ${texts.length} > ${this.maxBatchSize} (maximum batch size).`,
);
}
try {
const response = await fetch(this.apiUrl, {
method: "POST",
headers: this.headers,
body: JSON.stringify({
text: texts,
}),
});

const data = (await response.json()) as {
success?: boolean;
messages: any[];
errors?: any[];
result: { shape: any[]; data: number[][] };
};
if (data.success === false) {
throw new Error(`${JSON.stringify(data.errors)}`);
}
return data.result.data;
} catch (error) {
console.error(error);
if (error instanceof Error) {
throw new Error(`Error calling CF API: ${error}`);
} else {
throw new Error(`Error calling CF API: ${error}`);
}
}
}
}
1 change: 1 addition & 0 deletions clients/js/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ export { HuggingFaceEmbeddingServerFunction } from "./embeddings/HuggingFaceEmbe
export { JinaEmbeddingFunction } from "./embeddings/JinaEmbeddingFunction";
export { GoogleGenerativeAiEmbeddingFunction } from "./embeddings/GoogleGeminiEmbeddingFunction";
export { OllamaEmbeddingFunction } from "./embeddings/OllamaEmbeddingFunction";
export { CloudflareWorkersAIEmbeddingFunction } from "./embeddings/CloudflareWorkersAIEmbeddingFunction";

export {
IncludeEnum,
Expand Down
99 changes: 99 additions & 0 deletions clients/js/test/embeddings/cloudflare.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import { expect, test } from "@jest/globals";
import { DOCUMENTS } from "../data";
import { CloudflareWorkersAIEmbeddingFunction } from "../../src";

if (!process.env.CF_API_TOKEN) {
test.skip("it should generate Cloudflare embeddings with apiToken and AccountId", async () => {});
} else {
test("it should generate Cloudflare embeddings with apiToken and AccountId", async () => {
const embedder = new CloudflareWorkersAIEmbeddingFunction({
apiToken: process.env.CF_API_TOKEN as string,
accountId: process.env.CF_ACCOUNT_ID,
});
const embeddings = await embedder.generate(DOCUMENTS);
expect(embeddings).toBeDefined();
expect(embeddings.length).toBe(DOCUMENTS.length);
});
}

if (!process.env.CF_API_TOKEN) {
test.skip("it should generate Cloudflare embeddings with apiToken and AccountId and model", async () => {});
} else {
test("it should generate Cloudflare embeddings with apiToken and AccountId and model", async () => {
const embedder = new CloudflareWorkersAIEmbeddingFunction({
apiToken: process.env.CF_API_TOKEN as string,
accountId: process.env.CF_ACCOUNT_ID,
model: "@cf/baai/bge-small-en-v1.5",
});
const embeddings = await embedder.generate(DOCUMENTS);
expect(embeddings).toBeDefined();
expect(embeddings.length).toBe(DOCUMENTS.length);
});
}

if (!process.env.CF_API_TOKEN) {
test.skip("it should generate Cloudflare embeddings with apiToken and gateway", async () => {});
} else {
test("it should generate Cloudflare embeddings with apiToken and gateway", async () => {
const embedder = new CloudflareWorkersAIEmbeddingFunction({
apiToken: process.env.CF_API_TOKEN as string,
gatewayUrl: process.env.CF_GATEWAY_ENDPOINT,
});
const embeddings = await embedder.generate(DOCUMENTS);
expect(embeddings).toBeDefined();
expect(embeddings.length).toBe(DOCUMENTS.length);
});
}

if (!process.env.CF_API_TOKEN) {
test.skip("it should fail when batch too large", async () => {});
} else {
test("it should fail when batch too large", async () => {
const embedder = new CloudflareWorkersAIEmbeddingFunction({
apiToken: process.env.CF_API_TOKEN as string,
gatewayUrl: process.env.CF_GATEWAY_ENDPOINT,
});
const largeBatch = Array(100)
.fill([...DOCUMENTS])
.flat();
try {
await embedder.generate(largeBatch);
} catch (e: any) {
expect(e.message).toMatch("Batch too large");
}
});
}

if (!process.env.CF_API_TOKEN) {
test.skip("it should fail when gateway endpoint and account id are both provided", async () => {});
} else {
test("it should fail when gateway endpoint and account id are both provided", async () => {
try {
new CloudflareWorkersAIEmbeddingFunction({
apiToken: process.env.CF_API_TOKEN as string,
accountId: process.env.CF_ACCOUNT_ID,
gatewayUrl: process.env.CF_GATEWAY_ENDPOINT,
});
} catch (e: any) {
expect(e.message).toMatch(
"Please provide either an accountId or a gatewayUrl, not both.",
);
}
});
}

if (!process.env.CF_API_TOKEN) {
test.skip("it should fail when neither gateway endpoint nor account id are provided", async () => {});
} else {
test("it should fail when neither gateway endpoint nor account id are provided", async () => {
try {
new CloudflareWorkersAIEmbeddingFunction({
apiToken: process.env.CF_API_TOKEN as string,
});
} catch (e: any) {
expect(e.message).toMatch(
"Please provide either an accountId or a gatewayUrl.",
);
}
});
}
17 changes: 9 additions & 8 deletions docs/docs.trychroma.com/pages/guides/embeddings.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,16 @@ Chroma provides lightweight wrappers around popular embedding providers, making
{% special_table %}
{% /special_table %}

| | Python | JS |
|--------------|-----------|---------------|
| [OpenAI](/integrations/openai) | ✅ | ✅ |
| [Google Generative AI](/integrations/google-gemini) | ✅ | ✅ |
| [Cohere](/integrations/cohere) | ✅ | ✅ |
| [Hugging Face](/integrations/hugging-face) | ✅ | ➖ |
| [Instructor](/integrations/instructor) | ✅ | ➖ |
| | Python | JS |
|--------------------------------------------------------------------|-----------|---------------|
| [OpenAI](/integrations/openai) | ✅ | ✅ |
| [Google Generative AI](/integrations/google-gemini) | ✅ | ✅ |
| [Cohere](/integrations/cohere) | ✅ | ✅ |
| [Hugging Face](/integrations/hugging-face) | ✅ | ➖ |
| [Instructor](/integrations/instructor) | ✅ | ➖ |
| [Hugging Face Embedding Server](/integrations/hugging-face-server) | ✅ | ✅ |
| [Jina AI](/integrations/jinaai) | ✅ | ✅ |
| [Jina AI](/integrations/jinaai) | ✅ | ✅ |
| [Cloudflare Workers AI](/integrations/cloudflare) | ✅ | ✅ |

We welcome pull requests to add new Embedding Functions to the community.

Expand Down
1 change: 1 addition & 0 deletions docs/docs.trychroma.com/pages/integrations/_sidenav.js
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ export const items = [
{ href: '/integrations/jinaai', children: 'JinaAI' },
{ href: '/integrations/roboflow', children: 'Roboflow' },
{ href: '/integrations/ollama', children: 'Ollama Embeddings' },
{ href: '/integrations/cloudflare', children: 'Cloudflare Workers AI Embeddings' },
]
},
{
Expand Down
Loading