Skip to content

Commit 49b18fd

Browse files
committed
Refactor convergence_plot to use backend plotting
1 parent 605e966 commit 49b18fd

File tree

3 files changed

+265
-175
lines changed

3 files changed

+265
-175
lines changed

src/optimagic/visualization/backends.py

Lines changed: 50 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,12 @@ def _line_plot_plotly(
8282
legend=legend_properties,
8383
margin=margin_properties,
8484
)
85+
fig.update_xaxes(
86+
title=xlabel.format(linebreak="<br>") if xlabel else None, row=row, col=col
87+
)
88+
fig.update_yaxes(
89+
title=ylabel.format(linebreak="<br>") if ylabel else None, row=row, col=col
90+
)
8591

8692
if horizontal_line is not None:
8793
fig.add_hline(
@@ -104,13 +110,6 @@ def _line_plot_plotly(
104110
)
105111
fig.add_trace(trace, row=row, col=col)
106112

107-
fig.update_xaxes(
108-
title=xlabel.format(linebreak="<br>") if xlabel else None, row=row, col=col
109-
)
110-
fig.update_yaxes(
111-
title=ylabel.format(linebreak="<br>") if ylabel else None, row=row, col=col
112-
)
113-
114113
return fig
115114

116115

@@ -217,7 +216,7 @@ def _line_plot_matplotlib(
217216
)
218217

219218
if legend_properties is not None:
220-
ax.legend(**legend_properties)
219+
fig.legend(**legend_properties)
221220

222221
return ax
223222

@@ -247,9 +246,13 @@ def _grid_line_plot_matplotlib(
247246
layout="constrained",
248247
)
249248

250-
for i, (ax, lines) in enumerate(zip(axes.ravel(), lines_list, strict=False)):
249+
for i, (row, col) in enumerate(itertools.product(range(n_rows), range(n_cols))):
250+
if i >= len(lines_list):
251+
axes[row, col].set_visible(False)
252+
continue
253+
251254
_line_plot_matplotlib(
252-
lines,
255+
lines_list[i],
253256
title=titles[i] if titles else None,
254257
xlabel=xlabel,
255258
ylabel=ylabel,
@@ -259,26 +262,14 @@ def _grid_line_plot_matplotlib(
259262
legend_properties=None,
260263
margin_properties=None,
261264
horizontal_line=None,
262-
subplot=ax,
265+
subplot=axes[row, col],
263266
)
264267

265268
fig.legend(**legend_properties or {})
266269

267270
return axes
268271

269272

270-
BACKEND_AVAILABILITY_AND_LINE_PLOT_FUNCTION: dict[
271-
str, tuple[bool, LinePlotFunction, GridLinePlotFunction]
272-
] = {
273-
"plotly": (True, _line_plot_plotly, _grid_line_plot_plotly),
274-
"matplotlib": (
275-
IS_MATPLOTLIB_INSTALLED,
276-
_line_plot_matplotlib,
277-
_grid_line_plot_matplotlib,
278-
),
279-
}
280-
281-
282273
def line_plot(
283274
lines: list[LineData],
284275
backend: Literal["plotly", "matplotlib"] = "plotly",
@@ -350,9 +341,33 @@ def grid_line_plot(
350341
legend_properties: dict[str, Any] | None = None,
351342
margin_properties: dict[str, Any] | None = None,
352343
) -> Any:
344+
"""Create a grid of line plots corresponding to the specified backend.
345+
346+
Args:
347+
lines_list: A list where each element is a list of objects containing data
348+
for the lines in a subplot. The order of sublists determines the order
349+
of subplots in the grid (row-wise), and the order of lines within each
350+
sublist determines the order of lines in that subplot.
351+
backend: The backend to use for plotting.
352+
n_rows: Number of rows in the grid.
353+
n_cols: Number of columns in the grid.
354+
titles: Titles for each subplot in the grid.
355+
xlabel: Label for the x-axis of each subplot.
356+
ylabel: Label for the y-axis of each subplot.
357+
template: Backend-specific template for styling the plots.
358+
height: Height of the entire grid plot (in pixels).
359+
width: Width of the entire grid plot (in pixels).
360+
legend_properties: Backend-specific properties for the legend.
361+
margin_properties: Backend-specific properties for the plot margins.
362+
363+
Returns:
364+
A figure object corresponding to the specified backend.
365+
366+
"""
353367
_grid_line_plot_backend_function = cast(
354368
GridLinePlotFunction, _get_plot_function(backend, grid_plot=True)
355369
)
370+
356371
fig = _grid_line_plot_backend_function(
357372
lines_list,
358373
n_rows=n_rows,
@@ -370,6 +385,18 @@ def grid_line_plot(
370385
return fig
371386

372387

388+
BACKEND_AVAILABILITY_AND_LINE_PLOT_FUNCTION: dict[
389+
str, tuple[bool, LinePlotFunction, GridLinePlotFunction]
390+
] = {
391+
"plotly": (True, _line_plot_plotly, _grid_line_plot_plotly),
392+
"matplotlib": (
393+
IS_MATPLOTLIB_INSTALLED,
394+
_line_plot_matplotlib,
395+
_grid_line_plot_matplotlib,
396+
),
397+
}
398+
399+
373400
def _get_plot_function(
374401
backend: str, grid_plot: bool
375402
) -> LinePlotFunction | GridLinePlotFunction:

0 commit comments

Comments
 (0)