Skip to content

Commit e360448

Browse files
author
Ricardo Decal
committed
finalize the hyperparameter tuning tutorial
1 parent 1f11769 commit e360448

File tree

3 files changed

+160
-99
lines changed

3 files changed

+160
-99
lines changed

beginner_source/hyperparameter_tuning_tutorial.py

Lines changed: 158 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -2,34 +2,38 @@
22
Hyperparameter tuning with Ray Tune
33
===================================
44
5-
Hyperparameter tuning can make the difference between an average model
6-
and a highly accurate one. Often, simple decisions like choosing a
7-
different learning rate or changing a network layer size can
8-
dramatically impact model performance.
9-
10-
This page shows how to integrate `Ray
11-
Tune <https://docs.ray.io/en/latest/tune.html>`__ into your PyTorch
12-
training workflow for distributed hyperparameter tuning. It extends the
13-
PyTorch tutorial for training a CIFAR10 image classifier in the `CIFAR10
14-
tutorial (PyTorch
5+
This tutorial shows how to integrate Ray Tune into your PyTorch training
6+
workflow to perform scalable and efficient hyperparameter tuning.
7+
8+
`Ray <https://docs.ray.io/en/latest/index.html>`__, a project of the
9+
PyTorch Foundation, is an open-source unified framework for scaling AI
10+
and Python applications. It helps run distributed workloads by handling
11+
the complexity of distributed computing. `Ray
12+
Tune <https://docs.ray.io/en/latest/tune/index.html>`__ is a library
13+
built on Ray for hyperparameter tuning that enables you to scale a
14+
hyperparameter sweep from your machine to a large cluster with no code
15+
changes.
16+
17+
This tutorial extends the PyTorch tutorial for training a CIFAR10 image
18+
classifier in the `CIFAR10 tutorial (PyTorch
1519
documentation) <https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html>`__.
20+
Only minor modifications are needed to adapt the PyTorch tutorial for
21+
Ray Tune. Specifically, this tutorial wraps the data loading and
22+
training in functions, makes some network parameters configurable, adds
23+
optional checkpointing, and defines the search space for model tuning.
1624
17-
Only minor modifications are needed. Specifically, this example wraps
18-
data loading and training in functions, makes some network parameters
19-
configurable, adds optional checkpointing, and defines the search space
20-
for model tuning.
25+
Setup
26+
-----
2127
22-
To run this tutorial, install the following prerequisites:
28+
To run this tutorial, install the dependencies:
2329
24-
- ``ray[tune]`` – Distributed hyperparameter tuning library
25-
- ``torchvision`` – Data transforms for computer vision datasets
26-
27-
Setup and imports
28-
-----------------
30+
"""
2931

30-
Let’s start with the imports:
32+
# %%bash
33+
# pip install "ray[tune]" torchvision
3134

32-
"""
35+
######################################################################
36+
# Then start with the imports:
3337

3438
from functools import partial
3539
import os
@@ -42,20 +46,18 @@
4246
from torch.utils.data import random_split
4347
import torchvision
4448
import torchvision.transforms as transforms
49+
# New: imports for Ray Tune
4550
import ray
4651
from ray import tune
4752
from ray.tune import Checkpoint
4853
from ray.tune.schedulers import ASHAScheduler
4954

5055
######################################################################
51-
# Most of the imports are needed for building the PyTorch model. Only the
52-
# last few are specific to Ray Tune.
53-
#
54-
# Data loaders
55-
# ------------
56+
# How to use PyTorch data loaders with Ray Tune
57+
# ---------------------------------------------
5658
#
57-
# We wrap the data loaders in a function and pass a global data directory.
58-
# This allows us to share a data directory across different trials.
59+
# Wrap the data loaders in a constructor function. Pass a global data
60+
# directory here to reuse the dataset across different trials.
5961

6062
def load_data(data_dir="./data"):
6163
transform = transforms.Compose(
@@ -73,15 +75,15 @@ def load_data(data_dir="./data"):
7375
return trainset, testset
7476

7577
######################################################################
76-
# Configurable neural network
77-
# ---------------------------
78+
# Configure the hyperparameters
79+
# -----------------------------
7880
#
7981
# In this example, we specify the layer sizes of the fully connected
8082
# layers.
8183

8284
class Net(nn.Module):
8385
def __init__(self, l1=120, l2=84):
84-
super(Net, self).__init__()
86+
super().__init__()
8587
self.conv1 = nn.Conv2d(3, 6, 5)
8688
self.pool = nn.MaxPool2d(2, 2)
8789
self.conv2 = nn.Conv2d(6, 16, 5)
@@ -99,23 +101,23 @@ def forward(self, x):
99101
return x
100102

101103
######################################################################
102-
# Train function
103-
# --------------
104+
# Use a train function with Ray Tune
105+
# ----------------------------------
104106
#
105107
# Now it gets interesting, because we introduce some changes to the
106-
# example from the `CIFAR10 tutorial (PyTorch
107-
# documentation) <https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html>`__.
108+
# example `from the PyTorch
109+
# documentation <https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html>`__.
108110
#
109111
# We wrap the training script in a function
110112
# ``train_cifar(config, data_dir=None)``. The ``config`` parameter
111113
# receives the hyperparameters we want to train with. The ``data_dir``
112114
# specifies the directory where we load and store the data, allowing
113115
# multiple runs to share the same data source. This is especially useful
114116
# in cluster environments where you can mount shared storage (for example
115-
# NFS), preventing the data from being downloaded to each node separately.
117+
# NFS) to prevent the data from being downloaded to each node separately.
116118
# We also load the model and optimizer state at the start of the run if a
117119
# checkpoint is provided. Further down in this tutorial, you will find
118-
# information on how to save the checkpoint and what it is used for.
120+
# information on how to save the checkpoint and how it is used.
119121
#
120122
# .. code-block:: python
121123
#
@@ -143,12 +145,12 @@ def forward(self, x):
143145
# the remaining 20%. The batch sizes with which we iterate through the
144146
# training and test sets are configurable as well.
145147
#
146-
# Adding (multi) GPU support with DataParallel
147-
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
148+
# Add multi-GPU support with DataParallel
149+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
148150
#
149-
# Image classification benefits largely from GPUs. Luckily, we can
150-
# continue to use PyTorch’s tools in Ray Tune. Thus, we can wrap our model
151-
# in ``nn.DataParallel`` to support data parallel training on multiple
151+
# Image classification benefits largely from GPUs. Luckily, you can
152+
# continue to use PyTorch tools in Ray Tune. Thus, you can wrap the model
153+
# in ``nn.DataParallel`` to support data-parallel training on multiple
152154
# GPUs:
153155
#
154156
# .. code-block:: python
@@ -206,7 +208,7 @@ def forward(self, x):
206208
# configuration and to stop underperforming trials early, saving
207209
# resources.
208210
#
209-
# The checkpoint saving is optional, however, it is necessary if we wanted
211+
# The checkpoint saving is optional. However, it is necessary if we wanted
210212
# to use advanced schedulers like `Population Based
211213
# Training <https://docs.ray.io/en/latest/tune/examples/pbt_guide.html>`__.
212214
# Saving the checkpoint also allows us to later load the trained models
@@ -218,7 +220,7 @@ def forward(self, x):
218220
# optionally checkpoints) to Ray Tune, ``tune.get_checkpoint()`` to load a
219221
# model from a checkpoint, and ``Checkpoint.from_directory()`` to create a
220222
# checkpoint object from saved state. The rest of your training code
221-
# remains standard PyTorch!
223+
# remains standard PyTorch.
222224
#
223225
# Full training function
224226
# ~~~~~~~~~~~~~~~~~~~~~~
@@ -332,8 +334,8 @@ def train_cifar(config, data_dir=None):
332334
# As you can see, most of the code is adapted directly from the original
333335
# example.
334336
#
335-
# Test set accuracy
336-
# -----------------
337+
# Compute test set accuracy
338+
# -------------------------
337339
#
338340
# Commonly the performance of a machine learning model is tested on a
339341
# held-out test set with data that has not been used for training the
@@ -360,58 +362,95 @@ def test_accuracy(net, device="cpu"):
360362
return correct / total
361363

362364
######################################################################
363-
# The function also expects a ``device`` parameter so we can do the test
365+
# The function also expects a ``device`` parameter so you can run the test
364366
# set validation on a GPU.
365367
#
366-
# Search space configuration
368+
# Configure the search space
367369
# --------------------------
368370
#
369-
# Lastly, we need to define Ray Tune’s search space. Here is an example:
371+
# Lastly, we need to define Ray Tune’s search space. Ray Tune offers a
372+
# variety of `search space
373+
# distributions <https://docs.ray.io/en/latest/tune/api/search_space.html>`__
374+
# to suit different parameter types: ``loguniform``, ``uniform``,
375+
# ``choice``, ``randint``, ``grid``, and more. It also lets you express
376+
# complex dependencies between parameters with `conditional search
377+
# spaces <https://docs.ray.io/en/latest/tune/tutorials/tune-search-spaces.html#how-to-use-custom-and-conditional-search-spaces-in-tune>`__.
378+
#
379+
# Here is an example:
370380
#
371381
# .. code-block:: python
372382
#
373383
# config = {
374-
# "l1": tune.choice([2 ** i for i in range(9)]),
375-
# "l2": tune.choice([2 ** i for i in range(9)]),
384+
# "l1": tune.choice([2**i for i in range(9)]),
385+
# "l2": tune.choice([2**i for i in range(9)]),
376386
# "lr": tune.loguniform(1e-4, 1e-1),
377-
# "batch_size": tune.choice([2, 4, 8, 16])
387+
# "batch_size": tune.choice([2, 4, 8, 16]),
378388
# }
379389
#
380390
# The ``tune.choice()`` accepts a list of values that are uniformly
381-
# sampled from. In this example, the ``l1`` and ``l2`` parameters should
382-
# be powers of 2 between 1 and 256: 1, 2, 4, 8, 16, 32, 64, 128, or 256.
383-
# The ``lr`` (learning rate) should be uniformly sampled between 0.0001
384-
# and 0.1. Lastly, the batch size is a choice between 2, 4, 8, and 16.
385-
#
386-
# For each trial, Ray Tune samples a combination of parameters from these
387-
# search spaces according to the search space configuration and search
388-
# strategy. It then trains multiple models in parallel to identify the
389-
# best performing one.
390-
#
391-
# By default, Ray Tune uses random search to pick the next hyperparameter
392-
# configuration to try. However, Ray Tune also provides more sophisticated
393-
# search algorithms that can more efficiently navigate the search space,
394-
# such as
395-
# `Optuna <https://docs.ray.io/en/latest/tune/api/suggestion.html#optuna>`__,
396-
# `HyperOpt <https://docs.ray.io/en/latest/tune/api/suggestion.html#hyperopt>`__,
397-
# and `Bayesian
398-
# Optimization <https://docs.ray.io/en/latest/tune/api/suggestion.html#bayesopt>`__.
399-
#
400-
# We use the ``ASHAScheduler`` to terminate underperforming trials early.
401-
#
402-
# We wrap the ``train_cifar`` function with ``functools.partial`` to set
403-
# the constant ``data_dir`` parameter. We can also tell Ray Tune what
404-
# resources should be available for each trial using
391+
# sampled from. In this example, the ``l1`` and ``l2`` parameter values
392+
# will be powers of 2 between 1 and 256. The learning rate is sampled on a
393+
# log scale between 0.0001 and 0.1. Sampling on a log scale ensures that
394+
# the search space is explored efficiently across different magnitudes.
395+
#
396+
# Smarter sampling and scheduling
397+
# -------------------------------
398+
#
399+
# To make the hyperparameter search process efficient, Ray Tune provides
400+
# two main controls:
401+
#
402+
# 1. It can intelligently pick the next set of hyperparameters to test
403+
# based on previous results using `advanced search
404+
# algorithms <https://docs.ray.io/en/latest/tune/api/suggestion.html>`__
405+
# such as
406+
# `Optuna <https://docs.ray.io/en/latest/tune/api/suggestion.html#optuna>`__
407+
# or
408+
# ```bayesopt`` <https://docs.ray.io/en/latest/tune/api/suggestion.html#bayesopt>`__,
409+
# instead of relying only on random or grid search.
410+
# 2. It can detect underperforming trials and stop them early using
411+
# `schedulers <https://docs.ray.io/en/latest/tune/key-concepts.html#tune-schedulers>`__,
412+
# enabling you to explore the parameter space more on the same compute
413+
# budget.
414+
#
415+
# In this tutorial, we use the ``ASHAScheduler``, which aggressively
416+
# terminates low-performing trials to save computational resources.
417+
#
418+
# Configure the resources
419+
# -----------------------
420+
#
421+
# Tell Ray Tune what resources should be available for each trial using
405422
# ``tune.with_resources``:
406423
#
407424
# .. code-block:: python
408425
#
409-
# gpus_per_trial = 2
410-
# # ...
426+
# tune.with_resources(
427+
# partial(train_cifar, data_dir=data_dir),
428+
# resources={"cpu": cpus_per_trial, "gpu": gpus_per_trial}
429+
# )
430+
#
431+
# This tells Ray Tune to allocate ``cpus_per_trial`` CPUs and
432+
# ``gpus_per_trial`` GPUs for each trial. Ray Tune automatically manages
433+
# the placement of these trials and ensures they are isolated, so you
434+
# don’t need to manually assign GPUs to processes.
435+
#
436+
# For example, if you are running this experiment on a cluster of 20
437+
# machines, each with 8 GPUs, you can set ``gpus_per_trial = 0.5`` to
438+
# schedule 2 concurrent trials per GPU. This configuration runs 320 trials
439+
# in parallel across the cluster.
440+
#
441+
# Putting it together
442+
# -------------------
443+
#
444+
# The Ray Tune API is designed to be modular and composable: you pass your
445+
# configurations to the ``tune.Tuner`` class to create a tuner object,
446+
# then execute ``tuner.fit()`` to start training:
447+
#
448+
# .. code-block:: python
449+
#
411450
# tuner = tune.Tuner(
412451
# tune.with_resources(
413452
# partial(train_cifar, data_dir=data_dir),
414-
# resources={"cpu": 8, "gpu": gpus_per_trial}
453+
# resources={"cpu": cpus_per_trial, "gpu": gpus_per_trial}
415454
# ),
416455
# tune_config=tune.TuneConfig(
417456
# metric="loss",
@@ -423,24 +462,9 @@ def test_accuracy(net, device="cpu"):
423462
# )
424463
# results = tuner.fit()
425464
#
426-
# Specify the number of CPUs, which are then available, for example to
427-
# increase the ``num_workers`` of the PyTorch ``DataLoader`` instances.
428-
# The selected number of GPUs are made visible to PyTorch in each trial.
429-
# Trials do not have access to GPUs that have not been requested, so you
430-
# don’t need to worry about resource contention.
431-
#
432-
# You can specify fractional GPUs (for example, ``gpus_per_trial=0.5``),
433-
# which allows trials to share a GPU. Just ensure that the models fit
434-
# within the GPU memory.
435-
#
436465
# After training the models, we will find the best performing one and load
437466
# the trained network from the checkpoint file. We then obtain the test
438-
# set accuracy and report everything by printing.
439-
#
440-
# The full main function looks like this. Note that the
441-
# ``if __name__ == "__main__":`` block is configured for a quick run (1
442-
# trial, 1 epoch, CPU only) to verify that everything works. You should
443-
# increase these values to perform an actual hyperparameter tuning search.
467+
# set accuracy and report the results.
444468

445469
def main(num_trials=10, max_num_epochs=10, gpus_per_trial=2):
446470
print("Starting hyperparameter tuning.")
@@ -500,10 +524,11 @@ def main(num_trials=10, max_num_epochs=10, gpus_per_trial=2):
500524

501525
if __name__ == "__main__":
502526
# Set the number of trials, epochs, and GPUs per trial here:
527+
# The following configuration is for a quick run (1 trial, 1 epoch, CPU only) for demonstration purposes.
503528
main(num_trials=1, max_num_epochs=1, gpus_per_trial=0)
504529

505530
######################################################################
506-
# Your output will look something like this:
531+
# Your Ray Tune trial summary output will look something like this:
507532
#
508533
# .. code-block:: bash
509534
#
@@ -533,3 +558,39 @@ def main(num_trials=10, max_num_epochs=10, gpus_per_trial=2):
533558
# which could be confirmed on the test set.
534559
#
535560
# You can now tune the parameters of your PyTorch models.
561+
#
562+
# Observability
563+
# -------------
564+
#
565+
# When running large-scale experiments, monitoring is crucial. Ray
566+
# provides a
567+
# `Dashboard <https://docs.ray.io/en/latest/ray-observability/getting-started.html>`__
568+
# that lets you view the status of your trials, check cluster resource
569+
# utilization, and inspect logs in real-time.
570+
#
571+
# For debugging, Ray also offers `Distributed
572+
# Debugging <https://docs.ray.io/en/latest/ray-observability/user-guides/debug-apps/ray-debugger.html>`__
573+
# tools that let you attach a debugger to running trials across the
574+
# cluster.
575+
#
576+
# Conclusion
577+
# ----------
578+
#
579+
# In this tutorial, you learned how to tune the hyperparameters of a
580+
# PyTorch model using Ray Tune. You saw how to integrate Ray Tune into
581+
# your PyTorch training loop, define a search space for your
582+
# hyperparameters, use an efficient scheduler like ASHA to terminate bad
583+
# trials early, save checkpoints and report metrics to Ray Tune, and run
584+
# the hyperparameter search and analyze the results.
585+
#
586+
# Ray Tune makes it easy to scale your experiments from a single machine
587+
# to a large cluster, helping you find the best model configuration
588+
# efficiently.
589+
#
590+
# Further reading
591+
# ---------------
592+
#
593+
# - `Ray Tune
594+
# documentation <https://docs.ray.io/en/latest/tune/index.html>`__
595+
# - `Ray Tune
596+
# examples <https://docs.ray.io/en/latest/tune/examples/index.html>`__

ecosystem.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ to production deployment.
3333
:card_description: Learn how to use Ray Tune to find the best performing set of hyperparameters for your model.
3434
:image: _static/img/ray-tune.png
3535
:link: beginner/hyperparameter_tuning_tutorial.html
36-
:tags: Model-Optimization,Best-Practice,Ecosystem
36+
:tags: Model-Optimization,Best-Practice,Ecosystem,Ray-Distributed,Parallel-and-Distributed-Training
3737

3838
.. customcarditem::
3939
:header: Multi-Objective Neural Architecture Search with Ax

0 commit comments

Comments
 (0)