Skip to content

Commit 0a364fe

Browse files
Merge pull request #2350 from AI-Hypercomputer:aireen/gemma3-multi-image
PiperOrigin-RevId: 808811878
2 parents d4495e1 + 8fdac10 commit 0a364fe

File tree

10 files changed

+150
-111
lines changed

10 files changed

+150
-111
lines changed

src/MaxText/configs/base.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -798,9 +798,13 @@ freeze_vision_encoder_params: True
798798
dtype_mm: "float32" # Data type for multimodal model's vision encoder
799799
remat_policy_for_vit: "minimal" # Remat policy for multimodal model's vision encoder. Check `remat_policy` for options.
800800
image_size_for_vit: 896 # Default for Gemma3, and should be overwritten by model's config
801-
image_path: "" # Local image path used for decoding
801+
image_path: "" # Local image path used for decoding, can be multiple paths separated by comma, exp "/path/image1.jpg,/path/image2.jpg"
802802
image_placeholder: "<|image|>"
803803
posemb_type_for_vit: "learn"
804+
# max_num_images_per_example only applies for training when your image column is a list of images.
805+
# -1 means no limit, and will pad to the max possible number of images determined by sequence length.
806+
# Set it to avoid unnecessary padding if you know the maximum number of images per example.
807+
max_num_images_per_example: -1
804808

805809
### llama4 multi modal configs
806810
hidden_size_for_vit: 1408

src/MaxText/decode.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import os
1818
from typing import Sequence
19-
19+
import numpy as np
2020
import jax
2121
import jax.numpy as jnp
2222

@@ -101,13 +101,14 @@ def main(argv: Sequence[str]) -> None:
101101
prefill_length = config.max_prefill_predict_length
102102
processor_output = multimodal_utils.PreprocessorOutput()
103103
if config.use_multimodal:
104-
text = multimodal_utils.reformat_prompt(
105-
text, image_placeholder=config.image_placeholder, model_name=config.model_name
104+
image_path = config.image_path.split(",")
105+
images = [multimodal_utils.load_image_from_path(p) for p in image_path]
106+
processor_outputs = [multimodal_utils.pre_process_image(img, model_name=config.model_name) for img in images]
107+
image_offsets = sum(
108+
[multimodal_utils.get_image_offsets(config.model_name, processor_output=po) for po in processor_outputs]
106109
)
107-
# TODO(hengtaoguo): Support multiple images as input.
108-
images = multimodal_utils.load_image_from_path(config.image_path)
109-
processor_output = multimodal_utils.pre_process_image(images, model_name=config.model_name)
110-
prefill_length -= multimodal_utils.get_image_offsets(config.model_name, processor_output=processor_output)
110+
prefill_length -= image_offsets
111+
text = multimodal_utils.reformat_prompt(text, image_placeholder=config.image_placeholder, model_name=config.model_name, num_images=len(images))
111112

112113
metadata = engine.get_tokenizer()
113114
tokenizer_model = engine.build_tokenizer(metadata)
@@ -119,9 +120,9 @@ def main(argv: Sequence[str]) -> None:
119120
tokens, true_length = tokenizer_model.encode(text, is_bos=not has_chat_template, prefill_lengths=[prefill_length])
120121
if config.use_multimodal:
121122
tokens = multimodal_utils.prepare_text_for_image_fusion(
122-
tokens, model_name=config.model_name, processor_output=processor_output
123+
tokens, model_name=config.model_name, processor_output=processor_outputs
123124
)
124-
true_length += multimodal_utils.get_image_offsets(config.model_name, processor_output=processor_output)
125+
true_length += image_offsets
125126

126127
assert (
127128
true_length <= config.max_prefill_predict_length
@@ -147,7 +148,7 @@ def main(argv: Sequence[str]) -> None:
147148
prefill_result, first_token = engine.prefill(
148149
params=params,
149150
padded_tokens=tokens,
150-
images=processor_output.pixel_values,
151+
images=np.stack([po.pixel_values for po in processor_outputs]) if config.use_multimodal else None,
151152
true_length=true_length,
152153
rng=rng_prefill,
153154
slot=i,

src/MaxText/experimental/rl/grpo_input_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def lists2array(x):
169169

170170
operations = [
171171
grain.MapOperation(lists2array),
172-
_input_pipeline_utils.PadOrTrimToMaxLength(max_target_length),
172+
_input_pipeline_utils.PadOrTrimToMaxLength(max_target_length, add_true_length=True),
173173
grain.Batch(batch_size=global_batch_size // jax.process_count(), drop_remainder=drop_remainder),
174174
]
175175

src/MaxText/input_pipeline/_grain_data_processing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def pretrain_preprocessing_pipeline(dataset, config, data_columns, tokenize, gra
133133
}
134134
dataset = dataset.map(_input_pipeline_utils.Rekey(rekey_dict))
135135
else:
136-
dataset = dataset.map(_input_pipeline_utils.PadToMaxLength(config.max_target_length, pad_id))
136+
dataset = dataset.map(_input_pipeline_utils.PadOrTrimToMaxLength(config.max_target_length, pad_id))
137137
batch_fn = functools.partial(grain.experimental.batch_and_pad, batch_size=batch_size, pad_value=pad_id)
138138
dataset = dataset.batch(batch_size, batch_fn=batch_fn)
139139

@@ -175,7 +175,7 @@ def dpo_preprocessing_pipeline(dataset, config, data_columns, tokenize, grain_wo
175175
)
176176
)
177177

178-
dataset = dataset.map(_input_pipeline_utils.PadToMaxLength(config.max_target_length, pad_id))
178+
dataset = dataset.map(_input_pipeline_utils.PadOrTrimToMaxLength(config.max_target_length, pad_id))
179179
batch_size = config.global_batch_size_to_load // jax.process_count()
180180
batch_fn = functools.partial(grain.experimental.batch_and_pad, batch_size=batch_size, pad_value=pad_id)
181181
dataset = dataset.batch(batch_size, batch_fn=batch_fn)

src/MaxText/input_pipeline/_hf_data_processing.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,14 @@ def vision_sft_preprocessing_pipeline(
4343
"""pipeline for multimodal SFT with HF dataset"""
4444

4545
assert len(text_columns) == 2, f"Need two text_columns for query and response, received {text_columns=}"
46-
46+
batch_size = global_batch_size // jax.process_count()
4747
if config.enable_data_shuffling:
4848
dataset = dataset.shuffle(seed=config.data_shuffle_seed)
4949

5050
dataset = dataset.select_columns(text_columns + [image_column])
51+
if image_column != "images":
52+
dataset = dataset.rename_column(image_column, "images")
53+
5154
dataset = dataset.map(
5255
_input_pipeline_utils.reformat_prompt,
5356
fn_kwargs={
@@ -60,8 +63,6 @@ def vision_sft_preprocessing_pipeline(
6063
_input_pipeline_utils.reformat_response,
6164
fn_kwargs={"column": text_columns[1], "model_name": config.model_name},
6265
)
63-
if image_column != "images":
64-
dataset = dataset.rename_column(image_column, "images")
6566

6667
dataset = dataset.map(
6768
_input_pipeline_utils.pre_process_image_sft,
@@ -85,6 +86,7 @@ def vision_sft_preprocessing_pipeline(
8586
dataset = dataset.map(
8687
_input_pipeline_utils.tokenization,
8788
batched=True,
89+
batch_size=global_batch_size,
8890
fn_kwargs={
8991
"hf_tokenizer": tokenizer,
9092
"truncation": False,
@@ -115,8 +117,15 @@ def vision_sft_preprocessing_pipeline(
115117
)
116118
)
117119
# TODO(aireenmei, hengtaoguo): support packing
118-
operations.append(_input_pipeline_utils.PadToMaxLength(config.max_target_length, pad_id))
119-
operations.append(grain.Batch(batch_size=global_batch_size // jax.process_count(), drop_remainder=True))
120+
operations.append(
121+
_input_pipeline_utils.PadOrTrimToMaxLength(
122+
config.max_target_length,
123+
pad_id,
124+
model_name=config.model_name,
125+
max_num_images_per_example=config.max_num_images_per_example,
126+
)
127+
)
128+
operations.append(grain.Batch(batch_size=batch_size, drop_remainder=True))
120129
operations.append(_input_pipeline_utils.ShiftData(ignored_ids=[pad_id], axis=1))
121130
dummy_index_sampler = grain.IndexSampler(
122131
num_records=len(dataset),
@@ -134,7 +143,7 @@ def vision_sft_preprocessing_pipeline(
134143
sampler=dummy_index_sampler,
135144
worker_count=1, # only supports <=1 for now, more workers results in duplicated data
136145
worker_buffer_size=1,
137-
read_options=grain.ReadOptions(num_threads=1, prefetch_buffer_size=128),
146+
read_options=grain.ReadOptions(num_threads=1, prefetch_buffer_size=batch_size * 4),
138147
)
139148

140149
multihost_gen = multihost_dataloading.MultiHostDataLoadIterator(dataloader, global_mesh)
@@ -274,7 +283,7 @@ def lists2array(x):
274283
)
275284
operations.append(_input_pipeline_utils.ReformatPacking(data_column_names))
276285
else:
277-
operations.append(_input_pipeline_utils.PadToMaxLength(max_target_length, pad_id))
286+
operations.append(_input_pipeline_utils.PadOrTrimToMaxLength(max_target_length, pad_id))
278287
operations.append(grain.Batch(batch_size=global_batch_size // jax.process_count(), drop_remainder=drop_remainder))
279288

280289
if shift and not use_dpo:

src/MaxText/input_pipeline/_input_pipeline_utils.py

Lines changed: 62 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,11 @@ def add_segmentation_and_position(x, data_columns, padding_token=0):
6868

6969
def reformat_prompt(example, column, image_placeholder, model_name):
7070
"""reformat prompt for multimodal SFT"""
71-
example[column] = multimodal_utils.reformat_prompt(example[column], image_placeholder, model_name)
71+
if isinstance(example["images"], list):
72+
num_images = len(example["images"])
73+
else:
74+
num_images = 1
75+
example[column] = multimodal_utils.reformat_prompt(example[column], image_placeholder, model_name, num_images)
7276
return example
7377

7478

@@ -80,11 +84,19 @@ def reformat_response(example, column, model_name):
8084

8185
def pre_process_image_sft(example, image_column, model_name):
8286
"""pre-process image for multimodal SFT"""
83-
image = multimodal_utils.convert_to_RGB(example[image_column])
84-
# TODO(aireenmei, hengtaoguo): add support for different image sizes
85-
image = multimodal_utils.resize_image(image, model_name)
86-
image = np.array(image)
87-
example[image_column] = multimodal_utils.pre_process_image(image, model_name)
87+
88+
def _process_image_fn(image):
89+
image = multimodal_utils.convert_to_RGB(image)
90+
# TODO(aireenmei, hengtaoguo): add support for different image sizes
91+
image = multimodal_utils.resize_image(image, model_name)
92+
image = np.array(image)
93+
image = multimodal_utils.pre_process_image(image, model_name)
94+
return image
95+
96+
if isinstance(example[image_column], list):
97+
example[image_column] = [_process_image_fn(img) for img in example[image_column]]
98+
else:
99+
example[image_column] = _process_image_fn(example[image_column])
88100
return example
89101

90102

@@ -93,7 +105,10 @@ def prepare_text_for_image_fusion(example, column_name, model_name):
93105
example[column_name] = multimodal_utils.prepare_text_for_image_fusion(
94106
example[column_name], model_name, processor_output=example["images"]
95107
)
96-
example["images"] = example["images"].pixel_values
108+
if isinstance(example["images"], list):
109+
example["images"] = [image.pixel_values for image in example["images"]]
110+
else:
111+
example["images"] = example["images"].pixel_values
97112
return example
98113

99114

@@ -400,58 +415,58 @@ def map(self, element):
400415

401416
@dataclasses.dataclass
402417
class PadOrTrimToMaxLength(grain.MapTransform):
403-
"""Pads/Trims each input to the specified length
404-
and returns true_length of input
405-
"""
406-
407-
def __init__(self, max_length):
408-
self.max_length = max_length
409-
410-
def map(self, element: dict[str, np.ndarray]):
411-
"""map to each element"""
412-
413-
def _pad(x, max_length):
414-
pad_amount = max(max_length - x.shape[0], 0)
415-
pad_amount = [(0, pad_amount)] + [(0, 0)] * (len(x.shape) - 1)
416-
return np.pad(x, pad_amount)[:max_length]
417-
418-
data_columns = list(element.keys())
419-
for data_column in data_columns:
420-
element[f"{data_column}_segmentation"] = (element[data_column] != 0).astype(np.int32)
421-
element[f"{data_column}_position"] = np.arange(element[data_column].shape[0], dtype=np.int32)
422-
element[f"{data_column}_true_length"] = np.array([element[data_column].shape[0]], dtype=np.int32)
423-
for key, _ in element.items():
424-
if "true_length" not in key:
425-
element[key] = _pad(element[key], self.max_length)
426-
# for data_column in data_columns:
427-
# data[f"{data_column}_true_length"] = _max_true_length(data[data_column], 0)
428-
return element
429-
418+
"""Pads or trims each input to the specified length.
419+
And optionally add true length for the input."""
430420

431-
@dataclasses.dataclass
432-
class PadToMaxLength(grain.MapTransform):
433-
"""Pads each input to the specified length"""
434-
435-
def __init__(self, max_length, pad_id):
421+
def __init__(self, max_length, pad_id=0, model_name=None, add_true_length=False, max_num_images_per_example=-1):
436422
self.max_length = max_length
437423
self.pad_id = pad_id
424+
self.model_name = model_name
425+
self.add_true_length = add_true_length
426+
self.max_num_images_per_example = max_num_images_per_example
427+
428+
def _pad_text(self, x, max_length, pad_id):
429+
pad_amount = max(max_length - x.shape[0], 0)
430+
pad_amount = [(0, pad_amount)] + [(0, 0)] * (len(x.shape) - 1)
431+
return np.pad(x, pad_amount, constant_values=pad_id)[: self.max_length]
432+
433+
def _pad_image(self, images):
434+
image_offsets = multimodal_utils.get_image_offsets(self.model_name, None)
435+
max_num_images = (self.max_length // image_offsets) -1 # -1 to reserve space for at least one text token
436+
if self.max_num_images_per_example > 0:
437+
max_num_images = min(self.max_num_images_per_example, max_num_images)
438+
image_shape = multimodal_utils.get_dummy_image_shape_for_init(self.model_name)[2:]
439+
assert (
440+
images.shape[0] <= max_num_images
441+
), f"Number of images {images.shape[0]} exceeds the maximum allowed {max_num_images}"
442+
if images.shape[0] < max_num_images:
443+
pad_size = max_num_images - images.shape[0]
444+
pad_shape = (pad_size,) + image_shape
445+
pad_images = np.zeros(pad_shape, dtype=images.dtype)
446+
if images is not None and images.size > 0:
447+
images = np.concatenate([images, pad_images], axis=0)
448+
else:
449+
images = pad_images
450+
return images
438451

439452
def map(self, element: dict[str, np.ndarray]):
440453
"""map to each element"""
441-
442-
def _pad(x, max_length, pad_id):
443-
pad_amount = max(max_length - x.shape[0], 0)
444-
pad_amount = [(0, pad_amount)] + [(0, 0)] * (len(x.shape) - 1)
445-
return np.pad(x, pad_amount, constant_values=pad_id)
446-
447454
data_columns = list(element.keys())
448455
for data_column in data_columns:
449456
if data_column != "images":
450457
element[f"{data_column}_segmentation"] = (element[data_column] != self.pad_id).astype(np.int32)
451458
element[f"{data_column}_position"] = np.arange(element[data_column].shape[0], dtype=np.int32)
459+
if self.add_true_length:
460+
element[f"{data_column}_true_length"] = np.array([element[data_column].shape[0]], dtype=np.int32)
452461
for key, _ in element.items():
453-
if key != "images":
454-
element[key] = _pad(element[key], self.max_length, self.pad_id)
462+
if key == "images":
463+
if isinstance(element["images"], list):
464+
assert self.model_name is not None, "model_name must be provided when padding images"
465+
element["images"] = self._pad_image(np.asarray(element["images"]))
466+
else:
467+
element["images"] = np.asarray(element["images"])[None, ...]
468+
elif "true_length" not in key:
469+
element[key] = self._pad_text(element[key], self.max_length, self.pad_id)
455470
return element
456471

457472

src/MaxText/maxengine.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from MaxText import inference_utils
4747
from MaxText import max_utils
4848
from MaxText import maxtext_utils
49+
from MaxText import multimodal_utils
4950
from MaxText import pyconfig
5051
from MaxText.common_types import MODEL_MODE_PREFILL, DECODING_ACTIVE_SEQUENCE_INDICATOR, MODEL_MODE_AUTOREGRESSIVE
5152
from MaxText.globals import MAXTEXT_PKG_DIR
@@ -331,7 +332,7 @@ def model_apply(_p, _rng):
331332
jnp.ones((1, self.config.max_prefill_predict_length), dtype=jnp.int32),
332333
jnp.ones((1, self.config.max_prefill_predict_length), dtype=jnp.int32),
333334
encoder_images=jnp.ones(
334-
maxtext_utils.get_dummy_image_shape_for_init(self.config),
335+
multimodal_utils.get_dummy_image_shape_for_init(self.config.model_name, batch_size=self.config.micro_batch_size_to_train_on),
335336
dtype=jnp.float32,
336337
)
337338
if self.config.use_multimodal
@@ -474,10 +475,12 @@ def _prefill_jit(
474475

475476
input_images = None
476477
if self.config.use_multimodal and images is not None:
477-
if self.config.model_name.startswith("gemma3"):
478-
input_images = images[jnp.newaxis, jnp.newaxis, ...] # Add batch and sequence dimension [B, N, H, W, C]
479-
elif self.config.model_name.startswith("llama4"):
480-
input_images = images[jnp.newaxis, ...] # Add batch dimension [B, T, C, H, W]
478+
if images.ndim == 3:
479+
# For Gemma3 single image, add batch and image count dimensions
480+
input_images = images[jnp.newaxis, jnp.newaxis, ...]
481+
elif images.ndim == 4:
482+
# add batch dimension
483+
input_images = images[jnp.newaxis, ...]
481484

482485
# sequence_indicator will be concatenated to existing_prefix decoder_segment_ids
483486
start_to_n = jnp.arange(start_position, start_position + input_tokens.shape[1])
@@ -1524,7 +1527,7 @@ def init(abstract_params, page_state):
15241527
(int(self.config.per_device_batch_size * self.mesh.size), 1),
15251528
dtype=jnp.int32,
15261529
)
1527-
dummy_image = jnp.ones(maxtext_utils.get_dummy_image_shape_for_init(self.config), dtype=jnp.int32)
1530+
dummy_image = jnp.ones(multimodal_utils.get_dummy_image_shape_for_init(self.config.model_name, batch_size=self.config.micro_batch_size_to_train_on), dtype=jnp.int32)
15281531
_, cache = self.model.apply(
15291532
abstract_params,
15301533
x,

0 commit comments

Comments
 (0)