1414import torch .backends .cudnn as cudnn
1515import torch .optim as optim
1616import torch .utils .data as data
17- import torchvision .transforms as transforms
1817import torch .nn .functional as F
1918
2019import models .wideresnet as models
5150parser .add_argument ('--lambda-u' , default = 75 , type = float )
5251parser .add_argument ('--T' , default = 0.5 , type = float )
5352parser .add_argument ('--ema-decay' , default = 0.999 , type = float )
54-
53+ parser .add_argument ('--resolution' , default = 32 , type = int )
54+ # Data options
55+ parser .add_argument ('--data_root' , default = 'data' ,
56+ help = 'Data directory' )
57+ parser .add_argument ('--dataset' , default = 'C10' ,
58+ help = 'Dataset name: C10 | STL10' )
5559
5660args = parser .parse_args ()
5761state = {k : v for k , v in args ._get_kwargs ()}
6771
6872best_acc = 0 # best test accuracy
6973
74+ which_model = models .WideResNet
75+ if args .resolution == 96 :
76+ which_model = models .WideResNet96
77+ if args .resolution == 48 :
78+ which_model = models .WideResNet48
79+
80+
7081def main ():
7182 global best_acc
7283
7384 if not os .path .isdir (args .out ):
7485 mkdir_p (args .out )
7586
7687 # Data
77- print ( f'==> Preparing cifar10' )
78- transform_train = transforms . Compose ([
79- dataset . RandomPadandCrop ( 32 ),
80- dataset .RandomFlip (),
81- dataset . ToTensor (),
82- ])
83-
84- transform_val = transforms . Compose ([
85- dataset . ToTensor (),
86- ] )
87-
88- train_labeled_set , train_unlabeled_set , val_set , test_set = dataset . get_cifar10 ( './data' , args . n_labeled , transform_train = transform_train , transform_val = transform_val )
88+ dataset . validate_dataset ( args . dataset )
89+ print ( f'==> Preparing %s' % args . dataset )
90+
91+ transform_train , transform_val = dataset .get_transforms ( args . dataset , args . resolution )
92+
93+ if args . dataset == 'C10' :
94+ train_labeled_set , train_unlabeled_set , val_set , test_set = dataset . get_cifar10 ( args . data_root , args . n_labeled , transform_train = transform_train , transform_val = transform_val )
95+ elif args . dataset == 'STL10' :
96+ if args . n_labeled != 5000 :
97+ raise ValueError ( "For STL10 the only supported n_labeled is 5000" )
98+ train_labeled_set , train_unlabeled_set , val_set , test_set = dataset . get_stl10 ( args . data_root , transform_train = transform_train , transform_val = transform_val )
99+
89100 labeled_trainloader = data .DataLoader (train_labeled_set , batch_size = args .batch_size , shuffle = True , num_workers = 0 , drop_last = True )
90101 unlabeled_trainloader = data .DataLoader (train_unlabeled_set , batch_size = args .batch_size , shuffle = True , num_workers = 0 , drop_last = True )
91- val_loader = data .DataLoader (val_set , batch_size = args .batch_size , shuffle = False , num_workers = 0 )
92- test_loader = data .DataLoader (test_set , batch_size = args .batch_size , shuffle = False , num_workers = 0 )
102+ val_loader = data .DataLoader (val_set , batch_size = args .batch_size , shuffle = True , num_workers = 0 , drop_last = True )
103+ if test_set is not None :
104+ test_loader = data .DataLoader (test_set , batch_size = args .batch_size , shuffle = False , num_workers = 0 )
105+ else :
106+ test_loader = None
93107
94108 # Model
95- print ("==> creating WRN-28-2" )
109+ print ("==> creating %s" % which_model . __name__ )
96110
97111 def create_model (ema = False ):
98- model = models . WideResNet (num_classes = 10 )
112+ model = which_model (num_classes = 10 )
99113 model = model .cuda ()
100114
101115 if ema :
@@ -123,7 +137,6 @@ def create_model(ema=False):
123137 # Load checkpoint.
124138 print ('==> Resuming from checkpoint..' )
125139 assert os .path .isfile (args .resume ), 'Error: no checkpoint directory found!'
126- args .out = os .path .dirname (args .resume )
127140 checkpoint = torch .load (args .resume )
128141 best_acc = checkpoint ['best_acc' ]
129142 start_epoch = checkpoint ['epoch' ]
@@ -146,7 +159,10 @@ def create_model(ema=False):
146159 train_loss , train_loss_x , train_loss_u = train (labeled_trainloader , unlabeled_trainloader , model , optimizer , ema_optimizer , train_criterion , epoch , use_cuda )
147160 _ , train_acc = validate (labeled_trainloader , ema_model , criterion , epoch , use_cuda , mode = 'Train Stats' )
148161 val_loss , val_acc = validate (val_loader , ema_model , criterion , epoch , use_cuda , mode = 'Valid Stats' )
149- test_loss , test_acc = validate (test_loader , ema_model , criterion , epoch , use_cuda , mode = 'Test Stats ' )
162+ if test_loader is not None :
163+ test_loss , test_acc = validate (test_loader , ema_model , criterion , epoch , use_cuda , mode = 'Test Stats ' )
164+ else :
165+ test_loss , test_acc = [- 1 , - 1 ]
150166
151167 step = args .val_iteration * (epoch + 1 )
152168
@@ -206,10 +222,18 @@ def train(labeled_trainloader, unlabeled_trainloader, model, optimizer, ema_opti
206222 inputs_x , targets_x = labeled_train_iter .next ()
207223
208224 try :
209- (inputs_u , inputs_u2 ), _ = unlabeled_train_iter .next ()
225+ if args .dataset == 'STL10' :
226+ inputs_u , _ = unlabeled_train_iter .next ()
227+ inputs_u2 , _ = unlabeled_train_iter .next ()
228+ else :
229+ (inputs_u , inputs_u2 ), _ = unlabeled_train_iter .next ()
210230 except :
211231 unlabeled_train_iter = iter (unlabeled_trainloader )
212- (inputs_u , inputs_u2 ), _ = unlabeled_train_iter .next ()
232+ if args .dataset == 'STL10' :
233+ inputs_u , _ = unlabeled_train_iter .next ()
234+ inputs_u2 , _ = unlabeled_train_iter .next ()
235+ else :
236+ (inputs_u , inputs_u2 ), _ = unlabeled_train_iter .next ()
213237
214238 # measure data loading time
215239 data_time .update (time .time () - end )
0 commit comments