diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index 5088f571..b125b6d1 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -89,8 +89,13 @@ jobs: if: ${{ matrix.backend == 'jax'}} run: python3 -c "import jax; print('JAX devices:', jax.devices())" - - name: Test with pytest - run: pytest keras_rs/src/layers/embedding/distributed_embedding_test.py + - name: Test with pytest (TensorFlow) + if: ${{ matrix.backend == 'tensorflow' }} + run: pytest keras_rs/ --ignore=keras_rs/src/layers/embedding/jax + + - name: Test with pytest (JAX) + if: ${{ matrix.backend == 'jax' }} + run: pytest keras_rs/ --ignore=keras_rs/src/layers/embedding/jax/distributed_embedding_test.py check_format: name: Check the code format diff --git a/conftest.py b/conftest.py new file mode 100644 index 00000000..f80e27f6 --- /dev/null +++ b/conftest.py @@ -0,0 +1,26 @@ +from typing import Union + +import pytest +import tensorflow as tf + +from keras_rs.src.utils import tpu_test_utils + +StrategyType = Union[ + tf.distribute.Strategy, + tpu_test_utils.DummyStrategy, + tpu_test_utils.JaxDummyStrategy, +] + + +@pytest.fixture(scope="session", autouse=True) +def prime_shared_tpu_strategy() -> None: + """ + Eagerly initializes the shared TPU strategy at the beginning of the session + if running on a TPU. This helps catch initialization errors early. + """ + strategy = tpu_test_utils.get_shared_tpu_strategy() + if not strategy: + pytest.fail( + "Failed to initialize shared TPUStrategy for the test session. " + "Check logs for details from create_tpu_strategy." + ) diff --git a/keras_rs/src/layers/embedding/distributed_embedding_test.py b/keras_rs/src/layers/embedding/distributed_embedding_test.py index cb4df82f..e493931f 100644 --- a/keras_rs/src/layers/embedding/distributed_embedding_test.py +++ b/keras_rs/src/layers/embedding/distributed_embedding_test.py @@ -1,4 +1,3 @@ -import contextlib import functools import math import os @@ -14,6 +13,7 @@ from keras_rs.src import types from keras_rs.src.layers.embedding import distributed_embedding from keras_rs.src.layers.embedding import distributed_embedding_config as config +from keras_rs.src.utils import tpu_test_utils try: import jax @@ -30,28 +30,6 @@ SEQUENCE_LENGTH = 13 -class DummyStrategy: - def scope(self): - return contextlib.nullcontext() - - @property - def num_replicas_in_sync(self): - return 1 - - def run(self, fn, args): - return fn(*args) - - def experimental_distribute_dataset(self, dataset, options=None): - del options - return dataset - - -class JaxDummyStrategy(DummyStrategy): - @property - def num_replicas_in_sync(self): - return jax.device_count("tpu") - - def ragged_bool_true(self): return True @@ -74,46 +52,10 @@ def setUp(self): # FLAGS.xla_sparse_core_max_ids_per_partition_per_sample = 16 # FLAGS.xla_sparse_core_max_unique_ids_per_partition_per_sample = 16 - resolver = tf.distribute.cluster_resolver.TPUClusterResolver() - tf.config.experimental_connect_to_cluster(resolver) - - topology = tf.tpu.experimental.initialize_tpu_system(resolver) - tpu_metadata = resolver.get_tpu_system_metadata() - - device_assignment = tf.tpu.experimental.DeviceAssignment.build( - topology, num_replicas=tpu_metadata.num_hosts - ) - self._strategy = tf.distribute.TPUStrategy( - resolver, experimental_device_assignment=device_assignment - ) - print("### num_replicas", self._strategy.num_replicas_in_sync) - self.addCleanup(tf.tpu.experimental.shutdown_tpu_system, resolver) - elif keras.backend.backend() == "jax" and self.on_tpu: - self._strategy = JaxDummyStrategy() - else: - self._strategy = DummyStrategy() - self.batch_size = ( - BATCH_SIZE_PER_CORE * self._strategy.num_replicas_in_sync + BATCH_SIZE_PER_CORE * self.strategy.num_replicas_in_sync ) - def run_with_strategy(self, fn, *args, jit_compile=False): - """Wrapper for running a function under a strategy.""" - - if keras.backend.backend() == "tensorflow": - - @tf.function(jit_compile=jit_compile) - def tf_function_wrapper(*tf_function_args): - def strategy_fn(*strategy_fn_args): - return fn(*strategy_fn_args) - - return self._strategy.run(strategy_fn, args=tf_function_args) - - return tf_function_wrapper(*args) - else: - self.assertFalse(jit_compile) - return fn(*args) - def get_embedding_config(self, input_type, placement): sequence_length = 1 if input_type == "dense" else SEQUENCE_LENGTH @@ -252,18 +194,20 @@ def test_basics(self, input_type, placement): if placement == "sparsecore" and not self.on_tpu: with self.assertRaisesRegex(Exception, "sparsecore"): - with self._strategy.scope(): + with self.strategy.scope(): distributed_embedding.DistributedEmbedding(feature_configs) return - with self._strategy.scope(): + with self.strategy.scope(): layer = distributed_embedding.DistributedEmbedding(feature_configs) if keras.backend.backend() == "jax": preprocessed_inputs = layer.preprocess(inputs, weights) res = layer(preprocessed_inputs) else: - res = self.run_with_strategy(layer.__call__, inputs, weights) + res = tpu_test_utils.run_with_strategy( + self.strategy, layer.__call__, inputs, weights + ) if placement == "default_device" or not self.on_tpu: # verify sublayers and variables are tracked @@ -332,7 +276,7 @@ def test_model_fit(self, input_type, use_weights): (test_model_inputs, test_labels) ) - with self._strategy.scope(): + with self.strategy.scope(): layer = distributed_embedding.DistributedEmbedding(feature_configs) def _create_keras_input( @@ -403,7 +347,7 @@ def test_dataset_generator(): # New preprocessed data removes the `weights` component. dataset_has_weights = False else: - train_dataset = self._strategy.experimental_distribute_dataset( + train_dataset = self.strategy.experimental_distribute_dataset( train_dataset, options=tf.distribute.InputOptions( experimental_fetch_to_device=False @@ -418,18 +362,18 @@ def test_dataset_generator(): inputs=keras_model_inputs, outputs=keras_model_outputs ) - with self._strategy.scope(): + with self.strategy.scope(): model.compile(optimizer="adam", loss="mse") model_inputs, _ = next(iter(test_dataset)) - test_output_before = self.run_with_strategy( - model.__call__, model_inputs + test_output_before = tpu_test_utils.run_with_strategy( + self.strategy, model.__call__, model_inputs ) model.fit(train_dataset, steps_per_epoch=1, epochs=1) - test_output_after = self.run_with_strategy( - model.__call__, model_inputs + test_output_after = tpu_test_utils.run_with_strategy( + self.strategy, model.__call__, model_inputs ) # Verify that the embedding has actually trained. @@ -567,7 +511,7 @@ def test_correctness( if not use_weights: weights = None - with self._strategy.scope(): + with self.strategy.scope(): layer = distributed_embedding.DistributedEmbedding(feature_config) if keras.backend.backend() == "jax": @@ -610,15 +554,21 @@ def test_correctness( preprocessed, ) else: - res = self.run_with_strategy(layer.__call__, preprocessed) + res = tpu_test_utils.run_with_strategy( + self.strategy, layer.__call__, preprocessed + ) else: - res = self.run_with_strategy( - layer.__call__, inputs, weights, jit_compile=jit_compile + res = tpu_test_utils.run_with_strategy( + self.strategy, + layer.__call__, + inputs, + weights, + jit_compile=jit_compile, ) self.assertEqual(res.shape, (self.batch_size, EMBEDDING_OUTPUT_DIM)) - with self._strategy.scope(): + with self.strategy.scope(): tables = layer.get_embedding_tables() emb = tables["table"] @@ -683,10 +633,12 @@ def test_shared_table(self): "dense", embedding_config ) - with self._strategy.scope(): + with self.strategy.scope(): layer = distributed_embedding.DistributedEmbedding(embedding_config) - res = self.run_with_strategy(layer.__call__, inputs) + res = tpu_test_utils.run_with_strategy( + self.strategy, layer.__call__, inputs + ) if self.placement == "default_device": self.assertLen(layer._flatten_layers(include_self=False), 1) @@ -757,10 +709,12 @@ def test_mixed_placement(self): "dense", embedding_config ) - with self._strategy.scope(): + with self.strategy.scope(): layer = distributed_embedding.DistributedEmbedding(embedding_config) - res = self.run_with_strategy(layer.__call__, inputs) + res = tpu_test_utils.run_with_strategy( + self.strategy, layer.__call__, inputs + ) self.assertEqual( res["feature1"].shape, (self.batch_size, embedding_output_dim1) @@ -786,20 +740,22 @@ def test_save_load_model(self): with tempfile.TemporaryDirectory() as temp_dir: path = os.path.join(temp_dir, "model.keras") - with self._strategy.scope(): + with self.strategy.scope(): layer = distributed_embedding.DistributedEmbedding( feature_configs ) keras_outputs = layer(keras_inputs) model = keras.Model(inputs=keras_inputs, outputs=keras_outputs) - output_before = self.run_with_strategy(model.__call__, inputs) + output_before = tpu_test_utils.run_with_strategy( + self.strategy, model.__call__, inputs + ) model.save(path) - with self._strategy.scope(): + with self.strategy.scope(): reloaded_model = keras.models.load_model(path) - output_after = self.run_with_strategy( - reloaded_model.__call__, inputs + output_after = tpu_test_utils.run_with_strategy( + self.strategy, reloaded_model.__call__, inputs ) if self.placement == "sparsecore": diff --git a/keras_rs/src/layers/embedding/embed_reduce_test.py b/keras_rs/src/layers/embedding/embed_reduce_test.py index 1d7fb456..440259a9 100644 --- a/keras_rs/src/layers/embedding/embed_reduce_test.py +++ b/keras_rs/src/layers/embedding/embed_reduce_test.py @@ -9,8 +9,18 @@ from keras_rs.src import testing from keras_rs.src.layers.embedding.embed_reduce import EmbedReduce +try: + import jax + from jax.experimental import sparse as jax_sparse +except ImportError: + jax = None + jax_sparse = None + class EmbedReduceTest(testing.TestCase, parameterized.TestCase): + def setUp(self): + super().setUp() + @parameterized.named_parameters( [ ( @@ -172,7 +182,9 @@ def test_symbolic_call(self, input_type, input_rank, use_weights): def test_predict(self): input = keras.random.randint((5, 7), minval=0, maxval=10) - model = keras.models.Sequential([EmbedReduce(10, 20)]) + with self.strategy.scope(): + model = keras.models.Sequential([EmbedReduce(10, 20)]) + model.compile(optimizer="adam", loss="mse") model.predict(input, batch_size=2) def test_serialization(self): diff --git a/keras_rs/src/layers/embedding/tensorflow/config_conversion_test.py b/keras_rs/src/layers/embedding/tensorflow/config_conversion_test.py index f314f887..3d8e4daa 100644 --- a/keras_rs/src/layers/embedding/tensorflow/config_conversion_test.py +++ b/keras_rs/src/layers/embedding/tensorflow/config_conversion_test.py @@ -1,4 +1,5 @@ import keras +import pytest import tensorflow as tf from absl.testing import parameterized @@ -7,6 +8,10 @@ from keras_rs.src.layers.embedding.tensorflow import config_conversion +@pytest.mark.skipif( + keras.backend.backend() != "tensorflow", + reason="Tensorflow specific test", +) class ConfigConversionTest(testing.TestCase, parameterized.TestCase): @parameterized.named_parameters( ( diff --git a/keras_rs/src/layers/feature_interaction/dot_interaction_test.py b/keras_rs/src/layers/feature_interaction/dot_interaction_test.py index 99c38abc..b5aa1f6c 100644 --- a/keras_rs/src/layers/feature_interaction/dot_interaction_test.py +++ b/keras_rs/src/layers/feature_interaction/dot_interaction_test.py @@ -12,6 +12,8 @@ class DotInteractionTest(testing.TestCase, parameterized.TestCase): def setUp(self): + super().setUp() + self.input = [ ops.array([[0.1, -4.3, 0.2, 1.1, 0.3]]), ops.array([[2.0, 3.2, -1.0, 0.0, 1.0]]), @@ -81,7 +83,12 @@ def test_call(self, self_interaction, skip_gather, exp_output_idx): self_interaction=self_interaction, skip_gather=skip_gather ) output = layer(self.input) - self.assertAllClose(output, self.exp_outputs[exp_output_idx]) + self.assertAllClose( + output, + self.exp_outputs[exp_output_idx], + tpu_atol=1e-2, + tpu_rtol=1e-2, + ) def test_invalid_input_rank(self): rank_1_input = [ops.ones((3,)), ops.ones((3,))] @@ -120,14 +127,16 @@ def test_invalid_input_different_shapes(self): ), ) def test_predict(self, self_interaction, skip_gather): - feature1 = keras.layers.Input(shape=(5,)) - feature2 = keras.layers.Input(shape=(5,)) - feature3 = keras.layers.Input(shape=(5,)) - x = DotInteraction( - self_interaction=self_interaction, skip_gather=skip_gather - )([feature1, feature2, feature3]) - x = keras.layers.Dense(units=1)(x) - model = keras.Model([feature1, feature2, feature3], x) + with self.strategy.scope(): + feature1 = keras.layers.Input(shape=(5,)) + feature2 = keras.layers.Input(shape=(5,)) + feature3 = keras.layers.Input(shape=(5,)) + x = DotInteraction( + self_interaction=self_interaction, skip_gather=skip_gather + )([feature1, feature2, feature3]) + x = keras.layers.Dense(units=1)(x) + model = keras.Model([feature1, feature2, feature3], x) + model.compile(optimizer="adam", loss="mse") model.predict(self.input, batch_size=2) diff --git a/keras_rs/src/layers/feature_interaction/feature_cross_test.py b/keras_rs/src/layers/feature_interaction/feature_cross_test.py index 8724ab53..485f776e 100644 --- a/keras_rs/src/layers/feature_interaction/feature_cross_test.py +++ b/keras_rs/src/layers/feature_interaction/feature_cross_test.py @@ -10,6 +10,8 @@ class FeatureCrossTest(testing.TestCase, parameterized.TestCase): def setUp(self): + super().setUp() + self.x0 = ops.array([[0.1, 0.2, 0.3]], dtype="float32") self.x = ops.array([[0.4, 0.5, 0.6]], dtype="float32") self.exp_output = ops.array([[0.55, 0.8, 1.05]]) @@ -77,11 +79,13 @@ def test_pre_activation(self): self.assertAllClose(self.x, output) def test_predict(self): - x0 = keras.layers.Input(shape=(3,)) - x1 = FeatureCross(projection_dim=None)(x0, x0) - x2 = FeatureCross(projection_dim=None)(x0, x1) - logits = keras.layers.Dense(units=1)(x2) - model = keras.Model(x0, logits) + with self.strategy.scope(): + x0 = keras.layers.Input(shape=(3,)) + x1 = FeatureCross(projection_dim=None)(x0, x0) + x2 = FeatureCross(projection_dim=None)(x0, x1) + logits = keras.layers.Dense(units=1)(x2) + model = keras.Model(x0, logits) + model.compile(optimizer="adam", loss="mse") model.predict(self.x0, batch_size=2) diff --git a/keras_rs/src/layers/retrieval/hard_negative_mining_test.py b/keras_rs/src/layers/retrieval/hard_negative_mining_test.py index d7ab74d0..69964d98 100644 --- a/keras_rs/src/layers/retrieval/hard_negative_mining_test.py +++ b/keras_rs/src/layers/retrieval/hard_negative_mining_test.py @@ -9,6 +9,9 @@ class HardNegativeMiningTest(testing.TestCase, parameterized.TestCase): + def setUp(self): + super().setUp() + def create_inputs(self, rank=2): shape_3d = (15, 20, 10) shape = shape_3d[-rank:] @@ -89,12 +92,16 @@ def test_call(self, rank, num_hard_negatives): def test_predict(self): logits, labels = self.create_inputs() - in_logits = keras.layers.Input(shape=logits.shape[1:]) - in_labels = keras.layers.Input(shape=labels.shape[1:]) - out_logits, out_labels = hard_negative_mining.HardNegativeMining( - num_hard_negatives=3 - )(in_logits, in_labels) - model = keras.Model([in_logits, in_labels], [out_logits, out_labels]) + with self.strategy.scope(): + in_logits = keras.layers.Input(shape=logits.shape[1:]) + in_labels = keras.layers.Input(shape=labels.shape[1:]) + out_logits, out_labels = hard_negative_mining.HardNegativeMining( + num_hard_negatives=3 + )(in_logits, in_labels) + model = keras.Model( + [in_logits, in_labels], [out_logits, out_labels] + ) + model.compile(optimizer="adam", loss="mse") model.predict([logits, labels], batch_size=8) diff --git a/keras_rs/src/layers/retrieval/remove_accidental_hits_test.py b/keras_rs/src/layers/retrieval/remove_accidental_hits_test.py index 8cb4fa71..f436678b 100644 --- a/keras_rs/src/layers/retrieval/remove_accidental_hits_test.py +++ b/keras_rs/src/layers/retrieval/remove_accidental_hits_test.py @@ -9,6 +9,9 @@ class RemoveAccidentalHitsTest(testing.TestCase, parameterized.TestCase): + def setUp(self): + super().setUp() + def create_inputs(self, logits_rank=2, candidate_ids_rank=1): shape_3d = (15, 20, 10) shape = shape_3d[-logits_rank:] @@ -151,14 +154,16 @@ def test_predict(self): # Note: for predict, we test with probabilities that have a batch dim. logits, labels, candidate_ids = self.create_inputs(candidate_ids_rank=2) - layer = remove_accidental_hits.RemoveAccidentalHits() - in_logits = keras.layers.Input(logits.shape[1:]) - in_labels = keras.layers.Input(labels.shape[1:]) - in_candidate_ids = keras.layers.Input(labels.shape[1:]) - out_logits = layer(in_logits, in_labels, in_candidate_ids) - model = keras.Model( - [in_logits, in_labels, in_candidate_ids], out_logits - ) + with self.strategy.scope(): + layer = remove_accidental_hits.RemoveAccidentalHits() + in_logits = keras.layers.Input(logits.shape[1:]) + in_labels = keras.layers.Input(labels.shape[1:]) + in_candidate_ids = keras.layers.Input(labels.shape[1:]) + out_logits = layer(in_logits, in_labels, in_candidate_ids) + model = keras.Model( + [in_logits, in_labels, in_candidate_ids], out_logits + ) + model.compile(optimizer="adam", loss="mse") model.predict([logits, labels, candidate_ids], batch_size=8) diff --git a/keras_rs/src/layers/retrieval/sampling_probability_correction_test.py b/keras_rs/src/layers/retrieval/sampling_probability_correction_test.py index 8dc8ff73..0ade230a 100644 --- a/keras_rs/src/layers/retrieval/sampling_probability_correction_test.py +++ b/keras_rs/src/layers/retrieval/sampling_probability_correction_test.py @@ -11,6 +11,9 @@ class SamplingProbabilityCorrectionTest( testing.TestCase, parameterized.TestCase ): + def setUp(self): + super().setUp() + def create_inputs(self, logits_rank=2, probs_rank=1): shape_3d = (15, 20, 10) logits_shape = shape_3d[-logits_rank:] @@ -87,11 +90,15 @@ def test_predict(self): # Note: for predict, we test with probabilities that have a batch dim. logits, probs = self.create_inputs(probs_rank=2) - layer = sampling_probability_correction.SamplingProbabilityCorrection() - in_logits = keras.layers.Input(logits.shape[1:]) - in_probs = keras.layers.Input(probs.shape[1:]) - out_logits = layer(in_logits, in_probs) - model = keras.Model([in_logits, in_probs], out_logits) + with self.strategy.scope(): + layer = ( + sampling_probability_correction.SamplingProbabilityCorrection() + ) + in_logits = keras.layers.Input(logits.shape[1:]) + in_probs = keras.layers.Input(probs.shape[1:]) + out_logits = layer(in_logits, in_probs) + model = keras.Model([in_logits, in_probs], out_logits) + model.compile(optimizer="adam", loss="mse") model.predict([logits, probs], batch_size=4) diff --git a/keras_rs/src/losses/list_mle_loss_test.py b/keras_rs/src/losses/list_mle_loss_test.py index 3656354b..ebf9a1d1 100644 --- a/keras_rs/src/losses/list_mle_loss_test.py +++ b/keras_rs/src/losses/list_mle_loss_test.py @@ -10,6 +10,7 @@ class ListMLELossTest(testing.TestCase, parameterized.TestCase): def setUp(self): + super().setUp() self.unbatched_scores = ops.array( [1.0, 3.0, 2.0, 4.0, 0.8], dtype="float32" ) @@ -83,11 +84,19 @@ def test_scalar_sample_weight(self): ) def test_model_fit(self): - inputs = keras.Input(shape=(20,), dtype="float32") - outputs = keras.layers.Dense(5)(inputs) - model = keras.Model(inputs=inputs, outputs=outputs) + def create_model(): + inputs = keras.Input(shape=(20,), dtype="float32") + outputs = keras.layers.Dense(5)(inputs) + model = keras.Model(inputs=inputs, outputs=outputs) + model.compile(loss=ListMLELoss(), optimizer="adam") + return model + + if self.strategy: + with self.strategy.scope(): + model = create_model() + else: + model = create_model() - model.compile(loss=ListMLELoss(), optimizer="adam") model.fit( x=keras.random.normal((2, 20)), y=keras.random.randint((2, 5), minval=0, maxval=2), diff --git a/keras_rs/src/losses/pairwise_hinge_loss_test.py b/keras_rs/src/losses/pairwise_hinge_loss_test.py index f5aedb20..7c782015 100644 --- a/keras_rs/src/losses/pairwise_hinge_loss_test.py +++ b/keras_rs/src/losses/pairwise_hinge_loss_test.py @@ -10,6 +10,7 @@ class PairwiseHingeLossTest(testing.TestCase, parameterized.TestCase): def setUp(self): + super().setUp() self.unbatched_scores = ops.array([1.0, 3.0, 2.0, 4.0, 0.8]) self.unbatched_labels = ops.array([1.0, 0.0, 1.0, 3.0, 2.0]) @@ -110,11 +111,19 @@ def test_mask_input(self): self.assertAllClose(output, expected_output, atol=1e-5) def test_model_fit(self): - inputs = keras.Input(shape=(20,), dtype="float32") - outputs = keras.layers.Dense(5)(inputs) - model = keras.Model(inputs=inputs, outputs=outputs) + def create_model(): + inputs = keras.Input(shape=(20,), dtype="float32") + outputs = keras.layers.Dense(5)(inputs) + model = keras.Model(inputs=inputs, outputs=outputs) + model.compile(loss=PairwiseHingeLoss(), optimizer="adam") + return model + + if self.strategy: + with self.strategy.scope(): + model = create_model() + else: + model = create_model() - model.compile(loss=PairwiseHingeLoss(), optimizer="adam") model.fit( x=keras.random.normal((2, 20)), y=keras.random.randint((2, 5), minval=0, maxval=2), diff --git a/keras_rs/src/losses/pairwise_logistic_loss_test.py b/keras_rs/src/losses/pairwise_logistic_loss_test.py index ffba4b05..74c383a0 100644 --- a/keras_rs/src/losses/pairwise_logistic_loss_test.py +++ b/keras_rs/src/losses/pairwise_logistic_loss_test.py @@ -10,6 +10,7 @@ class PairwiseLogisticLossTest(testing.TestCase, parameterized.TestCase): def setUp(self): + super().setUp() self.unbatched_scores = ops.array([1.0, 3.0, 2.0, 4.0, 0.8]) self.unbatched_labels = ops.array([1.0, 0.0, 1.0, 3.0, 2.0]) @@ -110,11 +111,19 @@ def test_mask_input(self): self.assertAllClose(output, expected_output, atol=1e-5) def test_model_fit(self): - inputs = keras.Input(shape=(20,), dtype="float32") - outputs = keras.layers.Dense(5)(inputs) - model = keras.Model(inputs=inputs, outputs=outputs) + def create_model(): + inputs = keras.Input(shape=(20,), dtype="float32") + outputs = keras.layers.Dense(5)(inputs) + model = keras.Model(inputs=inputs, outputs=outputs) + model.compile(loss=PairwiseLogisticLoss(), optimizer="adam") + return model + + if self.strategy: + with self.strategy.scope(): + model = create_model() + else: + model = create_model() - model.compile(loss=PairwiseLogisticLoss(), optimizer="adam") model.fit( x=keras.random.normal((2, 20)), y=keras.random.randint((2, 5), minval=0, maxval=2), diff --git a/keras_rs/src/losses/pairwise_mean_squared_error_test.py b/keras_rs/src/losses/pairwise_mean_squared_error_test.py index 4b93eff9..e1f865c6 100644 --- a/keras_rs/src/losses/pairwise_mean_squared_error_test.py +++ b/keras_rs/src/losses/pairwise_mean_squared_error_test.py @@ -12,6 +12,7 @@ class PairwiseMeanSquaredErrorTest(testing.TestCase, parameterized.TestCase): def setUp(self): + super().setUp() self.unbatched_scores = ops.array([1.0, 3.0, 2.0, 4.0, 0.8]) self.unbatched_labels = ops.array([1.0, 0.0, 1.0, 3.0, 2.0]) @@ -109,11 +110,19 @@ def test_mask_input(self): self.assertAllClose(output, expected_output, atol=1e-5) def test_model_fit(self): - inputs = keras.Input(shape=(20,), dtype="float32") - outputs = keras.layers.Dense(5)(inputs) - model = keras.Model(inputs=inputs, outputs=outputs) + def create_model(): + inputs = keras.Input(shape=(20,), dtype="float32") + outputs = keras.layers.Dense(5)(inputs) + model = keras.Model(inputs=inputs, outputs=outputs) + model.compile(loss=PairwiseMeanSquaredError(), optimizer="adam") + return model + + if self.strategy: + with self.strategy.scope(): + model = create_model() + else: + model = create_model() - model.compile(loss=PairwiseMeanSquaredError(), optimizer="adam") model.fit( x=keras.random.normal((2, 20)), y=keras.random.randint((2, 5), minval=0, maxval=2), diff --git a/keras_rs/src/losses/pairwise_soft_zero_one_loss_test.py b/keras_rs/src/losses/pairwise_soft_zero_one_loss_test.py index 66e7d634..92ddeae2 100644 --- a/keras_rs/src/losses/pairwise_soft_zero_one_loss_test.py +++ b/keras_rs/src/losses/pairwise_soft_zero_one_loss_test.py @@ -12,6 +12,7 @@ class PairwiseSoftZeroOneLossTest(testing.TestCase, parameterized.TestCase): def setUp(self): + super().setUp() self.unbatched_scores = ops.array([1.0, 3.0, 2.0, 4.0, 0.8]) self.unbatched_labels = ops.array([1.0, 0.0, 1.0, 3.0, 2.0]) @@ -112,11 +113,19 @@ def test_mask_input(self): self.assertAllClose(output, expected_output, atol=1e-5) def test_model_fit(self): - inputs = keras.Input(shape=(20,), dtype="float32") - outputs = keras.layers.Dense(5)(inputs) - model = keras.Model(inputs=inputs, outputs=outputs) + def create_model(): + inputs = keras.Input(shape=(20,), dtype="float32") + outputs = keras.layers.Dense(5)(inputs) + model = keras.Model(inputs=inputs, outputs=outputs) + model.compile(loss=PairwiseSoftZeroOneLoss(), optimizer="adam") + return model + + if self.strategy: + with self.strategy.scope(): + model = create_model() + else: + model = create_model() - model.compile(loss=PairwiseSoftZeroOneLoss(), optimizer="adam") model.fit( x=keras.random.normal((2, 20)), y=keras.random.randint((2, 5), minval=0, maxval=2), diff --git a/keras_rs/src/metrics/dcg_test.py b/keras_rs/src/metrics/dcg_test.py index 430214ac..3e2b4b20 100644 --- a/keras_rs/src/metrics/dcg_test.py +++ b/keras_rs/src/metrics/dcg_test.py @@ -8,6 +8,7 @@ from keras_rs.src import testing from keras_rs.src.metrics.dcg import DCG +from keras_rs.src.utils import tpu_test_utils def _compute_dcg(labels, ranks): @@ -19,6 +20,8 @@ def _compute_dcg(labels, ranks): class DCGTest(testing.TestCase, parameterized.TestCase): def setUp(self): + self._strategy = tpu_test_utils.get_tpu_strategy(self) + self.y_true_batched = ops.array( [ [0, 0, 1, 0], @@ -345,15 +348,17 @@ def inverse_discount_fn(rank): self.assertAllClose(result, expected_output, rtol=1e-5) def test_model_evaluate(self): - inputs = keras.Input(shape=(20,), dtype="float32") - outputs = keras.layers.Dense(5)(inputs) - model = keras.Model(inputs=inputs, outputs=outputs) + with self._strategy.scope(): + inputs = keras.Input(shape=(20,), dtype="float32") + outputs = keras.layers.Dense(5)(inputs) + model = keras.Model(inputs=inputs, outputs=outputs) + + model.compile( + loss=keras.losses.MeanSquaredError(), + metrics=[DCG()], + optimizer="adam", + ) - model.compile( - loss=keras.losses.MeanSquaredError(), - metrics=[DCG()], - optimizer="adam", - ) model.evaluate( x=keras.random.normal((2, 20)), y=keras.random.randint((2, 5), minval=0, maxval=4), diff --git a/keras_rs/src/metrics/mean_average_precision_test.py b/keras_rs/src/metrics/mean_average_precision_test.py index 9c16d25e..7161534a 100644 --- a/keras_rs/src/metrics/mean_average_precision_test.py +++ b/keras_rs/src/metrics/mean_average_precision_test.py @@ -6,10 +6,13 @@ from keras_rs.src import testing from keras_rs.src.metrics.mean_average_precision import MeanAveragePrecision +from keras_rs.src.utils import tpu_test_utils class MeanAveragePrecisionTest(testing.TestCase, parameterized.TestCase): def setUp(self): + self._strategy = tpu_test_utils.get_tpu_strategy(self) + self.y_true_batched = ops.array( [ [0, 0, 1, 0], @@ -276,15 +279,16 @@ def test_serialization(self): self.assertDictEqual(metric.get_config(), restored.get_config()) def test_model_evaluate(self): - inputs = keras.Input(shape=(20,), dtype="float32") - outputs = keras.layers.Dense(5)(inputs) - model = keras.Model(inputs=inputs, outputs=outputs) + with self._strategy.scope(): + inputs = keras.Input(shape=(20,), dtype="float32") + outputs = keras.layers.Dense(5)(inputs) + model = keras.Model(inputs=inputs, outputs=outputs) - model.compile( - loss=keras.losses.MeanSquaredError(), - metrics=[MeanAveragePrecision()], - optimizer="adam", - ) + model.compile( + loss=keras.losses.MeanSquaredError(), + metrics=[MeanAveragePrecision()], + optimizer="adam", + ) model.evaluate( x=keras.random.normal((2, 20)), y=keras.random.randint((2, 5), minval=0, maxval=4), diff --git a/keras_rs/src/metrics/mean_reciprocal_rank_test.py b/keras_rs/src/metrics/mean_reciprocal_rank_test.py index 02940c36..3d5264bc 100644 --- a/keras_rs/src/metrics/mean_reciprocal_rank_test.py +++ b/keras_rs/src/metrics/mean_reciprocal_rank_test.py @@ -6,10 +6,13 @@ from keras_rs.src import testing from keras_rs.src.metrics.mean_reciprocal_rank import MeanReciprocalRank +from keras_rs.src.utils import tpu_test_utils class MeanReciprocalRankTest(testing.TestCase, parameterized.TestCase): def setUp(self): + self._strategy = tpu_test_utils.get_tpu_strategy(self) + self.y_true_batched = ops.array( [ [0, 0, 1, 0], @@ -248,15 +251,17 @@ def test_serialization(self): self.assertDictEqual(metric.get_config(), restored.get_config()) def test_model_evaluate(self): - inputs = keras.Input(shape=(20,), dtype="float32") - outputs = keras.layers.Dense(5)(inputs) - model = keras.Model(inputs=inputs, outputs=outputs) + with self._strategy.scope(): + inputs = keras.Input(shape=(20,), dtype="float32") + outputs = keras.layers.Dense(5)(inputs) + model = keras.Model(inputs=inputs, outputs=outputs) + + model.compile( + loss=keras.losses.MeanSquaredError(), + metrics=[MeanReciprocalRank()], + optimizer="adam", + ) - model.compile( - loss=keras.losses.MeanSquaredError(), - metrics=[MeanReciprocalRank()], - optimizer="adam", - ) model.evaluate( x=keras.random.normal((2, 20)), y=keras.random.randint((2, 5), minval=0, maxval=4), diff --git a/keras_rs/src/metrics/ndcg_test.py b/keras_rs/src/metrics/ndcg_test.py index 8c86e01c..68fc2c96 100644 --- a/keras_rs/src/metrics/ndcg_test.py +++ b/keras_rs/src/metrics/ndcg_test.py @@ -8,6 +8,7 @@ from keras_rs.src import testing from keras_rs.src.metrics.ndcg import NDCG +from keras_rs.src.utils import tpu_test_utils def _compute_dcg(labels, ranks): @@ -19,6 +20,8 @@ def _compute_dcg(labels, ranks): class NDCGTest(testing.TestCase, parameterized.TestCase): def setUp(self): + self._strategy = tpu_test_utils.get_tpu_strategy(self) + self.y_true_batched = ops.array( [ [0, 0, 1, 0], @@ -357,15 +360,17 @@ def inverse_discount_fn(rank): self.assertAllClose(result, ndcg, rtol=1e-5) def test_model_evaluate(self): - inputs = keras.Input(shape=(20,), dtype="float32") - outputs = keras.layers.Dense(5)(inputs) - model = keras.Model(inputs=inputs, outputs=outputs) + with self._strategy.scope(): + inputs = keras.Input(shape=(20,), dtype="float32") + outputs = keras.layers.Dense(5)(inputs) + model = keras.Model(inputs=inputs, outputs=outputs) + + model.compile( + loss=keras.losses.MeanSquaredError(), + metrics=[NDCG()], + optimizer="adam", + ) - model.compile( - loss=keras.losses.MeanSquaredError(), - metrics=[NDCG()], - optimizer="adam", - ) model.evaluate( x=keras.random.normal((2, 20)), y=keras.random.randint((2, 5), minval=0, maxval=4), diff --git a/keras_rs/src/metrics/precision_at_k_test.py b/keras_rs/src/metrics/precision_at_k_test.py index d83c5c9d..62b8348a 100644 --- a/keras_rs/src/metrics/precision_at_k_test.py +++ b/keras_rs/src/metrics/precision_at_k_test.py @@ -6,10 +6,13 @@ from keras_rs.src import testing from keras_rs.src.metrics.precision_at_k import PrecisionAtK +from keras_rs.src.utils import tpu_test_utils class PrecisionAtKTest(testing.TestCase, parameterized.TestCase): def setUp(self): + self._strategy = tpu_test_utils.get_tpu_strategy(self) + self.y_true_batched = ops.array( [ [0, 0, 1, 0], @@ -228,15 +231,17 @@ def test_serialization(self): self.assertDictEqual(metric.get_config(), restored.get_config()) def test_model_evaluate(self): - inputs = keras.Input(shape=(20,), dtype="float32") - outputs = keras.layers.Dense(5)(inputs) - model = keras.Model(inputs=inputs, outputs=outputs) - - model.compile( - loss=keras.losses.MeanSquaredError(), - metrics=[PrecisionAtK(k=3)], - optimizer="adam", - ) + with self._strategy.scope(): + inputs = keras.Input(shape=(20,), dtype="float32") + outputs = keras.layers.Dense(5)(inputs) + model = keras.Model(inputs=inputs, outputs=outputs) + + model.compile( + loss=keras.losses.MeanSquaredError(), + metrics=[PrecisionAtK(k=3)], + optimizer="adam", + ) + model.evaluate( x=keras.random.normal((2, 20)), y=keras.random.randint( diff --git a/keras_rs/src/metrics/recall_at_k_test.py b/keras_rs/src/metrics/recall_at_k_test.py index 1a6672ce..d397f9cd 100644 --- a/keras_rs/src/metrics/recall_at_k_test.py +++ b/keras_rs/src/metrics/recall_at_k_test.py @@ -6,10 +6,13 @@ from keras_rs.src import testing from keras_rs.src.metrics.recall_at_k import RecallAtK +from keras_rs.src.utils import tpu_test_utils class RecallAtKTest(testing.TestCase, parameterized.TestCase): def setUp(self): + self._strategy = tpu_test_utils.get_tpu_strategy(self) + self.y_true_batched = ops.array( [ [0, 0, 1, 0], @@ -231,15 +234,17 @@ def test_serialization(self): self.assertDictEqual(metric.get_config(), restored.get_config()) def test_model_evaluate(self): - inputs = keras.Input(shape=(20,), dtype="float32") - outputs = keras.layers.Dense(5)(inputs) - model = keras.Model(inputs=inputs, outputs=outputs) - - model.compile( - loss=keras.losses.MeanSquaredError(), - metrics=[RecallAtK(k=3)], - optimizer="adam", - ) + with self._strategy.scope(): + inputs = keras.Input(shape=(20,), dtype="float32") + outputs = keras.layers.Dense(5)(inputs) + model = keras.Model(inputs=inputs, outputs=outputs) + + model.compile( + loss=keras.losses.MeanSquaredError(), + metrics=[RecallAtK(k=3)], + optimizer="adam", + ) + model.evaluate( x=keras.random.normal((2, 20)), y=keras.random.randint( diff --git a/keras_rs/src/testing/test_case.py b/keras_rs/src/testing/test_case.py index a764abf3..6dff2bdd 100644 --- a/keras_rs/src/testing/test_case.py +++ b/keras_rs/src/testing/test_case.py @@ -1,12 +1,20 @@ import os import tempfile import unittest -from typing import Any +from typing import Any, Optional, Union import keras import numpy as np +import tensorflow as tf from keras_rs.src import types +from keras_rs.src.utils import tpu_test_utils + +StrategyType = Union[ + tf.distribute.Strategy, + tpu_test_utils.DummyStrategy, + tpu_test_utils.JaxDummyStrategy, +] class TestCase(unittest.TestCase): @@ -16,6 +24,25 @@ def setUp(self) -> None: super().setUp() keras.utils.clear_session() keras.config.disable_traceback_filtering() + if keras.backend.backend() == "tensorflow": + tf.debugging.disable_traceback_filtering() + self.on_tpu = "TPU_NAME" in os.environ + self._strategy: Optional[StrategyType] = None + + @property + def strategy(self) -> StrategyType: + strat = tpu_test_utils.get_shared_tpu_strategy() + + if strat is None: + self.fail( + "TPU environment detected, but the shared TPUStrategy is None. " + "Initialization likely failed." + ) + return strat + # if self._strategy is not None: + # return self._strategy + # self._strategy = tpu_test_utils.get_tpu_strategy(self) + # return self._strategy def assertAllClose( self, @@ -23,6 +50,8 @@ def assertAllClose( desired: types.Tensor, atol: float = 1e-6, rtol: float = 1e-6, + tpu_atol: float | None = None, + tpu_rtol: float | None = None, msg: str = "", ) -> None: """Verify that two tensors are close in value element by element. @@ -34,6 +63,11 @@ def assertAllClose( rtol: Relative tolerance. msg: Optional error message. """ + if tpu_atol is not None and self.on_tpu: + atol = tpu_atol + if tpu_rtol is not None and self.on_tpu: + rtol = tpu_rtol + if not isinstance(actual, np.ndarray): actual = keras.ops.convert_to_numpy(actual) if not isinstance(desired, np.ndarray): diff --git a/keras_rs/src/utils/tpu_test_utils.py b/keras_rs/src/utils/tpu_test_utils.py new file mode 100644 index 00000000..7d5be954 --- /dev/null +++ b/keras_rs/src/utils/tpu_test_utils.py @@ -0,0 +1,170 @@ +import contextlib +import os +import threading +from types import ModuleType +from typing import Any, Callable, ContextManager, Optional, Tuple, Union + +import keras +import tensorflow as tf + +jax: Optional[ModuleType] = None + +try: + import jax +except ImportError: + pass + + +class DummyStrategy: + def scope(self) -> ContextManager[None]: + return contextlib.nullcontext() + + @property + def num_replicas_in_sync(self) -> int: + return 1 + + def run(self, fn: Callable[..., Any], args: Tuple[Any, ...]) -> Any: + return fn(*args) + + def experimental_distribute_dataset( + self, dataset: Any, options: Optional[Any] = None + ) -> Any: + del options + return dataset + + +class JaxDummyStrategy(DummyStrategy): + @property + def num_replicas_in_sync(self) -> Any: + if jax is None: + return 0 + return jax.device_count("tpu") + + +StrategyType = Union[tf.distribute.Strategy, DummyStrategy, JaxDummyStrategy] + +_shared_strategy: Optional[StrategyType] = None +_lock = threading.Lock() + + +def create_tpu_strategy() -> Optional[StrategyType]: + """Initializes the TPU system and returns a TPUStrategy.""" + print("Attempting to create TPUStrategy...") + try: + resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="") + tf.config.experimental_connect_to_cluster(resolver) + tf.tpu.experimental.initialize_tpu_system(resolver) + strategy = tf.distribute.TPUStrategy(resolver) + print( + "TPUStrategy created successfully." + "Devices: {strategy.extended.num_replicas_in_sync}" + ) + return strategy + except Exception as e: + print(f"Error creating TPUStrategy: {e}") + return None + + +def get_shared_tpu_strategy() -> Optional[StrategyType]: + """ + Returns a session-wide shared TPUStrategy instance. + Creates the instance on the first call. + Returns None if not in a TPU environment or if creation fails. + """ + global _shared_strategy + if _shared_strategy is not None: + return _shared_strategy + + with _lock: + if _shared_strategy is None: + if "TPU_NAME" not in os.environ: + _shared_strategy = DummyStrategy() + return _shared_strategy + if keras.backend.backend() == "tensorflow": + resolver = tf.distribute.cluster_resolver.TPUClusterResolver() + tf.config.experimental_connect_to_cluster(resolver) + topology = tf.tpu.experimental.initialize_tpu_system(resolver) + tpu_metadata = resolver.get_tpu_system_metadata() + device_assignment = tf.tpu.experimental.DeviceAssignment.build( + topology, num_replicas=tpu_metadata.num_hosts + ) + _shared_strategy = tf.distribute.TPUStrategy( + resolver, experimental_device_assignment=device_assignment + ) + print("### num_replicas", _shared_strategy.num_replicas_in_sync) + elif keras.backend.backend() == "jax": + if jax is None: + raise ImportError( + "JAX backend requires jax to be installed for TPU." + ) + print("### num_replicas", jax.device_count("tpu")) + _shared_strategy = JaxDummyStrategy() + else: + _shared_strategy = DummyStrategy() + if _shared_strategy is None: + print("Failed to create the shared TPUStrategy.") + return _shared_strategy + + +def get_tpu_strategy(test_case: Any) -> StrategyType: + """Get TPU strategy if on TPU, otherwise return DummyStrategy.""" + if "TPU_NAME" not in os.environ: + return DummyStrategy() + if keras.backend.backend() == "tensorflow": + resolver = tf.distribute.cluster_resolver.TPUClusterResolver() + tf.config.experimental_connect_to_cluster(resolver) + topology = tf.tpu.experimental.initialize_tpu_system(resolver) + tpu_metadata = resolver.get_tpu_system_metadata() + device_assignment = tf.tpu.experimental.DeviceAssignment.build( + topology, num_replicas=tpu_metadata.num_hosts + ) + strategy = tf.distribute.TPUStrategy( + resolver, experimental_device_assignment=device_assignment + ) + print("### num_replicas", strategy.num_replicas_in_sync) + test_case.addCleanup(tf.tpu.experimental.shutdown_tpu_system, resolver) + return strategy + elif keras.backend.backend() == "jax": + if jax is None: + raise ImportError( + "JAX backend requires jax to be installed for TPU." + ) + print("### num_replicas", jax.device_count("tpu")) + return JaxDummyStrategy() + else: + return DummyStrategy() + + +def run_with_strategy( + strategy: Any, + fn: Callable[..., Any], + *args: Any, + jit_compile: bool = False, + **kwargs: Any, +) -> Any: + """ + Final wrapper fix: Flattens allowed kwargs into positional args before + entering tf.function to guarantee a fixed graph signature. + """ + if keras.backend.backend() == "tensorflow": + # Extract sample_weight and treat it as an explicit third positional + # argument. If not present, use a placeholder (None). + sample_weight_value = kwargs.get("sample_weight", None) + all_inputs = args + (sample_weight_value,) + + @tf.function(jit_compile=jit_compile) # type: ignore[misc] + def tf_function_wrapper(input_tuple: Tuple[Any, ...]) -> Any: + num_original_args = len(args) + core_args = input_tuple[:num_original_args] + sw_value = input_tuple[-1] + + if sw_value is not None: + all_positional_args = core_args + (sw_value,) + return strategy.run(fn, args=all_positional_args) + else: + return strategy.run(fn, args=core_args) + + return tf_function_wrapper(all_inputs) + else: + assert not jit_compile + return fn(*args, **kwargs)