diff --git a/src/MaxText/input_pipeline/_hf_data_processing.py b/src/MaxText/input_pipeline/_hf_data_processing.py index 8a44250003..3fefb54d3d 100644 --- a/src/MaxText/input_pipeline/_hf_data_processing.py +++ b/src/MaxText/input_pipeline/_hf_data_processing.py @@ -27,7 +27,7 @@ import numpy as np from MaxText.input_pipeline import _input_pipeline_utils -from MaxText import multihost_dataloading +from MaxText import multihost_dataloading, maxtext_utils def vision_sft_preprocessing_pipeline( @@ -39,6 +39,8 @@ def vision_sft_preprocessing_pipeline( text_columns, image_column, global_batch_size, + microbatch_size_to_run=None, + input_data_sharding=None ): """pipeline for multimodal SFT with HF dataset""" @@ -137,7 +139,7 @@ def vision_sft_preprocessing_pipeline( read_options=grain.ReadOptions(num_threads=1, prefetch_buffer_size=128), ) - multihost_gen = multihost_dataloading.MultiHostDataLoadIterator(dataloader, global_mesh) + multihost_gen = multihost_dataloading.MultiHostDataLoadIterator(dataloader, global_mesh, microbatch_size_to_run=microbatch_size_to_run, input_data_sharding=input_data_sharding) # Return multi-host jax.Array prep iterator return multihost_gen @@ -167,6 +169,8 @@ def preprocessing_pipeline( use_sft=None, sft_train_on_completion_only=True, grain_worker_count=1, # only support 0 or 1 + microbatch_size_to_run=None, + input_data_sharding=None ): """pipeline for preprocessing HF dataset""" @@ -302,7 +306,7 @@ def lists2array(x): read_options=grain.ReadOptions(num_threads=num_threads, prefetch_buffer_size=128), ) - multihost_gen = multihost_dataloading.MultiHostDataLoadIterator(dataloader, global_mesh, generate_padding_batch) + multihost_gen = multihost_dataloading.MultiHostDataLoadIterator(dataloader, global_mesh, generate_padding_batch, microbatch_size_to_run, input_data_sharding) # Return multi-host jax.Array prep iterator return multihost_gen @@ -314,6 +318,8 @@ def make_hf_train_iterator( process_indices_train, ): """Load, preprocess dataset and return iterators""" + input_data_sharding = maxtext_utils.get_input_data_sharding(config, global_mesh) + train_ds = datasets.load_dataset( config.hf_path, data_dir=config.hf_data_dir, @@ -332,6 +338,8 @@ def make_hf_train_iterator( text_columns=config.train_data_columns, image_column=config.train_image_column, global_batch_size=config.global_batch_size_to_load, + microbatch_size_to_run=config.micro_batch_size_to_train_on, + input_data_sharding=input_data_sharding, ) else: train_iter = preprocessing_pipeline( @@ -354,6 +362,8 @@ def make_hf_train_iterator( use_dpo=config.use_dpo, use_sft=config.use_sft, sft_train_on_completion_only=config.sft_train_on_completion_only, + microbatch_size_to_run=config.micro_batch_size_to_train_on, + input_data_sharding=input_data_sharding, ) return train_iter @@ -364,6 +374,8 @@ def make_hf_eval_iterator( process_indices_eval, ): """Make Hugging Face evaluation iterator. Load and preprocess eval dataset: and return iterator.""" + input_data_sharding = maxtext_utils.get_input_data_sharding(config, global_mesh) + eval_ds = datasets.load_dataset( config.hf_path, data_dir=config.hf_data_dir, @@ -382,6 +394,8 @@ def make_hf_eval_iterator( text_columns=config.eval_data_columns, image_column=config.eval_image_column, global_batch_size=config.global_batch_size_to_load_eval, + microbatch_size_to_run=config.micro_batch_size_to_eval_on, + input_data_sharding=input_data_sharding, ) else: eval_iter = preprocessing_pipeline( @@ -404,5 +418,7 @@ def make_hf_eval_iterator( use_dpo=config.use_dpo, use_sft=config.use_sft, sft_train_on_completion_only=config.sft_train_on_completion_only, + microbatch_size_to_run=config.micro_batch_size_to_eval_on, + input_data_sharding=input_data_sharding ) return eval_iter diff --git a/src/MaxText/multihost_dataloading.py b/src/MaxText/multihost_dataloading.py index c568f282f3..7bb09c029f 100644 --- a/src/MaxText/multihost_dataloading.py +++ b/src/MaxText/multihost_dataloading.py @@ -18,8 +18,9 @@ Adapted from Sholto's: https://github.com/sholtodouglas/multihost_dataloading """ +import itertools from functools import partial -from typing import Union, Sequence +from typing import Union, Sequence, Optional from collections.abc import Iterator, Iterable import time @@ -68,18 +69,13 @@ def _form_global_array(path, array: np.ndarray, global_mesh: Mesh) -> jax.Array: class MultiHostDataLoadIterator: """fold get_next_batch_sharded into a iterator class""" - def __init__(self, dataloader: tf.data.Dataset | Iterable, global_mesh: Mesh, generate_padding_batch: bool = False): + def __init__(self, dataloader: tf.data.Dataset | Iterable, global_mesh: Mesh, generate_padding_batch: bool = False, microbatch_size_to_run: Optional[int]=None, input_data_sharding=None): self.global_mesh = global_mesh self.dataloader = dataloader - if isinstance(self.dataloader, tf.data.Dataset): - self.local_iterator = self.dataloader.as_numpy_iterator() - elif isinstance(self.dataloader, Iterable): - self.local_iterator = iter(self.dataloader) - else: - raise ValueError("Type error: dataloader should be either tf.data.Dataset or Iterable.") - self.out_of_data = False - self.last_local_data = None self.generate_padding_batch = generate_padding_batch + self.microbatch_size_to_run = microbatch_size_to_run + self.input_data_sharding = input_data_sharding + self.reset() def reset(self): if isinstance(self.dataloader, tf.data.Dataset): @@ -91,12 +87,39 @@ def reset(self): self.out_of_data = False self.last_local_data = None + sharded_iter = self._base_iter() + if self.microbatch_size_to_run: + self.local_iterator = itertools.chain.from_iterable( + self.explode_to_micro(b) for b in sharded_iter + ) + else: + self.local_iterator = sharded_iter + def __iter__(self): self.reset() return self def __next__(self): - return self._get_next_batch_sharded() + return next(self.local_iterator) + + def _base_iter(self): + while True: + yield self._get_next_batch_sharded() + + def explode_to_micro(self, batch): + """Splits larger batch into smaller equally sized batches""" + mb = self.microbatch_size_to_run + # `batch` is a dict-like PyTree of jax.Arrays + k0 = next(iter(batch)) + B = batch[k0].shape[0] + assert B % mb == 0, f"global batch {B} not divisible by microbatch {mb}" + M = B // mb + reshaped = {k: v.reshape((M, mb) + v.shape[1:]) for k, v in batch.items()} + for i in range(M): + microbatch = {k: reshaped[k][i] for k in reshaped} + if self.input_data_sharding is not None: + microbatch = jax.lax.with_sharding_constraint(microbatch, self.input_data_sharding) + yield microbatch def _get_next_batch_sharded(self) -> jax.Array: """Splits the host loaded data equally over all devices.""" diff --git a/src/MaxText/train.py b/src/MaxText/train.py index b1b0b2c7bf..de5b0a942d 100644 --- a/src/MaxText/train.py +++ b/src/MaxText/train.py @@ -94,13 +94,6 @@ def loss_fn(model, config, data, dropout_rng, params, is_train=True): loss: average loss aux: a dictionary including intermediate_outputs, total_loss, and total_weights """ - # decimate proportion of data when per_device_batch_size<1 - if is_train: - for k, v in data.items(): - data[k] = v[: config.micro_batch_size_to_train_on, :] - else: - for k, v in data.items(): - data[k] = v[: config.micro_batch_size_to_eval_on, :] mutable_collections = ["intermediates"] if config.mtp_num_layers > 0 and is_train: # The single model.apply call now triggers the entire chain if MTP is enabled: