@@ -1272,6 +1272,51 @@ def call(self, inputs):
1272
1272
self .assertAllEqual (out_false , sample_input )
1273
1273
1274
1274
1275
+ @test_utils .run_v2_only
1276
+ class BaseRandomLayerTest (test_combinations .TestCase ):
1277
+ def teardown (self ):
1278
+ backend .disable_tf_random_generator ()
1279
+
1280
+ def test_rng_type_is_saved_in_config (self ):
1281
+ backend .disable_tf_random_generator ()
1282
+
1283
+ layer = base_layer .BaseRandomLayer (rng_type = "stateful" )
1284
+ config = layer .get_config ()
1285
+ self .assertEqual (config ["rng_type" ], "stateful" )
1286
+ reloaded_layer = base_layer .BaseRandomLayer .from_config (config )
1287
+ self .assertEqual (reloaded_layer ._random_generator ._rng_type , "stateful" )
1288
+
1289
+ layer = base_layer .BaseRandomLayer (rng_type = "stateless" )
1290
+ config = layer .get_config ()
1291
+ self .assertEqual (config ["rng_type" ], "stateless" )
1292
+ reloaded_layer = base_layer .BaseRandomLayer .from_config (config )
1293
+ self .assertEqual (reloaded_layer ._random_generator ._rng_type , "stateless" )
1294
+
1295
+ layer = base_layer .BaseRandomLayer ()
1296
+ config = layer .get_config ()
1297
+ self .assertNotIn ("rng_type" , config )
1298
+ reloaded_layer = base_layer .BaseRandomLayer .from_config (config )
1299
+ self .assertEqual (reloaded_layer ._random_generator ._rng_type , "legacy_stateful" )
1300
+
1301
+ layer = base_layer .BaseRandomLayer (rng_type = "legacy_stateful" )
1302
+ config = layer .get_config ()
1303
+ self .assertNotIn ("rng_type" , config )
1304
+ reloaded_layer = base_layer .BaseRandomLayer .from_config (config )
1305
+ self .assertEqual (reloaded_layer ._random_generator ._rng_type , "legacy_stateful" )
1306
+
1307
+ def test_rng_type_with_tf_random_generator (self ):
1308
+ # Test `rng_type` is still serialized when global stateful mode is on.
1309
+ backend .enable_tf_random_generator ()
1310
+
1311
+ layer = base_layer .BaseRandomLayer ()
1312
+ config = layer .get_config ()
1313
+ self .assertEqual (config ["rng_type" ], "stateful" )
1314
+
1315
+ backend .disable_tf_random_generator ()
1316
+ reloaded_layer = base_layer .BaseRandomLayer .from_config (config )
1317
+ self .assertEqual (reloaded_layer ._random_generator ._rng_type , "stateful" )
1318
+
1319
+
1275
1320
@test_utils .run_v2_only
1276
1321
class SymbolicSupportTest (test_combinations .TestCase ):
1277
1322
def test_using_symbolic_tensors_with_tf_ops (self ):
0 commit comments