|  | 
|  | 1 | +"""Knowledge Distillation helpers for training with a teacher model.""" | 
|  | 2 | +import logging | 
|  | 3 | +from typing import Tuple | 
|  | 4 | + | 
|  | 5 | +import torch | 
|  | 6 | +import torch.nn as nn | 
|  | 7 | +import torchvision.transforms as T | 
|  | 8 | + | 
|  | 9 | +from timm.models import create_model | 
|  | 10 | + | 
|  | 11 | + | 
|  | 12 | +_logger = logging.getLogger(__name__) | 
|  | 13 | + | 
|  | 14 | + | 
|  | 15 | +class DistillationTeacher(nn.Module): | 
|  | 16 | +    """Wrapper for a teacher model used in knowledge distillation. | 
|  | 17 | +
 | 
|  | 18 | +    Creates and manages a pre-trained teacher model for knowledge distillation, | 
|  | 19 | +    handling model compilation and normalization differences between teacher and student. | 
|  | 20 | +
 | 
|  | 21 | +    Args: | 
|  | 22 | +        model_name: Name of the teacher model to create | 
|  | 23 | +        num_classes: Number of output classes | 
|  | 24 | +        in_chans: Number of input channels | 
|  | 25 | +        pretrained: Whether to load pretrained weights | 
|  | 26 | +        device: Device to place the model on (default: 'cuda') | 
|  | 27 | +        dtype: Model dtype (default: None, uses float32) | 
|  | 28 | +    """ | 
|  | 29 | + | 
|  | 30 | +    def __init__( | 
|  | 31 | +            self, | 
|  | 32 | +            model_name: str, | 
|  | 33 | +            num_classes: int, | 
|  | 34 | +            in_chans: int = 3, | 
|  | 35 | +            device: torch.device = torch.device('cuda'), | 
|  | 36 | +            dtype: torch.dtype = None, | 
|  | 37 | +    ): | 
|  | 38 | +        super().__init__() | 
|  | 39 | + | 
|  | 40 | +        _logger.info(f"Creating KD teacher model: '{model_name}'") | 
|  | 41 | + | 
|  | 42 | +        model_kd = create_model( | 
|  | 43 | +            model_name=model_name, | 
|  | 44 | +            num_classes=num_classes, | 
|  | 45 | +            pretrained=True, | 
|  | 46 | +            in_chans=in_chans, | 
|  | 47 | +        ) | 
|  | 48 | + | 
|  | 49 | +        model_kd = model_kd.to(device=device, dtype=dtype) | 
|  | 50 | +        model_kd.eval() | 
|  | 51 | + | 
|  | 52 | +        try: | 
|  | 53 | +            model_kd = torch.compile(model_kd) | 
|  | 54 | +            _logger.info("torch.compile applied successfully to KD teacher model") | 
|  | 55 | +        except Exception as e: | 
|  | 56 | +            _logger.warning(f"torch.compile failed with error {e}, continuing without compilation") | 
|  | 57 | + | 
|  | 58 | +        self.model = model_kd | 
|  | 59 | +        self.mean_model_kd = model_kd.pretrained_cfg['mean'] | 
|  | 60 | +        self.std_model_kd = model_kd.pretrained_cfg['std'] | 
|  | 61 | + | 
|  | 62 | +    def normalize_input( | 
|  | 63 | +        self, | 
|  | 64 | +        input: torch.Tensor, | 
|  | 65 | +        student_model: nn.Module, | 
|  | 66 | +    ) -> torch.Tensor: | 
|  | 67 | +        """Normalize input to match teacher's expected normalization. | 
|  | 68 | +
 | 
|  | 69 | +        Handles different normalization between teacher and student models by | 
|  | 70 | +        converting the student's normalized input to the teacher's expected format. | 
|  | 71 | +
 | 
|  | 72 | +        Args: | 
|  | 73 | +            input: Input tensor (already normalized for student) | 
|  | 74 | +            student_model: Student model to extract normalization params from | 
|  | 75 | +
 | 
|  | 76 | +        Returns: | 
|  | 77 | +            Input tensor normalized for the teacher model | 
|  | 78 | +        """ | 
|  | 79 | +        if hasattr(student_model, 'module'): | 
|  | 80 | +            model_s = student_model.module | 
|  | 81 | +        else: | 
|  | 82 | +            model_s = student_model | 
|  | 83 | + | 
|  | 84 | +        mean_student = model_s.pretrained_cfg['mean'] | 
|  | 85 | +        std_student = model_s.pretrained_cfg['std'] | 
|  | 86 | + | 
|  | 87 | +        input_kd = input | 
|  | 88 | +        if mean_student != self.mean_model_kd or std_student != self.std_model_kd: | 
|  | 89 | +            # Compute normalized std and mean transformations | 
|  | 90 | +            std = tuple(t_std / s_std for t_std, s_std in zip(self.std_model_kd, std_student)) | 
|  | 91 | +            transform_std = T.Normalize(mean=(0, 0, 0), std=std) | 
|  | 92 | + | 
|  | 93 | +            mean = tuple(t_mean - s_mean for t_mean, s_mean in zip(self.mean_model_kd, mean_student)) | 
|  | 94 | +            transform_mean = T.Normalize(mean=mean, std=(1, 1, 1)) | 
|  | 95 | + | 
|  | 96 | +            input_kd = transform_mean(transform_std(input)) | 
|  | 97 | + | 
|  | 98 | +        return input_kd | 
|  | 99 | + | 
|  | 100 | + | 
|  | 101 | +def apply_kd_loss( | 
|  | 102 | +        loss: torch.Tensor, | 
|  | 103 | +        student_output: torch.Tensor, | 
|  | 104 | +        input: torch.Tensor, | 
|  | 105 | +        student_model: nn.Module, | 
|  | 106 | +        teacher_model: DistillationTeacher, | 
|  | 107 | +        alpha_kd: float, | 
|  | 108 | +        use_kd_only: bool = False, | 
|  | 109 | +) -> torch.Tensor: | 
|  | 110 | +    """Apply knowledge distillation loss. | 
|  | 111 | +
 | 
|  | 112 | +    Computes KL divergence between student and teacher outputs and combines | 
|  | 113 | +    with the base loss (or replaces it if use_kd_only is True). | 
|  | 114 | +
 | 
|  | 115 | +    Args: | 
|  | 116 | +        loss: Base loss (e.g., cross-entropy with labels) | 
|  | 117 | +        student_output: Logits from student model | 
|  | 118 | +        input: Input tensor (already normalized for student) | 
|  | 119 | +        student_model: Student model being trained | 
|  | 120 | +        teacher_model: Teacher model for distillation | 
|  | 121 | +        alpha_kd: Weight for the KD loss component | 
|  | 122 | +        use_kd_only: If True, only use KD loss (ignore base loss) | 
|  | 123 | +
 | 
|  | 124 | +    Returns: | 
|  | 125 | +        Combined loss with KD component | 
|  | 126 | +    """ | 
|  | 127 | +    # Student probability calculation | 
|  | 128 | +    prob_s = torch.nn.functional.log_softmax(student_output, dim=-1) | 
|  | 129 | + | 
|  | 130 | +    # Teacher probability calculation | 
|  | 131 | +    with torch.inference_mode(): | 
|  | 132 | +        input_kd = teacher_model.normalize_input(input, student_model) | 
|  | 133 | +        out_t = teacher_model.model(input_kd.detach()) | 
|  | 134 | +        prob_t = torch.nn.functional.softmax(out_t, dim=-1) | 
|  | 135 | + | 
|  | 136 | +    # Compute KL divergence loss | 
|  | 137 | +    kd_loss = alpha_kd * torch.nn.functional.kl_div(prob_s, prob_t, reduction='batchmean') | 
|  | 138 | + | 
|  | 139 | +    if use_kd_only: | 
|  | 140 | +        return kd_loss | 
|  | 141 | +    else: | 
|  | 142 | +        return loss + kd_loss | 
0 commit comments