Skip to content

Commit be68edd

Browse files
ChinmayBansaldavidsbatistaanakin87
authored
feat: add image support to LlamaCppChatGenerator (#2197)
* feat: add multimodal support to LlamaCppChatGenerator * address PR feedback * feat: address PR feedback * simplify; smaller models; workflow maintenance * type fix --------- Co-authored-by: David S. Batista <[email protected]> Co-authored-by: anakin87 <[email protected]>
1 parent 2fc25e2 commit be68edd

File tree

5 files changed

+306
-23
lines changed

5 files changed

+306
-23
lines changed

.github/workflows/llama_cpp.yml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,8 @@ jobs:
3030
strategy:
3131
fail-fast: false
3232
matrix:
33-
os: [ubuntu-latest, macos-latest] # we don't test on windows because of issues on llama-cpp-python
34-
# we test with 3.12 since sentencepiece has issues with 3.13: https://github.com/google/sentencepiece/issues/1103
35-
python-version: ["3.9", "3.12"]
33+
os: [ubuntu-latest, macos-latest, windows-latest]
34+
python-version: ["3.9", "3.13"]
3635

3736
steps:
3837
- name: Support longpaths

integrations/llama_cpp/pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ classifiers = [
2626
"Programming Language :: Python :: Implementation :: CPython",
2727
"Programming Language :: Python :: Implementation :: PyPy",
2828
]
29-
dependencies = ["haystack-ai>=2.13.0", "llama-cpp-python>=0.2.87"]
29+
dependencies = ["haystack-ai>=2.16.1", "llama-cpp-python>=0.2.87"]
3030

3131
[project.urls]
3232
Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/llama_cpp#readme"
@@ -60,7 +60,7 @@ dependencies = [
6060
"pytest-rerunfailures",
6161
"mypy",
6262
"pip",
63-
"transformers[sentencepiece]"
63+
"transformers[sentencepiece]",
6464
]
6565

6666
[tool.hatch.envs.test.scripts]

integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/chat/chat_generator.py

Lines changed: 108 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
from haystack.dataclasses import (
88
ChatMessage,
99
ComponentInfo,
10+
ImageContent,
1011
StreamingCallbackT,
12+
TextContent,
1113
ToolCall,
1214
ToolCallDelta,
1315
select_streaming_callback,
@@ -25,12 +27,15 @@
2527
ChatCompletionMessageToolCall,
2628
ChatCompletionRequestAssistantMessage,
2729
ChatCompletionRequestMessage,
30+
ChatCompletionRequestMessageContentPart,
2831
ChatCompletionResponseChoice,
2932
ChatCompletionTool,
3033
CreateChatCompletionResponse,
3134
CreateChatCompletionStreamResponse,
3235
Llama,
36+
llama_chat_format,
3337
)
38+
from llama_cpp.llama_chat_format import Llava15ChatHandler
3439
from llama_cpp.llama_tokenizer import LlamaHFTokenizer
3540

3641
logger = logging.getLogger(__name__)
@@ -42,6 +47,8 @@
4247
"function_call": "tool_calls",
4348
}
4449

50+
SUPPORTED_IMAGE_FORMATS = ["image/jpeg", "image/jpg", "image/png", "image/gif", "image/webp"]
51+
4552

4653
def _convert_message_to_llamacpp_format(message: ChatMessage) -> ChatCompletionRequestMessage:
4754
"""
@@ -50,16 +57,24 @@ def _convert_message_to_llamacpp_format(message: ChatMessage) -> ChatCompletionR
5057
text_contents = message.texts
5158
tool_calls = message.tool_calls
5259
tool_call_results = message.tool_call_results
60+
images = message.images
5361

54-
if not text_contents and not tool_calls and not tool_call_results:
55-
msg = "A `ChatMessage` must contain at least one `TextContent`, `ToolCall`, or `ToolCallResult`."
62+
if not text_contents and not tool_calls and not tool_call_results and not images:
63+
msg = (
64+
"A `ChatMessage` must contain at least one `TextContent`, `ImageContent`, `ToolCall`, or `ToolCallResult`."
65+
)
5666
raise ValueError(msg)
5767
elif len(text_contents) + len(tool_call_results) > 1:
58-
msg = "A `ChatMessage` can only contain one `TextContent` or one `ToolCallResult`."
68+
msg = "For llama.cpp compatibility, a `ChatMessage` can contain at most one `TextContent` or `ToolCallResult`."
5969
raise ValueError(msg)
6070

6171
role = message._role.value
6272

73+
# Check that images are only in user messages
74+
if images and role != "user":
75+
msg = "Image content is only supported for user messages"
76+
raise ValueError(msg)
77+
6378
if role == "tool" and tool_call_results:
6479
if tool_call_results[0].origin.id is None:
6580
msg = "`ToolCall` must have a non-null `id` attribute to be used with llama.cpp."
@@ -71,12 +86,34 @@ def _convert_message_to_llamacpp_format(message: ChatMessage) -> ChatCompletionR
7186
}
7287

7388
if role == "system":
74-
content = text_contents[0] if text_contents else None
75-
return {"role": "system", "content": content}
89+
return {"role": "system", "content": text_contents[0]}
7690

7791
if role == "user":
78-
content = text_contents[0] if text_contents else None
79-
return {"role": "user", "content": content}
92+
# Handle multimodal content (text + images) preserving order
93+
if images:
94+
# Check image constraints for LlamaCpp
95+
for image in images:
96+
if image.mime_type not in SUPPORTED_IMAGE_FORMATS:
97+
supported_formats = ", ".join(SUPPORTED_IMAGE_FORMATS)
98+
msg = (
99+
f"Unsupported image format: {image.mime_type}. "
100+
f"LlamaCpp supports the following formats: {supported_formats}"
101+
)
102+
raise ValueError(msg)
103+
104+
content_parts: list[ChatCompletionRequestMessageContentPart] = []
105+
for part in message._content:
106+
if isinstance(part, TextContent) and part.text:
107+
content_parts.append({"type": "text", "text": part.text})
108+
elif isinstance(part, ImageContent):
109+
# LlamaCpp expects base64 data URI format
110+
image_url = f"data:{part.mime_type};base64,{part.base64_image}"
111+
content_parts.append({"type": "image_url", "image_url": {"url": image_url}})
112+
113+
return {"role": "user", "content": content_parts}
114+
115+
# Simple text-only message
116+
return {"role": "user", "content": text_contents[0]}
80117

81118
if role == "assistant":
82119
result: ChatCompletionRequestAssistantMessage = {"role": "assistant"}
@@ -113,6 +150,7 @@ class LlamaCppChatGenerator:
113150
114151
[llama.cpp](https://github.com/ggml-org/llama.cpp) is a project written in C/C++ for efficient inference of LLMs.
115152
It employs the quantized GGUF format, suitable for running these models on standard machines (even without GPUs).
153+
Supports both text-only and multimodal (text + image) models like LLaVA.
116154
117155
Usage example:
118156
```python
@@ -121,7 +159,30 @@ class LlamaCppChatGenerator:
121159
generator = LlamaCppGenerator(model="zephyr-7b-beta.Q4_0.gguf", n_ctx=2048, n_batch=512)
122160
123161
print(generator.run(user_message, generation_kwargs={"max_tokens": 128}))
124-
# {"replies": [ChatMessage(content="John Cusack", role=<ChatRole.ASSISTANT: "assistant">, name=None, meta={...}]}
162+
# {"replies": [ChatMessage(content="John Cusack", role=<ChatRole.ASSISTANT: "assistant">, name=None, meta={...})}
163+
```
164+
165+
Usage example with multimodal (image + text):
166+
```python
167+
from haystack.dataclasses import ChatMessage, ImageContent
168+
169+
# Create an image from file path or base64
170+
image_content = ImageContent.from_file_path("path/to/your/image.jpg")
171+
172+
# Create a multimodal message with both text and image
173+
messages = [ChatMessage.from_user(content_parts=["What's in this image?", image_content])]
174+
175+
# Initialize with multimodal support
176+
generator = LlamaCppChatGenerator(
177+
model="llava-v1.5-7b-q4_0.gguf",
178+
chat_handler_name="Llava15ChatHandler", # Use llava-1-5 handler
179+
model_clip_path="mmproj-model-f16.gguf", # CLIP model
180+
n_ctx=4096 # Larger context for image processing
181+
)
182+
generator.warm_up()
183+
184+
result = generator.run(messages)
185+
print(result)
125186
```
126187
"""
127188

@@ -135,6 +196,8 @@ def __init__(
135196
*,
136197
tools: Optional[Union[List[Tool], Toolset]] = None,
137198
streaming_callback: Optional[StreamingCallbackT] = None,
199+
chat_handler_name: Optional[str] = None,
200+
model_clip_path: Optional[str] = None,
138201
):
139202
"""
140203
:param model: The path of a quantized model for text generation, for example, "zephyr-7b-beta.Q4_0.gguf".
@@ -153,6 +216,12 @@ def __init__(
153216
A list of tools or a Toolset for which the model can prepare calls.
154217
This parameter can accept either a list of `Tool` objects or a `Toolset` instance.
155218
:param streaming_callback: A callback function that is called when a new token is received from the stream.
219+
:param chat_handler_name: Name of the chat handler for multimodal models.
220+
Common options include: "Llava16ChatHandler", "MoondreamChatHandler", "Qwen25VLChatHandler".
221+
For other handlers, check
222+
[llama-cpp-python documentation](https://llama-cpp-python.readthedocs.io/en/latest/#multi-modal-models).
223+
:param model_clip_path: Path to the CLIP model for vision processing (e.g., "mmproj.bin").
224+
Required when chat_handler_name is provided for multimodal models.
156225
"""
157226

158227
model_kwargs = model_kwargs or {}
@@ -166,6 +235,19 @@ def __init__(
166235

167236
_check_duplicate_tool_names(list(tools or []))
168237

238+
handler: Optional[Llava15ChatHandler] = None
239+
# Validate multimodal requirements
240+
if chat_handler_name is not None:
241+
if model_clip_path is None:
242+
msg = "model_clip_path must be provided when chat_handler_name is specified for multimodal models"
243+
raise ValueError(msg)
244+
# Validate chat handler by attempting to import it
245+
try:
246+
handler = getattr(llama_chat_format, chat_handler_name)
247+
except AttributeError as e:
248+
msg = f"Failed to import chat handler '{chat_handler_name}'."
249+
raise ValueError(msg) from e
250+
169251
self.model_path = model
170252
self.n_ctx = n_ctx
171253
self.n_batch = n_batch
@@ -174,14 +256,25 @@ def __init__(
174256
self._model: Optional[Llama] = None
175257
self.tools = tools
176258
self.streaming_callback = streaming_callback
259+
self.chat_handler_name = chat_handler_name
260+
self.model_clip_path = model_clip_path
261+
self._handler = handler
177262

178263
def warm_up(self):
179-
if "hf_tokenizer_path" in self.model_kwargs and "tokenizer" not in self.model_kwargs:
180-
tokenizer = LlamaHFTokenizer.from_pretrained(self.model_kwargs["hf_tokenizer_path"])
181-
self.model_kwargs["tokenizer"] = tokenizer
264+
if self._model is not None:
265+
return
182266

183-
if self._model is None:
184-
self._model = Llama(**self.model_kwargs)
267+
kwargs = self.model_kwargs.copy()
268+
if "hf_tokenizer_path" in kwargs and "tokenizer" not in kwargs:
269+
tokenizer = LlamaHFTokenizer.from_pretrained(kwargs["hf_tokenizer_path"])
270+
kwargs["tokenizer"] = tokenizer
271+
272+
# Handle multimodal initialization
273+
if self._handler is not None and self.model_clip_path is not None:
274+
# the following command is correct, but mypy complains because handlers also have a __call__ method
275+
kwargs["chat_handler"] = self._handler(clip_model_path=self.model_clip_path) # type: ignore[call-arg]
276+
277+
self._model = Llama(**kwargs)
185278

186279
def to_dict(self) -> Dict[str, Any]:
187280
"""
@@ -200,6 +293,8 @@ def to_dict(self) -> Dict[str, Any]:
200293
generation_kwargs=self.generation_kwargs,
201294
tools=serialize_tools_or_toolset(self.tools),
202295
streaming_callback=callback_name,
296+
chat_handler_name=self.chat_handler_name,
297+
model_clip_path=self.model_clip_path,
203298
)
204299

205300
@classmethod

0 commit comments

Comments
 (0)