66from sklearn .utils .class_weight import compute_class_weight
77
88from logger import get_logger
9- from mains import Cross_Valid
9+ from mains import Cross_Valid , Multiprocessor
1010import models .metric as module_metric
1111from utils import (
1212 prepare_device ,
1919if is_apex_available ():
2020 from apex import amp
2121
22- logger = get_logger ("train" )
2322
23+ def train_mp (config ):
24+ k_fold = config ["cross_validation" ]["k_fold" ]
25+ do_mp = config .run_args .mp
26+ n_jobs = config .run_args .n_jobs
27+ assert n_jobs <= k_fold , "n_jobs can not be more than k_fold."
28+
29+ results = []
30+ fold_idx = 0
31+ while fold_idx < k_fold :
32+ mp = Multiprocessor ()
33+ job_idx = 0
34+ while job_idx < n_jobs and fold_idx < k_fold :
35+ mp .run (train , config , do_mp , fold_idx )
36+ job_idx += 1
37+ fold_idx += 1
38+ ret = mp .wait () # get results of processes
39+ results .extend (ret )
40+
41+ return results
42+
43+
44+ def train (config , do_mp = False , fold_idx = 0 ):
45+ # different logging when multiprocessing
46+ if do_mp :
47+ config .set_log (log_name = f"fold_{ fold_idx } .log" )
48+ else :
49+ config .set_log ()
50+ logger = get_logger ("train" )
2451
25- def train (config ):
2652 # setup GPU device if available, move model into configured device
2753 device , device_ids = prepare_device (config ["n_gpu" ])
2854
@@ -83,12 +109,14 @@ def train(config):
83109 k_fold = config ["cross_validation" ]["k_fold" ]
84110
85111 results = pd .DataFrame ()
86- Cross_Valid .create_CV (repeat_time , k_fold )
112+ Cross_Valid .create_CV (repeat_time , k_fold , fold_idx = fold_idx )
87113 start = time .time ()
88114 for t in range (repeat_time ):
89115 if k_fold > 1 : # cross validation enabled
90116 train_datasets ["data" ].split_cv_indexes (k_fold )
91- for k in range (k_fold ):
117+ # 1 loop for multi-process; k_fold loops for single-process
118+ k_time = 1 if do_mp else k_fold
119+ for k in range (k_time ):
92120 # data_loaders
93121 train_data_loaders = dict ()
94122 valid_data_loaders = dict ()
@@ -165,7 +193,7 @@ def train(config):
165193 train_log = trainer .train ()
166194 results = pd .concat ((results , train_log ), axis = 1 )
167195
168- if k_fold > 1 :
196+ if k_time > 1 :
169197 Cross_Valid .next_fold ()
170198
171199 if repeat_time > 1 :
@@ -184,4 +212,7 @@ def train(config):
184212
185213 logger .info (msg )
186214
215+ max_min , mnt_metric = config ["trainer" ]["kwargs" ]["monitor" ].split ()
216+ result = result .at [mnt_metric , "mean" ]
217+
187218 return result
0 commit comments