Skip to content

Commit b43444d

Browse files
hertschuhtensorflower-gardener
authored andcommitted
Add rng_type to BaseRandomLayer.get_config().
The stateful mode can be enabled either by passing it explicitly to the layer's `__init__` or by using `enable_tf_random_generator`. However, it was not saved in either cases. PiperOrigin-RevId: 775020954
1 parent c79cc0e commit b43444d

File tree

2 files changed

+58
-0
lines changed

2 files changed

+58
-0
lines changed

tf_keras/engine/base_layer.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3932,6 +3932,19 @@ def build(self, input_shape):
39323932
super().build(input_shape)
39333933
self._random_generator._maybe_init()
39343934

3935+
def get_config(self):
3936+
base_config = super().get_config()
3937+
if (
3938+
self._random_generator._rng_type
3939+
== backend.RandomGenerator.RNG_LEGACY_STATEFUL
3940+
):
3941+
return base_config
3942+
3943+
config = {
3944+
"rng_type": self._random_generator._rng_type
3945+
}
3946+
return dict(list(base_config.items()) + list(config.items()))
3947+
39353948
def _trackable_children(self, save_type="checkpoint", **kwargs):
39363949
if save_type == "savedmodel":
39373950
cache = kwargs["cache"]

tf_keras/engine/base_layer_test.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1272,6 +1272,51 @@ def call(self, inputs):
12721272
self.assertAllEqual(out_false, sample_input)
12731273

12741274

1275+
@test_utils.run_v2_only
1276+
class BaseRandomLayerTest(test_combinations.TestCase):
1277+
def teardown(self):
1278+
backend.disable_tf_random_generator()
1279+
1280+
def test_rng_type_is_saved_in_config(self):
1281+
backend.disable_tf_random_generator()
1282+
1283+
layer = base_layer.BaseRandomLayer(rng_type="stateful")
1284+
config = layer.get_config()
1285+
self.assertEqual(config["rng_type"], "stateful")
1286+
reloaded_layer = base_layer.BaseRandomLayer.from_config(config)
1287+
self.assertEqual(reloaded_layer._random_generator._rng_type, "stateful")
1288+
1289+
layer = base_layer.BaseRandomLayer(rng_type="stateless")
1290+
config = layer.get_config()
1291+
self.assertEqual(config["rng_type"], "stateless")
1292+
reloaded_layer = base_layer.BaseRandomLayer.from_config(config)
1293+
self.assertEqual(reloaded_layer._random_generator._rng_type, "stateless")
1294+
1295+
layer = base_layer.BaseRandomLayer()
1296+
config = layer.get_config()
1297+
self.assertNotIn("rng_type", config)
1298+
reloaded_layer = base_layer.BaseRandomLayer.from_config(config)
1299+
self.assertEqual(reloaded_layer._random_generator._rng_type, "legacy_stateful")
1300+
1301+
layer = base_layer.BaseRandomLayer(rng_type="legacy_stateful")
1302+
config = layer.get_config()
1303+
self.assertNotIn("rng_type", config)
1304+
reloaded_layer = base_layer.BaseRandomLayer.from_config(config)
1305+
self.assertEqual(reloaded_layer._random_generator._rng_type, "legacy_stateful")
1306+
1307+
def test_rng_type_with_tf_random_generator(self):
1308+
# Test `rng_type` is still serialized when global stateful mode is on.
1309+
backend.enable_tf_random_generator()
1310+
1311+
layer = base_layer.BaseRandomLayer()
1312+
config = layer.get_config()
1313+
self.assertEqual(config["rng_type"], "stateful")
1314+
1315+
backend.disable_tf_random_generator()
1316+
reloaded_layer = base_layer.BaseRandomLayer.from_config(config)
1317+
self.assertEqual(reloaded_layer._random_generator._rng_type, "stateful")
1318+
1319+
12751320
@test_utils.run_v2_only
12761321
class SymbolicSupportTest(test_combinations.TestCase):
12771322
def test_using_symbolic_tensors_with_tf_ops(self):

0 commit comments

Comments
 (0)