1- from typing import TYPE_CHECKING , Any , Literal , Protocol , runtime_checkable
1+ import itertools
2+ from typing import TYPE_CHECKING , Any , Literal , Protocol , cast , runtime_checkable
23
4+ import numpy as np
35import plotly .graph_objects as go
46
57from optimagic .config import IS_MATPLOTLIB_INSTALLED
@@ -25,6 +27,26 @@ def __call__(
2527 legend_properties : dict [str , Any ] | None ,
2628 margin_properties : dict [str , Any ] | None ,
2729 horizontal_line : float | None ,
30+ subplot : Any | None = None ,
31+ ) -> Any : ...
32+
33+
34+ @runtime_checkable
35+ class GridLinePlotFunction (Protocol ):
36+ def __call__ (
37+ self ,
38+ lines_list : list [list [LineData ]],
39+ * ,
40+ n_rows : int ,
41+ n_cols : int ,
42+ titles : list [str ] | None ,
43+ xlabel : str | None ,
44+ ylabel : str | None ,
45+ template : str | None ,
46+ height : int | None ,
47+ width : int | None ,
48+ legend_properties : dict [str , Any ] | None ,
49+ margin_properties : dict [str , Any ] | None ,
2850 ) -> Any : ...
2951
3052
@@ -40,16 +62,20 @@ def _line_plot_plotly(
4062 legend_properties : dict [str , Any ] | None ,
4163 margin_properties : dict [str , Any ] | None ,
4264 horizontal_line : float | None ,
65+ subplot : tuple [go .Figure , int , int ] | None = None ,
4366) -> go .Figure :
4467 if template is None :
4568 template = "simple_white"
4669
47- fig = go .Figure ()
70+ if subplot is None :
71+ fig = go .Figure ()
72+ row , col = None , None
73+
74+ else :
75+ fig , row , col = subplot
4876
4977 fig .update_layout (
5078 title = title ,
51- xaxis_title = xlabel .format (linebreak = "<br>" ) if xlabel else None ,
52- yaxis_title = ylabel ,
5379 template = template ,
5480 height = height ,
5581 width = width ,
@@ -62,6 +88,8 @@ def _line_plot_plotly(
6288 y = horizontal_line ,
6389 line_width = fig .layout .yaxis .linewidth or 1 ,
6490 opacity = 1.0 ,
91+ row = row ,
92+ col = col ,
6593 )
6694
6795 for line in lines :
@@ -72,8 +100,60 @@ def _line_plot_plotly(
72100 line_color = line .color ,
73101 mode = "lines" ,
74102 showlegend = line .show_in_legend ,
103+ legendgroup = line .name ,
104+ )
105+ fig .add_trace (trace , row = row , col = col )
106+
107+ fig .update_xaxes (
108+ title = xlabel .format (linebreak = "<br>" ) if xlabel else None , row = row , col = col
109+ )
110+ fig .update_yaxes (
111+ title = ylabel .format (linebreak = "<br>" ) if ylabel else None , row = row , col = col
112+ )
113+
114+ return fig
115+
116+
117+ def _grid_line_plot_plotly (
118+ lines_list : list [list [LineData ]],
119+ * ,
120+ n_rows : int ,
121+ n_cols : int ,
122+ titles : list [str ] | None ,
123+ xlabel : str | None ,
124+ ylabel : str | None ,
125+ template : str | None ,
126+ height : int | None ,
127+ width : int | None ,
128+ legend_properties : dict [str , Any ] | None ,
129+ margin_properties : dict [str , Any ] | None ,
130+ ) -> go .Figure :
131+ from plotly .subplots import make_subplots
132+
133+ fig = make_subplots (
134+ rows = n_rows ,
135+ cols = n_cols ,
136+ subplot_titles = titles ,
137+ )
138+
139+ for lines , (row , col ) in zip (
140+ lines_list ,
141+ itertools .product (range (1 , n_rows + 1 ), range (1 , n_cols + 1 )),
142+ strict = False ,
143+ ):
144+ _line_plot_plotly (
145+ lines ,
146+ title = None ,
147+ xlabel = xlabel ,
148+ ylabel = ylabel ,
149+ template = template ,
150+ height = height ,
151+ width = width ,
152+ legend_properties = legend_properties ,
153+ margin_properties = margin_properties ,
154+ horizontal_line = None ,
155+ subplot = (fig , row , col ),
75156 )
76- fig .add_trace (trace )
77157
78158 return fig
79159
@@ -90,6 +170,7 @@ def _line_plot_matplotlib(
90170 legend_properties : dict [str , Any ] | None ,
91171 margin_properties : dict [str , Any ] | None ,
92172 horizontal_line : float | None ,
173+ subplot : "plt.Axes | None" = None ,
93174) -> "plt.Axes" :
94175 import matplotlib .pyplot as plt
95176
@@ -105,10 +186,14 @@ def _line_plot_matplotlib(
105186 template = "default"
106187
107188 with plt .style .context (template ):
108- px = 1 / plt .rcParams ["figure.dpi" ] # pixel in inches
109- fig , ax = plt .subplots (
110- figsize = (width * px , height * px ) if width and height else None
111- )
189+ if subplot is None :
190+ px = 1 / plt .rcParams ["figure.dpi" ] # pixel in inches
191+ fig , ax = plt .subplots (
192+ figsize = (width * px , height * px ) if width and height else None ,
193+ layout = "constrained" ,
194+ )
195+ else :
196+ ax = subplot
112197
113198 if horizontal_line is not None :
114199 ax .axhline (
@@ -128,23 +213,69 @@ def _line_plot_matplotlib(
128213 ax .set (
129214 title = title ,
130215 xlabel = xlabel .format (linebreak = "\n " ) if xlabel else None ,
131- ylabel = ylabel ,
216+ ylabel = ylabel . format ( linebreak = " \n " ) if ylabel else None ,
132217 )
133218
134- if legend_properties is None :
135- legend_properties = {}
136- ax .legend (** legend_properties )
137-
138- fig .tight_layout ()
219+ if legend_properties is not None :
220+ ax .legend (** legend_properties )
139221
140222 return ax
141223
142224
225+ def _grid_line_plot_matplotlib (
226+ lines_list : list [list [LineData ]],
227+ * ,
228+ n_rows : int ,
229+ n_cols : int ,
230+ titles : list [str ] | None ,
231+ xlabel : str | None ,
232+ ylabel : str | None ,
233+ template : str | None ,
234+ height : int | None ,
235+ width : int | None ,
236+ legend_properties : dict [str , Any ] | None ,
237+ margin_properties : dict [str , Any ] | None ,
238+ ) -> np .ndarray :
239+ import matplotlib .pyplot as plt
240+
241+ px = 1 / plt .rcParams ["figure.dpi" ] # pixel in inches
242+ fig , axes = plt .subplots (
243+ nrows = n_rows ,
244+ ncols = n_cols ,
245+ squeeze = False ,
246+ figsize = (width * px , height * px ) if width and height else None ,
247+ layout = "constrained" ,
248+ )
249+
250+ for i , (ax , lines ) in enumerate (zip (axes .ravel (), lines_list , strict = False )):
251+ _line_plot_matplotlib (
252+ lines ,
253+ title = titles [i ] if titles else None ,
254+ xlabel = xlabel ,
255+ ylabel = ylabel ,
256+ template = template ,
257+ height = None ,
258+ width = None ,
259+ legend_properties = None ,
260+ margin_properties = None ,
261+ horizontal_line = None ,
262+ subplot = ax ,
263+ )
264+
265+ fig .legend (** legend_properties or {})
266+
267+ return axes
268+
269+
143270BACKEND_AVAILABILITY_AND_LINE_PLOT_FUNCTION : dict [
144- str , tuple [bool , LinePlotFunction ]
271+ str , tuple [bool , LinePlotFunction , GridLinePlotFunction ]
145272] = {
146- "plotly" : (True , _line_plot_plotly ),
147- "matplotlib" : (IS_MATPLOTLIB_INSTALLED , _line_plot_matplotlib ),
273+ "plotly" : (True , _line_plot_plotly , _grid_line_plot_plotly ),
274+ "matplotlib" : (
275+ IS_MATPLOTLIB_INSTALLED ,
276+ _line_plot_matplotlib ,
277+ _grid_line_plot_matplotlib ,
278+ ),
148279}
149280
150281
@@ -184,6 +315,64 @@ def line_plot(
184315 A figure object corresponding to the specified backend.
185316
186317 """
318+ _line_plot_backend_function = cast (
319+ LinePlotFunction , _get_plot_function (backend , grid_plot = False )
320+ )
321+
322+ fig = _line_plot_backend_function (
323+ lines ,
324+ title = title ,
325+ xlabel = xlabel ,
326+ ylabel = ylabel ,
327+ template = template ,
328+ height = height ,
329+ width = width ,
330+ legend_properties = legend_properties ,
331+ margin_properties = margin_properties ,
332+ horizontal_line = horizontal_line ,
333+ )
334+
335+ return fig
336+
337+
338+ def grid_line_plot (
339+ lines_list : list [list [LineData ]],
340+ backend : Literal ["plotly" , "matplotlib" ] = "plotly" ,
341+ * ,
342+ n_rows : int ,
343+ n_cols : int ,
344+ titles : list [str ] | None = None ,
345+ xlabel : str | None = None ,
346+ ylabel : str | None = None ,
347+ template : str | None = None ,
348+ height : int | None = None ,
349+ width : int | None = None ,
350+ legend_properties : dict [str , Any ] | None = None ,
351+ margin_properties : dict [str , Any ] | None = None ,
352+ ) -> Any :
353+ _grid_line_plot_backend_function = cast (
354+ GridLinePlotFunction , _get_plot_function (backend , grid_plot = True )
355+ )
356+ fig = _grid_line_plot_backend_function (
357+ lines_list ,
358+ n_rows = n_rows ,
359+ n_cols = n_cols ,
360+ titles = titles ,
361+ xlabel = xlabel ,
362+ ylabel = ylabel ,
363+ template = template ,
364+ height = height ,
365+ width = width ,
366+ legend_properties = legend_properties ,
367+ margin_properties = margin_properties ,
368+ )
369+
370+ return fig
371+
372+
373+ def _get_plot_function (
374+ backend : str , grid_plot : bool
375+ ) -> LinePlotFunction | GridLinePlotFunction :
187376 if backend not in BACKEND_AVAILABILITY_AND_LINE_PLOT_FUNCTION :
188377 msg = (
189378 f"Invalid plotting backend '{ backend } '. "
@@ -192,9 +381,11 @@ def line_plot(
192381 )
193382 raise InvalidPlottingBackendError (msg )
194383
195- _is_backend_available , _line_plot_backend_function = (
196- BACKEND_AVAILABILITY_AND_LINE_PLOT_FUNCTION [backend ]
197- )
384+ (
385+ _is_backend_available ,
386+ _line_plot_backend_function ,
387+ _grid_line_plot_backend_function ,
388+ ) = BACKEND_AVAILABILITY_AND_LINE_PLOT_FUNCTION [backend ]
198389
199390 if not _is_backend_available :
200391 msg = (
@@ -204,17 +395,7 @@ def line_plot(
204395 )
205396 raise NotInstalledError (msg )
206397
207- fig = _line_plot_backend_function (
208- lines ,
209- title = title ,
210- xlabel = xlabel ,
211- ylabel = ylabel ,
212- template = template ,
213- height = height ,
214- width = width ,
215- legend_properties = legend_properties ,
216- margin_properties = margin_properties ,
217- horizontal_line = horizontal_line ,
218- )
219-
220- return fig
398+ if grid_plot :
399+ return _grid_line_plot_backend_function
400+ else :
401+ return _line_plot_backend_function
0 commit comments