@@ -498,8 +498,7 @@ def export_compressed_model(
498498 gptq_config = self .gptq_config if hasattr (self , "gptq_config" ) else {}
499499
500500 autoround_config = self .autoround_config if hasattr (self , "autoround_config" ) else {}
501-
502- if gptq_config or (autoround_config and device == "xpu" ):
501+ if gptq_config :
503502 for k , v in weight_config .items ():
504503 logger .debug (f"Compressing { k } on device { device } " )
505504 if v ["dtype" ] == "fp32" :
@@ -558,19 +557,54 @@ def export_compressed_model(
558557 )
559558 new_module .pack (int_weight , gptq_scale , gptq_zp , m .bias , gptq_perm )
560559 set_module (self .model , k , new_module )
561- elif autoround_config and (device == "cpu" or device == "auto" ):
562- from auto_round .export .export_to_itrex .export import pack_model # pylint: disable=E0401
560+ elif autoround_config :
561+ if device == "xpu" :
562+ for k , v in weight_config .items ():
563+ logger .debug (f"Compressing { k } on device { device } " )
564+ if v ["dtype" ] == "fp32" :
565+ continue
566+ else :
567+ dtype = v ["dtype" ]
568+ num_bits = v ["bits" ]
569+ group_size = v ["group_size" ]
570+ scheme = v ["scheme" ]
571+ m = fetch_module (self .model , k )
572+ autoround_conf = autoround_config [k ]
573+ fp32_weight = m .weight .data
574+ autoround_scale = torch .tensor (autoround_conf ["scale" ], dtype = torch .float32 )
575+ autoround_zp = None if scheme == "sym" else torch .tensor (autoround_conf ["zero" ], dtype = torch .int32 )
576+ int_weight = quant_weight_w_scale (fp32_weight , autoround_scale , autoround_zp , group_size )
577+ int_weight = int_weight .type (torch .int32 )
578+ new_module = WeightOnlyLinear (
579+ m .in_features ,
580+ m .out_features ,
581+ num_bits ,
582+ group_size ,
583+ dtype = dtype ,
584+ zp = autoround_zp is not None ,
585+ bias = m .bias is not None ,
586+ g_idx = None ,
587+ compression_dtype = compression_dtype ,
588+ compression_dim = compression_dim ,
589+ scale_dtype = scale_dtype ,
590+ device = device ,
591+ use_optimum_format = use_optimum_format ,
592+ )
593+ new_module .pack (int_weight , autoround_scale , autoround_zp , m .bias , None )
594+ set_module (self .model , k , new_module )
595+ else :
596+ from auto_round .export .export_to_itrex .export import pack_model # pylint: disable=E0401
563597
564- self .model = pack_model (
565- self .model ,
566- weight_config = autoround_config ,
567- enable_full_range = enable_full_range ,
568- compression_dtype = compression_dtype ,
569- compression_dim = compression_dim ,
570- device = device ,
571- use_optimum_format = use_optimum_format ,
572- inplace = True ,
573- )
598+ self .model = pack_model (
599+ self .model ,
600+ weight_config = autoround_config ,
601+ enable_full_range = enable_full_range ,
602+ compression_dtype = compression_dtype ,
603+ compression_dim = compression_dim ,
604+ device = device ,
605+ use_optimum_format = use_optimum_format ,
606+ inplace = True ,
607+ )
574608 else :
575609 for k , v in weight_config .items ():
576610 logger .debug (f"Compressing { k } on device { device } " )
0 commit comments