2626from tuning .data import tokenizer_data_utils
2727from tuning .utils .config_utils import get_hf_peft_config
2828from tuning .utils .data_type_utils import get_torch_dtype
29- from tuning .tracker .tracker import Tracker
29+ from tuning .tracker .tracker import Tracker , get_tracker
3030from tuning .tracker .aimstack_tracker import AimStackTracker
3131
3232logger = logging .get_logger ("sft_trainer" )
@@ -92,7 +92,7 @@ def train(
9292 Union [peft_config .LoraConfig , peft_config .PromptTuningConfig ]
9393 ] = None ,
9494 callbacks : Optional [List [TrainerCallback ]] = None ,
95- tracker : Optional [Tracker ] = None ,
95+ tracker : Optional [Tracker ] = None
9696):
9797 """Call the SFTTrainer
9898
@@ -285,6 +285,11 @@ def main(**kwargs):
285285 choices = ["pt" , "lora" , None , "none" ],
286286 default = "pt" ,
287287 )
288+ parser .add_argument (
289+ "--extra_metadata" ,
290+ type = str ,
291+ default = None ,
292+ )
288293 (
289294 model_args ,
290295 data_args ,
@@ -311,14 +316,7 @@ def main(**kwargs):
311316 tracker_config = None
312317
313318 # Initialize the tracker early so we can calculate custom metrics like model_load_time.
314- tracker_name = training_args .tracker
315- if tracker_name == 'aim' :
316- if tracker_config is not None :
317- tracker = AimStackTracker (tracker_config )
318- else :
319- logger .error ("Tracker name is set to " + tracker_name + " but config is None." )
320- else :
321- tracker = Tracker ()
319+ tracker = get_tracker (tracker_name , tracker_config )
322320
323321 # Initialize callbacks
324322 file_logger_callback = FileLoggingCallback (logger )
@@ -329,6 +327,14 @@ def main(**kwargs):
329327 if tracker_callback is not None :
330328 callbacks .append (tracker_callback )
331329
330+ # track extra metadata
331+ if additional .extra_metadata is not None :
332+ try :
333+ metadata = json .loads (additional .extra_metadata )
334+ tracker .track_metadata (metadata )
335+ except :
336+ logger .error ("failed while parsing extra metadata. pass a valid json" )
337+
332338 train (model_args , data_args , training_args , tune_config , callbacks , tracker )
333339
334340if __name__ == "__main__" :
0 commit comments