Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 43 additions & 1 deletion pyprep/find_noisy_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import mne
import numpy as np
from mne.utils import check_random_state, logger
from scipy import signal
from scipy import signal, stats

from pyprep.ransac import find_bad_by_ransac
from pyprep.removeTrend import removeTrend
Expand Down Expand Up @@ -70,6 +70,7 @@ def __init__(self, raw, do_detrend=True, random_state=None, matlab_strict=False)
"bad_by_hf_noise": {},
"bad_by_correlation": {},
"bad_by_dropout": {},
"bad_by_psd": {},
"bad_by_ransac": {},
}

Expand All @@ -84,6 +85,7 @@ def __init__(self, raw, do_detrend=True, random_state=None, matlab_strict=False)
self.bad_by_correlation = []
self.bad_by_SNR = []
self.bad_by_dropout = []
self.bad_by_psd = []
self.bad_by_ransac = []

# Get original EEG channel names, channel count & samples
Expand Down Expand Up @@ -486,6 +488,46 @@ def find_bad_by_SNR(self):
# Flag channels bad by both HF noise and low correlation as bad by low SNR
self.bad_by_SNR = list(bad_by_corr.intersection(bad_by_hf))

def find_bad_by_PSD(self, zscore_threshold=3.0):
"""
Detect channels with abnormally high or low overall power spectral density
(PSD) values.

A channel is considered "bad-by-psd" if its psd value deviates
considerably from the median channel psd, as calculated using a
Z-scoring method and the given z-score threshold.
PSD calculation is done using the Welch method.
Uses the Welch method for PSD calculation

Parameters
----------
zscore_threshold : float, optional
The minimum noisiness z-score of a channel for it to be considered
bad-by-psd. Defaults to ``3.0``.
"""
if self.EEGFiltered is None:
self.EEGFiltered = self._get_filtered_data()
psd = self.EEGFiltered.compute_psd(method='welch', fmin=1, fmax=50)
log_psd = 10 * np.log10(psd.get_data())
median_channel_psd = np.median(log_psd, axis=0)

# # Calculate robust Z-scores for the channel amplitudes
psd_zscore = np.zeros(self.n_chans_original)
psd_zscore[self.usable_idx] = stats.zscore(np.sum(log_psd - median_channel_psd, axis=1))

# Flag channels with unusually high or low PSD values compared to the median channel
psd_channel_mask = np.isnan(psd_zscore) | (psd_zscore > zscore_threshold)
abnormal_psd_channels = self.ch_names_original[psd_channel_mask]

# Update names of bad channels by abnormal PSD & save additional info
self.bad_by_psd = abnormal_psd_channels.tolist()
self._extra_info["bad_by_psd"].update(
{
"median_channel_psd": median_channel_psd,
"psd_zscore": psd_zscore,
}
)

def find_bad_by_ransac(
self,
n_samples=50,
Expand Down