- 
                Notifications
    
You must be signed in to change notification settings  - Fork 58
 
add autoround._generate_recipe() #758
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
d8b831e
              a78f99c
              62f81e5
              79323d6
              086eae2
              b67c79a
              715e2a1
              5a631f1
              2a0e0b8
              f02331d
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| 
          
            
          
           | 
    @@ -430,6 +430,9 @@ def _adjust_torch_compile(self, enable_torch_compile: bool) -> None: | |
| self.enable_torch_compile = False | ||
| logger.warning("reset enable_torch_compile to `False` as fp8 is enabled") | ||
| 
     | 
||
| self.recipe_mode = False | ||
| self.recipe_results = {} | ||
| 
     | 
||
| def _set_device_map_in_blocks(self, device_map: Union[str, dict, None]) -> None: | ||
| """Sets the device map for specific blocks in the model. | ||
| 
     | 
||
| 
          
            
          
           | 
    @@ -1433,6 +1436,8 @@ def quantize(self): | |
| m.tmp_name = n | ||
| self._check_compatibility() | ||
| self.has_qlayer_outside_block = self.set_layerwise_config(self.layer_config) | ||
| if not self.recipe_mode: | ||
| self._dump_average_bits() # leverage updated self.layer_config | ||
| if not hasattr(self, "formats"): | ||
| logger.warning("this API is deprecated, please use `quantize_and_save` instead") | ||
| else: | ||
| 
          
            
          
           | 
    @@ -1549,6 +1554,8 @@ def quantize(self): | |
| f"Expected exactly one packing format when 'is_packing_immediate' is True, " | ||
| f"but got {len(self.formats)} formats." | ||
| ) | ||
| if self.recipe_mode: | ||
| return | ||
| 
     | 
||
| self.quant_layers(layer_names, all_inputs) ##TODO pack layer immediately | ||
| 
     | 
||
| 
          
            
          
           | 
    @@ -2439,7 +2446,10 @@ def quantize_block(self, block, input_ids, input_others, q_input=None, device=to | |
| 
     | 
||
| modules = block.modules() | ||
| for module in modules: | ||
| update_fused_layer_global_scales(module) | ||
| try: | ||
| update_fused_layer_global_scales(module) | ||
| except: | ||
| pass # mix-precision may cause error, since q,k,v are not the same dtype. | ||
| round_params = [] | ||
| minmax_params = [] | ||
| for n, m in block.named_modules(): | ||
| 
          
            
          
           | 
    @@ -2561,7 +2571,7 @@ def quantize_block(self, block, input_ids, input_others, q_input=None, device=to | |
| logger.info(f"{unquantized_layer_names} have not been quantized") | ||
| with torch.no_grad(): | ||
| unwrapper_block(block, best_params) | ||
| if self.enable_quanted_input: | ||
| if self.enable_quanted_input and hasattr(self, "formats"): | ||
                
      
                  xin3he marked this conversation as resolved.
               
          
            Show resolved
            Hide resolved
         | 
||
| if is_nv_fp(self.act_data_type) and any("nv_fp" in format_ for format_ in self.formats): | ||
| from auto_round.utils import set_amax_for_all_moe_layers | ||
| 
     | 
||
| 
          
            
          
           | 
    @@ -2616,6 +2626,26 @@ def quantize_blocks( | |
| clear_memory() | ||
| input_ids = to_device(input_ids, self.cache_device) | ||
| input_others = to_device(input_others, self.cache_device) | ||
| if self.recipe_mode: | ||
| 
         There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be better to wrap this new code into a function and call it as early as possible.  | 
||
| pbar = tqdm(range(0, len(block_names), nblocks)) | ||
| for i in range(0, len(block_names), nblocks): | ||
| if i != 0: | ||
| pbar.update(1) | ||
| if nblocks == 1: | ||
| n = block_names[i] | ||
| pbar.set_description(f"[Recipe Mode] Processing {n}") | ||
| block = get_module(model, n) | ||
| else: | ||
| names = block_names[i : min(i + nblocks, len(block_names))] | ||
| pbar.set_description( | ||
| f"[Recipe Mode] Processing [{i + 1}-{min(i + nblocks, len(block_names))}]/{len(block_names)}" | ||
| ) | ||
| modules = [get_module(model, n) for n in names] | ||
| block = WrapperMultiblock(modules) | ||
| block_recipe_results = self._generate_block_recipe(block, input_ids, input_others) | ||
| for result in block_recipe_results: | ||
| self.recipe_results.update({block_names[i] + "." + result: self.recipe_mp_dtype}) | ||
| return | ||
| ## as in calibration phase, we may use bf16 for calibration due to low_gpu_memory usage | ||
| tmp_dtype = self.amp_dtype if self.amp else torch.float32 | ||
| for i in range(len(input_ids)): | ||
| 
          
            
          
           | 
    @@ -2954,6 +2984,141 @@ def sampling_inputs(cls, input_ids, input_others, indices, seqlen, batch_dim=0, | |
| 
     | 
||
| return current_input_ids, current_input_others | ||
| 
     | 
||
| def _generate_recipe( | ||
| self, | ||
| # same data type config as before | ||
| mp_dtype={ | ||
                
      
                  xin3he marked this conversation as resolved.
               
              
                Outdated
          
            Show resolved
            Hide resolved
         | 
||
| "data_type": "mx_fp8", | ||
| "act_data_type": "mx_fp8", | ||
| }, | ||
| # special mix-precision configuration | ||
| mp_config={ | ||
| "mp_ratio": 1 / 3, | ||
| "loss_weight": 2.0, | ||
| "numel_weight": 1.0, | ||
| }, | ||
| ): | ||
| self.recipe_mode = True | ||
| self.recipe_mp_dtype = mp_dtype | ||
| self.recipe_mp_config = mp_config | ||
| self.quantize() | ||
| recipe_layer_config = copy.deepcopy(self.layer_config) | ||
| recipe_layer_config.update(self.recipe_results) | ||
| self._dump_average_bits(layer_config=recipe_layer_config) | ||
| self.recipe_mode = False | ||
| return recipe_layer_config | ||
| 
     | 
||
| def _generate_block_recipe(self, block, input_ids, input_others): | ||
| from itertools import combinations | ||
| 
     | 
||
| from auto_round.utils import ( | ||
| DTYPE_INFO_MAPPING, | ||
| create_mp_block, | ||
| get_best_combination, | ||
| get_numel, | ||
| recover_mp_block, | ||
| ) | ||
| 
     | 
||
| # fetch mix-precision recipe configuration | ||
| sample_num = self.recipe_mp_config.get("sample_num", 8) | ||
| mp_ratio = self.recipe_mp_config.get("mp_ratio", 1 / 7) | ||
| loss_weight = float(self.recipe_mp_config.get("loss_weight", 2.0)) | ||
| numel_weight = float(self.recipe_mp_config.get("numel_weight", 1.0)) | ||
| loss_numel_ratio = loss_weight / numel_weight | ||
| 
     | 
||
| # calculate the number of layers to use mix-precision | ||
| quantizable_layers = [n for n, m in block.named_modules() if isinstance(m, SUPPORTED_LAYER_TYPES)] | ||
| quantizable_num = int(mp_ratio * len(quantizable_layers)) # It's ceiling | ||
| # fetch raw low-bits dtype of block for recovering mix-precision block | ||
| layer = get_module(block, quantizable_layers[0]) | ||
| raw_dtype = { | ||
| "data_type": layer.data_type, | ||
| "bits": layer.bits, | ||
| "sym": layer.sym, | ||
| "act_data_type": layer.act_data_type, | ||
| "act_bits": layer.act_bits, | ||
| "act_sym": layer.act_sym, | ||
| } | ||
| # update self.recipe_mp_dtype | ||
| self.recipe_mp_dtype.update( | ||
| { | ||
| "bits": DTYPE_INFO_MAPPING[self.recipe_mp_dtype["data_type"]]["bits"], | ||
| "group_size": DTYPE_INFO_MAPPING[self.recipe_mp_dtype["data_type"]]["group_size"], | ||
| "sym": DTYPE_INFO_MAPPING[self.recipe_mp_dtype["data_type"]]["sym"], | ||
| "act_bits": DTYPE_INFO_MAPPING[self.recipe_mp_dtype["act_data_type"]]["bits"], | ||
| "act_group_size": DTYPE_INFO_MAPPING[self.recipe_mp_dtype["act_data_type"]]["group_size"], | ||
| "act_sym": DTYPE_INFO_MAPPING[self.recipe_mp_dtype["act_data_type"]]["sym"], | ||
| } | ||
| ) | ||
| 
     | 
||
| # generate reference output of sample input_ids | ||
| reference_output = self.get_block_outputs( | ||
| block, | ||
| input_ids[:sample_num], | ||
| input_others, | ||
| bs=self.batch_size, | ||
| device=self.device, | ||
| cache_device=self.cache_device, | ||
| save_output=True, | ||
| ) | ||
| 
     | 
||
| # generate q_output of sample input_ids and get loss | ||
| def get_loss(q_block): | ||
| q_output = self.get_block_outputs( | ||
| q_block, | ||
| input_ids[:sample_num], | ||
| input_others, | ||
| bs=self.batch_size, | ||
| device=self.device, | ||
| cache_device=self.cache_device, | ||
| save_output=True, | ||
| ) | ||
| total_loss = 0 | ||
| mse_loss = torch.nn.MSELoss(reduction="sum").to(self.device) | ||
| for i in range(len(q_output)): | ||
| loss = mse_loss( # pylint: disable=not-callable | ||
| q_output[i].to(torch.float32), reference_output[i].to(torch.float32) | ||
| ) | ||
| total_loss += loss | ||
| if is_optimum_habana_available(): | ||
| htcore.mark_step() | ||
| return loss | ||
| 
     | 
||
| combination_list = [] | ||
| numel_list = [] | ||
| loss_list = [] | ||
| for hp_layers in combinations(quantizable_layers, quantizable_num): | ||
| combination_list.append(hp_layers) | ||
| # get numel | ||
| numel = get_numel(block, hp_layers) | ||
| numel_list.append(numel) | ||
| # get loss | ||
| block = create_mp_block(block, hp_layers, self.recipe_mp_dtype) | ||
| loss = get_loss(block) | ||
| loss_list.append(loss) | ||
| block = recover_mp_block(block, hp_layers, raw_dtype) | ||
| if is_optimum_habana_available(): | ||
| htcore.mark_step() | ||
| logger.debug(f"{hp_layers}, {loss}, {numel}") | ||
| 
     | 
||
| hp_layers = get_best_combination(combination_list, numel_list, loss_list, loss_numel_ratio) | ||
| logger.info(f"final hp layers: {hp_layers}") | ||
| return hp_layers | ||
| 
     | 
||
| def _dump_average_bits(self, layer_config=None): | ||
                
       | 
||
| total_numel = 0 | ||
| total_bits = 0 | ||
| for n, m in self.model.named_modules(): | ||
| if isinstance(m, SUPPORTED_LAYER_TYPES): | ||
| m_numel = m.weight.numel() | ||
| layer_config = self.layer_config if layer_config is None else layer_config | ||
| m_bits = layer_config[n]["bits"] if n in layer_config else self.bits | ||
| total_numel += m_numel | ||
| total_bits += m_numel * m_bits | ||
| avg_bits = round(total_bits / total_numel, 3) | ||
| logger.info(f"current average bits of model: {avg_bits}") | ||
| return avg_bits | ||
| 
     | 
||
| 
     | 
||
| class AutoRoundAdam(AutoRound): | ||
| """Class for automatic rounding-based quantization with optimizers like adamw of a PyTorch model. | ||
| 
          
            
          
           | 
    ||
Uh oh!
There was an error while loading. Please reload this page.