Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions moabb/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import mne_bids
import pandas as pd
from mne_bids.path import _find_matching_sidecar
from sklearn.pipeline import Pipeline

from moabb.datasets.bids_interface import StepType, _interface_map
Expand Down Expand Up @@ -698,6 +699,34 @@ def data_path(
""" # noqa: E501
pass

def get_additional_metadata(
self, subject: str, session: str, run: str
) -> None | pd.DataFrame:
"""
Load additional metadata for a specific subject, session, and run.

This method is intended to be overridden by subclasses to provide
additional metadata specific to the dataset. The metadata is typically
loaded from an `events.tsv` file or similar data source.

Parameters
----------
subject : str
The identifier for the subject.
session : str
The identifier for the session.
run : str
The identifier for the run.

Returns
-------
None | pd.DataFrame
A DataFrame containing the additional metadata if available,
otherwise None.
"""

return None


class BaseBIDSDataset(BaseDataset):
"""Abstract BIDS dataset class.
Expand Down Expand Up @@ -788,6 +817,69 @@ def _get_single_subject_data(self, subject):
data.setdefault(session, {})[run] = raw
return data

def get_additional_metadata(
self, subject: str, session: str, run: str
) -> None | pd.DataFrame:
"""
Load additional metadata for a specific subject, session, and run.
This is just loading all metadata, filtering down to epochs levels
is done at ...


Parameters
----------
subject : str
The identifier for the subject.
session : str
The identifier for the session.
run : str
The identifier for the run.

Returns
-------
None | pd.DataFrame
A DataFrame containing the additional metadata if available,
otherwise None.
"""

bids_paths = self.bids_paths(subject)

# select only with matching session and run
bids_path_selected = [
pth
for pth in bids_paths
if f"ses-{session}" in pth.basename and f"run-{run}" in pth.basename
]

if len(bids_path_selected) > 1:
raise ValueError("More than one matching BIDS path found.")
bids_path = bids_path_selected[0]

events_fname = _find_matching_sidecar(
bids_path, suffix="events", extension=".tsv", on_error="warn"
)

dm = pd.read_csv(events_fname, sep="\t").assign(
subject=subject, session=session, run=run
)

# As long as this is not part of mne-bids https://github.com/mne-tools/mne-bids/pull/1389,
# we cannot will functionally replicate the filtering (as we only)
# need the dropping part
dm = dm[(dm.onset != "n/a") & (~dm.onset.isna())]
dm["onset"] = dm["onset"].astype(float)

if "trial_type" in dm.columns:
dm = dm[(dm.trial_type != "n/a") & (~dm.onset.isna())]
elif "value" in dm.columns:
dm = dm[(dm.value != "n/a") & (~dm.onset.isna())]

# for the bids_dataset we can assume that the events are taken from
# a `trial_type` columns -> filter on this
dm = dm[dm["trial_type"].isin(self.event_id.keys())]

return dm


class LocalBIDSDataset(BaseBIDSDataset):
"""Generic local/private BIDS datasets.
Expand Down
74 changes: 71 additions & 3 deletions moabb/paradigms/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import abc
import logging
from operator import methodcaller
from typing import List, Optional, Tuple
from typing import List, Literal, Optional, Tuple

import mne
import numpy as np
Expand Down Expand Up @@ -232,6 +234,7 @@ def get_data( # noqa: C901
return_raws=False,
cache_config=None,
postprocess_pipeline=None,
additional_metadata: Literal["all"] | list[str] = None,
):
"""
Return the data for a list of subject.
Expand Down Expand Up @@ -265,6 +268,13 @@ def get_data( # noqa: C901
This pipeline must return an ``np.ndarray``.
This pipeline must be "fixed" because it will not be trained,
i.e. no call to ``fit`` will be made.
additional_metadata: Literal["all"] | list[str] | None
Additional metadata to be loaded if return_epochs=True.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The get_data() function returns a triplet (obj, labels, metadata).
obj contains the data and can be a np.array, mne.Epochs or mne.io.Raw depending on the return_epochs and return_raws parameters.
But we should always return some metadata, so the additional columns should always be set when additional_metadata='all'

If None, the default metadata will be loaded containing containing
`subject`, `session` and `run`. If "all", all columns of the `events.tsv`
file will be loaded. A list of column names can be passed to just
select these columns in addition to the three default values mentioned
before.

Returns
-------
Expand Down Expand Up @@ -306,6 +316,22 @@ def get_data( # noqa: C901
for session, runs in sessions.items():
for run in runs.keys():
proc = [data_i[subject][session][run] for data_i in data]

if additional_metadata:
ext_metadata = [
dataset.get_additional_metadata(
subject=subject, session=session, run=run
)
] * len(process_pipelines)

if isinstance(additional_metadata, list):
ext_metadata = [
dm[["session", "subject", "run"] + additional_metadata]
for dm in ext_metadata
]
else:
ext_metadata = [None] * len(process_pipelines)

if any(obj is None for obj in proc):
# this mean the run did not contain any selected event
# go to next
Expand All @@ -321,6 +347,7 @@ def get_data( # noqa: C901
if len(self.filters) == 1
else mne.concatenate_epochs(proc)
)

elif return_raws:
assert all(len(proc[0]) == len(p) for p in proc[1:])
n = 1
Expand Down Expand Up @@ -350,16 +377,30 @@ def get_data( # noqa: C901
met["subject"] = subject
met["session"] = session
met["run"] = run

metadata.append(met)

# overwrite if additional is required
if additional_metadata:
# extend the metadata according to the filters

dmeta_ext = (
ext_metadata[0].copy()
if isinstance(ext_metadata[0], pd.DataFrame)
else pd.DataFrame()
)
metadata[-1] = dmeta_ext

if return_epochs:
x.metadata = (
met.copy()
metadata[-1].copy()
if len(self.filters) == 1
else pd.concat(
[met.copy()] * len(self.filters), ignore_index=True
[metadata[-1].copy()] * len(self.filters),
ignore_index=True,
)
)

X.append(x)
labels.append(lbs)

Expand Down Expand Up @@ -556,3 +597,30 @@ def scoring(self):
def _get_events_pipeline(self, dataset):
event_id = self.used_events(dataset)
return RawToEvents(event_id=event_id, interval=dataset.interval)


# def load_bids_event_metadata(
# data_set: BaseBIDSDataset, subject: str, session: str, run: str
# ) -> pd.DataFrame:
# bids_paths = data_set.bids_paths(subject)
#
# # select only with matching session and run
# bids_path_selected = [
# pth
# for pth in bids_paths
# if f"ses-{session}" in pth.basename and f"run-{run}" in pth.basename
# ]
#
# if len(bids_path_selected) > 1:
# raise ValueError("More than one matching BIDS path found.")
# bids_path = bids_path_selected[0]
#
# events_fname = _find_matching_sidecar(
# bids_path, suffix="events", extension=".tsv", on_error="warn"
# )
#
# dm = pd.read_csv(events_fname, sep="\t").assign(
# subject=subject, session=session, run=run
# )
#
# return dm
135 changes: 135 additions & 0 deletions moabb/tests/test_paradigms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,12 @@
import pytest
from mne import BaseEpochs
from mne.io import BaseRaw
from mne_bids.path import _find_matching_sidecar

from moabb.datasets import BNCI2014_001
from moabb.datasets.base import (
LocalBIDSDataset,
)
from moabb.datasets.fake import FakeDataset
from moabb.paradigms import (
CVEP,
Expand Down Expand Up @@ -1237,3 +1241,134 @@ def test_epochs(self, epochs_labels_metadata, dataset):
np.testing.assert_array_almost_equal(
epo.get_data()[0, :, 0] * dataset.unit_factor, X
)


class TestMetadata:

@pytest.fixture(scope="class")
def cached_dataset_root(self, tmpdir_factory):
root = tmpdir_factory.mktemp("fake_bids")
dataset = FakeDataset(
event_list=["fake1", "fake2"], n_sessions=2, n_subjects=2, n_runs=1
)
dataset.get_data(cache_config=dict(save_raw=True, overwrite_raw=False, path=root))
return root / "MNE-BIDS-fake-dataset-imagery-2-2--60--120--fake1-fake2--c3-cz-c4"

def test_additional_metadata_extracts_aligned(self, cached_dataset_root):
"""
Test extraction of additional metadata if all rows in the events.tsv
were used to create annotations on the raw -> used for epoching
"""

# --- The tsv files have metadata which would contain the following
#
# onset duration trial_type value sample
# 0.0078125 3.0 fake1 1 1
# 1.984375 3.0 fake2 2 254
# 3.96875 3.0 fake1 1 508
# 5.953125 3.0 fake2 2 762
#
# --- While onset, duration and trial_type, are implicitly available
# --- by the epoch design, we could want `value` and or `sample` as well

dataset = LocalBIDSDataset(
cached_dataset_root,
events={"fake1": 1, "fake2": 2},
interval=[0, 3],
paradigm="imagery",
)
paradigm = MotorImagery()

epo1, labels1, metadata1 = paradigm.get_data(
dataset=dataset,
subjects=["1"],
return_epochs=True,
)

raw, raw_labels, raw_metadata = paradigm.get_data(
dataset=dataset,
subjects=["1"],
return_epochs=False,
additional_metadata="all",
)

epo2, labels2, metadata2 = paradigm.get_data(
dataset=dataset,
subjects=["1"],
return_epochs=True,
additional_metadata="all",
)

epo3, labels3, metadata3 = paradigm.get_data(
dataset=dataset,
subjects=["1"],
return_epochs=True,
additional_metadata=["value"],
)

epo4, labels4, metadata4 = paradigm.get_data(
dataset=dataset,
subjects=["1"],
return_epochs=True,
additional_metadata=["value", "duration"],
)

assert epo1 == epo2 == epo3
assert (labels1 == labels2).all()
assert (labels2 == labels3).all()

assert (raw_metadata == epo2.metadata).all().all()
assert (metadata2 == epo2.metadata).all().all()

assert (metadata1.columns == ["subject", "session", "run"]).all()

assert "value" in metadata2.columns
assert "sample" in metadata2.columns
assert "value" in metadata3.columns
assert "value" in metadata4.columns
assert "duration" in metadata4.columns
assert "sample" not in metadata3.columns
assert "sample" not in metadata4.columns

def test_additional_metadata_extracts_non_aligned(self, cached_dataset_root):
"""
Test extraction of additional metadata if NOT all rows in the events.tsv
were used to create annotations on the raw -> used for epoching
"""

dataset = LocalBIDSDataset(
cached_dataset_root,
events={"fake1": 1},
interval=[0, 3],
paradigm="imagery",
)

# modify the events.tsv to contain 'n/a'
events_fname = _find_matching_sidecar(
dataset.bids_paths("1")[0], suffix="events", extension=".tsv"
)
df = pd.read_csv(events_fname, sep="\t")
df = df.assign(ix=range(len(df)))
for c in ["onset", "trial_type"]:
df[c] = df[c].astype(str)
df.loc[0, "onset"] = "n/a"
df.loc[2, "trial_type"] = "n/a"
df.to_csv(events_fname, sep="\t", index=False)

paradigm = MotorImagery()

epo, labels, metadata = paradigm.get_data(
dataset=dataset, subjects=["1"], return_epochs=True, additional_metadata="all"
)

assert (
len(epo[metadata["session"] == "0"])
== len(df[df["trial_type"] == "fake1"]) - 1
) # -1 for the one onset which is n/a!

# test that the first fake1 value is skipped since the onset is n/a
# and the second is skipped as the trial_type is n/a
assert metadata.onset.iloc[0] == float(df[df.trial_type == "fake1"].onset.iloc[1])
assert "n/a" not in df.trial_type

assert (epo.metadata.fillna(0) == metadata.fillna(0)).all().all()
Loading