Skip to content

Commit 605e966

Browse files
committed
Implementation of grid plot functions for plotly and matplotlib
1 parent 5bf11e5 commit 605e966

File tree

1 file changed

+216
-35
lines changed

1 file changed

+216
-35
lines changed

src/optimagic/visualization/backends.py

Lines changed: 216 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
from typing import TYPE_CHECKING, Any, Literal, Protocol, runtime_checkable
1+
import itertools
2+
from typing import TYPE_CHECKING, Any, Literal, Protocol, cast, runtime_checkable
23

4+
import numpy as np
35
import plotly.graph_objects as go
46

57
from optimagic.config import IS_MATPLOTLIB_INSTALLED
@@ -25,6 +27,26 @@ def __call__(
2527
legend_properties: dict[str, Any] | None,
2628
margin_properties: dict[str, Any] | None,
2729
horizontal_line: float | None,
30+
subplot: Any | None = None,
31+
) -> Any: ...
32+
33+
34+
@runtime_checkable
35+
class GridLinePlotFunction(Protocol):
36+
def __call__(
37+
self,
38+
lines_list: list[list[LineData]],
39+
*,
40+
n_rows: int,
41+
n_cols: int,
42+
titles: list[str] | None,
43+
xlabel: str | None,
44+
ylabel: str | None,
45+
template: str | None,
46+
height: int | None,
47+
width: int | None,
48+
legend_properties: dict[str, Any] | None,
49+
margin_properties: dict[str, Any] | None,
2850
) -> Any: ...
2951

3052

@@ -40,16 +62,20 @@ def _line_plot_plotly(
4062
legend_properties: dict[str, Any] | None,
4163
margin_properties: dict[str, Any] | None,
4264
horizontal_line: float | None,
65+
subplot: tuple[go.Figure, int, int] | None = None,
4366
) -> go.Figure:
4467
if template is None:
4568
template = "simple_white"
4669

47-
fig = go.Figure()
70+
if subplot is None:
71+
fig = go.Figure()
72+
row, col = None, None
73+
74+
else:
75+
fig, row, col = subplot
4876

4977
fig.update_layout(
5078
title=title,
51-
xaxis_title=xlabel.format(linebreak="<br>") if xlabel else None,
52-
yaxis_title=ylabel,
5379
template=template,
5480
height=height,
5581
width=width,
@@ -62,6 +88,8 @@ def _line_plot_plotly(
6288
y=horizontal_line,
6389
line_width=fig.layout.yaxis.linewidth or 1,
6490
opacity=1.0,
91+
row=row,
92+
col=col,
6593
)
6694

6795
for line in lines:
@@ -72,8 +100,60 @@ def _line_plot_plotly(
72100
line_color=line.color,
73101
mode="lines",
74102
showlegend=line.show_in_legend,
103+
legendgroup=line.name,
104+
)
105+
fig.add_trace(trace, row=row, col=col)
106+
107+
fig.update_xaxes(
108+
title=xlabel.format(linebreak="<br>") if xlabel else None, row=row, col=col
109+
)
110+
fig.update_yaxes(
111+
title=ylabel.format(linebreak="<br>") if ylabel else None, row=row, col=col
112+
)
113+
114+
return fig
115+
116+
117+
def _grid_line_plot_plotly(
118+
lines_list: list[list[LineData]],
119+
*,
120+
n_rows: int,
121+
n_cols: int,
122+
titles: list[str] | None,
123+
xlabel: str | None,
124+
ylabel: str | None,
125+
template: str | None,
126+
height: int | None,
127+
width: int | None,
128+
legend_properties: dict[str, Any] | None,
129+
margin_properties: dict[str, Any] | None,
130+
) -> go.Figure:
131+
from plotly.subplots import make_subplots
132+
133+
fig = make_subplots(
134+
rows=n_rows,
135+
cols=n_cols,
136+
subplot_titles=titles,
137+
)
138+
139+
for lines, (row, col) in zip(
140+
lines_list,
141+
itertools.product(range(1, n_rows + 1), range(1, n_cols + 1)),
142+
strict=False,
143+
):
144+
_line_plot_plotly(
145+
lines,
146+
title=None,
147+
xlabel=xlabel,
148+
ylabel=ylabel,
149+
template=template,
150+
height=height,
151+
width=width,
152+
legend_properties=legend_properties,
153+
margin_properties=margin_properties,
154+
horizontal_line=None,
155+
subplot=(fig, row, col),
75156
)
76-
fig.add_trace(trace)
77157

78158
return fig
79159

@@ -90,6 +170,7 @@ def _line_plot_matplotlib(
90170
legend_properties: dict[str, Any] | None,
91171
margin_properties: dict[str, Any] | None,
92172
horizontal_line: float | None,
173+
subplot: "plt.Axes | None" = None,
93174
) -> "plt.Axes":
94175
import matplotlib.pyplot as plt
95176

@@ -105,10 +186,14 @@ def _line_plot_matplotlib(
105186
template = "default"
106187

107188
with plt.style.context(template):
108-
px = 1 / plt.rcParams["figure.dpi"] # pixel in inches
109-
fig, ax = plt.subplots(
110-
figsize=(width * px, height * px) if width and height else None
111-
)
189+
if subplot is None:
190+
px = 1 / plt.rcParams["figure.dpi"] # pixel in inches
191+
fig, ax = plt.subplots(
192+
figsize=(width * px, height * px) if width and height else None,
193+
layout="constrained",
194+
)
195+
else:
196+
ax = subplot
112197

113198
if horizontal_line is not None:
114199
ax.axhline(
@@ -128,23 +213,69 @@ def _line_plot_matplotlib(
128213
ax.set(
129214
title=title,
130215
xlabel=xlabel.format(linebreak="\n") if xlabel else None,
131-
ylabel=ylabel,
216+
ylabel=ylabel.format(linebreak="\n") if ylabel else None,
132217
)
133218

134-
if legend_properties is None:
135-
legend_properties = {}
136-
ax.legend(**legend_properties)
137-
138-
fig.tight_layout()
219+
if legend_properties is not None:
220+
ax.legend(**legend_properties)
139221

140222
return ax
141223

142224

225+
def _grid_line_plot_matplotlib(
226+
lines_list: list[list[LineData]],
227+
*,
228+
n_rows: int,
229+
n_cols: int,
230+
titles: list[str] | None,
231+
xlabel: str | None,
232+
ylabel: str | None,
233+
template: str | None,
234+
height: int | None,
235+
width: int | None,
236+
legend_properties: dict[str, Any] | None,
237+
margin_properties: dict[str, Any] | None,
238+
) -> np.ndarray:
239+
import matplotlib.pyplot as plt
240+
241+
px = 1 / plt.rcParams["figure.dpi"] # pixel in inches
242+
fig, axes = plt.subplots(
243+
nrows=n_rows,
244+
ncols=n_cols,
245+
squeeze=False,
246+
figsize=(width * px, height * px) if width and height else None,
247+
layout="constrained",
248+
)
249+
250+
for i, (ax, lines) in enumerate(zip(axes.ravel(), lines_list, strict=False)):
251+
_line_plot_matplotlib(
252+
lines,
253+
title=titles[i] if titles else None,
254+
xlabel=xlabel,
255+
ylabel=ylabel,
256+
template=template,
257+
height=None,
258+
width=None,
259+
legend_properties=None,
260+
margin_properties=None,
261+
horizontal_line=None,
262+
subplot=ax,
263+
)
264+
265+
fig.legend(**legend_properties or {})
266+
267+
return axes
268+
269+
143270
BACKEND_AVAILABILITY_AND_LINE_PLOT_FUNCTION: dict[
144-
str, tuple[bool, LinePlotFunction]
271+
str, tuple[bool, LinePlotFunction, GridLinePlotFunction]
145272
] = {
146-
"plotly": (True, _line_plot_plotly),
147-
"matplotlib": (IS_MATPLOTLIB_INSTALLED, _line_plot_matplotlib),
273+
"plotly": (True, _line_plot_plotly, _grid_line_plot_plotly),
274+
"matplotlib": (
275+
IS_MATPLOTLIB_INSTALLED,
276+
_line_plot_matplotlib,
277+
_grid_line_plot_matplotlib,
278+
),
148279
}
149280

150281

@@ -184,6 +315,64 @@ def line_plot(
184315
A figure object corresponding to the specified backend.
185316
186317
"""
318+
_line_plot_backend_function = cast(
319+
LinePlotFunction, _get_plot_function(backend, grid_plot=False)
320+
)
321+
322+
fig = _line_plot_backend_function(
323+
lines,
324+
title=title,
325+
xlabel=xlabel,
326+
ylabel=ylabel,
327+
template=template,
328+
height=height,
329+
width=width,
330+
legend_properties=legend_properties,
331+
margin_properties=margin_properties,
332+
horizontal_line=horizontal_line,
333+
)
334+
335+
return fig
336+
337+
338+
def grid_line_plot(
339+
lines_list: list[list[LineData]],
340+
backend: Literal["plotly", "matplotlib"] = "plotly",
341+
*,
342+
n_rows: int,
343+
n_cols: int,
344+
titles: list[str] | None = None,
345+
xlabel: str | None = None,
346+
ylabel: str | None = None,
347+
template: str | None = None,
348+
height: int | None = None,
349+
width: int | None = None,
350+
legend_properties: dict[str, Any] | None = None,
351+
margin_properties: dict[str, Any] | None = None,
352+
) -> Any:
353+
_grid_line_plot_backend_function = cast(
354+
GridLinePlotFunction, _get_plot_function(backend, grid_plot=True)
355+
)
356+
fig = _grid_line_plot_backend_function(
357+
lines_list,
358+
n_rows=n_rows,
359+
n_cols=n_cols,
360+
titles=titles,
361+
xlabel=xlabel,
362+
ylabel=ylabel,
363+
template=template,
364+
height=height,
365+
width=width,
366+
legend_properties=legend_properties,
367+
margin_properties=margin_properties,
368+
)
369+
370+
return fig
371+
372+
373+
def _get_plot_function(
374+
backend: str, grid_plot: bool
375+
) -> LinePlotFunction | GridLinePlotFunction:
187376
if backend not in BACKEND_AVAILABILITY_AND_LINE_PLOT_FUNCTION:
188377
msg = (
189378
f"Invalid plotting backend '{backend}'. "
@@ -192,9 +381,11 @@ def line_plot(
192381
)
193382
raise InvalidPlottingBackendError(msg)
194383

195-
_is_backend_available, _line_plot_backend_function = (
196-
BACKEND_AVAILABILITY_AND_LINE_PLOT_FUNCTION[backend]
197-
)
384+
(
385+
_is_backend_available,
386+
_line_plot_backend_function,
387+
_grid_line_plot_backend_function,
388+
) = BACKEND_AVAILABILITY_AND_LINE_PLOT_FUNCTION[backend]
198389

199390
if not _is_backend_available:
200391
msg = (
@@ -204,17 +395,7 @@ def line_plot(
204395
)
205396
raise NotInstalledError(msg)
206397

207-
fig = _line_plot_backend_function(
208-
lines,
209-
title=title,
210-
xlabel=xlabel,
211-
ylabel=ylabel,
212-
template=template,
213-
height=height,
214-
width=width,
215-
legend_properties=legend_properties,
216-
margin_properties=margin_properties,
217-
horizontal_line=horizontal_line,
218-
)
219-
220-
return fig
398+
if grid_plot:
399+
return _grid_line_plot_backend_function
400+
else:
401+
return _line_plot_backend_function

0 commit comments

Comments
 (0)