Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 28 additions & 11 deletions optuna/pruners/_successive_halving.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import heapq
import math

import optuna
Expand Down Expand Up @@ -228,15 +229,24 @@ def _estimate_min_resource(trials: list["optuna.trial.FrozenTrial"]) -> int | No


def _get_current_rung(trial: "optuna.trial.FrozenTrial") -> int:
# The following loop takes `O(log step)` iterations.
rung = 0
while _completed_rung_key(rung) in trial.system_attrs:
rung += 1
return rung
# This function was optimized to avoid repeated string creation and dict lookups in a loop.
# Since system_attrs keys are of the form 'completed_rung_N', we can count such keys directly.
count = 0
prefix = "completed_rung_"
for k in trial.system_attrs:
if k.startswith(prefix):
try:
int_val = int(k[len(prefix) :])
if int_val >= 0:
count += 1
except ValueError:
continue
return count


def _completed_rung_key(rung: int) -> str:
return "completed_rung_{}".format(rung)
# Use faster f-string formatting.
return f"completed_rung_{rung}"


def _get_competing_values(
Expand All @@ -253,15 +263,22 @@ def _is_trial_promotable_to_next_rung(
reduction_factor: int,
study_direction: StudyDirection,
) -> bool:
promotable_idx = (len(competing_values) // reduction_factor) - 1
# Optimize by using heapq for top-k instead of sorting the whole list.
n = len(competing_values)
promotable_idx = (n // reduction_factor) - 1

if promotable_idx == -1:
# Optuna does not support suspending or resuming ongoing trials. Therefore, for the first
# `eta - 1` trials, this implementation instead promotes the trial if its value is the
# smallest one among the competing values.
promotable_idx = 0

competing_values.sort()
if study_direction == StudyDirection.MAXIMIZE:
return value >= competing_values[-(promotable_idx + 1)]
return value <= competing_values[promotable_idx]
# Find the k-th largest value efficiently
k = promotable_idx + 1
kth_value = heapq.nlargest(k, competing_values)[-1]
return value >= kth_value
else:
# Find the k-th smallest value efficiently
k = promotable_idx + 1
kth_value = heapq.nsmallest(k, competing_values)[-1]
return value <= kth_value
10 changes: 3 additions & 7 deletions optuna/storages/_base.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,15 @@
from __future__ import annotations

import abc
from collections.abc import Container
from collections.abc import Sequence
from typing import Any
from typing import cast
from collections.abc import Container, Sequence
from typing import Any, cast

from optuna._typing import JSONSerializable
from optuna.distributions import BaseDistribution
from optuna.exceptions import UpdateFinishedTrialError
from optuna.study._frozen import FrozenStudy
from optuna.study._study_direction import StudyDirection
from optuna.trial import FrozenTrial
from optuna.trial import TrialState

from optuna.trial import FrozenTrial, TrialState

DEFAULT_STUDY_NAME_PREFIX = "no-name-"

Expand Down
37 changes: 11 additions & 26 deletions optuna/study/study.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,38 @@
from __future__ import annotations

from collections.abc import Container
from collections.abc import Iterable
from collections.abc import Mapping
import copy
from numbers import Real
import threading
from typing import Any
from typing import Callable
from typing import cast
from typing import Sequence
from typing import TYPE_CHECKING
from typing import Union
import warnings
from collections.abc import Container, Iterable, Mapping
from numbers import Real
from typing import TYPE_CHECKING, Any, Callable, Sequence, Union, cast

import numpy as np

import optuna
from optuna import exceptions
from optuna import logging
from optuna import pruners
from optuna import samplers
from optuna import storages
from optuna import exceptions, logging, pruners, samplers, storages
from optuna._convert_positional_args import convert_positional_args
from optuna._deprecated import deprecated_func
from optuna._experimental import experimental_func
from optuna._imports import _LazyImport
from optuna._typing import JSONSerializable
from optuna.distributions import _convert_old_distribution_to_new_distribution
from optuna.distributions import BaseDistribution
from optuna.distributions import (
BaseDistribution, _convert_old_distribution_to_new_distribution)
from optuna.storages._heartbeat import is_heartbeat_enabled
from optuna.study._constrained_optimization import _CONSTRAINTS_KEY
from optuna.study._constrained_optimization import _get_feasible_trials
from optuna.study._constrained_optimization import (_CONSTRAINTS_KEY,
_get_feasible_trials)
from optuna.study._multi_objective import _get_pareto_front_trials
from optuna.study._optimize import _optimize
from optuna.study._study_direction import StudyDirection
from optuna.study._study_summary import StudySummary # NOQA
from optuna.study._tell import _get_frozen_trial
from optuna.study._tell import _tell_with_warning
from optuna.trial import create_trial
from optuna.trial import TrialState

from optuna.study._tell import _get_frozen_trial, _tell_with_warning
from optuna.trial import TrialState, create_trial

_dataframe = _LazyImport("optuna.study._dataframe")

if TYPE_CHECKING:
from optuna.study._dataframe import pd
from optuna.trial import FrozenTrial
from optuna.trial import Trial
from optuna.trial import FrozenTrial, Trial


ObjectiveFuncType = Callable[["Trial"], Union[float, Sequence[float]]]
Expand Down