Skip to content

Commit 3c5135c

Browse files
committed
Merge branch 'master' of github.com:mrT23/pytorch-image-models
2 parents d8b4c34 + b48e88c commit 3c5135c

File tree

2 files changed

+98
-0
lines changed

2 files changed

+98
-0
lines changed

timm/utils/model_kd.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import logging
2+
import torch
3+
import torch.nn as nn
4+
import torchvision.transforms as T
5+
from timm.models import create_model
6+
7+
_logger = logging.getLogger(__name__)
8+
9+
class build_kd_model(nn.Module):
10+
def __init__(self, args):
11+
super(build_kd_model, self).__init__()
12+
13+
_logger.info(f"Creating KD model: from '{args.kd_model_name}'")
14+
in_chans = 3
15+
if args.in_chans is not None:
16+
in_chans = args.in_chans
17+
model_kd = create_model(
18+
model_name=args.kd_model_name,
19+
num_classes=args.num_classes,
20+
pretrained=True,
21+
in_chans=in_chans)
22+
23+
# compile model
24+
model_kd.cpu().eval()
25+
try:
26+
model_kd = torch.compile(model_kd)
27+
_logger.info(f"torch.compile applied successfully to KD model")
28+
except Exception as e:
29+
_logger.warning(f"torch.compile failed with error {e}, continuing KD model without torch compilation")
30+
31+
self.model = model_kd.cuda()
32+
self.mean_model_kd = model_kd.default_cfg['mean']
33+
self.std_model_kd = model_kd.default_cfg['std']
34+
35+
# handling different normalization of teacher and student
36+
def normalize_input(self, input, student_model):
37+
if hasattr(student_model, 'module'):
38+
model_s = student_model.module
39+
else:
40+
model_s = student_model
41+
42+
mean_student = model_s.default_cfg['mean']
43+
std_student = model_s.default_cfg['std']
44+
45+
input_kd = input
46+
if mean_student != self.mean_model_kd or std_student != self.std_model_kd:
47+
std = (self.std_model_kd[0] / std_student[0], self.std_model_kd[1] / std_student[1],
48+
self.std_model_kd[2] / std_student[2])
49+
transform_std = T.Normalize(mean=(0, 0, 0), std=std)
50+
51+
mean = (self.mean_model_kd[0] - mean_student[0], self.mean_model_kd[1] - mean_student[1],
52+
self.mean_model_kd[2] - mean_student[2])
53+
transform_mean = T.Normalize(mean=mean, std=(1, 1, 1))
54+
55+
input_kd = transform_mean(transform_std(input))
56+
57+
return input_kd
58+
59+
60+
def add_kd_loss(_loss, output, input, model, model_kd, args):
61+
# student probability calculation
62+
prob_s = torch.nn.functional.log_softmax(output, dim=-1)
63+
64+
# teacher probability calculation
65+
with torch.no_grad():
66+
input_kd = model_kd.normalize_input(input, model)
67+
out_t = model_kd.model(input_kd.detach())
68+
prob_t = torch.nn.functional.softmax(out_t, dim=-1)
69+
70+
# adding KL loss
71+
if not args.use_kd_only_loss:
72+
_loss += args.alpha_kd * torch.nn.functional.kl_div(prob_s, prob_t, reduction='batchmean')
73+
else: # only kd
74+
_loss = args.alpha_kd * torch.nn.functional.kl_div(prob_s, prob_t, reduction='batchmean')
75+
76+
return _loss
77+

train.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from timm.optim import create_optimizer_v2, optimizer_kwargs
4242
from timm.scheduler import create_scheduler_v2, scheduler_kwargs
4343
from timm.utils import ApexScaler, NativeScaler
44+
from timm.utils.model_kd import build_kd_model, add_kd_loss
4445

4546
try:
4647
from apex import amp
@@ -415,6 +416,14 @@
415416
group.add_argument('--naflex-loss-scale', default='linear', type=str,
416417
help='Scale loss (gradient) by batch_size ("none", "sqrt", or "linear")')
417418

419+
# Knowledge Distillation parameters
420+
parser.add_argument('--kd-model-name', default=None, type=str,
421+
help='Name of teacher model for knowledge distillation')
422+
parser.add_argument('--alpha-kd', default=5, type=float,
423+
help='Weight for KD loss (default: 5)')
424+
parser.add_argument('--use-kd-only-loss', action='store_true', default=False,
425+
help='Use only KD loss, without cross-entropy loss')
426+
418427

419428
def _parse_args():
420429
# Do we have a config file to parse?
@@ -480,6 +489,11 @@ def main():
480489

481490
utils.random_seed(args.seed, args.rank)
482491

492+
# Create the KD teacher model if specified
493+
model_kd = None
494+
if args.kd_model_name is not None:
495+
model_kd = build_kd_model(args)
496+
483497
if args.fuser:
484498
utils.set_jit_fuser(args.fuser)
485499
if args.fast_norm:
@@ -1006,6 +1020,7 @@ def main():
10061020
mixup_fn=mixup_fn,
10071021
num_updates_total=num_epochs * updates_per_epoch,
10081022
naflex_mode=naflex_mode,
1023+
model_kd=model_kd,
10091024
)
10101025

10111026
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
@@ -1109,6 +1124,7 @@ def train_one_epoch(
11091124
mixup_fn=None,
11101125
num_updates_total=None,
11111126
naflex_mode=False,
1127+
model_kd=None,
11121128
):
11131129
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
11141130
if args.prefetcher and loader.mixup_enabled:
@@ -1155,6 +1171,11 @@ def _forward():
11551171
with amp_autocast():
11561172
output = model(input)
11571173
_loss = loss_fn(output, target)
1174+
1175+
# KD logic
1176+
if model_kd is not None:
1177+
_loss= add_kd_loss(_loss, output, input, model, model_kd, args)
1178+
11581179
if accum_steps > 1:
11591180
_loss /= accum_steps
11601181
return _loss

0 commit comments

Comments
 (0)