Skip to content

Commit 4673643

Browse files
Merge pull request #2301 from AI-Hypercomputer:aireen/padding_batch
PiperOrigin-RevId: 807767938
2 parents 7b33d57 + 0a65506 commit 4673643

17 files changed

+95
-90
lines changed

docs/guides/data_input_grain.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ Grain ensures determinism in data input pipelines by saving the pipeline's state
2424
* **Debug training anomalies**: When troubleshooting training spikes or anomalies, the ability to replay the exact data sequence helps distinguish between bad data batches and underlying hardware or software issues.
2525

2626
## Data shuffling
27-
* **Global shuffle**: This feature is only available when using Grain with [ArrayRecord](https://github.com/google/array_record) (random access) format, achieved by shuffling indices globally at the beginning of each epoch and then reading the elements according to the random order. This is usually fast enough, even when using hard drives and distributed file systems.
27+
* **Global shuffle**: This feature is only available when using Grain with [ArrayRecord](https://github.com/google/array_record) (random access) format, achieved by shuffling indices globally at the beginning of each epoch and then reading the elements according to the random order. This shuffle method effectively prevents local overfitting, leading to better training results.
2828
* **Hierarchical shuffle**: For sequential access format [Parquet](https://arrow.apache.org/docs/python/parquet.html), shuffle is performed by these steps: file shuffling, interleave from files, and window shuffle using a fixed size buffer.
2929

3030
## Using Grain

docs/guides/data_input_pipeline.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,11 @@ The approaches to solve these challenges depend on whether your dataset supports
3737
Random-access formats are highly recommended for multi-host training because they allow any part of the file to be read directly by its index.<br>
3838
In MaxText, this is best supported by the ArrayRecord format using the Grain input pipeline. This approach gracefully handles the key challenges:
3939
* **Concurrent access and uniqueness**: Grain assigns a unique set of indices to each host. ArrayRecord allows different hosts to read from different indices in the same file.
40-
* **Uneven completion**: Data indices are distributed evenly among hosts. Without packing, the data imbalance between hosts will be at most one batch. To handle the final steps where some hosts run out of data, you can enable the `generate_padding_example` flag. This directs hosts to generate empty "padding" batches until the training or evaluation steps are met. **Note**: When sequence packing is enabled, the difference in the number of packed examples per host can be larger. The `generate_padding_example` flag still solves this. However, as more hosts begin generating padding, you will observe a decrease in total_weights and a slower change in the training loss. If all hosts exhaust their data before the target step count is reached, both total_weights and loss will drop to 0.
40+
* **Uneven completion**: Data indices are distributed evenly among hosts. Without packing, the data imbalance between hosts will be at most one batch. To handle the final steps where some hosts run out of data, you can enable the `generate_padding_batch_train`/`generate_padding_batch_eval` flag. This directs hosts to generate empty "padding" batches until the training or evaluation steps are met. **Note**: When sequence packing is enabled, the difference in the number of packed examples per host can be larger. The `generate_padding_batch_train`/`generate_padding_batch_eval` flag still solves this. However, as more hosts begin generating padding, you will observe a decrease in total_weights and a slower change in the training loss. If all hosts exhaust their data before the target step count is reached, both total_weights and loss will drop to 0.
4141

4242
### Sequential access dataset
4343
* **Concurrent access and uniqueness**: Sequential-access datasets (e.g., Parquet, JSON, TFRecord) cannot be accessed by index, requiring a different strategy -- file-based sharding, where each host is given exclusive access to a specific subset of data files. **Key requirement**: `(Number of data files) % (Number of data-loading hosts) == 0`. If the file count isn't a multiple of the host count, the files will be distributed unevenly. For example, with 10 files and 8 hosts, some hosts will get two files while others get one, significantly worsening the "uneven completion" problem. If you have fewer files than hosts, performance will be severely degraded as all hosts are concurrently accessing all the files.
44-
* **Uneven completion**: Similar to random-access datasets, you can use the `generate_padding_example` flag to handle hosts that finish their file shards early (currently only supported in Hugging Face pipeline, not available in TFDS pipeline).
44+
* **Uneven completion**: Similar to random-access datasets, you can use the `generate_padding_batch_train`/`generate_padding_batch_eval` flag to handle hosts that finish their file shards early.
4545

4646
```{toctree}
4747
:hidden:

requirements_with_jax_ai_image.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ datasets
44
flax>=0.11.0
55
google-api-python-client
66
google-jetstream@git+https://github.com/AI-Hypercomputer/JetStream.git
7-
grain[parquet]>=0.2.6
7+
grain[parquet]>=0.2.12
88
jaxtyping
99
jsonlines
1010
mlperf-logging@git+https://github.com/mlperf/logging.git
@@ -13,12 +13,12 @@ orbax-checkpoint>=0.11.22
1313
pathwaysutils>=0.1.1
1414
pillow>=11.1.0
1515
pre-commit
16-
protobuf==3.20.3
16+
protobuf>=5.29.5
1717
pyink
1818
pylint
1919
pytest
2020
pytype
21-
sentencepiece==0.1.97
21+
sentencepiece>=0.2.0
2222
tensorflow-datasets
2323
tensorflow-text>=2.17.0
2424
tiktoken

requirements_with_jax_stable_stack_0_6_1_pipreqs.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ datasets==3.6.0
88
etils==1.12.2
99
evaluate==0.4.4
1010
flax==0.11.0
11-
grain==0.2.10
11+
grain==0.2.12
1212
grpcio==1.72.0rc1
1313
huggingface_hub==0.33.0
1414
jax==0.6.0

src/MaxText/assets/tokenizer

2.37 KB
Binary file not shown.

src/MaxText/configs/base.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,8 @@ eval_data_columns: ['text'] # for DPO dataset containing "chosen" and "rejected"
470470
eval_image_column: 'image'
471471
packing: True
472472
num_epoch: 1 # only grain and tfds pipeline supports num_epoch > 1
473+
generate_padding_batch_train: False
474+
generate_padding_batch_eval: False
473475

474476
# direct preference optimization (DPO)
475477
use_dpo: False

src/MaxText/data_loader.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
maybe_record_goodput,
2727
)
2828

29-
3029
class DataLoader:
3130
"""
3231
Loads preprocessed data for training.
@@ -51,10 +50,10 @@ def load_next_batch(self):
5150
self.last_batch = jax.lax.with_sharding_constraint(example_batch, self.input_data_shardings)
5251
self.check_example_batch()
5352
except Exception as e: # pylint: disable=broad-except
54-
if "StopIteration" in str(e):
55-
raise exceptions.StopTraining("You may have run out of training data.")
53+
if isinstance(e, StopIteration):
54+
raise exceptions.StopTraining(f"You may have run out of training data. Received {type(e)} exception: ({e})")
5655
else:
57-
raise exceptions.StopTraining(f"`load_next_batch()` failed ({e}).")
56+
raise exceptions.StopTraining(f"`load_next_batch()` failed with {type(e)} exception: ({e}).")
5857
return self.last_batch
5958

6059
def check_example_batch(self):

src/MaxText/input_pipeline/_grain_data_processing.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,11 @@ def pretrain_preprocessing_pipeline(dataset, config, data_columns, tokenize, gra
120120
data_columns, config.max_target_length, config.add_bos, config.add_eos, tokenizer_model
121121
)
122122
)
123-
124123
# Pack and Batch examples.
124+
batch_size = config.global_batch_size_to_load // jax.process_count()
125125
if config.packing:
126126
length_struct = {col: config.max_target_length for col in data_columns}
127-
dataset = grain.experimental.FirstFitPackIterDataset(dataset, length_struct=length_struct, num_packing_bins=30)
127+
dataset = grain.experimental.FirstFitPackIterDataset(dataset, length_struct=length_struct, num_packing_bins=batch_size)
128128
rekey_dict = {
129129
"targets_segmentation": "targets_segment_ids",
130130
"inputs_segmentation": "inputs_segment_ids",
@@ -134,7 +134,8 @@ def pretrain_preprocessing_pipeline(dataset, config, data_columns, tokenize, gra
134134
dataset = dataset.map(_input_pipeline_utils.Rekey(rekey_dict))
135135
else:
136136
dataset = dataset.map(_input_pipeline_utils.PadToMaxLength(config.max_target_length, pad_id))
137-
dataset = dataset.batch(batch_size=config.global_batch_size_to_load // jax.process_count(), drop_remainder=False)
137+
batch_fn = functools.partial(grain.experimental.batch_and_pad, batch_size=batch_size, pad_value=pad_id)
138+
dataset = dataset.batch(batch_size, batch_fn=batch_fn)
138139

139140
# Shift inputs for teacher-forced training
140141
dataset = dataset.map(
@@ -175,7 +176,9 @@ def dpo_preprocessing_pipeline(dataset, config, data_columns, tokenize, grain_wo
175176
)
176177

177178
dataset = dataset.map(_input_pipeline_utils.PadToMaxLength(config.max_target_length, pad_id))
178-
dataset = dataset.batch(batch_size=config.global_batch_size_to_load // jax.process_count(), drop_remainder=False)
179+
batch_size = config.global_batch_size_to_load // jax.process_count()
180+
batch_fn = functools.partial(grain.experimental.batch_and_pad, batch_size=batch_size, pad_value=pad_id)
181+
dataset = dataset.batch(batch_size, batch_fn=batch_fn)
179182
dataset = dataset.mp_prefetch(grain.MultiprocessingOptions(num_workers=grain_worker_count))
180183
return dataset
181184

@@ -216,7 +219,9 @@ def make_grain_train_iterator(
216219
tokenize=config.tokenize_train_data,
217220
grain_worker_count=config.grain_worker_count,
218221
)
219-
return multihost_dataloading.MultiHostDataLoadIterator(train_dataloader, global_mesh)
222+
return multihost_dataloading.MultiHostDataLoadIterator(
223+
train_dataloader, global_mesh, config.generate_padding_batch_train
224+
)
220225
else:
221226
get_ds_fn = functools.partial(
222227
get_datasets,
@@ -283,7 +288,7 @@ def make_grain_eval_iterator(
283288
tokenize=config.tokenize_eval_data,
284289
grain_worker_count=config.grain_worker_count_eval,
285290
)
286-
return multihost_dataloading.MultiHostDataLoadIterator(eval_dataloader, global_mesh)
291+
return multihost_dataloading.MultiHostDataLoadIterator(eval_dataloader, global_mesh, config.generate_padding_batch_eval)
287292
else:
288293
get_ds_fn = functools.partial(
289294
get_datasets,

src/MaxText/input_pipeline/_hf_data_processing.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,6 @@ def vision_sft_preprocessing_pipeline(
102102
dataloading_host_index=dataloading_host_index,
103103
dataloading_host_count=dataloading_host_count,
104104
num_threads=1,
105-
generate_padding_example=True,
106105
max_target_length=config.max_target_length,
107106
data_column_names=text_columns,
108107
)
@@ -162,8 +161,8 @@ def preprocessing_pipeline(
162161
packing=True,
163162
shift=True,
164163
num_threads=1,
165-
drop_remainder=False,
166-
generate_padding_example=False,
164+
drop_remainder=True,
165+
generate_padding_batch=False,
167166
use_dpo=None,
168167
use_sft=None,
169168
sft_train_on_completion_only=True,
@@ -239,7 +238,6 @@ def preprocessing_pipeline(
239238
dataloading_host_index,
240239
dataloading_host_count,
241240
num_threads,
242-
generate_padding_example,
243241
max_target_length,
244242
data_column_names,
245243
)
@@ -304,7 +302,7 @@ def lists2array(x):
304302
read_options=grain.ReadOptions(num_threads=num_threads, prefetch_buffer_size=128),
305303
)
306304

307-
multihost_gen = multihost_dataloading.MultiHostDataLoadIterator(dataloader, global_mesh)
305+
multihost_gen = multihost_dataloading.MultiHostDataLoadIterator(dataloader, global_mesh, generate_padding_batch)
308306

309307
# Return multi-host jax.Array prep iterator
310308
return multihost_gen
@@ -352,7 +350,7 @@ def make_hf_train_iterator(
352350
add_bos=config.add_bos,
353351
add_eos=config.add_eos,
354352
packing=config.packing,
355-
generate_padding_example=False,
353+
generate_padding_batch=config.generate_padding_batch_train,
356354
use_dpo=config.use_dpo,
357355
use_sft=config.use_sft,
358356
sft_train_on_completion_only=config.sft_train_on_completion_only,
@@ -374,8 +372,6 @@ def make_hf_eval_iterator(
374372
streaming=True,
375373
token=config.hf_access_token,
376374
)
377-
378-
eval_generate_padding_example = config.eval_steps > 0
379375
if config.use_sft and config.use_multimodal:
380376
eval_iter = vision_sft_preprocessing_pipeline(
381377
dataset=eval_ds,
@@ -404,7 +400,7 @@ def make_hf_eval_iterator(
404400
add_bos=config.add_bos,
405401
add_eos=config.add_eos,
406402
packing=config.packing,
407-
generate_padding_example=eval_generate_padding_example,
403+
generate_padding_batch=config.generate_padding_batch_eval,
408404
use_dpo=config.use_dpo,
409405
use_sft=config.use_sft,
410406
sft_train_on_completion_only=config.sft_train_on_completion_only,

src/MaxText/input_pipeline/_input_pipeline_utils.py

Lines changed: 5 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,7 @@ def apply_chat_template(example, tokenizer_model, data_column_name):
157157
for message in example[data_column_name]:
158158
if message["role"] == "user":
159159
prompt = message
160-
prompt_in_chat_template = tokenizer_model.apply_chat_template(
161-
[prompt], add_generation_prompt=False, tokenize=False
162-
)
160+
prompt_in_chat_template = tokenizer_model.apply_chat_template([prompt], add_generation_prompt=False, tokenize=False)
163161
messages.append(prompt_in_chat_template)
164162
is_prompt.append(True)
165163
elif message["role"] == "assistant":
@@ -266,15 +264,13 @@ def __init__(
266264
dataloading_host_index: int,
267265
dataloading_host_count: int,
268266
num_threads: int,
269-
generate_padding_example: bool,
270267
max_target_length: int,
271268
data_column_names: list[str],
272269
):
273270
self.dataset = dataset
274271
self.num_threads = num_threads
275272
self.dataloading_host_count = dataloading_host_count
276273
self.dataloading_host_index = dataloading_host_index
277-
self.generate_padding_example = generate_padding_example
278274
self.max_target_lenth = max_target_length
279275
self.data_column_names = data_column_names
280276
if hasattr(dataset, "n_shards"):
@@ -285,7 +281,6 @@ def __init__(
285281
self.dataset_shards = [dataloading_host_index * self.num_threads + i for i in range(self.num_threads)]
286282
self.datasets = [split_dataset_by_node(dataset, world_size=self.n_shards, rank=x) for x in self.dataset_shards]
287283
self.data_iters = []
288-
self.out_of_data = False
289284

290285
def _check_shard_count(self):
291286
if self.n_shards < (self.dataloading_host_count * self.num_threads):
@@ -300,20 +295,13 @@ def _update_shard(self, idx):
300295
"""update shard"""
301296
new_shard = self.dataset_shards[idx] + self.dataloading_host_count * self.num_threads
302297
if new_shard < self.n_shards:
303-
max_logging.log(
304-
f"Updating host {self.dataloading_host_index} dataset {idx}, was on shard {self.dataset_shards[idx]}"
305-
)
298+
max_logging.log(f"Updating host {self.dataloading_host_index} dataset {idx}, was on shard {self.dataset_shards[idx]}")
306299
max_logging.log(f"New shard is {new_shard}")
307300
self.dataset_shards[idx] = new_shard
308301
self.datasets[idx] = split_dataset_by_node(self.dataset, world_size=self.n_shards, rank=self.dataset_shards[idx])
309302
self.data_iters[idx] = iter(self.datasets[idx])
310303
else:
311-
max_logging.log(f"Run out of shards on host {self.dataloading_host_index}, shard {new_shard} is not available")
312-
self.out_of_data = True
313-
if self.generate_padding_example:
314-
max_logging.log(
315-
f"Host {self.dataloading_host_index} will start generating all-0 padding examples until step number is met."
316-
)
304+
raise StopIteration(f"Run out of shards on host {self.dataloading_host_index}, shard {new_shard} is not available")
317305

318306
def __len__(self):
319307
"""Return length of the HF dataset. Since HuggingFace IterableDataset does not have length,
@@ -329,20 +317,10 @@ def __getitem__(self, index):
329317

330318
while True:
331319
try:
332-
if self.out_of_data:
333-
if self.generate_padding_example:
334-
return {
335-
column_name: np.zeros(self.max_target_lenth, dtype=np.int32) for column_name in self.data_column_names
336-
}
337-
else:
338-
raise StopIteration("Running out of data")
339320
data = next(self.data_iters[idx])
340321
return data
341-
except StopIteration as e:
342-
if not self.out_of_data:
343-
self._update_shard(idx)
344-
else:
345-
raise e
322+
except StopIteration:
323+
self._update_shard(idx)
346324

347325

348326
########## Functions used by Grain pipeline

0 commit comments

Comments
 (0)