@@ -68,7 +68,11 @@ def add_segmentation_and_position(x, data_columns, padding_token=0):
68
68
69
69
def reformat_prompt (example , column , image_placeholder , model_name ):
70
70
"""reformat prompt for multimodal SFT"""
71
- example [column ] = multimodal_utils .reformat_prompt (example [column ], image_placeholder , model_name )
71
+ if isinstance (example ["images" ], list ):
72
+ num_images = len (example ["images" ])
73
+ else :
74
+ num_images = 1
75
+ example [column ] = multimodal_utils .reformat_prompt (example [column ], image_placeholder , model_name , num_images )
72
76
return example
73
77
74
78
@@ -80,11 +84,19 @@ def reformat_response(example, column, model_name):
80
84
81
85
def pre_process_image_sft (example , image_column , model_name ):
82
86
"""pre-process image for multimodal SFT"""
83
- image = multimodal_utils .convert_to_RGB (example [image_column ])
84
- # TODO(aireenmei, hengtaoguo): add support for different image sizes
85
- image = multimodal_utils .resize_image (image , model_name )
86
- image = np .array (image )
87
- example [image_column ] = multimodal_utils .pre_process_image (image , model_name )
87
+
88
+ def _process_image_fn (image ):
89
+ image = multimodal_utils .convert_to_RGB (image )
90
+ # TODO(aireenmei, hengtaoguo): add support for different image sizes
91
+ image = multimodal_utils .resize_image (image , model_name )
92
+ image = np .array (image )
93
+ image = multimodal_utils .pre_process_image (image , model_name )
94
+ return image
95
+
96
+ if isinstance (example [image_column ], list ):
97
+ example [image_column ] = [_process_image_fn (img ) for img in example [image_column ]]
98
+ else :
99
+ example [image_column ] = _process_image_fn (example [image_column ])
88
100
return example
89
101
90
102
@@ -93,7 +105,10 @@ def prepare_text_for_image_fusion(example, column_name, model_name):
93
105
example [column_name ] = multimodal_utils .prepare_text_for_image_fusion (
94
106
example [column_name ], model_name , processor_output = example ["images" ]
95
107
)
96
- example ["images" ] = example ["images" ].pixel_values
108
+ if isinstance (example ["images" ], list ):
109
+ example ["images" ] = [image .pixel_values for image in example ["images" ]]
110
+ else :
111
+ example ["images" ] = example ["images" ].pixel_values
97
112
return example
98
113
99
114
@@ -400,58 +415,58 @@ def map(self, element):
400
415
401
416
@dataclasses .dataclass
402
417
class PadOrTrimToMaxLength (grain .MapTransform ):
403
- """Pads/Trims each input to the specified length
404
- and returns true_length of input
405
- """
406
-
407
- def __init__ (self , max_length ):
408
- self .max_length = max_length
409
-
410
- def map (self , element : dict [str , np .ndarray ]):
411
- """map to each element"""
412
-
413
- def _pad (x , max_length ):
414
- pad_amount = max (max_length - x .shape [0 ], 0 )
415
- pad_amount = [(0 , pad_amount )] + [(0 , 0 )] * (len (x .shape ) - 1 )
416
- return np .pad (x , pad_amount )[:max_length ]
417
-
418
- data_columns = list (element .keys ())
419
- for data_column in data_columns :
420
- element [f"{ data_column } _segmentation" ] = (element [data_column ] != 0 ).astype (np .int32 )
421
- element [f"{ data_column } _position" ] = np .arange (element [data_column ].shape [0 ], dtype = np .int32 )
422
- element [f"{ data_column } _true_length" ] = np .array ([element [data_column ].shape [0 ]], dtype = np .int32 )
423
- for key , _ in element .items ():
424
- if "true_length" not in key :
425
- element [key ] = _pad (element [key ], self .max_length )
426
- # for data_column in data_columns:
427
- # data[f"{data_column}_true_length"] = _max_true_length(data[data_column], 0)
428
- return element
429
-
418
+ """Pads or trims each input to the specified length.
419
+ And optionally add true length for the input."""
430
420
431
- @dataclasses .dataclass
432
- class PadToMaxLength (grain .MapTransform ):
433
- """Pads each input to the specified length"""
434
-
435
- def __init__ (self , max_length , pad_id ):
421
+ def __init__ (self , max_length , pad_id = 0 , model_name = None , add_true_length = False , max_num_images_per_example = - 1 ):
436
422
self .max_length = max_length
437
423
self .pad_id = pad_id
424
+ self .model_name = model_name
425
+ self .add_true_length = add_true_length
426
+ self .max_num_images_per_example = max_num_images_per_example
427
+
428
+ def _pad_text (self , x , max_length , pad_id ):
429
+ pad_amount = max (max_length - x .shape [0 ], 0 )
430
+ pad_amount = [(0 , pad_amount )] + [(0 , 0 )] * (len (x .shape ) - 1 )
431
+ return np .pad (x , pad_amount , constant_values = pad_id )[: self .max_length ]
432
+
433
+ def _pad_image (self , images ):
434
+ image_offsets = multimodal_utils .get_image_offsets (self .model_name , None )
435
+ max_num_images = (self .max_length // image_offsets ) - 1 # -1 to reserve space for at least one text token
436
+ if self .max_num_images_per_example > 0 :
437
+ max_num_images = min (self .max_num_images_per_example , max_num_images )
438
+ image_shape = multimodal_utils .get_dummy_image_shape_for_init (self .model_name )[2 :]
439
+ assert (
440
+ images .shape [0 ] <= max_num_images
441
+ ), f"Number of images { images .shape [0 ]} exceeds the maximum allowed { max_num_images } "
442
+ if images .shape [0 ] < max_num_images :
443
+ pad_size = max_num_images - images .shape [0 ]
444
+ pad_shape = (pad_size ,) + image_shape
445
+ pad_images = np .zeros (pad_shape , dtype = images .dtype )
446
+ if images is not None and images .size > 0 :
447
+ images = np .concatenate ([images , pad_images ], axis = 0 )
448
+ else :
449
+ images = pad_images
450
+ return images
438
451
439
452
def map (self , element : dict [str , np .ndarray ]):
440
453
"""map to each element"""
441
-
442
- def _pad (x , max_length , pad_id ):
443
- pad_amount = max (max_length - x .shape [0 ], 0 )
444
- pad_amount = [(0 , pad_amount )] + [(0 , 0 )] * (len (x .shape ) - 1 )
445
- return np .pad (x , pad_amount , constant_values = pad_id )
446
-
447
454
data_columns = list (element .keys ())
448
455
for data_column in data_columns :
449
456
if data_column != "images" :
450
457
element [f"{ data_column } _segmentation" ] = (element [data_column ] != self .pad_id ).astype (np .int32 )
451
458
element [f"{ data_column } _position" ] = np .arange (element [data_column ].shape [0 ], dtype = np .int32 )
459
+ if self .add_true_length :
460
+ element [f"{ data_column } _true_length" ] = np .array ([element [data_column ].shape [0 ]], dtype = np .int32 )
452
461
for key , _ in element .items ():
453
- if key != "images" :
454
- element [key ] = _pad (element [key ], self .max_length , self .pad_id )
462
+ if key == "images" :
463
+ if isinstance (element ["images" ], list ):
464
+ assert self .model_name is not None , "model_name must be provided when padding images"
465
+ element ["images" ] = self ._pad_image (np .asarray (element ["images" ]))
466
+ else :
467
+ element ["images" ] = np .asarray (element ["images" ])[None , ...]
468
+ elif "true_length" not in key :
469
+ element [key ] = self ._pad_text (element [key ], self .max_length , self .pad_id )
455
470
return element
456
471
457
472
0 commit comments