1717import glob
1818from pathlib import Path
1919import functools
20- from google .cloud import storage
21- import re
22-
2320import ml_collections
24-
2521import jax
26-
2722import grain .python as grain
2823
24+ from MaxText .utils import gcs_utils
2925from MaxText .input_pipeline import _input_pipeline_utils
3026from MaxText .input_pipeline import _grain_tokenizer
3127from MaxText import multihost_dataloading
3632def find_data_files (data_file_pattern ):
3733 """Find data files matching the pattern."""
3834 if data_file_pattern .startswith ("gs://" ):
39- storage_client = storage .Client ()
40- match = re .match (r"gs://([a-z0-9._-]+)/(.+)" , data_file_pattern )
41- if not match :
42- raise ValueError ("Invalid GCS path pattern." )
43- bucket_name , glob_pattern = match .groups ()
44- blobs = storage_client .list_blobs (bucket_name , match_glob = glob_pattern )
45- data_files = [f"gs://{ bucket_name } /{ blob .name } " for blob in blobs ]
35+ data_files = gcs_utils .gcs_glob_pattern (data_file_pattern )
4636 else :
47- # Use glob for local files
37+ # Local files
4838 data_files = glob .glob (str (Path (data_file_pattern ).expanduser ().resolve ()))
49-
5039 if not data_files :
5140 raise FileNotFoundError (f"No files found matching pattern: { data_file_pattern } " )
52-
5341 max_logging .log (f"Found { len (data_files )} files for train/eval with grain" )
5442 return data_files
5543
@@ -74,15 +62,22 @@ def get_datasets(
7462 dataset_list = [
7563 grain .MapDataset .source (grain .ArrayRecordDataSource (find_data_files (pattern ))) for pattern in data_file_patterns
7664 ]
77- dataset = grain .MapDataset .mix (dataset_list , weights )
65+ # create iterator per dataset with unique index
66+ for ds in dataset_list :
67+ if shuffle :
68+ ds = ds .shuffle (seed = shuffle_seed )
69+ ds = ds .repeat (num_epoch )
70+ ds = ds [dataloading_host_index ::dataloading_host_count ] # sharding
71+ ds = ds .to_iter_dataset ()
72+ dataset = grain .IterDataset .mix (dataset_list , weights )
7873 else :
7974 data_files = find_data_files (data_file_pattern )
8075 dataset = grain .MapDataset .source (grain .ArrayRecordDataSource (data_files ))
81- if shuffle :
82- dataset = dataset .shuffle (seed = shuffle_seed )
83- dataset = dataset .repeat (num_epoch )
84- dataset = dataset [dataloading_host_index ::dataloading_host_count ] # sharding
85- dataset = dataset .to_iter_dataset ()
76+ if shuffle :
77+ dataset = dataset .shuffle (seed = shuffle_seed )
78+ dataset = dataset .repeat (num_epoch )
79+ dataset = dataset [dataloading_host_index ::dataloading_host_count ] # sharding
80+ dataset = dataset .to_iter_dataset ()
8681 elif data_file_type == "parquet" :
8782 data_files = find_data_files (data_file_pattern )
8883 dataset = grain .MapDataset .source (data_files )
0 commit comments