diff --git a/pack_quantized_model.py b/pack_quantized_model.py
index 176d501..356141a 100644
--- a/pack_quantized_model.py
+++ b/pack_quantized_model.py
@@ -46,6 +46,11 @@ def parse_args():
choices=["float16", "bfloat16"],
help="Torch dtype used."
)
+ parser.add_argument(
+ "--load_last_shard",
+ action="store_true",
+ help="Whether to load the last shard of the model in the beginning (needed from Kimi-K2-Thinking)."
+ )
args = parser.parse_args()
return args
@@ -116,6 +121,8 @@ def main():
# Load DeepSeek model
config = AutoConfig.from_pretrained(args.model_name_or_path, trust_remote_code=True)
+ # Infer quantization format
+ orig_quantization_format = quant_utils.infer_quantization_format(config)
if hasattr(config, "quantization_config"):
delattr(config, "quantization_config")
@@ -148,12 +155,23 @@ def main():
# Prepare directory to save packed weights
os.makedirs(args.packed_model_path, exist_ok=True)
+ # Get num shards
+ for path in os.listdir(args.model_name_or_path):
+ if path.endswith(".safetensors"):
+ num_shards_str = path.split(".")[0].split("-")[-1]
+ num_shards = int(num_shards_str)
+ break
+
# Load initial weight shard
weight_dir = args.model_name_or_path
current_input_shard_id = 1
- weight_path = f"model-{current_input_shard_id:05}-of-000163.safetensors"
+ weight_path = f"model-{current_input_shard_id:05}-of-{num_shards_str}.safetensors"
param_buffer = loading_utils.load_param_shard(weight_dir, weight_path)
+ if args.load_last_shard:
+ last_shard_weight_path = f"model-{num_shards:05}-of-{num_shards_str}.safetensors"
+ last_shard = loading_utils.load_param_shard(weight_dir, last_shard_weight_path)
+ param_buffer.update(last_shard)
# Save embeddings
current_output_shard_path = f"model-{current_output_shard_id:05}-of-{num_output_shards:05}.safetensors"
@@ -176,11 +194,15 @@ def main():
while not is_subset(block_keys_with_prefix, set(param_buffer.keys())):
current_input_shard_id += 1
- weight_path = f"model-{current_input_shard_id:05}-of-000163.safetensors"
- param_buffer.update(loading_utils.load_param_shard(weight_dir, weight_path))
+ weight_path = f"model-{current_input_shard_id:05}-of-{num_shards_str}.safetensors"
+ param_shard = loading_utils.load_param_shard(weight_dir, weight_path)
+ # Dequantize weights from a current shard
+ if quant_utils.can_dequantize(param_shard, orig_quantization_format):
+ quant_utils.dequantize_state_dict(param_shard, dtype, orig_quantization_format)
+ param_buffer.update(param_shard)
block_state_dict = {k: param_buffer[k] for k in param_buffer if k.startswith(prefix)}
- quant_utils.dequantize_state_dict(block_state_dict, dtype)
+ quant_utils.dequantize_state_dict(block_state_dict, dtype, orig_quantization_format)
for layer_name in quantized_layer_names[block_idx]:
weight_state_dict = torch.load(
@@ -189,7 +211,7 @@ def main():
map_location="cpu"
)
packed_weight_state_dict = pack_weight(weight_state_dict, args.bits, args.sym, args.group_size)
- block_state_dict.pop(f"{layer_name}.weight")
+ block_state_dict.pop(f"{layer_name}.weight", None)
block_state_dict.pop(f"{layer_name}.weight_scale_inv", None)
block_state_dict.update({f"{layer_name}.{k}": v for k, v in packed_weight_state_dict.items()})
@@ -209,9 +231,9 @@ def main():
gc.collect()
# Load final shard
- if current_input_shard_id < 163:
- current_input_shard_id = 163
- weight_path = f"model-{current_input_shard_id:05}-of-000163.safetensors"
+ if current_input_shard_id < num_shards:
+ current_input_shard_id = num_shards
+ weight_path = f"model-{current_input_shard_id:05}-of-{num_shards_str}.safetensors"
param_buffer.update(loading_utils.load_param_shard(weight_dir, weight_path))
# Save lm head
diff --git a/quant.py b/quant.py
index 162f5ed..6fbabad 100644
--- a/quant.py
+++ b/quant.py
@@ -46,7 +46,7 @@ def parse_args():
"--bits",
type=int,
default=4,
- choices=[4],
+ choices=[2, 4],
help="Quantization bitwidth.",
)
parser.add_argument(
@@ -79,6 +79,11 @@ def parse_args():
parser.add_argument(
"--dtype", default="float16", type=str, choices=["float16s", "bfloat16"], help="Torch dtype used."
)
+ parser.add_argument(
+ "--load_last_shard",
+ action="store_true",
+ help="Whether to load the last shard of the model in the beginning (needed from Kimi-K2-Thinking)."
+ )
args = parser.parse_args()
return args
@@ -117,6 +122,8 @@ def main():
# Load DeepSeek model
config = AutoConfig.from_pretrained(args.model_name_or_path, trust_remote_code=True)
+ # Infer quantization format
+ orig_quantization_format = quant_utils.infer_quantization_format(config)
# Sanity check
assert config.architectures == ["DeepseekV3ForCausalLM"], "Only DeepseekV3 is supported!"
if hasattr(config, "quantization_config"):
@@ -125,7 +132,10 @@ def main():
with init_empty_weights():
model = AutoModelForCausalLM.from_config(
- config=config, trust_remote_code=True, attn_implementation="flash_attention_2", torch_dtype=dtype
+ config=config,
+ trust_remote_code=True,
+ attn_implementation="flash_attention_2",
+ torch_dtype=dtype
).eval()
model.config.use_cache = False
@@ -141,14 +151,26 @@ def main():
calibration_dataset = calibration_dataset[rank * num_seq_per_rank : (rank + 1) * num_seq_per_rank]
dist_utils.barrier(device_ids=[rank])
+ # Get num shards
+ for path in os.listdir(args.model_name_or_path):
+ if path.endswith(".safetensors"):
+ num_shards_str = path.split(".")[0].split("-")[-1]
+ num_shards = int(num_shards_str)
+ break
+
# Load initial weight shard
weight_dir = args.model_name_or_path
current_shard_id = 1
- weight_path = f"model-{current_shard_id:05}-of-000163.safetensors"
+ weight_path = f"model-{current_shard_id:05}-of-{num_shards_str}.safetensors"
param_buffer = {}
if dist_utils.is_main():
param_buffer = loading_utils.load_param_shard(weight_dir, weight_path)
+ if args.load_last_shard:
+ last_shard_weight_path = f"model-{num_shards:05}-of-{num_shards_str}.safetensors"
+ last_shard = loading_utils.load_param_shard(weight_dir, last_shard_weight_path)
+ param_buffer.update(last_shard)
+
dist_utils.barrier(device_ids=[rank])
# Get resume block id
@@ -171,6 +193,9 @@ def main():
# Offload embeddings back to meta
model.model.embed_tokens.to(device="meta")
param_buffer.pop("model.embed_tokens.weight", None)
+ # Pop lm head and norm in case it is loaded from last shard
+ param_buffer.pop("lm_head.weight", None)
+ param_buffer.pop("model.norm.weight", None)
for block_idx, block in tqdm(
enumerate(model.model.layers), desc="Processing transformer blocks", total=len(model.model.layers)
@@ -197,16 +222,20 @@ def main():
if dist_utils.is_main():
can_dequantize = True
# Select weights corresponding to current block
- block_state_dict = {k[len(prefix) :]: v for k, v in param_buffer.items() if k.startswith(prefix)}
+ block_state_dict = {k[len(prefix):]: v for k, v in param_buffer.items() if k.startswith(prefix)}
while not (is_subset(block_keys_with_prefix, set(param_buffer.keys())) and can_dequantize):
current_shard_id += 1
- weight_path = f"model-{current_shard_id:05}-of-000163.safetensors"
- param_buffer.update(loading_utils.load_param_shard(weight_dir, weight_path))
+ weight_path = f"model-{current_shard_id:05}-of-{num_shards_str}.safetensors"
+ param_shard = loading_utils.load_param_shard(weight_dir, weight_path)
+ # Dequantize weights from a current shard
+ if quant_utils.can_dequantize(param_shard, orig_quantization_format):
+ quant_utils.dequantize_state_dict(param_shard, dtype, orig_quantization_format)
+ param_buffer.update(param_shard)
# Update weights corresponding to current block
block_state_dict = {k[len(prefix) :]: v for k, v in param_buffer.items() if k.startswith(prefix)}
- can_dequantize = quant_utils.can_dequantize_from_fp8(block_state_dict)
+ can_dequantize = quant_utils.can_dequantize(block_state_dict, orig_quantization_format)
# Dequantize weights corresponding to current block
- quant_utils.dequantize_state_dict(block_state_dict, dtype)
+ quant_utils.dequantize_state_dict(block_state_dict, dtype, orig_quantization_format)
# Put block onto GPU
block.to_empty(device=device)
@@ -214,7 +243,7 @@ def main():
# Simply load block state dict on master and broadcast
if block_idx < model.config.first_k_dense_replace:
if dist_utils.is_main():
- block.load_state_dict(block_state_dict)
+ block.load_state_dict(block_state_dict, strict=False)
if dist_utils.is_dist_available_and_initialized():
dist_utils.broadcast_parameters(block)
# Send dict with part of expets to target device
@@ -222,7 +251,7 @@ def main():
if dist_utils.is_main():
# Load state dict on master
rank_state_dict = {k: block_state_dict[k] for k in rank_block_keys}
- block.load_state_dict(rank_state_dict)
+ block.load_state_dict(rank_state_dict, strict=False)
# Send to other processes
for i in range(1, world_size):
rank_state_dict = {k: block_state_dict[k] for k in other_ranks_keys[i - 1]}
@@ -232,7 +261,7 @@ def main():
rank_state_dict = block.state_dict()
for k in rank_state_dict:
dist.recv(rank_state_dict[k], src=0)
- block.load_state_dict(rank_state_dict)
+ block.load_state_dict(rank_state_dict, strict=False)
del rank_state_dict
# Clear memory before calibration
torch.cuda.empty_cache()
diff --git a/src/data_utils.py b/src/data_utils.py
index 2a6cadf..f9e2373 100644
--- a/src/data_utils.py
+++ b/src/data_utils.py
@@ -1,3 +1,4 @@
+import re
from typing import Optional, List
import torch
@@ -5,6 +6,14 @@
from transformers import AutoTokenizer
+def split_thought_solution(text: str):
+ thought_re = re.compile(r"<\|begin_of_thought\|>(.*?)<\|end_of_thought\|>", re.DOTALL)
+ solution_re = re.compile(r"<\|begin_of_solution\|>(.*?)<\|end_of_solution\|>", re.DOTALL)
+
+ thought = thought_re.search(text).group(1).strip()
+ solution = solution_re.search(text).group(1).strip()
+ return thought, solution
+
def prepare_open_thoughts(
tokenizer: AutoTokenizer,
max_sequence_length: int,
@@ -14,6 +23,12 @@ def prepare_open_thoughts(
train_dataset_raw = load_dataset("open-thoughts/OpenThoughts-114k", split="train")
if num_calibration_samples:
train_dataset_raw = train_dataset_raw.shuffle(seed=seed).select(range(num_calibration_samples))
+ # Update chat template
+ tokenizer.chat_template = tokenizer.chat_template.replace(
+ "{{render_content(message)}}",
+ "{%- set rc = message.get('reasoning_content', '') -%}"
+ "{{rc}}{{render_content(message)}}"
+ )
# Preprocess the data into the format the model is trained with.
def preprocess(example):
messages = []
@@ -21,7 +36,12 @@ def preprocess(example):
messages.append({"role": "system", "content": example['system']})
# add dialogue
for message in example['conversations']:
- messages.append({"role": message["from"], "content": message["value"]})
+ role = message["from"]
+ if role == "user":
+ messages.append({"role": "user", "content": message["value"]})
+ else:
+ thought, solution = split_thought_solution(message["value"])
+ messages.append({"role": "assistant", "content": solution, "reasoning_content": thought})
return {"text": tokenizer.apply_chat_template(messages, tokenize=False)}
train_dataset_raw = train_dataset_raw.map(preprocess)
# Tokenize the data
diff --git a/src/gptq.py b/src/gptq.py
index ede3647..038b514 100644
--- a/src/gptq.py
+++ b/src/gptq.py
@@ -213,7 +213,7 @@ def _quantize(self, bits: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor
return qweight, scale, zero
- def quantize(self, bits: int | float) -> Tensor:
+ def quantize(self, bits: int) -> Tensor:
self.quantization_pre_step()
return self._quantize(bits)
diff --git a/src/gptq_loop.py b/src/gptq_loop.py
index ced958e..f70f5d7 100644
--- a/src/gptq_loop.py
+++ b/src/gptq_loop.py
@@ -31,6 +31,7 @@ def quantize_error_triton_kernel(
maxq = tl.load(maxq_ptr)
dtype = None if dtype_ptr is None else tl.load(dtype_ptr).dtype
+ # Quantize
qx = tl_quantize(x, scale, qzero, maxq)
y = tl_dequantize(qx, scale, qzero, dtype)
error = y - x
diff --git a/src/quant_utils.py b/src/quant_utils.py
index f949421..351d99a 100644
--- a/src/quant_utils.py
+++ b/src/quant_utils.py
@@ -6,6 +6,8 @@
import torch.nn.functional as F
import triton
from triton import language as tl
+from transformers import AutoConfig
+from compressed_tensors.compressors import unpack_from_int32
torch.backends.cuda.matmul.allow_tf32 = False
@@ -283,22 +285,80 @@ def dequantize_weight_from_fp8(W, s):
return W
-def dequantize_state_dict(state_dict: dict[str, torch.Tensor], dtype: torch.dtype = torch.float16) -> None:
+def dequantize_state_dict(
+ state_dict: dict[str, torch.Tensor],
+ dtype: torch.dtype = torch.float16,
+ quantization_format: str = "fp8"
+) -> None:
+ assert quantization_format in ["fp8", "int4"]
state_dict_keys = list(state_dict.keys())
- # Dequantize
- for k in state_dict_keys:
- if k.endswith("scale_inv"):
- layer_name, _ = k.rsplit(".", 1)
- W = state_dict[f"{layer_name}.weight"].to(dtype)
- s = state_dict[f"{layer_name}.weight_scale_inv"].to(dtype)
-
- state_dict[f"{layer_name}.weight"] = dequantize_weight_from_fp8(W, s)
- del state_dict[f"{layer_name}.weight_scale_inv"]
+ # Original DeepSeek packing
+ if quantization_format == "fp8":
+ # Dequantize
+ for k in state_dict_keys:
+ if k.endswith("scale_inv"):
+ layer_name, _ = k.rsplit(".", 1)
+
+ W = state_dict[f"{layer_name}.weight"].to(dtype)
+ s = state_dict[f"{layer_name}.weight_scale_inv"].to(dtype)
+
+ state_dict[f"{layer_name}.weight"] = dequantize_weight_from_fp8(W, s)
+ del state_dict[f"{layer_name}.weight_scale_inv"]
+
+ # KIMI-K2 compressed tensors packing
+ elif quantization_format == "int4":
+ # Dequantize
+ for k in state_dict_keys:
+ if k.endswith("weight_packed"):
+ layer_name, _ = k.rsplit(".", 1)
+
+ weight_shape = state_dict[f"{layer_name}.weight_shape"]
+
+ qweight_packed = state_dict[f"{layer_name}.weight_packed"]
+ qweight = unpack_from_int32(qweight_packed, num_bits=4, shape=weight_shape)
+
+ scale = state_dict[f"{layer_name}.weight_scale"]
+
+ W = (qweight.view(*scale.shape, -1) * scale[..., None]).view_as(qweight).to(dtype)
+
+ state_dict[f"{layer_name}.weight"] = W
+ del state_dict[f"{layer_name}.weight_packed"]
+ del state_dict[f"{layer_name}.weight_shape"]
+ del state_dict[f"{layer_name}.weight_scale"]
+
+
+def can_dequantize(
+ state_dict: dict[str, torch.Tensor],
+ quantization_format: str = "fp8"
+) -> bool:
+ assert quantization_format in ["fp8", "int4"]
+ # Original DeepSeek packing
+ if quantization_format == "fp8":
+ for k, v in state_dict.items():
+ if v.dtype in FP8_DTYPES and f"{k}_scale_inv" not in state_dict:
+ return False
+
+ # KIMI-K2 compressed tensors packing
+ elif quantization_format == "int4":
+ for k, v in state_dict.items():
+ if k.endswith("weight_packed"):
+ layer_name, _ = k.rsplit(".", 1)
+ if f"{layer_name}.weight_shape" not in state_dict or f"{layer_name}.weight_scale" not in state_dict:
+ return False
+ return True
-def can_dequantize_from_fp8(state_dict: dict[str, torch.Tensor]) -> bool:
- for k, v in state_dict.items():
- if v.dtype in FP8_DTYPES and f"{k}_scale_inv" not in state_dict:
- return False
- return True
+def infer_quantization_format(config: AutoConfig) -> str:
+ quant_method = config.quantization_config["quant_method"]
+ # DeepSeek format
+ if quant_method == "fp8":
+ return "fp8"
+ # KIMI-K2 format
+ elif quant_method == "compressed-tensors":
+ if config.quantization_config["config_groups"]["group_0"]["weights"]["num_bits"] == 4:
+ return "int4"
+ else:
+ raise ValueError("Only 4-bit quantization is supported.")
+ else:
+ raise ValueError("Unknown or unsupported quantization method.")