4444
4545"""
4646
47- # %matplotlib inline
48-
4947from functools import partial
5048import os
5149import tempfile
5755from torch .utils .data import random_split
5856import torchvision
5957import torchvision .transforms as transforms
60- # sphinx_gallery_start_ignore
61- # Fixes `AttributeError: '_LoggingTee' object has no attribute 'fileno'`.
62- # This is only needed to run with sphinx-build.
63- import sys
64- if not hasattr (sys .stdout , "encoding" ):
65- sys .stdout .encoding = "latin1"
66- sys .stdout .fileno = lambda : 0
67- # sphinx_gallery_end_ignore
6858import ray
6959from ray import tune
7060from ray .tune import Checkpoint
7161from ray .tune .schedulers import ASHAScheduler
72- import ray .cloudpickle as pickle
7362
7463######################################################################
7564# Most of the imports are needed for building the PyTorch model. Only the
@@ -135,10 +124,13 @@ def forward(self, x):
135124# ``train_cifar(config, data_dir=None)``. The ``config`` parameter will
136125# receive the hyperparameters we would like to train with. The
137126# ``data_dir`` specifies the directory where we load and store the data,
138- # so that multiple runs can share the same data source. We also load the
139- # model and optimizer state at the start of the run, if a checkpoint is
140- # provided. Further down in this tutorial you will find information on how
141- # to save the checkpoint and what it is used for.
127+ # so that multiple runs can share the same data source. This is especially
128+ # useful in cluster environments, where you can mount a shared storage
129+ # (e.g. NFS) to this directory so that the data is not downloaded to each
130+ # node separately. We also load the model and optimizer state at the start
131+ # of the run, if a checkpoint is provided. Further down in this tutorial
132+ # you will find information on how to save the checkpoint and what it is
133+ # used for.
142134#
143135# .. code-block:: python
144136#
@@ -147,9 +139,8 @@ def forward(self, x):
147139# checkpoint = tune.get_checkpoint()
148140# if checkpoint:
149141# with checkpoint.as_directory() as checkpoint_dir:
150- # data_path = Path(checkpoint_dir) / "data.pkl"
151- # with open(data_path, "rb") as fp:
152- # checkpoint_state = pickle.load(fp)
142+ # checkpoint_path = Path(checkpoint_dir) / "checkpoint.pt"
143+ # checkpoint_state = torch.load(checkpoint_path)
153144# start_epoch = checkpoint_state["epoch"]
154145# net.load_state_dict(checkpoint_state["net_state_dict"])
155146# optimizer.load_state_dict(checkpoint_state["optimizer_state_dict"])
@@ -213,9 +204,8 @@ def forward(self, x):
213204# "optimizer_state_dict": optimizer.state_dict(),
214205# }
215206# with tempfile.TemporaryDirectory() as checkpoint_dir:
216- # data_path = Path(checkpoint_dir) / "data.pkl"
217- # with open(data_path, "wb") as fp:
218- # pickle.dump(checkpoint_data, fp)
207+ # checkpoint_path = Path(checkpoint_dir) / "checkpoint.pt"
208+ # torch.save(checkpoint_data, checkpoint_path)
219209#
220210# checkpoint = Checkpoint.from_directory(checkpoint_dir)
221211# tune.report(
@@ -259,9 +249,8 @@ def train_cifar(config, data_dir=None):
259249 checkpoint = tune .get_checkpoint ()
260250 if checkpoint :
261251 with checkpoint .as_directory () as checkpoint_dir :
262- data_path = Path (checkpoint_dir ) / "data.pkl"
263- with open (data_path , "rb" ) as fp :
264- checkpoint_state = pickle .load (fp )
252+ checkpoint_path = Path (checkpoint_dir ) / "checkpoint.pt"
253+ checkpoint_state = torch .load (checkpoint_path )
265254 start_epoch = checkpoint_state ["epoch" ]
266255 net .load_state_dict (checkpoint_state ["net_state_dict" ])
267256 optimizer .load_state_dict (checkpoint_state ["optimizer_state_dict" ])
@@ -334,9 +323,8 @@ def train_cifar(config, data_dir=None):
334323 "optimizer_state_dict" : optimizer .state_dict (),
335324 }
336325 with tempfile .TemporaryDirectory () as checkpoint_dir :
337- data_path = Path (checkpoint_dir ) / "data.pkl"
338- with open (data_path , "wb" ) as fp :
339- pickle .dump (checkpoint_data , fp )
326+ checkpoint_path = Path (checkpoint_dir ) / "checkpoint.pt"
327+ torch .save (checkpoint_data , checkpoint_path )
340328
341329 checkpoint = Checkpoint .from_directory (checkpoint_dir )
342330 tune .report (
@@ -452,6 +440,7 @@ def test_accuracy(net, device="cpu"):
452440
453441def main (num_trials = 10 , max_num_epochs = 10 , gpus_per_trial = 2 ):
454442 print ("Starting hyperparameter tuning." )
443+ ray .init ()
455444
456445 data_dir = os .path .abspath ("./data" )
457446 load_data (data_dir )
@@ -497,9 +486,8 @@ def main(num_trials=10, max_num_epochs=10, gpus_per_trial=2):
497486
498487 best_checkpoint = best_result .checkpoint
499488 with best_checkpoint .as_directory () as checkpoint_dir :
500- data_path = Path (checkpoint_dir ) / "data.pkl"
501- with open (data_path , "rb" ) as fp :
502- best_checkpoint_data = pickle .load (fp )
489+ checkpoint_path = Path (checkpoint_dir ) / "checkpoint.pt"
490+ best_checkpoint_data = torch .load (checkpoint_path )
503491
504492 best_trained_model .load_state_dict (best_checkpoint_data ["net_state_dict" ])
505493 test_acc = test_accuracy (best_trained_model , device )
0 commit comments