diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 9ed3109bb..bd12d6520 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -1,8 +1,9 @@ # Standard from datetime import datetime -from typing import Optional, Union +from typing import List, NamedTuple, Optional, Union import json import os +import time # Third Party from peft.utils.other import fsdp_auto_wrap_policy @@ -40,6 +41,12 @@ def on_save(self, args, state, control, **kwargs): os.remove(os.path.join(checkpoint_path, "pytorch_model.bin")) +class EnhancedTrainOutput(NamedTuple): + train_output: transformers.trainer.TrainOutput + + model_load_time: float + + class FileLoggingCallback(TrainerCallback): """Exports metrics, e.g., training loss to a file in the checkpoint directory.""" @@ -84,7 +91,8 @@ def train( peft_config: Optional[ Union[peft_config.LoraConfig, peft_config.PromptTuningConfig] ] = None, -): + callbacks: Optional[List[TrainerCallback]] = None, +) -> EnhancedTrainOutput: """Call the SFTTrainer Args: @@ -92,9 +100,14 @@ def train( data_args: tuning.config.configs.DataArguments train_args: tuning.config.configs.TrainingArguments peft_config: peft_config.LoraConfig for Lora tuning | \ - peft_config.PromptTuningConfig for prompt tuning | \ - None for fine tuning + peft_config.PromptTuningConfig for prompt tuning | \ + None for fine tuning The peft configuration to pass to trainer + callbacks: optional callbacks for SFTTrainer + + Returns: + A EnhancedTrainOutput containing the TrainOutput of SFTTrainer.train() plus extra metrics + such as model_load_time """ run_distributed = int(os.environ.get("WORLD_SIZE", "1")) > 1 @@ -116,13 +129,15 @@ def train( train_args.fsdp_config = {"xla": False} task_type = "CAUSAL_LM" - model = AutoModelForCausalLM.from_pretrained( + + model_load_time = time.time() + model = transformers.AutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, cache_dir=train_args.cache_dir, torch_dtype=get_torch_dtype(model_args.torch_dtype), use_flash_attention_2=model_args.use_flash_attn, ) - + model_load_time = time.time() - model_load_time peft_config = get_hf_peft_config(task_type, peft_config) model.gradient_checkpointing_enable() @@ -212,10 +227,12 @@ def train( formatted_validation_dataset = json_dataset["validation"].map(format_dataset) logger.info(f"Validation dataset length is {len(formatted_validation_dataset)}") + callbacks = callbacks or [] + aim_callback = get_aimstack_callback() file_logger_callback = FileLoggingCallback(logger) peft_saving_callback = PeftSavingCallback() - callbacks = [aim_callback, peft_saving_callback, file_logger_callback] + callbacks.extend([aim_callback, peft_saving_callback, file_logger_callback]) if train_args.packing: logger.info("Packing is set to True") @@ -260,7 +277,13 @@ def train( trainer.accelerator.state.fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy( model ) - trainer.train() + + train_output: "transformers.trainer.TrainOutput" = trainer.train() + + return EnhancedTrainOutput( + train_output=train_output, + model_load_time=model_load_time, + ) def main(**kwargs):