@@ -1106,6 +1106,171 @@ def __init__(self, var1, var2, var3=None, **kwargs):
1106
1106
with self .assertRaises (NotImplementedError ):
1107
1107
config = layer .get_config ()
1108
1108
1109
+ def test_call_context_args_with_custom_layers_propagates_args (self ):
1110
+ class Inner (layers .Layer ):
1111
+ def __init__ (self ):
1112
+ super ().__init__ ()
1113
+ self ._register_call_context_args ("foo_mode" )
1114
+
1115
+ def call (self , x , foo_mode = None ):
1116
+ return x + (1 if foo_mode else 0 )
1117
+
1118
+ class Outer (layers .Layer ):
1119
+ def __init__ (self ):
1120
+ super ().__init__ ()
1121
+ self ._register_call_context_args ("foo_mode" )
1122
+ self .inner = Inner ()
1123
+
1124
+ def call (self , x ):
1125
+ # Outer doesn’t even need to re‑inject explicitly:
1126
+ # our base class will propagate foo_mode automatically
1127
+ return self .inner (x )
1128
+
1129
+ layer = Outer ()
1130
+ self .assertEqual (int (layer (np .array (0 ), foo_mode = True )), 1 )
1131
+ self .assertEqual (int (layer (np .array (0 ))), 0 )
1132
+
1133
+ def test_register_call_context_arguments_success (self ):
1134
+ """Validate that registering call-context args works as expected."""
1135
+
1136
+ class MyLayer (layers .Layer ):
1137
+ def call (self , x ):
1138
+ return x
1139
+
1140
+ layer = MyLayer ()
1141
+
1142
+ layer ._register_call_context_args ("foo_mode" )
1143
+
1144
+ self .assertCountEqual (
1145
+ layer ._call_context_args , ("foo_mode" , "training" )
1146
+ )
1147
+
1148
+ def test_register_call_context_arguments_after_call_raises_error (self ):
1149
+ """Validate that registering call-context args after the layer has
1150
+ been called raises an error."""
1151
+
1152
+ class MyLayer (layers .Layer ):
1153
+ def call (self , x ):
1154
+ return x
1155
+
1156
+ layer = MyLayer ()
1157
+ layer (np .array (0 ))
1158
+ with self .assertRaisesRegex (
1159
+ RuntimeError ,
1160
+ "Cannot add call-context args after the layer has been called." ,
1161
+ ):
1162
+ layer ._register_call_context_args ("foo_mode" )
1163
+
1164
+ def test_nested_context_args_follow_priority_order (self ):
1165
+ """Validate that call-context args are propagated correctly
1166
+ through multiple layers, and that the most specific value is used
1167
+ when multiple values are passed down the call-stack.
1168
+ """
1169
+
1170
+ class Inner (base_layer .Layer ):
1171
+ def __init__ (self ):
1172
+ super ().__init__ (name = "inner_layer" )
1173
+ self ._register_call_context_args ("foo_mode" )
1174
+
1175
+ def call (self , inputs , foo_mode = None ):
1176
+ return inputs + (1 if foo_mode else 0 )
1177
+
1178
+ class Middle (base_layer .Layer ):
1179
+ def __init__ (self ):
1180
+ super ().__init__ (name = "middle_layer" )
1181
+ self ._inner_layer = Inner ()
1182
+
1183
+ def call (self , inputs ):
1184
+ return self ._inner_layer (inputs )
1185
+
1186
+ class Outer (base_layer .Layer ):
1187
+ def __init__ (self ):
1188
+ super ().__init__ (name = "outer_layer" )
1189
+ self ._middle = Middle ()
1190
+
1191
+ def call (self , inputs ):
1192
+ return self ._middle (inputs )
1193
+
1194
+ layer = Outer ()
1195
+ layer ._register_call_context_args ("foo_mode" )
1196
+
1197
+ # The value of foo_mode is set to True in the call to Outer,
1198
+ # so it should automatically propagate to Inner through Middle.
1199
+ self .assertEqual (int (layer (np .array (0 ), foo_mode = True )), 1 )
1200
+ self .assertEqual (int (layer (np .array (0 ))), 0 )
1201
+
1202
+ def test_context_arg_propagation_without_declaration_does_not_resolve (self ):
1203
+ """Validate that layer does not resolve a propagated arg if it is not
1204
+ declared as a call-context arg in the layer itself."""
1205
+
1206
+ class Inner (layers .Layer ):
1207
+ def call (self , x , foo_mode = None ):
1208
+ return x + (1 if foo_mode else 0 )
1209
+
1210
+ class Wrapper (layers .Layer ):
1211
+ def __init__ (self ):
1212
+ super ().__init__ ()
1213
+ self .inner = Inner ()
1214
+
1215
+ def call (self , x ):
1216
+ return self .inner (x )
1217
+
1218
+ layer = Wrapper ()
1219
+ layer ._register_call_context_args ("foo_mode" )
1220
+
1221
+ # The value of foo_mode is set to True in the call to Wrapper,
1222
+ # However, it is not declared as a call-context arg in Inner,
1223
+ # so it should not resolve to True inside Inner (and instead
1224
+ # default to False).
1225
+ self .assertEqual (int (layer (np .array (0 ), foo_mode = True )), 0 )
1226
+
1227
+ def test_call_context_args_with_models_as_layers_propagates_args (self ):
1228
+ """Validate that call-context args are propagated correctly
1229
+ through functional and sequential models when used as layers.
1230
+ """
1231
+
1232
+ class InnerLayer (base_layer .Layer ):
1233
+ def __init__ (self ):
1234
+ super ().__init__ (name = "inner_layer" )
1235
+ self ._register_call_context_args ("foo" )
1236
+
1237
+ def call (self , inputs , foo = None ):
1238
+ if foo :
1239
+ return inputs + 1.0
1240
+ return inputs
1241
+
1242
+ class OuterLayer (base_layer .Layer ):
1243
+ def __init__ (self ):
1244
+ super ().__init__ (name = "outer_layer" )
1245
+ self ._inner_layer = InnerLayer ()
1246
+
1247
+ def call (self , inputs ):
1248
+ return self ._inner_layer (inputs )
1249
+
1250
+ sample_input = tf .constant ([[1.0 , 2.0 ], [3.0 , 4.0 ]], dtype = "float32" )
1251
+
1252
+ # Sequential model
1253
+ seq = sequential .Sequential ([OuterLayer ()])
1254
+ seq ._register_call_context_args ("foo" )
1255
+
1256
+ out_true = seq (sample_input , foo = True )
1257
+ self .assertAllEqual (out_true , sample_input + 1.0 )
1258
+
1259
+ out_false = seq (sample_input , foo = False )
1260
+ self .assertAllEqual (out_false , sample_input )
1261
+
1262
+ # Functional model
1263
+ inp = input_layer .Input ((2 ,))
1264
+ outer = OuterLayer ()(inp )
1265
+ model = training_lib .Model (inputs = [inp ], outputs = [outer ])
1266
+ model ._register_call_context_args ("foo" )
1267
+
1268
+ out_true = model (sample_input , foo = True )
1269
+ self .assertAllEqual (out_true , sample_input + 1.0 )
1270
+
1271
+ out_false = model (sample_input , foo = False )
1272
+ self .assertAllEqual (out_false , sample_input )
1273
+
1109
1274
1110
1275
@test_utils .run_v2_only
1111
1276
class SymbolicSupportTest (test_combinations .TestCase ):
0 commit comments