1717import glob
1818from pathlib import Path
1919import functools
20+ from google .cloud import storage
21+ import re
2022
2123import ml_collections
2224
3234
3335
3436def find_data_files (data_file_pattern ):
35- data_files = glob .glob (str (Path (data_file_pattern ).expanduser ().resolve ()))
36- assert len (data_files ) > 0 , f"No file found with pattern { data_file_pattern } ."
37+ """Find data files matching the pattern."""
38+ if data_file_pattern .startswith ("gs://" ):
39+ storage_client = storage .Client ()
40+ match = re .match (r"gs://([a-z0-9._-]+)/(.+)" , data_file_pattern )
41+ if not match :
42+ raise ValueError ("Invalid GCS path pattern." )
43+ bucket_name , glob_pattern = match .groups ()
44+ blobs = storage_client .list_blobs (bucket_name , match_glob = glob_pattern )
45+ data_files = [f"gs://{ bucket_name } /{ blob .name } " for blob in blobs ]
46+ else :
47+ # Use glob for local files
48+ data_files = glob .glob (str (Path (data_file_pattern ).expanduser ().resolve ()))
49+
50+ if not data_files :
51+ raise FileNotFoundError (f"No files found matching pattern: { data_file_pattern } " )
52+
3753 max_logging .log (f"Found { len (data_files )} files for train/eval with grain" )
3854 return data_files
3955
@@ -51,7 +67,7 @@ def get_datasets(
5167 """Load dataset from array_record files for using with grain"""
5268 if data_file_type == "arrayrecord" :
5369 if ";" in data_file_pattern :
54- data_file_patterns , weights = zip (* [pattern .split (": " ) for pattern in data_file_pattern .split (";" )])
70+ data_file_patterns , weights = zip (* [pattern .split (", " ) for pattern in data_file_pattern .split (";" )])
5571 assert len (data_file_patterns ) == len (weights ), "Number of data file patterns and weights must match"
5672 weights = [float (weight ) for weight in weights ]
5773 weights = [round (weight / sum (weights ), 4 ) for weight in weights ]
0 commit comments