Skip to content

Commit b364dcc

Browse files
Updated internal links with external github links in Optimizer and Learning Rate Scheduler
PiperOrigin-RevId: 639932638
1 parent a400f08 commit b364dcc

File tree

1 file changed

+174
-0
lines changed

1 file changed

+174
-0
lines changed

official/nlp/docs/optimization.md

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
# Optimizer and Learning Rate Scheduler
2+
3+
This page describes the
4+
[optimization package](https://github.com/tensorflow/models/tree/28d972a0b30b628cbb7f67a090ea564c3eda99ea/official/modeling/optimization)
5+
for Tensorflow Official Models (TFM) which includes optimizers, and learning
6+
rate schedulers.
7+
8+
## Building Optimizer and LR Scheduler
9+
10+
We use an Optimizer factory class to manage optimizer and learning rate
11+
creation. Optimizer factory takes a config as an input, and it has member
12+
functions that are used to build optimizer and learning rate schedule. To create
13+
an optimizer and a LR schedule through OptimizerFactory, you need to do the
14+
following:
15+
16+
1. Define optimization config, this includes optimizer, and learning rate
17+
schedule.
18+
2. Initialize the OptimizerFactory instance using the optimization config.
19+
3. Build the learning rate, and the optimizer using the class member functions.
20+
21+
The following is an example for creating an SGD optimizer with stepwise LR
22+
scheduler with linear warmup:
23+
24+
```python
25+
params = {'optimizer': { 'type': 'sgd',
26+
'sgd': {'momentum': 0.9}},
27+
'learning_rate': {'type': 'stepwise',
28+
'stepwise': {
29+
'boundaries': [10000, 20000],
30+
'values': [0.1, 0.01, 0.001]}},
31+
'warmup': {'type': 'linear',
32+
'linear': {'warmup_steps': 500,
33+
'warmup_learning_rate': 0.01}}}
34+
# Defines optimization config from a dictionary.
35+
opt_config = optimization.OptimizationConfig(params)
36+
# Initializes an optimization factory from optimization config.
37+
opt_factory = optimization.OptimizerFactory(opt_config)
38+
# Builds the desired learning rate scheduling instance.
39+
lr = opt_factory.build_learning_rate()
40+
# Builds the optimizer instance with the desired learning rate schedule.
41+
optimizer = opt_factory.build_optimizer(lr)
42+
```
43+
44+
To initialize an OptimizerFactory, `optimizer` and `learning_rate` fields must
45+
be defined, while `warmup` is an optional field. The field `type` is used to
46+
define the type of each optimization component. The set of available types are
47+
explained in details in the following sections.
48+
49+
In the following sections, we explain how to create different optimizers,
50+
learning rate, and warmup schedulers. We also explain how to add new optimizers,
51+
or learning rate schedulers.
52+
53+
## Optimizers
54+
55+
The list of supported optimizers can be found
56+
[here](https://github.com/tensorflow/models/blob/28d972a0b30b628cbb7f67a090ea564c3eda99ea/official/modeling/optimization/optimizer_factory.py#L31).
57+
58+
```python
59+
OPTIMIZERS_CLS = {
60+
'sgd': tf.keras.optimizers.SGD,
61+
'adam': tf.keras.optimizers.Adam,
62+
'adamw': nlp_optimization.AdamWeightDecay,
63+
'lamb': tfa_optimizers.LAMB,
64+
'rmsprop': tf.keras.optimizers.RMSprop
65+
}
66+
```
67+
68+
You can specify the type of optimizer to be one of the above using
69+
[oneof](https://github.com/tensorflow/models/blob/28d972a0b30b628cbb7f67a090ea564c3eda99ea/official/modeling/hyperparams/oneof.py)
70+
config. The available config fields can be found
71+
[here](https://github.com/tensorflow/models/blob/28d972a0b30b628cbb7f67a090ea564c3eda99ea/official/modeling/optimization/configs/optimizer_config.py).
72+
73+
All optimizers support gradient clipping methods: clip by value, clip by norm,
74+
clip by global norm. To specify which method to use, you need to specify the
75+
appropriate field list
76+
[here](https://github.com/tensorflow/models/blob/28d972a0b30b628cbb7f67a090ea564c3eda99ea/official/modeling/optimization/configs/optimizer_config.py#L34).
77+
78+
### Example
79+
80+
We will specify an rmsprop optimizer with discounting factor (rho) of 0.9, and
81+
global norm gradient clipping of 10.0. Below is the config to be used.
82+
83+
```python
84+
params = {'optimizer': { 'type': 'rmsprop',
85+
'rmsprop': {'rho': 0.9,
86+
'global_clipnorm': 10.0}}}
87+
```
88+
89+
### Adding a New Optimizer
90+
91+
To add a new optimizer, you need to do the following:
92+
93+
1. Create a
94+
[custom](https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Optimizer#creating_a_custom_optimizer_2)
95+
of tf.keras.optimizers.Optimizer.
96+
2. Add the required config fields under
97+
[optimization/configs/optimizer_config.py](https://github.com/tensorflow/models/blob/28d972a0b30b628cbb7f67a090ea564c3eda99ea/official/modeling/optimization/configs/optimizer_config.py).
98+
3. Add the optimizer class to the list of available optimizer classes in
99+
(optimizer_factor)[https://github.com/tensorflow/models/blob/28d972a0b30b628cbb7f67a090ea564c3eda99ea/official/modeling/optimization/optimizer_factory.py]
100+
101+
## Learning Rate and Warmup Schedules
102+
103+
Learning rate with an optional warmup can be configured by specifying
104+
`learning_rate`, and `warmup` fields in optimization config. `learning_rate` is
105+
a required field, while `warmup` is an optional one. The list of supported
106+
`learning_rate` and `warmup` schedules can be found
107+
[here](https://github.com/tensorflow/models/blob/28d972a0b30b628cbb7f67a090ea564c3eda99ea/official/modeling/optimization/optimizer_factory.py#L59).
108+
109+
```python
110+
LR_CLS = {
111+
'stepwise': tf.keras.optimizers.schedules.PiecewiseConstantDecay,
112+
'polynomial': tf.keras.optimizers.schedules.PolynomialDecay,
113+
'exponential': tf.keras.optimizers.schedules.ExponentialDecay,
114+
'cosine': tf.keras.experimental.CosineDecay,
115+
'power': lr_schedule.DirectPowerDecay,
116+
}
117+
118+
WARMUP_CLS = {
119+
'linear': lr_schedule.LinearWarmup,
120+
'polynomial': lr_schedule.PolynomialWarmUp
121+
}
122+
```
123+
124+
In addition, a `constant` learning rate can be specified.
125+
126+
## How Learning Rate Works
127+
128+
Learning rate takes `step` as an input, and it returns the learning rate value.
129+
As the training progresses, usually learning rate value decays. Warmup schedule
130+
is often used to stablize the training. Warmup schedule starts from a low
131+
learning rate value, and it gradually increases until it reaches the initial
132+
value for the regular learning rate decay schedule. We combine `learning_rate`
133+
(lr) with `warmup` (warmup) schedules as follows
134+
135+
* Steps [0, warmup_steps): `learning_rate = warmup(step)`
136+
* Steps [warmup_steps, train_steps): `learning_rate = lr(step)`
137+
* We designed the warmup schedule such that final warmup learning rate is
138+
inferred from the learning rate schedule (i.e.
139+
`learning_rate(warmup_steps) = warmup(warmup_steps)`). Note that, warmup
140+
schedule doesn't delay the regular learning rate decay by warmup_steps,
141+
instead it replaces it.
142+
143+
Learning rate value is logged every
144+
[summary_interval](https://github.com/tensorflow/models/blob/28d972a0b30b628cbb7f67a090ea564c3eda99ea/official/core/config_definitions.py#L262).
145+
If warmup_steps are less that the `summary_interval`, you won't be able to see
146+
warmup values.
147+
148+
### Example
149+
150+
We want to specify a cosine learning rate decay with decay_steps of 20000, with
151+
a linear warmup schedule for the first 500 steps.
152+
153+
```python
154+
params = {'learning_rate': {'type': 'cosine',
155+
'cosine': {'decay_steps': 20000}},
156+
'warmup': {'type': 'linear',
157+
'linear': {'warmup_steps': 500}}}
158+
```
159+
160+
## Customizing Optimizer Inside Task
161+
162+
Optimizer and learning rate are created inside the
163+
[task](https://github.com/tensorflow/models/blob/28d972a0b30b628cbb7f67a090ea564c3eda99ea/official/core/base_task.py#L99).
164+
If different optimizers/learning rate schedulers are needed, they can be defined
165+
by overriding the class method.
166+
167+
## Important Factors To Consider
168+
169+
* Batch size: Changing batch size usually requires scaling learning rate
170+
values, and number of training steps. Make sure that you change appropriate
171+
values as batch size changes.
172+
* Train steps: Train steps is highly correlated with fields such as
173+
`decay_steps` for cosine learning rate decay. Changing one without changing
174+
the other might result in undesired behavior.

0 commit comments

Comments
 (0)