From f3bcdded8a95661cd0fb304d7bd6c96dcdeb727e Mon Sep 17 00:00:00 2001 From: Jirka B Date: Wed, 28 May 2025 12:18:06 +0200 Subject: [PATCH 1/3] Replace Black with Ruff formatting and update configuration Switched from Black to Ruff for code formatting to simplify tooling. Updated `.pre-commit-config.yaml` and `pyproject.toml` to reflect the removal of Black and added Ruff-specific formatting configurations. This change ensures consistency and leverages Ruff's integrated formatting capabilities. --- .pre-commit-config.yaml | 5 +---- pyproject.toml | 31 +++++++------------------------ 2 files changed, 8 insertions(+), 28 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e3dc886c1..d5acad4d8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,10 +13,7 @@ repos: hooks: - id: ruff args: [--fix] - - repo: https://github.com/psf/black - rev: 24.8.0 - hooks: - - id: black + - id: ruff-format - repo: https://github.com/nbQA-dev/nbQA rev: 1.8.7 hooks: diff --git a/pyproject.toml b/pyproject.toml index 6bf315aa0..64c481f30 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -142,6 +142,13 @@ exclude = [ ] target-version = "py39" +[tool.ruff.format] +# Enable formatting +quote-style = "double" +indent-style = "space" +skip-magic-trailing-comma = false +line-ending = "auto" + [tool.ruff.lint] select = ["E", "F", "W", "C4", "S"] extend-select = [ @@ -170,30 +177,6 @@ force-sort-within-sections = true "E501", # Line too long being fixed in #1746 To be removed after merging ] -[tool.black] -line-length = 88 -include = '\.pyi?$' -exclude = ''' -( - /( - \.eggs # exclude a few common directories in the - | \.git # root of the project - | \.hg - | \.mypy_cache - | \.tox - | \.venv - | _build - | buck-out - | build - | dist - )/ - | docs/build/ - | node_modules/ - | venve/ - | .venv/ -) -''' - [tool.nbqa.mutate] ruff = 1 black = 1 From 279f13e3ed86534d96f75ddbf65e153b8289be45 Mon Sep 17 00:00:00 2001 From: Jirka B Date: Wed, 28 May 2025 12:19:21 +0200 Subject: [PATCH 2/3] apply formatting --- pytorch_forecasting/data/data_module.py | 1 - pytorch_forecasting/data/samplers.py | 5 +++-- .../data/timeseries/_timeseries.py | 12 ++++-------- .../data/timeseries/_timeseries_v2.py | 1 - pytorch_forecasting/metrics/base_metrics.py | 4 +++- pytorch_forecasting/models/base/_base_object.py | 1 - pytorch_forecasting/models/deepar/_deepar.py | 17 +++++++++++------ pytorch_forecasting/models/rnn/_rnn.py | 17 +++++++++++------ .../temporal_fusion_transformer/_tft_v2.py | 1 - 9 files changed, 32 insertions(+), 27 deletions(-) diff --git a/pytorch_forecasting/data/data_module.py b/pytorch_forecasting/data/data_module.py index ec3c11de9..de7ab5d83 100644 --- a/pytorch_forecasting/data/data_module.py +++ b/pytorch_forecasting/data/data_module.py @@ -108,7 +108,6 @@ def __init__( num_workers: int = 0, train_val_test_split: tuple = (0.7, 0.15, 0.15), ): - self.time_series_dataset = time_series_dataset self.max_encoder_length = max_encoder_length self.min_encoder_length = min_encoder_length diff --git a/pytorch_forecasting/data/samplers.py b/pytorch_forecasting/data/samplers.py index 27add6a31..3fbbd593c 100644 --- a/pytorch_forecasting/data/samplers.py +++ b/pytorch_forecasting/data/samplers.py @@ -51,8 +51,9 @@ def __init__( ) if not isinstance(drop_last, bool): raise ValueError( - "drop_last should be a boolean value, but got " - "drop_last={}".format(drop_last) + "drop_last should be a boolean value, but got " "drop_last={}".format( + drop_last + ) ) self.sampler = sampler self.batch_size = batch_size diff --git a/pytorch_forecasting/data/timeseries/_timeseries.py b/pytorch_forecasting/data/timeseries/_timeseries.py index f384367aa..2c9063633 100644 --- a/pytorch_forecasting/data/timeseries/_timeseries.py +++ b/pytorch_forecasting/data/timeseries/_timeseries.py @@ -2198,27 +2198,23 @@ def __getitem__(self, idx: int) -> tuple[dict[str, torch.Tensor], torch.Tensor]: # select subset of sequence of new sequence if new_encoder_length + new_decoder_length < len(target[0]): data_cat = data_cat[ - encoder_length - - new_encoder_length : encoder_length + encoder_length - new_encoder_length : encoder_length + new_decoder_length ] data_cont = data_cont[ - encoder_length - - new_encoder_length : encoder_length + encoder_length - new_encoder_length : encoder_length + new_decoder_length ] target = [ t[ - encoder_length - - new_encoder_length : encoder_length + encoder_length - new_encoder_length : encoder_length + new_decoder_length ] for t in target ] if weight is not None: weight = weight[ - encoder_length - - new_encoder_length : encoder_length + encoder_length - new_encoder_length : encoder_length + new_decoder_length ] encoder_length = new_encoder_length diff --git a/pytorch_forecasting/data/timeseries/_timeseries_v2.py b/pytorch_forecasting/data/timeseries/_timeseries_v2.py index d5ecbcabb..51db2a108 100644 --- a/pytorch_forecasting/data/timeseries/_timeseries_v2.py +++ b/pytorch_forecasting/data/timeseries/_timeseries_v2.py @@ -90,7 +90,6 @@ def __init__( unknown: Optional[List[Union[str, List[str]]]] = None, static: Optional[List[Union[str, List[str]]]] = None, ): - self.data = data self.data_future = data_future self.time = time diff --git a/pytorch_forecasting/metrics/base_metrics.py b/pytorch_forecasting/metrics/base_metrics.py index c9952ef96..02f2f22c1 100644 --- a/pytorch_forecasting/metrics/base_metrics.py +++ b/pytorch_forecasting/metrics/base_metrics.py @@ -1123,7 +1123,9 @@ def sample(self, y_pred, n_samples: int) -> torch.Tensor: torch.Tensor: tensor with samples (shape batch_size x n_timesteps x n_samples) """ # noqa: E501 dist = self.map_x_to_distribution(y_pred) - samples = dist.sample((n_samples,)).permute( + samples = dist.sample( + (n_samples,) + ).permute( 2, 1, 0 ) # returned as (n_samples, n_timesteps, batch_size), so reshape to (batch_size, n_timesteps, n_samples) # noqa: E501 return samples diff --git a/pytorch_forecasting/models/base/_base_object.py b/pytorch_forecasting/models/base/_base_object.py index 7fd59d6a4..0106b7afa 100644 --- a/pytorch_forecasting/models/base/_base_object.py +++ b/pytorch_forecasting/models/base/_base_object.py @@ -8,7 +8,6 @@ class _BaseObject(_SkbaseBaseObject): - pass diff --git a/pytorch_forecasting/models/deepar/_deepar.py b/pytorch_forecasting/models/deepar/_deepar.py index 3a769d7f0..ae4982cc6 100644 --- a/pytorch_forecasting/models/deepar/_deepar.py +++ b/pytorch_forecasting/models/deepar/_deepar.py @@ -222,13 +222,18 @@ def from_dataset( MultiLoss([NormalDistributionLoss()] * len(dataset.target_names)), ) new_kwargs.update(kwargs) - assert not isinstance(dataset.target_normalizer, NaNLabelEncoder) and ( - not isinstance(dataset.target_normalizer, MultiNormalizer) - or all( - not isinstance(normalizer, NaNLabelEncoder) - for normalizer in dataset.target_normalizer + assert ( + not isinstance(dataset.target_normalizer, NaNLabelEncoder) + and ( + not isinstance(dataset.target_normalizer, MultiNormalizer) + or all( + not isinstance(normalizer, NaNLabelEncoder) + for normalizer in dataset.target_normalizer + ) ) - ), "target(s) should be continuous - categorical targets are not supported" # todo: remove this restriction # noqa: E501 + ), ( + "target(s) should be continuous - categorical targets are not supported" + ) # todo: remove this restriction # noqa: E501 if isinstance(new_kwargs.get("loss", None), MultivariateDistributionLoss): assert ( dataset.min_prediction_length == dataset.max_prediction_length diff --git a/pytorch_forecasting/models/rnn/_rnn.py b/pytorch_forecasting/models/rnn/_rnn.py index a1c0fabd1..c2c4978f7 100644 --- a/pytorch_forecasting/models/rnn/_rnn.py +++ b/pytorch_forecasting/models/rnn/_rnn.py @@ -204,13 +204,18 @@ def from_dataset( dataset=dataset, kwargs=kwargs, default_loss=MAE() ) ) - assert not isinstance(dataset.target_normalizer, NaNLabelEncoder) and ( - not isinstance(dataset.target_normalizer, MultiNormalizer) - or all( - not isinstance(normalizer, NaNLabelEncoder) - for normalizer in dataset.target_normalizer + assert ( + not isinstance(dataset.target_normalizer, NaNLabelEncoder) + and ( + not isinstance(dataset.target_normalizer, MultiNormalizer) + or all( + not isinstance(normalizer, NaNLabelEncoder) + for normalizer in dataset.target_normalizer + ) ) - ), "target(s) should be continuous - categorical targets are not supported" # todo: remove this restriction # noqa: E501 + ), ( + "target(s) should be continuous - categorical targets are not supported" + ) # todo: remove this restriction # noqa: E501 return super().from_dataset( dataset, allowed_encoder_known_variable_names=allowed_encoder_known_variable_names, diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/_tft_v2.py b/pytorch_forecasting/models/temporal_fusion_transformer/_tft_v2.py index a0cf7d39e..c739e9ad5 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/_tft_v2.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/_tft_v2.py @@ -179,7 +179,6 @@ def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: static_context = self.static_context_linear(static_input) static_context = static_context.view(batch_size, self.hidden_size) else: - static_input = torch.cat([static_cont, static_cat], dim=2).to( dtype=self.static_context_linear.weight.dtype ) From 424274dbd5c3535252b66a48c96dc723df147a02 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Wed, 28 May 2025 12:20:13 +0200 Subject: [PATCH 3/3] Apply suggestions from code review --- pytorch_forecasting/data/samplers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_forecasting/data/samplers.py b/pytorch_forecasting/data/samplers.py index 3fbbd593c..0ae4677d6 100644 --- a/pytorch_forecasting/data/samplers.py +++ b/pytorch_forecasting/data/samplers.py @@ -51,7 +51,7 @@ def __init__( ) if not isinstance(drop_last, bool): raise ValueError( - "drop_last should be a boolean value, but got " "drop_last={}".format( + "drop_last should be a boolean value, but got drop_last={}".format( drop_last ) )