@@ -82,6 +82,12 @@ def _line_plot_plotly(
8282 legend = legend_properties ,
8383 margin = margin_properties ,
8484 )
85+ fig .update_xaxes (
86+ title = xlabel .format (linebreak = "<br>" ) if xlabel else None , row = row , col = col
87+ )
88+ fig .update_yaxes (
89+ title = ylabel .format (linebreak = "<br>" ) if ylabel else None , row = row , col = col
90+ )
8591
8692 if horizontal_line is not None :
8793 fig .add_hline (
@@ -104,13 +110,6 @@ def _line_plot_plotly(
104110 )
105111 fig .add_trace (trace , row = row , col = col )
106112
107- fig .update_xaxes (
108- title = xlabel .format (linebreak = "<br>" ) if xlabel else None , row = row , col = col
109- )
110- fig .update_yaxes (
111- title = ylabel .format (linebreak = "<br>" ) if ylabel else None , row = row , col = col
112- )
113-
114113 return fig
115114
116115
@@ -217,7 +216,7 @@ def _line_plot_matplotlib(
217216 )
218217
219218 if legend_properties is not None :
220- ax .legend (** legend_properties )
219+ fig .legend (** legend_properties )
221220
222221 return ax
223222
@@ -247,9 +246,13 @@ def _grid_line_plot_matplotlib(
247246 layout = "constrained" ,
248247 )
249248
250- for i , (ax , lines ) in enumerate (zip (axes .ravel (), lines_list , strict = False )):
249+ for i , (row , col ) in enumerate (itertools .product (range (n_rows ), range (n_cols ))):
250+ if i >= len (lines_list ):
251+ axes [row , col ].set_visible (False )
252+ continue
253+
251254 _line_plot_matplotlib (
252- lines ,
255+ lines_list [ i ] ,
253256 title = titles [i ] if titles else None ,
254257 xlabel = xlabel ,
255258 ylabel = ylabel ,
@@ -259,26 +262,14 @@ def _grid_line_plot_matplotlib(
259262 legend_properties = None ,
260263 margin_properties = None ,
261264 horizontal_line = None ,
262- subplot = ax ,
265+ subplot = axes [ row , col ] ,
263266 )
264267
265268 fig .legend (** legend_properties or {})
266269
267270 return axes
268271
269272
270- BACKEND_AVAILABILITY_AND_LINE_PLOT_FUNCTION : dict [
271- str , tuple [bool , LinePlotFunction , GridLinePlotFunction ]
272- ] = {
273- "plotly" : (True , _line_plot_plotly , _grid_line_plot_plotly ),
274- "matplotlib" : (
275- IS_MATPLOTLIB_INSTALLED ,
276- _line_plot_matplotlib ,
277- _grid_line_plot_matplotlib ,
278- ),
279- }
280-
281-
282273def line_plot (
283274 lines : list [LineData ],
284275 backend : Literal ["plotly" , "matplotlib" ] = "plotly" ,
@@ -350,9 +341,33 @@ def grid_line_plot(
350341 legend_properties : dict [str , Any ] | None = None ,
351342 margin_properties : dict [str , Any ] | None = None ,
352343) -> Any :
344+ """Create a grid of line plots corresponding to the specified backend.
345+
346+ Args:
347+ lines_list: A list where each element is a list of objects containing data
348+ for the lines in a subplot. The order of sublists determines the order
349+ of subplots in the grid (row-wise), and the order of lines within each
350+ sublist determines the order of lines in that subplot.
351+ backend: The backend to use for plotting.
352+ n_rows: Number of rows in the grid.
353+ n_cols: Number of columns in the grid.
354+ titles: Titles for each subplot in the grid.
355+ xlabel: Label for the x-axis of each subplot.
356+ ylabel: Label for the y-axis of each subplot.
357+ template: Backend-specific template for styling the plots.
358+ height: Height of the entire grid plot (in pixels).
359+ width: Width of the entire grid plot (in pixels).
360+ legend_properties: Backend-specific properties for the legend.
361+ margin_properties: Backend-specific properties for the plot margins.
362+
363+ Returns:
364+ A figure object corresponding to the specified backend.
365+
366+ """
353367 _grid_line_plot_backend_function = cast (
354368 GridLinePlotFunction , _get_plot_function (backend , grid_plot = True )
355369 )
370+
356371 fig = _grid_line_plot_backend_function (
357372 lines_list ,
358373 n_rows = n_rows ,
@@ -370,6 +385,18 @@ def grid_line_plot(
370385 return fig
371386
372387
388+ BACKEND_AVAILABILITY_AND_LINE_PLOT_FUNCTION : dict [
389+ str , tuple [bool , LinePlotFunction , GridLinePlotFunction ]
390+ ] = {
391+ "plotly" : (True , _line_plot_plotly , _grid_line_plot_plotly ),
392+ "matplotlib" : (
393+ IS_MATPLOTLIB_INSTALLED ,
394+ _line_plot_matplotlib ,
395+ _grid_line_plot_matplotlib ,
396+ ),
397+ }
398+
399+
373400def _get_plot_function (
374401 backend : str , grid_plot : bool
375402) -> LinePlotFunction | GridLinePlotFunction :
0 commit comments