Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
**Fixed**

**Dependencies**
- Changed `Timeseries.plot()` implementation to no longer rely on xarray under the hood while keeping the same functionality. [#2932](https://github.com/unit8co/darts/pull/2932) by [Jakub Chłapek](https://github.com/jakubchlapek)

### For developers of the library:

Expand Down
246 changes: 246 additions & 0 deletions darts/tests/test_timeseries_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
from itertools import product
from unittest.mock import patch

import matplotlib.collections as mcollections
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pytest

from darts import TimeSeries
from darts.utils.utils import generate_index


class TestTimeSeriesPlot:
# datetime index, deterministic
n_comps = 2
series_dt_d = TimeSeries.from_times_and_values(
times=generate_index(start="2000-01-01", length=10, freq="D"),
values=np.random.random((10, n_comps, 1)),
)
# datetime index, probabilistic
series_dt_p = TimeSeries.from_times_and_values(
times=generate_index(start="2000-01-01", length=10, freq="D"),
values=np.random.random((10, n_comps, 5)),
)
# range index, deterministic
series_ri_d = TimeSeries.from_times_and_values(
times=generate_index(start=0, length=10, freq=1),
values=np.random.random((10, n_comps, 1)),
)
# range index, probabilistic
series_ri_p = TimeSeries.from_times_and_values(
times=generate_index(start=0, length=10, freq=1),
values=np.random.random((10, n_comps, 5)),
)

@patch("matplotlib.pyplot.show")
@pytest.mark.parametrize(
"config",
product(
["dt", "ri"],
["d", "p"],
[True, False],
),
)
def test_plot_single_series(self, mock_show, config):
index_type, stoch_type, use_ax = config
series = getattr(self, f"series_{index_type}_{stoch_type}")
if use_ax:
_, ax = plt.subplots()
else:
ax = None
series.plot(ax=ax)

# For deterministic series with len > 1: one line per component
# For probabilistic series with len > 1: one line per component + one area per component
ax = ax if use_ax else plt.gca()

# Count lines (Line2D objects with multiple data points representing actual lines)
lines = [line for line in ax.lines if len(line.get_xdata()) > 1]
assert len(lines) == self.n_comps

# For probabilistic: count filled areas (PolyCollection from fill_between)
if series.is_stochastic:
areas = [
coll
for coll in ax.collections
if isinstance(coll, mcollections.PolyCollection)
]
assert len(areas) == self.n_comps

plt.show()
plt.close()

@patch("matplotlib.pyplot.show")
@pytest.mark.parametrize(
"config",
product(
["dt", "ri"],
["d", "p"],
),
)
def test_plot_point_series(self, mock_show, config):
index_type, stoch_type = config
series = getattr(self, f"series_{index_type}_{stoch_type}")
series = series[:1]
series.plot()

# For deterministic series with len == 1: one point per component
# For probabilistic series with len == 1: one point per component + one vertical line per component
ax = plt.gca()

# Count points (Line2D objects with markers representing single points)
points = [
line
for line in ax.lines
if len(line.get_xdata()) == 1 and line.get_marker() != "None"
]
assert len(points) == self.n_comps

# For probabilistic: count vertical lines for confidence intervals
if series.is_stochastic:
# The confidence interval is plotted as a line with "-+" marker
# It's a vertical line where x-coordinates are the same
vert_lines = []
for line in ax.lines:
xdata = np.asarray(line.get_xdata())
ydata = np.asarray(line.get_ydata())
if len(xdata) == 2 and len(ydata) == 2:
# check if x-coords are the same (vertical line)
xdiff = xdata[0] - xdata[1]

if isinstance(xdiff, pd.Timedelta):
xdiff = xdiff.total_seconds()

if abs(xdiff) < 1e-10:
vert_lines.append(line)
assert len(vert_lines) == self.n_comps

plt.show()
plt.close()

@patch("matplotlib.pyplot.show")
@pytest.mark.parametrize(
"config",
product(
["dt", "ri"],
["d", "p"],
),
)
def test_plot_empty_series(self, mock_show, config):
index_type, stoch_type = config
series = getattr(self, f"series_{index_type}_{stoch_type}")
series = series[:0]
series.plot()

# For len == 0: no points or lines should be plotted
ax = plt.gca()
# empty plot creates a line with empty data, but we want to check for actual plotted content
# no points
points = [
line
for line in ax.lines
if len(line.get_xdata()) == 1 and line.get_marker() != "None"
]
assert len(points) == 0

# no lines
lines_meaningful = [line for line in ax.lines if len(line.get_xdata()) > 1]
assert len(lines_meaningful) == 0

# no areas
areas = [
coll
for coll in ax.collections
if isinstance(coll, mcollections.PolyCollection)
]
assert len(areas) == 0

plt.show()
plt.close()

@patch("matplotlib.pyplot.show")
@pytest.mark.parametrize(
"config",
product(
["dt", "ri"],
["d", "p"],
[
{"new_plot": True},
{"default_formatting": False},
{"title": "my title"},
{"label": "comps"},
{"label": ["comps_1", "comps_2"]},
{"alpha": 0.1, "color": "blue"},
{"color": ["blue", "red"]},
{"lw": 2},
],
),
)
def test_plot_params(self, mock_show, config):
index_type, stoch_type, kwargs = config
series = getattr(self, f"series_{index_type}_{stoch_type}")
series.plot(**kwargs)
plt.show()
plt.close()

@patch("matplotlib.pyplot.show")
@pytest.mark.parametrize(
"config",
product(
["dt", "ri"],
[
{"central_quantile": "mean"},
{"central_quantile": 0.5},
{
"low_quantile": 0.2,
"central_quantile": 0.6,
"high_quantile": 0.7,
"alpha": 0.1,
},
],
),
)
def test_plot_stochastic_params(self, mock_show, config):
(index_type, kwargs), stoch_type = config, "p"
series = getattr(self, f"series_{index_type}_{stoch_type}")
series.plot(**kwargs)
plt.show()
plt.close()

@patch("matplotlib.pyplot.show")
@pytest.mark.parametrize("config", ["dt", "ri"])
def test_plot_multiple_series(self, mock_show, config):
index_type = config
series1 = getattr(self, f"series_{index_type}_d")
series2 = getattr(self, f"series_{index_type}_p")
series1.plot()
series2.plot()
plt.show()
plt.close()

@patch("matplotlib.pyplot.show")
@pytest.mark.parametrize("config", ["dt", "ri"])
def test_plot_deterministic_and_stochastic(self, mock_show, config):
index_type = config
series1 = getattr(self, f"series_{index_type}_d")
series2 = getattr(self, f"series_{index_type}_p")
series1.plot()
series2.plot()
plt.show()
plt.close()

@patch("matplotlib.pyplot.show")
@pytest.mark.parametrize("config", ["d", "p"])
def test_cannot_plot_different_index_types(self, mock_show, config):
stoch_type = config
series1 = getattr(self, f"series_dt_{stoch_type}")
series2 = getattr(self, f"series_ri_{stoch_type}")
# datetime index plot changes x-axis to use datetime index
series1.plot()
# cannot plot a range index on datetime index
with pytest.raises(TypeError):
series2.plot()
plt.show()
plt.close()
65 changes: 37 additions & 28 deletions darts/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -4323,8 +4323,6 @@ def plot(
) -> matplotlib.axes.Axes:
"""Plot the series.

This is a wrapper method around :func:`xarray.DataArray.plot()`.

Parameters
----------
new_plot
Expand All @@ -4345,7 +4343,7 @@ def plot(
default_formatting
Whether to use the darts default scheme.
title
Optionally, a custom plot title. If `None`, will use the name of the underlying `xarray.DataArray`.
Optionally, a plot title.
label
Can either be a string or list of strings. If a string and the series only has a single component, it is
used as the label for that component. If a string and the series has multiple components, it is used as
Expand Down Expand Up @@ -4457,19 +4455,18 @@ def plot(
if ax is None:
ax = plt.gca()

# TODO: migrate from xarray plotting to something else
data_array = self.data_array(copy=False)
for i, c in enumerate(data_array.component[:n_components_to_plot]):
comp_name = str(c.values)
comp = data_array.sel(component=c)
for i, comp_name in enumerate(self.components[:n_components_to_plot]):
comp_ts = self[comp_name]

if comp.sample.size > 1:
if self.is_stochastic:
if central_quantile == "mean":
central_series = comp.mean(dim=DIMS[2])
central_ts = comp_ts.mean()
else:
central_series = comp.quantile(q=central_quantile, dim=DIMS[2])
central_ts = comp_ts.quantile(q=central_quantile)
else:
central_series = comp.mean(dim=DIMS[2])
central_ts = comp_ts

central_series = central_ts.to_series() # shape: (time,)

if custom_labels:
label_to_use = label[i]
Expand All @@ -4484,46 +4481,58 @@ def plot(
kwargs["c"] = color[i] if custom_colors else color

kwargs_central = deepcopy(kwargs)
if not self.is_deterministic:
if self.is_stochastic:
kwargs_central["alpha"] = 1
if central_series.shape[0] > 1:
p = central_series.plot(*args, ax=ax, **kwargs_central)
# empty TimeSeries
elif central_series.shape[0] == 0:
p = ax.plot(
[],
[],
# line plot
if len(central_series) > 1:
p = central_series.plot(
*args,
ax=ax,
**kwargs_central,
)
else:
color_used = (
p.get_lines()[-1].get_color() if default_formatting else None
)
# point plot
elif len(central_series) == 1:
p = ax.plot(
[self.start_time()],
central_series.values[0],
"o",
*args,
**kwargs_central,
)
color_used = p[0].get_color() if default_formatting else None
# empty plot
else:
p = ax.plot(
[],
[],
*args,
**kwargs_central,
)
color_used = p[0].get_color() if default_formatting else None
ax.set_xlabel(self.time_dim)
color_used = p[0].get_color() if default_formatting else None

# Optionally show confidence intervals
if (
comp.sample.size > 1
self.is_stochastic
and low_quantile is not None
and high_quantile is not None
):
low_series = comp.quantile(q=low_quantile, dim=DIMS[2])
high_series = comp.quantile(q=high_quantile, dim=DIMS[2])
if low_series.shape[0] > 1:
low_series = comp_ts.quantile(q=low_quantile).to_series()
high_series = comp_ts.quantile(q=high_quantile).to_series()
# filled area
if len(low_series) > 1:
ax.fill_between(
self.time_index,
low_series,
high_series,
color=color_used,
alpha=(alpha if alpha is not None else alpha_confidence_intvls),
)
else:
# filled line
elif len(low_series) == 1:
ax.plot(
[self.start_time(), self.start_time()],
[low_series.values[0], high_series.values[0]],
Expand All @@ -4533,7 +4542,7 @@ def plot(
)

ax.legend()
ax.set_title(title if title is not None else data_array.name)
ax.set_title(title if title is not None else "")
return ax

def with_columns_renamed(
Expand Down