Skip to content

Commit 400c042

Browse files
Merge pull request #1014 from samuelgarcia/ced_read_spikes
Add read spike times in CedIO
2 parents 0bf7a5e + c03c22e commit 400c042

File tree

1 file changed

+59
-6
lines changed

1 file changed

+59
-6
lines changed

neo/rawio/cedrawio.py

Lines changed: 59 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,16 @@ def _parse_header(self):
5858
self.smrx_file = sonpy.lib.SonFile(sName=str(self.filename), bReadOnly=True)
5959
smrx = self.smrx_file
6060

61+
self._time_base = smrx.GetTimeBase()
62+
6163
channel_infos = []
6264
signal_channels = []
65+
spike_channels = []
66+
self._all_spike_ticks = {}
67+
6368
for chan_ind in range(smrx.MaxChannels()):
6469
chan_type = smrx.ChannelType(chan_ind)
70+
chan_id = str(chan_ind)
6571
if chan_type == sonpy.lib.DataType.Adc:
6672
physical_chan = smrx.PhysicalChannel(chan_ind)
6773
divide = smrx.ChannelDivide(chan_ind)
@@ -78,13 +84,35 @@ def _parse_header(self):
7884
offset = smrx.GetChannelOffset(chan_ind)
7985
units = smrx.GetChannelUnits(chan_ind)
8086
ch_name = smrx.GetChannelTitle(chan_ind)
81-
chan_id = str(chan_ind)
87+
8288
dtype = 'int16'
8389
# set later after grouping
8490
stream_id = '0'
8591
signal_channels.append((ch_name, chan_id, sr, dtype,
8692
units, gain, offset, stream_id))
8793

94+
elif chan_type == sonpy.lib.DataType.AdcMark:
95+
# spike and waveforms : only spike times is used here
96+
ch_name = smrx.GetChannelTitle(chan_ind)
97+
first_time = smrx.FirstTime(chan_ind, 0, max_time)
98+
max_time = smrx.ChannelMaxTime(chan_ind)
99+
divide = smrx.ChannelDivide(chan_ind)
100+
# here we don't use filter (sonpy.lib.MarkerFilter()) so we get all marker
101+
wave_marks = smrx.ReadWaveMarks(chan_ind, int(max_time / divide), 0, max_time)
102+
103+
# here we load in memory all spike once because the access is really slow
104+
# with the ReadWaveMarks
105+
spike_ticks = np.array([t.Tick for t in wave_marks])
106+
spike_codes = np.array([t.Code1 for t in wave_marks])
107+
108+
unit_ids = np.unique(spike_codes)
109+
for unit_id in unit_ids:
110+
name = f'{ch_name}#{unit_id}'
111+
spike_chan_id = f'ch{chan_id}#{unit_id}'
112+
spike_channels.append((name, spike_chan_id, '', 1, 0, 0, 0))
113+
mask = spike_codes == unit_id
114+
self._all_spike_ticks[spike_chan_id] = spike_ticks[mask]
115+
88116
signal_channels = np.array(signal_channels, dtype=_signal_channel_dtype)
89117

90118
# channels are grouped into stream if they have a common start, stop, size, divide and sampling_rate
@@ -104,8 +132,7 @@ def _parse_header(self):
104132
signal_streams = np.array(signal_streams, dtype=_signal_stream_dtype)
105133

106134
# spike channels not handled
107-
spike_channels = []
108-
spike_channels = np.array([], dtype=_spike_channel_dtype)
135+
spike_channels = np.array(spike_channels, dtype=_spike_channel_dtype)
109136

110137
# event channels not handled
111138
event_channels = []
@@ -115,9 +142,10 @@ def _parse_header(self):
115142
self._seg_t_stop = -np.inf
116143
for info in self.stream_info:
117144
self._seg_t_start = min(self._seg_t_start,
118-
info['first_time'] / info['sampling_rate'])
145+
info['first_time'] * self._time_base)
146+
119147
self._seg_t_stop = max(self._seg_t_stop,
120-
info['max_time'] / info['sampling_rate'])
148+
info['max_time'] * self._time_base)
121149

122150
self.header = {}
123151
self.header['nb_block'] = 1
@@ -141,7 +169,7 @@ def _get_signal_size(self, block_index, seg_index, stream_index):
141169

142170
def _get_signal_t_start(self, block_index, seg_index, stream_index):
143171
info = self.stream_info[stream_index]
144-
t_start = info['first_time'] / info['sampling_rate']
172+
t_start = info['first_time'] * self._time_base
145173
return t_start
146174

147175
def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop,
@@ -175,3 +203,28 @@ def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop,
175203
sigs[:, i] = sig
176204

177205
return sigs
206+
207+
def _spike_count(self, block_index, seg_index, unit_index):
208+
unit_id = self.header['spike_channels'][unit_index]['id']
209+
spike_ticks = self._all_spike_ticks[unit_id]
210+
return spike_ticks.size
211+
212+
def _get_spike_timestamps(self, block_index, seg_index, unit_index, t_start, t_stop):
213+
unit_id = self.header['spike_channels'][unit_index]['id']
214+
spike_ticks = self._all_spike_ticks[unit_id]
215+
if t_start is not None:
216+
tick_start = int(t_start / self._time_base)
217+
spike_ticks = spike_ticks[spike_ticks >= tick_start]
218+
if t_stop is not None:
219+
tick_stop = int(t_stop / self._time_base)
220+
spike_ticks = spike_ticks[spike_ticks <= tick_stop]
221+
return spike_ticks
222+
223+
def _rescale_spike_timestamp(self, spike_timestamps, dtype):
224+
spike_times = spike_timestamps.astype(dtype)
225+
spike_times *= self._time_base
226+
return spike_times
227+
228+
def _get_spike_raw_waveforms(self, block_index, seg_index,
229+
spike_channel_index, t_start, t_stop):
230+
return None

0 commit comments

Comments
 (0)