From ca1bb1767aeb6a4ea940808c3c59ee6426031999 Mon Sep 17 00:00:00 2001 From: blackadder Date: Fri, 7 Nov 2025 17:26:39 +0100 Subject: [PATCH 1/4] Patching for Kimi-K2 --- quant.py | 51 +++++++--- src/gptq.py | 8 +- src/gptq_loop.py | 47 ++++++++-- src/quant_utils.py | 90 +++++++++++++++--- src/sparsegpt.py | 226 +++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 385 insertions(+), 37 deletions(-) create mode 100644 src/sparsegpt.py diff --git a/quant.py b/quant.py index 162f5ed..6fbabad 100644 --- a/quant.py +++ b/quant.py @@ -46,7 +46,7 @@ def parse_args(): "--bits", type=int, default=4, - choices=[4], + choices=[2, 4], help="Quantization bitwidth.", ) parser.add_argument( @@ -79,6 +79,11 @@ def parse_args(): parser.add_argument( "--dtype", default="float16", type=str, choices=["float16s", "bfloat16"], help="Torch dtype used." ) + parser.add_argument( + "--load_last_shard", + action="store_true", + help="Whether to load the last shard of the model in the beginning (needed from Kimi-K2-Thinking)." + ) args = parser.parse_args() return args @@ -117,6 +122,8 @@ def main(): # Load DeepSeek model config = AutoConfig.from_pretrained(args.model_name_or_path, trust_remote_code=True) + # Infer quantization format + orig_quantization_format = quant_utils.infer_quantization_format(config) # Sanity check assert config.architectures == ["DeepseekV3ForCausalLM"], "Only DeepseekV3 is supported!" if hasattr(config, "quantization_config"): @@ -125,7 +132,10 @@ def main(): with init_empty_weights(): model = AutoModelForCausalLM.from_config( - config=config, trust_remote_code=True, attn_implementation="flash_attention_2", torch_dtype=dtype + config=config, + trust_remote_code=True, + attn_implementation="flash_attention_2", + torch_dtype=dtype ).eval() model.config.use_cache = False @@ -141,14 +151,26 @@ def main(): calibration_dataset = calibration_dataset[rank * num_seq_per_rank : (rank + 1) * num_seq_per_rank] dist_utils.barrier(device_ids=[rank]) + # Get num shards + for path in os.listdir(args.model_name_or_path): + if path.endswith(".safetensors"): + num_shards_str = path.split(".")[0].split("-")[-1] + num_shards = int(num_shards_str) + break + # Load initial weight shard weight_dir = args.model_name_or_path current_shard_id = 1 - weight_path = f"model-{current_shard_id:05}-of-000163.safetensors" + weight_path = f"model-{current_shard_id:05}-of-{num_shards_str}.safetensors" param_buffer = {} if dist_utils.is_main(): param_buffer = loading_utils.load_param_shard(weight_dir, weight_path) + if args.load_last_shard: + last_shard_weight_path = f"model-{num_shards:05}-of-{num_shards_str}.safetensors" + last_shard = loading_utils.load_param_shard(weight_dir, last_shard_weight_path) + param_buffer.update(last_shard) + dist_utils.barrier(device_ids=[rank]) # Get resume block id @@ -171,6 +193,9 @@ def main(): # Offload embeddings back to meta model.model.embed_tokens.to(device="meta") param_buffer.pop("model.embed_tokens.weight", None) + # Pop lm head and norm in case it is loaded from last shard + param_buffer.pop("lm_head.weight", None) + param_buffer.pop("model.norm.weight", None) for block_idx, block in tqdm( enumerate(model.model.layers), desc="Processing transformer blocks", total=len(model.model.layers) @@ -197,16 +222,20 @@ def main(): if dist_utils.is_main(): can_dequantize = True # Select weights corresponding to current block - block_state_dict = {k[len(prefix) :]: v for k, v in param_buffer.items() if k.startswith(prefix)} + block_state_dict = {k[len(prefix):]: v for k, v in param_buffer.items() if k.startswith(prefix)} while not (is_subset(block_keys_with_prefix, set(param_buffer.keys())) and can_dequantize): current_shard_id += 1 - weight_path = f"model-{current_shard_id:05}-of-000163.safetensors" - param_buffer.update(loading_utils.load_param_shard(weight_dir, weight_path)) + weight_path = f"model-{current_shard_id:05}-of-{num_shards_str}.safetensors" + param_shard = loading_utils.load_param_shard(weight_dir, weight_path) + # Dequantize weights from a current shard + if quant_utils.can_dequantize(param_shard, orig_quantization_format): + quant_utils.dequantize_state_dict(param_shard, dtype, orig_quantization_format) + param_buffer.update(param_shard) # Update weights corresponding to current block block_state_dict = {k[len(prefix) :]: v for k, v in param_buffer.items() if k.startswith(prefix)} - can_dequantize = quant_utils.can_dequantize_from_fp8(block_state_dict) + can_dequantize = quant_utils.can_dequantize(block_state_dict, orig_quantization_format) # Dequantize weights corresponding to current block - quant_utils.dequantize_state_dict(block_state_dict, dtype) + quant_utils.dequantize_state_dict(block_state_dict, dtype, orig_quantization_format) # Put block onto GPU block.to_empty(device=device) @@ -214,7 +243,7 @@ def main(): # Simply load block state dict on master and broadcast if block_idx < model.config.first_k_dense_replace: if dist_utils.is_main(): - block.load_state_dict(block_state_dict) + block.load_state_dict(block_state_dict, strict=False) if dist_utils.is_dist_available_and_initialized(): dist_utils.broadcast_parameters(block) # Send dict with part of expets to target device @@ -222,7 +251,7 @@ def main(): if dist_utils.is_main(): # Load state dict on master rank_state_dict = {k: block_state_dict[k] for k in rank_block_keys} - block.load_state_dict(rank_state_dict) + block.load_state_dict(rank_state_dict, strict=False) # Send to other processes for i in range(1, world_size): rank_state_dict = {k: block_state_dict[k] for k in other_ranks_keys[i - 1]} @@ -232,7 +261,7 @@ def main(): rank_state_dict = block.state_dict() for k in rank_state_dict: dist.recv(rank_state_dict[k], src=0) - block.load_state_dict(rank_state_dict) + block.load_state_dict(rank_state_dict, strict=False) del rank_state_dict # Clear memory before calibration torch.cuda.empty_cache() diff --git a/src/gptq.py b/src/gptq.py index ede3647..6134751 100644 --- a/src/gptq.py +++ b/src/gptq.py @@ -154,7 +154,7 @@ def quantization_pre_step(self) -> None: self.pre_step_completed = True @torch.no_grad() - def _quantize(self, bits: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def _quantize(self, bits: int, sparse_n: int = 0, sparse_m: int = 0) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Quantize the weight matrix using GPTQ """ @@ -196,6 +196,8 @@ def _quantize(self, bits: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor maxq=maxq, dtype=dtype, gptq_block_size=block_size, + sparse_n=sparse_n, + sparse_m=sparse_m, )[perm_inv].transpose(-2, -1).contiguous().to(torch.uint8) # Remove scale and zero replication scale = scale[:, ::group_size].to(dtype) @@ -213,9 +215,9 @@ def _quantize(self, bits: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor return qweight, scale, zero - def quantize(self, bits: int | float) -> Tensor: + def quantize(self, bits: int, sparse_n: int = 0, sparse_m: int = 0) -> Tensor: self.quantization_pre_step() - return self._quantize(bits) + return self._quantize(bits, sparse_n, sparse_m) @torch.no_grad() def _get_hessian_inverse(self): diff --git a/src/gptq_loop.py b/src/gptq_loop.py index ced958e..71939c7 100644 --- a/src/gptq_loop.py +++ b/src/gptq_loop.py @@ -10,7 +10,7 @@ @triton.jit -def quantize_error_triton_kernel( +def sparse_quantize_error_triton_kernel( x_ptr, qx_ptr, error_ptr, @@ -19,6 +19,8 @@ def quantize_error_triton_kernel( maxq_ptr, dtype_ptr, n_elements: int, + sparse_n: tl.constexpr, + sparse_m: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): pid = tl.program_id(axis=0) @@ -31,7 +33,24 @@ def quantize_error_triton_kernel( maxq = tl.load(maxq_ptr) dtype = None if dtype_ptr is None else tl.load(dtype_ptr).dtype - qx = tl_quantize(x, scale, qzero, maxq) + # Prune + if sparse_m > 0: + sparse_offsets = tl.arange(0, sparse_m) + sx = x.reshape(BLOCK_SIZE // sparse_m, sparse_m) + score = sx.abs() + sparse_mask = tl.zeros(score.shape, dtype=tl.int1) + for i in range(sparse_n): + max_idx = tl.argmin(score, axis=1, keep_dims=True) + sparse_mask_i = sparse_offsets == max_idx + sparse_mask = sparse_mask | sparse_mask_i + score = tl.where(sparse_mask_i, float("inf"), score) + sx = tl.where(sparse_mask, 0, sx) + sx = sx.reshape(BLOCK_SIZE) + else: + sx = x + + # Quantize + qx = tl_quantize(sx, scale, qzero, maxq) y = tl_dequantize(qx, scale, qzero, dtype) error = y - x @@ -40,19 +59,23 @@ def quantize_error_triton_kernel( tl.store(error_ptr + offsets, error, mask=mask) -def quantize_error_triton( +def sparse_quantize_error_triton( x: torch.Tensor, qx: torch.Tensor, error: torch.Tensor, scale: torch.Tensor, qzero: torch.Tensor, maxq: torch.Tensor, + sparse_n: int = 0, + sparse_m: int = 0, dtype: torch.dtype = None, ) -> None: + if sparse_m > 0: + assert 0 < sparse_n < sparse_m, "sparse_n must be in (0, sparse_m)" n_elements: int = x.numel() grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) - quantize_error_triton_kernel[grid]( + sparse_quantize_error_triton_kernel[grid]( x, qx, error, @@ -61,6 +84,8 @@ def quantize_error_triton( maxq, torch.empty(0, dtype=dtype) if dtype is not None else None, n_elements, + sparse_n, + sparse_m, BLOCK_SIZE=128, ) @@ -115,6 +140,8 @@ def gptq_loop_graph( error_block: torch.Tensor = None, dtype: torch.dtype = None, gptq_block_size: int = 128, + sparse_n: int = 0, + sparse_m: int = 0, direct: bool = True, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ @@ -141,8 +168,8 @@ def gptq_loop_graph( for i1 in range(0, n_columns, gptq_block_size): i2: int = min(i1 + gptq_block_size, n_columns) for j in range(i1, i2): - quantize_error_triton( - weight[j], qweight[j], error_block[j - i1], scale[j], qzero[j], maxq, dtype, + sparse_quantize_error_triton( + weight[j], qweight[j], error_block[j - i1], scale[j], qzero[j], maxq, sparse_n, sparse_m, dtype, ) addvv_triton(hessian_inv[j, j + 1 : i2], error_block[j - i1], weight[j + 1 : i2]) weight[i2:].addmm_(hessian_inv[i1:i2, i2:].t(), error_block[: i2 - i1], beta=1, alpha=1) @@ -169,10 +196,10 @@ def gptq_loop_graph( s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): for _ in range(n_warmups): - gptq_loop_graph(**graph_tensors, dtype=dtype, gptq_block_size=gptq_block_size, direct=True) + gptq_loop_graph(**graph_tensors, dtype=dtype, gptq_block_size=gptq_block_size, sparse_n=sparse_n, sparse_m=sparse_m, direct=True) torch.cuda.current_stream().wait_stream(s) with torch.cuda.graph(graph): - gptq_loop_graph(**graph_tensors, dtype=dtype, gptq_block_size=gptq_block_size, direct=True) + gptq_loop_graph(**graph_tensors, dtype=dtype, gptq_block_size=gptq_block_size, sparse_n=sparse_n, sparse_m=sparse_m, direct=True) gptq_loop_graph.graph_info[graph_key] = {"graph": graph, "tensors": graph_tensors} graph, graph_tensors = ( @@ -198,6 +225,8 @@ def gptq_loop( maxq: torch.Tensor, dtype: torch.dtype, gptq_block_size: int = 128, + sparse_n: int = 0, + sparse_m: int = 0, ) -> tuple[torch.Tensor, torch.Tensor]: """ Quantize weight tensor with GPTQ algorithm @@ -220,6 +249,8 @@ def gptq_loop( maxq=maxq, dtype=dtype, gptq_block_size=gptq_block_size, + sparse_n=sparse_n, + sparse_m=sparse_m, direct=False, ) return qweight # (C, R) diff --git a/src/quant_utils.py b/src/quant_utils.py index f949421..351d99a 100644 --- a/src/quant_utils.py +++ b/src/quant_utils.py @@ -6,6 +6,8 @@ import torch.nn.functional as F import triton from triton import language as tl +from transformers import AutoConfig +from compressed_tensors.compressors import unpack_from_int32 torch.backends.cuda.matmul.allow_tf32 = False @@ -283,22 +285,80 @@ def dequantize_weight_from_fp8(W, s): return W -def dequantize_state_dict(state_dict: dict[str, torch.Tensor], dtype: torch.dtype = torch.float16) -> None: +def dequantize_state_dict( + state_dict: dict[str, torch.Tensor], + dtype: torch.dtype = torch.float16, + quantization_format: str = "fp8" +) -> None: + assert quantization_format in ["fp8", "int4"] state_dict_keys = list(state_dict.keys()) - # Dequantize - for k in state_dict_keys: - if k.endswith("scale_inv"): - layer_name, _ = k.rsplit(".", 1) - W = state_dict[f"{layer_name}.weight"].to(dtype) - s = state_dict[f"{layer_name}.weight_scale_inv"].to(dtype) - - state_dict[f"{layer_name}.weight"] = dequantize_weight_from_fp8(W, s) - del state_dict[f"{layer_name}.weight_scale_inv"] + # Original DeepSeek packing + if quantization_format == "fp8": + # Dequantize + for k in state_dict_keys: + if k.endswith("scale_inv"): + layer_name, _ = k.rsplit(".", 1) + + W = state_dict[f"{layer_name}.weight"].to(dtype) + s = state_dict[f"{layer_name}.weight_scale_inv"].to(dtype) + + state_dict[f"{layer_name}.weight"] = dequantize_weight_from_fp8(W, s) + del state_dict[f"{layer_name}.weight_scale_inv"] + + # KIMI-K2 compressed tensors packing + elif quantization_format == "int4": + # Dequantize + for k in state_dict_keys: + if k.endswith("weight_packed"): + layer_name, _ = k.rsplit(".", 1) + + weight_shape = state_dict[f"{layer_name}.weight_shape"] + + qweight_packed = state_dict[f"{layer_name}.weight_packed"] + qweight = unpack_from_int32(qweight_packed, num_bits=4, shape=weight_shape) + + scale = state_dict[f"{layer_name}.weight_scale"] + + W = (qweight.view(*scale.shape, -1) * scale[..., None]).view_as(qweight).to(dtype) + + state_dict[f"{layer_name}.weight"] = W + del state_dict[f"{layer_name}.weight_packed"] + del state_dict[f"{layer_name}.weight_shape"] + del state_dict[f"{layer_name}.weight_scale"] + + +def can_dequantize( + state_dict: dict[str, torch.Tensor], + quantization_format: str = "fp8" +) -> bool: + assert quantization_format in ["fp8", "int4"] + # Original DeepSeek packing + if quantization_format == "fp8": + for k, v in state_dict.items(): + if v.dtype in FP8_DTYPES and f"{k}_scale_inv" not in state_dict: + return False + + # KIMI-K2 compressed tensors packing + elif quantization_format == "int4": + for k, v in state_dict.items(): + if k.endswith("weight_packed"): + layer_name, _ = k.rsplit(".", 1) + if f"{layer_name}.weight_shape" not in state_dict or f"{layer_name}.weight_scale" not in state_dict: + return False + return True -def can_dequantize_from_fp8(state_dict: dict[str, torch.Tensor]) -> bool: - for k, v in state_dict.items(): - if v.dtype in FP8_DTYPES and f"{k}_scale_inv" not in state_dict: - return False - return True +def infer_quantization_format(config: AutoConfig) -> str: + quant_method = config.quantization_config["quant_method"] + # DeepSeek format + if quant_method == "fp8": + return "fp8" + # KIMI-K2 format + elif quant_method == "compressed-tensors": + if config.quantization_config["config_groups"]["group_0"]["weights"]["num_bits"] == 4: + return "int4" + else: + raise ValueError("Only 4-bit quantization is supported.") + else: + raise ValueError("Unknown or unsupported quantization method.") diff --git a/src/sparsegpt.py b/src/sparsegpt.py new file mode 100644 index 0000000..901e5c3 --- /dev/null +++ b/src/sparsegpt.py @@ -0,0 +1,226 @@ +from enum import Enum +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.distributed as dist +from torch import Tensor +from torch.nn.modules.conv import _ConvNd + +from src import dist_utils, model_utils, linalg_utils + + +torch.backends.cuda.matmul.allow_tf32 = False +torch.backends.cudnn.allow_tf32 = False + + +class SparseGPT: + + def __init__( + self, + layer: nn.Module, + rel_damp: float = 1e-2, + block_size: int = None, + is_distributed: bool = False, + tied_sparsegpt_handle: Optional["SparseGPT"] = None + ): + self._validate_layer(layer) + self.layer = layer + self.W = self.layer.weight + self.d_row, self.d_col = model_utils.get_number_of_rows_and_cols(layer) + # SparseGPT hyperparameters + self.rel_damp = rel_damp + self.block_size = block_size or self.d_col + # backup layer properties + self.W_device = self.W.device + self.W_dtype = self.W.dtype + self.W_shape = self.W.shape + # init hessian + self.H = None + self.num_samples = 0 + self.is_distributed = is_distributed + self.tied_sparsegpt_handle = tied_sparsegpt_handle + self.num_tied_handles = 0 + if tied_sparsegpt_handle is not None: + tied_sparsegpt_handle.num_tied_handles += 1 + # Flags indicating issues + self.issue_zero_samples = False + self.issue_nan_hessian = False + self.issue_non_invertible = False + + @staticmethod + def _validate_layer(layer): + assert isinstance(layer, (nn.Linear, _ConvNd)), "OBC supports only linear and convolutional layers." + + def has_hessian_issues(self) -> bool: + return any([self.issue_zero_samples, self.issue_nan_hessian, self.issue_non_invertible]) + + # preparatory methods + @torch.no_grad() + def update(self, input: Tensor) -> None: + """ + Update the estimate of Hessian matrix from a batch of data. + + Args: + input: batch of layer inputs + """ + # init hessian + if self.H is None: + self.H = torch.zeros((self.d_col, self.d_col), device=input.device, dtype=torch.float32) + # input reshaping + if isinstance(self.layer, nn.Linear): + input = input.reshape(-1, input.shape[-1]) + else: + unfold = nn.Unfold( + self.layer.kernel_size, + dilation=self.layer.dilation, + padding=self.layer.padding, + stride=self.layer.stride, + ) + # output size (batch_size, channels * \prod kernel_size, num_patches) + input = unfold(input) + input = input.transpose(1, 2).flatten(0, 1) + input = input.float() + # get number of samples (tokens) in batch + num_new_samples = input.shape[0] + # hessian update + beta = self.num_samples / (self.num_samples + num_new_samples) + alpha = 2.0 / (self.num_samples + num_new_samples) + self.H.addmm_(input.T, input, beta=beta, alpha=alpha) + # update number of collected samples + self.num_samples += num_new_samples + + @property + def tokens_collected(self) -> int: + return self.num_samples + + def reset(self) -> None: + self.W = self.layer.weight + if self.num_tied_handles == 0: + self.H = None + elif self.tied_sparsegpt_handle: + self.tied_sparsegpt_handle.num_tied_handles -= 1 + if self.tied_sparsegpt_handle.num_tied_handles == 0: + self.tied_sparsegpt_handle.H = None + self.num_samples = 0 + torch.cuda.empty_cache() + + @torch.no_grad() + def sparsification_pre_step(self) -> None: + """ + Preparatory step with hessian regularization and weight reshaping. + """ + # 1) Hessian preparation + reduce_if_needed = True + if self.H is None: + if self.tied_sparsegpt_handle: + self.H = self.tied_sparsegpt_handle.H + else: + self.H = torch.eye(self.d_col, device=self.W_device, dtype=torch.float32) + self.issue_zero_samples = True + # No need to reduce + reduce_if_needed = False + # synchronize Hessians + if self.is_distributed and reduce_if_needed and dist_utils.is_dist_available_and_initialized(): + dist.all_reduce(self.H, op=dist.ReduceOp.AVG) + # Replace matrix by identity in case of NaNs + if torch.isnan(self.H).any().item(): + self.H = torch.eye(self.d_col, device=self.W_device, dtype=torch.float32) + self.issue_nan_hessian = True + # get ids of pruned channels + pruned_ids = torch.diag(self.H) == 0 + self.H[pruned_ids, pruned_ids] = 1 + # 2) Weight preparation + # copy weight, flatten + self.W = self.W.clone().float() + if isinstance(self.layer, _ConvNd): + self.W = self.W.flatten(1, -1) + self.W[:, pruned_ids] = 0 + # flag pre step as completed + self.pre_step_completed = True + + @torch.no_grad() + def _prune(self, n: int = 2, m: int = 4) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Prune the layer according to the given sparsity. + """ + # 1) Define constants and chunk + d_row, d_col, block_size, device, dtype = self.d_row, self.d_col, self.block_size, self.W_device, self.W_dtype + + is_main_sparsegpt_process = dist_utils.is_main() or not self.is_distributed + + if is_main_sparsegpt_process: + w = self.W + # Get hessian inverse + hessian_inv = self._get_hessian_inverse() + # Get hessian inverse + for c1 in range(0, d_col, block_size): + c2 = min(c1 + block_size, d_col) + ncols = c2 - c1 # number of columns + w_blk = w[:, c1:c2].clone() # column-wise weight slice + res = torch.zeros_like(w_blk) + errs = torch.zeros_like(w_blk) + losses_blk = torch.zeros_like(w_blk) + hessian_inv_blk = hessian_inv[c1:c2, c1:c2] + mask = torch.zeros_like(w_blk, dtype=torch.bool) + # 2) iterate over block + for i in range(ncols): + if i % m == 0: + scores = w_blk[:, i: (i + m)].pow(2) / hessian_inv_blk.diag()[i: (i + m)].view(1, -1).pow(2) + thr, _ = torch.kthvalue(scores, k=n, dim=-1, keepdim=True) + mask[:, i: (i + m)] = scores > thr + + w_ci = w_blk[:, i] + d = hessian_inv_blk[i, i] + + q = w_ci.clone() + q[~mask[:, i]] = 0 + + res[:, i] = q + err = (w_ci - q) / d + losses_blk[:, i] = err ** 2 + + w_blk[:, i:].addr_(err, hessian_inv_blk[i, i:], alpha=-1) + errs[:, i] = err + # 3) update the weights after block + w[:, c1:c2] = res + w[:, c2:].addmm_(errs, hessian_inv[c1:c2, c2:], alpha=-1) + + sweight = w.to(dtype=dtype) + else: + sweight = torch.empty(d_row, d_col, device=device, dtype=dtype) + + if self.is_distributed and dist_utils.is_dist_available_and_initialized(): + dist.barrier() + dist.broadcast(sweight, src=0) + + return sweight + + def prune(self, n: int = 2, m: int = 4) -> Tensor: + self.sparsification_pre_step() + return self._prune(n, m) + + @torch.no_grad() + def _get_hessian_inverse(self): + w = self.W + # Get columns with all zeros + zero_cols = torch.nonzero(w.eq(0).all(dim=0)) + H = self.H + # Regularize Hessian before sparsification + if not self.tied_sparsegpt_handle: + # Mask rows with zero input channels + H[zero_cols, :] = 0 + H[:, zero_cols] = 0 + H[zero_cols, zero_cols] = 1 + # Hessian regularization + damp = self.rel_damp * torch.diag(self.H).mean() + self.H[range(self.d_col), range(self.d_col)] += damp + # Invert + try: + H = linalg_utils.inv_sym(H) + H_inv_cho = torch.linalg.cholesky(H, upper=True) + except: + H_inv_cho = torch.eye(self.d_col, device=H.device, dtype=torch.float32) + # Divide Hessian inverse by diagonal (in order to not divide on it later) + H_inv_cho.div_(H_inv_cho.diag()[:, None]) + return H_inv_cho From 6020edb6f3f49f0f58980052829567977eab22b0 Mon Sep 17 00:00:00 2001 From: blackadder Date: Sat, 8 Nov 2025 13:02:08 +0100 Subject: [PATCH 2/4] Edited packing script for Kimi-K2 --- pack_quantized_model.py | 38 ++++++++++++++++++++++++++++++-------- 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/pack_quantized_model.py b/pack_quantized_model.py index 176d501..356141a 100644 --- a/pack_quantized_model.py +++ b/pack_quantized_model.py @@ -46,6 +46,11 @@ def parse_args(): choices=["float16", "bfloat16"], help="Torch dtype used." ) + parser.add_argument( + "--load_last_shard", + action="store_true", + help="Whether to load the last shard of the model in the beginning (needed from Kimi-K2-Thinking)." + ) args = parser.parse_args() return args @@ -116,6 +121,8 @@ def main(): # Load DeepSeek model config = AutoConfig.from_pretrained(args.model_name_or_path, trust_remote_code=True) + # Infer quantization format + orig_quantization_format = quant_utils.infer_quantization_format(config) if hasattr(config, "quantization_config"): delattr(config, "quantization_config") @@ -148,12 +155,23 @@ def main(): # Prepare directory to save packed weights os.makedirs(args.packed_model_path, exist_ok=True) + # Get num shards + for path in os.listdir(args.model_name_or_path): + if path.endswith(".safetensors"): + num_shards_str = path.split(".")[0].split("-")[-1] + num_shards = int(num_shards_str) + break + # Load initial weight shard weight_dir = args.model_name_or_path current_input_shard_id = 1 - weight_path = f"model-{current_input_shard_id:05}-of-000163.safetensors" + weight_path = f"model-{current_input_shard_id:05}-of-{num_shards_str}.safetensors" param_buffer = loading_utils.load_param_shard(weight_dir, weight_path) + if args.load_last_shard: + last_shard_weight_path = f"model-{num_shards:05}-of-{num_shards_str}.safetensors" + last_shard = loading_utils.load_param_shard(weight_dir, last_shard_weight_path) + param_buffer.update(last_shard) # Save embeddings current_output_shard_path = f"model-{current_output_shard_id:05}-of-{num_output_shards:05}.safetensors" @@ -176,11 +194,15 @@ def main(): while not is_subset(block_keys_with_prefix, set(param_buffer.keys())): current_input_shard_id += 1 - weight_path = f"model-{current_input_shard_id:05}-of-000163.safetensors" - param_buffer.update(loading_utils.load_param_shard(weight_dir, weight_path)) + weight_path = f"model-{current_input_shard_id:05}-of-{num_shards_str}.safetensors" + param_shard = loading_utils.load_param_shard(weight_dir, weight_path) + # Dequantize weights from a current shard + if quant_utils.can_dequantize(param_shard, orig_quantization_format): + quant_utils.dequantize_state_dict(param_shard, dtype, orig_quantization_format) + param_buffer.update(param_shard) block_state_dict = {k: param_buffer[k] for k in param_buffer if k.startswith(prefix)} - quant_utils.dequantize_state_dict(block_state_dict, dtype) + quant_utils.dequantize_state_dict(block_state_dict, dtype, orig_quantization_format) for layer_name in quantized_layer_names[block_idx]: weight_state_dict = torch.load( @@ -189,7 +211,7 @@ def main(): map_location="cpu" ) packed_weight_state_dict = pack_weight(weight_state_dict, args.bits, args.sym, args.group_size) - block_state_dict.pop(f"{layer_name}.weight") + block_state_dict.pop(f"{layer_name}.weight", None) block_state_dict.pop(f"{layer_name}.weight_scale_inv", None) block_state_dict.update({f"{layer_name}.{k}": v for k, v in packed_weight_state_dict.items()}) @@ -209,9 +231,9 @@ def main(): gc.collect() # Load final shard - if current_input_shard_id < 163: - current_input_shard_id = 163 - weight_path = f"model-{current_input_shard_id:05}-of-000163.safetensors" + if current_input_shard_id < num_shards: + current_input_shard_id = num_shards + weight_path = f"model-{current_input_shard_id:05}-of-{num_shards_str}.safetensors" param_buffer.update(loading_utils.load_param_shard(weight_dir, weight_path)) # Save lm head From cad41367ef385c88128c4489b8a0fd6a6bdb2bb5 Mon Sep 17 00:00:00 2001 From: blackadder Date: Sat, 8 Nov 2025 13:26:03 +0100 Subject: [PATCH 3/4] Removed sparsity from methods --- src/gptq.py | 8 +- src/gptq_loop.py | 46 ++-------- src/sparsegpt.py | 226 ----------------------------------------------- 3 files changed, 11 insertions(+), 269 deletions(-) delete mode 100644 src/sparsegpt.py diff --git a/src/gptq.py b/src/gptq.py index 6134751..038b514 100644 --- a/src/gptq.py +++ b/src/gptq.py @@ -154,7 +154,7 @@ def quantization_pre_step(self) -> None: self.pre_step_completed = True @torch.no_grad() - def _quantize(self, bits: int, sparse_n: int = 0, sparse_m: int = 0) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def _quantize(self, bits: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Quantize the weight matrix using GPTQ """ @@ -196,8 +196,6 @@ def _quantize(self, bits: int, sparse_n: int = 0, sparse_m: int = 0) -> Tuple[to maxq=maxq, dtype=dtype, gptq_block_size=block_size, - sparse_n=sparse_n, - sparse_m=sparse_m, )[perm_inv].transpose(-2, -1).contiguous().to(torch.uint8) # Remove scale and zero replication scale = scale[:, ::group_size].to(dtype) @@ -215,9 +213,9 @@ def _quantize(self, bits: int, sparse_n: int = 0, sparse_m: int = 0) -> Tuple[to return qweight, scale, zero - def quantize(self, bits: int, sparse_n: int = 0, sparse_m: int = 0) -> Tensor: + def quantize(self, bits: int) -> Tensor: self.quantization_pre_step() - return self._quantize(bits, sparse_n, sparse_m) + return self._quantize(bits) @torch.no_grad() def _get_hessian_inverse(self): diff --git a/src/gptq_loop.py b/src/gptq_loop.py index 71939c7..f70f5d7 100644 --- a/src/gptq_loop.py +++ b/src/gptq_loop.py @@ -10,7 +10,7 @@ @triton.jit -def sparse_quantize_error_triton_kernel( +def quantize_error_triton_kernel( x_ptr, qx_ptr, error_ptr, @@ -19,8 +19,6 @@ def sparse_quantize_error_triton_kernel( maxq_ptr, dtype_ptr, n_elements: int, - sparse_n: tl.constexpr, - sparse_m: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): pid = tl.program_id(axis=0) @@ -33,24 +31,8 @@ def sparse_quantize_error_triton_kernel( maxq = tl.load(maxq_ptr) dtype = None if dtype_ptr is None else tl.load(dtype_ptr).dtype - # Prune - if sparse_m > 0: - sparse_offsets = tl.arange(0, sparse_m) - sx = x.reshape(BLOCK_SIZE // sparse_m, sparse_m) - score = sx.abs() - sparse_mask = tl.zeros(score.shape, dtype=tl.int1) - for i in range(sparse_n): - max_idx = tl.argmin(score, axis=1, keep_dims=True) - sparse_mask_i = sparse_offsets == max_idx - sparse_mask = sparse_mask | sparse_mask_i - score = tl.where(sparse_mask_i, float("inf"), score) - sx = tl.where(sparse_mask, 0, sx) - sx = sx.reshape(BLOCK_SIZE) - else: - sx = x - # Quantize - qx = tl_quantize(sx, scale, qzero, maxq) + qx = tl_quantize(x, scale, qzero, maxq) y = tl_dequantize(qx, scale, qzero, dtype) error = y - x @@ -59,23 +41,19 @@ def sparse_quantize_error_triton_kernel( tl.store(error_ptr + offsets, error, mask=mask) -def sparse_quantize_error_triton( +def quantize_error_triton( x: torch.Tensor, qx: torch.Tensor, error: torch.Tensor, scale: torch.Tensor, qzero: torch.Tensor, maxq: torch.Tensor, - sparse_n: int = 0, - sparse_m: int = 0, dtype: torch.dtype = None, ) -> None: - if sparse_m > 0: - assert 0 < sparse_n < sparse_m, "sparse_n must be in (0, sparse_m)" n_elements: int = x.numel() grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) - sparse_quantize_error_triton_kernel[grid]( + quantize_error_triton_kernel[grid]( x, qx, error, @@ -84,8 +62,6 @@ def sparse_quantize_error_triton( maxq, torch.empty(0, dtype=dtype) if dtype is not None else None, n_elements, - sparse_n, - sparse_m, BLOCK_SIZE=128, ) @@ -140,8 +116,6 @@ def gptq_loop_graph( error_block: torch.Tensor = None, dtype: torch.dtype = None, gptq_block_size: int = 128, - sparse_n: int = 0, - sparse_m: int = 0, direct: bool = True, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ @@ -168,8 +142,8 @@ def gptq_loop_graph( for i1 in range(0, n_columns, gptq_block_size): i2: int = min(i1 + gptq_block_size, n_columns) for j in range(i1, i2): - sparse_quantize_error_triton( - weight[j], qweight[j], error_block[j - i1], scale[j], qzero[j], maxq, sparse_n, sparse_m, dtype, + quantize_error_triton( + weight[j], qweight[j], error_block[j - i1], scale[j], qzero[j], maxq, dtype, ) addvv_triton(hessian_inv[j, j + 1 : i2], error_block[j - i1], weight[j + 1 : i2]) weight[i2:].addmm_(hessian_inv[i1:i2, i2:].t(), error_block[: i2 - i1], beta=1, alpha=1) @@ -196,10 +170,10 @@ def gptq_loop_graph( s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): for _ in range(n_warmups): - gptq_loop_graph(**graph_tensors, dtype=dtype, gptq_block_size=gptq_block_size, sparse_n=sparse_n, sparse_m=sparse_m, direct=True) + gptq_loop_graph(**graph_tensors, dtype=dtype, gptq_block_size=gptq_block_size, direct=True) torch.cuda.current_stream().wait_stream(s) with torch.cuda.graph(graph): - gptq_loop_graph(**graph_tensors, dtype=dtype, gptq_block_size=gptq_block_size, sparse_n=sparse_n, sparse_m=sparse_m, direct=True) + gptq_loop_graph(**graph_tensors, dtype=dtype, gptq_block_size=gptq_block_size, direct=True) gptq_loop_graph.graph_info[graph_key] = {"graph": graph, "tensors": graph_tensors} graph, graph_tensors = ( @@ -225,8 +199,6 @@ def gptq_loop( maxq: torch.Tensor, dtype: torch.dtype, gptq_block_size: int = 128, - sparse_n: int = 0, - sparse_m: int = 0, ) -> tuple[torch.Tensor, torch.Tensor]: """ Quantize weight tensor with GPTQ algorithm @@ -249,8 +221,6 @@ def gptq_loop( maxq=maxq, dtype=dtype, gptq_block_size=gptq_block_size, - sparse_n=sparse_n, - sparse_m=sparse_m, direct=False, ) return qweight # (C, R) diff --git a/src/sparsegpt.py b/src/sparsegpt.py deleted file mode 100644 index 901e5c3..0000000 --- a/src/sparsegpt.py +++ /dev/null @@ -1,226 +0,0 @@ -from enum import Enum -from typing import Optional, Tuple - -import torch -import torch.nn as nn -import torch.distributed as dist -from torch import Tensor -from torch.nn.modules.conv import _ConvNd - -from src import dist_utils, model_utils, linalg_utils - - -torch.backends.cuda.matmul.allow_tf32 = False -torch.backends.cudnn.allow_tf32 = False - - -class SparseGPT: - - def __init__( - self, - layer: nn.Module, - rel_damp: float = 1e-2, - block_size: int = None, - is_distributed: bool = False, - tied_sparsegpt_handle: Optional["SparseGPT"] = None - ): - self._validate_layer(layer) - self.layer = layer - self.W = self.layer.weight - self.d_row, self.d_col = model_utils.get_number_of_rows_and_cols(layer) - # SparseGPT hyperparameters - self.rel_damp = rel_damp - self.block_size = block_size or self.d_col - # backup layer properties - self.W_device = self.W.device - self.W_dtype = self.W.dtype - self.W_shape = self.W.shape - # init hessian - self.H = None - self.num_samples = 0 - self.is_distributed = is_distributed - self.tied_sparsegpt_handle = tied_sparsegpt_handle - self.num_tied_handles = 0 - if tied_sparsegpt_handle is not None: - tied_sparsegpt_handle.num_tied_handles += 1 - # Flags indicating issues - self.issue_zero_samples = False - self.issue_nan_hessian = False - self.issue_non_invertible = False - - @staticmethod - def _validate_layer(layer): - assert isinstance(layer, (nn.Linear, _ConvNd)), "OBC supports only linear and convolutional layers." - - def has_hessian_issues(self) -> bool: - return any([self.issue_zero_samples, self.issue_nan_hessian, self.issue_non_invertible]) - - # preparatory methods - @torch.no_grad() - def update(self, input: Tensor) -> None: - """ - Update the estimate of Hessian matrix from a batch of data. - - Args: - input: batch of layer inputs - """ - # init hessian - if self.H is None: - self.H = torch.zeros((self.d_col, self.d_col), device=input.device, dtype=torch.float32) - # input reshaping - if isinstance(self.layer, nn.Linear): - input = input.reshape(-1, input.shape[-1]) - else: - unfold = nn.Unfold( - self.layer.kernel_size, - dilation=self.layer.dilation, - padding=self.layer.padding, - stride=self.layer.stride, - ) - # output size (batch_size, channels * \prod kernel_size, num_patches) - input = unfold(input) - input = input.transpose(1, 2).flatten(0, 1) - input = input.float() - # get number of samples (tokens) in batch - num_new_samples = input.shape[0] - # hessian update - beta = self.num_samples / (self.num_samples + num_new_samples) - alpha = 2.0 / (self.num_samples + num_new_samples) - self.H.addmm_(input.T, input, beta=beta, alpha=alpha) - # update number of collected samples - self.num_samples += num_new_samples - - @property - def tokens_collected(self) -> int: - return self.num_samples - - def reset(self) -> None: - self.W = self.layer.weight - if self.num_tied_handles == 0: - self.H = None - elif self.tied_sparsegpt_handle: - self.tied_sparsegpt_handle.num_tied_handles -= 1 - if self.tied_sparsegpt_handle.num_tied_handles == 0: - self.tied_sparsegpt_handle.H = None - self.num_samples = 0 - torch.cuda.empty_cache() - - @torch.no_grad() - def sparsification_pre_step(self) -> None: - """ - Preparatory step with hessian regularization and weight reshaping. - """ - # 1) Hessian preparation - reduce_if_needed = True - if self.H is None: - if self.tied_sparsegpt_handle: - self.H = self.tied_sparsegpt_handle.H - else: - self.H = torch.eye(self.d_col, device=self.W_device, dtype=torch.float32) - self.issue_zero_samples = True - # No need to reduce - reduce_if_needed = False - # synchronize Hessians - if self.is_distributed and reduce_if_needed and dist_utils.is_dist_available_and_initialized(): - dist.all_reduce(self.H, op=dist.ReduceOp.AVG) - # Replace matrix by identity in case of NaNs - if torch.isnan(self.H).any().item(): - self.H = torch.eye(self.d_col, device=self.W_device, dtype=torch.float32) - self.issue_nan_hessian = True - # get ids of pruned channels - pruned_ids = torch.diag(self.H) == 0 - self.H[pruned_ids, pruned_ids] = 1 - # 2) Weight preparation - # copy weight, flatten - self.W = self.W.clone().float() - if isinstance(self.layer, _ConvNd): - self.W = self.W.flatten(1, -1) - self.W[:, pruned_ids] = 0 - # flag pre step as completed - self.pre_step_completed = True - - @torch.no_grad() - def _prune(self, n: int = 2, m: int = 4) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Prune the layer according to the given sparsity. - """ - # 1) Define constants and chunk - d_row, d_col, block_size, device, dtype = self.d_row, self.d_col, self.block_size, self.W_device, self.W_dtype - - is_main_sparsegpt_process = dist_utils.is_main() or not self.is_distributed - - if is_main_sparsegpt_process: - w = self.W - # Get hessian inverse - hessian_inv = self._get_hessian_inverse() - # Get hessian inverse - for c1 in range(0, d_col, block_size): - c2 = min(c1 + block_size, d_col) - ncols = c2 - c1 # number of columns - w_blk = w[:, c1:c2].clone() # column-wise weight slice - res = torch.zeros_like(w_blk) - errs = torch.zeros_like(w_blk) - losses_blk = torch.zeros_like(w_blk) - hessian_inv_blk = hessian_inv[c1:c2, c1:c2] - mask = torch.zeros_like(w_blk, dtype=torch.bool) - # 2) iterate over block - for i in range(ncols): - if i % m == 0: - scores = w_blk[:, i: (i + m)].pow(2) / hessian_inv_blk.diag()[i: (i + m)].view(1, -1).pow(2) - thr, _ = torch.kthvalue(scores, k=n, dim=-1, keepdim=True) - mask[:, i: (i + m)] = scores > thr - - w_ci = w_blk[:, i] - d = hessian_inv_blk[i, i] - - q = w_ci.clone() - q[~mask[:, i]] = 0 - - res[:, i] = q - err = (w_ci - q) / d - losses_blk[:, i] = err ** 2 - - w_blk[:, i:].addr_(err, hessian_inv_blk[i, i:], alpha=-1) - errs[:, i] = err - # 3) update the weights after block - w[:, c1:c2] = res - w[:, c2:].addmm_(errs, hessian_inv[c1:c2, c2:], alpha=-1) - - sweight = w.to(dtype=dtype) - else: - sweight = torch.empty(d_row, d_col, device=device, dtype=dtype) - - if self.is_distributed and dist_utils.is_dist_available_and_initialized(): - dist.barrier() - dist.broadcast(sweight, src=0) - - return sweight - - def prune(self, n: int = 2, m: int = 4) -> Tensor: - self.sparsification_pre_step() - return self._prune(n, m) - - @torch.no_grad() - def _get_hessian_inverse(self): - w = self.W - # Get columns with all zeros - zero_cols = torch.nonzero(w.eq(0).all(dim=0)) - H = self.H - # Regularize Hessian before sparsification - if not self.tied_sparsegpt_handle: - # Mask rows with zero input channels - H[zero_cols, :] = 0 - H[:, zero_cols] = 0 - H[zero_cols, zero_cols] = 1 - # Hessian regularization - damp = self.rel_damp * torch.diag(self.H).mean() - self.H[range(self.d_col), range(self.d_col)] += damp - # Invert - try: - H = linalg_utils.inv_sym(H) - H_inv_cho = torch.linalg.cholesky(H, upper=True) - except: - H_inv_cho = torch.eye(self.d_col, device=H.device, dtype=torch.float32) - # Divide Hessian inverse by diagonal (in order to not divide on it later) - H_inv_cho.div_(H_inv_cho.diag()[:, None]) - return H_inv_cho From 1bb0dd75d245428e409b06806ef6a61a716b488b Mon Sep 17 00:00:00 2001 From: blackadder Date: Thu, 11 Dec 2025 19:07:39 +0100 Subject: [PATCH 4/4] Open-thoughts patch --- src/data_utils.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/src/data_utils.py b/src/data_utils.py index 2a6cadf..f9e2373 100644 --- a/src/data_utils.py +++ b/src/data_utils.py @@ -1,3 +1,4 @@ +import re from typing import Optional, List import torch @@ -5,6 +6,14 @@ from transformers import AutoTokenizer +def split_thought_solution(text: str): + thought_re = re.compile(r"<\|begin_of_thought\|>(.*?)<\|end_of_thought\|>", re.DOTALL) + solution_re = re.compile(r"<\|begin_of_solution\|>(.*?)<\|end_of_solution\|>", re.DOTALL) + + thought = thought_re.search(text).group(1).strip() + solution = solution_re.search(text).group(1).strip() + return thought, solution + def prepare_open_thoughts( tokenizer: AutoTokenizer, max_sequence_length: int, @@ -14,6 +23,12 @@ def prepare_open_thoughts( train_dataset_raw = load_dataset("open-thoughts/OpenThoughts-114k", split="train") if num_calibration_samples: train_dataset_raw = train_dataset_raw.shuffle(seed=seed).select(range(num_calibration_samples)) + # Update chat template + tokenizer.chat_template = tokenizer.chat_template.replace( + "{{render_content(message)}}", + "{%- set rc = message.get('reasoning_content', '') -%}" + "{{rc}}{{render_content(message)}}" + ) # Preprocess the data into the format the model is trained with. def preprocess(example): messages = [] @@ -21,7 +36,12 @@ def preprocess(example): messages.append({"role": "system", "content": example['system']}) # add dialogue for message in example['conversations']: - messages.append({"role": message["from"], "content": message["value"]}) + role = message["from"] + if role == "user": + messages.append({"role": "user", "content": message["value"]}) + else: + thought, solution = split_thought_solution(message["value"]) + messages.append({"role": "assistant", "content": solution, "reasoning_content": thought}) return {"text": tokenizer.apply_chat_template(messages, tokenize=False)} train_dataset_raw = train_dataset_raw.map(preprocess) # Tokenize the data