diff --git a/pytorch_forecasting/data/examples.py b/pytorch_forecasting/data/examples.py index be89fa6f7..77443b776 100644 --- a/pytorch_forecasting/data/examples.py +++ b/pytorch_forecasting/data/examples.py @@ -50,6 +50,82 @@ def get_stallion_data() -> pd.DataFrame: return pd.read_parquet(fname) +def get_stallion_dummy_data(seed: int | None = 0) -> pd.DataFrame: + """ + Small dummy dataset for testing. + + Returns: + pd.DataFrame: data + """ + rng = np.random.default_rng(seed) + n_series = 350 + dates = pd.date_range("2018-01-01", periods=24, freq="M") + agency_list = [f"Agency_{i:02d}" for i in range(20)] + sku_list = [f"SKU_{i:02d}" for i in range(25)] + pairs = [(a, s) for a in agency_list for s in sku_list] + selected_pairs = pairs[:n_series] + agencies = np.array([p[0] for p in selected_pairs]) + skus = np.array([p[1] for p in selected_pairs]) + n_rows = len(selected_pairs) * len(dates) + + df = pd.DataFrame( + { + "agency": np.repeat(agencies, len(dates)), + "sku": np.repeat(skus, len(dates)), + "volume": rng.lognormal(7.0, 0.8, n_rows), + "date": np.tile(dates, len(selected_pairs)), + "industry_volume": np.clip( + rng.normal(5.4e8, 6.3e7, n_rows), 4.0e8, None + ).astype(np.int64), + "soda_volume": np.clip( + rng.normal(8.5e8, 8.0e7, n_rows), 6.5e8, None + ).astype(np.int64), + "avg_max_temp": rng.normal(28.5, 4.0, n_rows), + "price_regular": np.clip(rng.normal(1500.0, 450.0, n_rows), 100.0, 2000.0), + "discount": rng.gamma(2.0, 5.0, n_rows), + "avg_population_2017": np.clip( + rng.normal(60000, 8000, n_rows), 20000, None + ).astype(np.int64), + "avg_yearly_household_income_2017": np.clip( + rng.normal(35000, 5000, n_rows), 15000, None + ).astype(np.int64), + "timeseries": np.repeat(np.arange(len(selected_pairs)), len(dates)), + } + ) + + df["price_actual"] = np.maximum(df["price_regular"] - df["discount"], 0.0) + df["discount_in_percent"] = ( + df["discount"] / np.maximum(df["price_regular"], 1.0) + ) * 100.0 + + holiday_cols = [ + "easter_day", + "good_friday", + "new_year", + "christmas", + "labor_day", + "independence_day", + "revolution_day_memorial", + "regional_games", + "fifa_u_17_world_cup", + "football_gold_cup", + "beer_capital", + "music_fest", + ] + for col in holiday_cols: + df[col] = rng.binomial(1, 0.08, n_rows).astype(np.int64) + + df["agency"] = df["agency"].astype("category") + df["sku"] = df["sku"].astype("category") + df["avg_max_temp"] = df["avg_max_temp"].astype(np.float64) + df["price_regular"] = df["price_regular"].astype(np.float64) + df["price_actual"] = df["price_actual"].astype(np.float64) + df["discount"] = df["discount"].astype(np.float64) + df = df.sort_values(["agency", "sku", "date"]).reset_index(drop=True) + + return df + + def generate_ar_data( n_series: int = 10, timesteps: int = 400, diff --git a/pytorch_forecasting/tests/_conftest.py b/pytorch_forecasting/tests/_conftest.py index 8def0bfe2..def967ab9 100644 --- a/pytorch_forecasting/tests/_conftest.py +++ b/pytorch_forecasting/tests/_conftest.py @@ -4,7 +4,6 @@ from pytorch_forecasting import TimeSeriesDataSet from pytorch_forecasting.data import EncoderNormalizer, GroupNormalizer, NaNLabelEncoder -from pytorch_forecasting.data.examples import generate_ar_data, get_stallion_data torch.manual_seed(23) diff --git a/pytorch_forecasting/tests/_data_scenarios.py b/pytorch_forecasting/tests/_data_scenarios.py index c13ff0ae5..13ee4a8b1 100644 --- a/pytorch_forecasting/tests/_data_scenarios.py +++ b/pytorch_forecasting/tests/_data_scenarios.py @@ -6,14 +6,14 @@ from pytorch_forecasting import TimeSeriesDataSet from pytorch_forecasting.data import EncoderNormalizer, GroupNormalizer, NaNLabelEncoder -from pytorch_forecasting.data.examples import generate_ar_data, get_stallion_data +from pytorch_forecasting.data.examples import generate_ar_data, get_stallion_dummy_data from pytorch_forecasting.data.timeseries import TimeSeries torch.manual_seed(23) def data_with_covariates(): - data = get_stallion_data() + data = get_stallion_dummy_data() data["month"] = data.date.dt.month.astype(str) data["log_volume"] = np.log1p(data.volume) data["weight"] = 1 + np.sqrt(data.volume) diff --git a/tests/conftest.py b/tests/conftest.py index 608f06550..635955321 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,8 +8,7 @@ from pytorch_forecasting import TimeSeriesDataSet # isort:skip -from pytorch_forecasting.data.examples import get_stallion_data # isort:skip - +from pytorch_forecasting.data.examples import get_stallion_dummy_data # isort:skip # for vscode debugging: https://stackoverflow.com/a/62563106/14121677 if os.getenv("_PYTEST_RAISE", "0") != "0": @@ -25,7 +24,7 @@ def pytest_internalerror(excinfo): @pytest.fixture(scope="session") def test_data(): - data = get_stallion_data() + data = get_stallion_dummy_data() data["month"] = data.date.dt.month.astype(str) data["log_volume"] = np.log1p(data.volume) data["weight"] = 1 + np.sqrt(data.volume)