Skip to content

Commit 8661cc0

Browse files
committed
change default output path for aim run export
Signed-off-by: Dushyant Behl <[email protected]>
1 parent a28aa7b commit 8661cc0

File tree

5 files changed

+33
-17
lines changed

5 files changed

+33
-17
lines changed

tuning/sft_trainer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
from tuning.data import tokenizer_data_utils
2727
from tuning.utils.config_utils import get_hf_peft_config
2828
from tuning.utils.data_type_utils import get_torch_dtype
29-
from tuning.tracker.tracker import Tracker, TrackerFactory
29+
from tuning.trackers.tracker import Tracker
30+
from tuning.trackers.tracker_factory import get_tracker
3031

3132
logger = logging.get_logger("sft_trainer")
3233

@@ -331,7 +332,7 @@ def main(**kwargs):
331332
callbacks = [peft_saving_callback, file_logger_callback]
332333

333334
# Initialize the tracker
334-
tracker = TrackerFactory.get_tracker(tracker_name, tracker_config)
335+
tracker = get_tracker(tracker_name, tracker_config)
335336
tracker_callback = tracker.get_hf_callback()
336337
if tracker_callback is not None:
337338
callbacks.append(tracker_callback)
File renamed without changes.

tuning/tracker/aimstack_tracker.py renamed to tuning/trackers/aimstack_tracker.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class CustomAimCallback(AimCallback):
1111

1212
# A path to export run hash generated by Aim
1313
# This is used to link back to the expriments from outside aimstack
14-
aim_run_hash_export_path = None
14+
run_hash_export_path = None
1515

1616
def on_init_end(self, args, state, control, **kwargs):
1717

@@ -20,9 +20,18 @@ def on_init_end(self, args, state, control, **kwargs):
2020

2121
self.setup() # initializes the run_hash
2222

23-
# store the run hash
24-
if self.aim_run_hash_export_path:
25-
with open(self.aim_run_hash_export_path, 'w') as f:
23+
# Store the run hash
24+
# Change default run hash path to output directory
25+
if self.run_hash_export_path is None:
26+
if args and args.output_dir:
27+
# args.output_dir/.aim_run_hash
28+
self.run_hash_export_path = os.path.join(
29+
args.output_dir,
30+
'.aim_run_hash'
31+
)
32+
33+
if self.run_hash_export_path:
34+
with open(self.run_hash_export_path, 'w') as f:
2635
f.write('{\"run_hash\":\"'+str(self._run.hash)+'\"}\n')
2736

2837
def on_train_begin(self, args, state, control, model=None, **kwargs):
@@ -60,7 +69,7 @@ def get_hf_callback(self):
6069
else:
6170
aim_callback = CustomAimCallback(experiment=exp)
6271

63-
aim_callback.aim_run_hash_export_path = hash_export_path
72+
aim_callback.run_hash_export_path = hash_export_path
6473
self.hf_callback = aim_callback
6574
return self.hf_callback
6675

tuning/tracker/tracker.py renamed to tuning/trackers/tracker.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ def __init__(self, name=None, tracker_config=None) -> None:
99
else:
1010
self._name = name
1111

12-
def get_hf_callback():
12+
# we use args here to denote any argument.
13+
def get_hf_callback(self):
1314
return None
1415

1516
def track(self, metric, name, stage):
@@ -18,12 +19,4 @@ def track(self, metric, name, stage):
1819
# Object passed here is supposed to be a KV object
1920
# for the parameters to be associated with a run
2021
def set_params(self, params, name):
21-
pass
22-
23-
class TrackerFactory:
24-
def get_tracker(tracker_name, tracker_config):
25-
for T in Tracker.__subclasses__():
26-
if T._name == tracker_name:
27-
return T(tracker_config)
28-
else:
29-
return Tracker()
22+
pass

tuning/trackers/tracker_factory.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from .tracker import Tracker
2+
from .aimstack_tracker import AimStackTracker
3+
4+
REGISTERED_TRACKERS = {
5+
"aim" : AimStackTracker
6+
}
7+
8+
def get_tracker(tracker_name, tracker_config):
9+
if tracker_name in REGISTERED_TRACKERS:
10+
T = REGISTERED_TRACKERS[tracker_name]
11+
return T(tracker_config)
12+
else:
13+
return Tracker()

0 commit comments

Comments
 (0)