Skip to content

Commit 5086918

Browse files
committed
enable tracker to track extra metadata
Signed-off-by: Dushyant Behl <[email protected]>
1 parent b0c170c commit 5086918

File tree

2 files changed

+35
-11
lines changed

2 files changed

+35
-11
lines changed

tuning/sft_trainer.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from tuning.data import tokenizer_data_utils
2727
from tuning.utils.config_utils import get_hf_peft_config
2828
from tuning.utils.data_type_utils import get_torch_dtype
29-
from tuning.tracker.tracker import Tracker
29+
from tuning.tracker.tracker import Tracker, get_tracker
3030
from tuning.tracker.aimstack_tracker import AimStackTracker
3131

3232
logger = logging.get_logger("sft_trainer")
@@ -92,7 +92,7 @@ def train(
9292
Union[peft_config.LoraConfig, peft_config.PromptTuningConfig]
9393
] = None,
9494
callbacks: Optional[List[TrainerCallback]] = None,
95-
tracker: Optional[Tracker] = None,
95+
tracker: Optional[Tracker] = None
9696
):
9797
"""Call the SFTTrainer
9898
@@ -285,6 +285,11 @@ def main(**kwargs):
285285
choices=["pt", "lora", None, "none"],
286286
default="pt",
287287
)
288+
parser.add_argument(
289+
"--extra_metadata",
290+
type=str,
291+
default=None,
292+
)
288293
(
289294
model_args,
290295
data_args,
@@ -311,14 +316,7 @@ def main(**kwargs):
311316
tracker_config=None
312317

313318
# Initialize the tracker early so we can calculate custom metrics like model_load_time.
314-
tracker_name = training_args.tracker
315-
if tracker_name == 'aim':
316-
if tracker_config is not None:
317-
tracker = AimStackTracker(tracker_config)
318-
else:
319-
logger.error("Tracker name is set to "+tracker_name+" but config is None.")
320-
else:
321-
tracker = Tracker()
319+
tracker = get_tracker(tracker_name, tracker_config)
322320

323321
# Initialize callbacks
324322
file_logger_callback = FileLoggingCallback(logger)
@@ -329,6 +327,14 @@ def main(**kwargs):
329327
if tracker_callback is not None:
330328
callbacks.append(tracker_callback)
331329

330+
# track extra metadata
331+
if additional.extra_metadata is not None:
332+
try:
333+
metadata = json.loads(additional.extra_metadata)
334+
tracker.track_metadata(metadata)
335+
except:
336+
logger.error("failed while parsing extra metadata. pass a valid json")
337+
332338
train(model_args, data_args, training_args, tune_config, callbacks, tracker)
333339

334340
if __name__ == "__main__":

tuning/tracker/tracker.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# Generic Tracker API
22

3+
from tuning.tracker.aimstack_tracker import AimStackTracker
4+
35
class Tracker:
46
def __init__(self, tracker_config) -> None:
57
self.config = tracker_config
@@ -8,4 +10,20 @@ def get_hf_callback():
810
return None
911

1012
def track(self, metric, name, stage):
11-
pass
13+
pass
14+
15+
# Metadata passed here is supposed to be a KV object
16+
# Key being the name and value being the metric to track.
17+
def track_metadata(self, metadata=None):
18+
if metadata is None or not isinstance(metadata, dict):
19+
return
20+
for k, v in metadata.items():
21+
self.track(name=k, metric=v)
22+
23+
def get_tracker(tracker_name, tracker_config):
24+
if tracker_name == 'aim':
25+
if tracker_config is not None:
26+
tracker = AimStackTracker(tracker_config)
27+
else:
28+
tracker = Tracker()
29+
return tracker

0 commit comments

Comments
 (0)