diff --git a/keras_rs/src/layers/embedding/base_distributed_embedding.py b/keras_rs/src/layers/embedding/base_distributed_embedding.py index c30993f..88329f2 100644 --- a/keras_rs/src/layers/embedding/base_distributed_embedding.py +++ b/keras_rs/src/layers/embedding/base_distributed_embedding.py @@ -337,18 +337,33 @@ def step(data): embedding_layer = DistributedEmbedding(feature_configs) # Add preprocessing to a data input pipeline. - def train_dataset_generator(): - for (inputs, weights), labels in iter(train_dataset): + def preprocessed_dataset_generator(dataset): + for (inputs, weights), labels in iter(dataset): yield embedding_layer.preprocess( inputs, weights, training=True ), labels - preprocessed_train_dataset = train_dataset_generator() + preprocessed_train_dataset = preprocessed_dataset_generator(train_dataset) ``` This explicit preprocessing stage combines the input and optional weights, so the new data can be passed directly into the `inputs` argument of the layer or model. + **NOTE**: When working in a multi-host setting with data parallelism, the + data needs to be sharded properly across hosts. If the original dataset is + of type `tf.data.Dataset`, it will need to be manually sharded _prior_ to + applying the preprocess generator: + ```python + # Manually shard the dataset across hosts. + train_dataset = distribution.distribute_dataset(train_dataset) + distribution.auto_shard_dataset = False # Dataset is already sharded. + + # Add a preprocessing stage to the distributed data input pipeline. + train_dataset = preprocessed_dataset_generator(train_dataset) + ``` + If the original dataset is _not_ a `tf.data.Dataset`, it must already be + pre-sharded across hosts. + #### Usage in a Keras model Once the global distribution is set and the input preprocessing pipeline