Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions .github/workflows/actions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,13 @@ jobs:
if: ${{ matrix.backend == 'jax'}}
run: python3 -c "import jax; print('JAX devices:', jax.devices())"

- name: Test with pytest
run: pytest keras_rs/src/layers/embedding/distributed_embedding_test.py
- name: Test with pytest (TensorFlow)
if: ${{ matrix.backend == 'tensorflow' }}
run: pytest keras_rs/ --ignore=keras_rs/src/layers/embedding/jax
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of the --ignore=keras_rs/src/layers/embedding/jax, can you do this on the JAX tests?

@pytest.mark.skipif(
    keras.backend.backend() != "jax",
    reason="JAX specific test",
)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the issue with the Jax test in TF TPU backend is that it has issue with import jax as it's not installed. So the skip won't work here.


- name: Test with pytest (JAX)
if: ${{ matrix.backend == 'jax' }}
run: pytest keras_rs/ --ignore=keras_rs/src/layers/embedding/jax/distributed_embedding_test.py

check_format:
name: Check the code format
Expand Down
16 changes: 16 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import pytest
import os
from keras_rs.src.utils import tpu_test_utils

@pytest.fixture(scope="session", autouse=True)
def prime_shared_tpu_strategy(request):
"""
Eagerly initializes the shared TPU strategy at the beginning of the session
if running on a TPU. This helps catch initialization errors early.
"""
strategy = tpu_test_utils.get_shared_tpu_strategy()
if not strategy:
pytest.fail(
"Failed to initialize shared TPUStrategy for the test session. "
"Check logs for details from create_tpu_strategy."
)
124 changes: 40 additions & 84 deletions keras_rs/src/layers/embedding/distributed_embedding_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import contextlib
import functools
import math
import os
Expand All @@ -14,6 +13,7 @@
from keras_rs.src import types
from keras_rs.src.layers.embedding import distributed_embedding
from keras_rs.src.layers.embedding import distributed_embedding_config as config
from keras_rs.src.utils import tpu_test_utils

try:
import jax
Expand All @@ -30,28 +30,6 @@
SEQUENCE_LENGTH = 13


class DummyStrategy:
def scope(self):
return contextlib.nullcontext()

@property
def num_replicas_in_sync(self):
return 1

def run(self, fn, args):
return fn(*args)

def experimental_distribute_dataset(self, dataset, options=None):
del options
return dataset


class JaxDummyStrategy(DummyStrategy):
@property
def num_replicas_in_sync(self):
return jax.device_count("tpu")


def ragged_bool_true(self):
return True

Expand All @@ -74,46 +52,10 @@ def setUp(self):
# FLAGS.xla_sparse_core_max_ids_per_partition_per_sample = 16
# FLAGS.xla_sparse_core_max_unique_ids_per_partition_per_sample = 16

resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(resolver)

topology = tf.tpu.experimental.initialize_tpu_system(resolver)
tpu_metadata = resolver.get_tpu_system_metadata()

device_assignment = tf.tpu.experimental.DeviceAssignment.build(
topology, num_replicas=tpu_metadata.num_hosts
)
self._strategy = tf.distribute.TPUStrategy(
resolver, experimental_device_assignment=device_assignment
)
print("### num_replicas", self._strategy.num_replicas_in_sync)
self.addCleanup(tf.tpu.experimental.shutdown_tpu_system, resolver)
elif keras.backend.backend() == "jax" and self.on_tpu:
self._strategy = JaxDummyStrategy()
else:
self._strategy = DummyStrategy()

self.batch_size = (
BATCH_SIZE_PER_CORE * self._strategy.num_replicas_in_sync
BATCH_SIZE_PER_CORE * self.strategy.num_replicas_in_sync
)

def run_with_strategy(self, fn, *args, jit_compile=False):
"""Wrapper for running a function under a strategy."""

if keras.backend.backend() == "tensorflow":

@tf.function(jit_compile=jit_compile)
def tf_function_wrapper(*tf_function_args):
def strategy_fn(*strategy_fn_args):
return fn(*strategy_fn_args)

return self._strategy.run(strategy_fn, args=tf_function_args)

return tf_function_wrapper(*args)
else:
self.assertFalse(jit_compile)
return fn(*args)

def get_embedding_config(self, input_type, placement):
sequence_length = 1 if input_type == "dense" else SEQUENCE_LENGTH

Expand Down Expand Up @@ -252,18 +194,20 @@ def test_basics(self, input_type, placement):

if placement == "sparsecore" and not self.on_tpu:
with self.assertRaisesRegex(Exception, "sparsecore"):
with self._strategy.scope():
with self.strategy.scope():
distributed_embedding.DistributedEmbedding(feature_configs)
return

with self._strategy.scope():
with self.strategy.scope():
layer = distributed_embedding.DistributedEmbedding(feature_configs)

if keras.backend.backend() == "jax":
preprocessed_inputs = layer.preprocess(inputs, weights)
res = layer(preprocessed_inputs)
else:
res = self.run_with_strategy(layer.__call__, inputs, weights)
res = tpu_test_utils.run_with_strategy(
self.strategy, layer.__call__, inputs, weights
)

if placement == "default_device" or not self.on_tpu:
# verify sublayers and variables are tracked
Expand Down Expand Up @@ -332,7 +276,7 @@ def test_model_fit(self, input_type, use_weights):
(test_model_inputs, test_labels)
)

with self._strategy.scope():
with self.strategy.scope():
layer = distributed_embedding.DistributedEmbedding(feature_configs)

def _create_keras_input(
Expand Down Expand Up @@ -403,7 +347,7 @@ def test_dataset_generator():
# New preprocessed data removes the `weights` component.
dataset_has_weights = False
else:
train_dataset = self._strategy.experimental_distribute_dataset(
train_dataset = self.strategy.experimental_distribute_dataset(
train_dataset,
options=tf.distribute.InputOptions(
experimental_fetch_to_device=False
Expand All @@ -418,18 +362,18 @@ def test_dataset_generator():
inputs=keras_model_inputs, outputs=keras_model_outputs
)

with self._strategy.scope():
with self.strategy.scope():
model.compile(optimizer="adam", loss="mse")

model_inputs, _ = next(iter(test_dataset))
test_output_before = self.run_with_strategy(
model.__call__, model_inputs
test_output_before = tpu_test_utils.run_with_strategy(
self.strategy, model.__call__, model_inputs
)

model.fit(train_dataset, steps_per_epoch=1, epochs=1)

test_output_after = self.run_with_strategy(
model.__call__, model_inputs
test_output_after = tpu_test_utils.run_with_strategy(
self.strategy, model.__call__, model_inputs
)

# Verify that the embedding has actually trained.
Expand Down Expand Up @@ -567,7 +511,7 @@ def test_correctness(
if not use_weights:
weights = None

with self._strategy.scope():
with self.strategy.scope():
layer = distributed_embedding.DistributedEmbedding(feature_config)

if keras.backend.backend() == "jax":
Expand Down Expand Up @@ -610,15 +554,21 @@ def test_correctness(
preprocessed,
)
else:
res = self.run_with_strategy(layer.__call__, preprocessed)
res = tpu_test_utils.run_with_strategy(
self.strategy, layer.__call__, preprocessed
)
else:
res = self.run_with_strategy(
layer.__call__, inputs, weights, jit_compile=jit_compile
res = tpu_test_utils.run_with_strategy(
self.strategy,
layer.__call__,
inputs,
weights,
jit_compile=jit_compile,
)

self.assertEqual(res.shape, (self.batch_size, EMBEDDING_OUTPUT_DIM))

with self._strategy.scope():
with self.strategy.scope():
tables = layer.get_embedding_tables()

emb = tables["table"]
Expand Down Expand Up @@ -683,10 +633,12 @@ def test_shared_table(self):
"dense", embedding_config
)

with self._strategy.scope():
with self.strategy.scope():
layer = distributed_embedding.DistributedEmbedding(embedding_config)

res = self.run_with_strategy(layer.__call__, inputs)
res = tpu_test_utils.run_with_strategy(
self.strategy, layer.__call__, inputs
)

if self.placement == "default_device":
self.assertLen(layer._flatten_layers(include_self=False), 1)
Expand Down Expand Up @@ -757,10 +709,12 @@ def test_mixed_placement(self):
"dense", embedding_config
)

with self._strategy.scope():
with self.strategy.scope():
layer = distributed_embedding.DistributedEmbedding(embedding_config)

res = self.run_with_strategy(layer.__call__, inputs)
res = tpu_test_utils.run_with_strategy(
self.strategy, layer.__call__, inputs
)

self.assertEqual(
res["feature1"].shape, (self.batch_size, embedding_output_dim1)
Expand All @@ -786,20 +740,22 @@ def test_save_load_model(self):
with tempfile.TemporaryDirectory() as temp_dir:
path = os.path.join(temp_dir, "model.keras")

with self._strategy.scope():
with self.strategy.scope():
layer = distributed_embedding.DistributedEmbedding(
feature_configs
)
keras_outputs = layer(keras_inputs)
model = keras.Model(inputs=keras_inputs, outputs=keras_outputs)

output_before = self.run_with_strategy(model.__call__, inputs)
output_before = tpu_test_utils.run_with_strategy(
self.strategy, model.__call__, inputs
)
model.save(path)

with self._strategy.scope():
with self.strategy.scope():
reloaded_model = keras.models.load_model(path)
output_after = self.run_with_strategy(
reloaded_model.__call__, inputs
output_after = tpu_test_utils.run_with_strategy(
self.strategy, reloaded_model.__call__, inputs
)

if self.placement == "sparsecore":
Expand Down
14 changes: 13 additions & 1 deletion keras_rs/src/layers/embedding/embed_reduce_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,18 @@
from keras_rs.src import testing
from keras_rs.src.layers.embedding.embed_reduce import EmbedReduce

try:
import jax
from jax.experimental import sparse as jax_sparse
except ImportError:
jax = None
jax_sparse = None


class EmbedReduceTest(testing.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()

@parameterized.named_parameters(
[
(
Expand Down Expand Up @@ -172,7 +182,9 @@ def test_symbolic_call(self, input_type, input_rank, use_weights):

def test_predict(self):
input = keras.random.randint((5, 7), minval=0, maxval=10)
model = keras.models.Sequential([EmbedReduce(10, 20)])
with self.strategy.scope():
model = keras.models.Sequential([EmbedReduce(10, 20)])
model.compile(optimizer="adam", loss="mse")
model.predict(input, batch_size=2)

def test_serialization(self):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import keras
import pytest
import tensorflow as tf
from absl.testing import parameterized

Expand All @@ -7,6 +8,10 @@
from keras_rs.src.layers.embedding.tensorflow import config_conversion


@pytest.mark.skipif(
keras.backend.backend() != "tensorflow",
reason="Tensorflow specific test",
)
class ConfigConversionTest(testing.TestCase, parameterized.TestCase):
@parameterized.named_parameters(
(
Expand Down
27 changes: 18 additions & 9 deletions keras_rs/src/layers/feature_interaction/dot_interaction_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

class DotInteractionTest(testing.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()

self.input = [
ops.array([[0.1, -4.3, 0.2, 1.1, 0.3]]),
ops.array([[2.0, 3.2, -1.0, 0.0, 1.0]]),
Expand Down Expand Up @@ -81,7 +83,12 @@ def test_call(self, self_interaction, skip_gather, exp_output_idx):
self_interaction=self_interaction, skip_gather=skip_gather
)
output = layer(self.input)
self.assertAllClose(output, self.exp_outputs[exp_output_idx])
self.assertAllClose(
output,
self.exp_outputs[exp_output_idx],
tpu_atol=1e-2,
tpu_rtol=1e-2,
)

def test_invalid_input_rank(self):
rank_1_input = [ops.ones((3,)), ops.ones((3,))]
Expand Down Expand Up @@ -120,14 +127,16 @@ def test_invalid_input_different_shapes(self):
),
)
def test_predict(self, self_interaction, skip_gather):
feature1 = keras.layers.Input(shape=(5,))
feature2 = keras.layers.Input(shape=(5,))
feature3 = keras.layers.Input(shape=(5,))
x = DotInteraction(
self_interaction=self_interaction, skip_gather=skip_gather
)([feature1, feature2, feature3])
x = keras.layers.Dense(units=1)(x)
model = keras.Model([feature1, feature2, feature3], x)
with self.strategy.scope():
feature1 = keras.layers.Input(shape=(5,))
feature2 = keras.layers.Input(shape=(5,))
feature3 = keras.layers.Input(shape=(5,))
x = DotInteraction(
self_interaction=self_interaction, skip_gather=skip_gather
)([feature1, feature2, feature3])
x = keras.layers.Dense(units=1)(x)
model = keras.Model([feature1, feature2, feature3], x)
model.compile(optimizer="adam", loss="mse")

model.predict(self.input, batch_size=2)

Expand Down
Loading