@@ -45,10 +45,16 @@ def load_and_preprocess_state_dict(modelopt_state_root, world_size=8):
4545 state_dict_list = []
4646 # load amax from state dict
4747 for rank in range (world_size ):
48- state_dict_list .append (
49- torch .load (
50- f"{ modelopt_state_root } /amax_dict_rank{ rank } -mp{ world_size } .pt" ,
51- map_location = "cuda:0" ))
48+ amax_file = f"{ modelopt_state_root } /amax_dict_rank{ rank } -mp{ world_size } .pt"
49+ if os .path .exists (amax_file ):
50+ state_dict_list .append (torch .load (amax_file , map_location = "cuda:0" ))
51+ else :
52+ print (f"WARNING: amax file not found: { amax_file } " )
53+
54+ if not state_dict_list :
55+ print ("ERROR: No amax files loaded!" )
56+ return {}
57+
5258 # calculate the max across all TP ranks
5359 merged_state_dict = state_dict_list [0 ]
5460 for rank in range (world_size ):
@@ -232,15 +238,18 @@ def get_file_name(layer):
232238 continue
233239 new_safetensors .update ({key : get_tensor (key )})
234240
241+ # Process activation scales for all ranks
242+ if os .path .isdir (args .act_scales ):
243+ # Extract activation scales
244+ renamed_state_dict = load_and_preprocess_state_dict (
245+ modelopt_state_root = args .act_scales , world_size = 8 )
246+ scales = get_scales_from_amax (start_layer = start_layer ,
247+ end_layer = end_layer ,
248+ renamed_state_dict = renamed_state_dict )
249+ new_safetensors .update (scales )
250+
235251 if args .rank == 0 :
236- if os .path .isdir (args .act_scales ):
237- # Extract activation scales
238- renamed_state_dict = load_and_preprocess_state_dict (
239- modelopt_state_root = args .act_scales , world_size = 8 )
240- get_scales_from_amax (start_layer = start_layer ,
241- end_layer = end_layer ,
242- renamed_state_dict = renamed_state_dict )
243- else :
252+ if not os .path .isdir (args .act_scales ):
244253 input_scales = safe_open (args .act_scales , "pt" )
245254 for k in input_scales .keys ():
246255 new_safetensors .update ({k : input_scales .get_tensor (k )})
@@ -259,7 +268,10 @@ def get_file_name(layer):
259268 ]
260269 for name in names :
261270 shutil .copy (os .path .join (model_dir , name ), output_dir )
262- shutil .copy (args .act_scales , output_dir )
271+ if os .path .isdir (args .act_scales ):
272+ shutil .copytree (args .act_scales , output_dir , dirs_exist_ok = True )
273+ else :
274+ shutil .copy (args .act_scales , output_dir )
263275
264276 # config.json
265277 del config ['quantization_config' ]
0 commit comments