Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/actions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ jobs:
run: python3 -c "import jax; print('JAX devices:', jax.devices())"

- name: Test with pytest
run: pytest keras_rs/src/layers/embedding/distributed_embedding_test.py
run: pytest keras_rs/src/layers/

check_format:
name: Check the code format
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ build/
.idea/

venv/
venv_tf/
101 changes: 29 additions & 72 deletions keras_rs/src/layers/embedding/distributed_embedding_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import contextlib
import functools
import math
import os
Expand All @@ -14,6 +13,7 @@
from keras_rs.src import types
from keras_rs.src.layers.embedding import distributed_embedding
from keras_rs.src.layers.embedding import distributed_embedding_config as config
from keras_rs.src.utils import tpu_test_utils

try:
import jax
Expand All @@ -30,28 +30,6 @@
SEQUENCE_LENGTH = 13


class DummyStrategy:
def scope(self):
return contextlib.nullcontext()

@property
def num_replicas_in_sync(self):
return 1

def run(self, fn, args):
return fn(*args)

def experimental_distribute_dataset(self, dataset, options=None):
del options
return dataset


class JaxDummyStrategy(DummyStrategy):
@property
def num_replicas_in_sync(self):
return jax.device_count("tpu")


def ragged_bool_true(self):
return True

Expand All @@ -74,46 +52,11 @@ def setUp(self):
# FLAGS.xla_sparse_core_max_ids_per_partition_per_sample = 16
# FLAGS.xla_sparse_core_max_unique_ids_per_partition_per_sample = 16

resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(resolver)

topology = tf.tpu.experimental.initialize_tpu_system(resolver)
tpu_metadata = resolver.get_tpu_system_metadata()

device_assignment = tf.tpu.experimental.DeviceAssignment.build(
topology, num_replicas=tpu_metadata.num_hosts
)
self._strategy = tf.distribute.TPUStrategy(
resolver, experimental_device_assignment=device_assignment
)
print("### num_replicas", self._strategy.num_replicas_in_sync)
self.addCleanup(tf.tpu.experimental.shutdown_tpu_system, resolver)
elif keras.backend.backend() == "jax" and self.on_tpu:
self._strategy = JaxDummyStrategy()
else:
self._strategy = DummyStrategy()

self._strategy = tpu_test_utils.get_tpu_strategy(self)
self.batch_size = (
BATCH_SIZE_PER_CORE * self._strategy.num_replicas_in_sync
)

def run_with_strategy(self, fn, *args, jit_compile=False):
"""Wrapper for running a function under a strategy."""

if keras.backend.backend() == "tensorflow":

@tf.function(jit_compile=jit_compile)
def tf_function_wrapper(*tf_function_args):
def strategy_fn(*strategy_fn_args):
return fn(*strategy_fn_args)

return self._strategy.run(strategy_fn, args=tf_function_args)

return tf_function_wrapper(*args)
else:
self.assertFalse(jit_compile)
return fn(*args)

def get_embedding_config(self, input_type, placement):
sequence_length = 1 if input_type == "dense" else SEQUENCE_LENGTH

Expand Down Expand Up @@ -263,7 +206,9 @@ def test_basics(self, input_type, placement):
preprocessed_inputs = layer.preprocess(inputs, weights)
res = layer(preprocessed_inputs)
else:
res = self.run_with_strategy(layer.__call__, inputs, weights)
res = tpu_test_utils.run_with_strategy(
self._strategy, layer.__call__, inputs, weights
)

if placement == "default_device" or not self.on_tpu:
# verify sublayers and variables are tracked
Expand Down Expand Up @@ -422,14 +367,14 @@ def test_dataset_generator():
model.compile(optimizer="adam", loss="mse")

model_inputs, _ = next(iter(test_dataset))
test_output_before = self.run_with_strategy(
model.__call__, model_inputs
test_output_before = tpu_test_utils.run_with_strategy(
self._strategy, model.__call__, model_inputs
)

model.fit(train_dataset, steps_per_epoch=1, epochs=1)

test_output_after = self.run_with_strategy(
model.__call__, model_inputs
test_output_after = tpu_test_utils.run_with_strategy(
self._strategy, model.__call__, model_inputs
)

# Verify that the embedding has actually trained.
Expand Down Expand Up @@ -610,10 +555,16 @@ def test_correctness(
preprocessed,
)
else:
res = self.run_with_strategy(layer.__call__, preprocessed)
res = tpu_test_utils.run_with_strategy(
self._strategy, layer.__call__, preprocessed
)
else:
res = self.run_with_strategy(
layer.__call__, inputs, weights, jit_compile=jit_compile
res = tpu_test_utils.run_with_strategy(
self._strategy,
layer.__call__,
inputs,
weights,
jit_compile=jit_compile,
)

self.assertEqual(res.shape, (self.batch_size, EMBEDDING_OUTPUT_DIM))
Expand Down Expand Up @@ -686,7 +637,9 @@ def test_shared_table(self):
with self._strategy.scope():
layer = distributed_embedding.DistributedEmbedding(embedding_config)

res = self.run_with_strategy(layer.__call__, inputs)
res = tpu_test_utils.run_with_strategy(
self._strategy, layer.__call__, inputs
)

if self.placement == "default_device":
self.assertLen(layer._flatten_layers(include_self=False), 1)
Expand Down Expand Up @@ -760,7 +713,9 @@ def test_mixed_placement(self):
with self._strategy.scope():
layer = distributed_embedding.DistributedEmbedding(embedding_config)

res = self.run_with_strategy(layer.__call__, inputs)
res = tpu_test_utils.run_with_strategy(
self._strategy, layer.__call__, inputs
)

self.assertEqual(
res["feature1"].shape, (self.batch_size, embedding_output_dim1)
Expand Down Expand Up @@ -793,13 +748,15 @@ def test_save_load_model(self):
keras_outputs = layer(keras_inputs)
model = keras.Model(inputs=keras_inputs, outputs=keras_outputs)

output_before = self.run_with_strategy(model.__call__, inputs)
output_before = tpu_test_utils.run_with_strategy(
self._strategy, model.__call__, inputs
)
model.save(path)

with self._strategy.scope():
reloaded_model = keras.models.load_model(path)
output_after = self.run_with_strategy(
reloaded_model.__call__, inputs
output_after = tpu_test_utils.run_with_strategy(
self._strategy, reloaded_model.__call__, inputs
)

if self.placement == "sparsecore":
Expand Down
Loading