diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e3dc886c1..d5acad4d8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,10 +13,7 @@ repos: hooks: - id: ruff args: [--fix] - - repo: https://github.com/psf/black - rev: 24.8.0 - hooks: - - id: black + - id: ruff-format - repo: https://github.com/nbQA-dev/nbQA rev: 1.8.7 hooks: diff --git a/pyproject.toml b/pyproject.toml index e3d0f62e7..d35a9d28e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -142,6 +142,13 @@ exclude = [ ] target-version = "py39" +[tool.ruff.format] +# Enable formatting +quote-style = "double" +indent-style = "space" +skip-magic-trailing-comma = false +line-ending = "auto" + [tool.ruff.lint] select = ["E", "F", "W", "C4", "S"] extend-select = [ @@ -171,30 +178,6 @@ force-sort-within-sections = true "E501", # Line too long being fixed in #1746 To be removed after merging ] -[tool.black] -line-length = 88 -include = '\.pyi?$' -exclude = ''' -( - /( - \.eggs # exclude a few common directories in the - | \.git # root of the project - | \.hg - | \.mypy_cache - | \.tox - | \.venv - | _build - | buck-out - | build - | dist - )/ - | docs/build/ - | node_modules/ - | venve/ - | .venv/ -) -''' - [tool.nbqa.mutate] ruff = 1 black = 1 diff --git a/pytorch_forecasting/data/data_module.py b/pytorch_forecasting/data/data_module.py index 574f073e8..488a56f33 100644 --- a/pytorch_forecasting/data/data_module.py +++ b/pytorch_forecasting/data/data_module.py @@ -108,7 +108,6 @@ def __init__( num_workers: int = 0, train_val_test_split: tuple = (0.7, 0.15, 0.15), ): - self.time_series_dataset = time_series_dataset self.max_encoder_length = max_encoder_length self.min_encoder_length = min_encoder_length diff --git a/pytorch_forecasting/data/timeseries/_timeseries.py b/pytorch_forecasting/data/timeseries/_timeseries.py index b89b741d9..043a51d1e 100644 --- a/pytorch_forecasting/data/timeseries/_timeseries.py +++ b/pytorch_forecasting/data/timeseries/_timeseries.py @@ -2198,27 +2198,23 @@ def __getitem__(self, idx: int) -> tuple[dict[str, torch.Tensor], torch.Tensor]: # select subset of sequence of new sequence if new_encoder_length + new_decoder_length < len(target[0]): data_cat = data_cat[ - encoder_length - - new_encoder_length : encoder_length + encoder_length - new_encoder_length : encoder_length + new_decoder_length ] data_cont = data_cont[ - encoder_length - - new_encoder_length : encoder_length + encoder_length - new_encoder_length : encoder_length + new_decoder_length ] target = [ t[ - encoder_length - - new_encoder_length : encoder_length + encoder_length - new_encoder_length : encoder_length + new_decoder_length ] for t in target ] if weight is not None: weight = weight[ - encoder_length - - new_encoder_length : encoder_length + encoder_length - new_encoder_length : encoder_length + new_decoder_length ] encoder_length = new_encoder_length diff --git a/pytorch_forecasting/data/timeseries/_timeseries_v2.py b/pytorch_forecasting/data/timeseries/_timeseries_v2.py index c129e1790..9a9e02c25 100644 --- a/pytorch_forecasting/data/timeseries/_timeseries_v2.py +++ b/pytorch_forecasting/data/timeseries/_timeseries_v2.py @@ -90,7 +90,6 @@ def __init__( unknown: Optional[list[Union[str, list[str]]]] = None, static: Optional[list[Union[str, list[str]]]] = None, ): - self.data = data self.data_future = data_future self.time = time diff --git a/pytorch_forecasting/metrics/base_metrics.py b/pytorch_forecasting/metrics/base_metrics.py index c0949ca21..8b4ee141f 100644 --- a/pytorch_forecasting/metrics/base_metrics.py +++ b/pytorch_forecasting/metrics/base_metrics.py @@ -1123,7 +1123,9 @@ def sample(self, y_pred, n_samples: int) -> torch.Tensor: torch.Tensor: tensor with samples (shape batch_size x n_timesteps x n_samples) """ # noqa: E501 dist = self.map_x_to_distribution(y_pred) - samples = dist.sample((n_samples,)).permute( + samples = dist.sample( + (n_samples,) + ).permute( 2, 1, 0 ) # returned as (n_samples, n_timesteps, batch_size), so reshape to (batch_size, n_timesteps, n_samples) # noqa: E501 return samples diff --git a/pytorch_forecasting/models/base/_base_object.py b/pytorch_forecasting/models/base/_base_object.py index 7fd59d6a4..0106b7afa 100644 --- a/pytorch_forecasting/models/base/_base_object.py +++ b/pytorch_forecasting/models/base/_base_object.py @@ -8,7 +8,6 @@ class _BaseObject(_SkbaseBaseObject): - pass diff --git a/pytorch_forecasting/models/deepar/_deepar.py b/pytorch_forecasting/models/deepar/_deepar.py index c6aa0839a..3ac104f62 100644 --- a/pytorch_forecasting/models/deepar/_deepar.py +++ b/pytorch_forecasting/models/deepar/_deepar.py @@ -222,13 +222,18 @@ def from_dataset( MultiLoss([NormalDistributionLoss()] * len(dataset.target_names)), ) new_kwargs.update(kwargs) - assert not isinstance(dataset.target_normalizer, NaNLabelEncoder) and ( - not isinstance(dataset.target_normalizer, MultiNormalizer) - or all( - not isinstance(normalizer, NaNLabelEncoder) - for normalizer in dataset.target_normalizer + assert ( + not isinstance(dataset.target_normalizer, NaNLabelEncoder) + and ( + not isinstance(dataset.target_normalizer, MultiNormalizer) + or all( + not isinstance(normalizer, NaNLabelEncoder) + for normalizer in dataset.target_normalizer + ) ) - ), "target(s) should be continuous - categorical targets are not supported" # todo: remove this restriction # noqa: E501 + ), ( + "target(s) should be continuous - categorical targets are not supported" + ) # todo: remove this restriction # noqa: E501 if isinstance(new_kwargs.get("loss", None), MultivariateDistributionLoss): assert ( dataset.min_prediction_length == dataset.max_prediction_length diff --git a/pytorch_forecasting/models/rnn/_rnn.py b/pytorch_forecasting/models/rnn/_rnn.py index e9141816e..c7aa7ac73 100644 --- a/pytorch_forecasting/models/rnn/_rnn.py +++ b/pytorch_forecasting/models/rnn/_rnn.py @@ -204,13 +204,18 @@ def from_dataset( dataset=dataset, kwargs=kwargs, default_loss=MAE() ) ) - assert not isinstance(dataset.target_normalizer, NaNLabelEncoder) and ( - not isinstance(dataset.target_normalizer, MultiNormalizer) - or all( - not isinstance(normalizer, NaNLabelEncoder) - for normalizer in dataset.target_normalizer + assert ( + not isinstance(dataset.target_normalizer, NaNLabelEncoder) + and ( + not isinstance(dataset.target_normalizer, MultiNormalizer) + or all( + not isinstance(normalizer, NaNLabelEncoder) + for normalizer in dataset.target_normalizer + ) ) - ), "target(s) should be continuous - categorical targets are not supported" # todo: remove this restriction # noqa: E501 + ), ( + "target(s) should be continuous - categorical targets are not supported" + ) # todo: remove this restriction # noqa: E501 return super().from_dataset( dataset, allowed_encoder_known_variable_names=allowed_encoder_known_variable_names, diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/_tft_v2.py b/pytorch_forecasting/models/temporal_fusion_transformer/_tft_v2.py index e74f7cf32..3fffdf5cf 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/_tft_v2.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/_tft_v2.py @@ -179,7 +179,6 @@ def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: static_context = self.static_context_linear(static_input) static_context = static_context.view(batch_size, self.hidden_size) else: - static_input = torch.cat([static_cont, static_cat], dim=2).to( dtype=self.static_context_linear.weight.dtype )