77
88from  scipy  import  stats 
99
10- # TODO: spike_times -> spike_indexes  
10+ # TODO: spike_times -> spike_indices  
1111""" 
1212Notes 
1313----- 
1414- not everything is used for current purposes 
1515- things might be useful in future for making a sorting analyzer - compute template amplitude as average of spike amplitude. 
1616""" 
1717
18+ ######################################################################################################################## 
19+ # Get Spike Data 
20+ ######################################################################################################################## 
21+ 
1822
1923def  compute_spike_amplitude_and_depth (
2024    sorter_output : str  |  Path ,
2125    localised_spikes_only ,
2226    exclude_noise ,
2327    gain : float  |  None  =  None ,
24-     localised_spikes_channel_cutoff : int  =  None ,   # TODO 
28+     localised_spikes_channel_cutoff : int  =  None ,
2529) ->  tuple [np .ndarray , ...]:
2630    """ 
2731    Compute the amplitude and depth of all detected spikes from the kilosort output. 
2832
29-     This function was ported from  Nick Steinmetz's `spikes` repository 
30-     MATLAB code,  https://github.com/cortex-lab/spikes 
33+     This function is based on code in  Nick Steinmetz's `spikes` repository,  
34+     https://github.com/cortex-lab/spikes 
3135
3236    Parameters 
3337    ---------- 
@@ -46,8 +50,8 @@ def compute_spike_amplitude_and_depth(
4650
4751    Returns 
4852    ------- 
49-     spike_indexes  : np.ndarray 
50-         (num_spikes,) array of spike indexes . 
53+     spike_indices  : np.ndarray 
54+         (num_spikes,) array of spike indices . 
5155    spike_amplitudes : np.ndarray 
5256        (num_spikes,) array of corresponding spike amplitudes. 
5357    spike_depths : np.ndarray 
@@ -66,7 +70,7 @@ def compute_spike_amplitude_and_depth(
6670    if  isinstance (sorter_output , str ):
6771        sorter_output  =  Path (sorter_output )
6872
69-     params  =  _load_ks_dir (sorter_output , load_pcs = True , exclude_noise = exclude_noise )
73+     params  =  load_ks_dir (sorter_output , load_pcs = True , exclude_noise = exclude_noise )
7074
7175    if  localised_spikes_only :
7276        localised_templates  =  []
@@ -81,10 +85,56 @@ def compute_spike_amplitude_and_depth(
8185
8286        localised_template_by_spike  =  np .isin (params ["spike_templates" ], localised_templates )
8387
84-         _strip_spikes (params , localised_template_by_spike )
88+         params ["spike_templates" ] =  params ["spike_templates" ][localised_template_by_spike ]
89+         params ["spike_indices" ] =  params ["spike_indices" ][localised_template_by_spike ]
90+         params ["spike_clusters" ] =  params ["spike_clusters" ][localised_template_by_spike ]
91+         params ["temp_scaling_amplitudes" ] =  params ["temp_scaling_amplitudes" ][localised_template_by_spike ]
92+         params ["pc_features" ] =  params ["pc_features" ][localised_template_by_spike ]
93+ 
94+     spike_locations , spike_max_sites  =  _get_locations_from_pc_features (params )
95+ 
96+     #    Amplitude is calculated for each spike as the template amplitude 
97+     #    multiplied by the `template_scaling_amplitudes`. 
98+     template_amplitudes_unscaled , * _  =  get_unwhite_template_info (
99+         params ["templates" ],
100+         params ["whitening_matrix_inv" ],
101+         params ["channel_positions" ],
102+     )
103+     spike_amplitudes  =  template_amplitudes_unscaled [params ["spike_templates" ]] *  params ["temp_scaling_amplitudes" ]
104+ 
105+     if  gain  is  not None :
106+         spike_amplitudes  *=  gain 
107+ 
108+     compute_template_amplitudes_from_spikes (params ["templates" ], params ["spike_templates" ], spike_amplitudes )
109+ 
110+     if  localised_spikes_only :
111+         # Interpolate the channel ids to location. 
112+         # Remove spikes > 5 um from average position 
113+         # Above we already removed non-localized templates, but that on its own is insufficient. 
114+         # Note for IMEC probe adding a constant term kills the regression making the regressors rank deficient 
115+         # TODO: a couple of approaches. 1) do everything in 3D, draw a sphere around prediction, take spikes only within the sphere 
116+         # 2) do separate for x, y. But resolution will be much lower, making things noisier, also harder to determine threshold. 
117+         # 3) just use depth. Probably go for that. check with others. 
118+         spike_depths  =  spike_locations [:, 1 ]
119+         b  =  stats .linregress (spike_depths , spike_max_sites ).slope 
120+         i  =  np .abs (spike_max_sites  -  b  *  spike_depths ) <=  5 
85121
122+         params ["spike_indices" ] =  params ["spike_indices" ][i ]
123+         spike_amplitudes  =  spike_amplitudes [i ]
124+         spike_locations  =  spike_locations [i , :]
125+         spike_max_sites  =  spike_max_sites [i ]
126+ 
127+     return  params ["spike_indices" ], spike_amplitudes , spike_locations , spike_max_sites 
128+ 
129+ 
130+ def  _get_locations_from_pc_features (params ):
131+     """ 
132+ 
133+     This function is based on code in Nick Steinmetz's `spikes` repository, 
134+     https://github.com/cortex-lab/spikes 
135+     """ 
86136    # Compute spike depths 
87-     pc_features  =  params ["pc_features" ][:, 0 , :]   # Do this compute 
137+     pc_features  =  params ["pc_features" ][:, 0 , :]
88138    pc_features [pc_features  <  0 ] =  0 
89139
90140    # Some spikes do not load at all onto the first PC. To avoid biasing the 
@@ -109,58 +159,28 @@ def compute_spike_amplitude_and_depth(
109159            "to extend this code section to handle more components." 
110160        )
111161
112-     # Get the channel indexes  corresponding to the 32 channels from the PC. 
162+     # Get the channel indices  corresponding to the 32 channels from the PC. 
113163    spike_features_indices  =  params ["pc_features_indices" ][params ["spike_templates" ], :]
114164
115165    # Compute the spike locations as the center of mass of the PC scores 
116166    spike_feature_coords  =  params ["channel_positions" ][spike_features_indices , :]
117-     norm_weights  =  pc_features  /  np .sum (pc_features , axis = 1 )[:, np .newaxis ]  # TOOD: see why they use square 
167+     norm_weights  =  (
168+         pc_features  /  np .sum (pc_features , axis = 1 )[:, np .newaxis ]
169+     )  # TOOD: discuss use of square. Probbaly do not use to keep in line with COM in SI. 
118170    spike_locations  =  spike_feature_coords  *  norm_weights [:, :, np .newaxis ]
119171    spike_locations  =  np .sum (spike_locations , axis = 1 )
120172
121173    # TODO:  now max site per spike is computed from PCs, not as the channel max site as previous 
122-     spike_sites  =  spike_features_indices [np .arange (spike_features_indices .shape [0 ]), np .argmax (norm_weights , axis = 1 )]
174+     spike_max_sites  =  spike_features_indices [
175+         np .arange (spike_features_indices .shape [0 ]), np .argmax (norm_weights , axis = 1 )
176+     ]
123177
124-     #    Amplitude is calculated for each spike as the template amplitude 
125-     #    multiplied by the `template_scaling_amplitudes`. 
126-     template_amplitudes_unscaled , * _  =  get_unwhite_template_info (
127-         params ["templates" ],
128-         params ["whitening_matrix_inv" ],
129-         params ["channel_positions" ],
130-     )
131-     spike_amplitudes  =  template_amplitudes_unscaled [params ["spike_templates" ]] *  params ["temp_scaling_amplitudes" ]
178+     return  spike_locations , spike_max_sites 
132179
133-     if  gain  is  not None :
134-         spike_amplitudes  *=  gain 
135- 
136-     if  localised_spikes_only :
137-         # Interpolate the channel ids to location. 
138-         # Remove spikes > 5 um from average position 
139-         # Above we already removed non-localized templates, but that on its own is insufficient. 
140-         # Note for IMEC probe adding a constant term kills the regression making the regressors rank deficient 
141-         # TODO: a couple of approaches. 1) do everything in 3D, draw a sphere around prediction, take spikes only within the sphere 
142-         # 2) do separate for x, y. But resolution will be much lower, making things noisier, also harder to determine threshold. 
143-         # 3) just use depth. Probably go for that. check with others. 
144-         spike_depths  =  spike_locations [:, 1 ]
145-         b  =  stats .linregress (spike_depths , spike_sites ).slope 
146-         i  =  np .abs (spike_sites  -  b  *  spike_depths ) <=  5   # TODO: need to expose this 
147- 
148-         params ["spike_indexes" ] =  params ["spike_indexes" ][i ]
149-         spike_amplitudes  =  spike_amplitudes [i ]
150-         spike_locations  =  spike_locations [i , :]
151- 
152-     return  params ["spike_indexes" ], spike_amplitudes , spike_locations , spike_sites 
153180
154- 
155- def  _strip_spikes_in_place (params , indices ):
156-     """ """ 
157-     params ["spike_templates" ] =  params ["spike_templates" ][
158-         indices 
159-     ]  # TODO: make an function for this. because we do this a lot 
160-     params ["spike_indexes" ] =  params ["spike_indexes" ][indices ]
161-     params ["spike_clusters" ] =  params ["spike_clusters" ][indices ]
162-     params ["temp_scaling_amplitudes" ] =  params ["temp_scaling_amplitudes" ][indices ]
163-     params ["pc_features" ] =  params ["pc_features" ][indices ]  # TODO: be conciststetn! change indees to indices 
181+ ######################################################################################################################## 
182+ # Get Template Data 
183+ ######################################################################################################################## 
164184
165185
166186def  get_unwhite_template_info (
@@ -173,8 +193,8 @@ def get_unwhite_template_info(
173193    Amplitude is calculated for each spike as the template amplitude 
174194    multiplied by the `template_scaling_amplitudes`. 
175195
176-     This function was ported from  Nick Steinmetz's `spikes` repository 
177-     MATLAB code,  https://github.com/cortex-lab/spikes 
196+     This function is based on code in  Nick Steinmetz's `spikes` repository,  
197+     https://github.com/cortex-lab/spikes 
178198
179199    Parameters 
180200    ---------- 
@@ -213,7 +233,7 @@ def get_unwhite_template_info(
213233
214234    template_amplitudes_unscaled  =  np .max (template_amplitudes_per_channel , axis = 1 )
215235
216-     #  Zero any small channel amplitudes 
236+     #  Zero any small channel amplitudes  TODO: removed this.  
217237    #    threshold_values = 0.3 * template_amplitudes_unscaled  TODO: remove this to be more general. Agree? 
218238    #    template_amplitudes_per_channel[template_amplitudes_per_channel < threshold_values[:, np.newaxis]] = 0 
219239
@@ -253,9 +273,14 @@ def get_unwhite_template_info(
253273    )
254274
255275
256- def  compute_template_amplitudes_from_spikes ():
257-     # Take the average of all spike amplitudes to get actual template amplitudes 
258-     # (since tempScalingAmps are equal mean for all templates) 
276+ def  compute_template_amplitudes_from_spikes (templates , spike_templates , spike_amplitudes ):
277+     """ 
278+     Take the average of all spike amplitudes to get actual template amplitudes 
279+     (since tempScalingAmps are equal mean for all templates) 
280+ 
281+     This function is ported from Nick Steinmetz's `spikes` repository, 
282+     https://github.com/cortex-lab/spikes 
283+     """ 
259284    num_indices  =  templates .shape [0 ]
260285    sum_per_index  =  np .zeros (num_indices , dtype = np .float64 )
261286    np .add .at (sum_per_index , spike_templates , spike_amplitudes )
@@ -264,7 +289,12 @@ def compute_template_amplitudes_from_spikes():
264289    return  template_amplitudes 
265290
266291
267- def  _load_ks_dir (sorter_output : Path , exclude_noise : bool  =  True , load_pcs : bool  =  False ) ->  dict :
292+ ######################################################################################################################## 
293+ # Load Parameters from KS Directory 
294+ ######################################################################################################################## 
295+ 
296+ 
297+ def  load_ks_dir (sorter_output : Path , exclude_noise : bool  =  True , load_pcs : bool  =  False ) ->  dict :
268298    """ 
269299    Loads the output of Kilosort into a `params` dict. 
270300
@@ -300,7 +330,7 @@ def _load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool
300330
301331    params  =  read_python (sorter_output  /  "params.py" )
302332
303-     spike_indexes  =  np .load (sorter_output  /  "spike_times.npy" )
333+     spike_indices  =  np .load (sorter_output  /  "spike_times.npy" )
304334    spike_templates  =  np .load (sorter_output  /  "spike_templates.npy" )
305335
306336    if  (clusters_path  :=  sorter_output  /  "spike_clusters.csv" ).is_dir ():
@@ -328,7 +358,7 @@ def _load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool
328358        noise_cluster_ids  =  cluster_ids [cluster_groups  ==  0 ]
329359        not_noise_clusters_by_spike  =  ~ np .isin (spike_clusters .ravel (), noise_cluster_ids )
330360
331-         spike_indexes  =  spike_indexes [not_noise_clusters_by_spike ]
361+         spike_indices  =  spike_indices [not_noise_clusters_by_spike ]
332362        spike_templates  =  spike_templates [not_noise_clusters_by_spike ]
333363        temp_scaling_amplitudes  =  temp_scaling_amplitudes [not_noise_clusters_by_spike ]
334364
@@ -343,7 +373,7 @@ def _load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool
343373        cluster_groups  =  3  *  np .ones (cluster_ids .size )
344374
345375    new_params  =  {
346-         "spike_indexes " : spike_indexes .squeeze (),
376+         "spike_indices " : spike_indices .squeeze (),
347377        "spike_templates" : spike_templates .squeeze (),
348378        "spike_clusters" : spike_clusters .squeeze (),
349379        "pc_features" : pc_features ,
0 commit comments