Skip to content

Commit 18528fc

Browse files
committed
leverage gcs_utils.py
1 parent 4268f5a commit 18528fc

File tree

2 files changed

+18
-23
lines changed

2 files changed

+18
-23
lines changed

src/MaxText/input_pipeline/_grain_data_processing.py

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,11 @@
1717
import glob
1818
from pathlib import Path
1919
import functools
20-
from google.cloud import storage
21-
import re
22-
2320
import ml_collections
24-
2521
import jax
26-
2722
import grain.python as grain
2823

24+
from MaxText.utils import gcs_utils
2925
from MaxText.input_pipeline import _input_pipeline_utils
3026
from MaxText.input_pipeline import _grain_tokenizer
3127
from MaxText import multihost_dataloading
@@ -36,20 +32,12 @@
3632
def find_data_files(data_file_pattern):
3733
"""Find data files matching the pattern."""
3834
if data_file_pattern.startswith("gs://"):
39-
storage_client = storage.Client()
40-
match = re.match(r"gs://([a-z0-9._-]+)/(.+)", data_file_pattern)
41-
if not match:
42-
raise ValueError("Invalid GCS path pattern.")
43-
bucket_name, glob_pattern = match.groups()
44-
blobs = storage_client.list_blobs(bucket_name, match_glob=glob_pattern)
45-
data_files = [f"gs://{bucket_name}/{blob.name}" for blob in blobs]
35+
data_files = gcs_utils.gcs_glob_pattern(data_file_pattern)
4636
else:
47-
# Use glob for local files
37+
# Local files
4838
data_files = glob.glob(str(Path(data_file_pattern).expanduser().resolve()))
49-
5039
if not data_files:
5140
raise FileNotFoundError(f"No files found matching pattern: {data_file_pattern}")
52-
5341
max_logging.log(f"Found {len(data_files)} files for train/eval with grain")
5442
return data_files
5543

@@ -74,15 +62,22 @@ def get_datasets(
7462
dataset_list = [
7563
grain.MapDataset.source(grain.ArrayRecordDataSource(find_data_files(pattern))) for pattern in data_file_patterns
7664
]
77-
dataset = grain.MapDataset.mix(dataset_list, weights)
65+
# create iterator per dataset with unique index
66+
for ds in dataset_list:
67+
if shuffle:
68+
ds = ds.shuffle(seed=shuffle_seed)
69+
ds = ds.repeat(num_epoch)
70+
ds = ds[dataloading_host_index::dataloading_host_count] # sharding
71+
ds = ds.to_iter_dataset()
72+
dataset = grain.IterDataset.mix(dataset_list, weights)
7873
else:
7974
data_files = find_data_files(data_file_pattern)
8075
dataset = grain.MapDataset.source(grain.ArrayRecordDataSource(data_files))
81-
if shuffle:
82-
dataset = dataset.shuffle(seed=shuffle_seed)
83-
dataset = dataset.repeat(num_epoch)
84-
dataset = dataset[dataloading_host_index::dataloading_host_count] # sharding
85-
dataset = dataset.to_iter_dataset()
76+
if shuffle:
77+
dataset = dataset.shuffle(seed=shuffle_seed)
78+
dataset = dataset.repeat(num_epoch)
79+
dataset = dataset[dataloading_host_index::dataloading_host_count] # sharding
80+
dataset = dataset.to_iter_dataset()
8681
elif data_file_type == "parquet":
8782
data_files = find_data_files(data_file_pattern)
8883
dataset = grain.MapDataset.source(data_files)

tests/grain_data_processing_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,8 @@ def setUp(self):
114114
temp_dir = tempfile.gettempdir()
115115
# We use the same dataset for testing, but you can use different datasets by changing the file patterns.
116116
grain_train_files = [
117-
f"{temp_dir}/gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record*:0.3",
118-
f"{temp_dir}/gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record*:0.7",
117+
f"{temp_dir}/gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record*,0.3",
118+
f"{temp_dir}/gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record*,0.7",
119119
]
120120
grain_train_files = ";".join(grain_train_files)
121121
self.config = pyconfig.initialize(

0 commit comments

Comments
 (0)