Skip to content

🎉 Sklearn pipelines wrapper #106

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 18 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
13d7819
(containers) minor fix docstrings
xroynard Jul 7, 2025
e00877c
feat(dataset.py) get_scalars_to_tabular create array with dtype depen…
xroynard Jul 7, 2025
bd5600f
feat(dataset.py) improve __getitem__ to work with slices
xroynard Jul 7, 2025
776b4e8
feat(dataset.py) add method extract_dataset to extract a dataset with…
xroynard Jul 7, 2025
52ff991
feat(dataset.py) add method merge_samples to merge scalars/fields/tre…
xroynard Jul 7, 2025
70962b2
feat(dataset.py) add methods to work with tabular fields the same way…
xroynard Jul 7, 2025
d4b848e
feat(dataset.py) add some tests for new fonctiannalities -> to DEBUG
xroynard Jul 7, 2025
2dc9f54
(sklearn wrapper) add classes to wrap any sklearn block to use and re…
xroynard Jul 12, 2025
a660b51
(dataset) fix __getitem__ with slices
xroynard Jul 12, 2025
0ccac03
(notebooks) add to Pipelines examples with or without PLAID wrapping
xroynard Jul 15, 2025
0d1b97a
feat(dataset.py) add methods to handle time_series
xroynard Jul 16, 2025
26cf292
(sklearn wrapper) rename in/out_keys to in/out_features + update logi…
xroynard Jul 16, 2025
a77346e
(pipeline example) minor typo
xroynard Jul 16, 2025
b58438a
(sklearn wrapper) update convert_y_to_plaid logic -> needs debug
xroynard Jul 16, 2025
b7dd97e
(pipeline example) update notebook
xroynard Jul 16, 2025
7a9983a
(test_dataset/conftest) factorize fixtures in conftest + reorganize t…
xroynard Aug 7, 2025
4596a9f
(sklearn wrapper) minor ruff reformating
xroynard Aug 7, 2025
7624d99
(dataset) clean methods on tabular + fix a bug
xroynard Aug 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2,711 changes: 2,711 additions & 0 deletions docs/source/notebooks/pca_gp_plaid_pipeline.ipynb

Large diffs are not rendered by default.

7,899 changes: 7,899 additions & 0 deletions docs/source/notebooks/pca_gp_sklearn_pipeline.ipynb

Large diffs are not rendered by default.

397 changes: 370 additions & 27 deletions src/plaid/containers/dataset.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/plaid/containers/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def read_index_range(pyTree: list, dim: list[int]):


class Sample(BaseModel):
"""Represents a single sample. It contains data and information related to a single observation or measurement within a dataset."""
"""Represent a single sample. It contains data and information related to a single observation or measurement within a dataset."""

def __init__(
self,
Expand Down
7 changes: 7 additions & 0 deletions src/plaid/wrappers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""Wrapper functions for the PLAID library."""

# -*- coding: utf-8 -*-
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
#
501 changes: 501 additions & 0 deletions src/plaid/wrappers/sklearn.py

Large diffs are not rendered by default.

60 changes: 60 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,63 @@ def sample_with_tree(sample, tree):
@pytest.fixture()
def sample():
return Sample()


# fixtures for tabular scalar data
@pytest.fixture()
def nb_scalars():
return 5


@pytest.fixture()
def scalar_tabular(nb_samples, nb_scalars):
return np.random.randn(nb_samples, nb_scalars)


@pytest.fixture()
def scalar_names(nb_scalars):
return [f"test_scalar_{np.random.randint(1e8, 1e9)}" for _ in range(nb_scalars)]


# fixtures for tabular time_series data
@pytest.fixture()
def nb_time_series():
return 5


@pytest.fixture()
def nb_timestamps():
return 11


@pytest.fixture()
def time_series_tabular(nb_samples, nb_timestamps, nb_time_series):
return np.random.randn(nb_samples, nb_timestamps, nb_time_series, 2)


@pytest.fixture()
def time_series_names(nb_time_series):
return [
f"test_time_series_{np.random.randint(1e8, 1e9)}" for _ in range(nb_time_series)
]


# fixtures for tabular field data
@pytest.fixture()
def nb_fields():
return 5


@pytest.fixture()
def nb_points():
return 13


@pytest.fixture()
def field_tabular(nb_samples, nb_points, nb_fields):
return np.random.randn(nb_samples, nb_points, nb_fields)


@pytest.fixture()
def field_names(nb_fields):
return [f"test_field_{np.random.randint(1e8, 1e9)}" for _ in range(nb_fields)]
186 changes: 134 additions & 52 deletions tests/containers/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,6 @@
# %% Fixtures


@pytest.fixture()
def nb_scalars():
return 5


@pytest.fixture()
def tabular(nb_samples, nb_scalars):
return np.random.randn(nb_samples, nb_scalars)


@pytest.fixture()
def scalar_names(nb_scalars):
return [f"test_scalar_{np.random.randint(1e8, 1e9)}" for _ in range(nb_scalars)]


@pytest.fixture()
def sample(zone_name, base_name):
sample = Sample()
Expand All @@ -51,16 +36,6 @@ def sample(zone_name, base_name):
return sample


@pytest.fixture
def empty_sample():
return Sample()


@pytest.fixture()
def empty_dataset():
return Dataset()


@pytest.fixture()
def current_directory():
return Path(__file__).absolute().parent
Expand Down Expand Up @@ -200,30 +175,30 @@ def test_add_sample_not_a_sample(self, dataset):
with pytest.raises(TypeError):
dataset.add_sample(1)

def test_add_samples_empty(self, empty_dataset):
def test_add_samples_empty(self, dataset):
with pytest.raises(ValueError):
empty_dataset.add_samples([])
dataset.add_samples([])
with pytest.raises(ValueError):
empty_dataset.add_samples([], 1)
dataset.add_samples([], 1)

def test_add_samples_empty_with_ids(self, empty_dataset, sample):
def test_add_samples_empty_with_ids(self, dataset, sample):
with pytest.raises(ValueError):
empty_dataset.add_samples([sample], [1, 2, 3])
dataset.add_samples([sample], [1, 2, 3])

def test_add_samples_bad_number_ids_inf(self, empty_dataset, sample):
def test_add_samples_bad_number_ids_inf(self, dataset, sample):
with pytest.raises(ValueError):
samples = [sample, sample, sample]
empty_dataset.add_samples(samples, [1, 2])
dataset.add_samples(samples, [1, 2])

def test_add_samples_bad_number_ids_supp(self, empty_dataset, sample):
def test_add_samples_bad_number_ids_supp(self, dataset, sample):
with pytest.raises(ValueError):
samples = [sample, sample, sample]
empty_dataset.add_samples(samples, [1, 2, 3, 4])
dataset.add_samples(samples, [1, 2, 3, 4])

def test_add_samples_with_same_ids(self, empty_dataset, sample):
def test_add_samples_with_same_ids(self, dataset, sample):
with pytest.raises(ValueError):
samples = [sample, sample, sample]
empty_dataset.add_samples(samples, [1, 1, 1])
dataset.add_samples(samples, [1, 1, 1])

def test_add_samples_with_ids_good(self, dataset, sample):
samples = [sample, sample, sample]
Expand Down Expand Up @@ -321,46 +296,150 @@ def test_get_field_names(self, dataset_with_samples, nb_samples):
dataset_with_samples.get_field_names(np.random.randint(2, nb_samples, size=2))

# -------------------------------------------------------------------------#
def test_add_tabular_scalars(self, dataset, tabular, scalar_names, nb_samples):
dataset.add_tabular_scalars(tabular, scalar_names)
def test_add_tabular_scalars(
self, dataset, scalar_tabular, scalar_names, nb_samples
):
dataset.add_tabular_scalars(scalar_tabular, scalar_names)
assert len(dataset) == nb_samples

def test_add_tabular_scalars_no_names(self, dataset, tabular, nb_samples):
dataset.add_tabular_scalars(tabular)
def test_add_tabular_scalars_no_names(self, dataset, scalar_tabular, nb_samples):
dataset.add_tabular_scalars(scalar_tabular)
assert len(dataset) == nb_samples

def test_add_tabular_scalars_bad_ndim(self, dataset, tabular, scalar_names):
def test_add_tabular_scalars_bad_ndim(self, dataset, scalar_tabular, scalar_names):
with pytest.raises(ShapeError):
dataset.add_tabular_scalars(tabular.reshape((-1)), scalar_names)
dataset.add_tabular_scalars(scalar_tabular.reshape((-1)), scalar_names)

def test_add_tabular_scalars_bad_shape(self, dataset, tabular, scalar_names):
tabular = np.concatenate((tabular, np.zeros((len(tabular), 1))), axis=1)
def test_add_tabular_scalars_bad_shape(self, dataset, scalar_tabular, scalar_names):
scalar_tabular = np.concatenate(
(scalar_tabular, np.zeros((len(scalar_tabular), 1))), axis=1
)
with pytest.raises(ShapeError):
dataset.add_tabular_scalars(tabular, scalar_names)
dataset.add_tabular_scalars(scalar_tabular, scalar_names)

def test_get_scalars_to_tabular(self, dataset, tabular, scalar_names):
def test_get_scalars_to_tabular(self, dataset, scalar_tabular, scalar_names):
assert len(dataset.get_scalars_to_tabular()) == 0
assert dataset.get_scalars_to_tabular() == {}
dataset.add_tabular_scalars(tabular, scalar_names)
dataset.add_tabular_scalars(scalar_tabular, scalar_names)
assert dataset.get_scalars_to_tabular(as_nparray=True).shape == (
len(tabular),
len(scalar_tabular),
len(scalar_names),
)
dict_tabular = dataset.get_scalars_to_tabular()
for i_s, sname in enumerate(scalar_names):
assert np.all(dict_tabular[sname] == tabular[:, i_s])
assert np.all(dict_tabular[sname] == scalar_tabular[:, i_s])

def test_get_scalars_to_tabular_same_scalars_name(
self, dataset, tabular, scalar_names
self, dataset, scalar_tabular, scalar_names
):
dataset.add_tabular_scalars(tabular, scalar_names)
dataset.add_tabular_scalars(scalar_tabular, scalar_names)
assert dataset.get_scalars_to_tabular(as_nparray=True).shape == (
len(tabular),
len(scalar_tabular),
len(scalar_names),
)
dataset.get_scalars_to_tabular(sample_ids=[0, 0])
dataset.get_scalars_to_tabular(scalar_names=["test", "test"])

# -------------------------------------------------------------------------#
def test_add_tabular_time_series(
self, dataset, time_series_tabular, time_series_names, nb_samples
):
dataset.add_tabular_time_series(time_series_tabular, time_series_names)
assert len(dataset) == nb_samples

def test_add_tabular_time_series_no_names(
self, dataset, time_series_tabular, nb_samples
):
dataset.add_tabular_time_series(time_series_tabular)
assert len(dataset) == nb_samples

def test_add_tabular_time_series_bad_ndim(
self, dataset, time_series_tabular, time_series_names
):
with pytest.raises(ShapeError):
dataset.add_tabular_time_series(
time_series_tabular.reshape((-1)), time_series_names
)

def test_add_tabular_time_series_bad_shape(
self, dataset, time_series_tabular, time_series_names
):
time_series_tabular = np.concatenate(
(time_series_tabular, np.zeros((len(time_series_tabular), 1))), axis=1
)
with pytest.raises(ShapeError):
dataset.add_tabular_time_series(time_series_tabular, time_series_names)

def test_get_time_series_to_tabular(
self, dataset, time_series_tabular, time_series_names
):
assert len(dataset.get_time_series_to_tabular()) == 0
assert dataset.get_time_series_to_tabular() == {}
dataset.add_tabular_time_series(time_series_tabular, time_series_names)
assert dataset.get_time_series_to_tabular(as_nparray=True).shape == (
len(time_series_tabular),
len(time_series_names),
)
dict_tabular = dataset.get_time_series_to_tabular()
for i_s, sname in enumerate(time_series_names):
assert np.all(dict_tabular[sname] == time_series_tabular[:, i_s])

def test_get_time_series_to_tabular_same_time_series_name(
self, dataset, time_series_tabular, time_series_names
):
dataset.add_tabular_time_series(time_series_tabular, time_series_names)
assert dataset.get_time_series_to_tabular(as_nparray=True).shape == (
len(time_series_tabular),
len(time_series_names),
)
dataset.get_time_series_to_tabular(sample_ids=[0, 0])
dataset.get_time_series_to_tabular(time_series_names=["test", "test"])

# -------------------------------------------------------------------------#
def test_add_tabular_fields(self, dataset, field_tabular, field_names, nb_samples):
dataset.add_tabular_fields(field_tabular, field_names)
assert len(dataset) == nb_samples

def test_add_tabular_fields_no_names(self, dataset, field_tabular, nb_samples):
dataset.add_tabular_fields(field_tabular)
assert len(dataset) == nb_samples

def test_add_tabular_fields_bad_ndim(self, dataset, field_tabular, field_names):
with pytest.raises(ShapeError):
dataset.add_tabular_fields(field_tabular.reshape((-1)), field_names)

def test_add_tabular_fields_bad_shape(
self, dataset, field_tabular, nb_points, field_names
):
field_tabular = np.concatenate(
(field_tabular, np.zeros((len(field_tabular), nb_points, 1))), axis=-1
)
with pytest.raises(ShapeError):
dataset.add_tabular_fields(field_tabular, field_names)

def test_get_fields_to_tabular(self, dataset, field_tabular, field_names):
assert len(dataset.get_fields_to_tabular()) == 0
assert dataset.get_fields_to_tabular() == {}
dataset.add_tabular_fields(field_tabular, field_names)
assert dataset.get_fields_to_tabular(as_nparray=True).shape == (
len(field_tabular),
len(field_names),
)
dict_tabular = dataset.get_fields_to_tabular()
for i_s, sname in enumerate(field_names):
assert np.all(dict_tabular[sname] == field_tabular[:, i_s])

def test_get_fields_to_tabular_same_fields_name(
self, dataset, field_tabular, field_names
):
dataset.add_tabular_fields(field_tabular, field_names)
assert dataset.get_fields_to_tabular(as_nparray=True).shape == (
len(field_tabular),
len(field_names),
)
dataset.get_fields_to_tabular(sample_ids=[0, 0])
dataset.get_fields_to_tabular(field_names=["test", "test"])

# -------------------------------------------------------------------------#
def test_add_info(self, dataset):
dataset.add_info("legal", "owner", "PLAID")
Expand Down Expand Up @@ -415,6 +494,9 @@ def test_merge_dataset_with_bad_type(self, dataset_with_samples):
with pytest.raises(ValueError):
dataset_with_samples.merge_dataset(3)

def test_merge_samples(self, dataset_with_samples, other_dataset_with_samples):
dataset_with_samples.merge_samples(other_dataset_with_samples)

# -------------------------------------------------------------------------#

def test_save(self, dataset_with_samples, tmp_path):
Expand Down
2 changes: 1 addition & 1 deletion tests/problem_definition/problem_infos.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ input_fields:
- test_field
output_fields:
- field
- test_field
- predict_field
- test_field
input_timeseries:
- predict_timeseries
- test_timeseries
Expand Down
Loading