Skip to content

Commit 6d3a363

Browse files
committed
Support GCS file pattern in grain
1 parent fca6e8f commit 6d3a363

File tree

1 file changed

+19
-3
lines changed

1 file changed

+19
-3
lines changed

src/MaxText/input_pipeline/_grain_data_processing.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import glob
1818
from pathlib import Path
1919
import functools
20+
from google.cloud import storage
21+
import re
2022

2123
import ml_collections
2224

@@ -32,8 +34,22 @@
3234

3335

3436
def find_data_files(data_file_pattern):
35-
data_files = glob.glob(str(Path(data_file_pattern).expanduser().resolve()))
36-
assert len(data_files) > 0, f"No file found with pattern {data_file_pattern}."
37+
"""Find data files matching the pattern."""
38+
if data_file_pattern.startswith("gs://"):
39+
storage_client = storage.Client()
40+
match = re.match(r"gs://([a-z0-9._-]+)/(.+)", data_file_pattern)
41+
if not match:
42+
raise ValueError("Invalid GCS path pattern.")
43+
bucket_name, glob_pattern = match.groups()
44+
blobs = storage_client.list_blobs(bucket_name, match_glob=glob_pattern)
45+
data_files = [f"gs://{bucket_name}/{blob.name}" for blob in blobs]
46+
else:
47+
# Use glob for local files
48+
data_files = glob.glob(str(Path(data_file_pattern).expanduser().resolve()))
49+
50+
if not data_files:
51+
raise FileNotFoundError(f"No files found matching pattern: {data_file_pattern}")
52+
3753
max_logging.log(f"Found {len(data_files)} files for train/eval with grain")
3854
return data_files
3955

@@ -51,7 +67,7 @@ def get_datasets(
5167
"""Load dataset from array_record files for using with grain"""
5268
if data_file_type == "arrayrecord":
5369
if ";" in data_file_pattern:
54-
data_file_patterns, weights = zip(*[pattern.split(":") for pattern in data_file_pattern.split(";")])
70+
data_file_patterns, weights = zip(*[pattern.split(",") for pattern in data_file_pattern.split(";")])
5571
assert len(data_file_patterns) == len(weights), "Number of data file patterns and weights must match"
5672
weights = [float(weight) for weight in weights]
5773
weights = [round(weight / sum(weights), 4) for weight in weights]

0 commit comments

Comments
 (0)