@@ -1272,6 +1272,57 @@ def call(self, inputs):
1272
1272
self .assertAllEqual (out_false , sample_input )
1273
1273
1274
1274
1275
+ @test_utils .run_v2_only
1276
+ class BaseRandomLayerTest (test_combinations .TestCase ):
1277
+ def teardown (self ):
1278
+ backend .disable_tf_random_generator ()
1279
+
1280
+ def test_rng_type_is_saved_in_config (self ):
1281
+ backend .disable_tf_random_generator ()
1282
+
1283
+ layer = base_layer .BaseRandomLayer (rng_type = "stateful" )
1284
+ config = layer .get_config ()
1285
+ self .assertEqual (config ["rng_type" ], "stateful" )
1286
+ reloaded_layer = base_layer .BaseRandomLayer .from_config (config )
1287
+ self .assertEqual (reloaded_layer ._random_generator ._rng_type , "stateful" )
1288
+
1289
+ layer = base_layer .BaseRandomLayer (rng_type = "stateless" )
1290
+ config = layer .get_config ()
1291
+ self .assertEqual (config ["rng_type" ], "stateless" )
1292
+ reloaded_layer = base_layer .BaseRandomLayer .from_config (config )
1293
+ self .assertEqual (
1294
+ reloaded_layer ._random_generator ._rng_type , "stateless"
1295
+ )
1296
+
1297
+ layer = base_layer .BaseRandomLayer ()
1298
+ config = layer .get_config ()
1299
+ self .assertNotIn ("rng_type" , config )
1300
+ reloaded_layer = base_layer .BaseRandomLayer .from_config (config )
1301
+ self .assertEqual (
1302
+ reloaded_layer ._random_generator ._rng_type , "legacy_stateful"
1303
+ )
1304
+
1305
+ layer = base_layer .BaseRandomLayer (rng_type = "legacy_stateful" )
1306
+ config = layer .get_config ()
1307
+ self .assertNotIn ("rng_type" , config )
1308
+ reloaded_layer = base_layer .BaseRandomLayer .from_config (config )
1309
+ self .assertEqual (
1310
+ reloaded_layer ._random_generator ._rng_type , "legacy_stateful"
1311
+ )
1312
+
1313
+ def test_rng_type_with_tf_random_generator (self ):
1314
+ # Test `rng_type` is still serialized when global stateful mode is on.
1315
+ backend .enable_tf_random_generator ()
1316
+
1317
+ layer = base_layer .BaseRandomLayer ()
1318
+ config = layer .get_config ()
1319
+ self .assertEqual (config ["rng_type" ], "stateful" )
1320
+
1321
+ backend .disable_tf_random_generator ()
1322
+ reloaded_layer = base_layer .BaseRandomLayer .from_config (config )
1323
+ self .assertEqual (reloaded_layer ._random_generator ._rng_type , "stateful" )
1324
+
1325
+
1275
1326
@test_utils .run_v2_only
1276
1327
class SymbolicSupportTest (test_combinations .TestCase ):
1277
1328
def test_using_symbolic_tensors_with_tf_ops (self ):
0 commit comments