Skip to content

Commit 82807e5

Browse files
authored
[Fast image processor] refactor fast image processor glm4v (#39490)
refactor fast image processor glm4v
1 parent 4b4f04f commit 82807e5

File tree

2 files changed

+54
-225
lines changed

2 files changed

+54
-225
lines changed

src/transformers/models/glm4v/image_processing_glm4v_fast.py

Lines changed: 49 additions & 219 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,9 @@
2828
from ...image_utils import (
2929
OPENAI_CLIP_MEAN,
3030
OPENAI_CLIP_STD,
31-
ChannelDimension,
3231
ImageInput,
3332
PILImageResampling,
3433
SizeDict,
35-
get_image_size,
36-
make_flat_list_of_images,
37-
valid_images,
3834
)
3935
from ...processing_utils import Unpack
4036
from ...utils import (
@@ -45,7 +41,6 @@
4541
is_torchvision_v2_available,
4642
logging,
4743
)
48-
from ...video_utils import VideoInput
4944
from .image_processing_glm4v import smart_resize
5045

5146

@@ -54,8 +49,6 @@
5449

5550

5651
if is_torchvision_available():
57-
from ...image_utils import pil_torch_interpolation_mapping
58-
5952
if is_torchvision_v2_available():
6053
from torchvision.transforms.v2 import functional as F
6154
else:
@@ -96,19 +89,12 @@ class Glm4vImageProcessorFast(BaseImageProcessorFast):
9689
model_input_names = ["pixel_values", "image_grid_thw"]
9790

9891
def __init__(self, **kwargs: Unpack[Glm4vFastImageProcessorKwargs]):
99-
size = kwargs.pop("size", None)
100-
if size is not None and ("shortest_edge" not in size or "longest_edge" not in size):
101-
raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.")
102-
else:
103-
size = self.size
104-
105-
super().__init__(size=size, **kwargs)
92+
super().__init__(**kwargs)
10693

10794
def _preprocess(
10895
self,
10996
images: list["torch.Tensor"],
11097
do_resize: bool,
111-
size: SizeDict,
11298
interpolation: Optional["F.InterpolationMode"],
11399
do_rescale: bool,
114100
rescale_factor: float,
@@ -118,65 +104,19 @@ def _preprocess(
118104
patch_size: int,
119105
temporal_patch_size: int,
120106
merge_size: int,
121-
do_convert_rgb: bool,
122-
input_data_format: Optional[Union[str, ChannelDimension]],
123-
device: Optional[Union[str, torch.device]],
124107
disable_grouping: Optional[bool],
125-
):
108+
return_tensors: Optional[Union[str, TensorType]],
109+
**kwargs,
110+
) -> BatchFeature:
126111
"""
127112
Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`.
128-
129-
Args:
130-
images (`ImageInput`):
131-
Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`.
132-
vision_info (`List[Dict]`, *optional*):
133-
Optional list of dictionaries containing additional information about vision inputs.
134-
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
135-
Whether to resize the image.
136-
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
137-
Size of the image after resizing. `shortest_edge` and `longest_edge` keys must be present.
138-
interpolation (`InterpolationMode`):
139-
Resampling filter to use if resizing the image.
140-
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
141-
Whether to rescale the image.
142-
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
143-
Scale factor to use if rescaling the image.
144-
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
145-
Whether to normalize the image.
146-
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
147-
Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
148-
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
149-
Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
150-
patch_size (`int`, *optional*, defaults to `self.patch_size`):
151-
The spatial patch size of the vision encoder.
152-
temporal_patch_size (`int`, *optional*, defaults to `self.temporal_patch_size`):
153-
The temporal patch size of the vision encoder.
154-
merge_size (`int`, *optional*, defaults to `self.merge_size`):
155-
The merge size of the vision encoder to llm encoder.
156-
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
157-
Whether to convert the image to RGB.
158-
input_data_format (`ChannelDimension` or `str`, *optional*):
159-
The channel dimension format for the input image. Can be one of:
160-
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
161-
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
162-
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
163-
device (`torch.device`, *optional*):
164-
The device to process the images on. If unset, the device is inferred from the input images.
165113
"""
166-
images = self._prepare_input_images(
167-
images=images,
168-
do_convert_rgb=do_convert_rgb,
169-
input_data_format=input_data_format,
170-
device=device,
171-
)
172-
173-
height, width = get_image_size(images[0], channel_dim=ChannelDimension.FIRST)
174-
resized_height, resized_width = height, width
175114

176115
# Group images by size for batched resizing
177116
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
178117
resized_images_grouped = {}
179118
for shape, stacked_images in grouped_images.items():
119+
height, width = stacked_images.shape[-2:]
180120
if do_resize:
181121
resized_height, resized_width = smart_resize(
182122
num_frames=temporal_patch_size,
@@ -185,183 +125,73 @@ def _preprocess(
185125
temporal_factor=temporal_patch_size,
186126
factor=patch_size * merge_size,
187127
)
188-
stacked_images = F.resize(
189-
stacked_images, size=(resized_height, resized_width), interpolation=interpolation
128+
stacked_images = self.resize(
129+
stacked_images,
130+
size=SizeDict(height=resized_height, width=resized_width),
131+
interpolation=interpolation,
190132
)
191133
resized_images_grouped[shape] = stacked_images
192134
resized_images = reorder_images(resized_images_grouped, grouped_images_index)
193135
# Group images by size for further processing
194136
# Needed in case do_resize is False, or resize returns images with different sizes
195137
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
196138
processed_images_grouped = {}
139+
processed_grids = {}
197140
for shape, stacked_images in grouped_images.items():
198141
# Fused rescale and normalize
199142
stacked_images = self.rescale_and_normalize(
200143
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
201144
)
202-
processed_images_grouped[shape] = stacked_images
145+
# add a temporal dimension
146+
patches = stacked_images.unsqueeze(1)
147+
if patches.shape[1] % temporal_patch_size != 0:
148+
repeats = patches[:, -1:].repeat(1, temporal_patch_size - 1, 1, 1, 1)
149+
patches = torch.cat([patches, repeats], dim=1)
150+
batch_size, grid_t, channel = patches.shape[:3]
151+
grid_t = grid_t // temporal_patch_size
152+
grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
153+
154+
patches = patches.view(
155+
batch_size,
156+
grid_t,
157+
temporal_patch_size,
158+
channel,
159+
grid_h // merge_size,
160+
merge_size,
161+
patch_size,
162+
grid_w // merge_size,
163+
merge_size,
164+
patch_size,
165+
)
166+
patches = patches.permute(0, 1, 4, 7, 5, 8, 3, 2, 6, 9)
167+
flatten_patches = patches.reshape(
168+
batch_size,
169+
grid_t * grid_h * grid_w,
170+
channel * temporal_patch_size * patch_size * patch_size,
171+
)
203172

204-
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
205-
patches = torch.stack(processed_images, dim=0)
206-
if patches.shape[0] % temporal_patch_size != 0:
207-
repeats = patches[-1].unsqueeze(0).repeat(temporal_patch_size - 1, 1, 1, 1)
208-
patches = torch.cat([patches, repeats], dim=0)
173+
processed_images_grouped[shape] = flatten_patches
174+
processed_grids[shape] = [[grid_t, grid_h, grid_w]] * batch_size
209175

210-
channel = patches.shape[1]
211-
grid_t = patches.shape[0] // temporal_patch_size
212-
grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
176+
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
177+
processed_grids = reorder_images(processed_grids, grouped_images_index)
178+
pixel_values = torch.stack(processed_images, dim=0)
179+
image_grid_thw = torch.tensor(processed_grids)
213180

214-
patches = patches.view(
215-
grid_t,
216-
temporal_patch_size,
217-
channel,
218-
grid_h // merge_size,
219-
merge_size,
220-
patch_size,
221-
grid_w // merge_size,
222-
merge_size,
223-
patch_size,
224-
)
225-
patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8)
226-
flatten_patches = patches.reshape(
227-
grid_t * grid_h * grid_w, channel * temporal_patch_size * patch_size * patch_size
181+
return BatchFeature(
182+
data={"pixel_values": pixel_values, "image_grid_thw": image_grid_thw}, tensor_type=return_tensors
228183
)
229184

230-
return flatten_patches, (grid_t, grid_h, grid_w)
231-
232185
@auto_docstring
233186
def preprocess(
234187
self,
235188
images: ImageInput,
236-
videos: VideoInput = None,
237-
do_resize: Optional[bool] = None,
238-
size: Optional[dict[str, int]] = None,
239-
resample: Optional[Union["PILImageResampling", "F.InterpolationMode"]] = None,
240-
do_rescale: Optional[bool] = None,
241-
rescale_factor: Optional[float] = None,
242-
do_normalize: Optional[bool] = None,
243-
image_mean: Optional[Union[float, list[float]]] = None,
244-
image_std: Optional[Union[float, list[float]]] = None,
245-
patch_size: Optional[int] = None,
246-
temporal_patch_size: Optional[int] = None,
247-
merge_size: Optional[int] = None,
248-
do_convert_rgb: Optional[bool] = None,
249-
return_tensors: Optional[Union[str, TensorType]] = None,
250-
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
251-
input_data_format: Optional[Union[str, ChannelDimension]] = None,
252-
device: Optional["torch.device"] = None,
253-
disable_grouping: Optional[bool] = False,
254-
**kwargs,
255-
):
256-
r"""
257-
patch_size (`int`, *optional*, defaults to 14):
258-
The spatial patch size of the vision encoder.
259-
temporal_patch_size (`int`, *optional*, defaults to 2):
260-
The temporal patch size of the vision encoder.
261-
merge_size (`int`, *optional*, defaults to 2):
262-
The merge size of the vision encoder to llm encoder.
189+
**kwargs: Unpack[Glm4vFastImageProcessorKwargs],
190+
) -> BatchFeature:
263191
"""
264-
265-
do_resize = do_resize if do_resize is not None else self.do_resize
266-
size = size if size is not None else self.size
267-
resample = resample if resample is not None else self.resample
268-
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
269-
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
270-
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
271-
image_mean = image_mean if image_mean is not None else self.image_mean
272-
image_std = image_std if image_std is not None else self.image_std
273-
patch_size = patch_size if patch_size is not None else self.patch_size
274-
temporal_patch_size = temporal_patch_size if temporal_patch_size is not None else self.temporal_patch_size
275-
merge_size = merge_size if merge_size is not None else self.merge_size
276-
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
277-
278-
# Make hashable for cache
279-
size = SizeDict(**size) if size is not None else None
280-
image_mean = tuple(image_mean) if image_mean is not None else None
281-
image_std = tuple(image_std) if image_std is not None else None
282-
283-
self._validate_preprocess_kwargs(
284-
do_rescale=do_rescale,
285-
rescale_factor=rescale_factor,
286-
do_normalize=do_normalize,
287-
image_mean=image_mean,
288-
image_std=image_std,
289-
do_resize=do_resize,
290-
size=size,
291-
resample=resample,
292-
return_tensors=return_tensors,
293-
data_format=data_format,
294-
)
295-
interpolation = (
296-
pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample
297-
)
298-
299-
if images is not None:
300-
images = make_flat_list_of_images(images)
301-
302-
if images is not None and not valid_images(images):
303-
raise ValueError(
304-
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
305-
"torch.Tensor, tf.Tensor or jax.ndarray."
306-
)
307-
308-
data = {}
309-
if images is not None:
310-
pixel_values, vision_grid_thws = [], []
311-
for image in images:
312-
patches, image_grid_thw = self._preprocess(
313-
image,
314-
do_resize=do_resize,
315-
size=size,
316-
interpolation=interpolation,
317-
do_rescale=do_rescale,
318-
rescale_factor=rescale_factor,
319-
do_normalize=do_normalize,
320-
image_mean=image_mean,
321-
image_std=image_std,
322-
patch_size=patch_size,
323-
temporal_patch_size=temporal_patch_size,
324-
merge_size=merge_size,
325-
do_convert_rgb=do_convert_rgb,
326-
input_data_format=input_data_format,
327-
device=device,
328-
disable_grouping=disable_grouping,
329-
)
330-
pixel_values.extend(patches)
331-
vision_grid_thws.append(image_grid_thw)
332-
pixel_values = torch.stack(pixel_values)
333-
vision_grid_thws = torch.tensor(vision_grid_thws)
334-
data.update({"pixel_values": pixel_values, "image_grid_thw": vision_grid_thws})
335-
336-
return BatchFeature(data=data, tensor_type=return_tensors)
337-
338-
def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None):
192+
Preprocess an image or batch of images.
339193
"""
340-
A utility that returns number of image patches for a given image size.
341-
342-
Args:
343-
height (`int`):
344-
Height of the input image.
345-
width (`int`):
346-
Width of the input image.
347-
images_kwargs (`dict`, *optional*)
348-
Any kwargs to override defaults of the image processor.
349-
Returns:
350-
`int`: Number of image patches per image.
351-
"""
352-
patch_size = images_kwargs.get("patch_size", None) or self.patch_size
353-
merge_size = images_kwargs.get("merge_size", None) or self.merge_size
354-
355-
factor = patch_size * merge_size
356-
resized_height, resized_width = smart_resize(
357-
num_frames=self.temporal_patch_size,
358-
height=height,
359-
width=width,
360-
temporal_factor=self.temporal_patch_size,
361-
factor=factor,
362-
)
363-
grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
364-
return grid_h * grid_w
194+
return super().preprocess(images, **kwargs)
365195

366196

367197
__all__ = ["Glm4vImageProcessorFast"]

src/transformers/models/glm4v/video_processing_glm4v.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,6 @@
5353
if is_vision_available():
5454
from ...image_utils import PILImageResampling
5555

56-
import torch.nn.functional as F
57-
5856

5957
class Glm4vVideoProcessorInitKwargs(VideosKwargs):
6058
max_image_size: dict[str, int] = None
@@ -145,9 +143,8 @@ def _preprocess(
145143
self,
146144
videos: list[torch.Tensor],
147145
video_metadata: Optional[Union[list[VideoMetadata], list[dict]]] = None,
148-
do_convert_rgb: bool = True,
149146
do_resize: bool = True,
150-
size: SizeDict = None,
147+
interpolation: PILImageResampling = PILImageResampling.BICUBIC,
151148
do_rescale: bool = True,
152149
rescale_factor: float = 1 / 255.0,
153150
do_normalize: bool = True,
@@ -194,8 +191,10 @@ def _preprocess(
194191
max_pixels=self.max_image_size["longest_edge"],
195192
)
196193
stacked_videos = stacked_videos.view(B * T, C, H, W)
197-
stacked_videos = F.interpolate(
198-
stacked_videos, size=(resized_height, resized_width), mode="bicubic", align_corners=False
194+
stacked_videos = self.resize(
195+
stacked_videos,
196+
size=SizeDict(height=resized_height, width=resized_width),
197+
interpolation=interpolation,
199198
)
200199
stacked_videos = stacked_videos.view(B, T, C, resized_height, resized_width)
201200
resized_videos_grouped[shape] = stacked_videos

0 commit comments

Comments
 (0)