Skip to content

Commit 5384222

Browse files
committed
Modified version of FBCSP using welch method, this replicates the current implementation on NP.
1 parent 5b7db7f commit 5384222

File tree

1 file changed

+249
-0
lines changed

1 file changed

+249
-0
lines changed

code/paradigms/ParadigmFBCSPMod.m

Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
classdef ParadigmFBCSPMod < ParadigmDataflowSimplified
2+
% Paradigm for complex oscillatory processes using the filter-bank CSP algorithm.
3+
% Result = para_multiband_csp(Input-Data, Operation-Mode, Options...)
4+
%
5+
% Filter-bank CSP [1,2] is a simple extension of the basic CSP method (see ParadigmCSP), in which for
6+
% each of several time/frequency filters a set of CSP filters is learned, followed by log-variance
7+
% feature extraction, concatenation of all features (over all chosen spectral filters) and
8+
% subsequent machine learning. It is not a general replacement for CSP due to the problem of
9+
% overfitting, but is very useful whenever oscillatory processes in different frequency bands (and
10+
% with different spatial topographies) are jointly active, and their concerted behavior must be
11+
% taken into account for a given prediction task. Filter-bank CSP can also be used to capture
12+
% oscillations in multiple time windows, instead of frequency windows (for example for the detection
13+
% of complex event-related dynamics).
14+
%
15+
% Since the dimensionality of the feature space is larger than in CSP, and since complex
16+
% interactions may be present, a more complex classifier than the default LDA may be necessary to
17+
% learn an appropriate model. On the other hand, more flexibility amplifies the risk of overfitting
18+
% (especially with only little calibration data), so that the performance should always be compared
19+
% to standard CSP (and Spec-CSP). Another reason is that complex (relevant) interactions between
20+
% different frequency bands are seemingly rarely observed in practice. The most important
21+
% user-configurable parameters are the selection regions in time and frequency and the learner
22+
% component.
23+
%
24+
% Typical applications would be those in which either complex event-related oscillatory dynamics
25+
% happen (for example when reacting to a particular stimulus) and/or where non-trivial interactions
26+
% between frequency bands (e.g. alpha/theta) are relevant, such as, for example, in workload
27+
% measurements.
28+
%
29+
% Example: Consider a calibration data set in which a subject is maintaining and updating
30+
% different number of items in his/her working memory at different times, e.g. while performing
31+
% the n-Back task [2]. Events with types 'n1','n2','n3' indicate challenge stimuli in which the
32+
% respective number of items is being processed by the person. The goal is to be able to predict
33+
% the working-memory load of the person following the presentation of such a memory-related
34+
% challenge. An epoch of 3 seconds relative to each challenge is selected, and three different
35+
% regions are chosen, two of them over the entire interval, covering the theta and alpha ryhthm,
36+
% respectively, and one region that is restricted to a window around the time of heaviest
37+
% cognitive processing. The three regions are specified as a cell array of flt_select
38+
% parameters.
39+
%
40+
% data = io_loadset('data sets/mary/nback.eeg')
41+
% myapproach = {'FBCSP' 'SignalProcessing',{'EpochExtraction',[-0.5 2.5]}, ...
42+
% 'Prediction', {'FeatureExtraction',{'FreqWindows',[4 6; 7 15; 7 15],'TimeWindows',[-0.5 2.5; -0.5 2.5; 0.25 1.25]}, ...
43+
% 'MachineLearning',{'Learner','logreg'}}}
44+
% [loss,model,stats] = bci_train('Data',data, 'Approach','ParadigmFBCSP, 'TargetMarkers',{'n1','n2','n3'})
45+
%
46+
% References;
47+
% [1] Quadrianto Novi, Cuntai Guan, Tran Huy Dat, and Ping Xue, "Sub-band Common Spatial Pattern (SBCSP) for Brain-Computer Interface"
48+
% Proceedings of the 3rd International IEEE EMBS Conference on Neural Engineering Kohala Coast, Hawaii, USA, May 2-5, 2007
49+
% [2] Kai K. Ang, Zhang Y. Chin, Haihong Zhang, Cuntai Guan, "Filter Bank Common Spatial Pattern (FBCSP) in Brain-Computer Interface"
50+
% In 2008 IEEE International Joint Conference on Neural Networks (IEEE World Congress on Computational Intelligence) (June 2008), pp. 2390-2397.
51+
% [3] Owen, A. M., McMillan, K. M., Laird,A. R. & Bullmore, E. "N-back working memory paradigm: A meta-analysis of normative functional neuroimaging studies."
52+
% Human Brain Mapping, 25, 46-59, 2005
53+
%
54+
% Name:
55+
% Filter-Bank CSP
56+
%
57+
% Christian Kothe, Swartz Center for Computational Neuroscience, UCSD
58+
% 2010-04-29
59+
60+
methods
61+
62+
function defaults = preprocessing_defaults(self)
63+
% define the default pre-processing parameters of this paradigm
64+
defaults = {'EpochExtraction',[0.5 3.5],'Resampling',200};
65+
end
66+
67+
function model = feature_adapt(self,varargin)
68+
% adapt a feature representation using the CSP algorithm
69+
args = arg_define(varargin, ...
70+
arg_norep('signal'), ...
71+
arg({'patterns','PatternPairs'},3,uint32([1 1 64 10000]),'CSP patterns per band (times two).','cat','Feature Extraction'), ...
72+
arg({'freqwnds','FreqWindows'},[0.5 3; 4 7; 8 12; 13 30; 31 42],[0 0.5 200 1000],'Frequency bands of interest. Matrix containing one row for the start and end of each frequency band from which CSP patterns shall be computed. Values in Hz.','cat','Feature Extraction'), ...
73+
arg({'timewnds','TimeWindows'},[],[],'Time windows of interest. Matrix containing one row for the start and end of each time window from which CSP patterns shall be computed. Values in seconds. If both this and the freqwnds parameter are non-empty, they should have the same number of rows.','cat','Feature Extraction'), ...
74+
arg({'winfunc','WindowFunction'},'rect',{'barthann','bartlett','blackman','blackmanharris','bohman','cheb','flattop','gauss','hamming','hann','kaiser','nuttall','parzen','rect','taylor','triang','tukey'},'Type of window function. Typical choices are rect (rectangular), hann, gauss, blackman and kaiser.'),...
75+
arg({'winparam','WindowParameter','param'},[],[],'Parameter of the window function. This is mandatory for cheb, kaiser and tukey and optional for some others.','shape','scalar'),...
76+
arg({'nfft','NFFT'}, [], [],'Size of the FFT used in spectrum calculation. Default value is the greater of 256 or the next power of 2 greater than the length of the signal.' ),...
77+
arg({'winlen','WinLen'},100, [10, 1000], 'Divide the signal into sections of this length for Welch spectrum calculation.'),...
78+
arg({'numoverlap','NumOverlap'}, [], [10, 1000], 'Number of overlap samples from section to next for Welch spectrum calculation.'));
79+
80+
if args.signal.nbchan == 1
81+
error('Multi-band CSP does intrinsically not support single-channel data (it is a spatial filter).'); end
82+
if args.signal.nbchan < args.patterns
83+
error('Multi-band CSP prefers to work on at least as many channels as you request output patterns. Please reduce the number of pattern pairs.'); end
84+
if ~isempty(args.freqwnds) && ~isempty(args.timewnds) && size(args.freqwnds,1) ~= size(args.timewnds,1)
85+
error('If both time and frequency windows are specified, both arrays must have the same number of rows (together they define the windows in time and frequency).'); end
86+
if isempty(args.timewnds)
87+
args.timewnds = zeros(size(args.freqwnds,1),0); end
88+
if isempty(args.freqwnds)
89+
args.freqwnds = zeros(size(args.timewnds,1),0); end
90+
91+
[signal, nof, freqwnds, timewnds, winfunc, winparam, nfft, winlen, numoverlap] = deal(args.signal, args.patterns, args.freqwnds,...
92+
args.timewnds, args.winfunc, args.winparam, args.nfft, args.winlen, args.numoverlap);
93+
94+
[C,S,dum] = size(signal.data);
95+
Fs = signal.srate;
96+
97+
if isempty(numoverlap)
98+
numoverlap = floor(0.5*winlen); end
99+
100+
if (winlen > S) || (numoverlap > winlen)
101+
error(' In Welch method, the length of the window should be smaller than the signal length, and the number of overlap should be smaller than the window length.');
102+
else
103+
win = window_func(winfunc,winlen,winparam);
104+
nwin = floor((S-numoverlap)/(winlen-numoverlap));
105+
end
106+
107+
% The innfft is used internally to design and apply CSP filters
108+
if isempty(nfft)
109+
innfft = 2^(nextpow2(signal.pnts));
110+
else
111+
innfft = max(nfft, 2^(nextpow2(signal.pnts)))
112+
if innfft > nfft
113+
error(' The chosen length of FFT is too short. '); end
114+
end
115+
116+
allfreqs = 0:Fs/innfft:Fs;
117+
allfreqs = allfreqs(1:innfft);
118+
119+
for c=1:2
120+
% compute the per-class epoched data X and its Fourier transform (along time), Xfft
121+
X{c} = exp_eval_optimized(set_picktrials(signal,'rank',c));
122+
[C,S,T] = size(X{c}.data);
123+
Xfft{c} = zeros(C,T,nfft,nwin);
124+
125+
signal_idx = bsxfun(@plus,[1:winlen]',[0:nwin-1]*(winlen-numoverlap));
126+
Xdata_seg = repmat(X{c}.data,1,1,1,nwin); %Xdata_seg -> C,S,T,nwin
127+
Xdata_seg2 = permute(Xdata_seg,[1 3 2 4]); %Xdata_seg2 ->C,T,S,nwin
128+
Xdata_seg3 = reshape(Xdata_seg2(:,:,signal_idx),C,T,[],nwin); % Xdata_seg3 -> C,T,winlen,nwin
129+
Xdata_win = bsxfun(@times,Xdata_seg3,reshape(win,1,1,winlen,1)); % Xdata_win -> C,T,winlen,nwin
130+
Xfft{c} = fft(Xdata_win,innfft,3); % Xfft -> C,T,nfft,nwin
131+
end
132+
133+
filters = [];
134+
patterns = [];
135+
alphas = [];
136+
for fb=1:size(args.freqwnds,1)
137+
138+
[freqs,findx] = getfgrid(Fs,innfft,freqwnds(fb,:));
139+
I = zeros(innfft,1); I(findx)=1;
140+
for cc=1:2
141+
[C,S,T] = size(X{cc}.data);
142+
Xspec{cc} = zeros(C,C);
143+
%Xfft_crop -> C,T,numbands,numwin
144+
Xfft_crop = Xfft{cc}(:, :,findx, :);
145+
F = 2 * real(squeeze(sum(bsxfun(@times,conj(permute(Xfft_crop,[1,5,3,2,4])),permute(Xfft_crop,[5,1,3,2,4])),5))./nwin); % F{c}-> C,C,freqs,T
146+
% compute the cross-spectrum as an average over trials
147+
Xspec{cc} = sum(squeeze(mean(F,4)),3);
148+
end
149+
150+
[V,D] = eig(Xspec{1},Xspec{1}+Xspec{2});
151+
P = inv(V);
152+
filters = [filters V(:,[1:nof end-nof+1:end])];
153+
patterns = [patterns P([1:nof end-nof+1:end],:)'];
154+
alphas = [alphas repmat(I,1, 2*nof)];
155+
156+
end
157+
freq_args.nfft = innfft;
158+
freq_args.win = win;
159+
freq_args.winlen = winlen;
160+
freq_args.numoverlap = numoverlap;
161+
freq_args.nwin = nwin;
162+
model = struct('filters',{filters},'patterns',{patterns},'alphas',{alphas},'freq_args',{freq_args},'chanlocs',{args.signal.chanlocs});
163+
end
164+
165+
function features = feature_extract(self,signal,featuremodel)
166+
freq_args = featuremodel.freq_args;
167+
[nfft, win, winlen, numoverlap, nwin] = deal(freq_args.nfft, freq_args.win, freq_args.winlen, freq_args.numoverlap, freq_args.nwin);
168+
X = signal.data;
169+
[C,S,T] = size(X);
170+
numf = size(featuremodel.filters,2);
171+
features = zeros(T,numf);
172+
173+
Xtemp = permute(X,[3,2,1]); % T,S,C
174+
Xtemp2 = zeros(T,S,size(featuremodel.filters,2));
175+
for t=1:T
176+
Xtemp2(t,:,:) = squeeze(Xtemp(t,:,:)) * featuremodel.filters;
177+
end
178+
%Xtemp2 -> T,S,numf
179+
180+
% Using welch method
181+
signal_idx = bsxfun(@plus,[1:winlen]',[0:nwin-1]*(winlen-numoverlap));
182+
Xdata_seg = repmat(Xtemp2,1,1,1,nwin); %Xdata_seg -> T,S,numf,nwin
183+
Xdata_seg2 = permute(Xdata_seg,[1 3 2 4]); %Xdata_seg2 ->T,numf,S,nwin
184+
Xdata_seg3 = reshape(Xdata_seg2(:,:,signal_idx),T,numf,[],nwin); % Xdata_seg3 -> T,numf,winlen,nwin
185+
Xdata_win = bsxfun(@times,Xdata_seg3,reshape(win,1,1,winlen,1)); % Xdata_win -> T,numf,winlen,nwin
186+
Xfft = fft(Xdata_win,nfft,3); % Xfft -> T,numf,nfft,nwin
187+
Xdata1 = squeeze(sum(Xfft,4)./nwin); % Xdata1 ->T,numf,nfft
188+
Xdata = permute(Xdata1,[1,3,2]); % Xdata ->T,nfft,numf
189+
190+
%Xdata = fft(Xtemp2,nfft,2); %Using regular fft
191+
192+
Xdata2 = bsxfun(@times,Xdata,permute(featuremodel.alphas,[3,1,2]));
193+
Xdata3 = ifft(Xdata2,nfft,2);
194+
Xdata4 = 2*real(Xdata3(:,1:S,:));
195+
features = log(squeeze(var(Xdata4,0,2)));
196+
197+
end
198+
199+
function visualize_model(self,varargin) %#ok<*INUSD>
200+
args = arg_define([0 3],varargin, ...
201+
arg_norep({'myparent','Parent'},[],[],'Parent figure.'), ...
202+
arg_norep({'featuremodel','FeatureModel'},[],[],'Feature model. This is the part of the model that describes the feature extraction.'), ...
203+
arg_norep({'predictivemodel','PredictiveModel'},[],[],'Predictive model. This is the part of the model that describes the predictive mapping.'), ...
204+
arg({'patterns','PlotPatterns'},true,[],'Plot patterns instead of filters. Whether to plot spatial patterns (forward projections) rather than spatial filters.'), ...
205+
arg({'weight_scaled','WeightScaled'},false,[],'Scaled by weight. Whether to scale the patterns by weight.'));
206+
arg_toworkspace(args);
207+
208+
% find the relevant components
209+
scores = predictivemodel.model.w;
210+
scores = sqrt(abs(scores));
211+
% optionally remove the bias if included in w
212+
if length(scores) == size(featuremodel.patterns,2)+1
213+
scores = scores(1:end-1); end
214+
% frequency labels
215+
% titles = repmat({'delta','theta','alpha','beta','gamma'},8,1); titles = titles(:);
216+
% extract relevant patterns
217+
patterns = featuremodel.patterns(:,find(scores)); %#ok<FNDSB>
218+
filters = featuremodel.filters(:,find(scores)); %#ok<FNDSB>
219+
% plot them
220+
if args.weight_scaled
221+
if args.patterns
222+
topoplot_grid(patterns,featuremodel.chanlocs,'scales',scores(find(scores))/max(scores)*1);
223+
else
224+
topoplot_grid(filters,featuremodel.chanlocs,'scales',scores(find(scores))/max(scores)*1);
225+
end
226+
else
227+
if args.patterns
228+
topoplot_grid(patterns,featuremodel.chanlocs);
229+
else
230+
topoplot_grid(filters,featuremodel.chanlocs);
231+
end
232+
end
233+
% figure;
234+
end
235+
236+
function layout = dialog_layout_defaults(self)
237+
% define the default configuration dialog layout
238+
layout = {'SignalProcessing.Resampling.SamplingRate', 'SignalProcessing.EpochExtraction', '', ...
239+
'Prediction.FeatureExtraction.FreqWindows', 'Prediction.FeatureExtraction.TimeWindows', ...
240+
'Prediction.FeatureExtraction.WindowFunction', '', 'Prediction.FeatureExtraction.PatternPairs', '', ...
241+
'Prediction.MachineLearning.Learner'};
242+
end
243+
244+
function tf = needs_voting(self)
245+
tf = true;
246+
end
247+
end
248+
end
249+

0 commit comments

Comments
 (0)