Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 30 additions & 8 deletions pack_quantized_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ def parse_args():
choices=["float16", "bfloat16"],
help="Torch dtype used."
)
parser.add_argument(
"--load_last_shard",
action="store_true",
help="Whether to load the last shard of the model in the beginning (needed from Kimi-K2-Thinking)."
)
args = parser.parse_args()
return args

Expand Down Expand Up @@ -116,6 +121,8 @@ def main():

# Load DeepSeek model
config = AutoConfig.from_pretrained(args.model_name_or_path, trust_remote_code=True)
# Infer quantization format
orig_quantization_format = quant_utils.infer_quantization_format(config)
if hasattr(config, "quantization_config"):
delattr(config, "quantization_config")

Expand Down Expand Up @@ -148,12 +155,23 @@ def main():
# Prepare directory to save packed weights
os.makedirs(args.packed_model_path, exist_ok=True)

# Get num shards
for path in os.listdir(args.model_name_or_path):
if path.endswith(".safetensors"):
num_shards_str = path.split(".")[0].split("-")[-1]
num_shards = int(num_shards_str)
break

# Load initial weight shard
weight_dir = args.model_name_or_path
current_input_shard_id = 1
weight_path = f"model-{current_input_shard_id:05}-of-000163.safetensors"
weight_path = f"model-{current_input_shard_id:05}-of-{num_shards_str}.safetensors"

param_buffer = loading_utils.load_param_shard(weight_dir, weight_path)
if args.load_last_shard:
last_shard_weight_path = f"model-{num_shards:05}-of-{num_shards_str}.safetensors"
last_shard = loading_utils.load_param_shard(weight_dir, last_shard_weight_path)
param_buffer.update(last_shard)

# Save embeddings
current_output_shard_path = f"model-{current_output_shard_id:05}-of-{num_output_shards:05}.safetensors"
Expand All @@ -176,11 +194,15 @@ def main():

while not is_subset(block_keys_with_prefix, set(param_buffer.keys())):
current_input_shard_id += 1
weight_path = f"model-{current_input_shard_id:05}-of-000163.safetensors"
param_buffer.update(loading_utils.load_param_shard(weight_dir, weight_path))
weight_path = f"model-{current_input_shard_id:05}-of-{num_shards_str}.safetensors"
param_shard = loading_utils.load_param_shard(weight_dir, weight_path)
# Dequantize weights from a current shard
if quant_utils.can_dequantize(param_shard, orig_quantization_format):
quant_utils.dequantize_state_dict(param_shard, dtype, orig_quantization_format)
param_buffer.update(param_shard)

block_state_dict = {k: param_buffer[k] for k in param_buffer if k.startswith(prefix)}
quant_utils.dequantize_state_dict(block_state_dict, dtype)
quant_utils.dequantize_state_dict(block_state_dict, dtype, orig_quantization_format)

for layer_name in quantized_layer_names[block_idx]:
weight_state_dict = torch.load(
Expand All @@ -189,7 +211,7 @@ def main():
map_location="cpu"
)
packed_weight_state_dict = pack_weight(weight_state_dict, args.bits, args.sym, args.group_size)
block_state_dict.pop(f"{layer_name}.weight")
block_state_dict.pop(f"{layer_name}.weight", None)
block_state_dict.pop(f"{layer_name}.weight_scale_inv", None)
block_state_dict.update({f"{layer_name}.{k}": v for k, v in packed_weight_state_dict.items()})

Expand All @@ -209,9 +231,9 @@ def main():
gc.collect()

# Load final shard
if current_input_shard_id < 163:
current_input_shard_id = 163
weight_path = f"model-{current_input_shard_id:05}-of-000163.safetensors"
if current_input_shard_id < num_shards:
current_input_shard_id = num_shards
weight_path = f"model-{current_input_shard_id:05}-of-{num_shards_str}.safetensors"
param_buffer.update(loading_utils.load_param_shard(weight_dir, weight_path))

# Save lm head
Expand Down
51 changes: 40 additions & 11 deletions quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def parse_args():
"--bits",
type=int,
default=4,
choices=[4],
choices=[2, 4],
help="Quantization bitwidth.",
)
parser.add_argument(
Expand Down Expand Up @@ -79,6 +79,11 @@ def parse_args():
parser.add_argument(
"--dtype", default="float16", type=str, choices=["float16s", "bfloat16"], help="Torch dtype used."
)
parser.add_argument(
"--load_last_shard",
action="store_true",
help="Whether to load the last shard of the model in the beginning (needed from Kimi-K2-Thinking)."
)
args = parser.parse_args()

return args
Expand Down Expand Up @@ -117,6 +122,8 @@ def main():

# Load DeepSeek model
config = AutoConfig.from_pretrained(args.model_name_or_path, trust_remote_code=True)
# Infer quantization format
orig_quantization_format = quant_utils.infer_quantization_format(config)
# Sanity check
assert config.architectures == ["DeepseekV3ForCausalLM"], "Only DeepseekV3 is supported!"
if hasattr(config, "quantization_config"):
Expand All @@ -125,7 +132,10 @@ def main():

with init_empty_weights():
model = AutoModelForCausalLM.from_config(
config=config, trust_remote_code=True, attn_implementation="flash_attention_2", torch_dtype=dtype
config=config,
trust_remote_code=True,
attn_implementation="flash_attention_2",
torch_dtype=dtype
).eval()
model.config.use_cache = False

Expand All @@ -141,14 +151,26 @@ def main():
calibration_dataset = calibration_dataset[rank * num_seq_per_rank : (rank + 1) * num_seq_per_rank]
dist_utils.barrier(device_ids=[rank])

# Get num shards
for path in os.listdir(args.model_name_or_path):
if path.endswith(".safetensors"):
num_shards_str = path.split(".")[0].split("-")[-1]
num_shards = int(num_shards_str)
break

# Load initial weight shard
weight_dir = args.model_name_or_path
current_shard_id = 1
weight_path = f"model-{current_shard_id:05}-of-000163.safetensors"
weight_path = f"model-{current_shard_id:05}-of-{num_shards_str}.safetensors"

param_buffer = {}
if dist_utils.is_main():
param_buffer = loading_utils.load_param_shard(weight_dir, weight_path)
if args.load_last_shard:
last_shard_weight_path = f"model-{num_shards:05}-of-{num_shards_str}.safetensors"
last_shard = loading_utils.load_param_shard(weight_dir, last_shard_weight_path)
param_buffer.update(last_shard)

dist_utils.barrier(device_ids=[rank])

# Get resume block id
Expand All @@ -171,6 +193,9 @@ def main():
# Offload embeddings back to meta
model.model.embed_tokens.to(device="meta")
param_buffer.pop("model.embed_tokens.weight", None)
# Pop lm head and norm in case it is loaded from last shard
param_buffer.pop("lm_head.weight", None)
param_buffer.pop("model.norm.weight", None)

for block_idx, block in tqdm(
enumerate(model.model.layers), desc="Processing transformer blocks", total=len(model.model.layers)
Expand All @@ -197,32 +222,36 @@ def main():
if dist_utils.is_main():
can_dequantize = True
# Select weights corresponding to current block
block_state_dict = {k[len(prefix) :]: v for k, v in param_buffer.items() if k.startswith(prefix)}
block_state_dict = {k[len(prefix):]: v for k, v in param_buffer.items() if k.startswith(prefix)}
while not (is_subset(block_keys_with_prefix, set(param_buffer.keys())) and can_dequantize):
current_shard_id += 1
weight_path = f"model-{current_shard_id:05}-of-000163.safetensors"
param_buffer.update(loading_utils.load_param_shard(weight_dir, weight_path))
weight_path = f"model-{current_shard_id:05}-of-{num_shards_str}.safetensors"
param_shard = loading_utils.load_param_shard(weight_dir, weight_path)
# Dequantize weights from a current shard
if quant_utils.can_dequantize(param_shard, orig_quantization_format):
quant_utils.dequantize_state_dict(param_shard, dtype, orig_quantization_format)
param_buffer.update(param_shard)
# Update weights corresponding to current block
block_state_dict = {k[len(prefix) :]: v for k, v in param_buffer.items() if k.startswith(prefix)}
can_dequantize = quant_utils.can_dequantize_from_fp8(block_state_dict)
can_dequantize = quant_utils.can_dequantize(block_state_dict, orig_quantization_format)
# Dequantize weights corresponding to current block
quant_utils.dequantize_state_dict(block_state_dict, dtype)
quant_utils.dequantize_state_dict(block_state_dict, dtype, orig_quantization_format)

# Put block onto GPU
block.to_empty(device=device)

# Simply load block state dict on master and broadcast
if block_idx < model.config.first_k_dense_replace:
if dist_utils.is_main():
block.load_state_dict(block_state_dict)
block.load_state_dict(block_state_dict, strict=False)
if dist_utils.is_dist_available_and_initialized():
dist_utils.broadcast_parameters(block)
# Send dict with part of expets to target device
else:
if dist_utils.is_main():
# Load state dict on master
rank_state_dict = {k: block_state_dict[k] for k in rank_block_keys}
block.load_state_dict(rank_state_dict)
block.load_state_dict(rank_state_dict, strict=False)
# Send to other processes
for i in range(1, world_size):
rank_state_dict = {k: block_state_dict[k] for k in other_ranks_keys[i - 1]}
Expand All @@ -232,7 +261,7 @@ def main():
rank_state_dict = block.state_dict()
for k in rank_state_dict:
dist.recv(rank_state_dict[k], src=0)
block.load_state_dict(rank_state_dict)
block.load_state_dict(rank_state_dict, strict=False)
del rank_state_dict
# Clear memory before calibration
torch.cuda.empty_cache()
Expand Down
22 changes: 21 additions & 1 deletion src/data_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
import re
from typing import Optional, List

import torch
from datasets import load_dataset
from transformers import AutoTokenizer


def split_thought_solution(text: str):
thought_re = re.compile(r"<\|begin_of_thought\|>(.*?)<\|end_of_thought\|>", re.DOTALL)
solution_re = re.compile(r"<\|begin_of_solution\|>(.*?)<\|end_of_solution\|>", re.DOTALL)

thought = thought_re.search(text).group(1).strip()
solution = solution_re.search(text).group(1).strip()
return thought, solution

def prepare_open_thoughts(
tokenizer: AutoTokenizer,
max_sequence_length: int,
Expand All @@ -14,14 +23,25 @@ def prepare_open_thoughts(
train_dataset_raw = load_dataset("open-thoughts/OpenThoughts-114k", split="train")
if num_calibration_samples:
train_dataset_raw = train_dataset_raw.shuffle(seed=seed).select(range(num_calibration_samples))
# Update chat template
tokenizer.chat_template = tokenizer.chat_template.replace(
"<think></think>{{render_content(message)}}",
"{%- set rc = message.get('reasoning_content', '') -%}"
"<think>{{rc}}</think>{{render_content(message)}}"
)
# Preprocess the data into the format the model is trained with.
def preprocess(example):
messages = []
# add system prompt
messages.append({"role": "system", "content": example['system']})
# add dialogue
for message in example['conversations']:
messages.append({"role": message["from"], "content": message["value"]})
role = message["from"]
if role == "user":
messages.append({"role": "user", "content": message["value"]})
else:
thought, solution = split_thought_solution(message["value"])
messages.append({"role": "assistant", "content": solution, "reasoning_content": thought})
return {"text": tokenizer.apply_chat_template(messages, tokenize=False)}
train_dataset_raw = train_dataset_raw.map(preprocess)
# Tokenize the data
Expand Down
2 changes: 1 addition & 1 deletion src/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def _quantize(self, bits: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor

return qweight, scale, zero

def quantize(self, bits: int | float) -> Tensor:
def quantize(self, bits: int) -> Tensor:
self.quantization_pre_step()
return self._quantize(bits)

Expand Down
1 change: 1 addition & 0 deletions src/gptq_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def quantize_error_triton_kernel(
maxq = tl.load(maxq_ptr)
dtype = None if dtype_ptr is None else tl.load(dtype_ptr).dtype

# Quantize
qx = tl_quantize(x, scale, qzero, maxq)
y = tl_dequantize(qx, scale, qzero, dtype)
error = y - x
Expand Down
90 changes: 75 additions & 15 deletions src/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import torch.nn.functional as F
import triton
from triton import language as tl
from transformers import AutoConfig
from compressed_tensors.compressors import unpack_from_int32


torch.backends.cuda.matmul.allow_tf32 = False
Expand Down Expand Up @@ -283,22 +285,80 @@ def dequantize_weight_from_fp8(W, s):
return W


def dequantize_state_dict(state_dict: dict[str, torch.Tensor], dtype: torch.dtype = torch.float16) -> None:
def dequantize_state_dict(
state_dict: dict[str, torch.Tensor],
dtype: torch.dtype = torch.float16,
quantization_format: str = "fp8"
) -> None:
assert quantization_format in ["fp8", "int4"]
state_dict_keys = list(state_dict.keys())
# Dequantize
for k in state_dict_keys:
if k.endswith("scale_inv"):
layer_name, _ = k.rsplit(".", 1)

W = state_dict[f"{layer_name}.weight"].to(dtype)
s = state_dict[f"{layer_name}.weight_scale_inv"].to(dtype)

state_dict[f"{layer_name}.weight"] = dequantize_weight_from_fp8(W, s)
del state_dict[f"{layer_name}.weight_scale_inv"]
# Original DeepSeek packing
if quantization_format == "fp8":
# Dequantize
for k in state_dict_keys:
if k.endswith("scale_inv"):
layer_name, _ = k.rsplit(".", 1)

W = state_dict[f"{layer_name}.weight"].to(dtype)
s = state_dict[f"{layer_name}.weight_scale_inv"].to(dtype)

state_dict[f"{layer_name}.weight"] = dequantize_weight_from_fp8(W, s)
del state_dict[f"{layer_name}.weight_scale_inv"]

# KIMI-K2 compressed tensors packing
elif quantization_format == "int4":
# Dequantize
for k in state_dict_keys:
if k.endswith("weight_packed"):
layer_name, _ = k.rsplit(".", 1)

weight_shape = state_dict[f"{layer_name}.weight_shape"]

qweight_packed = state_dict[f"{layer_name}.weight_packed"]
qweight = unpack_from_int32(qweight_packed, num_bits=4, shape=weight_shape)

scale = state_dict[f"{layer_name}.weight_scale"]

W = (qweight.view(*scale.shape, -1) * scale[..., None]).view_as(qweight).to(dtype)

state_dict[f"{layer_name}.weight"] = W
del state_dict[f"{layer_name}.weight_packed"]
del state_dict[f"{layer_name}.weight_shape"]
del state_dict[f"{layer_name}.weight_scale"]


def can_dequantize(
state_dict: dict[str, torch.Tensor],
quantization_format: str = "fp8"
) -> bool:
assert quantization_format in ["fp8", "int4"]
# Original DeepSeek packing
if quantization_format == "fp8":
for k, v in state_dict.items():
if v.dtype in FP8_DTYPES and f"{k}_scale_inv" not in state_dict:
return False

# KIMI-K2 compressed tensors packing
elif quantization_format == "int4":
for k, v in state_dict.items():
if k.endswith("weight_packed"):
layer_name, _ = k.rsplit(".", 1)
if f"{layer_name}.weight_shape" not in state_dict or f"{layer_name}.weight_scale" not in state_dict:
return False
return True


def can_dequantize_from_fp8(state_dict: dict[str, torch.Tensor]) -> bool:
for k, v in state_dict.items():
if v.dtype in FP8_DTYPES and f"{k}_scale_inv" not in state_dict:
return False
return True
def infer_quantization_format(config: AutoConfig) -> str:
quant_method = config.quantization_config["quant_method"]
# DeepSeek format
if quant_method == "fp8":
return "fp8"
# KIMI-K2 format
elif quant_method == "compressed-tensors":
if config.quantization_config["config_groups"]["group_0"]["weights"]["num_bits"] == 4:
return "int4"
else:
raise ValueError("Only 4-bit quantization is supported.")
else:
raise ValueError("Unknown or unsupported quantization method.")