File tree Expand file tree Collapse file tree 3 files changed +16
-15
lines changed Expand file tree Collapse file tree 3 files changed +16
-15
lines changed Original file line number Diff line number Diff line change 1717import glob
1818from pathlib import Path
1919import functools
20- from google .cloud import storage
21- import re
22-
2320import ml_collections
2421
2522import jax
2623
2724import grain .python as grain
2825
26+ from MaxText .utils import gcs_utils
2927from MaxText .input_pipeline import _input_pipeline_utils
3028from MaxText .input_pipeline import _grain_tokenizer
3129from MaxText import multihost_dataloading
3634def find_data_files (data_file_pattern ):
3735 """Find data files matching the pattern."""
3836 if data_file_pattern .startswith ("gs://" ):
39- storage_client = storage .Client ()
40- match = re .match (r"gs://([a-z0-9._-]+)/(.+)" , data_file_pattern )
41- if not match :
42- raise ValueError ("Invalid GCS path pattern." )
43- bucket_name , glob_pattern = match .groups ()
44- blobs = storage_client .list_blobs (bucket_name , match_glob = glob_pattern )
45- data_files = [f"gs://{ bucket_name } /{ blob .name } " for blob in blobs ]
37+ data_files = gcs_utils .gcs_glob_pattern (data_file_pattern )
4638 else :
47- # Use glob for local files
39+ # Local files
4840 data_files = glob .glob (str (Path (data_file_pattern ).expanduser ().resolve ()))
49-
5041 if not data_files :
5142 raise FileNotFoundError (f"No files found matching pattern: { data_file_pattern } " )
52-
5343 max_logging .log (f"Found { len (data_files )} files for train/eval with grain" )
5444 return data_files
5545
Original file line number Diff line number Diff line change @@ -145,6 +145,17 @@ def gcs_list_directories(directory_path):
145145 return directories
146146
147147
148+ def gcs_glob_pattern (pattern ):
149+ """
150+ Globs GCS files and returns a list of full GCS paths.
151+ """
152+ storage_client = storage .Client ()
153+ bucket_name , glob_pattern = parse_gcs_bucket_and_prefix (pattern )
154+ blobs = storage_client .list_blobs (bucket_name , match_glob = glob_pattern )
155+ data_files = [f"gs://{ bucket_name } /{ blob .name } " for blob in blobs ]
156+ return data_files
157+
158+
148159def read_json_from_gcs (file_path ):
149160 """
150161 Read a json file from gcs bucket.
Original file line number Diff line number Diff line change @@ -114,8 +114,8 @@ def setUp(self):
114114 temp_dir = tempfile .gettempdir ()
115115 # We use the same dataset for testing, but you can use different datasets by changing the file patterns.
116116 grain_train_files = [
117- f"{ temp_dir } /gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record*: 0.3" ,
118- f"{ temp_dir } /gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record*: 0.7" ,
117+ f"{ temp_dir } /gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record*, 0.3" ,
118+ f"{ temp_dir } /gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record*, 0.7" ,
119119 ]
120120 grain_train_files = ";" .join (grain_train_files )
121121 self .config = pyconfig .initialize (
You can’t perform that action at this time.
0 commit comments