Skip to content

Commit 6756831

Browse files
committed
Fix correct handling of chat multimodal inputs in TransformersMultiModal model and update docs for the usage of Chat in TransformersMultiModal
1 parent b3fba13 commit 6756831

File tree

2 files changed

+63
-15
lines changed

2 files changed

+63
-15
lines changed

docs/features/models/transformers_multimodal.md

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,58 @@ result = model.batch(
120120
print(result) # ['The image shows a cat', 'The image shows an astronaut']
121121
```
122122

123+
### Chat
124+
You can use chat inputs with the `TransformersMultiModal` model. To do so, call the model with a `Chat` instance.
125+
126+
For instance:
127+
128+
```python
129+
import outlines
130+
from outlines.inputs import Chat, Image
131+
from transformers import AutoModelForImageTextToText, AutoProcessor
132+
from PIL import Image as PILImage
133+
from io import BytesIO
134+
from urllib.request import urlopen
135+
import torch
136+
137+
model_kwargs = {
138+
"torch_dtype": torch.bfloat16,
139+
"attn_implementation": "flash_attention_2",
140+
"device_map": "auto",
141+
}
142+
143+
def get_image_from_url(image_url):
144+
img_byte_stream = BytesIO(urlopen(image_url).read())
145+
image = PILImage.open(img_byte_stream).convert("RGB")
146+
image.format = "PNG"
147+
image.save("image.png")
148+
return image
149+
150+
# Create the model
151+
model = outlines.from_transformers(
152+
AutoModelForImageTextToText.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct", **model_kwargs),
153+
AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct", **model_kwargs)
154+
)
155+
156+
IMAGE_URL = "https://upload.wikimedia.org/wikipedia/commons/2/25/Siam_lilacpoint.jpg"
157+
158+
# Create the chat mutimodal input
159+
prompt = Chat([
160+
{
161+
"role": "user",
162+
"content": [
163+
{"type": "image", "image": Image(get_image_from_url(IMAGE_URL))},
164+
{"type": "text", "text": "Describe the image in few words."}
165+
],
166+
}
167+
])
168+
169+
# Call the model to generate a response
170+
response = model(prompt, max_new_tokens=50)
171+
print(response) # 'A Siamese cat with blue eyes is sitting on a cat tree, looking alert and curious.'
172+
```
173+
174+
123175
!!! Warning
124176

125177
Make sure your prompt contains the tags expected by your processor to correctly inject the assets in the prompt. For some vision multimodal models for instance, you need to add as many `<image>` tags in your prompt as there are image assets included in your model input.

outlines/models/transformers.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -441,26 +441,22 @@ def format_dict_input(self, model_input: dict) -> dict:
441441

442442
@format_input.register(Chat)
443443
def format_chat_input(self, model_input: Chat) -> dict:
444-
# we need to separate the images from the messages
445-
# to apply the chat template to the messages without images
444+
# we need to separate the assets from the messages
445+
# to apply the chat template to the messages without assets
446446
messages = model_input.messages
447-
images = []
448-
messages_without_images = []
447+
assets = []
449448
for message in messages:
450449
if isinstance(message["content"], list):
451-
images.extend(message["content"][1:])
452-
messages_without_images.append({
453-
"role": message["role"],
454-
"content": message["content"][0],
455-
})
456-
else:
457-
messages_without_images.append(message)
450+
for item in message["content"]:
451+
if item["type"] != "text":
452+
assets.append(item[item["type"]])
458453
formatted_prompt = self.tokenizer.apply_chat_template(
459-
messages_without_images,
460-
tokenize=False
454+
messages, # full message for applying chat template
455+
tokenize=False,
456+
add_generation_prompt=True
461457
)
462-
# use the formatted prompt and the images to format the input
463-
return self.format_list_input([formatted_prompt, *images])
458+
# use the formatted prompt and the assets to format the input
459+
return self.format_list_input([formatted_prompt, *assets])
464460

465461
@format_input.register(list)
466462
def format_list_input(self, model_input: list) -> dict:

0 commit comments

Comments
 (0)