From 8edd0b0210cf59262ca030bf048b9580f4d72a1f Mon Sep 17 00:00:00 2001 From: Vassilis Vassiliadis Date: Fri, 16 Feb 2024 14:03:33 +0000 Subject: [PATCH 1/2] feat: custom callbacks for train() and return TrainOutput plus model_load_time Signed-off-by: Vassilis Vassiliadis --- tuning/sft_trainer.py | 39 +++++++++++++++++++++++++++++++-------- 1 file changed, 31 insertions(+), 8 deletions(-) diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 9ed3109bb..4e96327cf 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -1,8 +1,9 @@ # Standard from datetime import datetime -from typing import Optional, Union +from typing import Optional, Union, List, NamedTuple import json import os +import time # Third Party from peft.utils.other import fsdp_auto_wrap_policy @@ -40,6 +41,12 @@ def on_save(self, args, state, control, **kwargs): os.remove(os.path.join(checkpoint_path, "pytorch_model.bin")) +class EnhancedTrainOutput(NamedTuple): + train_output: transformers.trainer.TrainOutput + + model_load_time: float + + class FileLoggingCallback(TrainerCallback): """Exports metrics, e.g., training loss to a file in the checkpoint directory.""" @@ -84,7 +91,8 @@ def train( peft_config: Optional[ Union[peft_config.LoraConfig, peft_config.PromptTuningConfig] ] = None, -): + callbacks: Optional[List[TrainerCallback]] = None, +) -> EnhancedTrainOutput: """Call the SFTTrainer Args: @@ -92,9 +100,14 @@ def train( data_args: tuning.config.configs.DataArguments train_args: tuning.config.configs.TrainingArguments peft_config: peft_config.LoraConfig for Lora tuning | \ - peft_config.PromptTuningConfig for prompt tuning | \ - None for fine tuning + peft_config.PromptTuningConfig for prompt tuning | \ + None for fine tuning The peft configuration to pass to trainer + callbacks: optional callbacks for SFTTrainer + + Returns: + A EnhancedTrainOutput containing the TrainOutput of SFTTrainer.train() plus extra metrics + such as model_load_time """ run_distributed = int(os.environ.get("WORLD_SIZE", "1")) > 1 @@ -116,13 +129,15 @@ def train( train_args.fsdp_config = {"xla": False} task_type = "CAUSAL_LM" - model = AutoModelForCausalLM.from_pretrained( + + model_load_time = time.time() + model = transformers.AutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, cache_dir=train_args.cache_dir, torch_dtype=get_torch_dtype(model_args.torch_dtype), use_flash_attention_2=model_args.use_flash_attn, ) - + model_load_time = time.time() - model_load_time peft_config = get_hf_peft_config(task_type, peft_config) model.gradient_checkpointing_enable() @@ -212,10 +227,12 @@ def train( formatted_validation_dataset = json_dataset["validation"].map(format_dataset) logger.info(f"Validation dataset length is {len(formatted_validation_dataset)}") + callbacks = callbacks or [] + aim_callback = get_aimstack_callback() file_logger_callback = FileLoggingCallback(logger) peft_saving_callback = PeftSavingCallback() - callbacks = [aim_callback, peft_saving_callback, file_logger_callback] + callbacks.extend([aim_callback, peft_saving_callback, file_logger_callback]) if train_args.packing: logger.info("Packing is set to True") @@ -260,7 +277,13 @@ def train( trainer.accelerator.state.fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy( model ) - trainer.train() + + train_output: "transformers.trainer.TrainOutput" = trainer.train() + + return EnhancedTrainOutput( + train_output=train_output, + model_load_time=model_load_time, + ) def main(**kwargs): From 82f99dc83114ee5b215c8ed41577d8b9e020e2c5 Mon Sep 17 00:00:00 2001 From: Vassilis Vassiliadis Date: Mon, 19 Feb 2024 11:25:40 +0000 Subject: [PATCH 2/2] refactor: run isort on sft_trainer.py Signed-off-by: Vassilis Vassiliadis --- tuning/sft_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 4e96327cf..bd12d6520 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -1,6 +1,6 @@ # Standard from datetime import datetime -from typing import Optional, Union, List, NamedTuple +from typing import List, NamedTuple, Optional, Union import json import os import time