diff --git a/moabb/datasets/base.py b/moabb/datasets/base.py index caff5c68a..c14a64939 100644 --- a/moabb/datasets/base.py +++ b/moabb/datasets/base.py @@ -14,6 +14,7 @@ import mne_bids import pandas as pd +from mne_bids.path import _find_matching_sidecar from sklearn.pipeline import Pipeline from moabb.datasets.bids_interface import StepType, _interface_map @@ -698,6 +699,34 @@ def data_path( """ # noqa: E501 pass + def get_additional_metadata( + self, subject: str, session: str, run: str + ) -> None | pd.DataFrame: + """ + Load additional metadata for a specific subject, session, and run. + + This method is intended to be overridden by subclasses to provide + additional metadata specific to the dataset. The metadata is typically + loaded from an `events.tsv` file or similar data source. + + Parameters + ---------- + subject : str + The identifier for the subject. + session : str + The identifier for the session. + run : str + The identifier for the run. + + Returns + ------- + None | pd.DataFrame + A DataFrame containing the additional metadata if available, + otherwise None. + """ + + return None + class BaseBIDSDataset(BaseDataset): """Abstract BIDS dataset class. @@ -788,6 +817,69 @@ def _get_single_subject_data(self, subject): data.setdefault(session, {})[run] = raw return data + def get_additional_metadata( + self, subject: str, session: str, run: str + ) -> None | pd.DataFrame: + """ + Load additional metadata for a specific subject, session, and run. + This is just loading all metadata, filtering down to epochs levels + is done at ... + + + Parameters + ---------- + subject : str + The identifier for the subject. + session : str + The identifier for the session. + run : str + The identifier for the run. + + Returns + ------- + None | pd.DataFrame + A DataFrame containing the additional metadata if available, + otherwise None. + """ + + bids_paths = self.bids_paths(subject) + + # select only with matching session and run + bids_path_selected = [ + pth + for pth in bids_paths + if f"ses-{session}" in pth.basename and f"run-{run}" in pth.basename + ] + + if len(bids_path_selected) > 1: + raise ValueError("More than one matching BIDS path found.") + bids_path = bids_path_selected[0] + + events_fname = _find_matching_sidecar( + bids_path, suffix="events", extension=".tsv", on_error="warn" + ) + + dm = pd.read_csv(events_fname, sep="\t").assign( + subject=subject, session=session, run=run + ) + + # As long as this is not part of mne-bids https://github.com/mne-tools/mne-bids/pull/1389, + # we cannot will functionally replicate the filtering (as we only) + # need the dropping part + dm = dm[(dm.onset != "n/a") & (~dm.onset.isna())] + dm["onset"] = dm["onset"].astype(float) + + if "trial_type" in dm.columns: + dm = dm[(dm.trial_type != "n/a") & (~dm.onset.isna())] + elif "value" in dm.columns: + dm = dm[(dm.value != "n/a") & (~dm.onset.isna())] + + # for the bids_dataset we can assume that the events are taken from + # a `trial_type` columns -> filter on this + dm = dm[dm["trial_type"].isin(self.event_id.keys())] + + return dm + class LocalBIDSDataset(BaseBIDSDataset): """Generic local/private BIDS datasets. diff --git a/moabb/paradigms/base.py b/moabb/paradigms/base.py index d9757a948..578157723 100644 --- a/moabb/paradigms/base.py +++ b/moabb/paradigms/base.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import abc import logging from operator import methodcaller -from typing import List, Optional, Tuple +from typing import List, Literal, Optional, Tuple import mne import numpy as np @@ -232,6 +234,7 @@ def get_data( # noqa: C901 return_raws=False, cache_config=None, postprocess_pipeline=None, + additional_metadata: Literal["all"] | list[str] = None, ): """ Return the data for a list of subject. @@ -265,6 +268,13 @@ def get_data( # noqa: C901 This pipeline must return an ``np.ndarray``. This pipeline must be "fixed" because it will not be trained, i.e. no call to ``fit`` will be made. + additional_metadata: Literal["all"] | list[str] | None + Additional metadata to be loaded if return_epochs=True. + If None, the default metadata will be loaded containing containing + `subject`, `session` and `run`. If "all", all columns of the `events.tsv` + file will be loaded. A list of column names can be passed to just + select these columns in addition to the three default values mentioned + before. Returns ------- @@ -306,6 +316,22 @@ def get_data( # noqa: C901 for session, runs in sessions.items(): for run in runs.keys(): proc = [data_i[subject][session][run] for data_i in data] + + if additional_metadata: + ext_metadata = [ + dataset.get_additional_metadata( + subject=subject, session=session, run=run + ) + ] * len(process_pipelines) + + if isinstance(additional_metadata, list): + ext_metadata = [ + dm[["session", "subject", "run"] + additional_metadata] + for dm in ext_metadata + ] + else: + ext_metadata = [None] * len(process_pipelines) + if any(obj is None for obj in proc): # this mean the run did not contain any selected event # go to next @@ -321,6 +347,7 @@ def get_data( # noqa: C901 if len(self.filters) == 1 else mne.concatenate_epochs(proc) ) + elif return_raws: assert all(len(proc[0]) == len(p) for p in proc[1:]) n = 1 @@ -350,16 +377,30 @@ def get_data( # noqa: C901 met["subject"] = subject met["session"] = session met["run"] = run + metadata.append(met) + # overwrite if additional is required + if additional_metadata: + # extend the metadata according to the filters + + dmeta_ext = ( + ext_metadata[0].copy() + if isinstance(ext_metadata[0], pd.DataFrame) + else pd.DataFrame() + ) + metadata[-1] = dmeta_ext + if return_epochs: x.metadata = ( - met.copy() + metadata[-1].copy() if len(self.filters) == 1 else pd.concat( - [met.copy()] * len(self.filters), ignore_index=True + [metadata[-1].copy()] * len(self.filters), + ignore_index=True, ) ) + X.append(x) labels.append(lbs) @@ -556,3 +597,30 @@ def scoring(self): def _get_events_pipeline(self, dataset): event_id = self.used_events(dataset) return RawToEvents(event_id=event_id, interval=dataset.interval) + + +# def load_bids_event_metadata( +# data_set: BaseBIDSDataset, subject: str, session: str, run: str +# ) -> pd.DataFrame: +# bids_paths = data_set.bids_paths(subject) +# +# # select only with matching session and run +# bids_path_selected = [ +# pth +# for pth in bids_paths +# if f"ses-{session}" in pth.basename and f"run-{run}" in pth.basename +# ] +# +# if len(bids_path_selected) > 1: +# raise ValueError("More than one matching BIDS path found.") +# bids_path = bids_path_selected[0] +# +# events_fname = _find_matching_sidecar( +# bids_path, suffix="events", extension=".tsv", on_error="warn" +# ) +# +# dm = pd.read_csv(events_fname, sep="\t").assign( +# subject=subject, session=session, run=run +# ) +# +# return dm diff --git a/moabb/tests/test_paradigms.py b/moabb/tests/test_paradigms.py index 44e874738..f43405581 100644 --- a/moabb/tests/test_paradigms.py +++ b/moabb/tests/test_paradigms.py @@ -9,8 +9,12 @@ import pytest from mne import BaseEpochs from mne.io import BaseRaw +from mne_bids.path import _find_matching_sidecar from moabb.datasets import BNCI2014_001 +from moabb.datasets.base import ( + LocalBIDSDataset, +) from moabb.datasets.fake import FakeDataset from moabb.paradigms import ( CVEP, @@ -1237,3 +1241,134 @@ def test_epochs(self, epochs_labels_metadata, dataset): np.testing.assert_array_almost_equal( epo.get_data()[0, :, 0] * dataset.unit_factor, X ) + + +class TestMetadata: + + @pytest.fixture(scope="class") + def cached_dataset_root(self, tmpdir_factory): + root = tmpdir_factory.mktemp("fake_bids") + dataset = FakeDataset( + event_list=["fake1", "fake2"], n_sessions=2, n_subjects=2, n_runs=1 + ) + dataset.get_data(cache_config=dict(save_raw=True, overwrite_raw=False, path=root)) + return root / "MNE-BIDS-fake-dataset-imagery-2-2--60--120--fake1-fake2--c3-cz-c4" + + def test_additional_metadata_extracts_aligned(self, cached_dataset_root): + """ + Test extraction of additional metadata if all rows in the events.tsv + were used to create annotations on the raw -> used for epoching + """ + + # --- The tsv files have metadata which would contain the following + # + # onset duration trial_type value sample + # 0.0078125 3.0 fake1 1 1 + # 1.984375 3.0 fake2 2 254 + # 3.96875 3.0 fake1 1 508 + # 5.953125 3.0 fake2 2 762 + # + # --- While onset, duration and trial_type, are implicitly available + # --- by the epoch design, we could want `value` and or `sample` as well + + dataset = LocalBIDSDataset( + cached_dataset_root, + events={"fake1": 1, "fake2": 2}, + interval=[0, 3], + paradigm="imagery", + ) + paradigm = MotorImagery() + + epo1, labels1, metadata1 = paradigm.get_data( + dataset=dataset, + subjects=["1"], + return_epochs=True, + ) + + raw, raw_labels, raw_metadata = paradigm.get_data( + dataset=dataset, + subjects=["1"], + return_epochs=False, + additional_metadata="all", + ) + + epo2, labels2, metadata2 = paradigm.get_data( + dataset=dataset, + subjects=["1"], + return_epochs=True, + additional_metadata="all", + ) + + epo3, labels3, metadata3 = paradigm.get_data( + dataset=dataset, + subjects=["1"], + return_epochs=True, + additional_metadata=["value"], + ) + + epo4, labels4, metadata4 = paradigm.get_data( + dataset=dataset, + subjects=["1"], + return_epochs=True, + additional_metadata=["value", "duration"], + ) + + assert epo1 == epo2 == epo3 + assert (labels1 == labels2).all() + assert (labels2 == labels3).all() + + assert (raw_metadata == epo2.metadata).all().all() + assert (metadata2 == epo2.metadata).all().all() + + assert (metadata1.columns == ["subject", "session", "run"]).all() + + assert "value" in metadata2.columns + assert "sample" in metadata2.columns + assert "value" in metadata3.columns + assert "value" in metadata4.columns + assert "duration" in metadata4.columns + assert "sample" not in metadata3.columns + assert "sample" not in metadata4.columns + + def test_additional_metadata_extracts_non_aligned(self, cached_dataset_root): + """ + Test extraction of additional metadata if NOT all rows in the events.tsv + were used to create annotations on the raw -> used for epoching + """ + + dataset = LocalBIDSDataset( + cached_dataset_root, + events={"fake1": 1}, + interval=[0, 3], + paradigm="imagery", + ) + + # modify the events.tsv to contain 'n/a' + events_fname = _find_matching_sidecar( + dataset.bids_paths("1")[0], suffix="events", extension=".tsv" + ) + df = pd.read_csv(events_fname, sep="\t") + df = df.assign(ix=range(len(df))) + for c in ["onset", "trial_type"]: + df[c] = df[c].astype(str) + df.loc[0, "onset"] = "n/a" + df.loc[2, "trial_type"] = "n/a" + df.to_csv(events_fname, sep="\t", index=False) + + paradigm = MotorImagery() + + epo, labels, metadata = paradigm.get_data( + dataset=dataset, subjects=["1"], return_epochs=True, additional_metadata="all" + ) + + assert ( + len(epo[metadata["session"] == "0"]) + == len(df[df["trial_type"] == "fake1"]) - 1 + ) # -1 for the one onset which is n/a! + + # test that the first fake1 value is skipped since the onset is n/a + # and the second is skipped as the trial_type is n/a + assert metadata.onset.iloc[0] == float(df[df.trial_type == "fake1"].onset.iloc[1]) + assert "n/a" not in df.trial_type + + assert (epo.metadata.fillna(0) == metadata.fillna(0)).all().all()