Skip to content

Commit e2dba41

Browse files
committed
leverage gcs_util
1 parent 4268f5a commit e2dba41

File tree

3 files changed

+16
-15
lines changed

3 files changed

+16
-15
lines changed

src/MaxText/input_pipeline/_grain_data_processing.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,13 @@
1717
import glob
1818
from pathlib import Path
1919
import functools
20-
from google.cloud import storage
21-
import re
22-
2320
import ml_collections
2421

2522
import jax
2623

2724
import grain.python as grain
2825

26+
from MaxText.utils import gcs_utils
2927
from MaxText.input_pipeline import _input_pipeline_utils
3028
from MaxText.input_pipeline import _grain_tokenizer
3129
from MaxText import multihost_dataloading
@@ -36,20 +34,12 @@
3634
def find_data_files(data_file_pattern):
3735
"""Find data files matching the pattern."""
3836
if data_file_pattern.startswith("gs://"):
39-
storage_client = storage.Client()
40-
match = re.match(r"gs://([a-z0-9._-]+)/(.+)", data_file_pattern)
41-
if not match:
42-
raise ValueError("Invalid GCS path pattern.")
43-
bucket_name, glob_pattern = match.groups()
44-
blobs = storage_client.list_blobs(bucket_name, match_glob=glob_pattern)
45-
data_files = [f"gs://{bucket_name}/{blob.name}" for blob in blobs]
37+
data_files = gcs_utils.gcs_glob_pattern(data_file_pattern)
4638
else:
47-
# Use glob for local files
39+
# Local files
4840
data_files = glob.glob(str(Path(data_file_pattern).expanduser().resolve()))
49-
5041
if not data_files:
5142
raise FileNotFoundError(f"No files found matching pattern: {data_file_pattern}")
52-
5343
max_logging.log(f"Found {len(data_files)} files for train/eval with grain")
5444
return data_files
5545

src/MaxText/utils/gcs_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,17 @@ def gcs_list_directories(directory_path):
145145
return directories
146146

147147

148+
def gcs_glob_pattern(pattern):
149+
"""
150+
Globs GCS files and returns a list of full GCS paths.
151+
"""
152+
storage_client = storage.Client()
153+
bucket_name, glob_pattern = parse_gcs_bucket_and_prefix(pattern)
154+
blobs = storage_client.list_blobs(bucket_name, match_glob=glob_pattern)
155+
data_files = [f"gs://{bucket_name}/{blob.name}" for blob in blobs]
156+
return data_files
157+
158+
148159
def read_json_from_gcs(file_path):
149160
"""
150161
Read a json file from gcs bucket.

tests/grain_data_processing_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,8 @@ def setUp(self):
114114
temp_dir = tempfile.gettempdir()
115115
# We use the same dataset for testing, but you can use different datasets by changing the file patterns.
116116
grain_train_files = [
117-
f"{temp_dir}/gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record*:0.3",
118-
f"{temp_dir}/gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record*:0.7",
117+
f"{temp_dir}/gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record*,0.3",
118+
f"{temp_dir}/gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record*,0.7",
119119
]
120120
grain_train_files = ";".join(grain_train_files)
121121
self.config = pyconfig.initialize(

0 commit comments

Comments
 (0)