Skip to content

Commit ad79672

Browse files
author
Ricardo Decal
committed
Clean up hparam tuning tutorial, modernize checkpointing
1 parent 10566e3 commit ad79672

File tree

1 file changed

+18
-30
lines changed

1 file changed

+18
-30
lines changed

beginner_source/hyperparameter_tuning_tutorial.py

Lines changed: 18 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,6 @@
4444
4545
"""
4646

47-
# %matplotlib inline
48-
4947
from functools import partial
5048
import os
5149
import tempfile
@@ -57,19 +55,10 @@
5755
from torch.utils.data import random_split
5856
import torchvision
5957
import torchvision.transforms as transforms
60-
# sphinx_gallery_start_ignore
61-
# Fixes `AttributeError: '_LoggingTee' object has no attribute 'fileno'`.
62-
# This is only needed to run with sphinx-build.
63-
import sys
64-
if not hasattr(sys.stdout, "encoding"):
65-
sys.stdout.encoding = "latin1"
66-
sys.stdout.fileno = lambda: 0
67-
# sphinx_gallery_end_ignore
6858
import ray
6959
from ray import tune
7060
from ray.tune import Checkpoint
7161
from ray.tune.schedulers import ASHAScheduler
72-
import ray.cloudpickle as pickle
7362

7463
######################################################################
7564
# Most of the imports are needed for building the PyTorch model. Only the
@@ -135,10 +124,13 @@ def forward(self, x):
135124
# ``train_cifar(config, data_dir=None)``. The ``config`` parameter will
136125
# receive the hyperparameters we would like to train with. The
137126
# ``data_dir`` specifies the directory where we load and store the data,
138-
# so that multiple runs can share the same data source. We also load the
139-
# model and optimizer state at the start of the run, if a checkpoint is
140-
# provided. Further down in this tutorial you will find information on how
141-
# to save the checkpoint and what it is used for.
127+
# so that multiple runs can share the same data source. This is especially
128+
# useful in cluster environments, where you can mount a shared storage
129+
# (e.g. NFS) to this directory so that the data is not downloaded to each
130+
# node separately. We also load the model and optimizer state at the start
131+
# of the run, if a checkpoint is provided. Further down in this tutorial
132+
# you will find information on how to save the checkpoint and what it is
133+
# used for.
142134
#
143135
# .. code-block:: python
144136
#
@@ -147,9 +139,8 @@ def forward(self, x):
147139
# checkpoint = tune.get_checkpoint()
148140
# if checkpoint:
149141
# with checkpoint.as_directory() as checkpoint_dir:
150-
# data_path = Path(checkpoint_dir) / "data.pkl"
151-
# with open(data_path, "rb") as fp:
152-
# checkpoint_state = pickle.load(fp)
142+
# checkpoint_path = Path(checkpoint_dir) / "checkpoint.pt"
143+
# checkpoint_state = torch.load(checkpoint_path)
153144
# start_epoch = checkpoint_state["epoch"]
154145
# net.load_state_dict(checkpoint_state["net_state_dict"])
155146
# optimizer.load_state_dict(checkpoint_state["optimizer_state_dict"])
@@ -213,9 +204,8 @@ def forward(self, x):
213204
# "optimizer_state_dict": optimizer.state_dict(),
214205
# }
215206
# with tempfile.TemporaryDirectory() as checkpoint_dir:
216-
# data_path = Path(checkpoint_dir) / "data.pkl"
217-
# with open(data_path, "wb") as fp:
218-
# pickle.dump(checkpoint_data, fp)
207+
# checkpoint_path = Path(checkpoint_dir) / "checkpoint.pt"
208+
# torch.save(checkpoint_data, checkpoint_path)
219209
#
220210
# checkpoint = Checkpoint.from_directory(checkpoint_dir)
221211
# tune.report(
@@ -259,9 +249,8 @@ def train_cifar(config, data_dir=None):
259249
checkpoint = tune.get_checkpoint()
260250
if checkpoint:
261251
with checkpoint.as_directory() as checkpoint_dir:
262-
data_path = Path(checkpoint_dir) / "data.pkl"
263-
with open(data_path, "rb") as fp:
264-
checkpoint_state = pickle.load(fp)
252+
checkpoint_path = Path(checkpoint_dir) / "checkpoint.pt"
253+
checkpoint_state = torch.load(checkpoint_path)
265254
start_epoch = checkpoint_state["epoch"]
266255
net.load_state_dict(checkpoint_state["net_state_dict"])
267256
optimizer.load_state_dict(checkpoint_state["optimizer_state_dict"])
@@ -334,9 +323,8 @@ def train_cifar(config, data_dir=None):
334323
"optimizer_state_dict": optimizer.state_dict(),
335324
}
336325
with tempfile.TemporaryDirectory() as checkpoint_dir:
337-
data_path = Path(checkpoint_dir) / "data.pkl"
338-
with open(data_path, "wb") as fp:
339-
pickle.dump(checkpoint_data, fp)
326+
checkpoint_path = Path(checkpoint_dir) / "checkpoint.pt"
327+
torch.save(checkpoint_data, checkpoint_path)
340328

341329
checkpoint = Checkpoint.from_directory(checkpoint_dir)
342330
tune.report(
@@ -452,6 +440,7 @@ def test_accuracy(net, device="cpu"):
452440

453441
def main(num_trials=10, max_num_epochs=10, gpus_per_trial=2):
454442
print("Starting hyperparameter tuning.")
443+
ray.init()
455444

456445
data_dir = os.path.abspath("./data")
457446
load_data(data_dir)
@@ -497,9 +486,8 @@ def main(num_trials=10, max_num_epochs=10, gpus_per_trial=2):
497486

498487
best_checkpoint = best_result.checkpoint
499488
with best_checkpoint.as_directory() as checkpoint_dir:
500-
data_path = Path(checkpoint_dir) / "data.pkl"
501-
with open(data_path, "rb") as fp:
502-
best_checkpoint_data = pickle.load(fp)
489+
checkpoint_path = Path(checkpoint_dir) / "checkpoint.pt"
490+
best_checkpoint_data = torch.load(checkpoint_path)
503491

504492
best_trained_model.load_state_dict(best_checkpoint_data["net_state_dict"])
505493
test_acc = test_accuracy(best_trained_model, device)

0 commit comments

Comments
 (0)