Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion keras/src/utils/audio_dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ def paths_and_labels_to_dataset(
"""Constructs a fixed-size dataset of audio and labels."""
path_ds = tf.data.Dataset.from_tensor_slices(file_paths)
if label_mode:
label_ds = dataset_utils.labels_to_dataset(
label_ds = dataset_utils.labels_to_dataset_tf(
labels, label_mode, num_classes
)
ds = tf.data.Dataset.zip((path_ds, label_ds))
Expand Down
91 changes: 83 additions & 8 deletions keras/src/utils/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

from keras.src import tree
from keras.src.api_export import keras_export
from keras.src.utils import file_utils
from keras.src.utils import io_utils
from keras.src.utils.module_utils import grain
from keras.src.utils.module_utils import tensorflow as tf


Expand Down Expand Up @@ -299,6 +301,17 @@ def is_torch_dataset(dataset):
return False


def is_grain_dataset(dataset):
if hasattr(dataset, "__class__"):
for parent in dataset.__class__.__mro__:
if parent.__name__ in (
"MapDataset",
"IterDataset",
) and str(parent.__module__).startswith("grain._src.python"):
return True
return False


def _rescale_dataset_split_sizes(left_size, right_size, total_length):
"""Rescale the dataset split sizes.

Expand Down Expand Up @@ -476,6 +489,10 @@ def _get_type_spec(dataset):
from torch.utils.data import Dataset as TorchDataset

return TorchDataset
elif is_grain_dataset(dataset):
from grain import MapDataset

return MapDataset
else:
return None

Expand Down Expand Up @@ -525,10 +542,17 @@ def index_directory(
- class_names: names of the classes corresponding to these labels, in
order.
"""
if file_utils.is_remote_path(directory):
os_module = tf.io.gfile
path_module = tf.io.gfile
else:
os_module = os
path_module = os.path

if labels == "inferred":
subdirs = []
for subdir in sorted(tf.io.gfile.listdir(directory)):
if tf.io.gfile.isdir(tf.io.gfile.join(directory, subdir)):
for subdir in sorted(os_module.listdir(directory)):
if path_module.isdir(path_module.join(directory, subdir)):
if not subdir.startswith("."):
if subdir.endswith("/"):
subdir = subdir[:-1]
Expand Down Expand Up @@ -566,7 +590,7 @@ def index_directory(
results = []
filenames = []

for dirpath in (tf.io.gfile.join(directory, subdir) for subdir in subdirs):
for dirpath in (path_module.join(directory, subdir) for subdir in subdirs):
results.append(
pool.apply_async(
index_subdirectory,
Expand Down Expand Up @@ -608,7 +632,7 @@ def index_directory(
)
pool.close()
pool.join()
file_paths = [tf.io.gfile.join(directory, fname) for fname in filenames]
file_paths = [path_module.join(directory, fname) for fname in filenames]

if shuffle:
# Shuffle globally to erase macro-structure
Expand All @@ -623,8 +647,10 @@ def index_directory(


def iter_valid_files(directory, follow_links, formats):
io_module = tf.io.gfile if file_utils.is_remote_path(directory) else os

if not follow_links:
walk = tf.io.gfile.walk(directory)
walk = io_module.walk(directory)
else:
walk = os.walk(directory, followlinks=follow_links)
for root, _, files in sorted(walk, key=lambda x: x[0]):
Expand All @@ -648,14 +674,18 @@ def index_subdirectory(directory, class_indices, follow_links, formats):
paths, and `labels` is a list of integer labels corresponding
to these files.
"""
path_module = (
tf.io.gfile if file_utils.is_remote_path(directory) else os.path
)

dirname = os.path.basename(directory)
valid_files = iter_valid_files(directory, follow_links, formats)
labels = []
filenames = []
for root, fname in valid_files:
labels.append(class_indices[dirname])
absolute_path = tf.io.gfile.join(root, fname)
relative_path = tf.io.gfile.join(
absolute_path = path_module.join(root, fname)
relative_path = path_module.join(
dirname, os.path.relpath(absolute_path, directory)
)
filenames.append(relative_path)
Expand Down Expand Up @@ -700,7 +730,7 @@ def get_training_or_validation_split(samples, labels, validation_split, subset):
return samples, labels


def labels_to_dataset(labels, label_mode, num_classes):
def labels_to_dataset_tf(labels, label_mode, num_classes):
"""Create a `tf.data.Dataset` from the list/tuple of labels.

Args:
Expand Down Expand Up @@ -730,6 +760,51 @@ def labels_to_dataset(labels, label_mode, num_classes):
return label_ds


def labels_to_dataset_grain(labels, label_mode, num_classes):
"""Create a `grain.MapDataset` from the list/tuple of labels.

Args:
labels: list/tuple of labels to be converted into a `grain.MapDataset`.
label_mode: String describing the encoding of `labels`. Options are:
- `"binary"` indicates that the labels (there can be only 2) are encoded
as `float32` scalars with values 0 or 1
(e.g. for `binary_crossentropy`).
- `"categorical"` means that the labels are mapped into a categorical
vector. (e.g. for `categorical_crossentropy` loss).
num_classes: number of classes of labels.

Returns:
A `grain.MapDataset` instance.
"""
from keras.src import backend
from keras.src import ops

if label_mode not in ("binary", "categorical", "int"):
raise ValueError(
f"Invalid `label_mode`: {label_mode}. "
"Expected one of: 'binary', 'categorical', 'int'."
)

def preprocess_labels_in_cpu(label_mode, x, num_classes):
with backend.device_scope("cpu"):
if label_mode == "binary":
return ops.expand_dims(
ops.convert_to_tensor(x, dtype="float32"), axis=-1
)
elif label_mode == "categorical":
return ops.one_hot(
ops.convert_to_tensor(x, dtype="int32"), num_classes
)
else:
return ops.convert_to_tensor(x, dtype="int32")

label_ds = grain.MapDataset.source(labels)
label_ds = label_ds.map(
lambda x: preprocess_labels_in_cpu(label_mode, x, num_classes),
)
return label_ds


def check_validation_split_arg(validation_split, subset, shuffle, seed):
"""Raise errors in case of invalid argument values.

Expand Down
33 changes: 33 additions & 0 deletions keras/src/utils/grain_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from keras.src import backend
from keras.src import tree


def make_batch(values):
from keras.src import ops

if not values:
raise ValueError("Cannot batch 0 values. Please file a bug.")

with backend.device_scope("cpu"):
return tree.map_structure(lambda *xs: ops.stack(xs), *values)


def make_string_batch(values):
from keras.src import ops

if not values:
raise ValueError("Cannot batch 0 values. Please file a bug.")

def batch_fn(*xs):
if isinstance(xs[0], str):
if backend.backend() == "tensorflow":
import tensorflow as tf

xs = [tf.convert_to_tensor(x, dtype=tf.string) for x in xs]
xs = tf.stack(xs)
return xs
else:
return ops.stack(xs)

with backend.device_scope("cpu"):
return tree.map_structure(batch_fn, *values)
Loading