Skip to content

Commit 803835f

Browse files
Adds support for Call-Context Arguments
Create an argument propagation flow for call-context arguments. Currently, keras uses `training` argument to infer whether layer should be called in training/inference mode. This change introduces a general flow of propagating arguments from a parent call to a child call (using call_context), so that we can add new control flow arguments in the future using a generic framework. This change does the following things: * Adds a `call_context_args` dictionary in the call_context object to store call-context arguments being propagated. * Changes current layer implementation to use the general propagation flow instead of using hardcoded `training`. * Adds utilities to query and set this context arguments in the `Layer` class. PiperOrigin-RevId: 761325027
1 parent 65c3548 commit 803835f

File tree

11 files changed

+525
-122
lines changed

11 files changed

+525
-122
lines changed

tf_keras/engine/base_layer.py

Lines changed: 223 additions & 96 deletions
Large diffs are not rendered by default.

tf_keras/engine/base_layer_test.py

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1106,6 +1106,171 @@ def __init__(self, var1, var2, var3=None, **kwargs):
11061106
with self.assertRaises(NotImplementedError):
11071107
config = layer.get_config()
11081108

1109+
def test_call_context_args_with_custom_layers_propagates_args(self):
1110+
class Inner(layers.Layer):
1111+
def __init__(self):
1112+
super().__init__()
1113+
self._register_call_context_args("foo_mode")
1114+
1115+
def call(self, x, foo_mode=None):
1116+
return x + (1 if foo_mode else 0)
1117+
1118+
class Outer(layers.Layer):
1119+
def __init__(self):
1120+
super().__init__()
1121+
self._register_call_context_args("foo_mode")
1122+
self.inner = Inner()
1123+
1124+
def call(self, x):
1125+
# Outer doesn’t even need to re‑inject explicitly:
1126+
# our base class will propagate foo_mode automatically
1127+
return self.inner(x)
1128+
1129+
layer = Outer()
1130+
self.assertEqual(int(layer(np.array(0), foo_mode=True)), 1)
1131+
self.assertEqual(int(layer(np.array(0))), 0)
1132+
1133+
def test_register_call_context_arguments_success(self):
1134+
"""Validate that registering call-context args works as expected."""
1135+
1136+
class MyLayer(layers.Layer):
1137+
def call(self, x):
1138+
return x
1139+
1140+
layer = MyLayer()
1141+
1142+
layer._register_call_context_args("foo_mode")
1143+
1144+
self.assertCountEqual(
1145+
layer._call_context_args, ("foo_mode", "training")
1146+
)
1147+
1148+
def test_register_call_context_arguments_after_call_raises_error(self):
1149+
"""Validate that registering call-context args after the layer has
1150+
been called raises an error."""
1151+
1152+
class MyLayer(layers.Layer):
1153+
def call(self, x):
1154+
return x
1155+
1156+
layer = MyLayer()
1157+
layer(np.array(0))
1158+
with self.assertRaisesRegex(
1159+
RuntimeError,
1160+
"Cannot add call-context args after the layer has been called.",
1161+
):
1162+
layer._register_call_context_args("foo_mode")
1163+
1164+
def test_nested_context_args_follow_priority_order(self):
1165+
"""Validate that call-context args are propagated correctly
1166+
through multiple layers, and that the most specific value is used
1167+
when multiple values are passed down the call-stack.
1168+
"""
1169+
1170+
class Inner(base_layer.Layer):
1171+
def __init__(self):
1172+
super().__init__(name="inner_layer")
1173+
self._register_call_context_args("foo_mode")
1174+
1175+
def call(self, inputs, foo_mode=None):
1176+
return inputs + (1 if foo_mode else 0)
1177+
1178+
class Middle(base_layer.Layer):
1179+
def __init__(self):
1180+
super().__init__(name="middle_layer")
1181+
self._inner_layer = Inner()
1182+
1183+
def call(self, inputs):
1184+
return self._inner_layer(inputs)
1185+
1186+
class Outer(base_layer.Layer):
1187+
def __init__(self):
1188+
super().__init__(name="outer_layer")
1189+
self._middle = Middle()
1190+
1191+
def call(self, inputs):
1192+
return self._middle(inputs)
1193+
1194+
layer = Outer()
1195+
layer._register_call_context_args("foo_mode")
1196+
1197+
# The value of foo_mode is set to True in the call to Outer,
1198+
# so it should automatically propagate to Inner through Middle.
1199+
self.assertEqual(int(layer(np.array(0), foo_mode=True)), 1)
1200+
self.assertEqual(int(layer(np.array(0))), 0)
1201+
1202+
def test_context_arg_propagation_without_declaration_does_not_resolve(self):
1203+
"""Validate that layer does not resolve a propagated arg if it is not
1204+
declared as a call-context arg in the layer itself."""
1205+
1206+
class Inner(layers.Layer):
1207+
def call(self, x, foo_mode=None):
1208+
return x + (1 if foo_mode else 0)
1209+
1210+
class Wrapper(layers.Layer):
1211+
def __init__(self):
1212+
super().__init__()
1213+
self.inner = Inner()
1214+
1215+
def call(self, x):
1216+
return self.inner(x)
1217+
1218+
layer = Wrapper()
1219+
layer._register_call_context_args("foo_mode")
1220+
1221+
# The value of foo_mode is set to True in the call to Wrapper,
1222+
# However, it is not declared as a call-context arg in Inner,
1223+
# so it should not resolve to True inside Inner (and instead
1224+
# default to False).
1225+
self.assertEqual(int(layer(np.array(0), foo_mode=True)), 0)
1226+
1227+
def test_call_context_args_with_models_as_layers_propagates_args(self):
1228+
"""Validate that call-context args are propagated correctly
1229+
through functional and sequential models when used as layers.
1230+
"""
1231+
1232+
class InnerLayer(base_layer.Layer):
1233+
def __init__(self):
1234+
super().__init__(name="inner_layer")
1235+
self._register_call_context_args("foo")
1236+
1237+
def call(self, inputs, foo=None):
1238+
if foo:
1239+
return inputs + 1.0
1240+
return inputs
1241+
1242+
class OuterLayer(base_layer.Layer):
1243+
def __init__(self):
1244+
super().__init__(name="outer_layer")
1245+
self._inner_layer = InnerLayer()
1246+
1247+
def call(self, inputs):
1248+
return self._inner_layer(inputs)
1249+
1250+
sample_input = tf.constant([[1.0, 2.0], [3.0, 4.0]], dtype="float32")
1251+
1252+
# Sequential model
1253+
seq = sequential.Sequential([OuterLayer()])
1254+
seq._register_call_context_args("foo")
1255+
1256+
out_true = seq(sample_input, foo=True)
1257+
self.assertAllEqual(out_true, sample_input + 1.0)
1258+
1259+
out_false = seq(sample_input, foo=False)
1260+
self.assertAllEqual(out_false, sample_input)
1261+
1262+
# Functional model
1263+
inp = input_layer.Input((2,))
1264+
outer = OuterLayer()(inp)
1265+
model = training_lib.Model(inputs=[inp], outputs=[outer])
1266+
model._register_call_context_args("foo")
1267+
1268+
out_true = model(sample_input, foo=True)
1269+
self.assertAllEqual(out_true, sample_input + 1.0)
1270+
1271+
out_false = model(sample_input, foo=False)
1272+
self.assertAllEqual(out_false, sample_input)
1273+
11091274

11101275
@test_utils.run_v2_only
11111276
class SymbolicSupportTest(test_combinations.TestCase):

tf_keras/engine/base_layer_utils.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,8 @@ class CallContext:
480480
layer: The `Layer` whose `call` is currently active.
481481
inputs: The inputs to the currently active `Layer`.
482482
build_graph: Whether currently inside a Graph or FuncGraph.
483-
training: Whether currently executing in training or inference mode.
483+
call_context_args: The call-context arguments being propagated through the
484+
the call-stack.
484485
saving: Whether currently saving to SavedModel.
485486
frozen: Whether currently executing inside a `Layer` with `trainable` set
486487
to `False`.
@@ -495,21 +496,25 @@ def __init__(self):
495496
"layer": None,
496497
"inputs": None,
497498
"build_graph": False,
499+
"call_context_args": dict(),
498500
"training": None,
499501
"saving": None,
500502
}
501503
# TODO(b/150169018): This logic can be replaced after the Functional API
502504
# refactor.
503505
self._in_keras_graph = False
504506

505-
def enter(self, layer, inputs, build_graph, training, saving=None):
507+
def enter(
508+
self, layer, inputs, build_graph, call_context_args=dict(), saving=None
509+
):
506510
"""Push a Layer and its inputs and state onto the current call context.
507511
508512
Args:
509513
layer: The `Layer` whose `call` is currently active.
510514
inputs: The inputs to the currently active `Layer`.
511515
build_graph: Whether currently inside a Graph or FuncGraph.
512-
training: Whether currently executing in training or inference mode.
516+
call_context_args: The call-context arguments being propagated through
517+
the call-stack.
513518
saving: Whether currently saving to SavedModel.
514519
515520
Returns:
@@ -519,7 +524,7 @@ def enter(self, layer, inputs, build_graph, training, saving=None):
519524
"layer": layer,
520525
"inputs": inputs,
521526
"build_graph": build_graph,
522-
"training": training,
527+
"call_context_args": call_context_args,
523528
"saving": saving,
524529
}
525530
return CallContextManager(self, state)
@@ -538,7 +543,14 @@ def build_graph(self):
538543

539544
@property
540545
def training(self):
541-
return self._state["training"]
546+
return self.call_context_args.get("training", None)
547+
548+
@property
549+
def call_context_args(self):
550+
return self._state["call_context_args"]
551+
552+
def get_call_context_arg(self, arg_name):
553+
return self.call_context_args.get(arg_name, None)
542554

543555
@property
544556
def saving(self):

tf_keras/engine/base_layer_v1.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def __init__(
132132
self, trainable=True, name=None, dtype=None, dynamic=False, **kwargs
133133
):
134134
self._instrument_layer_creation()
135+
self._called = False
135136

136137
# These properties should be set by the user via keyword arguments.
137138
# note that 'dtype', 'input_shape' and 'batch_input_shape'
@@ -165,6 +166,8 @@ def __init__(
165166
self._input_spec = None
166167
self.supports_masking = False
167168

169+
self._call_context_args = {"training"}
170+
168171
self._init_set_name(name)
169172
self._activity_regularizer = regularizers.get(
170173
kwargs.pop("activity_regularizer", None)
@@ -705,6 +708,7 @@ def __call__(self, *args, **kwargs):
705708
RuntimeError: if `super().__init__()` was not called in the
706709
constructor.
707710
"""
711+
self._called = True
708712
self._assert_built_as_v1()
709713

710714
if not hasattr(self, "_thread_local"):
@@ -803,7 +807,12 @@ def _convert_non_tensor(x):
803807
if build_graph and base_layer_utils.needs_keras_history(inputs):
804808
base_layer_utils.create_keras_history(inputs)
805809

806-
with call_context.enter(self, inputs, build_graph, training_value):
810+
with call_context.enter(
811+
self,
812+
inputs,
813+
build_graph,
814+
call_context_args={"training": training_value},
815+
):
807816
# Check input assumptions set after layer building, e.g. input
808817
# shape.
809818
if build_graph:

tf_keras/layers/core/tf_op_layer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,10 @@ def _call_wrapper(*args, **kwargs):
259259

260260
self._call_spec.expects_training_arg = False
261261
self._call_spec.expects_mask_arg = False
262+
# Clear the call-context arguments for the layer's call method.
263+
# Otherwise, Keras ends up injecting context arguments into the op-call
264+
# when the call method accepts kwargs.
265+
self._call_spec._expected_context_args.clear()
262266

263267
def _call_wrapper(self, *args, **kwargs):
264268
created_variables = []

tf_keras/layers/rnn/base_rnn_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -639,7 +639,9 @@ def test_stacked_rnn_attributes(self):
639639
cells[0].kernel, tf.ones_like(cells[0].kernel)
640640
)
641641
# TODO(b/128682878): Remove when RNNCells are __call__'d.
642-
with base_layer_utils.call_context().enter(layer, x, True, None):
642+
with base_layer_utils.call_context().enter(
643+
layer, x, {"training": True}, None
644+
):
643645
cells[0].add_update(update_1)
644646
cells[0].add_update(update_2)
645647
self.assertEqual(len(layer.updates), 2)

tf_keras/layers/rnn/bidirectional_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,9 @@ def test_Bidirectional_updates(self):
472472
_ = layer(x)
473473
assert not layer.updates
474474
# TODO(b/128684069): Remove when Wrapper sublayers are __call__'d.
475-
with base_layer_utils.call_context().enter(layer, x, True, None):
475+
with base_layer_utils.call_context().enter(
476+
layer, x, {"training": True}, None
477+
):
476478
layer.forward_layer.add_update(x_reachable_update)
477479
layer.forward_layer.add_update(1)
478480
layer.backward_layer.add_update(x_reachable_update)

tf_keras/layers/rnn/cell_wrappers.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,21 @@ def __init__(self, cell, *args, **kwargs):
5252
super().__init__(*args, **kwargs)
5353
self.cell = cell
5454
cell_call_spec = tf_inspect.getfullargspec(cell.call)
55+
accepts_kwargs = cell_call_spec.varkw is not None
56+
5557
self._call_spec.expects_training_arg = (
5658
"training" in cell_call_spec.args
57-
) or (cell_call_spec.varkw is not None)
59+
) or accepts_kwargs
60+
61+
# Filter _expects_context_arg. An argument is kept if:
62+
# 1. It's an explicit argument in cell_call_spec.args OR
63+
# 2. The cell accepts arbitrary keyword arguments (**kwargs),
64+
# meaning it could potentially handle the context argument.
65+
self._call_spec._expected_context_args = {
66+
arg
67+
for arg in self._call_spec._expected_context_args
68+
if (arg in cell_call_spec.args) or accepts_kwargs
69+
}
5870

5971
def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs):
6072
"""Calls the wrapped cell and performs the wrapping logic.

0 commit comments

Comments
 (0)