Skip to content

Commit 2e0081b

Browse files
authored
[NVIDIA#6530][fix] Fix script when using calibration tensors from modelopt (NVIDIA#6803)
Signed-off-by: Aurelien Chartier <[email protected]>
1 parent f68e03e commit 2e0081b

File tree

1 file changed

+25
-13
lines changed

1 file changed

+25
-13
lines changed

examples/quantization/quantize_mixed_precision_moe.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,16 @@ def load_and_preprocess_state_dict(modelopt_state_root, world_size=8):
4545
state_dict_list = []
4646
# load amax from state dict
4747
for rank in range(world_size):
48-
state_dict_list.append(
49-
torch.load(
50-
f"{modelopt_state_root}/amax_dict_rank{rank}-mp{world_size}.pt",
51-
map_location="cuda:0"))
48+
amax_file = f"{modelopt_state_root}/amax_dict_rank{rank}-mp{world_size}.pt"
49+
if os.path.exists(amax_file):
50+
state_dict_list.append(torch.load(amax_file, map_location="cuda:0"))
51+
else:
52+
print(f"WARNING: amax file not found: {amax_file}")
53+
54+
if not state_dict_list:
55+
print("ERROR: No amax files loaded!")
56+
return {}
57+
5258
# calculate the max across all TP ranks
5359
merged_state_dict = state_dict_list[0]
5460
for rank in range(world_size):
@@ -232,15 +238,18 @@ def get_file_name(layer):
232238
continue
233239
new_safetensors.update({key: get_tensor(key)})
234240

241+
# Process activation scales for all ranks
242+
if os.path.isdir(args.act_scales):
243+
# Extract activation scales
244+
renamed_state_dict = load_and_preprocess_state_dict(
245+
modelopt_state_root=args.act_scales, world_size=8)
246+
scales = get_scales_from_amax(start_layer=start_layer,
247+
end_layer=end_layer,
248+
renamed_state_dict=renamed_state_dict)
249+
new_safetensors.update(scales)
250+
235251
if args.rank == 0:
236-
if os.path.isdir(args.act_scales):
237-
# Extract activation scales
238-
renamed_state_dict = load_and_preprocess_state_dict(
239-
modelopt_state_root=args.act_scales, world_size=8)
240-
get_scales_from_amax(start_layer=start_layer,
241-
end_layer=end_layer,
242-
renamed_state_dict=renamed_state_dict)
243-
else:
252+
if not os.path.isdir(args.act_scales):
244253
input_scales = safe_open(args.act_scales, "pt")
245254
for k in input_scales.keys():
246255
new_safetensors.update({k: input_scales.get_tensor(k)})
@@ -259,7 +268,10 @@ def get_file_name(layer):
259268
]
260269
for name in names:
261270
shutil.copy(os.path.join(model_dir, name), output_dir)
262-
shutil.copy(args.act_scales, output_dir)
271+
if os.path.isdir(args.act_scales):
272+
shutil.copytree(args.act_scales, output_dir, dirs_exist_ok=True)
273+
else:
274+
shutil.copy(args.act_scales, output_dir)
263275

264276
# config.json
265277
del config['quantization_config']

0 commit comments

Comments
 (0)