Skip to content

Commit 26c2bd9

Browse files
committed
Update transformers multimodal docs to reflect new changes with Chat and Batching examples
1 parent d12666a commit 26c2bd9

File tree

1 file changed

+111
-44
lines changed

1 file changed

+111
-44
lines changed

docs/features/models/transformers_multimodal.md

Lines changed: 111 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,18 @@ The Outlines `TransformersMultiModal` model inherits from `Transformers` and sha
1010

1111
To load the model, you can use the `from_transformers` function. It takes 2 arguments:
1212

13-
- `model`: a `transformers` model (created with `AutoModelForCausalLM` for instance)
13+
- `model`: a `transformers` model (created with `AutoModelForImageTextToText` for instance)
1414
- `tokenizer_or_processor`: a `transformers` processor (created with `AutoProcessor` for instance, it must be an instance of `ProcessorMixin`)
1515

1616
For instance:
1717

1818
```python
1919
import outlines
20-
from transformers import AutoModelForCausalLM, AutoProcessor
20+
from transformers import AutoModelForImageTextToText, AutoProcessor
2121

2222
# Create the transformers model and processor
23-
hf_model = AutoModelForCausalLM.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
24-
hf_processor = AutoProcessor.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
23+
hf_model = AutoModelForImageTextToText.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")
24+
hf_processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")
2525

2626
# Create the Outlines model
2727
model = outlines.from_transformers(hf_model, hf_processor)
@@ -76,54 +76,66 @@ result = model(
7676
print(result) # '{"specie": "cat", "color": "white", "weight": 4}'
7777
print(Animal.model_validate_json(result)) # specie=cat, color=white, weight=4
7878
```
79+
!!! Warning
80+
81+
Make sure your prompt contains the tags expected by your processor to correctly inject the assets in the prompt. For some vision multimodal models for instance, you need to add as many `<image>` tags in your prompt as there are image assets included in your model input. `Chat` method, instead, does not require this step.
82+
7983

80-
The `TransformersMultiModal` model supports batch generation. To use it, invoke the `batch` method with a list of lists. You will receive as a result a list of completions.
84+
### Chat
85+
The `Chat` interface offers a more convenient way to work with multimodal inputs. You don't need to manually add asset tags like `<image>`. The model's HF processor handles the chat templating and asset placement for you automatically.
86+
To do so, call the model with a `Chat` instance using a multimodal chat format. Assets must be pre-processed as `outlines.inputs.{Image, Audio, Video}` format, and only `image`, `video`, and `audio` types are supported.
8187

8288
For instance:
8389

8490
```python
91+
import outlines
92+
from outlines.inputs import Chat, Image
93+
from transformers import AutoModelForImageTextToText, AutoProcessor
94+
from PIL import Image as PILImage
8595
from io import BytesIO
8696
from urllib.request import urlopen
97+
import torch
8798

88-
from PIL import Image as PILImage
89-
from transformers import (
90-
LlavaForConditionalGeneration,
91-
AutoProcessor,
92-
)
93-
94-
import outlines
95-
from outlines.inputs import Image
96-
97-
TEST_MODEL = "trl-internal-testing/tiny-LlavaForConditionalGeneration"
98-
IMAGE_URL = "https://upload.wikimedia.org/wikipedia/commons/2/25/Siam_lilacpoint.jpg"
99-
IMAGE_URL_2 ="https://upload.wikimedia.org/wikipedia/commons/9/98/Aldrin_Apollo_11_original.jpg"
99+
model_kwargs = {
100+
"torch_dtype": torch.bfloat16,
101+
"attn_implementation": "flash_attention_2",
102+
"device_map": "auto",
103+
}
100104

101105
def get_image_from_url(image_url):
102106
img_byte_stream = BytesIO(urlopen(image_url).read())
103107
image = PILImage.open(img_byte_stream).convert("RGB")
104108
image.format = "PNG"
105109
return image
106-
107-
# Create a model
110+
111+
# Create the model
108112
model = outlines.from_transformers(
109-
LlavaForConditionalGeneration.from_pretrained(TEST_MODEL),
110-
AutoProcessor.from_pretrained(TEST_MODEL),
113+
AutoModelForImageTextToText.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct", **model_kwargs),
114+
AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct", **model_kwargs)
111115
)
112116

113-
# Call the batch method with a list of model input dicts
114-
result = model.batch(
115-
[
116-
["<image>Describe the image.", Image(get_image_from_url(IMAGE_URL))],
117-
["<image>Describe the image.", Image(get_image_from_url(IMAGE_URL_2))],
118-
]
119-
)
120-
print(result) # ['The image shows a cat', 'The image shows an astronaut']
117+
IMAGE_URL = "https://upload.wikimedia.org/wikipedia/commons/2/25/Siam_lilacpoint.jpg"
118+
119+
# Create the chat mutimodal input
120+
prompt = Chat([
121+
{
122+
"role": "user",
123+
"content": [
124+
{"type": "image", "image": Image(get_image_from_url(IMAGE_URL))},
125+
{"type": "text", "text": "Describe the image in few words."}
126+
],
127+
}
128+
])
129+
130+
# Call the model to generate a response
131+
response = model(prompt, max_new_tokens=50)
132+
print(response) # 'A Siamese cat with blue eyes is sitting on a cat tree, looking alert and curious.'
121133
```
122134

123-
### Chat
124-
You can use chat inputs with the `TransformersMultiModal` model. To do so, call the model with a `Chat` instance.
135+
### Batching
136+
The `TransformersMultiModal` model supports batching through the `batch` method. To use it, provide a list of prompts (using the formats described above) to the `batch` method. You will receive as a result a list of completions.
125137

126-
For instance:
138+
An example using the Chat format:
127139

128140
```python
129141
import outlines
@@ -133,18 +145,22 @@ from PIL import Image as PILImage
133145
from io import BytesIO
134146
from urllib.request import urlopen
135147
import torch
148+
from pydantic import BaseModel
136149

137150
model_kwargs = {
138151
"torch_dtype": torch.bfloat16,
139152
"attn_implementation": "flash_attention_2",
140153
"device_map": "auto",
141154
}
142155

156+
class Animal(BaseModel):
157+
animal: str
158+
color: str
159+
143160
def get_image_from_url(image_url):
144161
img_byte_stream = BytesIO(urlopen(image_url).read())
145162
image = PILImage.open(img_byte_stream).convert("RGB")
146163
image.format = "PNG"
147-
image.save("image.png")
148164
return image
149165

150166
# Create the model
@@ -153,25 +169,76 @@ model = outlines.from_transformers(
153169
AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct", **model_kwargs)
154170
)
155171

156-
IMAGE_URL = "https://upload.wikimedia.org/wikipedia/commons/2/25/Siam_lilacpoint.jpg"
172+
IMAGE_URL_1 = "https://upload.wikimedia.org/wikipedia/commons/2/25/Siam_lilacpoint.jpg"
173+
IMAGE_URL_2 = "https://upload.wikimedia.org/wikipedia/commons/a/af/Golden_retriever_eating_pigs_foot.jpg"
157174

158-
# Create the chat mutimodal input
159-
prompt = Chat([
175+
# Create the chat mutimodal messages
176+
messages = [
160177
{
161178
"role": "user",
162179
"content": [
163-
{"type": "image", "image": Image(get_image_from_url(IMAGE_URL))},
164-
{"type": "text", "text": "Describe the image in few words."}
180+
{"type": "text", "text": "Describe the image in few words."},
181+
{"type": "image", "image": Image(get_image_from_url(IMAGE_URL_1))},
165182
],
166-
}
167-
])
183+
},
184+
]
185+
186+
messages_2 = [
187+
{
188+
"role": "user",
189+
"content": [
190+
{"type": "text", "text": "Describe the image in few words."},
191+
{"type": "image", "image": Image(get_image_from_url(IMAGE_URL_2))},
192+
],
193+
},
194+
]
195+
196+
prompts = [Chat(messages), Chat(messages_2)]
168197

169198
# Call the model to generate a response
170-
response = model(prompt, max_new_tokens=50)
171-
print(response) # 'A Siamese cat with blue eyes is sitting on a cat tree, looking alert and curious.'
199+
responses = model.batch(prompts, output_type=Animal, max_new_tokens=100)
200+
print(responses) # ['{ "animal": "cat", "color": "white and gray" }', '{ "animal": "dog", "color": "white" }']
201+
print([Animal.model_validate_json(i) for i in responses]) # [Animal(animal='cat', color='white and gray'), Animal(animal='dog', color='white')]
172202
```
173203

174204

175-
!!! Warning
205+
An example using a list of lists with tag assets:
176206

177-
Make sure your prompt contains the tags expected by your processor to correctly inject the assets in the prompt. For some vision multimodal models for instance, you need to add as many `<image>` tags in your prompt as there are image assets included in your model input.
207+
```python
208+
from io import BytesIO
209+
from urllib.request import urlopen
210+
211+
from PIL import Image as PILImage
212+
from transformers import (
213+
LlavaForConditionalGeneration,
214+
AutoProcessor,
215+
)
216+
217+
import outlines
218+
from outlines.inputs import Image
219+
220+
TEST_MODEL = "trl-internal-testing/tiny-LlavaForConditionalGeneration"
221+
IMAGE_URL = "https://upload.wikimedia.org/wikipedia/commons/2/25/Siam_lilacpoint.jpg"
222+
IMAGE_URL_2 ="https://upload.wikimedia.org/wikipedia/commons/9/98/Aldrin_Apollo_11_original.jpg"
223+
224+
def get_image_from_url(image_url):
225+
img_byte_stream = BytesIO(urlopen(image_url).read())
226+
image = PILImage.open(img_byte_stream).convert("RGB")
227+
image.format = "PNG"
228+
return image
229+
230+
# Create a model
231+
model = outlines.from_transformers(
232+
LlavaForConditionalGeneration.from_pretrained(TEST_MODEL),
233+
AutoProcessor.from_pretrained(TEST_MODEL),
234+
)
235+
236+
# Call the batch method with a list of model input dicts
237+
result = model.batch(
238+
[
239+
["<image>Describe the image.", Image(get_image_from_url(IMAGE_URL))],
240+
["<image>Describe the image.", Image(get_image_from_url(IMAGE_URL_2))],
241+
]
242+
)
243+
print(result) # ['The image shows a cat', 'The image shows an astronaut']
244+
```

0 commit comments

Comments
 (0)