@@ -155,9 +155,9 @@ def validate_awq_after(model: "AWQModifier") -> "AWQModifier":
155155 for group in config .config_groups .values ()
156156 if group .weights is not None
157157 )
158- assert (
159- len ( num_bits_set ) == 1
160- ), "In AWQ, all config groups must use the same configuration for num_bits"
158+ assert len ( num_bits_set ) == 1 , (
159+ "In AWQ, all config groups must use the same configuration for num_bits"
160+ )
161161
162162 model ._num_bits = next (iter (num_bits_set ))
163163
@@ -166,9 +166,9 @@ def validate_awq_after(model: "AWQModifier") -> "AWQModifier":
166166 for group in config .config_groups .values ()
167167 if group .weights is not None
168168 )
169- assert (
170- len ( symmetric_set ) == 1
171- ), "In AWQ, all config groups must use the same configuration for symmetric"
169+ assert len ( symmetric_set ) == 1 , (
170+ "In AWQ, all config groups must use the same configuration for symmetric"
171+ )
172172
173173 model ._symmetric = next (iter (symmetric_set ))
174174
@@ -177,9 +177,9 @@ def validate_awq_after(model: "AWQModifier") -> "AWQModifier":
177177 for group in config .config_groups .values ()
178178 if group .weights is not None
179179 )
180- assert (
181- len ( group_size_set ) == 1
182- ), "In AWQ, all config groups must use the same configuration for group_size"
180+ assert len ( group_size_set ) == 1 , (
181+ "In AWQ, all config groups must use the same configuration for group_size"
182+ )
183183
184184 model ._group_size = next (iter (group_size_set ))
185185
@@ -316,7 +316,7 @@ def _set_resolved_mappings(self, model: Module) -> None:
316316 )
317317 ):
318318 pbar .set_description (
319- f"Resolving mapping { mapping_idx + 1 } /{ len (self .mappings )} "
319+ f"Resolving mapping { mapping_idx + 1 } /{ len (self .mappings )} "
320320 f" ({ num_skipped_mappings } skipped)"
321321 )
322322
@@ -452,9 +452,11 @@ def _apply_smoothing(self, model: Module) -> None:
452452 balance_layers = mapping .balance_layers
453453 parent_module = mapping .parent
454454
455- with align_modules (
456- [parent_module , smooth_layer , * balance_layers ]
457- ), calibration_forward_context (model ), HooksMixin .disable_hooks ():
455+ with (
456+ align_modules ([parent_module , smooth_layer , * balance_layers ]),
457+ calibration_forward_context (model ),
458+ HooksMixin .disable_hooks (),
459+ ):
458460 # [STEP 1]: Compute per-channel mean of normalised weights
459461 # All layer weights are concatted together
460462 weight = torch .cat ([bl .weight for bl in balance_layers ], dim = 0 )
@@ -653,9 +655,9 @@ def _compute_best_scale(
653655 "https://github.com/vllm-project/llm-compressor/issues"
654656 )
655657
656- assert (
657- torch . isnan ( best_scales ). sum () == 0
658- ), f"Nan found in scales: { best_scales } "
658+ assert torch . isnan ( best_scales ). sum () == 0 , (
659+ f"Nan found in scales: { best_scales } "
660+ )
659661
660662 return best_scales .detach ().cpu ()
661663
0 commit comments