Skip to content

Add note about multi-host dataset sharding. #131

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions keras_rs/src/layers/embedding/base_distributed_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,18 +337,33 @@ def step(data):
embedding_layer = DistributedEmbedding(feature_configs)

# Add preprocessing to a data input pipeline.
def train_dataset_generator():
for (inputs, weights), labels in iter(train_dataset):
def preprocessed_dataset_generator(dataset):
for (inputs, weights), labels in iter(dataset):
yield embedding_layer.preprocess(
inputs, weights, training=True
), labels

preprocessed_train_dataset = train_dataset_generator()
preprocessed_train_dataset = preprocessed_dataset_generator(train_dataset)
```
This explicit preprocessing stage combines the input and optional weights,
so the new data can be passed directly into the `inputs` argument of the
layer or model.

**NOTE**: When working in a multi-host setting with data parallelism, the
data needs to be sharded properly across hosts. If the original dataset is
of type `tf.data.Dataset`, it will need to be manually sharded _prior_ to
applying the preprocess generator:
```python
# Manually shard the dataset across hosts.
train_dataset = distribution.distribute_dataset(train_dataset)
distribution.auto_shard_dataset = False # Dataset is already sharded.

# Add a preprocessing stage to the distributed data input pipeline.
train_dataset = preprocessed_dataset_generator(train_dataset)
```
If the original dataset is _not_ a `tf.data.Dataset`, it must already be
pre-sharded across hosts.

#### Usage in a Keras model

Once the global distribution is set and the input preprocessing pipeline
Expand Down