@@ -58,10 +58,16 @@ def _parse_header(self):
58
58
self .smrx_file = sonpy .lib .SonFile (sName = str (self .filename ), bReadOnly = True )
59
59
smrx = self .smrx_file
60
60
61
+ self ._time_base = smrx .GetTimeBase ()
62
+
61
63
channel_infos = []
62
64
signal_channels = []
65
+ spike_channels = []
66
+ self ._all_spike_ticks = {}
67
+
63
68
for chan_ind in range (smrx .MaxChannels ()):
64
69
chan_type = smrx .ChannelType (chan_ind )
70
+ chan_id = str (chan_ind )
65
71
if chan_type == sonpy .lib .DataType .Adc :
66
72
physical_chan = smrx .PhysicalChannel (chan_ind )
67
73
divide = smrx .ChannelDivide (chan_ind )
@@ -78,13 +84,35 @@ def _parse_header(self):
78
84
offset = smrx .GetChannelOffset (chan_ind )
79
85
units = smrx .GetChannelUnits (chan_ind )
80
86
ch_name = smrx .GetChannelTitle (chan_ind )
81
- chan_id = str ( chan_ind )
87
+
82
88
dtype = 'int16'
83
89
# set later after grouping
84
90
stream_id = '0'
85
91
signal_channels .append ((ch_name , chan_id , sr , dtype ,
86
92
units , gain , offset , stream_id ))
87
93
94
+ elif chan_type == sonpy .lib .DataType .AdcMark :
95
+ # spike and waveforms : only spike times is used here
96
+ ch_name = smrx .GetChannelTitle (chan_ind )
97
+ first_time = smrx .FirstTime (chan_ind , 0 , max_time )
98
+ max_time = smrx .ChannelMaxTime (chan_ind )
99
+ divide = smrx .ChannelDivide (chan_ind )
100
+ # here we don't use filter (sonpy.lib.MarkerFilter()) so we get all marker
101
+ wave_marks = smrx .ReadWaveMarks (chan_ind , int (max_time / divide ), 0 , max_time )
102
+
103
+ # here we load in memory all spike once because the access is really slow
104
+ # with the ReadWaveMarks
105
+ spike_ticks = np .array ([t .Tick for t in wave_marks ])
106
+ spike_codes = np .array ([t .Code1 for t in wave_marks ])
107
+
108
+ unit_ids = np .unique (spike_codes )
109
+ for unit_id in unit_ids :
110
+ name = f'{ ch_name } #{ unit_id } '
111
+ spike_chan_id = f'ch{ chan_id } #{ unit_id } '
112
+ spike_channels .append ((name , spike_chan_id , '' , 1 , 0 , 0 , 0 ))
113
+ mask = spike_codes == unit_id
114
+ self ._all_spike_ticks [spike_chan_id ] = spike_ticks [mask ]
115
+
88
116
signal_channels = np .array (signal_channels , dtype = _signal_channel_dtype )
89
117
90
118
# channels are grouped into stream if they have a common start, stop, size, divide and sampling_rate
@@ -104,8 +132,7 @@ def _parse_header(self):
104
132
signal_streams = np .array (signal_streams , dtype = _signal_stream_dtype )
105
133
106
134
# spike channels not handled
107
- spike_channels = []
108
- spike_channels = np .array ([], dtype = _spike_channel_dtype )
135
+ spike_channels = np .array (spike_channels , dtype = _spike_channel_dtype )
109
136
110
137
# event channels not handled
111
138
event_channels = []
@@ -115,9 +142,10 @@ def _parse_header(self):
115
142
self ._seg_t_stop = - np .inf
116
143
for info in self .stream_info :
117
144
self ._seg_t_start = min (self ._seg_t_start ,
118
- info ['first_time' ] / info ['sampling_rate' ])
145
+ info ['first_time' ] * self ._time_base )
146
+
119
147
self ._seg_t_stop = max (self ._seg_t_stop ,
120
- info ['max_time' ] / info [ 'sampling_rate' ] )
148
+ info ['max_time' ] * self . _time_base )
121
149
122
150
self .header = {}
123
151
self .header ['nb_block' ] = 1
@@ -141,7 +169,7 @@ def _get_signal_size(self, block_index, seg_index, stream_index):
141
169
142
170
def _get_signal_t_start (self , block_index , seg_index , stream_index ):
143
171
info = self .stream_info [stream_index ]
144
- t_start = info ['first_time' ] / info [ 'sampling_rate' ]
172
+ t_start = info ['first_time' ] * self . _time_base
145
173
return t_start
146
174
147
175
def _get_analogsignal_chunk (self , block_index , seg_index , i_start , i_stop ,
@@ -175,3 +203,28 @@ def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop,
175
203
sigs [:, i ] = sig
176
204
177
205
return sigs
206
+
207
+ def _spike_count (self , block_index , seg_index , unit_index ):
208
+ unit_id = self .header ['spike_channels' ][unit_index ]['id' ]
209
+ spike_ticks = self ._all_spike_ticks [unit_id ]
210
+ return spike_ticks .size
211
+
212
+ def _get_spike_timestamps (self , block_index , seg_index , unit_index , t_start , t_stop ):
213
+ unit_id = self .header ['spike_channels' ][unit_index ]['id' ]
214
+ spike_ticks = self ._all_spike_ticks [unit_id ]
215
+ if t_start is not None :
216
+ tick_start = int (t_start / self ._time_base )
217
+ spike_ticks = spike_ticks [spike_ticks >= tick_start ]
218
+ if t_stop is not None :
219
+ tick_stop = int (t_stop / self ._time_base )
220
+ spike_ticks = spike_ticks [spike_ticks <= tick_stop ]
221
+ return spike_ticks
222
+
223
+ def _rescale_spike_timestamp (self , spike_timestamps , dtype ):
224
+ spike_times = spike_timestamps .astype (dtype )
225
+ spike_times *= self ._time_base
226
+ return spike_times
227
+
228
+ def _get_spike_raw_waveforms (self , block_index , seg_index ,
229
+ spike_channel_index , t_start , t_stop ):
230
+ return None
0 commit comments