Skip to content

Commit f8bc1e8

Browse files
Merge pull request #2677 from AI-Hypercomputer:aireen/grain_gcs
PiperOrigin-RevId: 832005402
2 parents b853688 + aa79bb2 commit f8bc1e8

File tree

5 files changed

+28
-11
lines changed

5 files changed

+28
-11
lines changed

docs/guides/data_input_pipeline/data_input_grain.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,13 @@ MOUNT_PATH=$MOUNT_PATH \
4141
3. Set `dataset_type=grain`, `grain_file_type={arrayrecord|parquet}`, `grain_train_files` to match the file pattern on the mounted local path.
4242
4. Tune `grain_worker_count` for performance. This parameter controls the number of child processes used by Grain (more details in [behind_the_scenes](https://google-grain.readthedocs.io/en/latest/behind_the_scenes.html), [grain_pool.py](https://github.com/google/grain/blob/main/grain/_src/python/grain_pool.py)). If you use a large number of workers, check your config for gcsfuse in [setup_gcsfuse.sh](https://github.com/google/maxtext/blob/main/tools/setup/setup_gcsfuse.sh) to avoid gcsfuse throttling.
4343

44-
5. For multi-source blending, you can specify multiple data sources with their respective weights using semicolon (;) as a separator and colon (:) for weights. The weights will be automatically normalized to sum to 1.0. For example:
44+
5. For multi-source blending, you can specify multiple data sources with their respective weights using semicolon (;) as a separator and a comma (,) for weights. The weights will be automatically normalized to sum to 1.0. For example:
4545
```
4646
# Blend two data sources with 30% from first source and 70% from second source
47-
grain_train_files=/tmp/gcsfuse/dataset1.array_record*:0.3;/tmp/gcsfuse/dataset2.array_record*:0.7
47+
grain_train_files=/tmp/gcsfuse/dataset1.array_record*,0.3;/tmp/gcsfuse/dataset2.array_record*,0.7
4848
4949
# Blend three data sources with equal weights (will be normalized to 0.33 each)
50-
grain_train_files=/tmp/gcsfuse/dataset1.array_record*:1;/tmp/gcsfuse/dataset2.array_record*:1;/tmp/gcsfuse/dataset3.array_record*:1
50+
grain_train_files=/tmp/gcsfuse/dataset1.array_record*,1;/tmp/gcsfuse/dataset2.array_record*,1;/tmp/gcsfuse/dataset3.array_record*,1
5151
```
5252
Note: When using multiple data sources, only the ArrayRecord format is supported.
5353

src/MaxText/configs/base.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -569,8 +569,8 @@ hf_eval_files: ''
569569
hf_access_token: ''
570570
# for Grain input pipeline (dataset_type=grain)
571571
# Path to grain data files. Can be a single pattern or multiple patterns with weights.
572-
# For multiple patterns, use semicolon (;) to separate and colon (:) to specify weights.
573-
# Example: "path/to/data1.array_record*:0.3;path/to/data2.array_record*:0.7"
572+
# For multiple patterns, use semicolon (;) to separate and comma (,) to specify weights.
573+
# Example: "path/to/data1.array_record*,0.3;path/to/data2.array_record*,0.7"
574574
# Note: When using multiple files (separated by ';'), only ArrayRecord format is supported.
575575
# For more details, see https://github.com/google/maxtext/blob/main/getting_started/Data_Input_Pipeline.md#grain-input-pipeline
576576
grain_train_files: ''

src/MaxText/input_pipeline/_grain_data_processing.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@
1717
import glob
1818
from pathlib import Path
1919
import functools
20-
2120
import ml_collections
2221

2322
import jax
2423

2524
import grain.python as grain
2625

26+
from MaxText.utils import gcs_utils
2727
from MaxText.input_pipeline import _input_pipeline_utils
2828
from MaxText.input_pipeline import _grain_tokenizer
2929
from MaxText import multihost_dataloading
@@ -32,8 +32,14 @@
3232

3333

3434
def find_data_files(data_file_pattern):
35-
data_files = glob.glob(str(Path(data_file_pattern).expanduser().resolve()))
36-
assert len(data_files) > 0, f"No file found with pattern {data_file_pattern}."
35+
"""Find data files matching the pattern."""
36+
if data_file_pattern.startswith("gs://"):
37+
data_files = gcs_utils.gcs_glob_pattern(data_file_pattern)
38+
else:
39+
# Local files
40+
data_files = glob.glob(str(Path(data_file_pattern).expanduser().resolve()))
41+
if not data_files:
42+
raise FileNotFoundError(f"No files found matching pattern: {data_file_pattern}")
3743
max_logging.log(f"Found {len(data_files)} files for train/eval with grain")
3844
return data_files
3945

@@ -51,7 +57,7 @@ def get_datasets(
5157
"""Load dataset from array_record files for using with grain"""
5258
if data_file_type == "arrayrecord":
5359
if ";" in data_file_pattern:
54-
data_file_patterns, weights = zip(*[pattern.split(":") for pattern in data_file_pattern.split(";")])
60+
data_file_patterns, weights = zip(*[pattern.split(",") for pattern in data_file_pattern.split(";")])
5561
assert len(data_file_patterns) == len(weights), "Number of data file patterns and weights must match"
5662
weights = [float(weight) for weight in weights]
5763
weights = [round(weight / sum(weights), 4) for weight in weights]

src/MaxText/utils/gcs_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,17 @@ def gcs_list_directories(directory_path):
145145
return directories
146146

147147

148+
def gcs_glob_pattern(pattern):
149+
"""
150+
Globs GCS files and returns a list of full GCS paths.
151+
"""
152+
storage_client = storage.Client()
153+
bucket_name, glob_pattern = parse_gcs_bucket_and_prefix(pattern)
154+
blobs = storage_client.list_blobs(bucket_name, match_glob=glob_pattern)
155+
data_files = [f"gs://{bucket_name}/{blob.name}" for blob in blobs]
156+
return data_files
157+
158+
148159
def read_json_from_gcs(file_path):
149160
"""
150161
Read a json file from gcs bucket.

tests/grain_data_processing_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,8 @@ def setUp(self):
114114
temp_dir = tempfile.gettempdir()
115115
# We use the same dataset for testing, but you can use different datasets by changing the file patterns.
116116
grain_train_files = [
117-
f"{temp_dir}/gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record*:0.3",
118-
f"{temp_dir}/gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record*:0.7",
117+
f"{temp_dir}/gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record*,0.3",
118+
f"{temp_dir}/gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record*,0.7",
119119
]
120120
grain_train_files = ";".join(grain_train_files)
121121
self.config = pyconfig.initialize(

0 commit comments

Comments
 (0)