Skip to content
Open
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions .github/workflows/actions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,13 @@ jobs:
if: ${{ matrix.backend == 'jax'}}
run: python3 -c "import jax; print('JAX devices:', jax.devices())"

- name: Test with pytest
run: pytest keras_rs/src/layers/embedding/distributed_embedding_test.py
- name: Test with pytest (TensorFlow)
if: ${{ matrix.backend == 'tensorflow' }}
run: pytest keras_rs/ --ignore=keras_rs/src/layers/embedding/jax
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of the --ignore=keras_rs/src/layers/embedding/jax, can you do this on the JAX tests?

@pytest.mark.skipif(
    keras.backend.backend() != "jax",
    reason="JAX specific test",
)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the issue with the Jax test in TF TPU backend is that it has issue with import jax as it's not installed. So the skip won't work here.


- name: Test with pytest (JAX)
if: ${{ matrix.backend == 'jax' }}
run: pytest keras_rs/ --ignore=keras_rs/src/layers/embedding/jax/distributed_embedding_test.py

check_format:
name: Check the code format
Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,5 @@ build/
.idea/

venv/
venv_tf/
venv_jax/
101 changes: 29 additions & 72 deletions keras_rs/src/layers/embedding/distributed_embedding_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import contextlib
import functools
import math
import os
Expand All @@ -14,6 +13,7 @@
from keras_rs.src import types
from keras_rs.src.layers.embedding import distributed_embedding
from keras_rs.src.layers.embedding import distributed_embedding_config as config
from keras_rs.src.utils import tpu_test_utils

try:
import jax
Expand All @@ -30,28 +30,6 @@
SEQUENCE_LENGTH = 13


class DummyStrategy:
def scope(self):
return contextlib.nullcontext()

@property
def num_replicas_in_sync(self):
return 1

def run(self, fn, args):
return fn(*args)

def experimental_distribute_dataset(self, dataset, options=None):
del options
return dataset


class JaxDummyStrategy(DummyStrategy):
@property
def num_replicas_in_sync(self):
return jax.device_count("tpu")


def ragged_bool_true(self):
return True

Expand All @@ -74,46 +52,11 @@ def setUp(self):
# FLAGS.xla_sparse_core_max_ids_per_partition_per_sample = 16
# FLAGS.xla_sparse_core_max_unique_ids_per_partition_per_sample = 16

resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(resolver)

topology = tf.tpu.experimental.initialize_tpu_system(resolver)
tpu_metadata = resolver.get_tpu_system_metadata()

device_assignment = tf.tpu.experimental.DeviceAssignment.build(
topology, num_replicas=tpu_metadata.num_hosts
)
self._strategy = tf.distribute.TPUStrategy(
resolver, experimental_device_assignment=device_assignment
)
print("### num_replicas", self._strategy.num_replicas_in_sync)
self.addCleanup(tf.tpu.experimental.shutdown_tpu_system, resolver)
elif keras.backend.backend() == "jax" and self.on_tpu:
self._strategy = JaxDummyStrategy()
else:
self._strategy = DummyStrategy()

self._strategy = tpu_test_utils.get_tpu_strategy(self)
self.batch_size = (
BATCH_SIZE_PER_CORE * self._strategy.num_replicas_in_sync
)

def run_with_strategy(self, fn, *args, jit_compile=False):
"""Wrapper for running a function under a strategy."""

if keras.backend.backend() == "tensorflow":

@tf.function(jit_compile=jit_compile)
def tf_function_wrapper(*tf_function_args):
def strategy_fn(*strategy_fn_args):
return fn(*strategy_fn_args)

return self._strategy.run(strategy_fn, args=tf_function_args)

return tf_function_wrapper(*args)
else:
self.assertFalse(jit_compile)
return fn(*args)

def get_embedding_config(self, input_type, placement):
sequence_length = 1 if input_type == "dense" else SEQUENCE_LENGTH

Expand Down Expand Up @@ -263,7 +206,9 @@ def test_basics(self, input_type, placement):
preprocessed_inputs = layer.preprocess(inputs, weights)
res = layer(preprocessed_inputs)
else:
res = self.run_with_strategy(layer.__call__, inputs, weights)
res = tpu_test_utils.run_with_strategy(
self._strategy, layer.__call__, inputs, weights
)

if placement == "default_device" or not self.on_tpu:
# verify sublayers and variables are tracked
Expand Down Expand Up @@ -422,14 +367,14 @@ def test_dataset_generator():
model.compile(optimizer="adam", loss="mse")

model_inputs, _ = next(iter(test_dataset))
test_output_before = self.run_with_strategy(
model.__call__, model_inputs
test_output_before = tpu_test_utils.run_with_strategy(
self._strategy, model.__call__, model_inputs
)

model.fit(train_dataset, steps_per_epoch=1, epochs=1)

test_output_after = self.run_with_strategy(
model.__call__, model_inputs
test_output_after = tpu_test_utils.run_with_strategy(
self._strategy, model.__call__, model_inputs
)

# Verify that the embedding has actually trained.
Expand Down Expand Up @@ -610,10 +555,16 @@ def test_correctness(
preprocessed,
)
else:
res = self.run_with_strategy(layer.__call__, preprocessed)
res = tpu_test_utils.run_with_strategy(
self._strategy, layer.__call__, preprocessed
)
else:
res = self.run_with_strategy(
layer.__call__, inputs, weights, jit_compile=jit_compile
res = tpu_test_utils.run_with_strategy(
self._strategy,
layer.__call__,
inputs,
weights,
jit_compile=jit_compile,
)

self.assertEqual(res.shape, (self.batch_size, EMBEDDING_OUTPUT_DIM))
Expand Down Expand Up @@ -686,7 +637,9 @@ def test_shared_table(self):
with self._strategy.scope():
layer = distributed_embedding.DistributedEmbedding(embedding_config)

res = self.run_with_strategy(layer.__call__, inputs)
res = tpu_test_utils.run_with_strategy(
self._strategy, layer.__call__, inputs
)

if self.placement == "default_device":
self.assertLen(layer._flatten_layers(include_self=False), 1)
Expand Down Expand Up @@ -760,7 +713,9 @@ def test_mixed_placement(self):
with self._strategy.scope():
layer = distributed_embedding.DistributedEmbedding(embedding_config)

res = self.run_with_strategy(layer.__call__, inputs)
res = tpu_test_utils.run_with_strategy(
self._strategy, layer.__call__, inputs
)

self.assertEqual(
res["feature1"].shape, (self.batch_size, embedding_output_dim1)
Expand Down Expand Up @@ -793,13 +748,15 @@ def test_save_load_model(self):
keras_outputs = layer(keras_inputs)
model = keras.Model(inputs=keras_inputs, outputs=keras_outputs)

output_before = self.run_with_strategy(model.__call__, inputs)
output_before = tpu_test_utils.run_with_strategy(
self._strategy, model.__call__, inputs
)
model.save(path)

with self._strategy.scope():
reloaded_model = keras.models.load_model(path)
output_after = self.run_with_strategy(
reloaded_model.__call__, inputs
output_after = tpu_test_utils.run_with_strategy(
self._strategy, reloaded_model.__call__, inputs
)

if self.placement == "sparsecore":
Expand Down
Loading