Skip to content

Commit 69cff31

Browse files
geetu040Shakib-IOzucchini-nlp
authored
Add support for DeepseekAI's DeepseekVL (#36248)
* upload initial code * update deepseek-vl adaptor * update hierarchy of vision model classes * udpate aligner model * add text model * Added Image Processor * Added Image Processor * Added Image Processor * apply masks * remove projection; add aligner * remove interpolate_pos_encoding * remove unused params in config * cleaning * Add the __init__ file * added processing deepseek_vl class * modified the deepseek-vl processor * modified the deepseek-vl processor * update __init__ * Update the image processor class name * Added Deepseek to src/transformers/__init__.py file * Added Deepseek to image_processing_auto.py * update the __init__ file * update deepseek_vl image processor * Update Deepseek Processor * upload fast image processor * Revert "upload fast image processor" This reverts commit 68c8fd5. * update image processor * flatten heirarchy * remove DeepseekVLModel * major update (complete modeling) * auto modeling and other files * formatting * fix quality * replace torchvision in modeling * set default do_normalize to False * add fast image processor template using tool * update image processors * add fast image processor to other files * update liscense * Added deepseek image testcases * update image test * update processor * write CHAT_TEMPLATE * update model for processor * fix processor * minor fixes and formatting * fix image processing and tests * fix interpolation in sam * fix output_attentions in DeepseekVLModel * upload test_modeling * fix tests because of vocab size * set use_high_res_vision=False in tests * fix all modeling tests * fix styling * remove explicit background_color from image processors * added test_processor * added test_processor * fix processor tests * update docs * update docs * update docs * update conversion script * Fixed typos * minor fixes from review - remove model_id comments in examples - remove from pre-trained auto mapping - move to image-text-to-text from vision-to-seq in auto mapping - add image_token_index to __init__ for config - remove outdated temporary config in conversion script - update example to use chat_template in docstring example - update liscense 2021->2025 * fix type in config docstring Co-authored-by: Raushan Turganbay <[email protected]> * update get_image_features * fix config * improve DeepseekVLImageProcessor.preprocess * return image_hidden_states * use AutoTokenizer and AutoImageProcessor in Processor * fix model outputs * make num_image_tokens configurable * fix docstring of processor * move system prompt to chat template * fix repo consistency * fix return_dict * replace SamVisionEncoder with SamVisionModel * update to remove deepcopy * 🛠️ Major Architectural Changes (Adds DeepseekVLHybrid) * fix quality checks * add missing hybrid in auto modeling * run make style * update sam_hq * update high_res_size in test * update docs following #36979 * update code with auto_docstring * update conversion scripts * fix style * fix failing test because of tuple * set weights_only=True in conversion script * use safetensors.torch.load_file instead of torch.load in conversion script * make output_dir optional in conversion script * fix code snippets in docs (now the examples work fine) * integration tests for DeepseekVL * update expected texts * make style * integration tests for DeepseekVLHybrid * fix class name * update expected texts for hybrid * run "make style" * update since changes in main * run make-style * nits since changes in main * undo changes in sam * fix tests * fix tests; update with main * update with main: output_attention/output_hidden_states * fix copied part in deepseek_vl * run fix-copies * fix output_hidden_states * sam: fix _init_weigths * use modular for DeepseekVL * make image processor more modular * modular: use JanusPreTrainedModel * janus: provide kwargs in loss * update processors in conversion script * Revert "sam: fix _init_weigths" This reverts commit db625d0. * run fix-copies --------- Co-authored-by: Shakib-IO <[email protected]> Co-authored-by: Raushan Turganbay <[email protected]>
1 parent a98bbc2 commit 69cff31

33 files changed

+5856
-4
lines changed

docs/source/en/_toctree.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -725,6 +725,10 @@
725725
title: DAB-DETR
726726
- local: model_doc/deepseek_v2
727727
title: DeepSeek-V2
728+
- local: model_doc/deepseek_vl
729+
title: DeepseekVL
730+
- local: model_doc/deepseek_vl_hybrid
731+
title: DeepseekVLHybrid
728732
- local: model_doc/deformable_detr
729733
title: Deformable DETR
730734
- local: model_doc/deit
Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
<!--Copyright 2025 Deepseek AI and The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
12+
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
13+
rendered properly in your Markdown viewer.
14+
15+
-->
16+
17+
<div style="float: right;">
18+
<div class="flex flex-wrap space-x-1">
19+
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
20+
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
21+
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
22+
</div>
23+
</div>
24+
25+
# DeepseekVL
26+
27+
[Deepseek-VL](https://arxiv.org/abs/2403.05525) was introduced by the DeepSeek AI team. It is a vision-language model (VLM) designed to process both text and images for generating contextually relevant responses. The model leverages [LLaMA](./llama) as its text encoder, while [SigLip](./siglip) is used for encoding images.
28+
29+
You can find all the original Deepseek-VL checkpoints under the [DeepSeek-community](https://huggingface.co/deepseek-community) organization.
30+
31+
> [!TIP]
32+
> Click on the Deepseek-VL models in the right sidebar for more examples of how to apply Deepseek-VL to different vision and language tasks.
33+
34+
The example below demonstrates how to generate text based on an image with [`Pipeline`] or the [`AutoModel`] class.
35+
36+
<hfoptions id="usage">
37+
<hfoption id="Pipeline">
38+
39+
```py
40+
import torch
41+
from transformers import pipeline
42+
43+
pipe = pipeline(
44+
task="image-text-to-text",
45+
model="deepseek-community/deepseek-vl-1.3b-chat",
46+
device=0,
47+
torch_dtype=torch.float16
48+
)
49+
50+
messages = [
51+
{
52+
"role": "user",
53+
"content": [
54+
{
55+
"type": "image",
56+
"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg",
57+
},
58+
{ "type": "text", "text": "Describe this image."},
59+
]
60+
}
61+
]
62+
63+
pipe(text=messages, max_new_tokens=20, return_full_text=False)
64+
```
65+
</hfoption>
66+
67+
<hfoption id="AutoModel">
68+
69+
```py
70+
import torch
71+
from transformers import DeepseekVLForConditionalGeneration, AutoProcessor
72+
73+
model = DeepseekVLForConditionalGeneration.from_pretrained(
74+
"deepseek-community/deepseek-vl-1.3b-chat",
75+
torch_dtype=torch.float16,
76+
device_map="auto",
77+
attn_implementation="sdpa"
78+
)
79+
80+
processor = AutoProcessor.from_pretrained("deepseek-community/deepseek-vl-1.3b-chat")
81+
82+
messages = [
83+
{
84+
"role":"user",
85+
"content":[
86+
{
87+
"type":"image",
88+
"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
89+
},
90+
{
91+
"type":"text",
92+
"text":"Describe this image."
93+
}
94+
]
95+
}
96+
97+
]
98+
99+
inputs = processor.apply_chat_template(
100+
messages,
101+
add_generation_prompt=True,
102+
tokenize=True,
103+
return_dict=True,
104+
return_tensors="pt"
105+
).to(model.device, dtype=model.dtype)
106+
107+
generated_ids = model.generate(**inputs, max_new_tokens=128)
108+
generated_ids_trimmed = [
109+
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
110+
]
111+
output_text = processor.batch_decode(
112+
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
113+
)
114+
115+
print(output_text)
116+
```
117+
</hfoption>
118+
</hfoptions>
119+
120+
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends.
121+
122+
The example below uses [torchao](../quantization/torchao) to only quantize the weights to int4.
123+
124+
```python
125+
import torch
126+
from transformers import TorchAoConfig, DeepseekVLForConditionalGeneration, AutoProcessor
127+
128+
quantization_config = TorchAoConfig(
129+
"int4_weight_only",
130+
group_size=128
131+
)
132+
133+
model = DeepseekVLForConditionalGeneration.from_pretrained(
134+
"deepseek-community/deepseek-vl-1.3b-chat",
135+
torch_dtype=torch.bfloat16,
136+
device_map="auto",
137+
quantization_config=quantization_config
138+
)
139+
```
140+
### Notes
141+
142+
- Do inference with multiple images in a single conversation.
143+
```py
144+
import torch
145+
from transformers import DeepseekVLForConditionalGeneration, AutoProcessor
146+
147+
model = DeepseekVLForConditionalGeneration.from_pretrained(
148+
"deepseek-community/deepseek-vl-1.3b-chat",
149+
torch_dtype=torch.float16,
150+
device_map="auto",
151+
attn_implementation="sdpa"
152+
)
153+
154+
processor = AutoProcessor.from_pretrained("deepseek-community/deepseek-vl-1.3b-chat")
155+
156+
messages = [
157+
[
158+
{
159+
"role": "user",
160+
"content": [
161+
{"type": "text", "text": "What’s the difference between"},
162+
{"type": "image", "url": "http://images.cocodataset.org/val2017/000000039769.jpg"},
163+
{"type": "text", "text": " and "},
164+
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}
165+
]
166+
}
167+
],
168+
[
169+
{
170+
"role": "user",
171+
"content": [
172+
{"type": "image", "url": "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg"},
173+
{"type": "text", "text": "What do you see in this image?"}
174+
]
175+
}
176+
]
177+
]
178+
179+
inputs = processor.apply_chat_template(
180+
messages,
181+
add_generation_prompt=True,
182+
padding=True,
183+
truncation=True,
184+
tokenize=True,
185+
return_dict=True,
186+
return_tensors="pt"
187+
).to(model.device, dtype=model.dtype)
188+
189+
generated_ids = model.generate(**inputs, max_new_tokens=128)
190+
generated_ids_trimmed = [
191+
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
192+
]
193+
output_text = processor.batch_decode(
194+
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
195+
)
196+
197+
print(output_text)
198+
```
199+
200+
## DeepseekVLConfig
201+
202+
[[autodoc]] DeepseekVLConfig
203+
204+
## DeepseekVLProcessor
205+
206+
[[autodoc]] DeepseekVLProcessor
207+
208+
## DeepseekVLImageProcessor
209+
210+
[[autodoc]] DeepseekVLImageProcessor
211+
212+
## DeepseekVLModel
213+
214+
[[autodoc]] DeepseekVLModel
215+
- forward
216+
217+
## DeepseekVLForConditionalGeneration
218+
219+
[[autodoc]] DeepseekVLForConditionalGeneration
220+
- forward

0 commit comments

Comments
 (0)