Skip to content

Commit 4268f5a

Browse files
committed
Support GCS file pattern in grain
1 parent fca6e8f commit 4268f5a

File tree

3 files changed

+24
-8
lines changed

3 files changed

+24
-8
lines changed

docs/guides/data_input_pipeline/data_input_grain.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,13 @@ MOUNT_PATH=$MOUNT_PATH \
4141
3. Set `dataset_type=grain`, `grain_file_type={arrayrecord|parquet}`, `grain_train_files` to match the file pattern on the mounted local path.
4242
4. Tune `grain_worker_count` for performance. This parameter controls the number of child processes used by Grain (more details in [behind_the_scenes](https://google-grain.readthedocs.io/en/latest/behind_the_scenes.html), [grain_pool.py](https://github.com/google/grain/blob/main/grain/_src/python/grain_pool.py)). If you use a large number of workers, check your config for gcsfuse in [setup_gcsfuse.sh](https://github.com/google/maxtext/blob/main/tools/setup/setup_gcsfuse.sh) to avoid gcsfuse throttling.
4343

44-
5. For multi-source blending, you can specify multiple data sources with their respective weights using semicolon (;) as a separator and colon (:) for weights. The weights will be automatically normalized to sum to 1.0. For example:
44+
5. For multi-source blending, you can specify multiple data sources with their respective weights using semicolon (;) as a separator and a comma (,) for weights. The weights will be automatically normalized to sum to 1.0. For example:
4545
```
4646
# Blend two data sources with 30% from first source and 70% from second source
47-
grain_train_files=/tmp/gcsfuse/dataset1.array_record*:0.3;/tmp/gcsfuse/dataset2.array_record*:0.7
47+
grain_train_files=/tmp/gcsfuse/dataset1.array_record*,0.3;/tmp/gcsfuse/dataset2.array_record*,0.7
4848
4949
# Blend three data sources with equal weights (will be normalized to 0.33 each)
50-
grain_train_files=/tmp/gcsfuse/dataset1.array_record*:1;/tmp/gcsfuse/dataset2.array_record*:1;/tmp/gcsfuse/dataset3.array_record*:1
50+
grain_train_files=/tmp/gcsfuse/dataset1.array_record*,1;/tmp/gcsfuse/dataset2.array_record*,1;/tmp/gcsfuse/dataset3.array_record*,1
5151
```
5252
Note: When using multiple data sources, only the ArrayRecord format is supported.
5353

src/MaxText/configs/base.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -569,8 +569,8 @@ hf_eval_files: ''
569569
hf_access_token: ''
570570
# for Grain input pipeline (dataset_type=grain)
571571
# Path to grain data files. Can be a single pattern or multiple patterns with weights.
572-
# For multiple patterns, use semicolon (;) to separate and colon (:) to specify weights.
573-
# Example: "path/to/data1.array_record*:0.3;path/to/data2.array_record*:0.7"
572+
# For multiple patterns, use semicolon (;) to separate and comma (,) to specify weights.
573+
# Example: "path/to/data1.array_record*,0.3;path/to/data2.array_record*,0.7"
574574
# Note: When using multiple files (separated by ';'), only ArrayRecord format is supported.
575575
# For more details, see https://github.com/google/maxtext/blob/main/getting_started/Data_Input_Pipeline.md#grain-input-pipeline
576576
grain_train_files: ''

src/MaxText/input_pipeline/_grain_data_processing.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import glob
1818
from pathlib import Path
1919
import functools
20+
from google.cloud import storage
21+
import re
2022

2123
import ml_collections
2224

@@ -32,8 +34,22 @@
3234

3335

3436
def find_data_files(data_file_pattern):
35-
data_files = glob.glob(str(Path(data_file_pattern).expanduser().resolve()))
36-
assert len(data_files) > 0, f"No file found with pattern {data_file_pattern}."
37+
"""Find data files matching the pattern."""
38+
if data_file_pattern.startswith("gs://"):
39+
storage_client = storage.Client()
40+
match = re.match(r"gs://([a-z0-9._-]+)/(.+)", data_file_pattern)
41+
if not match:
42+
raise ValueError("Invalid GCS path pattern.")
43+
bucket_name, glob_pattern = match.groups()
44+
blobs = storage_client.list_blobs(bucket_name, match_glob=glob_pattern)
45+
data_files = [f"gs://{bucket_name}/{blob.name}" for blob in blobs]
46+
else:
47+
# Use glob for local files
48+
data_files = glob.glob(str(Path(data_file_pattern).expanduser().resolve()))
49+
50+
if not data_files:
51+
raise FileNotFoundError(f"No files found matching pattern: {data_file_pattern}")
52+
3753
max_logging.log(f"Found {len(data_files)} files for train/eval with grain")
3854
return data_files
3955

@@ -51,7 +67,7 @@ def get_datasets(
5167
"""Load dataset from array_record files for using with grain"""
5268
if data_file_type == "arrayrecord":
5369
if ";" in data_file_pattern:
54-
data_file_patterns, weights = zip(*[pattern.split(":") for pattern in data_file_pattern.split(";")])
70+
data_file_patterns, weights = zip(*[pattern.split(",") for pattern in data_file_pattern.split(";")])
5571
assert len(data_file_patterns) == len(weights), "Number of data file patterns and weights must match"
5672
weights = [float(weight) for weight in weights]
5773
weights = [round(weight / sum(weights), 4) for weight in weights]

0 commit comments

Comments
 (0)