Skip to content

Commit 4595239

Browse files
authored
Interleave optimizer variable creation to restore backward-compatibility. (#21247)
Added optimizer variables were orginally interleaved during `build` prior to #21232, e.g. `{momentum0, velocity0, momentum1, velocity1, ...}`. In #21232, the order was changed to non-interleaved for some optimizers, e.g. `{momentum0, momentum1, ..., velocity0, velocity1, ...}`. This broke some custom checkpoint serialization compatibility that relied on the order of variables remaining consistent. Here we modify the base function `add_optimizer_variables(...)` to support creating multiple optimizer variables per training variable, and interleaves creation to restore backward compatibility.
1 parent df36b8e commit 4595239

File tree

8 files changed

+80
-35
lines changed

8 files changed

+80
-35
lines changed

keras/src/optimizers/adadelta.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,10 @@ def build(self, var_list):
7575
if self.built:
7676
return
7777
super().build(var_list)
78-
self._accumulated_grads = self.add_optimizer_variables(
79-
var_list, "accumulated_grad"
80-
)
81-
self._accumulated_delta_vars = self.add_optimizer_variables(
82-
var_list, "accumulated_delta_var"
78+
self._accumulated_grads, self._accumulated_delta_vars = (
79+
self.add_optimizer_variables(
80+
var_list, ["accumulated_grad", "accumulated_delta_var"]
81+
)
8382
)
8483

8584
def update_step(self, grad, variable, learning_rate):

keras/src/optimizers/adafactor.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from keras.src import backend
12
from keras.src import ops
23
from keras.src.api_export import keras_export
34
from keras.src.optimizers import optimizer
@@ -96,11 +97,16 @@ def build(self, var_list):
9697
self._c = []
9798
self._v = []
9899
for var in var_list:
99-
if (
100-
self._overwrite_variable_with_gradient(var)
101-
or len(var.shape) < 2
102-
):
103-
# Don't factor if variable is of dimension < 2.
100+
if len(var.shape) < 2:
101+
# Don't factor if variable is of dimension < 2, but we still
102+
# need to create dummy variables as placeholder.
103+
self._r.append(
104+
backend.Variable(0, name=var.name, trainable=False)
105+
)
106+
self._c.append(
107+
backend.Variable(0, name=var.name, trainable=False)
108+
)
109+
elif self._overwrite_variable_with_gradient(var):
104110
self._r.append(None)
105111
self._c.append(None)
106112
else:

keras/src/optimizers/adam.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,9 @@ def build(self, var_list):
9090
if self.built:
9191
return
9292
super().build(var_list)
93-
self._momentums = self.add_optimizer_variables(var_list, "momentum")
94-
self._velocities = self.add_optimizer_variables(var_list, "velocity")
93+
self._momentums, self._velocities = self.add_optimizer_variables(
94+
var_list, ["momentum", "velocity"]
95+
)
9596

9697
if self.amsgrad:
9798
self._velocity_hats = self.add_optimizer_variables(

keras/src/optimizers/adamax.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,9 @@ def build(self, var_list):
9898
if self.built:
9999
return
100100
super().build(var_list)
101-
self._m = self.add_optimizer_variables(var_list, "momentum")
102-
self._u = self.add_optimizer_variables(var_list, "norm")
101+
self._m, self._u = self.add_optimizer_variables(
102+
var_list, ["momentum", "norm"]
103+
)
103104

104105
def update_step(self, gradient, variable, learning_rate):
105106
"""Update step given gradient and the associated model variable."""

keras/src/optimizers/base_optimizer.py

Lines changed: 49 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -338,29 +338,64 @@ def add_optimizer_variables(
338338
Args:
339339
trainable_variables: `keras.Variable`, the corresponding model
340340
variable to the optimizer variable to be created.
341-
name: The name prefix of the optimizer variable to be created. The
342-
variable name will follow the pattern
341+
name: The name prefix(es) of the optimizer variable(s) to be
342+
created. Can be a single string or list of strings. If a
343+
list of strings, will create an optimizer variable for each
344+
prefix. The variable name will follow the pattern
343345
`{variable_name}_{trainable_variable.name}`, e.g.,
344-
`momemtum/dense_1`. Defaults to `None`.
345-
initializer: Initializer object to use to populate the initial
346-
variable value, or string name of a built-in initializer (e.g.
347-
`"random_normal"`). If unspecified, defaults to `"zeros"`.
346+
`momemtum/dense_1`.
347+
initializer: Initializer object(s) to use to populate the initial
348+
variable value(s), or string name of a built-in initializer
349+
(e.g. `"random_normal"`). If unspecified, defaults to
350+
`"zeros"`.
348351
349352
Returns:
350353
A list of optimizer variables, in the format of `keras.Variable`s.
354+
If multiple names are provide, returns a tuple of lists.
351355
"""
352-
optimizer_variables = []
356+
name_list = name
357+
initializer_list = initializer
358+
if isinstance(name, str):
359+
# Single name/initializer.
360+
name_list = [name]
361+
initializer_list = [initializer]
362+
else:
363+
# Multiple names/initializers.
364+
# If there is only one initializer, use it for all names.
365+
if isinstance(initializer, str) or isinstance(
366+
initializer, initializers.Initializer
367+
):
368+
initializer_list = [initializer] * len(name_list)
369+
370+
if len(name_list) != len(initializer_list):
371+
raise ValueError(
372+
f"The number of provided names must match the number of "
373+
f"provided initializers. Received name='{name}', "
374+
f"initializer='{initializer}'"
375+
)
376+
377+
# Build up lists of optimizer variables.
378+
optimizer_variables = tuple([] for _ in name_list)
353379
for variable in trainable_variables:
380+
# Interleaves adding variables for backward-compatibility.
354381
if not self._overwrite_variable_with_gradient(variable):
355-
optimizer_variables.append(
356-
self.add_variable_from_reference(
357-
variable,
358-
name=name,
359-
initializer=initializer,
382+
for i, (var_name, var_init) in enumerate(
383+
zip(name_list, initializer_list)
384+
):
385+
optimizer_variables[i].append(
386+
self.add_variable_from_reference(
387+
variable,
388+
name=var_name,
389+
initializer=var_init,
390+
)
360391
)
361-
)
362392
else:
363-
optimizer_variables.append(None)
393+
for i in range(len(name_list)):
394+
optimizer_variables[i].append(None)
395+
396+
# If single input name, return the single list.
397+
if isinstance(name, str):
398+
return optimizer_variables[0]
364399

365400
return optimizer_variables
366401

keras/src/optimizers/ftrl.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,10 +162,11 @@ def build(self, var_list):
162162
accumulator_initializer = initializers.Constant(
163163
self.initial_accumulator_value,
164164
)
165-
self._accumulators = self.add_optimizer_variables(
166-
var_list, "accumulator", initializer=accumulator_initializer
165+
self._accumulators, self._linears = self.add_optimizer_variables(
166+
var_list,
167+
["accumulator", "linear"],
168+
initializer=[accumulator_initializer, "zeros"],
167169
)
168-
self._linears = self.add_optimizer_variables(var_list, "linear")
169170

170171
def update_step(self, gradient, variable, learning_rate):
171172
"""Update step given gradient and the associated model variable."""

keras/src/optimizers/lamb.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,9 @@ def build(self, var_list):
8282
if self.built:
8383
return
8484
super().build(var_list)
85-
self._momentums = self.add_optimizer_variables(var_list, "momentum")
86-
self._velocities = self.add_optimizer_variables(var_list, "velocity")
85+
self._momentums, self._velocities = self.add_optimizer_variables(
86+
var_list, ["momentum", "velocity"]
87+
)
8788

8889
def update_step(self, gradient, variable, learning_rate):
8990
"""Update step given gradient and the associated model variable."""

keras/src/optimizers/nadam.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,9 @@ def build(self, var_list):
8787
else:
8888
dtype = backend.floatx()
8989
super().build(var_list)
90-
self._momentums = self.add_optimizer_variables(var_list, "momentum")
91-
self._velocities = self.add_optimizer_variables(var_list, "velocity")
90+
self._momentums, self._velocities = self.add_optimizer_variables(
91+
var_list, ["momentum", "velocity"]
92+
)
9293
self._u_product = backend.Variable(1.0, dtype=dtype)
9394

9495
def _backend_update_step(self, grads, trainable_variables, learning_rate):

0 commit comments

Comments
 (0)