From 4b791aba0ff5a36cdcf2df59334b057ab2cf4514 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Sun, 17 Aug 2025 16:44:24 +0800 Subject: [PATCH 1/3] Add Grain support to `image_dataset_from_directory` and `text_dataset_from_directory`. --- keras/src/utils/audio_dataset_utils.py | 2 +- keras/src/utils/dataset_utils.py | 91 ++++++- keras/src/utils/grain_utils.py | 33 +++ keras/src/utils/image_dataset_utils.py | 247 +++++++++++++++++- keras/src/utils/image_dataset_utils_test.py | 276 ++++++++++++++------ keras/src/utils/module_utils.py | 1 + keras/src/utils/text_dataset_utils.py | 149 ++++++++++- keras/src/utils/text_dataset_utils_test.py | 239 ++++++++++++----- 8 files changed, 857 insertions(+), 181 deletions(-) create mode 100644 keras/src/utils/grain_utils.py diff --git a/keras/src/utils/audio_dataset_utils.py b/keras/src/utils/audio_dataset_utils.py index b6f27d37c85c..ad2fb4e7f565 100644 --- a/keras/src/utils/audio_dataset_utils.py +++ b/keras/src/utils/audio_dataset_utils.py @@ -411,7 +411,7 @@ def paths_and_labels_to_dataset( """Constructs a fixed-size dataset of audio and labels.""" path_ds = tf.data.Dataset.from_tensor_slices(file_paths) if label_mode: - label_ds = dataset_utils.labels_to_dataset( + label_ds = dataset_utils.labels_to_dataset_tf( labels, label_mode, num_classes ) ds = tf.data.Dataset.zip((path_ds, label_ds)) diff --git a/keras/src/utils/dataset_utils.py b/keras/src/utils/dataset_utils.py index b8d2a534b1a2..85fe677fb3d9 100644 --- a/keras/src/utils/dataset_utils.py +++ b/keras/src/utils/dataset_utils.py @@ -8,7 +8,9 @@ from keras.src import tree from keras.src.api_export import keras_export +from keras.src.utils import file_utils from keras.src.utils import io_utils +from keras.src.utils.module_utils import grain from keras.src.utils.module_utils import tensorflow as tf @@ -299,6 +301,17 @@ def is_torch_dataset(dataset): return False +def is_grain_dataset(dataset): + if hasattr(dataset, "__class__"): + for parent in dataset.__class__.__mro__: + if parent.__name__ in ( + "MapDataset", + "IterDataset", + ) and str(parent.__module__).startswith("grain._src.python"): + return True + return False + + def _rescale_dataset_split_sizes(left_size, right_size, total_length): """Rescale the dataset split sizes. @@ -476,6 +489,10 @@ def _get_type_spec(dataset): from torch.utils.data import Dataset as TorchDataset return TorchDataset + elif is_grain_dataset(dataset): + from grain import MapDataset + + return MapDataset else: return None @@ -525,10 +542,17 @@ def index_directory( - class_names: names of the classes corresponding to these labels, in order. """ + if file_utils.is_remote_path(directory): + os_module = tf.io.gfile + path_module = tf.io.gfile + else: + os_module = os + path_module = os.path + if labels == "inferred": subdirs = [] - for subdir in sorted(tf.io.gfile.listdir(directory)): - if tf.io.gfile.isdir(tf.io.gfile.join(directory, subdir)): + for subdir in sorted(os_module.listdir(directory)): + if path_module.isdir(path_module.join(directory, subdir)): if not subdir.startswith("."): if subdir.endswith("/"): subdir = subdir[:-1] @@ -566,7 +590,7 @@ def index_directory( results = [] filenames = [] - for dirpath in (tf.io.gfile.join(directory, subdir) for subdir in subdirs): + for dirpath in (path_module.join(directory, subdir) for subdir in subdirs): results.append( pool.apply_async( index_subdirectory, @@ -608,7 +632,7 @@ def index_directory( ) pool.close() pool.join() - file_paths = [tf.io.gfile.join(directory, fname) for fname in filenames] + file_paths = [path_module.join(directory, fname) for fname in filenames] if shuffle: # Shuffle globally to erase macro-structure @@ -623,8 +647,10 @@ def index_directory( def iter_valid_files(directory, follow_links, formats): + io_module = tf.io.gfile if file_utils.is_remote_path(directory) else os + if not follow_links: - walk = tf.io.gfile.walk(directory) + walk = io_module.walk(directory) else: walk = os.walk(directory, followlinks=follow_links) for root, _, files in sorted(walk, key=lambda x: x[0]): @@ -648,14 +674,18 @@ def index_subdirectory(directory, class_indices, follow_links, formats): paths, and `labels` is a list of integer labels corresponding to these files. """ + path_module = ( + tf.io.gfile if file_utils.is_remote_path(directory) else os.path + ) + dirname = os.path.basename(directory) valid_files = iter_valid_files(directory, follow_links, formats) labels = [] filenames = [] for root, fname in valid_files: labels.append(class_indices[dirname]) - absolute_path = tf.io.gfile.join(root, fname) - relative_path = tf.io.gfile.join( + absolute_path = path_module.join(root, fname) + relative_path = path_module.join( dirname, os.path.relpath(absolute_path, directory) ) filenames.append(relative_path) @@ -700,7 +730,7 @@ def get_training_or_validation_split(samples, labels, validation_split, subset): return samples, labels -def labels_to_dataset(labels, label_mode, num_classes): +def labels_to_dataset_tf(labels, label_mode, num_classes): """Create a `tf.data.Dataset` from the list/tuple of labels. Args: @@ -730,6 +760,51 @@ def labels_to_dataset(labels, label_mode, num_classes): return label_ds +def labels_to_dataset_grain(labels, label_mode, num_classes): + """Create a `grain.MapDataset` from the list/tuple of labels. + + Args: + labels: list/tuple of labels to be converted into a `grain.MapDataset`. + label_mode: String describing the encoding of `labels`. Options are: + - `"binary"` indicates that the labels (there can be only 2) are encoded + as `float32` scalars with values 0 or 1 + (e.g. for `binary_crossentropy`). + - `"categorical"` means that the labels are mapped into a categorical + vector. (e.g. for `categorical_crossentropy` loss). + num_classes: number of classes of labels. + + Returns: + A `grain.MapDataset` instance. + """ + from keras.src import backend + from keras.src import ops + + if label_mode not in ("binary", "categorical", "int"): + raise ValueError( + f"Invalid `label_mode`: {label_mode}. " + "Expected one of: 'binary', 'categorical', 'int'." + ) + + def preprocess_labels_in_cpu(label_mode, x, num_classes): + with backend.device_scope("cpu"): + if label_mode == "binary": + return ops.expand_dims( + ops.convert_to_tensor(x, dtype="float32"), axis=-1 + ) + elif label_mode == "categorical": + return ops.one_hot( + ops.convert_to_tensor(x, dtype="int32"), num_classes + ) + else: + return ops.convert_to_tensor(x, dtype="int32") + + label_ds = grain.MapDataset.source(labels) + label_ds = label_ds.map( + lambda x: preprocess_labels_in_cpu(label_mode, x, num_classes), + ) + return label_ds + + def check_validation_split_arg(validation_split, subset, shuffle, seed): """Raise errors in case of invalid argument values. diff --git a/keras/src/utils/grain_utils.py b/keras/src/utils/grain_utils.py new file mode 100644 index 000000000000..f0a562505dd6 --- /dev/null +++ b/keras/src/utils/grain_utils.py @@ -0,0 +1,33 @@ +from keras.src import backend +from keras.src import tree + + +def make_batch(values): + from keras.src import ops + + if not values: + raise ValueError("Cannot batch 0 values. Please file a bug.") + + with backend.device_scope("cpu"): + return tree.map_structure(lambda *xs: ops.stack(xs), *values) + + +def make_string_batch(values): + from keras.src import ops + + if not values: + raise ValueError("Cannot batch 0 values. Please file a bug.") + + def batch_fn(*xs): + if isinstance(xs[0], str): + if backend.backend() == "tensorflow": + import tensorflow as tf + + xs = [tf.convert_to_tensor(x, dtype=tf.string) for x in xs] + xs = tf.stack(xs) + return xs + else: + return ops.stack(xs) + + with backend.device_scope("cpu"): + return tree.map_structure(batch_fn, *values) diff --git a/keras/src/utils/image_dataset_utils.py b/keras/src/utils/image_dataset_utils.py index c1918be73eef..35cb802f0e84 100755 --- a/keras/src/utils/image_dataset_utils.py +++ b/keras/src/utils/image_dataset_utils.py @@ -1,11 +1,27 @@ +import io +import pathlib + import numpy as np from keras.src.api_export import keras_export from keras.src.backend.config import standardize_data_format from keras.src.utils import dataset_utils from keras.src.utils import image_utils +from keras.src.utils.grain_utils import make_batch +from keras.src.utils.module_utils import grain from keras.src.utils.module_utils import tensorflow as tf +try: + from PIL import Image as pil_image + + try: + pil_image_resampling = pil_image.Resampling + except AttributeError: + pil_image_resampling = pil_image +except ImportError: + pil_image = None + pil_image_resampling = None + ALLOWLIST_FORMATS = (".bmp", ".gif", ".jpeg", ".jpg", ".png") @@ -32,6 +48,7 @@ def image_dataset_from_directory( crop_to_aspect_ratio=False, pad_to_aspect_ratio=False, data_format=None, + format="tf", verbose=True, ): """Generates a `tf.data.Dataset` from image files in a directory. @@ -125,12 +142,19 @@ def image_dataset_from_directory( preserved. data_format: If None uses keras.config.image_data_format() otherwise either 'channel_last' or 'channel_first'. + format: The format of the return object. Defaults to `"tf"`. Available + options are: + - `"tf"`: returns a `tf.data.Dataset` object. Requires + TensorFlow to be installed. + - `"grain"`: returns a `grain.IterDataset` object. Requires + Grain to be installed. verbose: Whether to display number information on classes and number of files found. Defaults to `True`. Returns: - A `tf.data.Dataset` object. + A `tf.data.Dataset` (`format="tf"`) or `grain.IterDataset` + (`format="grain"`) object. - If `label_mode` is `None`, it yields `float32` tensors of shape `(batch_size, image_size[0], image_size[1], num_channels)`, @@ -222,6 +246,11 @@ def image_dataset_from_directory( f"{supported_interpolations}. " f"Received: interpolation={interpolation}" ) + if format not in ("tf", "grain"): + raise ValueError( + '`format` should be either "tf" or "grain". ' + f"Received: format={format}" + ) dataset_utils.check_validation_split_arg( validation_split, subset, shuffle, seed @@ -289,6 +318,7 @@ def image_dataset_from_directory( shuffle=shuffle, shuffle_buffer_size=shuffle_buffer_size, seed=seed, + format=format, ) val_dataset = paths_and_labels_to_dataset( @@ -303,14 +333,23 @@ def image_dataset_from_directory( pad_to_aspect_ratio=pad_to_aspect_ratio, data_format=data_format, shuffle=False, + format=format, ) - if batch_size is not None: - train_dataset = train_dataset.batch(batch_size) - val_dataset = val_dataset.batch(batch_size) - - train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE) - val_dataset = val_dataset.prefetch(tf.data.AUTOTUNE) + if format == "tf": + if batch_size is not None: + train_dataset = train_dataset.batch(batch_size) + val_dataset = val_dataset.batch(batch_size) + train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE) + val_dataset = val_dataset.prefetch(tf.data.AUTOTUNE) + else: + train_dataset = train_dataset.to_iter_dataset() + val_dataset = val_dataset.to_iter_dataset() + if batch_size is not None: + train_dataset = train_dataset.batch( + batch_size, batch_fn=make_batch + ) + val_dataset = val_dataset.batch(batch_size, batch_fn=make_batch) # Users may need to reference `class_names`. train_dataset.class_names = class_names @@ -345,12 +384,18 @@ def image_dataset_from_directory( shuffle=shuffle, shuffle_buffer_size=shuffle_buffer_size, seed=seed, + format=format, ) - if batch_size is not None: - dataset = dataset.batch(batch_size) + if format == "tf": + if batch_size is not None: + dataset = dataset.batch(batch_size) + dataset = dataset.prefetch(tf.data.AUTOTUNE) + else: + dataset = dataset.to_iter_dataset() + if batch_size is not None: + dataset = dataset.batch(batch_size, batch_fn=make_batch) - dataset = dataset.prefetch(tf.data.AUTOTUNE) # Users may need to reference `class_names`. dataset.class_names = class_names @@ -374,11 +419,66 @@ def paths_and_labels_to_dataset( shuffle=False, shuffle_buffer_size=None, seed=None, + format="tf", +): + """Constructs a dataset of images and labels.""" + if format == "tf": + return _paths_and_labels_to_dataset_tf( + image_paths=image_paths, + image_size=image_size, + num_channels=num_channels, + labels=labels, + label_mode=label_mode, + num_classes=num_classes, + interpolation=interpolation, + data_format=data_format, + crop_to_aspect_ratio=crop_to_aspect_ratio, + pad_to_aspect_ratio=pad_to_aspect_ratio, + shuffle=shuffle, + shuffle_buffer_size=shuffle_buffer_size, + seed=seed, + ) + elif format == "grain": + return _paths_and_labels_to_dataset_grain( + image_paths=image_paths, + image_size=image_size, + num_channels=num_channels, + labels=labels, + label_mode=label_mode, + num_classes=num_classes, + interpolation=interpolation, + data_format=data_format, + crop_to_aspect_ratio=crop_to_aspect_ratio, + pad_to_aspect_ratio=pad_to_aspect_ratio, + shuffle=shuffle, + seed=seed, + ) + else: + raise ValueError( + '`format` should be either "tf" or "grain". ' + f"Received: format={format}" + ) + + +def _paths_and_labels_to_dataset_tf( + image_paths, + image_size, + num_channels, + labels, + label_mode, + num_classes, + interpolation, + data_format, + crop_to_aspect_ratio=False, + pad_to_aspect_ratio=False, + shuffle=False, + shuffle_buffer_size=None, + seed=None, ): """Constructs a dataset of images and labels.""" path_ds = tf.data.Dataset.from_tensor_slices(image_paths) if label_mode: - label_ds = dataset_utils.labels_to_dataset( + label_ds = dataset_utils.labels_to_dataset_tf( labels, label_mode, num_classes ) ds = tf.data.Dataset.zip((path_ds, label_ds)) @@ -398,17 +498,18 @@ def paths_and_labels_to_dataset( ) if label_mode: ds = ds.map( - lambda x, y: (load_image(x, *args), y), + lambda x, y: (_load_image_tf(x, *args), y), num_parallel_calls=tf.data.AUTOTUNE, ) else: ds = ds.map( - lambda x: load_image(x, *args), num_parallel_calls=tf.data.AUTOTUNE + lambda x: _load_image_tf(x, *args), + num_parallel_calls=tf.data.AUTOTUNE, ) return ds -def load_image( +def _load_image_tf( path, image_size, num_channels, @@ -457,3 +558,121 @@ def load_image( else: img.set_shape((num_channels, image_size[0], image_size[1])) return img + + +def _paths_and_labels_to_dataset_grain( + image_paths, + image_size, + num_channels, + labels, + label_mode, + num_classes, + interpolation, + data_format, + crop_to_aspect_ratio=False, + pad_to_aspect_ratio=False, + shuffle=False, + seed=None, +): + """Constructs a dataset of images and labels.""" + path_ds = grain.MapDataset.source(image_paths) + if label_mode: + label_ds = dataset_utils.labels_to_dataset_grain( + labels, label_mode, num_classes + ) + ds = grain.experimental.ZipMapDataset([path_ds, label_ds]) + else: + ds = path_ds + + if shuffle: + ds = ds.shuffle(seed=seed) + + args = ( + image_size, + num_channels, + interpolation, + data_format, + crop_to_aspect_ratio, + pad_to_aspect_ratio, + ) + if label_mode: + ds = ds.map(lambda data: (_load_image_grain(data[0], *args), data[1])) + else: + ds = ds.map(lambda x: _load_image_grain(x, *args)) + + return ds + + +def _load_image_grain( + path, + image_size, + num_channels, + interpolation, + data_format, + crop_to_aspect_ratio=False, + pad_to_aspect_ratio=False, +): + """Load an image from a path and resize it.""" + from keras.src import backend + from keras.src import ops + + if pil_image is None: + raise ImportError( + "Could not import PIL.Image. The use of `load_img` requires PIL." + ) + if pad_to_aspect_ratio and crop_to_aspect_ratio: + raise ValueError( + "Only one of `pad_to_aspect_ratio`, `crop_to_aspect_ratio`" + " can be set to `True`." + ) + + if isinstance(path, io.BytesIO): + img = pil_image.open(path) + elif isinstance(path, (pathlib.Path, bytes, str)): + if isinstance(path, pathlib.Path): + path = str(path.resolve()) + img = pil_image.open(path) + else: + raise TypeError( + f"path should be path-like or io.BytesIO, not {type(path)}" + ) + if num_channels == 1: + # if image is not already an 8-bit, 16-bit or 32-bit grayscale image + # convert it to an 8-bit grayscale image. + if img.mode not in ("L", "I;16", "I"): + img = img.convert("L") + elif num_channels == 4: + if img.mode != "RGBA": + img = img.convert("RGBA") + elif num_channels == 3: + if img.mode != "RGB": + img = img.convert("RGB") + else: + raise ValueError( + "num_channels must be 1, 3 or 4. " + f"Received: num_channels={num_channels}" + ) + + with backend.device_scope("cpu"): + img = np.array(img) + if img.ndim == 2: + # If the image is grayscale, expand dims to add channel axis. + # The reason is that `ops.image.resize` expects 3D or 4D tensors. + img = np.expand_dims( + img, axis=-1 if data_format == "channels_last" else 0 + ) + img = ops.convert_to_tensor(img, dtype="float32") + img = ops.image.resize( + img, + size=image_size, + interpolation=interpolation, + crop_to_aspect_ratio=crop_to_aspect_ratio, + pad_to_aspect_ratio=pad_to_aspect_ratio, + data_format=data_format, + ) + if backend.backend() == "tensorflow": + if data_format == "channels_last": + img.set_shape((image_size[0], image_size[1], num_channels)) + else: + img.set_shape((num_channels, image_size[0], image_size[1])) + return img diff --git a/keras/src/utils/image_dataset_utils_test.py b/keras/src/utils/image_dataset_utils_test.py index e6d006ab7c0e..31251228b86f 100644 --- a/keras/src/utils/image_dataset_utils_test.py +++ b/keras/src/utils/image_dataset_utils_test.py @@ -1,8 +1,10 @@ import os import numpy as np +from absl.testing import parameterized from keras.src import backend +from keras.src import ops from keras.src import testing from keras.src.utils import image_dataset_utils from keras.src.utils import image_utils @@ -66,7 +68,11 @@ def _prepare_directory( i += 1 return temp_dir - def test_image_dataset_from_directory_no_labels(self): + @parameterized.named_parameters( + ("tf", "tf"), + ("grain", "grain"), + ) + def test_image_dataset_from_directory_no_labels(self, format): # Test retrieving images without labels from a directory and its # subdirs. @@ -77,7 +83,11 @@ def test_image_dataset_from_directory_no_labels(self): img.save(os.path.join(directory, filename)) dataset = image_dataset_utils.image_dataset_from_directory( - directory, batch_size=5, image_size=(18, 18), labels=None + directory, + batch_size=5, + image_size=(18, 18), + labels=None, + format=format, ) if backend.config.image_data_format() == "channels_last": output_shape = [5, 18, 18, 3] @@ -86,8 +96,8 @@ def test_image_dataset_from_directory_no_labels(self): self.assertEqual(dataset.class_names, None) batch = next(iter(dataset)) # We return plain images - self.assertEqual(batch.shape, output_shape) - self.assertEqual(batch.dtype.name, "float32") + self.assertEqual(list(batch.shape), output_shape) + self.assertDType(batch, "float32") # Count samples batch_count = 0 sample_count = 0 @@ -97,10 +107,18 @@ def test_image_dataset_from_directory_no_labels(self): self.assertEqual(batch_count, 2) self.assertEqual(sample_count, 10) - def test_image_dataset_from_directory_binary(self): + @parameterized.named_parameters( + ("tf", "tf"), + ("grain", "grain"), + ) + def test_image_dataset_from_directory_binary(self, format): directory = self._prepare_directory(num_classes=2) dataset = image_dataset_utils.image_dataset_from_directory( - directory, batch_size=8, image_size=(18, 18), label_mode="int" + directory, + batch_size=8, + image_size=(18, 18), + label_mode="int", + format=format, ) if backend.config.image_data_format() == "channels_last": output_shape = [8, 18, 18, 3] @@ -108,33 +126,38 @@ def test_image_dataset_from_directory_binary(self): output_shape = [8, 3, 18, 18] batch = next(iter(dataset)) self.assertLen(batch, 2) - self.assertEqual(batch[0].shape, output_shape) - self.assertEqual(batch[0].dtype.name, "float32") - self.assertEqual(batch[1].shape, (8,)) - self.assertEqual(batch[1].dtype.name, "int32") + self.assertEqual(list(batch[0].shape), output_shape) + self.assertDType(batch[0], "float32") + self.assertEqual(list(batch[1].shape), [8]) + self.assertDType(batch[1], "int32") dataset = image_dataset_utils.image_dataset_from_directory( - directory, batch_size=8, image_size=(18, 18), label_mode="binary" + directory, + batch_size=8, + image_size=(18, 18), + label_mode="binary", + format=format, ) batch = next(iter(dataset)) self.assertLen(batch, 2) - self.assertEqual(batch[0].shape, output_shape) - self.assertEqual(batch[0].dtype.name, "float32") - self.assertEqual(batch[1].shape, (8, 1)) - self.assertEqual(batch[1].dtype.name, "float32") + self.assertEqual(list(batch[0].shape), output_shape) + self.assertDType(batch[0], "float32") + self.assertEqual(list(batch[1].shape), [8, 1]) + self.assertDType(batch[1], "float32") dataset = image_dataset_utils.image_dataset_from_directory( directory, batch_size=8, image_size=(18, 18), label_mode="categorical", + format=format, ) batch = next(iter(dataset)) self.assertLen(batch, 2) - self.assertEqual(batch[0].shape, output_shape) - self.assertEqual(batch[0].dtype.name, "float32") - self.assertEqual(batch[1].shape, (8, 2)) - self.assertEqual(batch[1].dtype.name, "float32") + self.assertEqual(list(batch[0].shape), output_shape) + self.assertDType(batch[0], "float32") + self.assertEqual(list(batch[1].shape), [8, 2]) + self.assertDType(batch[1], "float32") def test_static_shape_in_graph(self): directory = self._prepare_directory(num_classes=2) @@ -154,31 +177,51 @@ def symbolic_fn(ds): symbolic_fn(dataset) - def test_sample_count(self): + @parameterized.named_parameters( + ("tf", "tf"), + ("grain", "grain"), + ) + def test_sample_count(self, format): directory = self._prepare_directory(num_classes=4, count=15) dataset = image_dataset_utils.image_dataset_from_directory( - directory, batch_size=8, image_size=(18, 18), label_mode=None + directory, + batch_size=8, + image_size=(18, 18), + label_mode=None, + format=format, ) sample_count = 0 for batch in dataset: sample_count += batch.shape[0] self.assertEqual(sample_count, 15) - def test_image_dataset_from_directory_multiclass(self): + @parameterized.named_parameters( + ("tf", "tf"), + ("grain", "grain"), + ) + def test_image_dataset_from_directory_multiclass(self, format): directory = self._prepare_directory(num_classes=4, count=15) dataset = image_dataset_utils.image_dataset_from_directory( - directory, batch_size=8, image_size=(18, 18), label_mode=None + directory, + batch_size=8, + image_size=(18, 18), + label_mode=None, + format=format, ) if backend.config.image_data_format() == "channels_last": output_shape = [8, 18, 18, 3] else: output_shape = [8, 3, 18, 18] batch = next(iter(dataset)) - self.assertEqual(batch.shape, output_shape) + self.assertEqual(list(batch.shape), output_shape) dataset = image_dataset_utils.image_dataset_from_directory( - directory, batch_size=8, image_size=(18, 18), label_mode=None + directory, + batch_size=8, + image_size=(18, 18), + label_mode=None, + format=format, ) sample_count = 0 iterator = iter(dataset) @@ -187,32 +230,45 @@ def test_image_dataset_from_directory_multiclass(self): self.assertEqual(sample_count, 15) dataset = image_dataset_utils.image_dataset_from_directory( - directory, batch_size=8, image_size=(18, 18), label_mode="int" + directory, + batch_size=8, + image_size=(18, 18), + label_mode="int", + format=format, ) batch = next(iter(dataset)) self.assertLen(batch, 2) - self.assertEqual(batch[0].shape, output_shape) - self.assertEqual(batch[0].dtype.name, "float32") - self.assertEqual(batch[1].shape, (8,)) - self.assertEqual(batch[1].dtype.name, "int32") + self.assertEqual(list(batch[0].shape), output_shape) + self.assertDType(batch[0], "float32") + self.assertEqual(list(batch[1].shape), [8]) + self.assertDType(batch[1], "int32") dataset = image_dataset_utils.image_dataset_from_directory( directory, batch_size=8, image_size=(18, 18), label_mode="categorical", + format=format, ) batch = next(iter(dataset)) self.assertLen(batch, 2) - self.assertEqual(batch[0].shape, (output_shape)) - self.assertEqual(batch[0].dtype.name, "float32") - self.assertEqual(batch[1].shape, (8, 4)) - self.assertEqual(batch[1].dtype.name, "float32") - - def test_image_dataset_from_directory_color_modes(self): + self.assertEqual(list(batch[0].shape), output_shape) + self.assertDType(batch[0], "float32") + self.assertEqual(list(batch[1].shape), [8, 4]) + self.assertDType(batch[1], "float32") + + @parameterized.named_parameters( + ("tf", "tf"), + ("grain", "grain"), + ) + def test_image_dataset_from_directory_color_modes(self, format): directory = self._prepare_directory(num_classes=4, color_mode="rgba") dataset = image_dataset_utils.image_dataset_from_directory( - directory, batch_size=8, image_size=(18, 18), color_mode="rgba" + directory, + batch_size=8, + image_size=(18, 18), + color_mode="rgba", + format=format, ) if backend.config.image_data_format() == "channels_last": output_shape = [8, 18, 18, 4] @@ -220,14 +276,18 @@ def test_image_dataset_from_directory_color_modes(self): output_shape = [8, 4, 18, 18] batch = next(iter(dataset)) self.assertLen(batch, 2) - self.assertEqual(batch[0].shape, output_shape) - self.assertEqual(batch[0].dtype.name, "float32") + self.assertEqual(list(batch[0].shape), output_shape) + self.assertDType(batch[0], "float32") directory = self._prepare_directory( num_classes=4, color_mode="grayscale" ) dataset = image_dataset_utils.image_dataset_from_directory( - directory, batch_size=8, image_size=(18, 18), color_mode="grayscale" + directory, + batch_size=8, + image_size=(18, 18), + color_mode="grayscale", + format=format, ) if backend.config.image_data_format() == "channels_last": output_shape = [8, 18, 18, 1] @@ -235,10 +295,14 @@ def test_image_dataset_from_directory_color_modes(self): output_shape = [8, 1, 18, 18] batch = next(iter(dataset)) self.assertLen(batch, 2) - self.assertEqual(batch[0].shape, output_shape) - self.assertEqual(batch[0].dtype.name, "float32") - - def test_image_dataset_from_directory_validation_split(self): + self.assertEqual(list(batch[0].shape), output_shape) + self.assertDType(batch[0], "float32") + + @parameterized.named_parameters( + ("tf", "tf"), + ("grain", "grain"), + ) + def test_image_dataset_from_directory_validation_split(self, format): directory = self._prepare_directory(num_classes=2, count=10) dataset = image_dataset_utils.image_dataset_from_directory( directory, @@ -247,6 +311,7 @@ def test_image_dataset_from_directory_validation_split(self): validation_split=0.2, subset="training", seed=1337, + format=format, ) batch = next(iter(dataset)) self.assertLen(batch, 2) @@ -256,7 +321,7 @@ def test_image_dataset_from_directory_validation_split(self): else: train_output_shape = [8, 3, 18, 18] val_output_shape = [2, 3, 18, 18] - self.assertEqual(batch[0].shape, train_output_shape) + self.assertEqual(list(batch[0].shape), train_output_shape) dataset = image_dataset_utils.image_dataset_from_directory( directory, batch_size=10, @@ -264,10 +329,11 @@ def test_image_dataset_from_directory_validation_split(self): validation_split=0.2, subset="validation", seed=1337, + format=format, ) batch = next(iter(dataset)) self.assertLen(batch, 2) - self.assertEqual(batch[0].shape, val_output_shape) + self.assertEqual(list(batch[0].shape), val_output_shape) ( train_dataset, @@ -279,15 +345,20 @@ def test_image_dataset_from_directory_validation_split(self): validation_split=0.2, subset="both", seed=1337, + format=format, ) batch = next(iter(train_dataset)) self.assertLen(batch, 2) - self.assertEqual(batch[0].shape, train_output_shape) + self.assertEqual(list(batch[0].shape), train_output_shape) batch = next(iter(val_dataset)) self.assertLen(batch, 2) - self.assertEqual(batch[0].shape, val_output_shape) + self.assertEqual(list(batch[0].shape), val_output_shape) - def test_image_dataset_from_directory_manual_labels(self): + @parameterized.named_parameters( + ("tf", "tf"), + ("grain", "grain"), + ) + def test_image_dataset_from_directory_manual_labels(self, format): # Case: wrong number of labels directory = self._prepare_directory(num_classes=1, count=4) with self.assertRaisesRegex(ValueError, "match the number of files"): @@ -297,6 +368,7 @@ def test_image_dataset_from_directory_manual_labels(self): image_size=(18, 18), labels=[0, 1, 0], shuffle=False, + format=format, ) # Case: single directory @@ -307,6 +379,7 @@ def test_image_dataset_from_directory_manual_labels(self): image_size=(18, 18), labels=[0, 1, 0, 1], shuffle=False, + format=format, ) if backend.config.image_data_format() == "channels_last": output_shape = [18, 18, 3] @@ -315,7 +388,7 @@ def test_image_dataset_from_directory_manual_labels(self): self.assertEqual(dataset.class_names, ["0", "1"]) batch = next(iter(dataset)) self.assertLen(batch, 2) - self.assertEqual(batch[0].shape, [4] + output_shape) + self.assertEqual(list(batch[0].shape), [4] + output_shape) self.assertAllClose(batch[1], [0, 1, 0, 1]) # Case: multiple directories @@ -326,14 +399,19 @@ def test_image_dataset_from_directory_manual_labels(self): image_size=(18, 18), labels=[0, 1, 0, 1, 1, 1], shuffle=False, + format=format, ) self.assertEqual(dataset.class_names, ["0", "1"]) batch = next(iter(dataset)) self.assertLen(batch, 2) - self.assertEqual(batch[0].shape, [6] + output_shape) + self.assertEqual(list(batch[0].shape), [6] + output_shape) self.assertAllClose(batch[1], [0, 1, 0, 1, 1, 1]) - def test_image_dataset_from_directory_follow_links(self): + @parameterized.named_parameters( + ("tf", "tf"), + ("grain", "grain"), + ) + def test_image_dataset_from_directory_follow_links(self, format): directory = self._prepare_directory( num_classes=2, count=25, nested_dirs=True ) @@ -343,24 +421,36 @@ def test_image_dataset_from_directory_follow_links(self): image_size=(18, 18), label_mode=None, follow_links=True, + format=format, ) sample_count = 0 for batch in dataset: sample_count += batch.shape[0] self.assertEqual(sample_count, 25) - def test_image_dataset_from_directory_no_images(self): + @parameterized.named_parameters( + ("tf", "tf"), + ("grain", "grain"), + ) + def test_image_dataset_from_directory_no_images(self, format): directory = self._prepare_directory(num_classes=2, count=0) with self.assertRaisesRegex(ValueError, "No images found."): - _ = image_dataset_utils.image_dataset_from_directory(directory) + _ = image_dataset_utils.image_dataset_from_directory( + directory, format=format + ) - def test_image_dataset_from_directory_crop_to_aspect_ratio(self): + @parameterized.named_parameters( + ("tf", "tf"), + ("grain", "grain"), + ) + def test_image_dataset_from_directory_crop_to_aspect_ratio(self, format): directory = self._prepare_directory(num_classes=2, count=5) dataset = image_dataset_utils.image_dataset_from_directory( directory, batch_size=5, image_size=(18, 18), crop_to_aspect_ratio=True, + format=format, ) if backend.config.image_data_format() == "channels_last": output_shape = [5, 18, 18, 3] @@ -368,15 +458,20 @@ def test_image_dataset_from_directory_crop_to_aspect_ratio(self): output_shape = [5, 3, 18, 18] batch = next(iter(dataset)) self.assertLen(batch, 2) - self.assertEqual(batch[0].shape, output_shape) + self.assertEqual(list(batch[0].shape), output_shape) - def test_image_dataset_from_directory_pad_to_aspect_ratio(self): + @parameterized.named_parameters( + ("tf", "tf"), + ("grain", "grain"), + ) + def test_image_dataset_from_directory_pad_to_aspect_ratio(self, format): directory = self._prepare_directory(num_classes=2, count=5) dataset = image_dataset_utils.image_dataset_from_directory( directory, batch_size=5, image_size=(18, 18), pad_to_aspect_ratio=True, + format=format, ) if backend.config.image_data_format() == "channels_last": output_shape = [5, 18, 18, 3] @@ -384,26 +479,30 @@ def test_image_dataset_from_directory_pad_to_aspect_ratio(self): output_shape = [5, 3, 18, 18] batch = next(iter(dataset)) self.assertLen(batch, 2) - self.assertEqual(batch[0].shape, output_shape) + self.assertEqual(list(batch[0].shape), output_shape) - def test_image_dataset_from_directory_errors(self): + @parameterized.named_parameters( + ("tf", "tf"), + ("grain", "grain"), + ) + def test_image_dataset_from_directory_errors(self, format): directory = self._prepare_directory(num_classes=3, count=5) with self.assertRaisesRegex(ValueError, "`labels` argument should be"): _ = image_dataset_utils.image_dataset_from_directory( - directory, labels="other" + directory, labels="other", format=format ) with self.assertRaisesRegex( ValueError, "`label_mode` argument must be" ): _ = image_dataset_utils.image_dataset_from_directory( - directory, label_mode="other" + directory, label_mode="other", format=format ) with self.assertRaisesRegex(ValueError, "`color_mode` must be one of"): _ = image_dataset_utils.image_dataset_from_directory( - directory, color_mode="other" + directory, color_mode="other", format=format ) with self.assertRaisesRegex( @@ -413,6 +512,7 @@ def test_image_dataset_from_directory_errors(self): directory, labels=[0, 0, 1, 1, 1], class_names=["class_0", "class_1", "class_2"], + format=format, ) with self.assertRaisesRegex( @@ -420,26 +520,26 @@ def test_image_dataset_from_directory_errors(self): "Expected the lengths of `labels` to match the number of files", ): _ = image_dataset_utils.image_dataset_from_directory( - directory, labels=[0, 0, 1, 1] + directory, labels=[0, 0, 1, 1], format=format ) with self.assertRaisesRegex( ValueError, "`class_names` passed did not match" ): _ = image_dataset_utils.image_dataset_from_directory( - directory, class_names=["class_0", "wrong_class"] + directory, class_names=["class_0", "wrong_class"], format=format ) with self.assertRaisesRegex(ValueError, "there must be exactly 2"): _ = image_dataset_utils.image_dataset_from_directory( - directory, label_mode="binary" + directory, label_mode="binary", format=format ) with self.assertRaisesRegex( ValueError, "`validation_split` must be between 0 and 1" ): _ = image_dataset_utils.image_dataset_from_directory( - directory, validation_split=2 + directory, validation_split=2, format=format ) with self.assertRaisesRegex( @@ -447,22 +547,32 @@ def test_image_dataset_from_directory_errors(self): '`subset` must be either "training", "validation" or "both"', ): _ = image_dataset_utils.image_dataset_from_directory( - directory, validation_split=0.2, subset="other" + directory, validation_split=0.2, subset="other", format=format ) with self.assertRaisesRegex( ValueError, "`validation_split` must be set" ): _ = image_dataset_utils.image_dataset_from_directory( - directory, validation_split=0.0, subset="training" + directory, + validation_split=0.0, + subset="training", + format=format, ) with self.assertRaisesRegex(ValueError, "must provide a `seed`"): _ = image_dataset_utils.image_dataset_from_directory( - directory, validation_split=0.2, subset="training" + directory, + validation_split=0.2, + subset="training", + format=format, ) - def test_image_dataset_from_directory_not_batched(self): + @parameterized.named_parameters( + ("tf", "tf"), + ("grain", "grain"), + ) + def test_image_dataset_from_directory_not_batched(self, format): directory = self._prepare_directory(num_classes=2, count=2) dataset = image_dataset_utils.image_dataset_from_directory( directory, @@ -470,11 +580,16 @@ def test_image_dataset_from_directory_not_batched(self): image_size=(18, 18), label_mode=None, shuffle=False, + format=format, ) sample = next(iter(dataset)) self.assertEqual(len(sample.shape), 3) - def test_image_dataset_from_directory_shuffle(self): + @parameterized.named_parameters( + ("tf", "tf"), + ("grain", "grain"), + ) + def test_image_dataset_from_directory_shuffle(self, format): # TODO: add same test for train/val directory = self._prepare_directory( num_classes=2, count=25, nested_dirs=True @@ -486,14 +601,15 @@ def test_image_dataset_from_directory_shuffle(self): label_mode=None, follow_links=True, shuffle=False, + format=format, ) batches_1 = [] batches_2 = [] for b in dataset: - batches_1.append(b) + batches_1.append(ops.convert_to_numpy(b)) batches_1 = np.concatenate(batches_1, axis=0) for b in dataset: - batches_2.append(b) + batches_2.append(ops.convert_to_numpy(b)) batches_2 = np.concatenate(batches_2, axis=0) self.assertAllClose(batches_1, batches_2, atol=1e-6) @@ -505,16 +621,21 @@ def test_image_dataset_from_directory_shuffle(self): follow_links=True, shuffle=True, seed=1337, + format=format, ) batches_1 = [] batches_2 = [] for b in dataset: - batches_1.append(b) + batches_1.append(ops.convert_to_numpy(b)) batches_1 = np.concatenate(batches_1, axis=0) for b in dataset: - batches_2.append(b) + batches_2.append(ops.convert_to_numpy(b)) batches_2 = np.concatenate(batches_2, axis=0) - self.assertNotAllClose(batches_1, batches_2, atol=1e-6) + if format == "tf": + self.assertNotAllClose(batches_1, batches_2, atol=1e-6) + else: + # Grain shuffles deterministically, so we expect the same batches. + self.assertAllClose(batches_1, batches_2, atol=1e-6) # Test random seed determinism dataset = image_dataset_utils.image_dataset_from_directory( @@ -525,9 +646,10 @@ def test_image_dataset_from_directory_shuffle(self): follow_links=True, shuffle=True, seed=1337, + format=format, ) batches_1_alt = [] for b in dataset: - batches_1_alt.append(b) + batches_1_alt.append(ops.convert_to_numpy(b)) batches_1_alt = np.concatenate(batches_1_alt, axis=0) self.assertAllClose(batches_1, batches_1_alt, atol=1e-6) diff --git a/keras/src/utils/module_utils.py b/keras/src/utils/module_utils.py index d81ec05028e4..286394a99358 100644 --- a/keras/src/utils/module_utils.py +++ b/keras/src/utils/module_utils.py @@ -58,3 +58,4 @@ def __repr__(self): optree = LazyModule("optree") dmtree = LazyModule("tree") tf2onnx = LazyModule("tf2onnx") +grain = LazyModule("grain") diff --git a/keras/src/utils/text_dataset_utils.py b/keras/src/utils/text_dataset_utils.py index a76134818570..79ae687faf11 100644 --- a/keras/src/utils/text_dataset_utils.py +++ b/keras/src/utils/text_dataset_utils.py @@ -2,6 +2,8 @@ from keras.src.api_export import keras_export from keras.src.utils import dataset_utils +from keras.src.utils.grain_utils import make_string_batch +from keras.src.utils.module_utils import grain from keras.src.utils.module_utils import tensorflow as tf @@ -23,6 +25,7 @@ def text_dataset_from_directory( validation_split=None, subset=None, follow_links=False, + format="tf", verbose=True, ): """Generates a `tf.data.Dataset` from text files in a directory. @@ -91,19 +94,34 @@ def text_dataset_from_directory( (the training and validation datasets respectively). follow_links: Whether to visits subdirectories pointed to by symlinks. Defaults to `False`. + format: The format of the return object. Defaults to `"tf"`. Available + options are: + - `"tf"`: returns a `tf.data.Dataset` object. Requires + TensorFlow to be installed. + - `"grain"`: returns a `grain.IterDataset` object. Requires + Grain to be installed. verbose: Whether to display number information on classes and number of files found. Defaults to `True`. Returns: - A `tf.data.Dataset` object. + A `tf.data.Dataset` (`format="tf"`) or `grain.IterDataset` + (`format="grain"`) object. + When `format="tf"`: - If `label_mode` is `None`, it yields `string` tensors of shape `(batch_size,)`, containing the contents of a batch of text files. - Otherwise, it yields a tuple `(texts, labels)`, where `texts` has shape `(batch_size,)` and `labels` follows the format described below. + When `format="grain"`: + - If `label_mode` is `None`, it yields a list of Python strings containing + the contents of a batch of text files. + - Otherwise, it yields a tuple `(texts, labels)`, where `texts` + is a list of Python strings and `labels` follows the format described + below. + Rules regarding labels format: - if `label_mode` is `int`, the labels are an `int32` tensor of shape @@ -137,6 +155,11 @@ def text_dataset_from_directory( '"categorical", "binary", ' f"or None. Received: label_mode={label_mode}" ) + if format not in ("tf", "grain"): + raise ValueError( + '`format` should be either "tf" or "grain". ' + f"Received: format={format}" + ) if labels is None or label_mode is None: labels = None label_mode = None @@ -199,6 +222,7 @@ def text_dataset_from_directory( shuffle=shuffle, shuffle_buffer_size=shuffle_buffer_size, seed=seed, + format=format, ) val_dataset = paths_and_labels_to_dataset( file_paths=file_paths_val, @@ -207,14 +231,25 @@ def text_dataset_from_directory( num_classes=len(class_names) if class_names else 0, max_length=max_length, shuffle=False, + format=format, ) - if batch_size is not None: - train_dataset = train_dataset.batch(batch_size) - val_dataset = val_dataset.batch(batch_size) - - train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE) - val_dataset = val_dataset.prefetch(tf.data.AUTOTUNE) + if format == "tf": + if batch_size is not None: + train_dataset = train_dataset.batch(batch_size) + val_dataset = val_dataset.batch(batch_size) + train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE) + val_dataset = val_dataset.prefetch(tf.data.AUTOTUNE) + else: + train_dataset = train_dataset.to_iter_dataset() + val_dataset = val_dataset.to_iter_dataset() + if batch_size is not None: + train_dataset = train_dataset.batch( + batch_size, batch_fn=make_string_batch + ) + val_dataset = val_dataset.batch( + batch_size, batch_fn=make_string_batch + ) # Users may need to reference `class_names`. train_dataset.class_names = class_names @@ -238,10 +273,17 @@ def text_dataset_from_directory( shuffle=shuffle, shuffle_buffer_size=shuffle_buffer_size, seed=seed, + format=format, ) - if batch_size is not None: - dataset = dataset.batch(batch_size) - dataset = dataset.prefetch(tf.data.AUTOTUNE) + + if format == "tf": + if batch_size is not None: + dataset = dataset.batch(batch_size) + dataset = dataset.prefetch(tf.data.AUTOTUNE) + else: + dataset = dataset.to_iter_dataset() + if batch_size is not None: + dataset = dataset.batch(batch_size, batch_fn=make_string_batch) # Users may need to reference `class_names`. dataset.class_names = class_names @@ -257,11 +299,47 @@ def paths_and_labels_to_dataset( shuffle=False, shuffle_buffer_size=None, seed=None, + format="tf", +): + """Constructs a dataset of text strings and labels.""" + if format == "tf": + return _paths_and_labels_to_dataset_tf( + file_paths, + labels, + label_mode, + num_classes, + max_length, + shuffle, + shuffle_buffer_size, + seed, + ) + elif format == "grain": + return _paths_and_labels_to_dataset_grain( + file_paths, + labels, + label_mode, + num_classes, + max_length, + shuffle, + shuffle_buffer_size, + seed, + ) + + +def _paths_and_labels_to_dataset_tf( + file_paths, + labels, + label_mode, + num_classes, + max_length, + shuffle=False, + shuffle_buffer_size=None, + seed=None, ): """Constructs a dataset of text strings and labels.""" path_ds = tf.data.Dataset.from_tensor_slices(file_paths) if label_mode: - label_ds = dataset_utils.labels_to_dataset( + label_ds = dataset_utils.labels_to_dataset_tf( labels, label_mode, num_classes ) ds = tf.data.Dataset.zip((path_ds, label_ds)) @@ -273,19 +351,62 @@ def paths_and_labels_to_dataset( if label_mode: ds = ds.map( - lambda x, y: (path_to_string_content(x, max_length), y), + lambda x, y: (_path_to_string_content_tf(x, max_length), y), num_parallel_calls=tf.data.AUTOTUNE, ) else: ds = ds.map( - lambda x: path_to_string_content(x, max_length), + lambda x: _path_to_string_content_tf(x, max_length), num_parallel_calls=tf.data.AUTOTUNE, ) return ds -def path_to_string_content(path, max_length): +def _path_to_string_content_tf(path, max_length): txt = tf.io.read_file(path) if max_length is not None: txt = tf.strings.substr(txt, 0, max_length) return txt + + +def _paths_and_labels_to_dataset_grain( + file_paths, + labels, + label_mode, + num_classes, + max_length, + shuffle=False, + shuffle_buffer_size=None, + seed=None, +): + """Constructs a dataset of text strings and labels.""" + path_ds = grain.MapDataset.source(file_paths) + if label_mode: + label_ds = dataset_utils.labels_to_dataset_grain( + labels, label_mode, num_classes + ) + ds = grain.experimental.ZipMapDataset([path_ds, label_ds]) + else: + ds = path_ds + + if shuffle: + ds = ds.shuffle(seed=seed) + + if label_mode: + ds = ds.map( + lambda data: ( + _path_to_string_content_grain(data[0], max_length), + data[1], + ), + ) + else: + ds = ds.map(lambda x: _path_to_string_content_grain(x, max_length)) + return ds + + +def _path_to_string_content_grain(path, max_length): + with open(path, "r") as f: + txt = f.read() + if max_length is not None: + txt = txt[:max_length] + return txt diff --git a/keras/src/utils/text_dataset_utils_test.py b/keras/src/utils/text_dataset_utils_test.py index 6e59b1bb67a3..cfa5d30b1878 100644 --- a/keras/src/utils/text_dataset_utils_test.py +++ b/keras/src/utils/text_dataset_utils_test.py @@ -2,6 +2,9 @@ import random import string +from absl.testing import parameterized + +from keras.src import backend from keras.src import testing from keras.src.utils import text_dataset_utils @@ -42,7 +45,11 @@ def _prepare_directory( f.write(text) return temp_dir - def test_text_dataset_from_directory_standalone(self): + @parameterized.named_parameters( + ("tf", "tf"), + ("grain", "grain"), + ) + def test_text_dataset_from_directory_standalone(self, format): # Test retrieving txt files without labels from a directory and its # subdirs. Save a few extra files in the parent directory. directory = self._prepare_directory(count=7, num_classes=2) @@ -55,103 +62,158 @@ def test_text_dataset_from_directory_standalone(self): f.write(text) dataset = text_dataset_utils.text_dataset_from_directory( - directory, batch_size=5, label_mode=None, max_length=10 + directory, + batch_size=5, + label_mode=None, + max_length=10, + format=format, ) batch = next(iter(dataset)) # We just return the texts, no labels - self.assertEqual(batch.shape, (5,)) - self.assertEqual(batch.dtype.name, "string") + if format == "tf" or backend.backend() == "tensorflow": + self.assertEqual(list(batch.shape), [5]) + self.assertDType(batch, "string") + else: + self.assertLen(batch, 5) + self.assertIsInstance(batch[0], str) # Count samples batch_count = 0 sample_count = 0 for batch in dataset: batch_count += 1 - sample_count += batch.shape[0] + sample_count += len(batch) self.assertEqual(batch_count, 2) self.assertEqual(sample_count, 10) - def test_text_dataset_from_directory_binary(self): + @parameterized.named_parameters( + ("tf", "tf"), + ("grain", "grain"), + ) + def test_text_dataset_from_directory_binary(self, format=format): directory = self._prepare_directory(num_classes=2) dataset = text_dataset_utils.text_dataset_from_directory( - directory, batch_size=8, label_mode="int", max_length=10 + directory, + batch_size=8, + label_mode="int", + max_length=10, + format=format, ) batch = next(iter(dataset)) self.assertLen(batch, 2) - self.assertEqual(batch[0].shape, (8,)) - self.assertEqual(batch[0].dtype.name, "string") - self.assertEqual(len(batch[0].numpy()[0]), 10) # Test max_length - self.assertEqual(batch[1].shape, (8,)) - self.assertEqual(batch[1].dtype.name, "int32") + if format == "tf" or backend.backend() == "tensorflow": + self.assertEqual(batch[0].shape, (8,)) + self.assertDType(batch[0], "string") + self.assertEqual(len(batch[0].numpy()[0]), 10) # Test max_length + else: + self.assertLen(batch[0], 8) + self.assertIsInstance(batch[0][0], str) + self.assertLen(batch[0][0], 10) # Test max_length + self.assertEqual(list(batch[1].shape), [8]) + self.assertDType(batch[1], "int32") dataset = text_dataset_utils.text_dataset_from_directory( - directory, batch_size=8, label_mode="binary" + directory, + batch_size=8, + label_mode="binary", + format=format, ) batch = next(iter(dataset)) self.assertLen(batch, 2) - self.assertEqual(batch[0].shape, (8,)) - self.assertEqual(batch[0].dtype.name, "string") - self.assertEqual(batch[1].shape, (8, 1)) - self.assertEqual(batch[1].dtype.name, "float32") + if format == "tf" or backend.backend() == "tensorflow": + self.assertEqual(list(batch[0].shape), [8]) + self.assertEqual(batch[0].dtype.name, "string") + else: + self.assertLen(batch[0], 8) + self.assertIsInstance(batch[0][0], str) + self.assertEqual(list(batch[1].shape), [8, 1]) + self.assertDType(batch[1], "float32") dataset = text_dataset_utils.text_dataset_from_directory( - directory, batch_size=8, label_mode="categorical" + directory, + batch_size=8, + label_mode="categorical", + format=format, ) batch = next(iter(dataset)) self.assertLen(batch, 2) - self.assertEqual(batch[0].shape, (8,)) - self.assertEqual(batch[0].dtype.name, "string") - self.assertEqual(batch[1].shape, (8, 2)) - self.assertEqual(batch[1].dtype.name, "float32") - - def test_sample_count(self): + if format == "tf" or backend.backend() == "tensorflow": + self.assertEqual(list(batch[0].shape), [8]) + self.assertEqual(batch[0].dtype.name, "string") + else: + self.assertLen(batch[0], 8) + self.assertIsInstance(batch[0][0], str) + self.assertEqual(list(batch[1].shape), [8, 2]) + self.assertDType(batch[1], "float32") + + @parameterized.named_parameters( + ("tf", "tf"), + ("grain", "grain"), + ) + def test_sample_count(self, format): directory = self._prepare_directory(num_classes=4, count=15) dataset = text_dataset_utils.text_dataset_from_directory( - directory, batch_size=8, label_mode=None + directory, batch_size=8, label_mode=None, format=format ) sample_count = 0 for batch in dataset: - sample_count += batch.shape[0] + sample_count += len(batch) self.assertEqual(sample_count, 15) - def test_text_dataset_from_directory_multiclass(self): + @parameterized.named_parameters( + ("tf", "tf"), + ("grain", "grain"), + ) + def test_text_dataset_from_directory_multiclass(self, format): directory = self._prepare_directory(num_classes=4, count=15) dataset = text_dataset_utils.text_dataset_from_directory( - directory, batch_size=8, label_mode=None + directory, batch_size=8, label_mode=None, format=format ) batch = next(iter(dataset)) - self.assertEqual(batch.shape, (8,)) + self.assertLen(batch, 8) dataset = text_dataset_utils.text_dataset_from_directory( - directory, batch_size=8, label_mode=None + directory, batch_size=8, label_mode=None, format=format ) sample_count = 0 iterator = iter(dataset) for batch in dataset: - sample_count += next(iterator).shape[0] + sample_count += len(next(iterator)) self.assertEqual(sample_count, 15) dataset = text_dataset_utils.text_dataset_from_directory( - directory, batch_size=8, label_mode="int" + directory, batch_size=8, label_mode="int", format=format ) batch = next(iter(dataset)) self.assertLen(batch, 2) - self.assertEqual(batch[0].shape, (8,)) - self.assertEqual(batch[0].dtype.name, "string") - self.assertEqual(batch[1].shape, (8,)) - self.assertEqual(batch[1].dtype.name, "int32") + if format == "tf" or backend.backend() == "tensorflow": + self.assertEqual(list(batch[0].shape), [8]) + self.assertEqual(batch[0].dtype.name, "string") + else: + self.assertLen(batch[0], 8) + self.assertIsInstance(batch[0][0], str) + self.assertEqual(list(batch[1].shape), [8]) + self.assertDType(batch[1], "int32") dataset = text_dataset_utils.text_dataset_from_directory( - directory, batch_size=8, label_mode="categorical" + directory, batch_size=8, label_mode="categorical", format=format ) batch = next(iter(dataset)) self.assertLen(batch, 2) - self.assertEqual(batch[0].shape, (8,)) - self.assertEqual(batch[0].dtype.name, "string") - self.assertEqual(batch[1].shape, (8, 4)) - self.assertEqual(batch[1].dtype.name, "float32") - - def test_text_dataset_from_directory_validation_split(self): + if format == "tf" or backend.backend() == "tensorflow": + self.assertEqual(list(batch[0].shape), [8]) + self.assertEqual(batch[0].dtype.name, "string") + else: + self.assertLen(batch[0], 8) + self.assertIsInstance(batch[0][0], str) + self.assertEqual(list(batch[1].shape), [8, 4]) + self.assertDType(batch[1], "float32") + + @parameterized.named_parameters( + ("tf", "tf"), + ("grain", "grain"), + ) + def test_text_dataset_from_directory_validation_split(self, format): directory = self._prepare_directory(num_classes=2, count=10) dataset = text_dataset_utils.text_dataset_from_directory( directory, @@ -159,20 +221,22 @@ def test_text_dataset_from_directory_validation_split(self): validation_split=0.2, subset="training", seed=1337, + format=format, ) batch = next(iter(dataset)) self.assertLen(batch, 2) - self.assertEqual(batch[0].shape, (8,)) + self.assertLen(batch[0], 8) dataset = text_dataset_utils.text_dataset_from_directory( directory, batch_size=10, validation_split=0.2, subset="validation", seed=1337, + format=format, ) batch = next(iter(dataset)) self.assertLen(batch, 2) - self.assertEqual(batch[0].shape, (2,)) + self.assertLen(batch[0], 2) ( train_dataset, @@ -183,53 +247,76 @@ def test_text_dataset_from_directory_validation_split(self): validation_split=0.2, subset="both", seed=1337, + format=format, ) batch = next(iter(train_dataset)) self.assertLen(batch, 2) - self.assertEqual(batch[0].shape, (8,)) + self.assertLen(batch[0], 8) batch = next(iter(val_dataset)) self.assertLen(batch, 2) - self.assertEqual(batch[0].shape, (2,)) + self.assertLen(batch[0], 2) - def test_text_dataset_from_directory_manual_labels(self): + @parameterized.named_parameters( + ("tf", "tf"), + ("grain", "grain"), + ) + def test_text_dataset_from_directory_manual_labels(self, format): directory = self._prepare_directory(num_classes=2, count=2) dataset = text_dataset_utils.text_dataset_from_directory( - directory, batch_size=8, labels=[0, 1], shuffle=False + directory, batch_size=8, labels=[0, 1], shuffle=False, format=format ) batch = next(iter(dataset)) self.assertLen(batch, 2) self.assertAllClose(batch[1], [0, 1]) - def test_text_dataset_from_directory_follow_links(self): + @parameterized.named_parameters( + ("tf", "tf"), + ("grain", "grain"), + ) + def test_text_dataset_from_directory_follow_links(self, format): directory = self._prepare_directory( num_classes=2, count=25, nested_dirs=True ) dataset = text_dataset_utils.text_dataset_from_directory( - directory, batch_size=8, label_mode=None, follow_links=True + directory, + batch_size=8, + label_mode=None, + follow_links=True, + format=format, ) sample_count = 0 for batch in dataset: - sample_count += batch.shape[0] + sample_count += len(batch) self.assertEqual(sample_count, 25) - def test_text_dataset_from_directory_no_files(self): + @parameterized.named_parameters( + ("tf", "tf"), + ("grain", "grain"), + ) + def test_text_dataset_from_directory_no_files(self, format): directory = self._prepare_directory(num_classes=2, count=0) with self.assertRaisesRegex(ValueError, "No text files found"): - _ = text_dataset_utils.text_dataset_from_directory(directory) + _ = text_dataset_utils.text_dataset_from_directory( + directory, format=format + ) - def test_text_dataset_from_directory_errors(self): + @parameterized.named_parameters( + ("tf", "tf"), + ("grain", "grain"), + ) + def test_text_dataset_from_directory_errors(self, format): directory = self._prepare_directory(num_classes=3, count=5) with self.assertRaisesRegex(ValueError, "`labels` argument should be"): _ = text_dataset_utils.text_dataset_from_directory( - directory, labels="other" + directory, labels="other", format=format ) with self.assertRaisesRegex( ValueError, "`label_mode` argument must be" ): _ = text_dataset_utils.text_dataset_from_directory( - directory, label_mode="other" + directory, label_mode="other", format=format ) with self.assertRaisesRegex( @@ -239,6 +326,7 @@ def test_text_dataset_from_directory_errors(self): directory, labels=[0, 0, 1, 1, 1], class_names=["class_0", "class_1", "class_2"], + format=format, ) with self.assertRaisesRegex( @@ -246,26 +334,26 @@ def test_text_dataset_from_directory_errors(self): "Expected the lengths of `labels` to match the number of files", ): _ = text_dataset_utils.text_dataset_from_directory( - directory, labels=[0, 0, 1, 1] + directory, labels=[0, 0, 1, 1], format=format ) with self.assertRaisesRegex( ValueError, "`class_names` passed did not match" ): _ = text_dataset_utils.text_dataset_from_directory( - directory, class_names=["class_0", "wrong_class"] + directory, class_names=["class_0", "wrong_class"], format=format ) with self.assertRaisesRegex(ValueError, "there must be exactly 2"): _ = text_dataset_utils.text_dataset_from_directory( - directory, label_mode="binary" + directory, label_mode="binary", format=format ) with self.assertRaisesRegex( ValueError, "`validation_split` must be between 0 and 1" ): _ = text_dataset_utils.text_dataset_from_directory( - directory, validation_split=2 + directory, validation_split=2, format=format ) with self.assertRaisesRegex( @@ -273,26 +361,43 @@ def test_text_dataset_from_directory_errors(self): '`subset` must be either "training", "validation" or "both"', ): _ = text_dataset_utils.text_dataset_from_directory( - directory, validation_split=0.2, subset="other" + directory, validation_split=0.2, subset="other", format=format ) with self.assertRaisesRegex( ValueError, "`validation_split` must be set" ): _ = text_dataset_utils.text_dataset_from_directory( - directory, validation_split=0.0, subset="training" + directory, + validation_split=0.0, + subset="training", + format=format, ) with self.assertRaisesRegex(ValueError, "must provide a `seed`"): _ = text_dataset_utils.text_dataset_from_directory( - directory, validation_split=0.2, subset="training" + directory, + validation_split=0.2, + subset="training", + format=format, ) - def test_text_dataset_from_directory_not_batched(self): + @parameterized.named_parameters( + ("tf", "tf"), + ("grain", "grain"), + ) + def test_text_dataset_from_directory_not_batched(self, format): directory = self._prepare_directory() dataset = text_dataset_utils.text_dataset_from_directory( - directory, batch_size=None, label_mode=None, follow_links=True + directory, + batch_size=None, + label_mode=None, + follow_links=True, + format=format, ) sample = next(iter(dataset)) - self.assertEqual(len(sample.shape), 0) + if format == "tf": + self.assertEqual(len(sample.shape), 0) + else: + self.assertIsInstance(sample, str) From 17802a0c158a296fae7147cdeb292a26e6ee852a Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Sun, 17 Aug 2025 18:03:53 +0800 Subject: [PATCH 2/3] Fix channels_first bug. --- keras/src/utils/image_dataset_utils.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/keras/src/utils/image_dataset_utils.py b/keras/src/utils/image_dataset_utils.py index 35cb802f0e84..3874df3ac99d 100755 --- a/keras/src/utils/image_dataset_utils.py +++ b/keras/src/utils/image_dataset_utils.py @@ -654,14 +654,13 @@ def _load_image_grain( ) with backend.device_scope("cpu"): - img = np.array(img) - if img.ndim == 2: + img = ops.convert_to_tensor(np.array(img), dtype="float32") + if len(img.shape) == 2: # If the image is grayscale, expand dims to add channel axis. # The reason is that `ops.image.resize` expects 3D or 4D tensors. - img = np.expand_dims( - img, axis=-1 if data_format == "channels_last" else 0 - ) - img = ops.convert_to_tensor(img, dtype="float32") + img = ops.expand_dims(img, axis=-1) + if data_format == "channels_first": + img = ops.transpose(img, (2, 0, 1)) img = ops.image.resize( img, size=image_size, From 3d3644cb72885c7dfde708e30aba5bc056470f65 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Sun, 17 Aug 2025 22:06:55 +0800 Subject: [PATCH 3/3] Refine the docstrings. --- keras/src/utils/image_dataset_utils.py | 8 ++++++-- keras/src/utils/text_dataset_utils.py | 8 ++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/keras/src/utils/image_dataset_utils.py b/keras/src/utils/image_dataset_utils.py index 3874df3ac99d..a9fe50050187 100755 --- a/keras/src/utils/image_dataset_utils.py +++ b/keras/src/utils/image_dataset_utils.py @@ -51,7 +51,7 @@ def image_dataset_from_directory( format="tf", verbose=True, ): - """Generates a `tf.data.Dataset` from image files in a directory. + """Generates a dataset from image files in a directory. If your directory structure is: @@ -66,13 +66,17 @@ def image_dataset_from_directory( ``` Then calling `image_dataset_from_directory(main_directory, - labels='inferred')` will return a `tf.data.Dataset` that yields batches of + labels='inferred')` will return a dataset that yields batches of images from the subdirectories `class_a` and `class_b`, together with labels 0 and 1 (0 corresponding to `class_a` and 1 corresponding to `class_b`). Supported image formats: `.jpeg`, `.jpg`, `.png`, `.bmp`, `.gif`. Animated gifs are truncated to the first frame. + By default, this function will return a `tf.data.Dataset` object. You can + set `format="grain"` to return a `grain.IterDataset` object instead, which + removes the TensorFlow dependency. + Args: directory: Directory where the data is located. If `labels` is `"inferred"`, it should contain diff --git a/keras/src/utils/text_dataset_utils.py b/keras/src/utils/text_dataset_utils.py index 79ae687faf11..d329d6944540 100644 --- a/keras/src/utils/text_dataset_utils.py +++ b/keras/src/utils/text_dataset_utils.py @@ -28,7 +28,7 @@ def text_dataset_from_directory( format="tf", verbose=True, ): - """Generates a `tf.data.Dataset` from text files in a directory. + """Generates a dataset from text files in a directory. If your directory structure is: @@ -43,12 +43,16 @@ def text_dataset_from_directory( ``` Then calling `text_dataset_from_directory(main_directory, - labels='inferred')` will return a `tf.data.Dataset` that yields batches of + labels='inferred')` will return a dataset that yields batches of texts from the subdirectories `class_a` and `class_b`, together with labels 0 and 1 (0 corresponding to `class_a` and 1 corresponding to `class_b`). Only `.txt` files are supported at this time. + By default, this function will return a `tf.data.Dataset` object. You can + set `format="grain"` to return a `grain.IterDataset` object instead, which + removes the TensorFlow dependency. + Args: directory: Directory where the data is located. If `labels` is `"inferred"`, it should contain