Skip to content

Commit a941761

Browse files
author
Ricardo Decal
committed
Polish hparam tutorial a bit
1 parent ad79672 commit a941761

File tree

1 file changed

+70
-50
lines changed

1 file changed

+70
-50
lines changed

beginner_source/hyperparameter_tuning_tutorial.py

Lines changed: 70 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
===================================
44
55
Hyperparameter tuning can make the difference between an average model
6-
and a highly accurate one. Often simple things like choosing a different
7-
learning rate or changing a network layer size can have a dramatic
8-
impact on your model performance.
6+
and a highly accurate one. Often, simple decisions like choosing a
7+
different learning rate or changing a network layer size can
8+
dramatically impact model performance.
99
1010
Fortunately, there are tools that help with finding the best combination
1111
of parameters. `Ray Tune <https://docs.ray.io/en/latest/tune.html>`__ is
@@ -21,15 +21,12 @@
2121
documentation <https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html>`__
2222
for training a CIFAR10 image classifier.
2323
24-
As you will see, we only need to add some slight modifications. In
25-
particular, we need to
24+
We only need to make minor modifications:
2625
2726
1. wrap data loading and training in functions,
2827
2. make some network parameters configurable,
2928
3. add checkpointing (optional),
30-
4. and define the search space for the model tuning
31-
32-
|
29+
4. define the search space for the model tuning
3330
3431
To run this tutorial, please make sure the following packages are
3532
installed:
@@ -62,14 +59,13 @@
6259

6360
######################################################################
6461
# Most of the imports are needed for building the PyTorch model. Only the
65-
# last imports are for Ray Tune.
62+
# last few are specific to Ray Tune.
6663
#
6764
# Data loaders
6865
# ------------
6966
#
70-
# We wrap the data loaders in their own function and pass a global data
71-
# directory. This way we can share a data directory between different
72-
# trials.
67+
# We wrap the data loaders in a function and pass a global data directory.
68+
# This allows us to share a data directory across different trials.
7369

7470
def load_data(data_dir="./data"):
7571
transform = transforms.Compose(
@@ -90,8 +86,8 @@ def load_data(data_dir="./data"):
9086
# Configurable neural network
9187
# ---------------------------
9288
#
93-
# We can only tune those parameters that are configurable. In this
94-
# example, we can specify the layer sizes of the fully connected layers:
89+
# We can only tune parameters that are configurable. In this example, we
90+
# specify the layer sizes of the fully connected layers:
9591

9692
class Net(nn.Module):
9793
def __init__(self, l1=120, l2=84):
@@ -121,14 +117,14 @@ def forward(self, x):
121117
# documentation <https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html>`__.
122118
#
123119
# We wrap the training script in a function
124-
# ``train_cifar(config, data_dir=None)``. The ``config`` parameter will
125-
# receive the hyperparameters we would like to train with. The
126-
# ``data_dir`` specifies the directory where we load and store the data,
127-
# so that multiple runs can share the same data source. This is especially
128-
# useful in cluster environments, where you can mount a shared storage
129-
# (e.g. NFS) to this directory so that the data is not downloaded to each
120+
# ``train_cifar(config, data_dir=None)``. The ``config`` parameter
121+
# receives the hyperparameters we want to train with. The ``data_dir``
122+
# specifies the directory where we load and store the data, allowing
123+
# multiple runs to share the same data source. This is especially useful
124+
# in cluster environments where you can mount a shared storage (e.g. NFS)
125+
# to this directory, preventing the data from being downloaded to each
130126
# node separately. We also load the model and optimizer state at the start
131-
# of the run, if a checkpoint is provided. Further down in this tutorial
127+
# of the run if a checkpoint is provided. Further down in this tutorial,
132128
# you will find information on how to save the checkpoint and what it is
133129
# used for.
134130
#
@@ -175,9 +171,9 @@ def forward(self, x):
175171
# net = nn.DataParallel(net)
176172
# net.to(device)
177173
#
178-
# By using a ``device`` variable we make sure that training also works
179-
# when we have no GPUs available. PyTorch requires us to send our data to
180-
# the GPU memory explicitly, like this:
174+
# By using a ``device`` variable, we ensure that training works even
175+
# without a GPU. PyTorch requires us to send our data to the GPU memory
176+
# explicitly:
181177
#
182178
# .. code-block:: python
183179
#
@@ -194,7 +190,9 @@ def forward(self, x):
194190
# Communicating with Ray Tune
195191
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~
196192
#
197-
# The most interesting part is the communication with Ray Tune:
193+
# The most interesting part is the communication with Ray Tune. As you’ll
194+
# see, integrating Ray Tune into your training code requires only a few
195+
# additional lines:
198196
#
199197
# .. code-block:: python
200198
#
@@ -215,18 +213,27 @@ def forward(self, x):
215213
#
216214
# Here we first save a checkpoint and then report some metrics back to Ray
217215
# Tune. Specifically, we send the validation loss and accuracy back to Ray
218-
# Tune. Ray Tune can then use these metrics to decide which hyperparameter
219-
# configuration lead to the best results. These metrics can also be used
220-
# to stop bad performing trials early in order to avoid wasting resources
221-
# on those trials.
216+
# Tune. Ray Tune uses these metrics to determine the best hyperparameter
217+
# configuration and to stop underperforming trials early, saving
218+
# resources.
222219
#
223220
# The checkpoint saving is optional, however, it is necessary if we wanted
224221
# to use advanced schedulers like `Population Based
225222
# Training <https://docs.ray.io/en/latest/tune/examples/pbt_guide.html>`__.
226-
# Also, by saving the checkpoint we can later load the trained models and
227-
# validate them on a test set. Lastly, saving checkpoints is useful for
228-
# fault tolerance, and it allows us to interrupt training and continue
229-
# training later.
223+
# Saving the checkpoint also allows us to later load the trained models
224+
# for validation on a test set. Lastly, it provides fault tolerance,
225+
# enabling us to pause and resume training.
226+
#
227+
# To summarize, integrating Ray Tune into your PyTorch training requires
228+
# just a few key additions:
229+
#
230+
# - ``tune.report()`` to report metrics (and optionally checkpoints) to
231+
# Ray Tune
232+
# - ``tune.get_checkpoint()`` to load a model from a checkpoint
233+
# - ``Checkpoint.from_directory()`` to create a checkpoint object from
234+
# saved state
235+
#
236+
# The rest of your training code remains standard PyTorch!
230237
#
231238
# Full training function
232239
# ~~~~~~~~~~~~~~~~~~~~~~
@@ -246,6 +253,7 @@ def train_cifar(config, data_dir=None):
246253
criterion = nn.CrossEntropyLoss()
247254
optimizer = optim.SGD(net.parameters(), lr=config["lr"], momentum=0.9)
248255

256+
# Load checkpoint if resuming training
249257
checkpoint = tune.get_checkpoint()
250258
if checkpoint:
251259
with checkpoint.as_directory() as checkpoint_dir:
@@ -317,6 +325,7 @@ def train_cifar(config, data_dir=None):
317325
val_loss += loss.cpu().numpy()
318326
val_steps += 1
319327

328+
# Save checkpoint and report metrics
320329
checkpoint_data = {
321330
"epoch": epoch,
322331
"net_state_dict": net.state_dict(),
@@ -331,7 +340,7 @@ def train_cifar(config, data_dir=None):
331340
{"loss": val_loss / val_steps, "accuracy": correct / total},
332341
checkpoint=checkpoint,
333342
)
334-
343+
335344
print("Finished Training")
336345

337346
######################################################################
@@ -390,11 +399,21 @@ def test_accuracy(net, device="cpu"):
390399
# 0.0001 and 0.1. Lastly, the batch size is a choice between 2, 4, 8, and
391400
# 16.
392401
#
393-
# At each trial, Ray Tune will now randomly sample a combination of
394-
# parameters from these search spaces. It will then train a number of
395-
# models in parallel and find the best performing one among these. We also
396-
# use the ``ASHAScheduler`` which will terminate bad performing trials
397-
# early.
402+
# For each trial, Ray Tune samples a combination of parameters from these
403+
# search spaces according to the search space configuration and search
404+
# strategy. It then trains multiple models in parallel to identify the
405+
# best performing one.
406+
#
407+
# By default, Ray Tune uses random search to pick the next hyperparameter
408+
# configuration to try. However, Ray Tune also provides more sophisticated
409+
# search algorithms that can more efficiently navigate the search space,
410+
# such as
411+
# `Optuna <https://docs.ray.io/en/latest/tune/api/suggestion.html#optuna>`__,
412+
# `HyperOpt <https://docs.ray.io/en/latest/tune/api/suggestion.html#hyperopt>`__,
413+
# and `Bayesian
414+
# Optimization <https://docs.ray.io/en/latest/tune/api/suggestion.html#bayesopt>`__.
415+
#
416+
# We use the ``ASHAScheduler`` to terminate underperforming trials early.
398417
#
399418
# We wrap the ``train_cifar`` function with ``functools.partial`` to set
400419
# the constant ``data_dir`` parameter. We can also tell Ray Tune what
@@ -423,20 +442,21 @@ def test_accuracy(net, device="cpu"):
423442
# You can specify the number of CPUs, which are then available e.g. to
424443
# increase the ``num_workers`` of the PyTorch ``DataLoader`` instances.
425444
# The selected number of GPUs are made visible to PyTorch in each trial.
426-
# Trials do not have access to GPUs that haven’t been requested for them -
427-
# so you don’t have to care about two trials using the same set of
428-
# resources.
445+
# Trials do not have access to GPUs that haven’t been requested, so you
446+
# don’t need to worry about resource contention.
429447
#
430-
# Here we can also specify fractional GPUs, so something like
431-
# ``gpus_per_trial=0.5`` is completely valid. The trials will then share
432-
# GPUs among each other. You just have to make sure that the models still
433-
# fit in the GPU memory.
448+
# You can also specify fractional GPUs (e.g., ``gpus_per_trial=0.5``),
449+
# which allows trials to share a GPU. Just ensure that the models fit
450+
# within the GPU memory.
434451
#
435452
# After training the models, we will find the best performing one and load
436453
# the trained network from the checkpoint file. We then obtain the test
437454
# set accuracy and report everything by printing.
438455
#
439-
# The full main function looks like this:
456+
# The full main function looks like this. Note that the
457+
# ``if __name__ == "__main__":`` block is configured for a quick run (1
458+
# trial, 1 epoch, CPU only) to verify that everything works. You should
459+
# increase these values to perform an actual hyperparameter tuning search.
440460

441461
def main(num_trials=10, max_num_epochs=10, gpus_per_trial=2):
442462
print("Starting hyperparameter tuning.")
@@ -495,7 +515,7 @@ def main(num_trials=10, max_num_epochs=10, gpus_per_trial=2):
495515

496516

497517
if __name__ == "__main__":
498-
# You can change the number of GPUs per trial here:
518+
# Set the number of trials, epochs, and GPUs per trial here:
499519
main(num_trials=1, max_num_epochs=1, gpus_per_trial=0)
500520

501521
######################################################################
@@ -524,8 +544,8 @@ def main(num_trials=10, max_num_epochs=10, gpus_per_trial=2):
524544
# Best trial final validation accuracy: 0.4761
525545
# Best trial test set accuracy: 0.4737
526546
#
527-
# Most trials have been stopped early in order to avoid wasting resources.
528-
# The best performing trial achieved a validation accuracy of about 47%,
547+
# Most trials were stopped early to conserve resources. The best
548+
# performing trial achieved a validation accuracy of approximately 47%,
529549
# which could be confirmed on the test set.
530550
#
531551
# So that’s it! You can now tune the parameters of your PyTorch models.

0 commit comments

Comments
 (0)