Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
This is an unofficial PyTorch implementation of [MixMatch: A Holistic Approach to Semi-Supervised Learning](https://arxiv.org/abs/1905.02249).
The official Tensorflow implementation is [here](https://github.com/google-research/mixmatch).

Now only experiments on CIFAR-10 are available.
Experiments on CIFAR-10 and STL-10 are available.

This repository carefully implemented important details of the official implementation to reproduce the results.

Expand Down Expand Up @@ -31,19 +31,38 @@ Train the model by 4000 labeled data of CIFAR-10 dataset:
python train.py --gpu <gpu_id> --n-labeled 4000 --out cifar10@4000
```

Train STL-10:

```
python train.py --resolution <32|48|96> --out stl10 --data_root data/stl10 --dataset STL10 --n-labeled 5000
```


### Monitoring training progress
```
tensorboard.sh --port 6006 --logdir cifar10@250
```

## Results (Accuracy)

### CIFAR10

| #Labels | 250 | 500 | 1000 | 2000| 4000 |
|:---|:---:|:---:|:---:|:---:|:---:|
|Paper | 88.92 ± 0.87 | 90.35 ± 0.94 | 92.25 ± 0.32| 92.97 ± 0.15 |93.76 ± 0.06|
|This code | 88.71 | 88.96 | 90.52 | 92.23 | 93.52 |

(Results of this code were evaluated on 1 run. Results of 5 runs with different seeds will be updated later. )

### STL10

Using the entire 5000 point dataset:

| Resolution | 32 | 48 | 96 |
|:---|:---:|:---:|:---:|
|Paper | - | - | 94.41 |
|This code | 82.69 | 86.41 | 91.33 |

## References
```
@article{berthelot2019mixmatch,
Expand Down
88 changes: 84 additions & 4 deletions dataset/cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,29 @@
import torchvision
import torch

import torchvision.transforms as transforms
from torchvision.datasets import STL10

# dict containing supported datasets with their image resolutions
imsize_dict = {'C10': 32, 'STL10': 96}

cifar10_mean = (0.4914, 0.4822, 0.4465)
cifar10_std = (0.2023, 0.1994, 0.2010)

stl10_mean = (0.4914, 0.4822, 0.4465)
stl10_std = (0.2471, 0.2435, 0.2616)

dataset_stats = {
'C10' : {
'mean': cifar10_mean,
'std': cifar10_std
},
'STL10' : {
'mean': stl10_mean,
'std': stl10_std
},
}

class TransformTwice:
def __init__(self, transform):
self.transform = transform
Expand All @@ -27,7 +50,67 @@ def get_cifar10(root, n_labeled,

print (f"#Labeled: {len(train_labeled_idxs)} #Unlabeled: {len(train_unlabeled_idxs)} #Val: {len(val_idxs)}")
return train_labeled_dataset, train_unlabeled_dataset, val_dataset, test_dataset


def get_stl10(root,
transform_train=None, transform_val=None,
download=True):

training_set = STL10(root, split='train', download=True, transform=transform_train)
dev_set = STL10(root, split='test', download=True, transform=transform_val)
unl_set = STL10(root, split='unlabeled', download=True, transform=transform_train)

print (f"#Labeled: {len(training_set)} #Unlabeled: {len(unl_set)} #Val: {len(dev_set)} #Test: None")
return training_set, unl_set, dev_set, None

def validate_dataset(dataset):
if dataset not in imsize_dict:
raise ValueError("Dataset %s not supported." % dataset)

def get_transforms(dataset, resolution):
dataset_resolution = imsize_dict[dataset]

if dataset == 'STL10':
if resolution == 96:
transform_train = transforms.Compose([
transforms.RandomCrop(96, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(stl10_mean, stl10_std),
])
else:
transform_train = transforms.Compose([
transforms.RandomCrop(86, padding=0),
transforms.Resize(resolution),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(stl10_mean, stl10_std),
])
if dataset_resolution == resolution:
transform_val = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(dataset_stats[dataset]['mean'], dataset_stats[dataset]['std']),
])
else:
transform_val = transforms.Compose([
transforms.Resize(resolution),
transforms.ToTensor(),
transforms.Normalize(dataset_stats[dataset]['mean'], dataset_stats[dataset]['std']),
])
if dataset == 'C10':
# already normalized in the CIFAR10_labeled/CIFAR10_unlabeled class
transform_train = transforms.Compose([
RandomPadandCrop(resolution),
RandomFlip(),
ToTensor(),
])
transform_val = transforms.Compose([
ToTensor(),
])


return transform_train, transform_val



def train_val_split(labels, n_labeled_per_class):
labels = np.array(labels)
Expand All @@ -47,9 +130,6 @@ def train_val_split(labels, n_labeled_per_class):

return train_labeled_idxs, train_unlabeled_idxs, val_idxs

cifar10_mean = (0.4914, 0.4822, 0.4465) # equals np.mean(train_set.train_data, axis=(0,1,2))/255
cifar10_std = (0.2471, 0.2435, 0.2616) # equals np.std(train_set.train_data, axis=(0,1,2))/255

def normalise(x, mean=cifar10_mean, std=cifar10_std):
x, mean, std = [np.array(a, np.float32) for a in (x, mean, std)]
x -= mean*255
Expand Down
92 changes: 90 additions & 2 deletions models/wideresnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import torch.nn as nn
import torch.nn.functional as F


class BasicBlock(nn.Module):
def __init__(self, in_planes, out_planes, stride, dropRate=0.0, activate_before_residual=False):
super(BasicBlock, self).__init__()
Expand All @@ -16,7 +15,7 @@ def __init__(self, in_planes, out_planes, stride, dropRate=0.0, activate_before_
self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
padding=1, bias=False)
self.droprate = dropRate
self.equalInOut = (in_planes == out_planes)
self.equalInOut = (in_planes == out_planes) and (stride == 1)
self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
padding=0, bias=False) or None
self.activate_before_residual = activate_before_residual
Expand Down Expand Up @@ -84,4 +83,93 @@ def forward(self, x):
out = self.relu(self.bn1(out))
out = F.avg_pool2d(out, 8)
out = out.view(-1, self.nChannels)
return self.fc(out)

class WideResNet48(nn.Module):
def __init__(self, num_classes, depth=28, widen_factor=2, dropRate=0.0):
super(WideResNet48, self).__init__()
nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor]
assert((depth - 4) % 6 == 0)
n = (depth - 4) / 6
block = BasicBlock
# 1st conv before any network block
self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,
padding=1, bias=False)
# 1st block
self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate, activate_before_residual=True)
# 2nd block
self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)
# 3rd block
self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)
# global average pooling and classifier
self.bn1 = nn.BatchNorm2d(nChannels[3], momentum=0.001)
self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
self.fc = nn.Linear(nChannels[3], num_classes)
self.nChannels = nChannels[3]

for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight.data)
m.bias.data.zero_()

def forward(self, x):
out = self.conv1(x)
out = self.block1(out)
out = self.block2(out)
out = self.block3(out)
out = self.relu(self.bn1(out))
out = F.avg_pool2d(out, 12)
out = out.view(-1, self.nChannels)
return self.fc(out)

class WideResNet96(nn.Module):
def __init__(self, num_classes, depth=28, widen_factor=2, dropRate=0.0):
super(WideResNet96, self).__init__()
nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor, 64*widen_factor]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this line be "nChannels = [16, 16widen_factor, 32widen_factor, 64widen_factor, 128widen_factor]"?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change doubles the parameter count from 2.67M to 5.93M. I started training the larger model now since I don't remember the performance difference from a few months ago. I'll post the results soon.

Copy link
Author

@ilyakava ilyakava Mar 25, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chencjGene It doesn't provide any gains. Last layer 64 = max top1 91.8 vs last layer 128 = max top1 91.76. The larger network trains slightly slower and starts to overfit sooner. wider results, narrower results

assert((depth - 4) % 6 == 0)
n = (depth - 4) / 6
block = BasicBlock
# 1st conv before any network block
self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,
padding=1, bias=False)
# 1st block
self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate, activate_before_residual=True)
# 2nd block
self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)
# 3rd block
self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)
# 4th block
self.block4 = NetworkBlock(n, nChannels[3], nChannels[4], block, 2, dropRate)
# global average pooling and classifier
self.bn1 = nn.BatchNorm2d(nChannels[4], momentum=0.001)
self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
self.fc = nn.Linear(nChannels[4], num_classes)
self.nChannels = nChannels[4]

for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight.data)
m.bias.data.zero_()

def forward(self, x):
out = self.conv1(x)
out = self.block1(out)
out = self.block2(out)
out = self.block3(out)
out = self.block4(out)
out = self.relu(self.bn1(out))
out = F.avg_pool2d(out, 12)
out = out.view(-1, self.nChannels)
return self.fc(out)
68 changes: 46 additions & 22 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
import torch.nn.functional as F

import models.wideresnet as models
Expand Down Expand Up @@ -51,7 +50,12 @@
parser.add_argument('--lambda-u', default=75, type=float)
parser.add_argument('--T', default=0.5, type=float)
parser.add_argument('--ema-decay', default=0.999, type=float)

parser.add_argument('--resolution', default=32, type=int)
# Data options
parser.add_argument('--data_root', default='data',
help='Data directory')
parser.add_argument('--dataset', default='C10',
help='Dataset name: C10 | STL10')

args = parser.parse_args()
state = {k: v for k, v in args._get_kwargs()}
Expand All @@ -67,35 +71,45 @@

best_acc = 0 # best test accuracy

which_model = models.WideResNet
if args.resolution == 96:
which_model = models.WideResNet96
if args.resolution == 48:
which_model = models.WideResNet48


def main():
global best_acc

if not os.path.isdir(args.out):
mkdir_p(args.out)

# Data
print(f'==> Preparing cifar10')
transform_train = transforms.Compose([
dataset.RandomPadandCrop(32),
dataset.RandomFlip(),
dataset.ToTensor(),
])

transform_val = transforms.Compose([
dataset.ToTensor(),
])

train_labeled_set, train_unlabeled_set, val_set, test_set = dataset.get_cifar10('./data', args.n_labeled, transform_train=transform_train, transform_val=transform_val)
dataset.validate_dataset(args.dataset)
print(f'==> Preparing %s' % args.dataset)

transform_train, transform_val = dataset.get_transforms(args.dataset, args.resolution)

if args.dataset == 'C10':
train_labeled_set, train_unlabeled_set, val_set, test_set = dataset.get_cifar10(args.data_root, args.n_labeled, transform_train=transform_train, transform_val=transform_val)
elif args.dataset == 'STL10':
if args.n_labeled != 5000:
raise ValueError("For STL10 the only supported n_labeled is 5000")
train_labeled_set, train_unlabeled_set, val_set, test_set = dataset.get_stl10(args.data_root, transform_train=transform_train, transform_val=transform_val)

labeled_trainloader = data.DataLoader(train_labeled_set, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=True)
unlabeled_trainloader = data.DataLoader(train_unlabeled_set, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=True)
val_loader = data.DataLoader(val_set, batch_size=args.batch_size, shuffle=False, num_workers=0)
test_loader = data.DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=0)
val_loader = data.DataLoader(val_set, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=True)
if test_set is not None:
test_loader = data.DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=0)
else:
test_loader = None

# Model
print("==> creating WRN-28-2")
print("==> creating %s" % which_model.__name__)

def create_model(ema=False):
model = models.WideResNet(num_classes=10)
model = which_model(num_classes=10)
model = model.cuda()

if ema:
Expand Down Expand Up @@ -123,7 +137,6 @@ def create_model(ema=False):
# Load checkpoint.
print('==> Resuming from checkpoint..')
assert os.path.isfile(args.resume), 'Error: no checkpoint directory found!'
args.out = os.path.dirname(args.resume)
checkpoint = torch.load(args.resume)
best_acc = checkpoint['best_acc']
start_epoch = checkpoint['epoch']
Expand All @@ -146,7 +159,10 @@ def create_model(ema=False):
train_loss, train_loss_x, train_loss_u = train(labeled_trainloader, unlabeled_trainloader, model, optimizer, ema_optimizer, train_criterion, epoch, use_cuda)
_, train_acc = validate(labeled_trainloader, ema_model, criterion, epoch, use_cuda, mode='Train Stats')
val_loss, val_acc = validate(val_loader, ema_model, criterion, epoch, use_cuda, mode='Valid Stats')
test_loss, test_acc = validate(test_loader, ema_model, criterion, epoch, use_cuda, mode='Test Stats ')
if test_loader is not None:
test_loss, test_acc = validate(test_loader, ema_model, criterion, epoch, use_cuda, mode='Test Stats ')
else:
test_loss, test_acc = [-1, -1]

step = args.val_iteration * (epoch + 1)

Expand Down Expand Up @@ -206,10 +222,18 @@ def train(labeled_trainloader, unlabeled_trainloader, model, optimizer, ema_opti
inputs_x, targets_x = labeled_train_iter.next()

try:
(inputs_u, inputs_u2), _ = unlabeled_train_iter.next()
if args.dataset == 'STL10':
inputs_u, _ = unlabeled_train_iter.next()
inputs_u2, _ = unlabeled_train_iter.next()
else:
(inputs_u, inputs_u2), _ = unlabeled_train_iter.next()
except:
unlabeled_train_iter = iter(unlabeled_trainloader)
(inputs_u, inputs_u2), _ = unlabeled_train_iter.next()
if args.dataset == 'STL10':
inputs_u, _ = unlabeled_train_iter.next()
inputs_u2, _ = unlabeled_train_iter.next()
else:
(inputs_u, inputs_u2), _ = unlabeled_train_iter.next()

# measure data loading time
data_time.update(time.time() - end)
Expand Down