Skip to content

Commit 0a08e48

Browse files
committed
Modernized typehints in core/lifecycle.py
Signed-off-by: siddhaka <[email protected]>
1 parent 7d6c87f commit 0a08e48

File tree

24 files changed

+102
-92
lines changed

24 files changed

+102
-92
lines changed

examples/awq/qwen3_coder_moe_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
def get_calib_dataset(tokenizer):
3232
ds = load_dataset(
3333
DATASET_ID,
34-
split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES*10}]",
34+
split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES * 10}]",
3535
)
3636

3737
def preprocess(example):

src/llmcompressor/modeling/fuse.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,9 @@ def fuse_norm_linears(norm: torch.nn.Module, linears: Iterable[torch.nn.Linear])
4747
for linear in linears:
4848
# NOTE: spinquant does this op in float64
4949
exec_device = get_execution_device(norm)
50-
with align_module_device(norm, exec_device), align_module_device(
51-
linear, exec_device
50+
with (
51+
align_module_device(norm, exec_device),
52+
align_module_device(linear, exec_device),
5253
):
5354
weight_dtype = linear.weight.dtype
5455
new_weight = linear.weight.to(PRECISION) * norm.weight.to(PRECISION)

src/llmcompressor/modifiers/awq/base.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -155,9 +155,9 @@ def validate_awq_after(model: "AWQModifier") -> "AWQModifier":
155155
for group in config.config_groups.values()
156156
if group.weights is not None
157157
)
158-
assert (
159-
len(num_bits_set) == 1
160-
), "In AWQ, all config groups must use the same configuration for num_bits"
158+
assert len(num_bits_set) == 1, (
159+
"In AWQ, all config groups must use the same configuration for num_bits"
160+
)
161161

162162
model._num_bits = next(iter(num_bits_set))
163163

@@ -166,9 +166,9 @@ def validate_awq_after(model: "AWQModifier") -> "AWQModifier":
166166
for group in config.config_groups.values()
167167
if group.weights is not None
168168
)
169-
assert (
170-
len(symmetric_set) == 1
171-
), "In AWQ, all config groups must use the same configuration for symmetric"
169+
assert len(symmetric_set) == 1, (
170+
"In AWQ, all config groups must use the same configuration for symmetric"
171+
)
172172

173173
model._symmetric = next(iter(symmetric_set))
174174

@@ -177,9 +177,9 @@ def validate_awq_after(model: "AWQModifier") -> "AWQModifier":
177177
for group in config.config_groups.values()
178178
if group.weights is not None
179179
)
180-
assert (
181-
len(group_size_set) == 1
182-
), "In AWQ, all config groups must use the same configuration for group_size"
180+
assert len(group_size_set) == 1, (
181+
"In AWQ, all config groups must use the same configuration for group_size"
182+
)
183183

184184
model._group_size = next(iter(group_size_set))
185185

@@ -316,7 +316,7 @@ def _set_resolved_mappings(self, model: Module) -> None:
316316
)
317317
):
318318
pbar.set_description(
319-
f"Resolving mapping {mapping_idx+1}/{len(self.mappings)}"
319+
f"Resolving mapping {mapping_idx + 1}/{len(self.mappings)}"
320320
f" ({num_skipped_mappings} skipped)"
321321
)
322322

@@ -452,9 +452,11 @@ def _apply_smoothing(self, model: Module) -> None:
452452
balance_layers = mapping.balance_layers
453453
parent_module = mapping.parent
454454

455-
with align_modules(
456-
[parent_module, smooth_layer, *balance_layers]
457-
), calibration_forward_context(model), HooksMixin.disable_hooks():
455+
with (
456+
align_modules([parent_module, smooth_layer, *balance_layers]),
457+
calibration_forward_context(model),
458+
HooksMixin.disable_hooks(),
459+
):
458460
# [STEP 1]: Compute per-channel mean of normalised weights
459461
# All layer weights are concatted together
460462
weight = torch.cat([bl.weight for bl in balance_layers], dim=0)
@@ -653,9 +655,9 @@ def _compute_best_scale(
653655
"https://github.com/vllm-project/llm-compressor/issues"
654656
)
655657

656-
assert (
657-
torch.isnan(best_scales).sum() == 0
658-
), f"Nan found in scales: {best_scales}"
658+
assert torch.isnan(best_scales).sum() == 0, (
659+
f"Nan found in scales: {best_scales}"
660+
)
659661

660662
return best_scales.detach().cpu()
661663

src/llmcompressor/modifiers/pruning/sparsegpt/base.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,11 @@ def compress_modules(self):
123123
num_samples = self._num_samples[module]
124124

125125
logger.info(f"Sparsifying {name} using {num_samples} samples")
126-
with torch.no_grad(), align_module_device(module), CompressionLogger(
127-
module
128-
) as comp_logger:
126+
with (
127+
torch.no_grad(),
128+
align_module_device(module),
129+
CompressionLogger(module) as comp_logger,
130+
):
129131
loss, sparsified_weight = sparsify_weight(
130132
module=module,
131133
hessians_dict=self._hessians,

src/llmcompressor/modifiers/pruning/utils/pytorch/layer_mask.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def setup_mask_for_param(param: Parameter, mask: torch.Tensor) -> torch.Tensor:
2323

2424
if mask.shape != param.data.shape:
2525
raise ValueError(
26-
f"Mask shape {mask.shape} does not match " f"param shape {param.data.shape}"
26+
f"Mask shape {mask.shape} does not match param shape {param.data.shape}"
2727
)
2828

2929
if mask.dtype != torch.bool:

src/llmcompressor/modifiers/pruning/wanda/base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,10 @@ def compress_modules(self):
108108
num_samples = self._num_samples[module]
109109

110110
logger.info(f"Sparsifying {name} using {num_samples} samples")
111-
with torch.no_grad(), align_module_device(module), CompressionLogger(
112-
module
111+
with (
112+
torch.no_grad(),
113+
align_module_device(module),
114+
CompressionLogger(module),
113115
):
114116
sparsified_weight = sparsify_weight(
115117
module=module,

src/llmcompressor/modifiers/quantization/gptq/base.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -249,11 +249,12 @@ def compress_modules(self):
249249
quant_args = getattr_chain(module, "quantization_scheme.weights")
250250

251251
logger.info(f"Quantizing {name} using {num_samples} samples")
252-
with torch.no_grad(), align_module_device(
253-
module
254-
), self._maybe_onload_hessian(module), CompressionLogger(
255-
module
256-
) as comp_logger:
252+
with (
253+
torch.no_grad(),
254+
align_module_device(module),
255+
self._maybe_onload_hessian(module),
256+
CompressionLogger(module) as comp_logger,
257+
):
257258
loss, quantized_weight, scale, zero_point, g_idx = quantize_weight(
258259
module=module,
259260
quant_args=quant_args,

src/llmcompressor/recipe/utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,7 @@ def _parse_recipe_from_md(file_path, yaml_str):
4848
else:
4949
# fail if we know whe should have extracted front matter out
5050
raise RuntimeError(
51-
"Could not extract YAML front matter from recipe card:" " {}".format(
52-
file_path
53-
)
51+
"Could not extract YAML front matter from recipe card: {}".format(file_path)
5452
)
5553
return yaml_str
5654

src/llmcompressor/utils/dev.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,12 @@ def patched(cls, *args, **kwargs):
7070

7171
return model
7272

73-
with tempfile.TemporaryDirectory() as tmp_dir, patch_attr(
74-
model_class, "from_pretrained", patched
75-
), skip_weights_initialize(), patch_transformers_logger_level():
73+
with (
74+
tempfile.TemporaryDirectory() as tmp_dir,
75+
patch_attr(model_class, "from_pretrained", patched),
76+
skip_weights_initialize(),
77+
patch_transformers_logger_level(),
78+
):
7679
yield
7780

7881

src/llmcompressor/utils/helpers.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1049,8 +1049,11 @@ def calibration_forward_context(model: torch.nn.Module):
10491049
- Disable train mode and enable eval mode
10501050
- Disable hf kernels which could bypass hooks
10511051
"""
1052-
with torch.no_grad(), disable_cache(model), eval_context(model), disable_hf_kernels(
1053-
model
1052+
with (
1053+
torch.no_grad(),
1054+
disable_cache(model),
1055+
eval_context(model),
1056+
disable_hf_kernels(model),
10541057
):
10551058
yield
10561059

0 commit comments

Comments
 (0)