28
28
from ...image_utils import (
29
29
OPENAI_CLIP_MEAN ,
30
30
OPENAI_CLIP_STD ,
31
- ChannelDimension ,
32
31
ImageInput ,
33
32
PILImageResampling ,
34
33
SizeDict ,
35
- get_image_size ,
36
- make_flat_list_of_images ,
37
- valid_images ,
38
34
)
39
35
from ...processing_utils import Unpack
40
36
from ...utils import (
45
41
is_torchvision_v2_available ,
46
42
logging ,
47
43
)
48
- from ...video_utils import VideoInput
49
44
from .image_processing_glm4v import smart_resize
50
45
51
46
54
49
55
50
56
51
if is_torchvision_available ():
57
- from ...image_utils import pil_torch_interpolation_mapping
58
-
59
52
if is_torchvision_v2_available ():
60
53
from torchvision .transforms .v2 import functional as F
61
54
else :
@@ -96,19 +89,12 @@ class Glm4vImageProcessorFast(BaseImageProcessorFast):
96
89
model_input_names = ["pixel_values" , "image_grid_thw" ]
97
90
98
91
def __init__ (self , ** kwargs : Unpack [Glm4vFastImageProcessorKwargs ]):
99
- size = kwargs .pop ("size" , None )
100
- if size is not None and ("shortest_edge" not in size or "longest_edge" not in size ):
101
- raise ValueError ("size must contain 'shortest_edge' and 'longest_edge' keys." )
102
- else :
103
- size = self .size
104
-
105
- super ().__init__ (size = size , ** kwargs )
92
+ super ().__init__ (** kwargs )
106
93
107
94
def _preprocess (
108
95
self ,
109
96
images : list ["torch.Tensor" ],
110
97
do_resize : bool ,
111
- size : SizeDict ,
112
98
interpolation : Optional ["F.InterpolationMode" ],
113
99
do_rescale : bool ,
114
100
rescale_factor : float ,
@@ -118,65 +104,19 @@ def _preprocess(
118
104
patch_size : int ,
119
105
temporal_patch_size : int ,
120
106
merge_size : int ,
121
- do_convert_rgb : bool ,
122
- input_data_format : Optional [Union [str , ChannelDimension ]],
123
- device : Optional [Union [str , torch .device ]],
124
107
disable_grouping : Optional [bool ],
125
- ):
108
+ return_tensors : Optional [Union [str , TensorType ]],
109
+ ** kwargs ,
110
+ ) -> BatchFeature :
126
111
"""
127
112
Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`.
128
-
129
- Args:
130
- images (`ImageInput`):
131
- Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`.
132
- vision_info (`List[Dict]`, *optional*):
133
- Optional list of dictionaries containing additional information about vision inputs.
134
- do_resize (`bool`, *optional*, defaults to `self.do_resize`):
135
- Whether to resize the image.
136
- size (`Dict[str, int]`, *optional*, defaults to `self.size`):
137
- Size of the image after resizing. `shortest_edge` and `longest_edge` keys must be present.
138
- interpolation (`InterpolationMode`):
139
- Resampling filter to use if resizing the image.
140
- do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
141
- Whether to rescale the image.
142
- rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
143
- Scale factor to use if rescaling the image.
144
- do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
145
- Whether to normalize the image.
146
- image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
147
- Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
148
- image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
149
- Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
150
- patch_size (`int`, *optional*, defaults to `self.patch_size`):
151
- The spatial patch size of the vision encoder.
152
- temporal_patch_size (`int`, *optional*, defaults to `self.temporal_patch_size`):
153
- The temporal patch size of the vision encoder.
154
- merge_size (`int`, *optional*, defaults to `self.merge_size`):
155
- The merge size of the vision encoder to llm encoder.
156
- do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
157
- Whether to convert the image to RGB.
158
- input_data_format (`ChannelDimension` or `str`, *optional*):
159
- The channel dimension format for the input image. Can be one of:
160
- - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
161
- - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
162
- - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
163
- device (`torch.device`, *optional*):
164
- The device to process the images on. If unset, the device is inferred from the input images.
165
113
"""
166
- images = self ._prepare_input_images (
167
- images = images ,
168
- do_convert_rgb = do_convert_rgb ,
169
- input_data_format = input_data_format ,
170
- device = device ,
171
- )
172
-
173
- height , width = get_image_size (images [0 ], channel_dim = ChannelDimension .FIRST )
174
- resized_height , resized_width = height , width
175
114
176
115
# Group images by size for batched resizing
177
116
grouped_images , grouped_images_index = group_images_by_shape (images , disable_grouping = disable_grouping )
178
117
resized_images_grouped = {}
179
118
for shape , stacked_images in grouped_images .items ():
119
+ height , width = stacked_images .shape [- 2 :]
180
120
if do_resize :
181
121
resized_height , resized_width = smart_resize (
182
122
num_frames = temporal_patch_size ,
@@ -185,183 +125,73 @@ def _preprocess(
185
125
temporal_factor = temporal_patch_size ,
186
126
factor = patch_size * merge_size ,
187
127
)
188
- stacked_images = F .resize (
189
- stacked_images , size = (resized_height , resized_width ), interpolation = interpolation
128
+ stacked_images = self .resize (
129
+ stacked_images ,
130
+ size = SizeDict (height = resized_height , width = resized_width ),
131
+ interpolation = interpolation ,
190
132
)
191
133
resized_images_grouped [shape ] = stacked_images
192
134
resized_images = reorder_images (resized_images_grouped , grouped_images_index )
193
135
# Group images by size for further processing
194
136
# Needed in case do_resize is False, or resize returns images with different sizes
195
137
grouped_images , grouped_images_index = group_images_by_shape (resized_images , disable_grouping = disable_grouping )
196
138
processed_images_grouped = {}
139
+ processed_grids = {}
197
140
for shape , stacked_images in grouped_images .items ():
198
141
# Fused rescale and normalize
199
142
stacked_images = self .rescale_and_normalize (
200
143
stacked_images , do_rescale , rescale_factor , do_normalize , image_mean , image_std
201
144
)
202
- processed_images_grouped [shape ] = stacked_images
145
+ # add a temporal dimension
146
+ patches = stacked_images .unsqueeze (1 )
147
+ if patches .shape [1 ] % temporal_patch_size != 0 :
148
+ repeats = patches [:, - 1 :].repeat (1 , temporal_patch_size - 1 , 1 , 1 , 1 )
149
+ patches = torch .cat ([patches , repeats ], dim = 1 )
150
+ batch_size , grid_t , channel = patches .shape [:3 ]
151
+ grid_t = grid_t // temporal_patch_size
152
+ grid_h , grid_w = resized_height // patch_size , resized_width // patch_size
153
+
154
+ patches = patches .view (
155
+ batch_size ,
156
+ grid_t ,
157
+ temporal_patch_size ,
158
+ channel ,
159
+ grid_h // merge_size ,
160
+ merge_size ,
161
+ patch_size ,
162
+ grid_w // merge_size ,
163
+ merge_size ,
164
+ patch_size ,
165
+ )
166
+ patches = patches .permute (0 , 1 , 4 , 7 , 5 , 8 , 3 , 2 , 6 , 9 )
167
+ flatten_patches = patches .reshape (
168
+ batch_size ,
169
+ grid_t * grid_h * grid_w ,
170
+ channel * temporal_patch_size * patch_size * patch_size ,
171
+ )
203
172
204
- processed_images = reorder_images (processed_images_grouped , grouped_images_index )
205
- patches = torch .stack (processed_images , dim = 0 )
206
- if patches .shape [0 ] % temporal_patch_size != 0 :
207
- repeats = patches [- 1 ].unsqueeze (0 ).repeat (temporal_patch_size - 1 , 1 , 1 , 1 )
208
- patches = torch .cat ([patches , repeats ], dim = 0 )
173
+ processed_images_grouped [shape ] = flatten_patches
174
+ processed_grids [shape ] = [[grid_t , grid_h , grid_w ]] * batch_size
209
175
210
- channel = patches .shape [1 ]
211
- grid_t = patches .shape [0 ] // temporal_patch_size
212
- grid_h , grid_w = resized_height // patch_size , resized_width // patch_size
176
+ processed_images = reorder_images (processed_images_grouped , grouped_images_index )
177
+ processed_grids = reorder_images (processed_grids , grouped_images_index )
178
+ pixel_values = torch .stack (processed_images , dim = 0 )
179
+ image_grid_thw = torch .tensor (processed_grids )
213
180
214
- patches = patches .view (
215
- grid_t ,
216
- temporal_patch_size ,
217
- channel ,
218
- grid_h // merge_size ,
219
- merge_size ,
220
- patch_size ,
221
- grid_w // merge_size ,
222
- merge_size ,
223
- patch_size ,
224
- )
225
- patches = patches .permute (0 , 3 , 6 , 4 , 7 , 2 , 1 , 5 , 8 )
226
- flatten_patches = patches .reshape (
227
- grid_t * grid_h * grid_w , channel * temporal_patch_size * patch_size * patch_size
181
+ return BatchFeature (
182
+ data = {"pixel_values" : pixel_values , "image_grid_thw" : image_grid_thw }, tensor_type = return_tensors
228
183
)
229
184
230
- return flatten_patches , (grid_t , grid_h , grid_w )
231
-
232
185
@auto_docstring
233
186
def preprocess (
234
187
self ,
235
188
images : ImageInput ,
236
- videos : VideoInput = None ,
237
- do_resize : Optional [bool ] = None ,
238
- size : Optional [dict [str , int ]] = None ,
239
- resample : Optional [Union ["PILImageResampling" , "F.InterpolationMode" ]] = None ,
240
- do_rescale : Optional [bool ] = None ,
241
- rescale_factor : Optional [float ] = None ,
242
- do_normalize : Optional [bool ] = None ,
243
- image_mean : Optional [Union [float , list [float ]]] = None ,
244
- image_std : Optional [Union [float , list [float ]]] = None ,
245
- patch_size : Optional [int ] = None ,
246
- temporal_patch_size : Optional [int ] = None ,
247
- merge_size : Optional [int ] = None ,
248
- do_convert_rgb : Optional [bool ] = None ,
249
- return_tensors : Optional [Union [str , TensorType ]] = None ,
250
- data_format : Optional [ChannelDimension ] = ChannelDimension .FIRST ,
251
- input_data_format : Optional [Union [str , ChannelDimension ]] = None ,
252
- device : Optional ["torch.device" ] = None ,
253
- disable_grouping : Optional [bool ] = False ,
254
- ** kwargs ,
255
- ):
256
- r"""
257
- patch_size (`int`, *optional*, defaults to 14):
258
- The spatial patch size of the vision encoder.
259
- temporal_patch_size (`int`, *optional*, defaults to 2):
260
- The temporal patch size of the vision encoder.
261
- merge_size (`int`, *optional*, defaults to 2):
262
- The merge size of the vision encoder to llm encoder.
189
+ ** kwargs : Unpack [Glm4vFastImageProcessorKwargs ],
190
+ ) -> BatchFeature :
263
191
"""
264
-
265
- do_resize = do_resize if do_resize is not None else self .do_resize
266
- size = size if size is not None else self .size
267
- resample = resample if resample is not None else self .resample
268
- do_rescale = do_rescale if do_rescale is not None else self .do_rescale
269
- rescale_factor = rescale_factor if rescale_factor is not None else self .rescale_factor
270
- do_normalize = do_normalize if do_normalize is not None else self .do_normalize
271
- image_mean = image_mean if image_mean is not None else self .image_mean
272
- image_std = image_std if image_std is not None else self .image_std
273
- patch_size = patch_size if patch_size is not None else self .patch_size
274
- temporal_patch_size = temporal_patch_size if temporal_patch_size is not None else self .temporal_patch_size
275
- merge_size = merge_size if merge_size is not None else self .merge_size
276
- do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self .do_convert_rgb
277
-
278
- # Make hashable for cache
279
- size = SizeDict (** size ) if size is not None else None
280
- image_mean = tuple (image_mean ) if image_mean is not None else None
281
- image_std = tuple (image_std ) if image_std is not None else None
282
-
283
- self ._validate_preprocess_kwargs (
284
- do_rescale = do_rescale ,
285
- rescale_factor = rescale_factor ,
286
- do_normalize = do_normalize ,
287
- image_mean = image_mean ,
288
- image_std = image_std ,
289
- do_resize = do_resize ,
290
- size = size ,
291
- resample = resample ,
292
- return_tensors = return_tensors ,
293
- data_format = data_format ,
294
- )
295
- interpolation = (
296
- pil_torch_interpolation_mapping [resample ] if isinstance (resample , (PILImageResampling , int )) else resample
297
- )
298
-
299
- if images is not None :
300
- images = make_flat_list_of_images (images )
301
-
302
- if images is not None and not valid_images (images ):
303
- raise ValueError (
304
- "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
305
- "torch.Tensor, tf.Tensor or jax.ndarray."
306
- )
307
-
308
- data = {}
309
- if images is not None :
310
- pixel_values , vision_grid_thws = [], []
311
- for image in images :
312
- patches , image_grid_thw = self ._preprocess (
313
- image ,
314
- do_resize = do_resize ,
315
- size = size ,
316
- interpolation = interpolation ,
317
- do_rescale = do_rescale ,
318
- rescale_factor = rescale_factor ,
319
- do_normalize = do_normalize ,
320
- image_mean = image_mean ,
321
- image_std = image_std ,
322
- patch_size = patch_size ,
323
- temporal_patch_size = temporal_patch_size ,
324
- merge_size = merge_size ,
325
- do_convert_rgb = do_convert_rgb ,
326
- input_data_format = input_data_format ,
327
- device = device ,
328
- disable_grouping = disable_grouping ,
329
- )
330
- pixel_values .extend (patches )
331
- vision_grid_thws .append (image_grid_thw )
332
- pixel_values = torch .stack (pixel_values )
333
- vision_grid_thws = torch .tensor (vision_grid_thws )
334
- data .update ({"pixel_values" : pixel_values , "image_grid_thw" : vision_grid_thws })
335
-
336
- return BatchFeature (data = data , tensor_type = return_tensors )
337
-
338
- def get_number_of_image_patches (self , height : int , width : int , images_kwargs = None ):
192
+ Preprocess an image or batch of images.
339
193
"""
340
- A utility that returns number of image patches for a given image size.
341
-
342
- Args:
343
- height (`int`):
344
- Height of the input image.
345
- width (`int`):
346
- Width of the input image.
347
- images_kwargs (`dict`, *optional*)
348
- Any kwargs to override defaults of the image processor.
349
- Returns:
350
- `int`: Number of image patches per image.
351
- """
352
- patch_size = images_kwargs .get ("patch_size" , None ) or self .patch_size
353
- merge_size = images_kwargs .get ("merge_size" , None ) or self .merge_size
354
-
355
- factor = patch_size * merge_size
356
- resized_height , resized_width = smart_resize (
357
- num_frames = self .temporal_patch_size ,
358
- height = height ,
359
- width = width ,
360
- temporal_factor = self .temporal_patch_size ,
361
- factor = factor ,
362
- )
363
- grid_h , grid_w = resized_height // patch_size , resized_width // patch_size
364
- return grid_h * grid_w
194
+ return super ().preprocess (images , ** kwargs )
365
195
366
196
367
197
__all__ = ["Glm4vImageProcessorFast" ]
0 commit comments