Skip to content

Commit 197d4d6

Browse files
committed
cheery-picked final PR changes
1 parent b941367 commit 197d4d6

File tree

1 file changed

+79
-25
lines changed

1 file changed

+79
-25
lines changed

modelopt/torch/quantization/algorithms.py

Lines changed: 79 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,14 @@
2828
import regex as re
2929
import torch
3030
import torch.nn as nn
31-
import torch.nn.functional as F
3231
from tqdm import tqdm
3332

3433
from modelopt.torch.opt.conversion import ModeloptStateManager
3534
from modelopt.torch.opt.hparam import CustomHPType, Hparam, HPType
3635
from modelopt.torch.opt.searcher import LPS, BaseSearcher, SearchConfig, SearchStateDict
3736
from modelopt.torch.opt.utils import get_hparam, named_hparams
3837
from modelopt.torch.utils import create_param_grad_clear_hook, print_rank_0, report_memory
39-
from modelopt.torch.utils.distributed import DistributedProcessGroup, is_master
38+
from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState, is_master
4039

4140
from . import config as mtq_config
4241
from . import model_calib
@@ -944,19 +943,72 @@ def run_search_with_stats(self, max_weight_size, verbose=False):
944943
return best_recipes, is_satisfied
945944

946945

946+
# TODO: does torch compile improves speed?
947947
@torch.compile
948-
def _get_kl_div_loss(logits_unquant: torch.Tensor, logits_quant: torch.Tensor) -> torch.Tensor:
949-
# TODO: Support TensorParallel
950-
prob_unquant = F.softmax(logits_unquant, dim=-1)
951-
log_prob_quant = F.log_softmax(logits_quant, dim=-1)
952-
return F.kl_div(log_prob_quant, prob_unquant, reduction="sum", log_target=False)
948+
def _get_softmax_dist(
949+
logits: torch.Tensor, tp_group, return_log_prob: bool = False
950+
) -> torch.Tensor:
951+
# TODO: test this
952+
dtype = logits.dtype
953+
max_logits = torch.amax(logits, dim=-1, keepdim=True)
954+
torch.distributed.all_reduce(max_logits, op=torch.distributed.ReduceOp.MAX, group=tp_group)
955+
logits = (logits - max_logits).float()
956+
sum_exp_logits = torch.exp(torch.logsumexp(logits, dim=-1, keepdim=True))
957+
torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=tp_group)
958+
logits = logits - torch.log(sum_exp_logits)
959+
if return_log_prob:
960+
return logits.to(dtype)
961+
else:
962+
return torch.exp(logits).to(dtype)
963+
964+
965+
@torch.compile
966+
def _get_softmax(logits: torch.Tensor, return_log_prob: bool = False) -> torch.Tensor:
967+
# TODO: do we need to do log_softmax in float32?
968+
# log_softmax is supposed to be numerically stable implementation
969+
log_prob = torch.log_softmax(logits.float(), dim=-1)
970+
if return_log_prob:
971+
return log_prob
972+
else:
973+
return torch.exp(log_prob)
974+
975+
976+
@torch.compile
977+
def _get_p_log_q(p: torch.Tensor, log_q: torch.Tensor) -> torch.Tensor:
978+
return torch.sum(p * log_q).float()
979+
980+
981+
def _get_prob_from_logits(
982+
logits: torch.Tensor, return_log_prob: bool = False, lm_head: nn.Module = None
983+
) -> torch.Tensor:
984+
parallel_state: ParallelState | None = (
985+
getattr(lm_head, "parallel_state", None) if lm_head is not None else None
986+
)
987+
if parallel_state is not None and parallel_state.tensor_parallel_group.is_initialized():
988+
return _get_softmax_dist(
989+
logits, parallel_state.tensor_parallel_group.group, return_log_prob
990+
)
991+
return _get_softmax(logits, return_log_prob)
992+
993+
994+
def _get_kl_div_loss(
995+
prob_unquant: torch.Tensor, logits_quant: torch.Tensor, lm_head: nn.Module = None
996+
) -> torch.Tensor:
997+
log_prob_quant = _get_prob_from_logits(logits_quant, return_log_prob=True, lm_head=lm_head)
998+
# We dont need to calculate the full kl div loss here, just get p*log_q
999+
return _get_p_log_q(prob_unquant, log_prob_quant)
1000+
1001+
1002+
def _get_lm_head(model: nn.Module) -> nn.Module:
1003+
for name, module in model.named_modules():
1004+
if name.endswith(("lm_head", "output_layer")): # HF transformers models or Megatron models
1005+
return module
1006+
return None
9531007

9541008

9551009
class AutoQuantizeKLDivSearcher(_AutoQuantizeBaseSearcher):
9561010
"""A searcher for AutoQuantize algorithm that uses KL-Divergence loss based score estimation."""
9571011

958-
score_module_rules: list[str | Callable] = [lambda name: ""]
959-
9601012
@property
9611013
def default_search_config(self):
9621014
"""Get the default config for the searcher."""
@@ -973,9 +1025,10 @@ def sanitize_search_config(self, config: SearchConfig | None) -> SearchConfig:
9731025
config = config or {}
9741026
for ignored_key in ["score_func", "loss_func", "forward_backward_step"]:
9751027
if ignored_key in config:
976-
warnings.warn(
977-
f"`{ignored_key}` is ignored for KL-Divergence loss based `auto_quantize`."
978-
)
1028+
if config[ignored_key] is not None:
1029+
warnings.warn(
1030+
f"`{ignored_key}` is ignored for KL-Divergence loss based `auto_quantize`."
1031+
)
9791032
config.pop(ignored_key)
9801033
config = super().sanitize_search_config(config)
9811034
assert config["forward_step"] is not None, (
@@ -984,21 +1037,12 @@ def sanitize_search_config(self, config: SearchConfig | None) -> SearchConfig:
9841037
)
9851038
return config
9861039

987-
@torch.no_grad()
1040+
@torch.inference_mode()
9881041
def estimate_sensitivity_scores(self):
9891042
"""Estimate the sensitivity scores for the model.
9901043
9911044
Higher score means more sensitive to quantization.
9921045
"""
993-
# Check if tensor parallelism is being used
994-
for name, module in self.model.named_modules():
995-
if hasattr(module, "parallel_state"):
996-
if hasattr(module.parallel_state, "tensor_parallel_group"):
997-
if module.parallel_state.tensor_parallel_group.is_initialized():
998-
warnings.warn(
999-
"Tensor Parallel is not supported for KL-Divergence based auto_quantize. "
1000-
)
1001-
break
10021046

10031047
def set_to_unquantized():
10041048
for name, hparam in named_hparams(self.model, unique=True):
@@ -1016,17 +1060,27 @@ def set_to_unquantized():
10161060
):
10171061
set_to_unquantized()
10181062
logits_unquant = self.config["forward_step"](self.model, data)
1063+
prob_unquant = _get_prob_from_logits(
1064+
logits_unquant,
1065+
return_log_prob=False,
1066+
lm_head=_get_lm_head(self.model),
1067+
)
10191068

1020-
for name, hparam in named_hparams(self.model, configurable=True):
1069+
for name, hparam in tqdm(
1070+
list(named_hparams(self.model, configurable=True)), desc="Evaluating hparams"
1071+
):
10211072
if not isinstance(hparam, QuantRecipeHparam):
10221073
continue
10231074
for recipe in hparam.choices:
10241075
if recipe == QuantRecipe(quant_cfg=None):
10251076
continue
10261077
hparam.active = recipe
10271078
logits_quant = self.config["forward_step"](self.model, data)
1028-
score = _get_kl_div_loss(logits_unquant, logits_quant)
1029-
hparam._importance_dict[recipe][hparam.score_modules[0]] = score
1079+
score = _get_kl_div_loss(prob_unquant, logits_quant, _get_lm_head(self.model))
1080+
if hparam._importance_dict[recipe][hparam.score_modules[0]] is None:
1081+
hparam._importance_dict[recipe][hparam.score_modules[0]] = score
1082+
else:
1083+
hparam._importance_dict[recipe][hparam.score_modules[0]] += score
10301084
hparam.active = QuantRecipe(quant_cfg=None)
10311085

10321086
def run_search_with_stats(self, max_weight_size, verbose=False):

0 commit comments

Comments
 (0)