Skip to content
This repository was archived by the owner on Apr 17, 2023. It is now read-only.

Commit d666b4c

Browse files
Change model to make it prunable
1 parent 54ba726 commit d666b4c

15 files changed

+397
-24
lines changed

torchreid/apis/training.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,10 @@
3131

3232
def run_lr_finder(cfg, datamanager, model, optimizer, scheduler, classes,
3333
rebuild_model=True, gpu_num=1, split_models=False):
34-
if rebuild_model:
35-
tmp_model = model
36-
else:
37-
tmp_model = deepcopy(model)
34+
if not rebuild_model:
35+
backup_model = deepcopy(model)
3836

39-
engine = build_engine(cfg, datamanager, tmp_model, optimizer, scheduler, initial_lr=cfg.train.lr)
37+
engine = build_engine(cfg, datamanager, model, optimizer, scheduler, initial_lr=cfg.train.lr)
4038
lr_finder = LrFinder(engine=engine, **lr_finder_run_kwargs(cfg))
4139
aux_lr = lr_finder.process()
4240

@@ -54,16 +52,18 @@ def run_lr_finder(cfg, datamanager, model, optimizer, scheduler, classes,
5452
set_random_seed(cfg.train.seed, cfg.train.deterministic)
5553
datamanager = build_datamanager(cfg, classes)
5654
num_train_classes = datamanager.num_train_pids
55+
5756
if rebuild_model:
58-
model = torchreid.models.build_model(**model_kwargs(cfg, num_train_classes))
57+
backup_model = torchreid.models.build_model(**model_kwargs(cfg, num_train_classes))
5958
num_aux_models = len(cfg.mutual_learning.aux_configs)
60-
model, _ = put_main_model_on_the_device(model, cfg.use_gpu, gpu_num, num_aux_models, split_models)
61-
optimizer = torchreid.optim.build_optimizer(model, **optimizer_kwargs(cfg))
59+
backup_model, _ = put_main_model_on_the_device(backup_model, cfg.use_gpu, gpu_num, num_aux_models, split_models)
60+
61+
optimizer = torchreid.optim.build_optimizer(backup_model, **optimizer_kwargs(cfg))
6262
scheduler = torchreid.optim.build_lr_scheduler(optimizer=optimizer,
6363
num_iter=datamanager.num_iter,
6464
**lr_scheduler_kwargs(cfg))
6565

66-
return cfg.train.lr
66+
return cfg.train.lr, backup_model, optimizer, scheduler
6767

6868

6969
def run_training(cfg, datamanager, model, optimizer, scheduler, extra_device_ids, init_lr,

torchreid/engine/engine.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,6 @@ def run(
326326
self.fixbase_epoch = fixbase_epoch
327327
test_acc = AverageMeter()
328328
print('=> Start training')
329-
330329
if perf_monitor and not lr_finder: perf_monitor.on_train_begin()
331330
for self.epoch in range(self.start_epoch, self.max_epoch):
332331
# change the NumPy’s seed at every epoch

torchreid/models/mobilenetv3.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import torch
44
import torch.nn as nn
5+
from torch.nn import functional as F
56

67
from torchreid.losses import AngleSimpleLinear
78
from torchreid.ops import Dropout, EvalModeSetter, rsc
@@ -33,19 +34,17 @@
3334
class SELayer(nn.Module):
3435
def __init__(self, channel, reduction=4):
3536
super(SELayer, self).__init__()
36-
self.avg_pool = nn.AdaptiveAvgPool2d(1)
3737
self.fc = nn.Sequential(
38-
nn.Linear(channel, make_divisible(channel // reduction, 8)),
38+
nn.Conv2d(channel, make_divisible(channel // reduction, 8), 1),
3939
nn.ReLU(inplace=True),
40-
nn.Linear(make_divisible(channel // reduction, 8), channel),
40+
nn.Conv2d(make_divisible(channel // reduction, 8), channel, 1),
4141
HSigmoid()
4242
)
4343

4444
def forward(self, x):
4545
with no_nncf_se_layer_context():
46-
b, c, _, _ = x.size()
47-
y = self.avg_pool(x).view(b, c)
48-
y = self.fc(y).view(b, c, 1, 1)
46+
y = F.adaptive_avg_pool2d(x, 1)
47+
y = self.fc(y)
4948
return x * y
5049

5150

torchreid/utils/torchtools.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,19 @@ def _print_loading_weights_inconsistencies(discarded_layers, unmatched_layers):
285285
)
286286

287287

288+
def update_checkpoint_mobilenet_v3(checkpoint):
289+
fc = []
290+
for k in checkpoint:
291+
if 'fc' in k and not 'bias' in k:
292+
fc.append(k)
293+
for name in fc:
294+
w = checkpoint[name]
295+
shape = w.shape
296+
w_new = w.view(shape + (1, 1))
297+
print(name, ': ', checkpoint[name].shape, '->', w_new.shape)
298+
checkpoint[name] = w_new
299+
300+
288301
def load_pretrained_weights(model, file_path='', pretrained_dict=None):
289302
r"""Loads pretrianed weights to model.
290303
Features::
@@ -317,6 +330,7 @@ def _remove_prefix(key, prefix):
317330
else:
318331
state_dict = checkpoint
319332

333+
update_checkpoint_mobilenet_v3(state_dict)
320334
model_dict = model.state_dict()
321335
new_state_dict = OrderedDict()
322336
matched_layers, discarded_layers = [], []
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
python /home/automation/dlyakhov/training_extensions/external/deep-object-reid/tools/main.py \
2+
--config-file /home/automation/dlyakhov/training_extensions/external/deep-object-reid/training/pruning_int8/mobilenet_v3_small/mobilenetv3_small.yml \
3+
--gpu-num 1 \
4+
--custom-roots \
5+
/home/automation/dlyakhov/training_extensions/external/deep-object-reid/CIFAR100/train \
6+
/home/automation/dlyakhov/training_extensions/external/deep-object-reid/CIFAR100/val \
7+
--root _ \
8+
model.load_weights /mnt/icv_externalN/dlyakhov/ote-classification-checkpoint/CIFAR100_mobielenet_v3_small/model_0/{name}.pth.tar-142 \
9+
data.save_dir /home/automation/dlyakhov/training_extensions/external/deep-object-reid/training/pruning_int8/mobilenet_v3_small/output_21_09
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
python /home/automation/dlyakhov/training_extensions/external/deep-object-reid/tools/main.py \
2+
--config-file /home/automation/dlyakhov/training_extensions/external/deep-object-reid/training/pruning_int8/mobilenet_v3_small/mobilenetv3_small_by_flops.yml \
3+
--gpu-num 1 \
4+
--custom-roots \
5+
/home/automation/dlyakhov/training_extensions/external/deep-object-reid/CIFAR100/train \
6+
/home/automation/dlyakhov/training_extensions/external/deep-object-reid/CIFAR100/val \
7+
--root _ \
8+
model.load_weights /mnt/icv_externalN/dlyakhov/ote-classification-checkpoint/CIFAR100_mobielenet_v3_small/model_0/{name}.pth.tar-142 \
9+
data.save_dir /home/automation/dlyakhov/training_extensions/external/deep-object-reid/training/pruning_int8/mobilenet_v3_small/output_pruning_0.1_filters_by_flops_13_08

training/pruning_int8/mobilenet_v3_small/compress_sparse.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@ python /home/automation/dlyakhov/training_extensions/external/deep-object-reid/t
66
/home/automation/dlyakhov/training_extensions/external/deep-object-reid/CIFAR100/val \
77
--root _ \
88
model.load_weights /mnt/icv_externalN/dlyakhov/ote-classification-checkpoint/CIFAR100_mobielenet_v3_small/model_0/{name}.pth.tar-142 \
9-
data.save_dir /home/automation/dlyakhov/training_extensions/external/deep-object-reid/training/pruning_int8/mobilenet_v3_small/sparsity_magnitude_24_09
9+
data.save_dir /home/automation/dlyakhov/training_extensions/external/deep-object-reid/training/pruning_int8/mobilenet_v3_small/sparsity_23_09

training/pruning_int8/mobilenet_v3_small/mobilenetv3_large_aux.yml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,10 @@ train:
3535
deterministic: True
3636
patience: 5
3737
gamma: 0.1
38-
sam:
39-
rho: 0.05
38+
39+
sam:
40+
enable: True
41+
rho: 0.05
4042

4143
test:
4244
batch_size: 128
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
2+
lr_finder:
3+
enable: True
4+
mode: TPE
5+
stop_after: False
6+
num_epochs: 6
7+
step: 0.001
8+
epochs_warmup: 1
9+
path_to_savefig: 'lr_finder.jpg'
10+
max_lr: 0.029
11+
min_lr: 0.005
12+
n_trials: 15
13+
14+
model:
15+
name: 'mobilenetv3_small'
16+
type: 'classification'
17+
pretrained: True
18+
save_chkpt: True
19+
feature_dim: 1024
20+
21+
mutual_learning:
22+
aux_configs: ['mobilenetv3_large_aux.yml']
23+
24+
custom_datasets:
25+
roots: ['data/CIFAR100/train', 'data/CIFAR100/val']
26+
types: ['classification_image_folder', 'classification_image_folder']
27+
names: ['CIFAR100_train', 'CIFAR100_val']
28+
29+
data:
30+
root: './'
31+
sources: ['CIFAR100_train']
32+
targets: ['CIFAR100_val']
33+
height: 224
34+
width: 224
35+
norm_mean: [0.485, 0.456, 0.406]
36+
norm_std: [0.229, 0.224, 0.225]
37+
save_dir: 'output/mobilenetv3_small/log'
38+
workers: 6
39+
transforms:
40+
random_flip:
41+
enable: True
42+
p: 0.5
43+
augmix:
44+
enable: True
45+
cfg_str: "augmix-m5-w3"
46+
47+
loss:
48+
name: 'softmax'
49+
softmax:
50+
s: 1.0
51+
compute_s: False
52+
53+
sampler:
54+
train_sampler: 'RandomSampler'
55+
56+
metric_losses:
57+
enable: False
58+
59+
train:
60+
optim: 'sam'
61+
lr: 0.013
62+
nbd: True
63+
max_epoch: 200
64+
weight_decay: 5e-4
65+
batch_size: 84
66+
lr_scheduler: 'warmup'
67+
warmup: 5
68+
base_scheduler: 'reduce_on_plateau_delayed'
69+
epoch_delay: 40
70+
early_stoping: True
71+
train_patience: 5
72+
lr_decay_factor: 200
73+
deterministic: True
74+
patience: 5
75+
gamma: 0.1
76+
ema:
77+
enable: True
78+
ema_decay: 0.999
79+
80+
sam:
81+
enable: True
82+
rho: 0.05
83+
84+
test:
85+
batch_size: 128
86+
evaluate: False
87+
eval_freq: 1
88+
89+
nncf:
90+
enable: True
91+
coeff_decrease_lr_for_nncf: 1.
92+
nncf_config_path: 'nncf_config.json'
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
2+
lr_finder:
3+
enable: True
4+
mode: TPE
5+
stop_after: False
6+
num_epochs: 6
7+
step: 0.001
8+
epochs_warmup: 1
9+
path_to_savefig: 'lr_finder.jpg'
10+
max_lr: 0.029
11+
min_lr: 0.005
12+
n_trials: 15
13+
14+
model:
15+
name: 'mobilenetv3_small'
16+
type: 'classification'
17+
pretrained: True
18+
save_chkpt: True
19+
feature_dim: 1024
20+
21+
mutual_learning:
22+
aux_configs: ['mobilenetv3_large_aux.yml']
23+
24+
custom_datasets:
25+
roots: ['data/CIFAR100/train', 'data/CIFAR100/val']
26+
types: ['classification_image_folder', 'classification_image_folder']
27+
names: ['CIFAR100_train', 'CIFAR100_val']
28+
29+
data:
30+
root: './'
31+
sources: ['CIFAR100_train']
32+
targets: ['CIFAR100_val']
33+
height: 224
34+
width: 224
35+
norm_mean: [0.485, 0.456, 0.406]
36+
norm_std: [0.229, 0.224, 0.225]
37+
save_dir: 'output/mobilenetv3_small/log'
38+
workers: 6
39+
transforms:
40+
random_flip:
41+
enable: True
42+
p: 0.5
43+
augmix:
44+
enable: True
45+
cfg_str: "augmix-m5-w3"
46+
47+
loss:
48+
name: 'softmax'
49+
softmax:
50+
s: 1.0
51+
compute_s: False
52+
53+
sampler:
54+
train_sampler: 'RandomSampler'
55+
56+
metric_losses:
57+
enable: False
58+
59+
train:
60+
optim: 'sam'
61+
lr: 0.013
62+
nbd: True
63+
max_epoch: 200
64+
weight_decay: 5e-4
65+
batch_size: 84
66+
lr_scheduler: 'warmup'
67+
warmup: 5
68+
base_scheduler: 'reduce_on_plateau'
69+
early_stoping: True
70+
train_patience: 5
71+
lr_decay_factor: 200
72+
deterministic: True
73+
patience: 5
74+
gamma: 0.1
75+
ema:
76+
enable: True
77+
ema_decay: 0.999
78+
79+
sam:
80+
enable: True
81+
rho: 0.05
82+
83+
test:
84+
batch_size: 128
85+
evaluate: False
86+
eval_freq: 1
87+
88+
nncf:
89+
enable: True
90+
coeff_decrease_lr_for_nncf: 1.
91+
nncf_config_path: 'nncf_config_by_flops.json'

0 commit comments

Comments
 (0)