Skip to content

Commit 129f948

Browse files
authored
feat: add Cohere Image Document Embedder (#2190)
* draft * improvements + tests * async * reorganize tests * fmt * small fix * fix comments * adjust test * add test_live_run_async
1 parent 6adb8f1 commit 129f948

File tree

9 files changed

+753
-7
lines changed

9 files changed

+753
-7
lines changed

integrations/cohere/pyproject.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ classifiers = [
2323
"Programming Language :: Python :: Implementation :: CPython",
2424
"Programming Language :: Python :: Implementation :: PyPy",
2525
]
26-
dependencies = ["haystack-ai>=2.15.1", "cohere>=5.16.0"]
26+
dependencies = ["haystack-ai>=2.16.1", "cohere>=5.16.0"]
2727

2828
[project.urls]
2929
Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/cohere#readme"
@@ -57,7 +57,9 @@ dependencies = [
5757
"pytest-cov",
5858
"pytest-rerunfailures",
5959
"mypy",
60-
"pip"
60+
"pip",
61+
"pillow", # image support
62+
"pypdfium2" # image support
6163
]
6264

6365
[tool.hatch.envs.test.scripts]

integrations/cohere/src/haystack_integrations/components/embedders/cohere/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44
from .document_embedder import CohereDocumentEmbedder
5+
from .document_image_embedder import CohereDocumentImageEmbedder
56
from .text_embedder import CohereTextEmbedder
67

7-
__all__ = ["CohereDocumentEmbedder", "CohereTextEmbedder"]
8+
__all__ = ["CohereDocumentEmbedder", "CohereDocumentImageEmbedder", "CohereTextEmbedder"]

integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_embedder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class CohereDocumentEmbedder:
2222
Usage example:
2323
```python
2424
from haystack import Document
25-
from cohere_haystack.embedders.document_embedder import CohereDocumentEmbedder
25+
from haystack_integrations.components.embedders.cohere import CohereDocumentEmbedder
2626
2727
doc = Document(content="I love pizza!")
2828
@@ -42,7 +42,7 @@ def __init__(
4242
input_type: str = "search_document",
4343
api_base_url: str = "https://api.cohere.com",
4444
truncate: str = "END",
45-
timeout: int = 120,
45+
timeout: float = 120.0,
4646
batch_size: int = 32,
4747
progress_bar: bool = True,
4848
meta_fields_to_embed: Optional[List[str]] = None,
Lines changed: 334 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,334 @@
1+
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from dataclasses import replace
6+
from typing import Any, Optional, Tuple
7+
8+
from haystack import Document, component, default_from_dict, default_to_dict, logging
9+
from haystack.components.converters.image.image_utils import (
10+
_batch_convert_pdf_pages_to_images,
11+
_encode_image_to_base64,
12+
_extract_image_sources_info,
13+
_PDFPageInfo,
14+
)
15+
from haystack.dataclasses import ByteStream
16+
from haystack.utils.auth import Secret, deserialize_secrets_inplace
17+
from tqdm import tqdm
18+
19+
from cohere import AsyncClientV2, ClientV2
20+
21+
from .embedding_types import EmbeddingTypes
22+
23+
# PDF is not officially supported, but we convert PDFs to JPEG images
24+
SUPPORTED_IMAGE_MIME_TYPES = ["image/jpeg", "image/png", "application/pdf"]
25+
26+
27+
logger = logging.getLogger(__name__)
28+
29+
30+
@component
31+
class CohereDocumentImageEmbedder:
32+
"""
33+
A component for computing Document embeddings based on images using Cohere models.
34+
35+
The embedding of each Document is stored in the `embedding` field of the Document.
36+
37+
### Usage example
38+
```python
39+
from haystack import Document
40+
from haystack_integrations.components.embedders.cohere import CohereDocumentImageEmbedder
41+
42+
embedder = CohereDocumentImageEmbedder(model="embed-v4.0")
43+
44+
documents = [
45+
Document(content="A photo of a cat", meta={"file_path": "cat.jpg"}),
46+
Document(content="A photo of a dog", meta={"file_path": "dog.jpg"}),
47+
]
48+
49+
result = embedder.run(documents=documents)
50+
documents_with_embeddings = result["documents"]
51+
print(documents_with_embeddings)
52+
53+
# [Document(id=...,
54+
# content='A photo of a cat',
55+
# meta={'file_path': 'cat.jpg',
56+
# 'embedding_source': {'type': 'image', 'file_path_meta_field': 'file_path'}},
57+
# embedding=vector of size 1536),
58+
# ...]
59+
```
60+
"""
61+
62+
def __init__(
63+
self,
64+
*,
65+
file_path_meta_field: str = "file_path",
66+
root_path: Optional[str] = None,
67+
image_size: Optional[Tuple[int, int]] = None,
68+
api_key: Secret = Secret.from_env_var(["COHERE_API_KEY", "CO_API_KEY"]),
69+
model: str = "embed-v4.0",
70+
api_base_url: str = "https://api.cohere.com",
71+
timeout: float = 120.0,
72+
embedding_dimension: Optional[int] = None,
73+
embedding_type: EmbeddingTypes = EmbeddingTypes.FLOAT,
74+
progress_bar: bool = True,
75+
) -> None:
76+
"""
77+
Creates a CohereDocumentImageEmbedder component.
78+
79+
:param file_path_meta_field:
80+
The metadata field in the Document that contains the file path to the image or PDF.
81+
:param root_path:
82+
The root directory path where document files are located. If provided, file paths in
83+
document metadata will be resolved relative to this path. If None, file paths are treated as absolute paths.
84+
:param image_size:
85+
If provided, resizes the image to fit within the specified dimensions (width, height) while
86+
maintaining aspect ratio. This reduces file size, memory usage, and processing time, which is beneficial
87+
when working with models that have resolution constraints or when transmitting images to remote services.
88+
:param api_key:
89+
The Cohere API key.
90+
:param model:
91+
The Cohere model to use for calculating embeddings.
92+
Read [Cohere documentation](https://docs.cohere.com/docs/models#embed) for a list of all supported models.
93+
:param api_base_url:
94+
The Cohere API base URL.
95+
:param timeout:
96+
Request timeout in seconds.
97+
:param embedding_dimension:
98+
The dimension of the embeddings to return. Only valid for v4 and newer models.
99+
Read [Cohere API reference](https://docs.cohere.com/reference/embed) for a list possible values and
100+
supported models.
101+
:param embedding_type:
102+
The type of embeddings to return. Defaults to float embeddings.
103+
Specifying a type different from float is only supported for Embed v3.0 and newer models.
104+
:param progress_bar:
105+
Whether to show a progress bar or not. Can be helpful to disable in production deployments
106+
to keep the logs clean.
107+
"""
108+
109+
self.file_path_meta_field = file_path_meta_field
110+
self.root_path = root_path or ""
111+
self.image_size = image_size
112+
self.model = model
113+
self.embedding_dimension = embedding_dimension
114+
self.embedding_type = embedding_type
115+
self.progress_bar = progress_bar
116+
117+
self._api_key = api_key
118+
self._api_base_url = api_base_url
119+
self._timeout = timeout
120+
121+
self._client = ClientV2(
122+
api_key=self._api_key.resolve_value(),
123+
base_url=self._api_base_url,
124+
timeout=self._timeout,
125+
client_name="haystack",
126+
)
127+
self._async_client = AsyncClientV2(
128+
api_key=self._api_key.resolve_value(),
129+
base_url=self._api_base_url,
130+
timeout=self._timeout,
131+
client_name="haystack",
132+
)
133+
134+
def to_dict(self) -> dict[str, Any]:
135+
"""
136+
Serializes the component to a dictionary.
137+
138+
:returns:
139+
Dictionary with serialized data.
140+
"""
141+
serialization_dict = default_to_dict(
142+
self,
143+
file_path_meta_field=self.file_path_meta_field,
144+
root_path=self.root_path,
145+
image_size=self.image_size,
146+
model=self.model,
147+
progress_bar=self.progress_bar,
148+
api_key=self._api_key.to_dict(),
149+
api_base_url=self._api_base_url,
150+
timeout=self._timeout,
151+
embedding_dimension=self.embedding_dimension,
152+
embedding_type=self.embedding_type.value,
153+
)
154+
return serialization_dict
155+
156+
@classmethod
157+
def from_dict(cls, data: dict[str, Any]) -> "CohereDocumentImageEmbedder":
158+
"""
159+
Deserializes the component from a dictionary.
160+
161+
:param data:
162+
Dictionary to deserialize from.
163+
:returns:
164+
Deserialized component.
165+
"""
166+
init_params = data["init_parameters"]
167+
deserialize_secrets_inplace(init_params, keys=["api_key"])
168+
init_params["embedding_type"] = EmbeddingTypes.from_str(init_params["embedding_type"])
169+
170+
return default_from_dict(cls, data)
171+
172+
def _extract_images_to_embed(self, documents: list[Document]) -> list[str]:
173+
"""
174+
Validates the input documents and extracts the images to embed in the format expected by the Cohere API.
175+
176+
:param documents:
177+
Documents to embed.
178+
179+
:returns:
180+
List of images to embed in the format expected by the Cohere API.
181+
182+
:raises TypeError:
183+
If the input is not a list of `Documents`.
184+
:raises ValueError:
185+
If the input contains unsupported image MIME types.
186+
:raises RuntimeError:
187+
If the conversion of some documents fails.
188+
"""
189+
if not isinstance(documents, list) or not all(isinstance(d, Document) for d in documents):
190+
msg = (
191+
"CohereDocumentImageEmbedder expects a list of Documents as input. "
192+
"In case you want to embed a string, please use the CohereTextEmbedder."
193+
)
194+
raise TypeError(msg)
195+
196+
images_source_info = _extract_image_sources_info(
197+
documents=documents, file_path_meta_field=self.file_path_meta_field, root_path=self.root_path
198+
)
199+
200+
for img_info in images_source_info:
201+
if img_info["mime_type"] not in SUPPORTED_IMAGE_MIME_TYPES:
202+
msg = (
203+
f"Unsupported image MIME type: {img_info['mime_type']}. "
204+
f"Supported types are: {', '.join(SUPPORTED_IMAGE_MIME_TYPES)}"
205+
)
206+
raise ValueError(msg)
207+
208+
images_to_embed: list[Optional[str]] = [None] * len(documents)
209+
pdf_page_infos: list[_PDFPageInfo] = []
210+
211+
for doc_idx, image_source_info in enumerate(images_source_info):
212+
if image_source_info["mime_type"] == "application/pdf":
213+
# Store PDF documents for later processing
214+
page_number = image_source_info.get("page_number")
215+
assert page_number is not None # checked in _extract_image_sources_info but mypy doesn't know that
216+
pdf_page_info: _PDFPageInfo = {
217+
"doc_idx": doc_idx,
218+
"path": image_source_info["path"],
219+
"page_number": page_number,
220+
}
221+
pdf_page_infos.append(pdf_page_info)
222+
else:
223+
# Process images directly
224+
image_byte_stream = ByteStream.from_file_path(
225+
filepath=image_source_info["path"], mime_type=image_source_info["mime_type"]
226+
)
227+
mime_type, base64_image = _encode_image_to_base64(bytestream=image_byte_stream, size=self.image_size)
228+
images_to_embed[doc_idx] = f"data:{mime_type};base64,{base64_image}"
229+
230+
base64_jpeg_images_by_doc_idx = _batch_convert_pdf_pages_to_images(
231+
pdf_page_infos=pdf_page_infos, return_base64=True, size=self.image_size
232+
)
233+
for doc_idx, base64_jpeg_image in base64_jpeg_images_by_doc_idx.items():
234+
images_to_embed[doc_idx] = f"data:image/jpeg;base64,{base64_jpeg_image}"
235+
236+
none_images_doc_ids = [documents[doc_idx].id for doc_idx, image in enumerate(images_to_embed) if image is None]
237+
if none_images_doc_ids:
238+
msg = f"Conversion failed for some documents. Document IDs: {none_images_doc_ids}."
239+
raise RuntimeError(msg)
240+
241+
# tested above that image is not None, but mypy doesn't know that
242+
return images_to_embed # type: ignore[return-value]
243+
244+
@component.output_types(documents=list[Document])
245+
def run(self, documents: list[Document]) -> dict[str, list[Document]]:
246+
"""
247+
Embed a list of image documents.
248+
249+
:param documents:
250+
Documents to embed.
251+
252+
:returns:
253+
A dictionary with the following keys:
254+
- `documents`: Documents with embeddings.
255+
"""
256+
257+
images_to_embed = self._extract_images_to_embed(documents)
258+
259+
embeddings = []
260+
261+
# The Cohere API only supports passing one image at a time
262+
for doc, image in tqdm(zip(documents, images_to_embed), desc="Embedding images", disable=not self.progress_bar):
263+
try:
264+
response = self._client.embed(
265+
model=self.model,
266+
images=[image],
267+
input_type="image",
268+
output_dimension=self.embedding_dimension,
269+
embedding_types=[self.embedding_type.value],
270+
)
271+
embedding = getattr(response.embeddings, self.embedding_type.value)[0]
272+
except Exception as e:
273+
msg = f"Error embedding Document {doc.id}"
274+
raise RuntimeError(msg) from e
275+
276+
embeddings.append(embedding)
277+
278+
docs_with_embeddings = []
279+
for doc, emb in zip(documents, embeddings):
280+
# we store this information for later inspection
281+
new_meta = {
282+
**doc.meta,
283+
"embedding_source": {"type": "image", "file_path_meta_field": self.file_path_meta_field},
284+
}
285+
new_doc = replace(doc, meta=new_meta, embedding=emb)
286+
docs_with_embeddings.append(new_doc)
287+
288+
return {"documents": docs_with_embeddings}
289+
290+
@component.output_types(documents=list[Document])
291+
async def run_async(self, documents: list[Document]) -> dict[str, list[Document]]:
292+
"""
293+
Asynchronously embed a list of image documents.
294+
295+
:param documents:
296+
Documents to embed.
297+
298+
:returns:
299+
A dictionary with the following keys:
300+
- `documents`: Documents with embeddings.
301+
"""
302+
303+
images_to_embed = self._extract_images_to_embed(documents)
304+
305+
embeddings = []
306+
307+
# The Cohere API only supports passing one image at a time
308+
for doc, image in tqdm(zip(documents, images_to_embed), desc="Embedding images", disable=not self.progress_bar):
309+
try:
310+
response = await self._async_client.embed(
311+
model=self.model,
312+
images=[image],
313+
input_type="image",
314+
output_dimension=self.embedding_dimension,
315+
embedding_types=[self.embedding_type.value],
316+
)
317+
embedding = getattr(response.embeddings, self.embedding_type.value)[0]
318+
except Exception as e:
319+
msg = f"Error embedding Document {doc.id}"
320+
raise RuntimeError(msg) from e
321+
322+
embeddings.append(embedding)
323+
324+
docs_with_embeddings = []
325+
for doc, emb in zip(documents, embeddings):
326+
# we store this information for later inspection
327+
new_meta = {
328+
**doc.meta,
329+
"embedding_source": {"type": "image", "file_path_meta_field": self.file_path_meta_field},
330+
}
331+
new_doc = replace(doc, meta=new_meta, embedding=emb)
332+
docs_with_embeddings.append(new_doc)
333+
334+
return {"documents": docs_with_embeddings}

integrations/cohere/src/haystack_integrations/components/embedders/cohere/text_embedder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class CohereTextEmbedder:
1919
2020
Usage example:
2121
```python
22-
from haystack_integrations.components.embedders.cohere import CohereDocumentEmbedder
22+
from haystack_integrations.components.embedders.cohere import CohereTextEmbedder
2323
2424
text_to_embed = "I love pizza!"
2525
@@ -39,7 +39,7 @@ def __init__(
3939
input_type: str = "search_query",
4040
api_base_url: str = "https://api.cohere.com",
4141
truncate: str = "END",
42-
timeout: int = 120,
42+
timeout: float = 120.0,
4343
embedding_type: Optional[EmbeddingTypes] = None,
4444
):
4545
"""

0 commit comments

Comments
 (0)