diff --git a/pack_quantized_model.py b/pack_quantized_model.py index 176d501..356141a 100644 --- a/pack_quantized_model.py +++ b/pack_quantized_model.py @@ -46,6 +46,11 @@ def parse_args(): choices=["float16", "bfloat16"], help="Torch dtype used." ) + parser.add_argument( + "--load_last_shard", + action="store_true", + help="Whether to load the last shard of the model in the beginning (needed from Kimi-K2-Thinking)." + ) args = parser.parse_args() return args @@ -116,6 +121,8 @@ def main(): # Load DeepSeek model config = AutoConfig.from_pretrained(args.model_name_or_path, trust_remote_code=True) + # Infer quantization format + orig_quantization_format = quant_utils.infer_quantization_format(config) if hasattr(config, "quantization_config"): delattr(config, "quantization_config") @@ -148,12 +155,23 @@ def main(): # Prepare directory to save packed weights os.makedirs(args.packed_model_path, exist_ok=True) + # Get num shards + for path in os.listdir(args.model_name_or_path): + if path.endswith(".safetensors"): + num_shards_str = path.split(".")[0].split("-")[-1] + num_shards = int(num_shards_str) + break + # Load initial weight shard weight_dir = args.model_name_or_path current_input_shard_id = 1 - weight_path = f"model-{current_input_shard_id:05}-of-000163.safetensors" + weight_path = f"model-{current_input_shard_id:05}-of-{num_shards_str}.safetensors" param_buffer = loading_utils.load_param_shard(weight_dir, weight_path) + if args.load_last_shard: + last_shard_weight_path = f"model-{num_shards:05}-of-{num_shards_str}.safetensors" + last_shard = loading_utils.load_param_shard(weight_dir, last_shard_weight_path) + param_buffer.update(last_shard) # Save embeddings current_output_shard_path = f"model-{current_output_shard_id:05}-of-{num_output_shards:05}.safetensors" @@ -176,11 +194,15 @@ def main(): while not is_subset(block_keys_with_prefix, set(param_buffer.keys())): current_input_shard_id += 1 - weight_path = f"model-{current_input_shard_id:05}-of-000163.safetensors" - param_buffer.update(loading_utils.load_param_shard(weight_dir, weight_path)) + weight_path = f"model-{current_input_shard_id:05}-of-{num_shards_str}.safetensors" + param_shard = loading_utils.load_param_shard(weight_dir, weight_path) + # Dequantize weights from a current shard + if quant_utils.can_dequantize(param_shard, orig_quantization_format): + quant_utils.dequantize_state_dict(param_shard, dtype, orig_quantization_format) + param_buffer.update(param_shard) block_state_dict = {k: param_buffer[k] for k in param_buffer if k.startswith(prefix)} - quant_utils.dequantize_state_dict(block_state_dict, dtype) + quant_utils.dequantize_state_dict(block_state_dict, dtype, orig_quantization_format) for layer_name in quantized_layer_names[block_idx]: weight_state_dict = torch.load( @@ -189,7 +211,7 @@ def main(): map_location="cpu" ) packed_weight_state_dict = pack_weight(weight_state_dict, args.bits, args.sym, args.group_size) - block_state_dict.pop(f"{layer_name}.weight") + block_state_dict.pop(f"{layer_name}.weight", None) block_state_dict.pop(f"{layer_name}.weight_scale_inv", None) block_state_dict.update({f"{layer_name}.{k}": v for k, v in packed_weight_state_dict.items()}) @@ -209,9 +231,9 @@ def main(): gc.collect() # Load final shard - if current_input_shard_id < 163: - current_input_shard_id = 163 - weight_path = f"model-{current_input_shard_id:05}-of-000163.safetensors" + if current_input_shard_id < num_shards: + current_input_shard_id = num_shards + weight_path = f"model-{current_input_shard_id:05}-of-{num_shards_str}.safetensors" param_buffer.update(loading_utils.load_param_shard(weight_dir, weight_path)) # Save lm head diff --git a/quant.py b/quant.py index 162f5ed..6fbabad 100644 --- a/quant.py +++ b/quant.py @@ -46,7 +46,7 @@ def parse_args(): "--bits", type=int, default=4, - choices=[4], + choices=[2, 4], help="Quantization bitwidth.", ) parser.add_argument( @@ -79,6 +79,11 @@ def parse_args(): parser.add_argument( "--dtype", default="float16", type=str, choices=["float16s", "bfloat16"], help="Torch dtype used." ) + parser.add_argument( + "--load_last_shard", + action="store_true", + help="Whether to load the last shard of the model in the beginning (needed from Kimi-K2-Thinking)." + ) args = parser.parse_args() return args @@ -117,6 +122,8 @@ def main(): # Load DeepSeek model config = AutoConfig.from_pretrained(args.model_name_or_path, trust_remote_code=True) + # Infer quantization format + orig_quantization_format = quant_utils.infer_quantization_format(config) # Sanity check assert config.architectures == ["DeepseekV3ForCausalLM"], "Only DeepseekV3 is supported!" if hasattr(config, "quantization_config"): @@ -125,7 +132,10 @@ def main(): with init_empty_weights(): model = AutoModelForCausalLM.from_config( - config=config, trust_remote_code=True, attn_implementation="flash_attention_2", torch_dtype=dtype + config=config, + trust_remote_code=True, + attn_implementation="flash_attention_2", + torch_dtype=dtype ).eval() model.config.use_cache = False @@ -141,14 +151,26 @@ def main(): calibration_dataset = calibration_dataset[rank * num_seq_per_rank : (rank + 1) * num_seq_per_rank] dist_utils.barrier(device_ids=[rank]) + # Get num shards + for path in os.listdir(args.model_name_or_path): + if path.endswith(".safetensors"): + num_shards_str = path.split(".")[0].split("-")[-1] + num_shards = int(num_shards_str) + break + # Load initial weight shard weight_dir = args.model_name_or_path current_shard_id = 1 - weight_path = f"model-{current_shard_id:05}-of-000163.safetensors" + weight_path = f"model-{current_shard_id:05}-of-{num_shards_str}.safetensors" param_buffer = {} if dist_utils.is_main(): param_buffer = loading_utils.load_param_shard(weight_dir, weight_path) + if args.load_last_shard: + last_shard_weight_path = f"model-{num_shards:05}-of-{num_shards_str}.safetensors" + last_shard = loading_utils.load_param_shard(weight_dir, last_shard_weight_path) + param_buffer.update(last_shard) + dist_utils.barrier(device_ids=[rank]) # Get resume block id @@ -171,6 +193,9 @@ def main(): # Offload embeddings back to meta model.model.embed_tokens.to(device="meta") param_buffer.pop("model.embed_tokens.weight", None) + # Pop lm head and norm in case it is loaded from last shard + param_buffer.pop("lm_head.weight", None) + param_buffer.pop("model.norm.weight", None) for block_idx, block in tqdm( enumerate(model.model.layers), desc="Processing transformer blocks", total=len(model.model.layers) @@ -197,16 +222,20 @@ def main(): if dist_utils.is_main(): can_dequantize = True # Select weights corresponding to current block - block_state_dict = {k[len(prefix) :]: v for k, v in param_buffer.items() if k.startswith(prefix)} + block_state_dict = {k[len(prefix):]: v for k, v in param_buffer.items() if k.startswith(prefix)} while not (is_subset(block_keys_with_prefix, set(param_buffer.keys())) and can_dequantize): current_shard_id += 1 - weight_path = f"model-{current_shard_id:05}-of-000163.safetensors" - param_buffer.update(loading_utils.load_param_shard(weight_dir, weight_path)) + weight_path = f"model-{current_shard_id:05}-of-{num_shards_str}.safetensors" + param_shard = loading_utils.load_param_shard(weight_dir, weight_path) + # Dequantize weights from a current shard + if quant_utils.can_dequantize(param_shard, orig_quantization_format): + quant_utils.dequantize_state_dict(param_shard, dtype, orig_quantization_format) + param_buffer.update(param_shard) # Update weights corresponding to current block block_state_dict = {k[len(prefix) :]: v for k, v in param_buffer.items() if k.startswith(prefix)} - can_dequantize = quant_utils.can_dequantize_from_fp8(block_state_dict) + can_dequantize = quant_utils.can_dequantize(block_state_dict, orig_quantization_format) # Dequantize weights corresponding to current block - quant_utils.dequantize_state_dict(block_state_dict, dtype) + quant_utils.dequantize_state_dict(block_state_dict, dtype, orig_quantization_format) # Put block onto GPU block.to_empty(device=device) @@ -214,7 +243,7 @@ def main(): # Simply load block state dict on master and broadcast if block_idx < model.config.first_k_dense_replace: if dist_utils.is_main(): - block.load_state_dict(block_state_dict) + block.load_state_dict(block_state_dict, strict=False) if dist_utils.is_dist_available_and_initialized(): dist_utils.broadcast_parameters(block) # Send dict with part of expets to target device @@ -222,7 +251,7 @@ def main(): if dist_utils.is_main(): # Load state dict on master rank_state_dict = {k: block_state_dict[k] for k in rank_block_keys} - block.load_state_dict(rank_state_dict) + block.load_state_dict(rank_state_dict, strict=False) # Send to other processes for i in range(1, world_size): rank_state_dict = {k: block_state_dict[k] for k in other_ranks_keys[i - 1]} @@ -232,7 +261,7 @@ def main(): rank_state_dict = block.state_dict() for k in rank_state_dict: dist.recv(rank_state_dict[k], src=0) - block.load_state_dict(rank_state_dict) + block.load_state_dict(rank_state_dict, strict=False) del rank_state_dict # Clear memory before calibration torch.cuda.empty_cache() diff --git a/src/data_utils.py b/src/data_utils.py index 2a6cadf..f9e2373 100644 --- a/src/data_utils.py +++ b/src/data_utils.py @@ -1,3 +1,4 @@ +import re from typing import Optional, List import torch @@ -5,6 +6,14 @@ from transformers import AutoTokenizer +def split_thought_solution(text: str): + thought_re = re.compile(r"<\|begin_of_thought\|>(.*?)<\|end_of_thought\|>", re.DOTALL) + solution_re = re.compile(r"<\|begin_of_solution\|>(.*?)<\|end_of_solution\|>", re.DOTALL) + + thought = thought_re.search(text).group(1).strip() + solution = solution_re.search(text).group(1).strip() + return thought, solution + def prepare_open_thoughts( tokenizer: AutoTokenizer, max_sequence_length: int, @@ -14,6 +23,12 @@ def prepare_open_thoughts( train_dataset_raw = load_dataset("open-thoughts/OpenThoughts-114k", split="train") if num_calibration_samples: train_dataset_raw = train_dataset_raw.shuffle(seed=seed).select(range(num_calibration_samples)) + # Update chat template + tokenizer.chat_template = tokenizer.chat_template.replace( + "{{render_content(message)}}", + "{%- set rc = message.get('reasoning_content', '') -%}" + "{{rc}}{{render_content(message)}}" + ) # Preprocess the data into the format the model is trained with. def preprocess(example): messages = [] @@ -21,7 +36,12 @@ def preprocess(example): messages.append({"role": "system", "content": example['system']}) # add dialogue for message in example['conversations']: - messages.append({"role": message["from"], "content": message["value"]}) + role = message["from"] + if role == "user": + messages.append({"role": "user", "content": message["value"]}) + else: + thought, solution = split_thought_solution(message["value"]) + messages.append({"role": "assistant", "content": solution, "reasoning_content": thought}) return {"text": tokenizer.apply_chat_template(messages, tokenize=False)} train_dataset_raw = train_dataset_raw.map(preprocess) # Tokenize the data diff --git a/src/gptq.py b/src/gptq.py index ede3647..038b514 100644 --- a/src/gptq.py +++ b/src/gptq.py @@ -213,7 +213,7 @@ def _quantize(self, bits: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor return qweight, scale, zero - def quantize(self, bits: int | float) -> Tensor: + def quantize(self, bits: int) -> Tensor: self.quantization_pre_step() return self._quantize(bits) diff --git a/src/gptq_loop.py b/src/gptq_loop.py index ced958e..f70f5d7 100644 --- a/src/gptq_loop.py +++ b/src/gptq_loop.py @@ -31,6 +31,7 @@ def quantize_error_triton_kernel( maxq = tl.load(maxq_ptr) dtype = None if dtype_ptr is None else tl.load(dtype_ptr).dtype + # Quantize qx = tl_quantize(x, scale, qzero, maxq) y = tl_dequantize(qx, scale, qzero, dtype) error = y - x diff --git a/src/quant_utils.py b/src/quant_utils.py index f949421..351d99a 100644 --- a/src/quant_utils.py +++ b/src/quant_utils.py @@ -6,6 +6,8 @@ import torch.nn.functional as F import triton from triton import language as tl +from transformers import AutoConfig +from compressed_tensors.compressors import unpack_from_int32 torch.backends.cuda.matmul.allow_tf32 = False @@ -283,22 +285,80 @@ def dequantize_weight_from_fp8(W, s): return W -def dequantize_state_dict(state_dict: dict[str, torch.Tensor], dtype: torch.dtype = torch.float16) -> None: +def dequantize_state_dict( + state_dict: dict[str, torch.Tensor], + dtype: torch.dtype = torch.float16, + quantization_format: str = "fp8" +) -> None: + assert quantization_format in ["fp8", "int4"] state_dict_keys = list(state_dict.keys()) - # Dequantize - for k in state_dict_keys: - if k.endswith("scale_inv"): - layer_name, _ = k.rsplit(".", 1) - W = state_dict[f"{layer_name}.weight"].to(dtype) - s = state_dict[f"{layer_name}.weight_scale_inv"].to(dtype) - - state_dict[f"{layer_name}.weight"] = dequantize_weight_from_fp8(W, s) - del state_dict[f"{layer_name}.weight_scale_inv"] + # Original DeepSeek packing + if quantization_format == "fp8": + # Dequantize + for k in state_dict_keys: + if k.endswith("scale_inv"): + layer_name, _ = k.rsplit(".", 1) + + W = state_dict[f"{layer_name}.weight"].to(dtype) + s = state_dict[f"{layer_name}.weight_scale_inv"].to(dtype) + + state_dict[f"{layer_name}.weight"] = dequantize_weight_from_fp8(W, s) + del state_dict[f"{layer_name}.weight_scale_inv"] + + # KIMI-K2 compressed tensors packing + elif quantization_format == "int4": + # Dequantize + for k in state_dict_keys: + if k.endswith("weight_packed"): + layer_name, _ = k.rsplit(".", 1) + + weight_shape = state_dict[f"{layer_name}.weight_shape"] + + qweight_packed = state_dict[f"{layer_name}.weight_packed"] + qweight = unpack_from_int32(qweight_packed, num_bits=4, shape=weight_shape) + + scale = state_dict[f"{layer_name}.weight_scale"] + + W = (qweight.view(*scale.shape, -1) * scale[..., None]).view_as(qweight).to(dtype) + + state_dict[f"{layer_name}.weight"] = W + del state_dict[f"{layer_name}.weight_packed"] + del state_dict[f"{layer_name}.weight_shape"] + del state_dict[f"{layer_name}.weight_scale"] + + +def can_dequantize( + state_dict: dict[str, torch.Tensor], + quantization_format: str = "fp8" +) -> bool: + assert quantization_format in ["fp8", "int4"] + # Original DeepSeek packing + if quantization_format == "fp8": + for k, v in state_dict.items(): + if v.dtype in FP8_DTYPES and f"{k}_scale_inv" not in state_dict: + return False + + # KIMI-K2 compressed tensors packing + elif quantization_format == "int4": + for k, v in state_dict.items(): + if k.endswith("weight_packed"): + layer_name, _ = k.rsplit(".", 1) + if f"{layer_name}.weight_shape" not in state_dict or f"{layer_name}.weight_scale" not in state_dict: + return False + return True -def can_dequantize_from_fp8(state_dict: dict[str, torch.Tensor]) -> bool: - for k, v in state_dict.items(): - if v.dtype in FP8_DTYPES and f"{k}_scale_inv" not in state_dict: - return False - return True +def infer_quantization_format(config: AutoConfig) -> str: + quant_method = config.quantization_config["quant_method"] + # DeepSeek format + if quant_method == "fp8": + return "fp8" + # KIMI-K2 format + elif quant_method == "compressed-tensors": + if config.quantization_config["config_groups"]["group_0"]["weights"]["num_bits"] == 4: + return "int4" + else: + raise ValueError("Only 4-bit quantization is supported.") + else: + raise ValueError("Unknown or unsupported quantization method.")