Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions pytorch_forecasting/data/examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,82 @@
return pd.read_parquet(fname)


def get_stallion_dummy_data(seed: int | None = 0) -> pd.DataFrame:
"""
Small dummy dataset for testing.

Returns:
pd.DataFrame: data
"""
rng = np.random.default_rng(seed)
n_series = 350
dates = pd.date_range("2018-01-01", periods=24, freq="M")

Check warning on line 62 in pytorch_forecasting/data/examples.py

View workflow job for this annotation

GitHub Actions / no-softdeps (macos-latest, 3.11)

'M' is deprecated and will be removed in a future version, please use 'ME' instead.

Check warning on line 62 in pytorch_forecasting/data/examples.py

View workflow job for this annotation

GitHub Actions / no-softdeps (macos-latest, 3.11)

'M' is deprecated and will be removed in a future version, please use 'ME' instead.

Check warning on line 62 in pytorch_forecasting/data/examples.py

View workflow job for this annotation

GitHub Actions / no-softdeps (macos-latest, 3.10)

'M' is deprecated and will be removed in a future version, please use 'ME' instead.

Check warning on line 62 in pytorch_forecasting/data/examples.py

View workflow job for this annotation

GitHub Actions / no-softdeps (macos-latest, 3.10)

'M' is deprecated and will be removed in a future version, please use 'ME' instead.

Check warning on line 62 in pytorch_forecasting/data/examples.py

View workflow job for this annotation

GitHub Actions / no-softdeps (ubuntu-latest, 3.11)

'M' is deprecated and will be removed in a future version, please use 'ME' instead.

Check warning on line 62 in pytorch_forecasting/data/examples.py

View workflow job for this annotation

GitHub Actions / no-softdeps (ubuntu-latest, 3.11)

'M' is deprecated and will be removed in a future version, please use 'ME' instead.

Check warning on line 62 in pytorch_forecasting/data/examples.py

View workflow job for this annotation

GitHub Actions / no-softdeps (macos-latest, 3.13)

'M' is deprecated and will be removed in a future version, please use 'ME' instead.

Check warning on line 62 in pytorch_forecasting/data/examples.py

View workflow job for this annotation

GitHub Actions / no-softdeps (macos-latest, 3.13)

'M' is deprecated and will be removed in a future version, please use 'ME' instead.

Check warning on line 62 in pytorch_forecasting/data/examples.py

View workflow job for this annotation

GitHub Actions / no-softdeps (macos-latest, 3.12)

'M' is deprecated and will be removed in a future version, please use 'ME' instead.

Check warning on line 62 in pytorch_forecasting/data/examples.py

View workflow job for this annotation

GitHub Actions / no-softdeps (macos-latest, 3.12)

'M' is deprecated and will be removed in a future version, please use 'ME' instead.

Check warning on line 62 in pytorch_forecasting/data/examples.py

View workflow job for this annotation

GitHub Actions / no-softdeps (ubuntu-latest, 3.12)

'M' is deprecated and will be removed in a future version, please use 'ME' instead.

Check warning on line 62 in pytorch_forecasting/data/examples.py

View workflow job for this annotation

GitHub Actions / no-softdeps (ubuntu-latest, 3.12)

'M' is deprecated and will be removed in a future version, please use 'ME' instead.

Check warning on line 62 in pytorch_forecasting/data/examples.py

View workflow job for this annotation

GitHub Actions / no-softdeps (ubuntu-latest, 3.13)

'M' is deprecated and will be removed in a future version, please use 'ME' instead.

Check warning on line 62 in pytorch_forecasting/data/examples.py

View workflow job for this annotation

GitHub Actions / no-softdeps (ubuntu-latest, 3.13)

'M' is deprecated and will be removed in a future version, please use 'ME' instead.

Check warning on line 62 in pytorch_forecasting/data/examples.py

View workflow job for this annotation

GitHub Actions / no-softdeps (ubuntu-latest, 3.10)

'M' is deprecated and will be removed in a future version, please use 'ME' instead.

Check warning on line 62 in pytorch_forecasting/data/examples.py

View workflow job for this annotation

GitHub Actions / no-softdeps (ubuntu-latest, 3.10)

'M' is deprecated and will be removed in a future version, please use 'ME' instead.

Check warning on line 62 in pytorch_forecasting/data/examples.py

View workflow job for this annotation

GitHub Actions / no-softdeps (windows-latest, 3.10)

'M' is deprecated and will be removed in a future version, please use 'ME' instead.

Check warning on line 62 in pytorch_forecasting/data/examples.py

View workflow job for this annotation

GitHub Actions / no-softdeps (windows-latest, 3.10)

'M' is deprecated and will be removed in a future version, please use 'ME' instead.

Check warning on line 62 in pytorch_forecasting/data/examples.py

View workflow job for this annotation

GitHub Actions / no-softdeps (windows-latest, 3.12)

'M' is deprecated and will be removed in a future version, please use 'ME' instead.

Check warning on line 62 in pytorch_forecasting/data/examples.py

View workflow job for this annotation

GitHub Actions / no-softdeps (windows-latest, 3.12)

'M' is deprecated and will be removed in a future version, please use 'ME' instead.

Check warning on line 62 in pytorch_forecasting/data/examples.py

View workflow job for this annotation

GitHub Actions / no-softdeps (windows-latest, 3.13)

'M' is deprecated and will be removed in a future version, please use 'ME' instead.

Check warning on line 62 in pytorch_forecasting/data/examples.py

View workflow job for this annotation

GitHub Actions / no-softdeps (windows-latest, 3.13)

'M' is deprecated and will be removed in a future version, please use 'ME' instead.

Check warning on line 62 in pytorch_forecasting/data/examples.py

View workflow job for this annotation

GitHub Actions / no-softdeps (windows-latest, 3.11)

'M' is deprecated and will be removed in a future version, please use 'ME' instead.

Check warning on line 62 in pytorch_forecasting/data/examples.py

View workflow job for this annotation

GitHub Actions / no-softdeps (windows-latest, 3.11)

'M' is deprecated and will be removed in a future version, please use 'ME' instead.
agency_list = [f"Agency_{i:02d}" for i in range(20)]
sku_list = [f"SKU_{i:02d}" for i in range(25)]
pairs = [(a, s) for a in agency_list for s in sku_list]
selected_pairs = pairs[:n_series]
agencies = np.array([p[0] for p in selected_pairs])
skus = np.array([p[1] for p in selected_pairs])
n_rows = len(selected_pairs) * len(dates)

df = pd.DataFrame(
{
"agency": np.repeat(agencies, len(dates)),
"sku": np.repeat(skus, len(dates)),
"volume": rng.lognormal(7.0, 0.8, n_rows),
"date": np.tile(dates, len(selected_pairs)),
"industry_volume": np.clip(
rng.normal(5.4e8, 6.3e7, n_rows), 4.0e8, None
).astype(np.int64),
"soda_volume": np.clip(
rng.normal(8.5e8, 8.0e7, n_rows), 6.5e8, None
).astype(np.int64),
"avg_max_temp": rng.normal(28.5, 4.0, n_rows),
"price_regular": np.clip(rng.normal(1500.0, 450.0, n_rows), 100.0, 2000.0),
"discount": rng.gamma(2.0, 5.0, n_rows),
"avg_population_2017": np.clip(
rng.normal(60000, 8000, n_rows), 20000, None
).astype(np.int64),
"avg_yearly_household_income_2017": np.clip(
rng.normal(35000, 5000, n_rows), 15000, None
).astype(np.int64),
"timeseries": np.repeat(np.arange(len(selected_pairs)), len(dates)),
}
)

df["price_actual"] = np.maximum(df["price_regular"] - df["discount"], 0.0)
df["discount_in_percent"] = (
df["discount"] / np.maximum(df["price_regular"], 1.0)
) * 100.0

holiday_cols = [
"easter_day",
"good_friday",
"new_year",
"christmas",
"labor_day",
"independence_day",
"revolution_day_memorial",
"regional_games",
"fifa_u_17_world_cup",
"football_gold_cup",
"beer_capital",
"music_fest",
]
for col in holiday_cols:
df[col] = rng.binomial(1, 0.08, n_rows).astype(np.int64)

df["agency"] = df["agency"].astype("category")
df["sku"] = df["sku"].astype("category")
df["avg_max_temp"] = df["avg_max_temp"].astype(np.float64)
df["price_regular"] = df["price_regular"].astype(np.float64)
df["price_actual"] = df["price_actual"].astype(np.float64)
df["discount"] = df["discount"].astype(np.float64)
df = df.sort_values(["agency", "sku", "date"]).reset_index(drop=True)

return df


def generate_ar_data(
n_series: int = 10,
timesteps: int = 400,
Expand Down
1 change: 0 additions & 1 deletion pytorch_forecasting/tests/_conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from pytorch_forecasting import TimeSeriesDataSet
from pytorch_forecasting.data import EncoderNormalizer, GroupNormalizer, NaNLabelEncoder
from pytorch_forecasting.data.examples import generate_ar_data, get_stallion_data

torch.manual_seed(23)

Expand Down
4 changes: 2 additions & 2 deletions pytorch_forecasting/tests/_data_scenarios.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@

from pytorch_forecasting import TimeSeriesDataSet
from pytorch_forecasting.data import EncoderNormalizer, GroupNormalizer, NaNLabelEncoder
from pytorch_forecasting.data.examples import generate_ar_data, get_stallion_data
from pytorch_forecasting.data.examples import generate_ar_data, get_stallion_dummy_data
from pytorch_forecasting.data.timeseries import TimeSeries

torch.manual_seed(23)


def data_with_covariates():
data = get_stallion_data()
data = get_stallion_dummy_data()
data["month"] = data.date.dt.month.astype(str)
data["log_volume"] = np.log1p(data.volume)
data["weight"] = 1 + np.sqrt(data.volume)
Expand Down
5 changes: 2 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@


from pytorch_forecasting import TimeSeriesDataSet # isort:skip
from pytorch_forecasting.data.examples import get_stallion_data # isort:skip

from pytorch_forecasting.data.examples import get_stallion_dummy_data # isort:skip

# for vscode debugging: https://stackoverflow.com/a/62563106/14121677
if os.getenv("_PYTEST_RAISE", "0") != "0":
Expand All @@ -25,7 +24,7 @@ def pytest_internalerror(excinfo):

@pytest.fixture(scope="session")
def test_data():
data = get_stallion_data()
data = get_stallion_dummy_data()
data["month"] = data.date.dt.month.astype(str)
data["log_volume"] = np.log1p(data.volume)
data["weight"] = 1 + np.sqrt(data.volume)
Expand Down
Loading