Skip to content

Conversation

@codeflash-ai
Copy link

@codeflash-ai codeflash-ai bot commented Nov 3, 2025

📄 10% (0.10x) speedup for _get_timeline_plot in optuna/visualization/matplotlib/_timeline.py

⏱️ Runtime : 517 milliseconds 471 milliseconds (best of 5 runs)

📝 Explanation and details

The optimization replaces an inefficient nested loop with a more efficient set-based lookup for legend construction.

Key Change:

  • Before: For each state in _cm, the code used any(_get_state_name(b) == state_name for b in info.bars) - this creates a nested O(n*k) loop where n is the number of bars and k is the number of possible states (6).
  • After: The code precomputes present_state_names = set(_get_state_name(b) for b in info.bars) once, then uses if state_name in present_state_names for O(1) lookups.

Why This is Faster:
The original approach has O(n*k) complexity because for each of the 6 possible states, it potentially scans through all n bars to check if that state exists. With the optimization, we scan the bars once to build a set (O(n)), then do 6 constant-time set lookups (O(k)), resulting in O(n+k) total complexity.

Performance Impact:
The line profiler shows the legend check (if any(_get_state_name(b) == state_name for b in info.bars)) took 31.6ms in the original vs the set-based approach taking only 6.5ms + 0.04ms = 6.54ms total - a ~79% reduction in that specific operation.

Best For:
This optimization is particularly effective for test cases with many trials (like the 999-trial test showing 9.6% speedup) where the nested loop penalty becomes significant, while still providing consistent 8-11% improvements across all test scenarios.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 7 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 96.3%
🌀 Generated Regression Tests and Runtime
import datetime
from types import SimpleNamespace

# imports
import pytest
from optuna.visualization.matplotlib._timeline import _get_timeline_plot


# Minimal mock implementations for required classes/enums
class TrialState:
    COMPLETE = SimpleNamespace(name="COMPLETE")
    FAIL = SimpleNamespace(name="FAIL")
    PRUNED = SimpleNamespace(name="PRUNED")
    RUNNING = SimpleNamespace(name="RUNNING")
    WAITING = SimpleNamespace(name="WAITING")

# Minimal _TimelineBarInfo and _TimelineInfo mocks
class _TimelineBarInfo:
    def __init__(self, number, start, complete, state, infeasible=False):
        self.number = number
        self.start = start
        self.complete = complete
        self.state = state
        self.infeasible = infeasible

class _TimelineInfo:
    def __init__(self, bars):
        self.bars = bars

# Minimal matplotlib mock for testing (no actual plotting)
class DummyPatch:
    def __init__(self, color, label):
        self.color = color
        self.label = label

class DummyAxes:
    def __init__(self):
        self._calls = []
        self._legend_handles = []
        self._title = None
        self._xlabel = None
        self._ylabel = None
        self._barh_args = None
        self._xlim = None
        self._major_locator_set = False
        self._major_formatter_set = False

    def set_title(self, title):
        self._title = title
        self._calls.append(('set_title', title))

    def set_xlabel(self, label):
        self._xlabel = label
        self._calls.append(('set_xlabel', label))

    def set_ylabel(self, label):
        self._ylabel = label
        self._calls.append(('set_ylabel', label))

    def barh(self, y, width, left, color):
        self._barh_args = (y, width, left, color)
        self._calls.append(('barh', y, width, left, color))

    def legend(self, handles, loc, bbox_to_anchor):
        self._legend_handles = handles
        self._calls.append(('legend', handles, loc, bbox_to_anchor))

    def set_xlim(self, right, left):
        self._xlim = (right, left)
        self._calls.append(('set_xlim', right, left))

    class yaxis:
        @staticmethod
        def set_major_locator(locator):
            DummyAxes._major_locator_set = True

    class xaxis:
        @staticmethod
        def set_major_formatter(formatter):
            DummyAxes._major_formatter_set = True

class DummyFig:
    def tight_layout(self):
        pass

def dummy_autofmt_xdate():
    pass

class DummyMatplotlib:
    class patches:
        Patch = DummyPatch
    class ticker:
        class MaxNLocator:
            def __init__(self, integer):
                self.integer = integer

matplotlib = DummyMatplotlib()
plt = SimpleNamespace(
    style=SimpleNamespace(use=lambda style: None),
    subplots=lambda: (DummyFig(), DummyAxes()),
    gcf=lambda: SimpleNamespace(autofmt_xdate=dummy_autofmt_xdate)
)
DateFormatter = lambda fmt: fmt

# Function under test
_INFEASIBLE_KEY = "INFEASIBLE"
from optuna.visualization.matplotlib._timeline import _get_timeline_plot

# --- Unit Tests ---

# Basic Test Cases



def test_empty_bars():
    # No bars (edge case)
    info = _TimelineInfo([])
    codeflash_output = _get_timeline_plot(info); ax = codeflash_output # 4.53ms -> 4.05ms (12.1% faster)

# Edge Test Cases










#------------------------------------------------
import datetime
from types import SimpleNamespace

# imports
import pytest
from optuna.visualization.matplotlib._timeline import _get_timeline_plot


# Simulate required Optuna classes/enums for testing
class TrialState:
    COMPLETE = SimpleNamespace(name="COMPLETE")
    FAIL = SimpleNamespace(name="FAIL")
    PRUNED = SimpleNamespace(name="PRUNED")
    RUNNING = SimpleNamespace(name="RUNNING")
    WAITING = SimpleNamespace(name="WAITING")

# Dummy matplotlib classes for testing (no actual plotting)
class DummyPatch:
    def __init__(self, color, label):
        self.color = color
        self.label = label

class DummyLocator:
    def __init__(self, integer=None):
        self.integer = integer

class DummyFormatter:
    def __init__(self, fmt):
        self.fmt = fmt

class DummyYAxis:
    def set_major_locator(self, locator):
        self.locator = locator

class DummyXAxis:
    def set_major_formatter(self, formatter):
        self.formatter = formatter

class DummyFigure:
    def tight_layout(self):
        self.tight = True
    def autofmt_xdate(self):
        self.autofmt = True

class DummyAxes:
    def __init__(self):
        self.title = None
        self.xlabel = None
        self.ylabel = None
        self.bars = []
        self.legend_handles = []
        self.xlim = None
        self.yaxis = DummyYAxis()
        self.xaxis = DummyXAxis()
        self.figure = DummyFigure()
    def set_title(self, title):
        self.title = title
    def set_xlabel(self, xlabel):
        self.xlabel = xlabel
    def set_ylabel(self, ylabel):
        self.ylabel = ylabel
    def barh(self, y, width, left, color):
        self.bars = list(zip(y, width, left, color))
    def legend(self, handles, loc, bbox_to_anchor):
        self.legend_handles = handles
    def set_xlim(self, right, left):
        self.xlim = (right, left)

class DummyMatplotlib:
    class patches:
        Patch = DummyPatch
    class ticker:
        MaxNLocator = DummyLocator

# Simulate plt and matplotlib for testing
class DummyPlt:
    @staticmethod
    def style_use(style):
        pass
    @staticmethod
    def subplots():
        return DummyFigure(), DummyAxes()
    @staticmethod
    def gcf():
        return DummyFigure()

# Simulate DateFormatter for testing
def DummyDateFormatter(fmt):
    return DummyFormatter(fmt)

# Simulate _TimelineBarInfo and _TimelineInfo
class _TimelineBarInfo:
    def __init__(self, number, start, complete, state, infeasible=False):
        self.number = number
        self.start = start
        self.complete = complete
        self.state = state
        self.infeasible = infeasible

class _TimelineInfo:
    def __init__(self, bars):
        self.bars = bars

# Patch the function to use our dummy matplotlib
matplotlib = DummyMatplotlib
plt = DummyPlt
DateFormatter = DummyDateFormatter

_INFEASIBLE_KEY = "INFEASIBLE"
from optuna.visualization.matplotlib._timeline import _get_timeline_plot

# unit tests

# Basic Test Cases



def test_infeasible_trial():
    # Trial is COMPLETE but infeasible
    start = datetime.datetime(2024, 6, 1, 10, 0, 0)
    complete = datetime.datetime(2024, 6, 1, 10, 10, 0)
    bars = [_TimelineBarInfo(0, start, complete, TrialState.COMPLETE, infeasible=True)]
    info = _TimelineInfo(bars)
    codeflash_output = _get_timeline_plot(info); ax = codeflash_output # 36.5ms -> 33.8ms (8.12% faster)

def test_empty_bars():
    # No trials
    info = _TimelineInfo([])
    codeflash_output = _get_timeline_plot(info); ax = codeflash_output # 4.41ms -> 4.07ms (8.30% faster)

# Edge Test Cases

def test_zero_duration_trial():
    # Trial with zero duration
    start = datetime.datetime(2024, 6, 1, 10, 0, 0)
    bars = [_TimelineBarInfo(0, start, start, TrialState.COMPLETE)]
    info = _TimelineInfo(bars)
    codeflash_output = _get_timeline_plot(info); ax = codeflash_output # 41.9ms -> 38.6ms (8.69% faster)

def test_overlapping_trials():
    # Trials with overlapping time intervals
    start = datetime.datetime(2024, 6, 1, 10, 0, 0)
    bars = [
        _TimelineBarInfo(0, start, start + datetime.timedelta(minutes=10), TrialState.COMPLETE),
        _TimelineBarInfo(1, start + datetime.timedelta(minutes=5), start + datetime.timedelta(minutes=15), TrialState.FAIL),
    ]
    info = _TimelineInfo(bars)
    codeflash_output = _get_timeline_plot(info); ax = codeflash_output # 29.8ms -> 27.0ms (10.5% faster)




def test_many_trials():
    # 100 trials, alternating states
    start = datetime.datetime(2024, 6, 1, 10, 0, 0)
    bars = []
    for i in range(100):
        state = [TrialState.COMPLETE, TrialState.FAIL, TrialState.PRUNED, TrialState.RUNNING, TrialState.WAITING][i % 5]
        bars.append(_TimelineBarInfo(i, start + datetime.timedelta(minutes=i), start + datetime.timedelta(minutes=i+1), state))
    info = _TimelineInfo(bars)
    codeflash_output = _get_timeline_plot(info); ax = codeflash_output # 63.9ms -> 57.6ms (10.9% faster)
    # Each state should appear in legend
    for label in ["COMPLETE", "FAIL", "PRUNED", "RUNNING", "WAITING"]:
        pass



def test_max_trials_limit():
    # Test at the upper limit (999 trials)
    start = datetime.datetime(2024, 6, 1, 10, 0, 0)
    bars = []
    for i in range(999):
        bars.append(_TimelineBarInfo(i, start + datetime.timedelta(minutes=i), start + datetime.timedelta(minutes=i+1), TrialState.COMPLETE))
    info = _TimelineInfo(bars)
    codeflash_output = _get_timeline_plot(info); ax = codeflash_output # 335ms -> 306ms (9.61% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-_get_timeline_plot-mhjm20zt and push.

Codeflash Static Badge

The optimization replaces an inefficient nested loop with a more efficient set-based lookup for legend construction. 

**Key Change:**
- **Before**: For each state in `_cm`, the code used `any(_get_state_name(b) == state_name for b in info.bars)` - this creates a nested O(n*k) loop where n is the number of bars and k is the number of possible states (6).
- **After**: The code precomputes `present_state_names = set(_get_state_name(b) for b in info.bars)` once, then uses `if state_name in present_state_names` for O(1) lookups.

**Why This is Faster:**
The original approach has O(n*k) complexity because for each of the 6 possible states, it potentially scans through all n bars to check if that state exists. With the optimization, we scan the bars once to build a set (O(n)), then do 6 constant-time set lookups (O(k)), resulting in O(n+k) total complexity.

**Performance Impact:**
The line profiler shows the legend check (`if any(_get_state_name(b) == state_name for b in info.bars)`) took 31.6ms in the original vs the set-based approach taking only 6.5ms + 0.04ms = 6.54ms total - a ~79% reduction in that specific operation.

**Best For:**
This optimization is particularly effective for test cases with many trials (like the 999-trial test showing 9.6% speedup) where the nested loop penalty becomes significant, while still providing consistent 8-11% improvements across all test scenarios.
@codeflash-ai codeflash-ai bot requested a review from mashraf-222 November 3, 2025 20:47
@codeflash-ai codeflash-ai bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash labels Nov 3, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant