Skip to content

Commit 2128d76

Browse files
hertschuhtensorflower-gardener
authored andcommitted
Add rng_type to BaseRandomLayer.get_config().
The stateful mode can be enabled either by passing it explicitly to the layer's `__init__` or by using `enable_tf_random_generator`. However, it was not saved in either cases. PiperOrigin-RevId: 775020954
1 parent c79cc0e commit 2128d76

File tree

2 files changed

+62
-0
lines changed

2 files changed

+62
-0
lines changed

tf_keras/engine/base_layer.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3932,6 +3932,17 @@ def build(self, input_shape):
39323932
super().build(input_shape)
39333933
self._random_generator._maybe_init()
39343934

3935+
def get_config(self):
3936+
base_config = super().get_config()
3937+
if (
3938+
self._random_generator._rng_type
3939+
== backend.RandomGenerator.RNG_LEGACY_STATEFUL
3940+
):
3941+
return base_config
3942+
3943+
config = {"rng_type": self._random_generator._rng_type}
3944+
return dict(list(base_config.items()) + list(config.items()))
3945+
39353946
def _trackable_children(self, save_type="checkpoint", **kwargs):
39363947
if save_type == "savedmodel":
39373948
cache = kwargs["cache"]

tf_keras/engine/base_layer_test.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1272,6 +1272,57 @@ def call(self, inputs):
12721272
self.assertAllEqual(out_false, sample_input)
12731273

12741274

1275+
@test_utils.run_v2_only
1276+
class BaseRandomLayerTest(test_combinations.TestCase):
1277+
def teardown(self):
1278+
backend.disable_tf_random_generator()
1279+
1280+
def test_rng_type_is_saved_in_config(self):
1281+
backend.disable_tf_random_generator()
1282+
1283+
layer = base_layer.BaseRandomLayer(rng_type="stateful")
1284+
config = layer.get_config()
1285+
self.assertEqual(config["rng_type"], "stateful")
1286+
reloaded_layer = base_layer.BaseRandomLayer.from_config(config)
1287+
self.assertEqual(reloaded_layer._random_generator._rng_type, "stateful")
1288+
1289+
layer = base_layer.BaseRandomLayer(rng_type="stateless")
1290+
config = layer.get_config()
1291+
self.assertEqual(config["rng_type"], "stateless")
1292+
reloaded_layer = base_layer.BaseRandomLayer.from_config(config)
1293+
self.assertEqual(
1294+
reloaded_layer._random_generator._rng_type, "stateless"
1295+
)
1296+
1297+
layer = base_layer.BaseRandomLayer()
1298+
config = layer.get_config()
1299+
self.assertNotIn("rng_type", config)
1300+
reloaded_layer = base_layer.BaseRandomLayer.from_config(config)
1301+
self.assertEqual(
1302+
reloaded_layer._random_generator._rng_type, "legacy_stateful"
1303+
)
1304+
1305+
layer = base_layer.BaseRandomLayer(rng_type="legacy_stateful")
1306+
config = layer.get_config()
1307+
self.assertNotIn("rng_type", config)
1308+
reloaded_layer = base_layer.BaseRandomLayer.from_config(config)
1309+
self.assertEqual(
1310+
reloaded_layer._random_generator._rng_type, "legacy_stateful"
1311+
)
1312+
1313+
def test_rng_type_with_tf_random_generator(self):
1314+
# Test `rng_type` is still serialized when global stateful mode is on.
1315+
backend.enable_tf_random_generator()
1316+
1317+
layer = base_layer.BaseRandomLayer()
1318+
config = layer.get_config()
1319+
self.assertEqual(config["rng_type"], "stateful")
1320+
1321+
backend.disable_tf_random_generator()
1322+
reloaded_layer = base_layer.BaseRandomLayer.from_config(config)
1323+
self.assertEqual(reloaded_layer._random_generator._rng_type, "stateful")
1324+
1325+
12751326
@test_utils.run_v2_only
12761327
class SymbolicSupportTest(test_combinations.TestCase):
12771328
def test_using_symbolic_tensors_with_tf_ops(self):

0 commit comments

Comments
 (0)