|
28 | 28 | import regex as re |
29 | 29 | import torch |
30 | 30 | import torch.nn as nn |
| 31 | +import torch.nn.functional as F |
31 | 32 | from tqdm import tqdm |
32 | 33 |
|
33 | 34 | from modelopt.torch.opt.conversion import ModeloptStateManager |
@@ -943,8 +944,175 @@ def run_search_with_stats(self, max_weight_size, verbose=False): |
943 | 944 | return best_recipes, is_satisfied |
944 | 945 |
|
945 | 946 |
|
946 | | -class AutoQuantizeLossSearcher(_AutoQuantizeBaseSearcher): |
947 | | - """A searcher for AutoQuantize algorithm that uses loss based score estimation.""" |
| 947 | +@torch.compile |
| 948 | +def _get_kl_div_loss(logits_unquant: torch.Tensor, logits_quant: torch.Tensor) -> torch.Tensor: |
| 949 | + # TODO: Support TensorParallel |
| 950 | + prob_unquant = F.softmax(logits_unquant, dim=-1) |
| 951 | + log_prob_quant = F.log_softmax(logits_quant, dim=-1) |
| 952 | + return F.kl_div(log_prob_quant, prob_unquant, reduction="sum", log_target=False) |
| 953 | + |
| 954 | + |
| 955 | +class AutoQuantizeKLDivSearcher(_AutoQuantizeBaseSearcher): |
| 956 | + """A searcher for AutoQuantize algorithm that uses KL-Divergence loss based score estimation.""" |
| 957 | + |
| 958 | + score_module_rules: list[str | Callable] = [lambda name: ""] |
| 959 | + |
| 960 | + @property |
| 961 | + def default_search_config(self): |
| 962 | + """Get the default config for the searcher.""" |
| 963 | + config = super().default_search_config |
| 964 | + config.update( |
| 965 | + { |
| 966 | + "forward_step": None, |
| 967 | + } |
| 968 | + ) |
| 969 | + return config |
| 970 | + |
| 971 | + def sanitize_search_config(self, config: SearchConfig | None) -> SearchConfig: |
| 972 | + """Sanitize the search config dict.""" |
| 973 | + config = config or {} |
| 974 | + for ignored_key in ["score_func", "loss_func", "forward_backward_step"]: |
| 975 | + if ignored_key in config: |
| 976 | + warnings.warn( |
| 977 | + f"`{ignored_key}` is ignored for KL-Divergence loss based `auto_quantize`." |
| 978 | + ) |
| 979 | + config.pop(ignored_key) |
| 980 | + config = super().sanitize_search_config(config) |
| 981 | + assert config["forward_step"] is not None, ( |
| 982 | + "`forward_step` must be provided for KL-Divergence loss based `auto_quantize`. " |
| 983 | + "`forward_step(model, data)` should return model logits." |
| 984 | + ) |
| 985 | + return config |
| 986 | + |
| 987 | + @torch.no_grad() |
| 988 | + def estimate_sensitivity_scores(self): |
| 989 | + """Estimate the sensitivity scores for the model. |
| 990 | +
|
| 991 | + Higher score means more sensitive to quantization. |
| 992 | + """ |
| 993 | + # Check if tensor parallelism is being used |
| 994 | + for name, module in self.model.named_modules(): |
| 995 | + if hasattr(module, "parallel_state"): |
| 996 | + if hasattr(module.parallel_state, "tensor_parallel_group"): |
| 997 | + if module.parallel_state.tensor_parallel_group.is_initialized(): |
| 998 | + warnings.warn( |
| 999 | + "Tensor Parallel is not supported for KL-Divergence based auto_quantize. " |
| 1000 | + ) |
| 1001 | + break |
| 1002 | + |
| 1003 | + def set_to_unquantized(): |
| 1004 | + for name, hparam in named_hparams(self.model, unique=True): |
| 1005 | + if not isinstance(hparam, QuantRecipeHparam): |
| 1006 | + continue |
| 1007 | + if hparam.is_configurable: |
| 1008 | + hparam.active = QuantRecipe(quant_cfg=None) |
| 1009 | + |
| 1010 | + self.model.eval() |
| 1011 | + num_iters = self.config["num_score_steps"] |
| 1012 | + for _, data in tqdm( |
| 1013 | + zip(range(num_iters), self.config["data_loader"]), |
| 1014 | + desc="Estimating KLDivergence loss", |
| 1015 | + total=num_iters, |
| 1016 | + ): |
| 1017 | + set_to_unquantized() |
| 1018 | + logits_unquant = self.config["forward_step"](self.model, data) |
| 1019 | + |
| 1020 | + for name, hparam in named_hparams(self.model, configurable=True): |
| 1021 | + if not isinstance(hparam, QuantRecipeHparam): |
| 1022 | + continue |
| 1023 | + for recipe in hparam.choices: |
| 1024 | + if recipe == QuantRecipe(quant_cfg=None): |
| 1025 | + continue |
| 1026 | + hparam.active = recipe |
| 1027 | + logits_quant = self.config["forward_step"](self.model, data) |
| 1028 | + score = _get_kl_div_loss(logits_unquant, logits_quant) |
| 1029 | + hparam._importance_dict[recipe][hparam.score_modules[0]] = score |
| 1030 | + hparam.active = QuantRecipe(quant_cfg=None) |
| 1031 | + |
| 1032 | + def run_search_with_stats(self, max_weight_size, verbose=False): |
| 1033 | + """Run threshold-based binary search for KLDivergence loss based auto_quantize. |
| 1034 | +
|
| 1035 | + We use binary search to minimize the max(per-layer score) while meeting the constraint. |
| 1036 | + """ |
| 1037 | + # Collect all sensitivity scores to determine initial threshold bounds |
| 1038 | + all_scores = [ |
| 1039 | + score for name in self.candidate_stats for score in self.candidate_stats[name]["scores"] |
| 1040 | + ] |
| 1041 | + |
| 1042 | + if not all_scores: |
| 1043 | + warnings.warn("No scores available for threshold-based search!") |
| 1044 | + is_satisfied = False |
| 1045 | + return {}, is_satisfied |
| 1046 | + |
| 1047 | + # Initialize binary search bounds |
| 1048 | + min_score = min(all_scores) |
| 1049 | + max_score = max(all_scores) |
| 1050 | + threshold = (min_score + max_score) / 2.0 |
| 1051 | + lower_bound = min_score |
| 1052 | + upper_bound = max_score |
| 1053 | + |
| 1054 | + # Run for fixed number of iterations |
| 1055 | + max_iterations = 100 |
| 1056 | + |
| 1057 | + if verbose: |
| 1058 | + print_rank_0("AutoQuantize: Starting threshold-based binary search") |
| 1059 | + print_rank_0(f" Score range: [{min_score:.6e}, {max_score:.6e}]") |
| 1060 | + print_rank_0(f" Target weight size: {max_weight_size:.2f}") |
| 1061 | + |
| 1062 | + for iteration in range(max_iterations): |
| 1063 | + # Select recipes based on current threshold |
| 1064 | + best_recipes = {} |
| 1065 | + total_weight_size = 0.0 |
| 1066 | + |
| 1067 | + for name in self.candidate_stats: |
| 1068 | + formats = self.candidate_stats[name]["formats"] |
| 1069 | + scores = self.candidate_stats[name]["scores"] |
| 1070 | + costs = self.candidate_stats[name]["costs"] |
| 1071 | + |
| 1072 | + selected_idx = 0 |
| 1073 | + for idx in range(len(formats)): |
| 1074 | + if scores[idx] <= threshold: |
| 1075 | + selected_idx = idx |
| 1076 | + break |
| 1077 | + |
| 1078 | + best_recipes[name] = { |
| 1079 | + "format": formats[selected_idx], |
| 1080 | + "costs": costs[selected_idx], |
| 1081 | + "scores": scores[selected_idx], |
| 1082 | + } |
| 1083 | + total_weight_size += costs[selected_idx] |
| 1084 | + |
| 1085 | + # Check if we meet the constraint |
| 1086 | + meets_constraint = total_weight_size <= max_weight_size |
| 1087 | + |
| 1088 | + if verbose: |
| 1089 | + print_rank_0( |
| 1090 | + f" Iteration {iteration + 1}: threshold={threshold:.6e}, " |
| 1091 | + f"weight_size={total_weight_size:.2f}, " |
| 1092 | + f"meets_constraint={meets_constraint}" |
| 1093 | + ) |
| 1094 | + |
| 1095 | + # Update binary search bounds |
| 1096 | + if meets_constraint: |
| 1097 | + upper_bound = threshold # Threshold was too aggressive, relax it |
| 1098 | + else: |
| 1099 | + lower_bound = threshold # Threshold was too lax, tighten it |
| 1100 | + |
| 1101 | + # Update threshold for next iteration |
| 1102 | + threshold = (lower_bound + upper_bound) / 2.0 |
| 1103 | + |
| 1104 | + # Final check if constraint is satisfied |
| 1105 | + is_satisfied = total_weight_size <= max_weight_size |
| 1106 | + |
| 1107 | + if verbose: |
| 1108 | + print_rank_0( |
| 1109 | + f"AutoQuantize: Search complete. " |
| 1110 | + f"Final weight size: {total_weight_size:.2f} " |
| 1111 | + f"(target: {max_weight_size:.2f}), " |
| 1112 | + f"constraint satisfied: {is_satisfied}" |
| 1113 | + ) |
| 1114 | + |
| 1115 | + return best_recipes, is_satisfied |
948 | 1116 |
|
949 | 1117 |
|
950 | 1118 | # Backward compatibility alias (defaults to gradient-based searcher) |
|
0 commit comments