Skip to content

Conversation

@VassilisVassiliadis
Copy link
Contributor

This PR resolves #33

Changes:

  • support for custom callbacks
  • record time it took to load model
  • train() method now returns output of SFTTrainer.train() as well as model_load_time

@dushyantbehl
Copy link
Collaborator

@VassilisVassiliadis can this be closed if you agree with the design and direction of PR #50 ?

Comment on lines +281 to +286
train_output: "transformers.trainer.TrainOutput" = trainer.train()

return EnhancedTrainOutput(
train_output=train_output,
model_load_time=model_load_time,
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dushyantbehl I need the train() method to return the TrainOutput + model_load_time.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you have any specific need for that one metric coming out of the train() function?

Could you not fetch it from the tracker?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'd like to programmatically execute a large array of fms-hf-tuning experiments to collect data (things like performance of model, system metrics, etc). Some of these runs may take place on machines which do not have network connectivity. Other runs we may not want to register to AIM at all as they contain experimental code/models/datasets and we wouldn't want to pollute the AIM database with data that we're not sure we'd like to keep around.

As a result, we need to collect the measured metrics (trainoutput + model_load_time) straight from the return value of the train() method.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add support for collecting metrics programmatically

2 participants