Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,7 @@ repos:
hooks:
- id: ruff
args: [--fix]
- repo: https://github.com/psf/black
rev: 24.8.0
hooks:
- id: black
- id: ruff-format
- repo: https://github.com/nbQA-dev/nbQA
rev: 1.8.7
hooks:
Expand Down
31 changes: 7 additions & 24 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,13 @@ exclude = [
]
target-version = "py39"

[tool.ruff.format]
# Enable formatting
quote-style = "double"
indent-style = "space"
skip-magic-trailing-comma = false
line-ending = "auto"

[tool.ruff.lint]
select = ["E", "F", "W", "C4", "S"]
extend-select = [
Expand Down Expand Up @@ -171,30 +178,6 @@ force-sort-within-sections = true
"E501", # Line too long being fixed in #1746 To be removed after merging
]

[tool.black]
line-length = 88
include = '\.pyi?$'
exclude = '''
(
/(
\.eggs # exclude a few common directories in the
| \.git # root of the project
| \.hg
| \.mypy_cache
| \.tox
| \.venv
| _build
| buck-out
| build
| dist
)/
| docs/build/
| node_modules/
| venve/
| .venv/
)
'''

[tool.nbqa.mutate]
ruff = 1
black = 1
Expand Down
1 change: 0 additions & 1 deletion pytorch_forecasting/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ def __init__(
num_workers: int = 0,
train_val_test_split: tuple = (0.7, 0.15, 0.15),
):

self.time_series_dataset = time_series_dataset
self.max_encoder_length = max_encoder_length
self.min_encoder_length = min_encoder_length
Expand Down
12 changes: 4 additions & 8 deletions pytorch_forecasting/data/timeseries/_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -2198,27 +2198,23 @@ def __getitem__(self, idx: int) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
# select subset of sequence of new sequence
if new_encoder_length + new_decoder_length < len(target[0]):
data_cat = data_cat[
encoder_length
- new_encoder_length : encoder_length
encoder_length - new_encoder_length : encoder_length
+ new_decoder_length
]
data_cont = data_cont[
encoder_length
- new_encoder_length : encoder_length
encoder_length - new_encoder_length : encoder_length
+ new_decoder_length
]
target = [
t[
encoder_length
- new_encoder_length : encoder_length
encoder_length - new_encoder_length : encoder_length
+ new_decoder_length
]
for t in target
]
if weight is not None:
weight = weight[
encoder_length
- new_encoder_length : encoder_length
encoder_length - new_encoder_length : encoder_length
+ new_decoder_length
]
encoder_length = new_encoder_length
Expand Down
1 change: 0 additions & 1 deletion pytorch_forecasting/data/timeseries/_timeseries_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ def __init__(
unknown: Optional[list[Union[str, list[str]]]] = None,
static: Optional[list[Union[str, list[str]]]] = None,
):

self.data = data
self.data_future = data_future
self.time = time
Expand Down
4 changes: 3 additions & 1 deletion pytorch_forecasting/metrics/base_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1123,7 +1123,9 @@ def sample(self, y_pred, n_samples: int) -> torch.Tensor:
torch.Tensor: tensor with samples (shape batch_size x n_timesteps x n_samples)
""" # noqa: E501
dist = self.map_x_to_distribution(y_pred)
samples = dist.sample((n_samples,)).permute(
samples = dist.sample(
(n_samples,)
).permute(
2, 1, 0
) # returned as (n_samples, n_timesteps, batch_size), so reshape to (batch_size, n_timesteps, n_samples) # noqa: E501
return samples
Expand Down
1 change: 0 additions & 1 deletion pytorch_forecasting/models/base/_base_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@


class _BaseObject(_SkbaseBaseObject):

pass


Expand Down
17 changes: 11 additions & 6 deletions pytorch_forecasting/models/deepar/_deepar.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,13 +222,18 @@ def from_dataset(
MultiLoss([NormalDistributionLoss()] * len(dataset.target_names)),
)
new_kwargs.update(kwargs)
assert not isinstance(dataset.target_normalizer, NaNLabelEncoder) and (
not isinstance(dataset.target_normalizer, MultiNormalizer)
or all(
not isinstance(normalizer, NaNLabelEncoder)
for normalizer in dataset.target_normalizer
assert (
not isinstance(dataset.target_normalizer, NaNLabelEncoder)
and (
not isinstance(dataset.target_normalizer, MultiNormalizer)
or all(
not isinstance(normalizer, NaNLabelEncoder)
for normalizer in dataset.target_normalizer
)
)
), "target(s) should be continuous - categorical targets are not supported" # todo: remove this restriction # noqa: E501
), (
"target(s) should be continuous - categorical targets are not supported"
) # todo: remove this restriction # noqa: E501
if isinstance(new_kwargs.get("loss", None), MultivariateDistributionLoss):
assert (
dataset.min_prediction_length == dataset.max_prediction_length
Expand Down
17 changes: 11 additions & 6 deletions pytorch_forecasting/models/rnn/_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,13 +204,18 @@ def from_dataset(
dataset=dataset, kwargs=kwargs, default_loss=MAE()
)
)
assert not isinstance(dataset.target_normalizer, NaNLabelEncoder) and (
not isinstance(dataset.target_normalizer, MultiNormalizer)
or all(
not isinstance(normalizer, NaNLabelEncoder)
for normalizer in dataset.target_normalizer
assert (
not isinstance(dataset.target_normalizer, NaNLabelEncoder)
and (
not isinstance(dataset.target_normalizer, MultiNormalizer)
or all(
not isinstance(normalizer, NaNLabelEncoder)
for normalizer in dataset.target_normalizer
)
)
), "target(s) should be continuous - categorical targets are not supported" # todo: remove this restriction # noqa: E501
), (
"target(s) should be continuous - categorical targets are not supported"
) # todo: remove this restriction # noqa: E501
return super().from_dataset(
dataset,
allowed_encoder_known_variable_names=allowed_encoder_known_variable_names,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,6 @@ def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
static_context = self.static_context_linear(static_input)
static_context = static_context.view(batch_size, self.hidden_size)
else:

static_input = torch.cat([static_cont, static_cat], dim=2).to(
dtype=self.static_context_linear.weight.dtype
)
Expand Down
Loading