2828import regex as re
2929import torch
3030import torch .nn as nn
31- import torch .nn .functional as F
3231from tqdm import tqdm
3332
3433from modelopt .torch .opt .conversion import ModeloptStateManager
3534from modelopt .torch .opt .hparam import CustomHPType , Hparam , HPType
3635from modelopt .torch .opt .searcher import LPS , BaseSearcher , SearchConfig , SearchStateDict
3736from modelopt .torch .opt .utils import get_hparam , named_hparams
3837from modelopt .torch .utils import create_param_grad_clear_hook , print_rank_0 , report_memory
39- from modelopt .torch .utils .distributed import DistributedProcessGroup , is_master
38+ from modelopt .torch .utils .distributed import DistributedProcessGroup , ParallelState , is_master
4039
4140from . import config as mtq_config
4241from . import model_calib
@@ -944,19 +943,72 @@ def run_search_with_stats(self, max_weight_size, verbose=False):
944943 return best_recipes , is_satisfied
945944
946945
946+ # TODO: does torch compile improves speed?
947947@torch .compile
948- def _get_kl_div_loss (logits_unquant : torch .Tensor , logits_quant : torch .Tensor ) -> torch .Tensor :
949- # TODO: Support TensorParallel
950- prob_unquant = F .softmax (logits_unquant , dim = - 1 )
951- log_prob_quant = F .log_softmax (logits_quant , dim = - 1 )
952- return F .kl_div (log_prob_quant , prob_unquant , reduction = "sum" , log_target = False )
948+ def _get_softmax_dist (
949+ logits : torch .Tensor , tp_group , return_log_prob : bool = False
950+ ) -> torch .Tensor :
951+ # TODO: test this
952+ dtype = logits .dtype
953+ max_logits = torch .amax (logits , dim = - 1 , keepdim = True )
954+ torch .distributed .all_reduce (max_logits , op = torch .distributed .ReduceOp .MAX , group = tp_group )
955+ logits = (logits - max_logits ).float ()
956+ sum_exp_logits = torch .exp (torch .logsumexp (logits , dim = - 1 , keepdim = True ))
957+ torch .distributed .all_reduce (sum_exp_logits , op = torch .distributed .ReduceOp .SUM , group = tp_group )
958+ logits = logits - torch .log (sum_exp_logits )
959+ if return_log_prob :
960+ return logits .to (dtype )
961+ else :
962+ return torch .exp (logits ).to (dtype )
963+
964+
965+ @torch .compile
966+ def _get_softmax (logits : torch .Tensor , return_log_prob : bool = False ) -> torch .Tensor :
967+ # TODO: do we need to do log_softmax in float32?
968+ # log_softmax is supposed to be numerically stable implementation
969+ log_prob = torch .log_softmax (logits .float (), dim = - 1 )
970+ if return_log_prob :
971+ return log_prob
972+ else :
973+ return torch .exp (log_prob )
974+
975+
976+ @torch .compile
977+ def _get_p_log_q (p : torch .Tensor , log_q : torch .Tensor ) -> torch .Tensor :
978+ return torch .sum (p * log_q ).float ()
979+
980+
981+ def _get_prob_from_logits (
982+ logits : torch .Tensor , return_log_prob : bool = False , lm_head : nn .Module = None
983+ ) -> torch .Tensor :
984+ parallel_state : ParallelState | None = (
985+ getattr (lm_head , "parallel_state" , None ) if lm_head is not None else None
986+ )
987+ if parallel_state is not None and parallel_state .tensor_parallel_group .is_initialized ():
988+ return _get_softmax_dist (
989+ logits , parallel_state .tensor_parallel_group .group , return_log_prob
990+ )
991+ return _get_softmax (logits , return_log_prob )
992+
993+
994+ def _get_kl_div_loss (
995+ prob_unquant : torch .Tensor , logits_quant : torch .Tensor , lm_head : nn .Module = None
996+ ) -> torch .Tensor :
997+ log_prob_quant = _get_prob_from_logits (logits_quant , return_log_prob = True , lm_head = lm_head )
998+ # We dont need to calculate the full kl div loss here, just get p*log_q
999+ return _get_p_log_q (prob_unquant , log_prob_quant )
1000+
1001+
1002+ def _get_lm_head (model : nn .Module ) -> nn .Module :
1003+ for name , module in model .named_modules ():
1004+ if name .endswith (("lm_head" , "output_layer" )): # HF transformers models or Megatron models
1005+ return module
1006+ return None
9531007
9541008
9551009class AutoQuantizeKLDivSearcher (_AutoQuantizeBaseSearcher ):
9561010 """A searcher for AutoQuantize algorithm that uses KL-Divergence loss based score estimation."""
9571011
958- score_module_rules : list [str | Callable ] = [lambda name : "" ]
959-
9601012 @property
9611013 def default_search_config (self ):
9621014 """Get the default config for the searcher."""
@@ -973,9 +1025,10 @@ def sanitize_search_config(self, config: SearchConfig | None) -> SearchConfig:
9731025 config = config or {}
9741026 for ignored_key in ["score_func" , "loss_func" , "forward_backward_step" ]:
9751027 if ignored_key in config :
976- warnings .warn (
977- f"`{ ignored_key } ` is ignored for KL-Divergence loss based `auto_quantize`."
978- )
1028+ if config [ignored_key ] is not None :
1029+ warnings .warn (
1030+ f"`{ ignored_key } ` is ignored for KL-Divergence loss based `auto_quantize`."
1031+ )
9791032 config .pop (ignored_key )
9801033 config = super ().sanitize_search_config (config )
9811034 assert config ["forward_step" ] is not None , (
@@ -984,21 +1037,12 @@ def sanitize_search_config(self, config: SearchConfig | None) -> SearchConfig:
9841037 )
9851038 return config
9861039
987- @torch .no_grad ()
1040+ @torch .inference_mode ()
9881041 def estimate_sensitivity_scores (self ):
9891042 """Estimate the sensitivity scores for the model.
9901043
9911044 Higher score means more sensitive to quantization.
9921045 """
993- # Check if tensor parallelism is being used
994- for name , module in self .model .named_modules ():
995- if hasattr (module , "parallel_state" ):
996- if hasattr (module .parallel_state , "tensor_parallel_group" ):
997- if module .parallel_state .tensor_parallel_group .is_initialized ():
998- warnings .warn (
999- "Tensor Parallel is not supported for KL-Divergence based auto_quantize. "
1000- )
1001- break
10021046
10031047 def set_to_unquantized ():
10041048 for name , hparam in named_hparams (self .model , unique = True ):
@@ -1016,17 +1060,27 @@ def set_to_unquantized():
10161060 ):
10171061 set_to_unquantized ()
10181062 logits_unquant = self .config ["forward_step" ](self .model , data )
1063+ prob_unquant = _get_prob_from_logits (
1064+ logits_unquant ,
1065+ return_log_prob = False ,
1066+ lm_head = _get_lm_head (self .model ),
1067+ )
10191068
1020- for name , hparam in named_hparams (self .model , configurable = True ):
1069+ for name , hparam in tqdm (
1070+ list (named_hparams (self .model , configurable = True )), desc = "Evaluating hparams"
1071+ ):
10211072 if not isinstance (hparam , QuantRecipeHparam ):
10221073 continue
10231074 for recipe in hparam .choices :
10241075 if recipe == QuantRecipe (quant_cfg = None ):
10251076 continue
10261077 hparam .active = recipe
10271078 logits_quant = self .config ["forward_step" ](self .model , data )
1028- score = _get_kl_div_loss (logits_unquant , logits_quant )
1029- hparam ._importance_dict [recipe ][hparam .score_modules [0 ]] = score
1079+ score = _get_kl_div_loss (prob_unquant , logits_quant , _get_lm_head (self .model ))
1080+ if hparam ._importance_dict [recipe ][hparam .score_modules [0 ]] is None :
1081+ hparam ._importance_dict [recipe ][hparam .score_modules [0 ]] = score
1082+ else :
1083+ hparam ._importance_dict [recipe ][hparam .score_modules [0 ]] += score
10301084 hparam .active = QuantRecipe (quant_cfg = None )
10311085
10321086 def run_search_with_stats (self , max_weight_size , verbose = False ):
0 commit comments