diff --git a/src/optimagic/visualization/backends.py b/src/optimagic/visualization/backends.py index fe9a598b6..541a39b30 100644 --- a/src/optimagic/visualization/backends.py +++ b/src/optimagic/visualization/backends.py @@ -38,7 +38,7 @@ def _line_plot_plotly( legend_properties: dict[str, Any] | None, ) -> go.Figure: if template is None: - template = "plotly" + template = "simple_white" fig = go.Figure() @@ -49,6 +49,7 @@ def _line_plot_plotly( name=line.name, line_color=line.color, mode="lines", + showlegend=line.show_in_legend, ) fig.add_trace(trace) @@ -123,6 +124,8 @@ def line_plot( Args: lines: List of objects each containing data for a line in the plot. + The order of lines in the list determines the order in which they are + plotted, with later lines being rendered on top of earlier ones. backend: The backend to use for plotting. title: Title of the plot. xlabel: Label for the x-axis. diff --git a/src/optimagic/visualization/history_plots.py b/src/optimagic/visualization/history_plots.py index 2dcaa55da..d3aa440cd 100644 --- a/src/optimagic/visualization/history_plots.py +++ b/src/optimagic/visualization/history_plots.py @@ -2,13 +2,12 @@ import itertools from dataclasses import dataclass from pathlib import Path -from typing import Any, Literal +from typing import Any, Callable, Literal import numpy as np -import plotly.graph_objects as go from pybaum import leaf_names, tree_flatten, tree_just_flatten, tree_unflatten -from optimagic.config import DEFAULT_PALETTE, PLOTLY_TEMPLATE +from optimagic.config import DEFAULT_PALETTE from optimagic.logging.logger import LogReader, SQLiteLogOptions from optimagic.optimization.algorithm import Algorithm from optimagic.optimization.history import History @@ -18,7 +17,7 @@ from optimagic.visualization.backends import line_plot from optimagic.visualization.plotting_utilities import LineData, get_palette_cycle -BACKEND_TO_CRITERION_PLOT_LEGEND_PROPERTIES: dict[str, dict[str, Any]] = { +BACKEND_TO_HISTORY_PLOT_LEGEND_PROPERTIES: dict[str, dict[str, Any]] = { "plotly": { "yanchor": "top", "xanchor": "right", @@ -77,10 +76,11 @@ def criterion_plot( # ================================================================================== # Extract backend-agnostic plotting data from results - list_of_optimize_data = _retrieve_optimization_data( + list_of_optimize_data = _retrieve_optimization_data_from_results( results=dict_of_optimize_results_or_paths, stack_multistart=stack_multistart, show_exploration=show_exploration, + plot_name="criterion_plot", ) lines, multistart_lines = _extract_criterion_plot_lines( @@ -95,14 +95,12 @@ def criterion_plot( # Generate the figure fig = line_plot( - lines=lines + multistart_lines, + lines=multistart_lines + lines, backend=backend, xlabel="No. of criterion evaluations", ylabel="Criterion value", template=template, - legend_properties=BACKEND_TO_CRITERION_PLOT_LEGEND_PROPERTIES.get( - backend, None - ), + legend_properties=BACKEND_TO_HISTORY_PLOT_LEGEND_PROPERTIES.get(backend, None), ) return fig @@ -153,92 +151,65 @@ def _convert_key_to_str(key: Any) -> str: def params_plot( - result, - selector=None, - max_evaluations=None, - template=PLOTLY_TEMPLATE, - show_exploration=False, -): + result: ResultOrPath, + selector: Callable[[PyTree], PyTree] | None = None, + max_evaluations: int | None = None, + backend: Literal["plotly", "matplotlib"] = "plotly", + template: str | None = None, + palette: list[str] | str = DEFAULT_PALETTE, + show_exploration: bool = False, +) -> Any: """Plot the params history of an optimization. Args: - result (Union[OptimizeResult, pathlib.Path, str]): An optimization results with - collected history. If dict, then the key is used as the name in a legend. - selector (callable): A callable that takes params and returns a subset - of params. If provided, only the selected subset of params is plotted. - max_evaluations (int): Clip the criterion history after that many entries. - template (str): The template for the figure. Default is "plotly_white". - show_exploration (bool): If True, exploration samples of a multistart - optimization are visualized. Default is False. + result: An optimization result with collected history, or path to it. + If dict, then the key is used as the name in the legend. + selector: A callable that takes params and returns a subset of params. + If provided, only the selected subset of params is plotted. + max_evaluations: Clip the criterion history after that many entries. + backend: The backend to use for plotting. Default is "plotly". + template: The template for the figure. If not specified, the default template of + the backend is used. + palette: The coloring palette for traces. Default is the D3 qualitative palette. + show_exploration: If True, exploration samples of a multistart optimization are + visualized. Default is False. Returns: - plotly.graph_objs._figure.Figure: The figure. + The figure object containing the params plot. """ # ================================================================================== # Process inputs - # ================================================================================== - if isinstance(result, OptimizeResult): - data = _retrieve_optimization_data_from_results_object( - result, - stack_multistart=True, - show_exploration=show_exploration, - plot_name="params_plot", - ) - start_params = result.start_params - elif isinstance(result, (str, Path)): - data = _retrieve_optimization_data_from_database( - result, - stack_multistart=True, - show_exploration=show_exploration, - ) - start_params = data.start_params - else: - raise TypeError("result must be an OptimizeResult or a path to a log file.") - - if data.stacked_local_histories is not None: - history = data.stacked_local_histories.params - else: - history = data.history.params + palette_cycle = get_palette_cycle(palette) # ================================================================================== - # Create figure - # ================================================================================== - - fig = go.Figure() - - registry = get_registry(extended=True) - - hist_arr = np.array([tree_just_flatten(p, registry=registry) for p in history]).T - names = leaf_names(start_params, registry=registry) + # Extract backend-agnostic plotting data from results - if selector is not None: - flat, treedef = tree_flatten(start_params, registry=registry) - helper = tree_unflatten(treedef, list(range(len(flat))), registry=registry) - selected = np.array(tree_just_flatten(selector(helper), registry=registry)) - names = [names[i] for i in selected] - hist_arr = hist_arr[selected] + optimize_data = _retrieve_optimization_data_from_single_result( + result=result, + stack_multistart=True, + show_exploration=show_exploration, + plot_name="params_plot", + ) - for name, data in zip(names, hist_arr, strict=False): - if max_evaluations is not None and len(data) > max_evaluations: - plot_data = data[:max_evaluations] - else: - plot_data = data + lines = _extract_params_plot_lines( + data=optimize_data, + selector=selector, + max_evaluations=max_evaluations, + palette_cycle=palette_cycle, + ) - trace = go.Scatter( - x=np.arange(len(plot_data)), - y=plot_data, - mode="lines", - name=name, - ) - fig.add_trace(trace) + # ================================================================================== + # Generate the figure - fig.update_layout( + fig = line_plot( + lines=lines, + backend=backend, + xlabel="No. of criterion evaluations", + ylabel="Parameter value", template=template, - xaxis_title_text="No. of criterion evaluations", - yaxis_title_text="Parameter value", - legend={"yanchor": "top", "xanchor": "right", "y": 0.95, "x": 0.95}, + legend_properties=BACKEND_TO_HISTORY_PLOT_LEGEND_PROPERTIES.get(backend, None), ) return fig @@ -261,66 +232,90 @@ class _PlottingMultistartHistory: stacked_local_histories: History | None -def _retrieve_optimization_data( +def _retrieve_optimization_data_from_results( results: dict[str, ResultOrPath], stack_multistart: bool, show_exploration: bool, + plot_name: str, ) -> list[_PlottingMultistartHistory]: - """Retrieve data for criterion plot from results (OptimizeResult or database). + # Retrieves data from multiple results by iterating over the results dictionary + # and calling the single result retrieval function. + + data = [] + for name, res in results.items(): + _data = _retrieve_optimization_data_from_single_result( + result=res, + stack_multistart=stack_multistart, + show_exploration=show_exploration, + plot_name=plot_name, + res_name=name, + ) + + data.append(_data) + + return data + + +def _retrieve_optimization_data_from_single_result( + result: ResultOrPath, + stack_multistart: bool, + show_exploration: bool, + plot_name: str, + res_name: str | None = None, +) -> _PlottingMultistartHistory: + """Retrieve data from a single result (OptimizeResult or database). Args: - results: A dict of optimization results with collected history. - The key is used as the name in a legend. + result: An optimization result with collected history, or path to it. stack_multistart: Whether to combine multistart histories into a single history. Default is False. show_exploration: If True, exploration samples of a multistart optimization are visualized. Default is False. + plot_name: Name of the plotting function that calls this function. Used for + raising errors. + res_name: Name of the result. Returns: - A list of objects containing the history, metadata, and local histories of each + A data object containing the history, metadata, and local histories of the optimization result. """ - data = [] - for name, res in results.items(): - if isinstance(res, OptimizeResult): - _data = _retrieve_optimization_data_from_results_object( - res=res, - stack_multistart=stack_multistart, - show_exploration=show_exploration, - plot_name="criterion_plot", - res_name=name, - ) - elif isinstance(res, (str, Path)): - _data = _retrieve_optimization_data_from_database( - res=res, - stack_multistart=stack_multistart, - show_exploration=show_exploration, - res_name=name, - ) - else: - msg = ( - "results must be (or contain) an OptimizeResult or a path to a log " - f"file, but is type {type(res)}." - ) - raise TypeError(msg) - - data.append(_data) + if isinstance(result, OptimizeResult): + data = _retrieve_optimization_data_from_result_object( + res=result, + stack_multistart=stack_multistart, + show_exploration=show_exploration, + plot_name=plot_name, + res_name=res_name, + ) + elif isinstance(result, (str, Path)): + data = _retrieve_optimization_data_from_database( + res=result, + stack_multistart=stack_multistart, + show_exploration=show_exploration, + res_name=res_name, + ) + else: + msg = ( + "result must be an OptimizeResult or a path to a log file, " + f"but is type {type(result)}." + ) + raise TypeError(msg) return data -def _retrieve_optimization_data_from_results_object( +def _retrieve_optimization_data_from_result_object( res: OptimizeResult, stack_multistart: bool, show_exploration: bool, plot_name: str, res_name: str | None = None, ) -> _PlottingMultistartHistory: - """Retrieve optimization data from results object. + """Retrieve optimization data from result object. Args: - res: An optimization results object. + res: An optimization result object. stack_multistart: Whether to combine multistart histories into a single history. Default is False. show_exploration: If True, exploration samples of a multistart optimization are @@ -550,3 +545,60 @@ def _extract_criterion_plot_lines( lines.append(line_data) return lines, multistart_lines + + +def _extract_params_plot_lines( + data: _PlottingMultistartHistory, + selector: Callable[[PyTree], PyTree] | None, + max_evaluations: int | None, + palette_cycle: "itertools.cycle[str]", +) -> list[LineData]: + """Extract lines for params plot from data. + + Args: + data: Data retrieved from results or database. + selector: A callable that takes params and returns a subset of params. + If provided, only the selected subset of params is plotted. + max_evaluations: Clip the criterion history after that many entries. + palette_cycle: Cycle of colors for plotting. + + Returns: + lines: Parameter histories. + + """ + if data.stacked_local_histories is not None: + history = data.stacked_local_histories.params + else: + history = data.history.params + start_params = data.start_params + + registry = get_registry(extended=True) + + hist_arr = np.array([tree_just_flatten(p, registry=registry) for p in history]).T + names = leaf_names(start_params, registry=registry) + + if selector is not None: + flat, treedef = tree_flatten(start_params, registry=registry) + helper = tree_unflatten(treedef, list(range(len(flat))), registry=registry) + selected = np.array(tree_just_flatten(selector(helper), registry=registry)) + names = [names[i] for i in selected] + hist_arr = hist_arr[selected] + + lines: list[LineData] = [] + + for name, _data in zip(names, hist_arr, strict=False): + if max_evaluations is not None and len(_data) > max_evaluations: + plot_data = _data[:max_evaluations] + else: + plot_data = _data + + line_data = LineData( + x=np.arange(len(plot_data)), + y=plot_data, + color=next(palette_cycle), + name=name, + show_in_legend=True, + ) + lines.append(line_data) + + return lines diff --git a/tests/optimagic/visualization/test_history_plots.py b/tests/optimagic/visualization/test_history_plots.py index 3c157115c..bad212fc8 100644 --- a/tests/optimagic/visualization/test_history_plots.py +++ b/tests/optimagic/visualization/test_history_plots.py @@ -15,9 +15,11 @@ from optimagic.visualization.history_plots import ( LineData, _extract_criterion_plot_lines, + _extract_params_plot_lines, _harmonize_inputs_to_dict, _PlottingMultistartHistory, - _retrieve_optimization_data, + _retrieve_optimization_data_from_results, + _retrieve_optimization_data_from_single_result, criterion_plot, params_plot, ) @@ -154,6 +156,12 @@ def test_criterion_plot_different_backends(minimize_result, backend): criterion_plot(res, backend=backend) +@pytest.mark.parametrize("backend", BACKEND_AVAILABILITY_AND_LINE_PLOT_FUNCTION.keys()) +def test_params_plot_different_backends(minimize_result, backend): + res = minimize_result[False][0] + params_plot(res, backend=backend) + + def test_harmonize_inputs_to_dict_single_result(): res = minimize(fun=lambda x: x @ x, params=np.arange(5), algorithm="scipy_lbfgsb") assert _harmonize_inputs_to_dict(results=res, names=None) == {"0": res} @@ -218,8 +226,8 @@ def test_retrieve_data_from_result(minimize_result): res = minimize_result[False][0] results = {"bla": res} - data = _retrieve_optimization_data( - results=results, stack_multistart=False, show_exploration=False + data = _retrieve_optimization_data_from_results( + results=results, stack_multistart=False, show_exploration=False, plot_name="bla" ) assert isinstance(data, list) and len(data) == 1 @@ -238,8 +246,8 @@ def test_retrieve_data_from_logged_result(tmp_path): ) results = {"logged": tmp_path / "test.db"} - data = _retrieve_optimization_data( - results=results, stack_multistart=False, show_exploration=False + data = _retrieve_optimization_data_from_results( + results=results, stack_multistart=False, show_exploration=False, plot_name="bla" ) assert isinstance(data, list) and len(data) == 1 @@ -254,8 +262,11 @@ def test_retrieve_data_from_multistart_result(minimize_result, stack_multistart) res = minimize_result[True][0] results = {"multistart": res} - data = _retrieve_optimization_data( - results=results, stack_multistart=stack_multistart, show_exploration=False + data = _retrieve_optimization_data_from_results( + results=results, + stack_multistart=stack_multistart, + show_exploration=False, + plot_name="bla", ) assert isinstance(data, list) and len(data) == 1 @@ -275,8 +286,8 @@ def test_retrieve_data_from_multistart_result(minimize_result, stack_multistart) def test_extract_criterion_plot_lines(minimize_result): res = minimize_result[True][0] results = {"multistart": res} - data = _retrieve_optimization_data( - results=results, stack_multistart=False, show_exploration=False + data = _retrieve_optimization_data_from_results( + results=results, stack_multistart=False, show_exploration=False, plot_name="bla" ) palette_cycle = itertools.cycle(["red", "green", "blue"]) @@ -301,3 +312,32 @@ def test_extract_criterion_plot_lines(minimize_result): isinstance(line, LineData) for line in multistart_lines ) assert len(multistart_lines) == 5 + + +def test_extract_params_plot_lines(minimize_result): + res = minimize_result[False][0] + data = _retrieve_optimization_data_from_single_result( + result=res, + stack_multistart=False, + show_exploration=False, + plot_name="params_plot", + ) + + palette_cycle = itertools.cycle(["red", "green", "blue"]) + + lines = _extract_params_plot_lines( + data=data, + selector=None, + max_evaluations=None, + palette_cycle=palette_cycle, + ) + + params = np.array(res.history.params) + num_params = params.shape[1] + + assert isinstance(lines, list) and len(lines) == num_params + assert all(isinstance(line, LineData) for line in lines) + + for i, line in enumerate(lines): + assert_array_equal(line.x, np.arange(len(params))) + assert_array_equal(line.y, params[:, i])