Skip to content

Commit 0b36f64

Browse files
committed
Delay enter trial runner context until the trial is actually being run.
This is part of an attempt to try and see if can work around issues with `multiprocessing.Pool` needing to pickle certain objects when forking. For instance, if the Environment is using an SshServer, we need to start an EventLoopContext in the background to handle the SSH connections and threads are not picklable. Nor are file handles, DB connections, etc., so there may be other things we also need to adjust to make this work. See Also microsoft#967
1 parent 2fbaaca commit 0b36f64

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

mlos_bench/mlos_bench/schedulers/base_scheduler.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -200,8 +200,9 @@ def __enter__(self) -> "Scheduler":
200200
_LOG.debug("Scheduler START :: %s", self)
201201
assert self.experiment is None
202202
assert not self._in_context
203-
for trial_runner in self._trial_runners.values():
204-
trial_runner.__enter__()
203+
# NOTE: We delay entering the context of trial_runners until it's time
204+
# to run the trial in order to avoid incompatibilities with
205+
# multiprocessing.Pool.
205206
self._optimizer.__enter__()
206207
# Start new or resume the existing experiment. Verify that the
207208
# experiment configuration is compatible with the previous runs.
@@ -235,7 +236,8 @@ def __exit__(
235236
self._experiment.__exit__(ex_type, ex_val, ex_tb)
236237
self._optimizer.__exit__(ex_type, ex_val, ex_tb)
237238
for trial_runner in self._trial_runners.values():
238-
trial_runner.__exit__(ex_type, ex_val, ex_tb)
239+
# TrialRunners should have already exited their context after running the Trial.
240+
assert not trial_runner._in_context # pylint: disable=protected-access
239241
self._experiment = None
240242
self._in_context = False
241243
return False # Do not suppress exceptions
@@ -267,7 +269,8 @@ def teardown(self) -> None:
267269
if self._do_teardown:
268270
for trial_runner in self._trial_runners.values():
269271
assert not trial_runner.is_running
270-
trial_runner.teardown()
272+
with trial_runner:
273+
trial_runner.teardown()
271274

272275
def get_best_observation(self) -> tuple[dict[str, float] | None, TunableGroups | None]:
273276
"""Get the best observation from the optimizer."""

mlos_bench/mlos_bench/schedulers/sync_scheduler.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,5 +39,6 @@ def run_trial(self, trial: Storage.Trial) -> None:
3939
super().run_trial(trial)
4040
# In the sync scheduler we run each trial on its own TrialRunner in sequence.
4141
trial_runner = self.get_trial_runner(trial)
42-
trial_runner.run_trial(trial, self.global_config)
43-
_LOG.info("QUEUE: Finished trial: %s on %s", trial, trial_runner)
42+
with trial_runner:
43+
trial_runner.run_trial(trial, self.global_config)
44+
_LOG.info("QUEUE: Finished trial: %s on %s", trial, trial_runner)

0 commit comments

Comments
 (0)