Skip to content

Commit e774207

Browse files
[New blog] Beyond text generation blogpost (#78)
1 parent e2fb739 commit e774207

File tree

4 files changed

+232
-0
lines changed

4 files changed

+232
-0
lines changed
Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
---
2+
layout: post
3+
title: "Serving Geospatial, Vision, and Beyond: Enabling Multimodal Output Processing in vLLM"
4+
author: Christian Pinto (IBM Research Europe - Dublin), Michele Gazzetti (IBM Research Europe - Dublin), Michael Johnston (IBM Research Europe - Dublin), Maximilien Philippe Marie de Bayser (IBM Research - Brazil)
5+
image: /assets/logos/vllm-logo-text-light.png
6+
---
7+
## Introduction
8+
9+
Until recently, generative AI infrastructure has been tightly coupled with autoregressive text generation models that produce output token-by-token, typically in the form of natural language. vLLM has been following the trend by initially supporting models working with text input and output, the traditional LLMs. The trend has started shifting towards multimodal data with the introduction of MLLMs (Multimodal Large Language Models), capable of reasoning on text as well as data in various modalities (e.g., images, video, audio, etc.). Again, vLLM has followed the trend with support for LLaVA-style MLLMs, reasoning on multimodal input data and generating text.
10+
11+
We are now witnessing a new trend shift with a growing class of non-autoregressive models that generate multimodal outputs in a single inference pass, enabling faster and more efficient generation across a wide range of modalities. These models can be seen as pooling models from the inference standpoint, but require additional support for input and output handling. Applications for this type of models can be found in domains beyond text: from image classification and segmentation, to audio synthesis and structured data generation.
12+
We've made the next step in vLLM and added support for this class of models.
13+
14+
Our initial integration focuses on geospatial foundation models, a class of convolutional or vision transformer models that requires data beyond RGB channels (e.g. multispectral or radar) and metadata (e.g. geolocation, date of image acquisition) used for, but not limited to, tasks like disaster response or land use classification from satellite imagery. However, the changes are generic and pave the way for serving a wide variety of non text-generating models.
15+
16+
As a concrete example, we've integrated all geospatial models from the [TerraTorch](https://github.com/IBM/terratorch) framework (some of which were developed in collaboration with NASA and ESA) into vLLM via a generic backend, making them first-class citizens in the vLLM ecosystem.
17+
18+
In the sections that follow, we describe the technical changes made to vLLM, starting with the requirements and challenges of serving geospatial foundation models.
19+
20+
## Integrating Geospatial Foundation Models in vLLM
21+
22+
Unlike text models, geospatial foundation models (often implemented as vision transformers) don’t need token decoding, i.e., they do not need output tokens to be transformed into text.
23+
Instead, given one input image, a single inference generates the raw model output, and then this is post-processed into the output image.
24+
In addition, sometimes the input image needs to be partitioned and batched into a number of sub-images, or patches.
25+
These patches are then fed to the model for inference, with the resulting output images from each patch being stitched together to form the final output image.
26+
27+
<p align="center">
28+
<picture>
29+
<img src="/assets/figures/beyond-text/models-diff.png" width="80%">
30+
</picture>
31+
</p>
32+
33+
Given these requirements, the obvious choice was to integrate geospatial foundation models in vLLM as pooling models. Pooling is a technique that is commonly used in deep learning models to reduce the spatial dimensions of feature maps. Common types include max pooling, average pooling and global pooling, each using different strategies to aggregate information. In vLLM, pooling can be applied to [tasks](https://docs.vllm.ai/en/latest/models/pooling_models.html?h=pooling) such as embedding vector calculation and classification. In addition, vLLM supports identity poolers, returning the model's hidden states without applying any transformations - exactly what we need.
34+
For the input, we pre-process images into tensors that are then fed to the model for inference, exploiting the existing multimodal input capabilities of vLLM.
35+
36+
Since we wanted to support multiple geospatial foundation models out-of-the-box in vLLM, we have also added a model implementation backend for TerraTorch models, following the same pattern as the backend for the HuggingFace Transformers library.
37+
38+
Getting this to work was no easy task, though.
39+
Enabling these model classes required changes to various parts of vLLM such as:
40+
41+
* adding support for attention free models
42+
* improving support for models that do not require a tokenizer
43+
* enabling processing of raw input data as opposed to the default multimodal input embeddings
44+
* extending the vLLM serving API.
45+
46+
## Meet IO Processor: Flexible Input/Output Handling for Any Model
47+
48+
So far so good! Well, this brings us only halfway towards our goal.
49+
50+
With the above integration, we can indeed serve geospatial foundation models -- though only in tensor-to-tensor format.
51+
Users still have to pre-process their image to a tensor format, before sending the tensors to the vLLM instance.
52+
Similarly, post-processing of the raw tensor output has to happen outside vLLM.
53+
The impact: there is no endpoint that users can send an image to and get an image back.
54+
55+
This problem existed because, before our changes, pre-processing of input data and post-processing of the model output was only partially supported in vLLM.
56+
Specifically, pre-processing of multimodal input data was only possible via the processors available in the Transformers library.
57+
However, the transformers processors usually support only standard data types and do not handle more complex data formats such as GeoTIFF, which are image files with enriched geospatial metadata.
58+
Also, on the output processing side, vLLM only supported de-tokenization into text or the application of poolers to the model hidden states - no other output processing was possible.
59+
60+
This is where the new IO Processor plugin framework we introduced comes in.
61+
The IO Processor framework allows developers to customize how model inputs and outputs are pre- and post-processed, all within the same vLLM serving instance.
62+
Whether your model returns a string, a JSON object, an image tensor, or a custom data structure, an IO Processor can translate it into the desired format before returning it to the client.
63+
64+
<p align="center">
65+
<picture>
66+
<img src="/assets/figures/beyond-text/io-plugins-flow.png" width="70%">
67+
</picture>
68+
</p>
69+
70+
The IO Processor framework unlocks a new level of flexibility for vLLM users.
71+
It means non-text models (e.g., image generators, image to segmentation mask, tabular to classification, etc.) can be served using standard vLLM infrastructure.
72+
Via IO Processors users can plug in custom logic to transform or enrich outputs such as decoding model outputs into images, or formatting responses for downstream systems.
73+
This maintains a unified serving stack, reducing operational complexity and improving maintainability.
74+
75+
### Using vLLM IO Processor Plugins
76+
77+
Each IO Processor plugin implements a pre-defined [IO Processor interface](https://github.com/vllm-project/vllm/blob/main/vllm/plugins/io_processors/interface.py) and resides outside the vLLM source code tree.
78+
At installation time, each plugin registers one or more entrypoints in the `vllm.io_processor_plugins` group.
79+
This allows vLLM to automatically discover and load plugins at engine initialization time.
80+
81+
Using an IO Processor plugin is as easy as installing it in the same Python environment with vLLM, and adding the `--io-processor-plugin <plugin_name>` parameter when starting the serving instance.
82+
Currently, each vLLM instance can load one IO Processor plugin.
83+
84+
Once the serving instance is started, pre- and post-processing is automatically applied to the model input and output when serving the `/pooling` endpoint.
85+
At this stage, IO Processors are only available for pooling models, but in the future we expect other endpoints to be integrated too.
86+
87+
## Step-by-Step: Serving the Prithvi Model in vLLM
88+
89+
One example of a model class that can be served with vLLM using the TerraTorch backend is [Prithvi for flood detection](https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11). A full plugin example for the Prithvi geospatial foundation model is available [here](https://github.com/christian-pinto/prithvi_io_processor_plugin).
90+
91+
### The Prithvi IO Processor Plugin
92+
To illustrate the flexibility of the IO Processor plugin approach, the pseudocode below shows the main steps of the Prithvi IO Processor pre- and post-processing. What we want to highlight is the decoupling between the data-specific transformations with the model inference data. This makes room for ideally any model and any input/output data type, or even multiple plugins applied to the same model output, depending on the downstream task that consumes the data.
93+
94+
```python
95+
def pre_process(request_data: dict):
96+
# Downloads geotiff
97+
# In this example the input image has 7 bands
98+
image_url = request_data["url"]
99+
image_obj = download_image(image_url)
100+
101+
# Extract image data:
102+
# - pixel_values([n, 6, 512, 512])
103+
# - 6 input bands R, G, B, +3 multispectral wavelengths
104+
# - n > 1 if the size of the input image is > [512, 512]
105+
# - metadata
106+
# - GPS coordinates
107+
# - date
108+
pixel_values, metadata = process_image(image_obj)
109+
110+
# Process the image data into n vLLM prompts
111+
model_prompts = pixels_to_prompts(pixel_values)
112+
113+
return model_prompts
114+
115+
116+
def post_process(model_outputs: list[PoolingRequestOutput]):
117+
# Uses the previously extracted metadata to guarantee the output
118+
# contains the same georeferences and date.
119+
return image_object(model_outputs, metadata)
120+
```
121+
122+
### Install the Python Requirements
123+
124+
Install the `terratorch` (>=1.1rc3) and `vllm` packages in your Python environment.
125+
At the time of writing this article, the changes required for replicating this example are not yet part of a vLLM release (current latest is v0.10.1.1) and we recommend users install the [latest code](https://docs.vllm.ai/en/latest/getting_started/installation/gpu.html#install-the-latest-code_1).
126+
127+
Download and install the IO Processor plugin for flood detection with Prithvi.
128+
129+
```bash
130+
git clone [email protected]:christian-pinto/prithvi_io_processor_plugin.git
131+
cd prithvi_io_processor_plugin
132+
pip install .
133+
```
134+
135+
This installs the `prithvi_to_tiff` plugin.
136+
137+
### Start a vLLM Serving Instance
138+
139+
Start a vLLM serving instance that loads the `prithvi_to_tiff` plugin and the Prithvi model for flood detection.
140+
141+
```bash
142+
vllm serve \
143+
--model=ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11 \
144+
--model-impl terratorch \
145+
--task embed --trust-remote-code \
146+
--skip-tokenizer-init --enforce-eager \
147+
--io-processor-plugin prithvi_to_tiff
148+
```
149+
150+
Once the instance is running, it is ready to serve requests with the selected plugin.
151+
The log entries below confirm that your vLLM instance is up and running and that it is listening on port `8000`.
152+
153+
```bash
154+
INFO: Starting vLLM API server 0 on http://0.0.0.0:8000
155+
...
156+
...
157+
INFO: Started server process [409128]
158+
INFO: Waiting for application startup.
159+
INFO: Application startup complete.
160+
```
161+
162+
### Send Requests to the Model
163+
The Python script below sends a request to the vLLM `/pooling` endpoint with a specific JSON payload where the `model` and `softmax` arguments are pre-defined, while the `data` field is defined by the user and depends on the plugin in use.
164+
>[!NOTE]
165+
>Setting the `softmax` field to `False` is required to ensure the plugin receives the raw model output.
166+
In this case, we send the input image to vLLM as a URL, and we request the response to be a GeoTIFF image in base64 encoding.
167+
The script decodes the image and writes it to disk as a tiff (GeoTIFF) file.
168+
169+
```python
170+
import base64
171+
import os
172+
import requests
173+
174+
def main():
175+
image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff"
176+
server_endpoint = "http://localhost:8000/pooling"
177+
178+
request_payload = {
179+
"data": {
180+
"data": image_url,
181+
"data_format": "url",
182+
"image_format": "tiff",
183+
"out_data_format": "b64_json",
184+
},
185+
"model": "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
186+
"softmax": False,
187+
}
188+
189+
ret = requests.post(server_endpoint, json=request_payload)
190+
191+
if ret.status_code == 200:
192+
response = ret.json()
193+
194+
decoded_image = base64.b64decode(response["data"]["data"])
195+
196+
out_path = os.path.join(os.getcwd(), "online_prediction.tiff")
197+
198+
with open(out_path, "wb") as f:
199+
f.write(decoded_image)
200+
else:
201+
print(f"Response status_code: {ret.status_code}")
202+
print(f"Response reason:{ret.reason}")
203+
204+
205+
if __name__ == "__main__":
206+
main()
207+
```
208+
209+
Below is an example of the input and the expected output.
210+
The input image (left) is a satellite picture of Valencia, Spain during the 2024 flood.
211+
The output image (right) shows the areas predicted as flooded (in white) by the Prithvi model.
212+
213+
<p align="center">
214+
<picture>
215+
<img src="/assets/figures/beyond-text/prithvi-prediction.png" width="100%">
216+
</picture>
217+
</p>
218+
219+
## What’s Next
220+
221+
This is just the beginning.
222+
We plan to expand IO Processor plugins across more TerraTorch models and modalities and beyond, making installation seamless.
223+
Longer-term, we envision IO Processors powered vision-language systems, structured reasoning agents, and multimodal pipelines, all served from the same vLLM stack. We're also excited to see how the community uses IO Processors to push the boundaries of what’s possible with vLLM.
224+
We also plan to continue working with and contributing to the vLLM community to enable more multimodal models and end-to-end use cases.
225+
226+
**Contributions, feedback, and ideas are always welcome!**
227+
228+
To get started with IO Processor plugins, check the [documentation](https://docs.vllm.ai/en/latest/design/io_processor_plugins.html) and explore the [examples](https://github.com/vllm-project/vllm/tree/main/examples).
229+
More information on IBM's TerraTorch is available [here](https://github.com/IBM/terratorch).
230+
231+
## Acknowledgement
232+
We would like to thank the members of the vLLM community for their help with improving our contribution. In particular, we would like to thank [Cyrus Leung](https://github.com/DarkLight1337) for his support in helping shape the overall concept of extending vLLM beyond text generation. Finally, we would like to thank the TerraTorch team at IBM, especially [Paolo Fraccaro](https://github.com/paolo-fraccaro) and [Joao Lucas de Sousa Almeida](https://github.com/Joao-L-S-Almeida), for their help with integrating the generic TerraTorch backend in vLLM.
206 KB
Loading
381 KB
Loading
4.78 MB
Loading

0 commit comments

Comments
 (0)