Skip to content

Commit d705d67

Browse files
committed
Cleanup distillation code
1 parent 3c5135c commit d705d67

File tree

4 files changed

+167
-84
lines changed

4 files changed

+167
-84
lines changed

timm/kd/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
"""Knowledge Distillation module for timm"""
2+
from .distillation import DistillationTeacher, apply_kd_loss
3+
4+
__all__ = ['DistillationTeacher', 'apply_kd_loss']

timm/kd/distillation.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
"""Knowledge Distillation helpers for training with a teacher model."""
2+
import logging
3+
from typing import Tuple
4+
5+
import torch
6+
import torch.nn as nn
7+
import torchvision.transforms as T
8+
9+
from timm.models import create_model
10+
11+
12+
_logger = logging.getLogger(__name__)
13+
14+
15+
class DistillationTeacher(nn.Module):
16+
"""Wrapper for a teacher model used in knowledge distillation.
17+
18+
Creates and manages a pre-trained teacher model for knowledge distillation,
19+
handling model compilation and normalization differences between teacher and student.
20+
21+
Args:
22+
model_name: Name of the teacher model to create
23+
num_classes: Number of output classes
24+
in_chans: Number of input channels
25+
pretrained: Whether to load pretrained weights
26+
device: Device to place the model on (default: 'cuda')
27+
dtype: Model dtype (default: None, uses float32)
28+
"""
29+
30+
def __init__(
31+
self,
32+
model_name: str,
33+
num_classes: int,
34+
in_chans: int = 3,
35+
device: torch.device = torch.device('cuda'),
36+
dtype: torch.dtype = None,
37+
):
38+
super().__init__()
39+
40+
_logger.info(f"Creating KD teacher model: '{model_name}'")
41+
42+
model_kd = create_model(
43+
model_name=model_name,
44+
num_classes=num_classes,
45+
pretrained=True,
46+
in_chans=in_chans,
47+
)
48+
49+
model_kd = model_kd.to(device=device, dtype=dtype)
50+
model_kd.eval()
51+
52+
try:
53+
model_kd = torch.compile(model_kd)
54+
_logger.info("torch.compile applied successfully to KD teacher model")
55+
except Exception as e:
56+
_logger.warning(f"torch.compile failed with error {e}, continuing without compilation")
57+
58+
self.model = model_kd
59+
self.mean_model_kd = model_kd.pretrained_cfg['mean']
60+
self.std_model_kd = model_kd.pretrained_cfg['std']
61+
62+
def normalize_input(
63+
self,
64+
input: torch.Tensor,
65+
student_model: nn.Module,
66+
) -> torch.Tensor:
67+
"""Normalize input to match teacher's expected normalization.
68+
69+
Handles different normalization between teacher and student models by
70+
converting the student's normalized input to the teacher's expected format.
71+
72+
Args:
73+
input: Input tensor (already normalized for student)
74+
student_model: Student model to extract normalization params from
75+
76+
Returns:
77+
Input tensor normalized for the teacher model
78+
"""
79+
if hasattr(student_model, 'module'):
80+
model_s = student_model.module
81+
else:
82+
model_s = student_model
83+
84+
mean_student = model_s.pretrained_cfg['mean']
85+
std_student = model_s.pretrained_cfg['std']
86+
87+
input_kd = input
88+
if mean_student != self.mean_model_kd or std_student != self.std_model_kd:
89+
# Compute normalized std and mean transformations
90+
std = tuple(t_std / s_std for t_std, s_std in zip(self.std_model_kd, std_student))
91+
transform_std = T.Normalize(mean=(0, 0, 0), std=std)
92+
93+
mean = tuple(t_mean - s_mean for t_mean, s_mean in zip(self.mean_model_kd, mean_student))
94+
transform_mean = T.Normalize(mean=mean, std=(1, 1, 1))
95+
96+
input_kd = transform_mean(transform_std(input))
97+
98+
return input_kd
99+
100+
101+
def apply_kd_loss(
102+
loss: torch.Tensor,
103+
student_output: torch.Tensor,
104+
input: torch.Tensor,
105+
student_model: nn.Module,
106+
teacher_model: DistillationTeacher,
107+
alpha_kd: float,
108+
use_kd_only: bool = False,
109+
) -> torch.Tensor:
110+
"""Apply knowledge distillation loss.
111+
112+
Computes KL divergence between student and teacher outputs and combines
113+
with the base loss (or replaces it if use_kd_only is True).
114+
115+
Args:
116+
loss: Base loss (e.g., cross-entropy with labels)
117+
student_output: Logits from student model
118+
input: Input tensor (already normalized for student)
119+
student_model: Student model being trained
120+
teacher_model: Teacher model for distillation
121+
alpha_kd: Weight for the KD loss component
122+
use_kd_only: If True, only use KD loss (ignore base loss)
123+
124+
Returns:
125+
Combined loss with KD component
126+
"""
127+
# Student probability calculation
128+
prob_s = torch.nn.functional.log_softmax(student_output, dim=-1)
129+
130+
# Teacher probability calculation
131+
with torch.inference_mode():
132+
input_kd = teacher_model.normalize_input(input, student_model)
133+
out_t = teacher_model.model(input_kd.detach())
134+
prob_t = torch.nn.functional.softmax(out_t, dim=-1)
135+
136+
# Compute KL divergence loss
137+
kd_loss = alpha_kd * torch.nn.functional.kl_div(prob_s, prob_t, reduction='batchmean')
138+
139+
if use_kd_only:
140+
return kd_loss
141+
else:
142+
return loss + kd_loss

timm/utils/model_kd.py

Lines changed: 0 additions & 77 deletions
This file was deleted.

train.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from timm.optim import create_optimizer_v2, optimizer_kwargs
4242
from timm.scheduler import create_scheduler_v2, scheduler_kwargs
4343
from timm.utils import ApexScaler, NativeScaler
44-
from timm.utils.model_kd import build_kd_model, add_kd_loss
44+
from timm.kd import DistillationTeacher, apply_kd_loss
4545

4646
try:
4747
from apex import amp
@@ -489,11 +489,6 @@ def main():
489489

490490
utils.random_seed(args.seed, args.rank)
491491

492-
# Create the KD teacher model if specified
493-
model_kd = None
494-
if args.kd_model_name is not None:
495-
model_kd = build_kd_model(args)
496-
497492
if args.fuser:
498493
utils.set_jit_fuser(args.fuser)
499494
if args.fast_norm:
@@ -543,6 +538,17 @@ def main():
543538
if args.grad_checkpointing:
544539
model.set_grad_checkpointing(enable=True)
545540

541+
# Create the KD teacher model if specified
542+
model_kd = None
543+
if args.kd_model_name is not None:
544+
model_kd = DistillationTeacher(
545+
model_name=args.kd_model_name,
546+
num_classes=args.num_classes,
547+
in_chans=in_chans,
548+
device=device,
549+
dtype=model_dtype,
550+
)
551+
546552
if utils.is_primary(args):
547553
_logger.info(
548554
f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}')
@@ -1174,7 +1180,15 @@ def _forward():
11741180

11751181
# KD logic
11761182
if model_kd is not None:
1177-
_loss= add_kd_loss(_loss, output, input, model, model_kd, args)
1183+
_loss = apply_kd_loss(
1184+
loss=_loss,
1185+
student_output=output,
1186+
input=input,
1187+
student_model=model,
1188+
teacher_model=model_kd,
1189+
alpha_kd=args.alpha_kd,
1190+
use_kd_only=args.use_kd_only_loss,
1191+
)
11781192

11791193
if accum_steps > 1:
11801194
_loss /= accum_steps

0 commit comments

Comments
 (0)