Skip to content

Commit ef48201

Browse files
committed
Add support for STL10 at resolutions 32, 48, and 96
1 parent a738cc9 commit ef48201

File tree

4 files changed

+240
-29
lines changed

4 files changed

+240
-29
lines changed

README.md

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
This is an unofficial PyTorch implementation of [MixMatch: A Holistic Approach to Semi-Supervised Learning](https://arxiv.org/abs/1905.02249).
33
The official Tensorflow implementation is [here](https://github.com/google-research/mixmatch).
44

5-
Now only experiments on CIFAR-10 are available.
5+
Experiments on CIFAR-10 and STL-10 are available.
66

77
This repository carefully implemented important details of the official implementation to reproduce the results.
88

@@ -31,19 +31,38 @@ Train the model by 4000 labeled data of CIFAR-10 dataset:
3131
python train.py --gpu <gpu_id> --n-labeled 4000 --out cifar10@4000
3232
```
3333

34+
Train STL-10:
35+
36+
```
37+
python train.py --resolution <32|48|96> --out stl10 --data_root data/stl10 --dataset STL10 --n-labeled 5000
38+
```
39+
40+
3441
### Monitoring training progress
3542
```
3643
tensorboard.sh --port 6006 --logdir cifar10@250
3744
```
3845

3946
## Results (Accuracy)
47+
48+
### CIFAR10
49+
4050
| #Labels | 250 | 500 | 1000 | 2000| 4000 |
4151
|:---|:---:|:---:|:---:|:---:|:---:|
4252
|Paper | 88.92 ± 0.87 | 90.35 ± 0.94 | 92.25 ± 0.32| 92.97 ± 0.15 |93.76 ± 0.06|
4353
|This code | 88.71 | 88.96 | 90.52 | 92.23 | 93.52 |
4454

4555
(Results of this code were evaluated on 1 run. Results of 5 runs with different seeds will be updated later. )
4656

57+
### STL10
58+
59+
Using the entire 5000 point dataset:
60+
61+
| Resolution | 32 | 48 | 96 |
62+
|:---|:---:|:---:|:---:|
63+
|Paper | - | - | 94.41 |
64+
|This code | 82.69 | 86.41 | 91.33 |
65+
4766
## References
4867
```
4968
@article{berthelot2019mixmatch,

dataset/cifar10.py

Lines changed: 84 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,29 @@
44
import torchvision
55
import torch
66

7+
import torchvision.transforms as transforms
8+
from torchvision.datasets import STL10
9+
10+
# dict containing supported datasets with their image resolutions
11+
imsize_dict = {'C10': 32, 'STL10': 96}
12+
13+
cifar10_mean = (0.4914, 0.4822, 0.4465)
14+
cifar10_std = (0.2023, 0.1994, 0.2010)
15+
16+
stl10_mean = (0.4914, 0.4822, 0.4465)
17+
stl10_std = (0.2471, 0.2435, 0.2616)
18+
19+
dataset_stats = {
20+
'C10' : {
21+
'mean': cifar10_mean,
22+
'std': cifar10_std
23+
},
24+
'STL10' : {
25+
'mean': stl10_mean,
26+
'std': stl10_std
27+
},
28+
}
29+
730
class TransformTwice:
831
def __init__(self, transform):
932
self.transform = transform
@@ -27,7 +50,67 @@ def get_cifar10(root, n_labeled,
2750

2851
print (f"#Labeled: {len(train_labeled_idxs)} #Unlabeled: {len(train_unlabeled_idxs)} #Val: {len(val_idxs)}")
2952
return train_labeled_dataset, train_unlabeled_dataset, val_dataset, test_dataset
30-
53+
54+
def get_stl10(root,
55+
transform_train=None, transform_val=None,
56+
download=True):
57+
58+
training_set = STL10(root, split='train', download=True, transform=transform_train)
59+
dev_set = STL10(root, split='test', download=True, transform=transform_val)
60+
unl_set = STL10(root, split='unlabeled', download=True, transform=transform_train)
61+
62+
print (f"#Labeled: {len(training_set)} #Unlabeled: {len(unl_set)} #Val: {len(dev_set)} #Test: None")
63+
return training_set, unl_set, dev_set, None
64+
65+
def validate_dataset(dataset):
66+
if dataset not in imsize_dict:
67+
raise ValueError("Dataset %s not supported." % dataset)
68+
69+
def get_transforms(dataset, resolution):
70+
dataset_resolution = imsize_dict[dataset]
71+
72+
if dataset == 'STL10':
73+
if resolution == 96:
74+
transform_train = transforms.Compose([
75+
transforms.RandomCrop(96, padding=4),
76+
transforms.RandomHorizontalFlip(),
77+
transforms.ToTensor(),
78+
transforms.Normalize(stl10_mean, stl10_std),
79+
])
80+
else:
81+
transform_train = transforms.Compose([
82+
transforms.RandomCrop(86, padding=0),
83+
transforms.Resize(resolution),
84+
transforms.RandomHorizontalFlip(),
85+
transforms.ToTensor(),
86+
transforms.Normalize(stl10_mean, stl10_std),
87+
])
88+
if dataset_resolution == resolution:
89+
transform_val = transforms.Compose([
90+
transforms.ToTensor(),
91+
transforms.Normalize(dataset_stats[dataset]['mean'], dataset_stats[dataset]['std']),
92+
])
93+
else:
94+
transform_val = transforms.Compose([
95+
transforms.Resize(resolution),
96+
transforms.ToTensor(),
97+
transforms.Normalize(dataset_stats[dataset]['mean'], dataset_stats[dataset]['std']),
98+
])
99+
if dataset == 'C10':
100+
# already normalized in the CIFAR10_labeled/CIFAR10_unlabeled class
101+
transform_train = transforms.Compose([
102+
RandomPadandCrop(resolution),
103+
RandomFlip(),
104+
ToTensor(),
105+
])
106+
transform_val = transforms.Compose([
107+
ToTensor(),
108+
])
109+
110+
111+
return transform_train, transform_val
112+
113+
31114

32115
def train_val_split(labels, n_labeled_per_class):
33116
labels = np.array(labels)
@@ -47,9 +130,6 @@ def train_val_split(labels, n_labeled_per_class):
47130

48131
return train_labeled_idxs, train_unlabeled_idxs, val_idxs
49132

50-
cifar10_mean = (0.4914, 0.4822, 0.4465) # equals np.mean(train_set.train_data, axis=(0,1,2))/255
51-
cifar10_std = (0.2471, 0.2435, 0.2616) # equals np.std(train_set.train_data, axis=(0,1,2))/255
52-
53133
def normalise(x, mean=cifar10_mean, std=cifar10_std):
54134
x, mean, std = [np.array(a, np.float32) for a in (x, mean, std)]
55135
x -= mean*255

models/wideresnet.py

Lines changed: 90 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import torch.nn as nn
44
import torch.nn.functional as F
55

6-
76
class BasicBlock(nn.Module):
87
def __init__(self, in_planes, out_planes, stride, dropRate=0.0, activate_before_residual=False):
98
super(BasicBlock, self).__init__()
@@ -16,7 +15,7 @@ def __init__(self, in_planes, out_planes, stride, dropRate=0.0, activate_before_
1615
self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
1716
padding=1, bias=False)
1817
self.droprate = dropRate
19-
self.equalInOut = (in_planes == out_planes)
18+
self.equalInOut = (in_planes == out_planes) and (stride == 1)
2019
self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
2120
padding=0, bias=False) or None
2221
self.activate_before_residual = activate_before_residual
@@ -84,4 +83,93 @@ def forward(self, x):
8483
out = self.relu(self.bn1(out))
8584
out = F.avg_pool2d(out, 8)
8685
out = out.view(-1, self.nChannels)
86+
return self.fc(out)
87+
88+
class WideResNet48(nn.Module):
89+
def __init__(self, num_classes, depth=28, widen_factor=2, dropRate=0.0):
90+
super(WideResNet48, self).__init__()
91+
nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor]
92+
assert((depth - 4) % 6 == 0)
93+
n = (depth - 4) / 6
94+
block = BasicBlock
95+
# 1st conv before any network block
96+
self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,
97+
padding=1, bias=False)
98+
# 1st block
99+
self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate, activate_before_residual=True)
100+
# 2nd block
101+
self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)
102+
# 3rd block
103+
self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)
104+
# global average pooling and classifier
105+
self.bn1 = nn.BatchNorm2d(nChannels[3], momentum=0.001)
106+
self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
107+
self.fc = nn.Linear(nChannels[3], num_classes)
108+
self.nChannels = nChannels[3]
109+
110+
for m in self.modules():
111+
if isinstance(m, nn.Conv2d):
112+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
113+
m.weight.data.normal_(0, math.sqrt(2. / n))
114+
elif isinstance(m, nn.BatchNorm2d):
115+
m.weight.data.fill_(1)
116+
m.bias.data.zero_()
117+
elif isinstance(m, nn.Linear):
118+
nn.init.xavier_normal_(m.weight.data)
119+
m.bias.data.zero_()
120+
121+
def forward(self, x):
122+
out = self.conv1(x)
123+
out = self.block1(out)
124+
out = self.block2(out)
125+
out = self.block3(out)
126+
out = self.relu(self.bn1(out))
127+
out = F.avg_pool2d(out, 12)
128+
out = out.view(-1, self.nChannels)
129+
return self.fc(out)
130+
131+
class WideResNet96(nn.Module):
132+
def __init__(self, num_classes, depth=28, widen_factor=2, dropRate=0.0):
133+
super(WideResNet96, self).__init__()
134+
nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor, 64*widen_factor]
135+
assert((depth - 4) % 6 == 0)
136+
n = (depth - 4) / 6
137+
block = BasicBlock
138+
# 1st conv before any network block
139+
self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,
140+
padding=1, bias=False)
141+
# 1st block
142+
self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate, activate_before_residual=True)
143+
# 2nd block
144+
self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)
145+
# 3rd block
146+
self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)
147+
# 4th block
148+
self.block4 = NetworkBlock(n, nChannels[3], nChannels[4], block, 2, dropRate)
149+
# global average pooling and classifier
150+
self.bn1 = nn.BatchNorm2d(nChannels[4], momentum=0.001)
151+
self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
152+
self.fc = nn.Linear(nChannels[4], num_classes)
153+
self.nChannels = nChannels[4]
154+
155+
for m in self.modules():
156+
if isinstance(m, nn.Conv2d):
157+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
158+
m.weight.data.normal_(0, math.sqrt(2. / n))
159+
elif isinstance(m, nn.BatchNorm2d):
160+
m.weight.data.fill_(1)
161+
m.bias.data.zero_()
162+
elif isinstance(m, nn.Linear):
163+
nn.init.xavier_normal_(m.weight.data)
164+
m.bias.data.zero_()
165+
166+
def forward(self, x):
167+
out = self.conv1(x)
168+
out = self.block1(out)
169+
out = self.block2(out)
170+
out = self.block3(out)
171+
out = self.block4(out)
172+
out = self.relu(self.bn1(out))
173+
out = F.avg_pool2d(out, 12)
174+
out = out.view(-1, self.nChannels)
87175
return self.fc(out)

train.py

Lines changed: 46 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import torch.backends.cudnn as cudnn
1515
import torch.optim as optim
1616
import torch.utils.data as data
17-
import torchvision.transforms as transforms
1817
import torch.nn.functional as F
1918

2019
import models.wideresnet as models
@@ -51,7 +50,12 @@
5150
parser.add_argument('--lambda-u', default=75, type=float)
5251
parser.add_argument('--T', default=0.5, type=float)
5352
parser.add_argument('--ema-decay', default=0.999, type=float)
54-
53+
parser.add_argument('--resolution', default=32, type=int)
54+
# Data options
55+
parser.add_argument('--data_root', default='data',
56+
help='Data directory')
57+
parser.add_argument('--dataset', default='C10',
58+
help='Dataset name: C10 | STL10')
5559

5660
args = parser.parse_args()
5761
state = {k: v for k, v in args._get_kwargs()}
@@ -67,35 +71,45 @@
6771

6872
best_acc = 0 # best test accuracy
6973

74+
which_model = models.WideResNet
75+
if args.resolution == 96:
76+
which_model = models.WideResNet96
77+
if args.resolution == 48:
78+
which_model = models.WideResNet48
79+
80+
7081
def main():
7182
global best_acc
7283

7384
if not os.path.isdir(args.out):
7485
mkdir_p(args.out)
7586

7687
# Data
77-
print(f'==> Preparing cifar10')
78-
transform_train = transforms.Compose([
79-
dataset.RandomPadandCrop(32),
80-
dataset.RandomFlip(),
81-
dataset.ToTensor(),
82-
])
83-
84-
transform_val = transforms.Compose([
85-
dataset.ToTensor(),
86-
])
87-
88-
train_labeled_set, train_unlabeled_set, val_set, test_set = dataset.get_cifar10('./data', args.n_labeled, transform_train=transform_train, transform_val=transform_val)
88+
dataset.validate_dataset(args.dataset)
89+
print(f'==> Preparing %s' % args.dataset)
90+
91+
transform_train, transform_val = dataset.get_transforms(args.dataset, args.resolution)
92+
93+
if args.dataset == 'C10':
94+
train_labeled_set, train_unlabeled_set, val_set, test_set = dataset.get_cifar10(args.data_root, args.n_labeled, transform_train=transform_train, transform_val=transform_val)
95+
elif args.dataset == 'STL10':
96+
if args.n_labeled != 5000:
97+
raise ValueError("For STL10 the only supported n_labeled is 5000")
98+
train_labeled_set, train_unlabeled_set, val_set, test_set = dataset.get_stl10(args.data_root, transform_train=transform_train, transform_val=transform_val)
99+
89100
labeled_trainloader = data.DataLoader(train_labeled_set, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=True)
90101
unlabeled_trainloader = data.DataLoader(train_unlabeled_set, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=True)
91-
val_loader = data.DataLoader(val_set, batch_size=args.batch_size, shuffle=False, num_workers=0)
92-
test_loader = data.DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=0)
102+
val_loader = data.DataLoader(val_set, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=True)
103+
if test_set is not None:
104+
test_loader = data.DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=0)
105+
else:
106+
test_loader = None
93107

94108
# Model
95-
print("==> creating WRN-28-2")
109+
print("==> creating %s" % which_model.__name__)
96110

97111
def create_model(ema=False):
98-
model = models.WideResNet(num_classes=10)
112+
model = which_model(num_classes=10)
99113
model = model.cuda()
100114

101115
if ema:
@@ -123,7 +137,6 @@ def create_model(ema=False):
123137
# Load checkpoint.
124138
print('==> Resuming from checkpoint..')
125139
assert os.path.isfile(args.resume), 'Error: no checkpoint directory found!'
126-
args.out = os.path.dirname(args.resume)
127140
checkpoint = torch.load(args.resume)
128141
best_acc = checkpoint['best_acc']
129142
start_epoch = checkpoint['epoch']
@@ -146,7 +159,10 @@ def create_model(ema=False):
146159
train_loss, train_loss_x, train_loss_u = train(labeled_trainloader, unlabeled_trainloader, model, optimizer, ema_optimizer, train_criterion, epoch, use_cuda)
147160
_, train_acc = validate(labeled_trainloader, ema_model, criterion, epoch, use_cuda, mode='Train Stats')
148161
val_loss, val_acc = validate(val_loader, ema_model, criterion, epoch, use_cuda, mode='Valid Stats')
149-
test_loss, test_acc = validate(test_loader, ema_model, criterion, epoch, use_cuda, mode='Test Stats ')
162+
if test_loader is not None:
163+
test_loss, test_acc = validate(test_loader, ema_model, criterion, epoch, use_cuda, mode='Test Stats ')
164+
else:
165+
test_loss, test_acc = [-1, -1]
150166

151167
step = args.val_iteration * (epoch + 1)
152168

@@ -206,10 +222,18 @@ def train(labeled_trainloader, unlabeled_trainloader, model, optimizer, ema_opti
206222
inputs_x, targets_x = labeled_train_iter.next()
207223

208224
try:
209-
(inputs_u, inputs_u2), _ = unlabeled_train_iter.next()
225+
if args.dataset == 'STL10':
226+
inputs_u, _ = unlabeled_train_iter.next()
227+
inputs_u2, _ = unlabeled_train_iter.next()
228+
else:
229+
(inputs_u, inputs_u2), _ = unlabeled_train_iter.next()
210230
except:
211231
unlabeled_train_iter = iter(unlabeled_trainloader)
212-
(inputs_u, inputs_u2), _ = unlabeled_train_iter.next()
232+
if args.dataset == 'STL10':
233+
inputs_u, _ = unlabeled_train_iter.next()
234+
inputs_u2, _ = unlabeled_train_iter.next()
235+
else:
236+
(inputs_u, inputs_u2), _ = unlabeled_train_iter.next()
213237

214238
# measure data loading time
215239
data_time.update(time.time() - end)

0 commit comments

Comments
 (0)