Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 31 additions & 8 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Standard
from datetime import datetime
from typing import Optional, Union
from typing import List, NamedTuple, Optional, Union
import json
import os
import time

# Third Party
from peft.utils.other import fsdp_auto_wrap_policy
Expand Down Expand Up @@ -40,6 +41,12 @@ def on_save(self, args, state, control, **kwargs):
os.remove(os.path.join(checkpoint_path, "pytorch_model.bin"))


class EnhancedTrainOutput(NamedTuple):
train_output: transformers.trainer.TrainOutput

model_load_time: float


class FileLoggingCallback(TrainerCallback):
"""Exports metrics, e.g., training loss to a file in the checkpoint directory."""

Expand Down Expand Up @@ -84,17 +91,23 @@ def train(
peft_config: Optional[
Union[peft_config.LoraConfig, peft_config.PromptTuningConfig]
] = None,
):
callbacks: Optional[List[TrainerCallback]] = None,
) -> EnhancedTrainOutput:
"""Call the SFTTrainer

Args:
model_args: tuning.config.configs.ModelArguments
data_args: tuning.config.configs.DataArguments
train_args: tuning.config.configs.TrainingArguments
peft_config: peft_config.LoraConfig for Lora tuning | \
peft_config.PromptTuningConfig for prompt tuning | \
None for fine tuning
peft_config.PromptTuningConfig for prompt tuning | \
None for fine tuning
The peft configuration to pass to trainer
callbacks: optional callbacks for SFTTrainer

Returns:
A EnhancedTrainOutput containing the TrainOutput of SFTTrainer.train() plus extra metrics
such as model_load_time
"""
run_distributed = int(os.environ.get("WORLD_SIZE", "1")) > 1

Expand All @@ -116,13 +129,15 @@ def train(
train_args.fsdp_config = {"xla": False}

task_type = "CAUSAL_LM"
model = AutoModelForCausalLM.from_pretrained(

model_load_time = time.time()
model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=train_args.cache_dir,
torch_dtype=get_torch_dtype(model_args.torch_dtype),
use_flash_attention_2=model_args.use_flash_attn,
)

model_load_time = time.time() - model_load_time
peft_config = get_hf_peft_config(task_type, peft_config)

model.gradient_checkpointing_enable()
Expand Down Expand Up @@ -212,10 +227,12 @@ def train(
formatted_validation_dataset = json_dataset["validation"].map(format_dataset)
logger.info(f"Validation dataset length is {len(formatted_validation_dataset)}")

callbacks = callbacks or []

aim_callback = get_aimstack_callback()
file_logger_callback = FileLoggingCallback(logger)
peft_saving_callback = PeftSavingCallback()
callbacks = [aim_callback, peft_saving_callback, file_logger_callback]
callbacks.extend([aim_callback, peft_saving_callback, file_logger_callback])

if train_args.packing:
logger.info("Packing is set to True")
Expand Down Expand Up @@ -260,7 +277,13 @@ def train(
trainer.accelerator.state.fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(
model
)
trainer.train()

train_output: "transformers.trainer.TrainOutput" = trainer.train()

return EnhancedTrainOutput(
train_output=train_output,
model_load_time=model_load_time,
)
Comment on lines +281 to +286
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dushyantbehl I need the train() method to return the TrainOutput + model_load_time.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you have any specific need for that one metric coming out of the train() function?

Could you not fetch it from the tracker?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'd like to programmatically execute a large array of fms-hf-tuning experiments to collect data (things like performance of model, system metrics, etc). Some of these runs may take place on machines which do not have network connectivity. Other runs we may not want to register to AIM at all as they contain experimental code/models/datasets and we wouldn't want to pollute the AIM database with data that we're not sure we'd like to keep around.

As a result, we need to collect the measured metrics (trainoutput + model_load_time) straight from the return value of the train() method.



def main(**kwargs):
Expand Down