33===================================
44
55Hyperparameter tuning can make the difference between an average model
6- and a highly accurate one. Often simple things like choosing a different
7- learning rate or changing a network layer size can have a dramatic
8- impact on your model performance.
6+ and a highly accurate one. Often, simple decisions like choosing a
7+ different learning rate or changing a network layer size can
8+ dramatically impact model performance.
99
1010Fortunately, there are tools that help with finding the best combination
1111of parameters. `Ray Tune <https://docs.ray.io/en/latest/tune.html>`__ is
2121documentation <https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html>`__
2222for training a CIFAR10 image classifier.
2323
24- As you will see, we only need to add some slight modifications. In
25- particular, we need to
24+ We only need to make minor modifications:
2625
27261. wrap data loading and training in functions,
28272. make some network parameters configurable,
29283. add checkpointing (optional),
30- 4. and define the search space for the model tuning
31-
32- |
29+ 4. define the search space for the model tuning
3330
3431To run this tutorial, please make sure the following packages are
3532installed:
6259
6360######################################################################
6461# Most of the imports are needed for building the PyTorch model. Only the
65- # last imports are for Ray Tune.
62+ # last few are specific to Ray Tune.
6663#
6764# Data loaders
6865# ------------
6966#
70- # We wrap the data loaders in their own function and pass a global data
71- # directory. This way we can share a data directory between different
72- # trials.
67+ # We wrap the data loaders in a function and pass a global data directory.
68+ # This allows us to share a data directory across different trials.
7369
7470def load_data (data_dir = "./data" ):
7571 transform = transforms .Compose (
@@ -90,8 +86,8 @@ def load_data(data_dir="./data"):
9086# Configurable neural network
9187# ---------------------------
9288#
93- # We can only tune those parameters that are configurable. In this
94- # example, we can specify the layer sizes of the fully connected layers:
89+ # We can only tune parameters that are configurable. In this example, we
90+ # specify the layer sizes of the fully connected layers:
9591
9692class Net (nn .Module ):
9793 def __init__ (self , l1 = 120 , l2 = 84 ):
@@ -121,14 +117,14 @@ def forward(self, x):
121117# documentation <https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html>`__.
122118#
123119# We wrap the training script in a function
124- # ``train_cifar(config, data_dir=None)``. The ``config`` parameter will
125- # receive the hyperparameters we would like to train with. The
126- # ``data_dir`` specifies the directory where we load and store the data,
127- # so that multiple runs can share the same data source. This is especially
128- # useful in cluster environments, where you can mount a shared storage
129- # (e.g. NFS) to this directory so that the data is not downloaded to each
120+ # ``train_cifar(config, data_dir=None)``. The ``config`` parameter
121+ # receives the hyperparameters we want to train with. The ``data_dir``
122+ # specifies the directory where we load and store the data, allowing
123+ # multiple runs to share the same data source. This is especially useful
124+ # in cluster environments where you can mount a shared storage (e.g. NFS)
125+ # to this directory, preventing the data from being downloaded to each
130126# node separately. We also load the model and optimizer state at the start
131- # of the run, if a checkpoint is provided. Further down in this tutorial
127+ # of the run if a checkpoint is provided. Further down in this tutorial,
132128# you will find information on how to save the checkpoint and what it is
133129# used for.
134130#
@@ -175,9 +171,9 @@ def forward(self, x):
175171# net = nn.DataParallel(net)
176172# net.to(device)
177173#
178- # By using a ``device`` variable we make sure that training also works
179- # when we have no GPUs available . PyTorch requires us to send our data to
180- # the GPU memory explicitly, like this :
174+ # By using a ``device`` variable, we ensure that training works even
175+ # without a GPU . PyTorch requires us to send our data to the GPU memory
176+ # explicitly:
181177#
182178# .. code-block:: python
183179#
@@ -194,7 +190,9 @@ def forward(self, x):
194190# Communicating with Ray Tune
195191# ~~~~~~~~~~~~~~~~~~~~~~~~~~~
196192#
197- # The most interesting part is the communication with Ray Tune:
193+ # The most interesting part is the communication with Ray Tune. As you’ll
194+ # see, integrating Ray Tune into your training code requires only a few
195+ # additional lines:
198196#
199197# .. code-block:: python
200198#
@@ -215,18 +213,27 @@ def forward(self, x):
215213#
216214# Here we first save a checkpoint and then report some metrics back to Ray
217215# Tune. Specifically, we send the validation loss and accuracy back to Ray
218- # Tune. Ray Tune can then use these metrics to decide which hyperparameter
219- # configuration lead to the best results. These metrics can also be used
220- # to stop bad performing trials early in order to avoid wasting resources
221- # on those trials.
216+ # Tune. Ray Tune uses these metrics to determine the best hyperparameter
217+ # configuration and to stop underperforming trials early, saving
218+ # resources.
222219#
223220# The checkpoint saving is optional, however, it is necessary if we wanted
224221# to use advanced schedulers like `Population Based
225222# Training <https://docs.ray.io/en/latest/tune/examples/pbt_guide.html>`__.
226- # Also, by saving the checkpoint we can later load the trained models and
227- # validate them on a test set. Lastly, saving checkpoints is useful for
228- # fault tolerance, and it allows us to interrupt training and continue
229- # training later.
223+ # Saving the checkpoint also allows us to later load the trained models
224+ # for validation on a test set. Lastly, it provides fault tolerance,
225+ # enabling us to pause and resume training.
226+ #
227+ # To summarize, integrating Ray Tune into your PyTorch training requires
228+ # just a few key additions:
229+ #
230+ # - ``tune.report()`` to report metrics (and optionally checkpoints) to
231+ # Ray Tune
232+ # - ``tune.get_checkpoint()`` to load a model from a checkpoint
233+ # - ``Checkpoint.from_directory()`` to create a checkpoint object from
234+ # saved state
235+ #
236+ # The rest of your training code remains standard PyTorch!
230237#
231238# Full training function
232239# ~~~~~~~~~~~~~~~~~~~~~~
@@ -246,6 +253,7 @@ def train_cifar(config, data_dir=None):
246253 criterion = nn .CrossEntropyLoss ()
247254 optimizer = optim .SGD (net .parameters (), lr = config ["lr" ], momentum = 0.9 )
248255
256+ # Load checkpoint if resuming training
249257 checkpoint = tune .get_checkpoint ()
250258 if checkpoint :
251259 with checkpoint .as_directory () as checkpoint_dir :
@@ -317,6 +325,7 @@ def train_cifar(config, data_dir=None):
317325 val_loss += loss .cpu ().numpy ()
318326 val_steps += 1
319327
328+ # Save checkpoint and report metrics
320329 checkpoint_data = {
321330 "epoch" : epoch ,
322331 "net_state_dict" : net .state_dict (),
@@ -331,7 +340,7 @@ def train_cifar(config, data_dir=None):
331340 {"loss" : val_loss / val_steps , "accuracy" : correct / total },
332341 checkpoint = checkpoint ,
333342 )
334-
343+
335344 print ("Finished Training" )
336345
337346######################################################################
@@ -390,11 +399,21 @@ def test_accuracy(net, device="cpu"):
390399# 0.0001 and 0.1. Lastly, the batch size is a choice between 2, 4, 8, and
391400# 16.
392401#
393- # At each trial, Ray Tune will now randomly sample a combination of
394- # parameters from these search spaces. It will then train a number of
395- # models in parallel and find the best performing one among these. We also
396- # use the ``ASHAScheduler`` which will terminate bad performing trials
397- # early.
402+ # For each trial, Ray Tune samples a combination of parameters from these
403+ # search spaces according to the search space configuration and search
404+ # strategy. It then trains multiple models in parallel to identify the
405+ # best performing one.
406+ #
407+ # By default, Ray Tune uses random search to pick the next hyperparameter
408+ # configuration to try. However, Ray Tune also provides more sophisticated
409+ # search algorithms that can more efficiently navigate the search space,
410+ # such as
411+ # `Optuna <https://docs.ray.io/en/latest/tune/api/suggestion.html#optuna>`__,
412+ # `HyperOpt <https://docs.ray.io/en/latest/tune/api/suggestion.html#hyperopt>`__,
413+ # and `Bayesian
414+ # Optimization <https://docs.ray.io/en/latest/tune/api/suggestion.html#bayesopt>`__.
415+ #
416+ # We use the ``ASHAScheduler`` to terminate underperforming trials early.
398417#
399418# We wrap the ``train_cifar`` function with ``functools.partial`` to set
400419# the constant ``data_dir`` parameter. We can also tell Ray Tune what
@@ -423,20 +442,21 @@ def test_accuracy(net, device="cpu"):
423442# You can specify the number of CPUs, which are then available e.g. to
424443# increase the ``num_workers`` of the PyTorch ``DataLoader`` instances.
425444# The selected number of GPUs are made visible to PyTorch in each trial.
426- # Trials do not have access to GPUs that haven’t been requested for them -
427- # so you don’t have to care about two trials using the same set of
428- # resources.
445+ # Trials do not have access to GPUs that haven’t been requested, so you
446+ # don’t need to worry about resource contention.
429447#
430- # Here we can also specify fractional GPUs, so something like
431- # ``gpus_per_trial=0.5`` is completely valid. The trials will then share
432- # GPUs among each other. You just have to make sure that the models still
433- # fit in the GPU memory.
448+ # You can also specify fractional GPUs (e.g., ``gpus_per_trial=0.5``),
449+ # which allows trials to share a GPU. Just ensure that the models fit
450+ # within the GPU memory.
434451#
435452# After training the models, we will find the best performing one and load
436453# the trained network from the checkpoint file. We then obtain the test
437454# set accuracy and report everything by printing.
438455#
439- # The full main function looks like this:
456+ # The full main function looks like this. Note that the
457+ # ``if __name__ == "__main__":`` block is configured for a quick run (1
458+ # trial, 1 epoch, CPU only) to verify that everything works. You should
459+ # increase these values to perform an actual hyperparameter tuning search.
440460
441461def main (num_trials = 10 , max_num_epochs = 10 , gpus_per_trial = 2 ):
442462 print ("Starting hyperparameter tuning." )
@@ -495,7 +515,7 @@ def main(num_trials=10, max_num_epochs=10, gpus_per_trial=2):
495515
496516
497517if __name__ == "__main__" :
498- # You can change the number of GPUs per trial here:
518+ # Set the number of trials, epochs, and GPUs per trial here:
499519 main (num_trials = 1 , max_num_epochs = 1 , gpus_per_trial = 0 )
500520
501521######################################################################
@@ -524,8 +544,8 @@ def main(num_trials=10, max_num_epochs=10, gpus_per_trial=2):
524544# Best trial final validation accuracy: 0.4761
525545# Best trial test set accuracy: 0.4737
526546#
527- # Most trials have been stopped early in order to avoid wasting resources.
528- # The best performing trial achieved a validation accuracy of about 47%,
547+ # Most trials were stopped early to conserve resources. The best
548+ # performing trial achieved a validation accuracy of approximately 47%,
529549# which could be confirmed on the test set.
530550#
531551# So that’s it! You can now tune the parameters of your PyTorch models.
0 commit comments