Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/guides/data_input_pipeline/data_input_grain.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@ MOUNT_PATH=$MOUNT_PATH \
3. Set `dataset_type=grain`, `grain_file_type={arrayrecord|parquet}`, `grain_train_files` to match the file pattern on the mounted local path.
4. Tune `grain_worker_count` for performance. This parameter controls the number of child processes used by Grain (more details in [behind_the_scenes](https://google-grain.readthedocs.io/en/latest/behind_the_scenes.html), [grain_pool.py](https://github.com/google/grain/blob/main/grain/_src/python/grain_pool.py)). If you use a large number of workers, check your config for gcsfuse in [setup_gcsfuse.sh](https://github.com/google/maxtext/blob/main/tools/setup/setup_gcsfuse.sh) to avoid gcsfuse throttling.

5. For multi-source blending, you can specify multiple data sources with their respective weights using semicolon (;) as a separator and colon (:) for weights. The weights will be automatically normalized to sum to 1.0. For example:
5. For multi-source blending, you can specify multiple data sources with their respective weights using semicolon (;) as a separator and a comma (,) for weights. The weights will be automatically normalized to sum to 1.0. For example:
```
# Blend two data sources with 30% from first source and 70% from second source
grain_train_files=/tmp/gcsfuse/dataset1.array_record*:0.3;/tmp/gcsfuse/dataset2.array_record*:0.7
grain_train_files=/tmp/gcsfuse/dataset1.array_record*,0.3;/tmp/gcsfuse/dataset2.array_record*,0.7
# Blend three data sources with equal weights (will be normalized to 0.33 each)
grain_train_files=/tmp/gcsfuse/dataset1.array_record*:1;/tmp/gcsfuse/dataset2.array_record*:1;/tmp/gcsfuse/dataset3.array_record*:1
grain_train_files=/tmp/gcsfuse/dataset1.array_record*,1;/tmp/gcsfuse/dataset2.array_record*,1;/tmp/gcsfuse/dataset3.array_record*,1
```
Note: When using multiple data sources, only the ArrayRecord format is supported.

Expand Down
4 changes: 2 additions & 2 deletions src/MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -569,8 +569,8 @@ hf_eval_files: ''
hf_access_token: ''
# for Grain input pipeline (dataset_type=grain)
# Path to grain data files. Can be a single pattern or multiple patterns with weights.
# For multiple patterns, use semicolon (;) to separate and colon (:) to specify weights.
# Example: "path/to/data1.array_record*:0.3;path/to/data2.array_record*:0.7"
# For multiple patterns, use semicolon (;) to separate and comma (,) to specify weights.
# Example: "path/to/data1.array_record*,0.3;path/to/data2.array_record*,0.7"
# Note: When using multiple files (separated by ';'), only ArrayRecord format is supported.
# For more details, see https://github.com/google/maxtext/blob/main/getting_started/Data_Input_Pipeline.md#grain-input-pipeline
grain_train_files: ''
Expand Down
14 changes: 10 additions & 4 deletions src/MaxText/input_pipeline/_grain_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
import glob
from pathlib import Path
import functools

import ml_collections

import jax

import grain.python as grain

from MaxText.utils import gcs_utils
from MaxText.input_pipeline import _input_pipeline_utils
from MaxText.input_pipeline import _grain_tokenizer
from MaxText import multihost_dataloading
Expand All @@ -32,8 +32,14 @@


def find_data_files(data_file_pattern):
data_files = glob.glob(str(Path(data_file_pattern).expanduser().resolve()))
assert len(data_files) > 0, f"No file found with pattern {data_file_pattern}."
"""Find data files matching the pattern."""
if data_file_pattern.startswith("gs://"):
data_files = gcs_utils.gcs_glob_pattern(data_file_pattern)
else:
# Local files
data_files = glob.glob(str(Path(data_file_pattern).expanduser().resolve()))
if not data_files:
raise FileNotFoundError(f"No files found matching pattern: {data_file_pattern}")
max_logging.log(f"Found {len(data_files)} files for train/eval with grain")
return data_files

Expand All @@ -51,7 +57,7 @@ def get_datasets(
"""Load dataset from array_record files for using with grain"""
if data_file_type == "arrayrecord":
if ";" in data_file_pattern:
data_file_patterns, weights = zip(*[pattern.split(":") for pattern in data_file_pattern.split(";")])
data_file_patterns, weights = zip(*[pattern.split(",") for pattern in data_file_pattern.split(";")])
assert len(data_file_patterns) == len(weights), "Number of data file patterns and weights must match"
weights = [float(weight) for weight in weights]
weights = [round(weight / sum(weights), 4) for weight in weights]
Expand Down
11 changes: 11 additions & 0 deletions src/MaxText/utils/gcs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,17 @@ def gcs_list_directories(directory_path):
return directories


def gcs_glob_pattern(pattern):
"""
Globs GCS files and returns a list of full GCS paths.
"""
storage_client = storage.Client()
bucket_name, glob_pattern = parse_gcs_bucket_and_prefix(pattern)
blobs = storage_client.list_blobs(bucket_name, match_glob=glob_pattern)
data_files = [f"gs://{bucket_name}/{blob.name}" for blob in blobs]
return data_files


def read_json_from_gcs(file_path):
"""
Read a json file from gcs bucket.
Expand Down
4 changes: 2 additions & 2 deletions tests/grain_data_processing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ def setUp(self):
temp_dir = tempfile.gettempdir()
# We use the same dataset for testing, but you can use different datasets by changing the file patterns.
grain_train_files = [
f"{temp_dir}/gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record*:0.3",
f"{temp_dir}/gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record*:0.7",
f"{temp_dir}/gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record*,0.3",
f"{temp_dir}/gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record*,0.7",
]
grain_train_files = ";".join(grain_train_files)
self.config = pyconfig.initialize(
Expand Down
Loading