Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions src/MaxText/input_pipeline/_hf_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import numpy as np

from MaxText.input_pipeline import _input_pipeline_utils
from MaxText import multihost_dataloading
from MaxText import multihost_dataloading, maxtext_utils


def vision_sft_preprocessing_pipeline(
Expand All @@ -39,6 +39,8 @@ def vision_sft_preprocessing_pipeline(
text_columns,
image_column,
global_batch_size,
microbatch_size_to_run=None,
input_data_sharding=None
):
"""pipeline for multimodal SFT with HF dataset"""

Expand Down Expand Up @@ -137,7 +139,7 @@ def vision_sft_preprocessing_pipeline(
read_options=grain.ReadOptions(num_threads=1, prefetch_buffer_size=128),
)

multihost_gen = multihost_dataloading.MultiHostDataLoadIterator(dataloader, global_mesh)
multihost_gen = multihost_dataloading.MultiHostDataLoadIterator(dataloader, global_mesh, microbatch_size_to_run=microbatch_size_to_run, input_data_sharding=input_data_sharding)

# Return multi-host jax.Array prep iterator
return multihost_gen
Expand Down Expand Up @@ -167,6 +169,8 @@ def preprocessing_pipeline(
use_sft=None,
sft_train_on_completion_only=True,
grain_worker_count=1, # only support 0 or 1
microbatch_size_to_run=None,
input_data_sharding=None
):
"""pipeline for preprocessing HF dataset"""

Expand Down Expand Up @@ -302,7 +306,7 @@ def lists2array(x):
read_options=grain.ReadOptions(num_threads=num_threads, prefetch_buffer_size=128),
)

multihost_gen = multihost_dataloading.MultiHostDataLoadIterator(dataloader, global_mesh, generate_padding_batch)
multihost_gen = multihost_dataloading.MultiHostDataLoadIterator(dataloader, global_mesh, generate_padding_batch, microbatch_size_to_run, input_data_sharding)

# Return multi-host jax.Array prep iterator
return multihost_gen
Expand All @@ -314,6 +318,8 @@ def make_hf_train_iterator(
process_indices_train,
):
"""Load, preprocess dataset and return iterators"""
input_data_sharding = maxtext_utils.get_input_data_sharding(config, global_mesh)

train_ds = datasets.load_dataset(
config.hf_path,
data_dir=config.hf_data_dir,
Expand All @@ -332,6 +338,8 @@ def make_hf_train_iterator(
text_columns=config.train_data_columns,
image_column=config.train_image_column,
global_batch_size=config.global_batch_size_to_load,
microbatch_size_to_run=config.micro_batch_size_to_train_on,
input_data_sharding=input_data_sharding,
)
else:
train_iter = preprocessing_pipeline(
Expand All @@ -354,6 +362,8 @@ def make_hf_train_iterator(
use_dpo=config.use_dpo,
use_sft=config.use_sft,
sft_train_on_completion_only=config.sft_train_on_completion_only,
microbatch_size_to_run=config.micro_batch_size_to_train_on,
input_data_sharding=input_data_sharding,
)
return train_iter

Expand All @@ -364,6 +374,8 @@ def make_hf_eval_iterator(
process_indices_eval,
):
"""Make Hugging Face evaluation iterator. Load and preprocess eval dataset: and return iterator."""
input_data_sharding = maxtext_utils.get_input_data_sharding(config, global_mesh)

eval_ds = datasets.load_dataset(
config.hf_path,
data_dir=config.hf_data_dir,
Expand All @@ -382,6 +394,8 @@ def make_hf_eval_iterator(
text_columns=config.eval_data_columns,
image_column=config.eval_image_column,
global_batch_size=config.global_batch_size_to_load_eval,
microbatch_size_to_run=config.micro_batch_size_to_eval_on,
input_data_sharding=input_data_sharding,
)
else:
eval_iter = preprocessing_pipeline(
Expand All @@ -404,5 +418,7 @@ def make_hf_eval_iterator(
use_dpo=config.use_dpo,
use_sft=config.use_sft,
sft_train_on_completion_only=config.sft_train_on_completion_only,
microbatch_size_to_run=config.micro_batch_size_to_eval_on,
input_data_sharding=input_data_sharding
)
return eval_iter
45 changes: 34 additions & 11 deletions src/MaxText/multihost_dataloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
Adapted from Sholto's:
https://github.com/sholtodouglas/multihost_dataloading
"""
import itertools
from functools import partial
from typing import Union, Sequence
from typing import Union, Sequence, Optional
from collections.abc import Iterator, Iterable
import time

Expand Down Expand Up @@ -68,18 +69,13 @@ def _form_global_array(path, array: np.ndarray, global_mesh: Mesh) -> jax.Array:
class MultiHostDataLoadIterator:
"""fold get_next_batch_sharded into a iterator class"""

def __init__(self, dataloader: tf.data.Dataset | Iterable, global_mesh: Mesh, generate_padding_batch: bool = False):
def __init__(self, dataloader: tf.data.Dataset | Iterable, global_mesh: Mesh, generate_padding_batch: bool = False, microbatch_size_to_run: Optional[int]=None, input_data_sharding=None):
self.global_mesh = global_mesh
self.dataloader = dataloader
if isinstance(self.dataloader, tf.data.Dataset):
self.local_iterator = self.dataloader.as_numpy_iterator()
elif isinstance(self.dataloader, Iterable):
self.local_iterator = iter(self.dataloader)
else:
raise ValueError("Type error: dataloader should be either tf.data.Dataset or Iterable.")
self.out_of_data = False
self.last_local_data = None
self.generate_padding_batch = generate_padding_batch
self.microbatch_size_to_run = microbatch_size_to_run
self.input_data_sharding = input_data_sharding
self.reset()

def reset(self):
if isinstance(self.dataloader, tf.data.Dataset):
Expand All @@ -91,12 +87,39 @@ def reset(self):
self.out_of_data = False
self.last_local_data = None

sharded_iter = self._base_iter()
if self.microbatch_size_to_run:
self.local_iterator = itertools.chain.from_iterable(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be simplified:

def microbatch_generator():
    for global_batch in sharded_iter:
        yield from self.explode_to_micro(global_batch)

if self.microbatch_size_to_run:
  self.local_iterator = microbatch_generator()
else:
  self.local_iterator = sharded_iter

self.explode_to_micro(b) for b in sharded_iter
)
else:
self.local_iterator = sharded_iter

def __iter__(self):
self.reset()
return self

def __next__(self):
return self._get_next_batch_sharded()
return next(self.local_iterator)

def _base_iter(self):
while True:
yield self._get_next_batch_sharded()

def explode_to_micro(self, batch):
"""Splits larger batch into smaller equally sized batches"""
mb = self.microbatch_size_to_run
# `batch` is a dict-like PyTree of jax.Arrays
k0 = next(iter(batch))
B = batch[k0].shape[0]
assert B % mb == 0, f"global batch {B} not divisible by microbatch {mb}"
M = B // mb
reshaped = {k: v.reshape((M, mb) + v.shape[1:]) for k, v in batch.items()}
for i in range(M):
microbatch = {k: reshaped[k][i] for k in reshaped}
if self.input_data_sharding is not None:
microbatch = jax.lax.with_sharding_constraint(microbatch, self.input_data_sharding)
yield microbatch

def _get_next_batch_sharded(self) -> jax.Array:
"""Splits the host loaded data equally over all devices."""
Expand Down
7 changes: 0 additions & 7 deletions src/MaxText/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,6 @@ def loss_fn(model, config, data, dropout_rng, params, is_train=True):
loss: average loss
aux: a dictionary including intermediate_outputs, total_loss, and total_weights
"""
# decimate proportion of data when per_device_batch_size<1
if is_train:
for k, v in data.items():
data[k] = v[: config.micro_batch_size_to_train_on, :]
else:
for k, v in data.items():
data[k] = v[: config.micro_batch_size_to_eval_on, :]
mutable_collections = ["intermediates"]
if config.mtp_num_layers > 0 and is_train:
# The single model.apply call now triggers the entire chain if MTP is enabled:
Expand Down
Loading