Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,7 @@ repos:
hooks:
- id: ruff
args: [--fix]
- repo: https://github.com/psf/black
rev: 24.8.0
hooks:
- id: black
- id: ruff-format
- repo: https://github.com/nbQA-dev/nbQA
rev: 1.8.7
hooks:
Expand Down
31 changes: 7 additions & 24 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,13 @@ exclude = [
]
target-version = "py39"

[tool.ruff.format]
# Enable formatting
quote-style = "double"
indent-style = "space"
skip-magic-trailing-comma = false
line-ending = "auto"

[tool.ruff.lint]
select = ["E", "F", "W", "C4", "S"]
extend-select = [
Expand Down Expand Up @@ -170,30 +177,6 @@ force-sort-within-sections = true
"E501", # Line too long being fixed in #1746 To be removed after merging
]

[tool.black]
line-length = 88
include = '\.pyi?$'
exclude = '''
(
/(
\.eggs # exclude a few common directories in the
| \.git # root of the project
| \.hg
| \.mypy_cache
| \.tox
| \.venv
| _build
| buck-out
| build
| dist
)/
| docs/build/
| node_modules/
| venve/
| .venv/
)
'''

[tool.nbqa.mutate]
ruff = 1
black = 1
Expand Down
1 change: 0 additions & 1 deletion pytorch_forecasting/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ def __init__(
num_workers: int = 0,
train_val_test_split: tuple = (0.7, 0.15, 0.15),
):

self.time_series_dataset = time_series_dataset
self.max_encoder_length = max_encoder_length
self.min_encoder_length = min_encoder_length
Expand Down
5 changes: 3 additions & 2 deletions pytorch_forecasting/data/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,9 @@ def __init__(
)
if not isinstance(drop_last, bool):
raise ValueError(
"drop_last should be a boolean value, but got "
"drop_last={}".format(drop_last)
"drop_last should be a boolean value, but got drop_last={}".format(
drop_last
)
)
self.sampler = sampler
self.batch_size = batch_size
Expand Down
12 changes: 4 additions & 8 deletions pytorch_forecasting/data/timeseries/_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -2198,27 +2198,23 @@ def __getitem__(self, idx: int) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
# select subset of sequence of new sequence
if new_encoder_length + new_decoder_length < len(target[0]):
data_cat = data_cat[
encoder_length
- new_encoder_length : encoder_length
encoder_length - new_encoder_length : encoder_length
+ new_decoder_length
]
data_cont = data_cont[
encoder_length
- new_encoder_length : encoder_length
encoder_length - new_encoder_length : encoder_length
+ new_decoder_length
]
target = [
t[
encoder_length
- new_encoder_length : encoder_length
encoder_length - new_encoder_length : encoder_length
+ new_decoder_length
]
for t in target
]
if weight is not None:
weight = weight[
encoder_length
- new_encoder_length : encoder_length
encoder_length - new_encoder_length : encoder_length
+ new_decoder_length
]
encoder_length = new_encoder_length
Expand Down
1 change: 0 additions & 1 deletion pytorch_forecasting/data/timeseries/_timeseries_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ def __init__(
unknown: Optional[List[Union[str, List[str]]]] = None,
static: Optional[List[Union[str, List[str]]]] = None,
):

self.data = data
self.data_future = data_future
self.time = time
Expand Down
4 changes: 3 additions & 1 deletion pytorch_forecasting/metrics/base_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1123,7 +1123,9 @@ def sample(self, y_pred, n_samples: int) -> torch.Tensor:
torch.Tensor: tensor with samples (shape batch_size x n_timesteps x n_samples)
""" # noqa: E501
dist = self.map_x_to_distribution(y_pred)
samples = dist.sample((n_samples,)).permute(
samples = dist.sample(
(n_samples,)
).permute(
2, 1, 0
) # returned as (n_samples, n_timesteps, batch_size), so reshape to (batch_size, n_timesteps, n_samples) # noqa: E501
return samples
Expand Down
1 change: 0 additions & 1 deletion pytorch_forecasting/models/base/_base_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@


class _BaseObject(_SkbaseBaseObject):

pass


Expand Down
17 changes: 11 additions & 6 deletions pytorch_forecasting/models/deepar/_deepar.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,13 +222,18 @@ def from_dataset(
MultiLoss([NormalDistributionLoss()] * len(dataset.target_names)),
)
new_kwargs.update(kwargs)
assert not isinstance(dataset.target_normalizer, NaNLabelEncoder) and (
not isinstance(dataset.target_normalizer, MultiNormalizer)
or all(
not isinstance(normalizer, NaNLabelEncoder)
for normalizer in dataset.target_normalizer
assert (
not isinstance(dataset.target_normalizer, NaNLabelEncoder)
and (
not isinstance(dataset.target_normalizer, MultiNormalizer)
or all(
not isinstance(normalizer, NaNLabelEncoder)
for normalizer in dataset.target_normalizer
)
)
), "target(s) should be continuous - categorical targets are not supported" # todo: remove this restriction # noqa: E501
), (
"target(s) should be continuous - categorical targets are not supported"
) # todo: remove this restriction # noqa: E501
if isinstance(new_kwargs.get("loss", None), MultivariateDistributionLoss):
assert (
dataset.min_prediction_length == dataset.max_prediction_length
Expand Down
17 changes: 11 additions & 6 deletions pytorch_forecasting/models/rnn/_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,13 +204,18 @@ def from_dataset(
dataset=dataset, kwargs=kwargs, default_loss=MAE()
)
)
assert not isinstance(dataset.target_normalizer, NaNLabelEncoder) and (
not isinstance(dataset.target_normalizer, MultiNormalizer)
or all(
not isinstance(normalizer, NaNLabelEncoder)
for normalizer in dataset.target_normalizer
assert (
not isinstance(dataset.target_normalizer, NaNLabelEncoder)
and (
not isinstance(dataset.target_normalizer, MultiNormalizer)
or all(
not isinstance(normalizer, NaNLabelEncoder)
for normalizer in dataset.target_normalizer
)
)
), "target(s) should be continuous - categorical targets are not supported" # todo: remove this restriction # noqa: E501
), (
"target(s) should be continuous - categorical targets are not supported"
) # todo: remove this restriction # noqa: E501
return super().from_dataset(
dataset,
allowed_encoder_known_variable_names=allowed_encoder_known_variable_names,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,6 @@ def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
static_context = self.static_context_linear(static_input)
static_context = static_context.view(batch_size, self.hidden_size)
else:

static_input = torch.cat([static_cont, static_cat], dim=2).to(
dtype=self.static_context_linear.weight.dtype
)
Expand Down
Loading