Skip to content

Commit 0c6cc06

Browse files
author
deeperlearner
committed
Merge branch 'develop', version v4.1.0
Add new run_args: `--mp` and `--n_jobs`
2 parents 8c03a93 + db27000 commit 0c6cc06

File tree

13 files changed

+132
-55
lines changed

13 files changed

+132
-55
lines changed

logger/logger.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@ def setup_logging(
1919
if log_config.is_file():
2020
config = read_json(log_config)
2121
# modify logging paths based on run config
22-
for _, handler in config["handlers"].items():
23-
if "filename" in handler:
24-
if filename is None:
25-
handler["filename"] = str(save_dir / handler["filename"])
22+
for handler_k, handler_v in config["handlers"].items():
23+
if "filename" in handler_v:
24+
if filename is None or handler_k != "info_file_handler":
25+
handler_v["filename"] = str(save_dir / handler_v["filename"])
2626
else:
27-
handler["filename"] = str(save_dir / filename)
27+
handler_v["filename"] = str(save_dir / filename)
2828

2929
logging.config.dictConfig(config)
3030
else:

logger/logger_config_mp.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,6 @@
3030
},
3131
"root": {
3232
"level": "INFO",
33-
"handlers": ["info_file_handler"]
33+
"handlers": ["warning_file_handler"]
3434
}
3535
}

mains/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from .bootstrap import bootstrapping
12
from .cross_validation import Cross_Valid
2-
from .train import train
3+
from .multiprocess import Multiprocessor
4+
from .train import train, train_mp
35
from .test import test
File renamed without changes.

mains/main.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
sys.path.insert(1, os.path.join(sys.path[0], ".."))
1010
from logger import get_logger
1111
from parse_config import ConfigParser
12-
from mains import train, test
12+
from mains import train, train_mp, test
1313
from utils import msg_box, consuming_time
1414

1515

@@ -18,13 +18,13 @@
1818

1919
# crutial args executed in scripts
2020
run_args = args.add_argument_group("run_args")
21+
run_args.add_argument("--optuna", action="store_true")
22+
run_args.add_argument("--mp", action="store_true", help="multiprocessing")
23+
run_args.add_argument("--n_jobs", default=2, type=int, help="number of jobs running at the same time")
2124
run_args.add_argument("-c", "--config", default="configs/config.json", type=str)
2225
run_args.add_argument("--mode", default="train", type=str)
23-
run_args.add_argument("--optuna", action="store_true")
2426
run_args.add_argument("--resume", default=None, type=str)
2527
run_args.add_argument("--run_id", default=None, type=str)
26-
run_args.add_argument("--log_name", default=None, type=str)
27-
run_args.add_argument("--mp", action="store_true", help="multiprocessing")
2828

2929
# custom cli options to modify configuration from default values given in json file.
3030
mod_args = args.add_argument_group("mod_args")
@@ -55,6 +55,7 @@
5555
test_args.add_argument("--output_path", default=None, type=str)
5656

5757
config = ConfigParser.from_args(args, options)
58+
config.set_log()
5859
logger = get_logger("main")
5960
mode = config.run_args.mode
6061
msg = msg_box(mode.upper())
@@ -66,6 +67,8 @@
6667
objective = config.init_obj(["optuna"])
6768
n_trials = config["optuna"]["n_trials"]
6869

70+
config.set_log(log_name="optuna.log")
71+
logger = get_logger("optuna")
6972
optuna.logging.enable_propagation()
7073
optuna.logging.disable_default_handler()
7174
direction = "maximize" if max_min == "max" else "minimize"
@@ -81,6 +84,9 @@
8184
msg += f"\nBest hyperparameters: {study.best_params}"
8285
logger.info(msg)
8386
else:
84-
train(config)
87+
if config.run_args.mp:
88+
train_mp(config)
89+
else:
90+
train(config)
8591
elif mode == "test":
8692
test(config)

mains/multiprocess.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from multiprocessing import Process, Queue
2+
3+
4+
# ref: https://stackoverflow.com/a/45829852/8380054
5+
class Multiprocessor():
6+
7+
def __init__(self):
8+
self.processes = []
9+
self.queue = Queue()
10+
11+
@staticmethod
12+
def _wrapper(func, queue, args, kwargs):
13+
ret = func(*args, **kwargs)
14+
queue.put(ret)
15+
16+
def run(self, func, *args, **kwargs):
17+
args2 = [func, self.queue, args, kwargs]
18+
p = Process(target=self._wrapper, args=args2)
19+
self.processes.append(p)
20+
p.start()
21+
22+
def wait(self):
23+
rets = []
24+
for p in self.processes:
25+
ret = self.queue.get()
26+
rets.append(ret)
27+
for p in self.processes:
28+
p.join()
29+
return rets

mains/test.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,10 @@
88
from tqdm import tqdm
99

1010
from logger import get_logger
11-
from mains import Cross_Valid
11+
from mains import Cross_Valid, bootstrapping
1212
import models.metric as module_metric
1313
from models.metric import MetricTracker
1414
from utils import prepare_device, get_by_path, msg_box, consuming_time
15-
from utils.bootstrap import bootstrapping
1615

1716
# fix random seeds for reproducibility
1817
SEED = 123

mains/train.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from sklearn.utils.class_weight import compute_class_weight
77

88
from logger import get_logger
9-
from mains import Cross_Valid
9+
from mains import Cross_Valid, Multiprocessor
1010
import models.metric as module_metric
1111
from utils import (
1212
prepare_device,
@@ -19,10 +19,36 @@
1919
if is_apex_available():
2020
from apex import amp
2121

22-
logger = get_logger("train")
2322

23+
def train_mp(config):
24+
k_fold = config["cross_validation"]["k_fold"]
25+
do_mp = config.run_args.mp
26+
n_jobs = config.run_args.n_jobs
27+
assert n_jobs <= k_fold, "n_jobs can not be more than k_fold."
28+
29+
results = []
30+
fold_idx = 0
31+
while fold_idx < k_fold:
32+
mp = Multiprocessor()
33+
job_idx = 0
34+
while job_idx < n_jobs and fold_idx < k_fold:
35+
mp.run(train, config, do_mp, fold_idx)
36+
job_idx += 1
37+
fold_idx += 1
38+
ret = mp.wait() # get results of processes
39+
results.extend(ret)
40+
41+
return results
42+
43+
44+
def train(config, do_mp=False, fold_idx=0):
45+
# different logging when multiprocessing
46+
if do_mp:
47+
config.set_log(log_name=f"fold_{fold_idx}.log")
48+
else:
49+
config.set_log()
50+
logger = get_logger("train")
2451

25-
def train(config):
2652
# setup GPU device if available, move model into configured device
2753
device, device_ids = prepare_device(config["n_gpu"])
2854

@@ -83,12 +109,14 @@ def train(config):
83109
k_fold = config["cross_validation"]["k_fold"]
84110

85111
results = pd.DataFrame()
86-
Cross_Valid.create_CV(repeat_time, k_fold)
112+
Cross_Valid.create_CV(repeat_time, k_fold, fold_idx=fold_idx)
87113
start = time.time()
88114
for t in range(repeat_time):
89115
if k_fold > 1: # cross validation enabled
90116
train_datasets["data"].split_cv_indexes(k_fold)
91-
for k in range(k_fold):
117+
# 1 loop for multi-process; k_fold loops for single-process
118+
k_time = 1 if do_mp else k_fold
119+
for k in range(k_time):
92120
# data_loaders
93121
train_data_loaders = dict()
94122
valid_data_loaders = dict()
@@ -165,7 +193,7 @@ def train(config):
165193
train_log = trainer.train()
166194
results = pd.concat((results, train_log), axis=1)
167195

168-
if k_fold > 1:
196+
if k_time > 1:
169197
Cross_Valid.next_fold()
170198

171199
if repeat_time > 1:
@@ -184,4 +212,7 @@ def train(config):
184212

185213
logger.info(msg)
186214

215+
max_min, mnt_metric = config["trainer"]["kwargs"]["monitor"].split()
216+
result = result.at[mnt_metric, "mean"]
217+
187218
return result

mains/train_mp.py

Lines changed: 0 additions & 3 deletions
This file was deleted.

parse_config.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ class to parse configuration json file. Handles hyperparameters for training,
3131
if modification is None:
3232
modification = {}
3333
modification.update(self.mod_args)
34+
if self.run_args.mp: # lower trainer verbosity when multiprocessing
35+
modification.update({"trainer;kwargs;verbosity": 0})
3436
self._config = _update_config(config, modification)
3537

3638
# test_args: self.test_args
@@ -49,20 +51,9 @@ class to parse configuration json file. Handles hyperparameters for training,
4951
ensure_dir(dir_path)
5052
self.save_dir[dir_name] = dir_path
5153

52-
log_config = {}
5354
if self.run_args.mode == "train":
54-
if self.run_args.mp: # multiprocessing
55-
log_config.update({"log_config": "logger/logger_config_mp.json"})
5655
self.backup()
5756

58-
# configure logging module
59-
setup_logging(
60-
self.save_dir["log"],
61-
root_dir=self.root_dir,
62-
filename=self.run_args.log_name,
63-
**log_config
64-
)
65-
6657
@classmethod
6758
def from_args(cls, parser, options=""):
6859
"""
@@ -133,6 +124,14 @@ def __getitem__(self, name):
133124
"""Access items like ordinary dict."""
134125
return self.config[name]
135126

127+
# configure logging module
128+
def set_log(self, log_name=None):
129+
setup_logging(
130+
self.save_dir["log"],
131+
root_dir=self.root_dir,
132+
filename=log_name,
133+
)
134+
136135
# read-only attributes
137136
@property
138137
def config(self):

0 commit comments

Comments
 (0)