Skip to content

Commit 6d569a2

Browse files
committed
Tidy up, general checks.
1 parent 02ec55a commit 6d569a2

File tree

2 files changed

+114
-63
lines changed

2 files changed

+114
-63
lines changed

src/spikeinterface/working/load_kilosort_utils.py

Lines changed: 90 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -7,27 +7,31 @@
77

88
from scipy import stats
99

10-
# TODO: spike_times -> spike_indexes
10+
# TODO: spike_times -> spike_indices
1111
"""
1212
Notes
1313
-----
1414
- not everything is used for current purposes
1515
- things might be useful in future for making a sorting analyzer - compute template amplitude as average of spike amplitude.
1616
"""
1717

18+
########################################################################################################################
19+
# Get Spike Data
20+
########################################################################################################################
21+
1822

1923
def compute_spike_amplitude_and_depth(
2024
sorter_output: str | Path,
2125
localised_spikes_only,
2226
exclude_noise,
2327
gain: float | None = None,
24-
localised_spikes_channel_cutoff: int = None, # TODO
28+
localised_spikes_channel_cutoff: int = None,
2529
) -> tuple[np.ndarray, ...]:
2630
"""
2731
Compute the amplitude and depth of all detected spikes from the kilosort output.
2832
29-
This function was ported from Nick Steinmetz's `spikes` repository
30-
MATLAB code, https://github.com/cortex-lab/spikes
33+
This function is based on code in Nick Steinmetz's `spikes` repository,
34+
https://github.com/cortex-lab/spikes
3135
3236
Parameters
3337
----------
@@ -46,8 +50,8 @@ def compute_spike_amplitude_and_depth(
4650
4751
Returns
4852
-------
49-
spike_indexes : np.ndarray
50-
(num_spikes,) array of spike indexes.
53+
spike_indices : np.ndarray
54+
(num_spikes,) array of spike indices.
5155
spike_amplitudes : np.ndarray
5256
(num_spikes,) array of corresponding spike amplitudes.
5357
spike_depths : np.ndarray
@@ -66,7 +70,7 @@ def compute_spike_amplitude_and_depth(
6670
if isinstance(sorter_output, str):
6771
sorter_output = Path(sorter_output)
6872

69-
params = _load_ks_dir(sorter_output, load_pcs=True, exclude_noise=exclude_noise)
73+
params = load_ks_dir(sorter_output, load_pcs=True, exclude_noise=exclude_noise)
7074

7175
if localised_spikes_only:
7276
localised_templates = []
@@ -81,10 +85,56 @@ def compute_spike_amplitude_and_depth(
8185

8286
localised_template_by_spike = np.isin(params["spike_templates"], localised_templates)
8387

84-
_strip_spikes(params, localised_template_by_spike)
88+
params["spike_templates"] = params["spike_templates"][localised_template_by_spike]
89+
params["spike_indices"] = params["spike_indices"][localised_template_by_spike]
90+
params["spike_clusters"] = params["spike_clusters"][localised_template_by_spike]
91+
params["temp_scaling_amplitudes"] = params["temp_scaling_amplitudes"][localised_template_by_spike]
92+
params["pc_features"] = params["pc_features"][localised_template_by_spike]
93+
94+
spike_locations, spike_max_sites = _get_locations_from_pc_features(params)
95+
96+
# Amplitude is calculated for each spike as the template amplitude
97+
# multiplied by the `template_scaling_amplitudes`.
98+
template_amplitudes_unscaled, *_ = get_unwhite_template_info(
99+
params["templates"],
100+
params["whitening_matrix_inv"],
101+
params["channel_positions"],
102+
)
103+
spike_amplitudes = template_amplitudes_unscaled[params["spike_templates"]] * params["temp_scaling_amplitudes"]
104+
105+
if gain is not None:
106+
spike_amplitudes *= gain
107+
108+
compute_template_amplitudes_from_spikes(params["templates"], params["spike_templates"], spike_amplitudes)
109+
110+
if localised_spikes_only:
111+
# Interpolate the channel ids to location.
112+
# Remove spikes > 5 um from average position
113+
# Above we already removed non-localized templates, but that on its own is insufficient.
114+
# Note for IMEC probe adding a constant term kills the regression making the regressors rank deficient
115+
# TODO: a couple of approaches. 1) do everything in 3D, draw a sphere around prediction, take spikes only within the sphere
116+
# 2) do separate for x, y. But resolution will be much lower, making things noisier, also harder to determine threshold.
117+
# 3) just use depth. Probably go for that. check with others.
118+
spike_depths = spike_locations[:, 1]
119+
b = stats.linregress(spike_depths, spike_max_sites).slope
120+
i = np.abs(spike_max_sites - b * spike_depths) <= 5
85121

122+
params["spike_indices"] = params["spike_indices"][i]
123+
spike_amplitudes = spike_amplitudes[i]
124+
spike_locations = spike_locations[i, :]
125+
spike_max_sites = spike_max_sites[i]
126+
127+
return params["spike_indices"], spike_amplitudes, spike_locations, spike_max_sites
128+
129+
130+
def _get_locations_from_pc_features(params):
131+
"""
132+
133+
This function is based on code in Nick Steinmetz's `spikes` repository,
134+
https://github.com/cortex-lab/spikes
135+
"""
86136
# Compute spike depths
87-
pc_features = params["pc_features"][:, 0, :] # Do this compute
137+
pc_features = params["pc_features"][:, 0, :]
88138
pc_features[pc_features < 0] = 0
89139

90140
# Some spikes do not load at all onto the first PC. To avoid biasing the
@@ -109,58 +159,28 @@ def compute_spike_amplitude_and_depth(
109159
"to extend this code section to handle more components."
110160
)
111161

112-
# Get the channel indexes corresponding to the 32 channels from the PC.
162+
# Get the channel indices corresponding to the 32 channels from the PC.
113163
spike_features_indices = params["pc_features_indices"][params["spike_templates"], :]
114164

115165
# Compute the spike locations as the center of mass of the PC scores
116166
spike_feature_coords = params["channel_positions"][spike_features_indices, :]
117-
norm_weights = pc_features / np.sum(pc_features, axis=1)[:, np.newaxis] # TOOD: see why they use square
167+
norm_weights = (
168+
pc_features / np.sum(pc_features, axis=1)[:, np.newaxis]
169+
) # TOOD: discuss use of square. Probbaly do not use to keep in line with COM in SI.
118170
spike_locations = spike_feature_coords * norm_weights[:, :, np.newaxis]
119171
spike_locations = np.sum(spike_locations, axis=1)
120172

121173
# TODO: now max site per spike is computed from PCs, not as the channel max site as previous
122-
spike_sites = spike_features_indices[np.arange(spike_features_indices.shape[0]), np.argmax(norm_weights, axis=1)]
174+
spike_max_sites = spike_features_indices[
175+
np.arange(spike_features_indices.shape[0]), np.argmax(norm_weights, axis=1)
176+
]
123177

124-
# Amplitude is calculated for each spike as the template amplitude
125-
# multiplied by the `template_scaling_amplitudes`.
126-
template_amplitudes_unscaled, *_ = get_unwhite_template_info(
127-
params["templates"],
128-
params["whitening_matrix_inv"],
129-
params["channel_positions"],
130-
)
131-
spike_amplitudes = template_amplitudes_unscaled[params["spike_templates"]] * params["temp_scaling_amplitudes"]
178+
return spike_locations, spike_max_sites
132179

133-
if gain is not None:
134-
spike_amplitudes *= gain
135-
136-
if localised_spikes_only:
137-
# Interpolate the channel ids to location.
138-
# Remove spikes > 5 um from average position
139-
# Above we already removed non-localized templates, but that on its own is insufficient.
140-
# Note for IMEC probe adding a constant term kills the regression making the regressors rank deficient
141-
# TODO: a couple of approaches. 1) do everything in 3D, draw a sphere around prediction, take spikes only within the sphere
142-
# 2) do separate for x, y. But resolution will be much lower, making things noisier, also harder to determine threshold.
143-
# 3) just use depth. Probably go for that. check with others.
144-
spike_depths = spike_locations[:, 1]
145-
b = stats.linregress(spike_depths, spike_sites).slope
146-
i = np.abs(spike_sites - b * spike_depths) <= 5 # TODO: need to expose this
147-
148-
params["spike_indexes"] = params["spike_indexes"][i]
149-
spike_amplitudes = spike_amplitudes[i]
150-
spike_locations = spike_locations[i, :]
151-
152-
return params["spike_indexes"], spike_amplitudes, spike_locations, spike_sites
153180

154-
155-
def _strip_spikes_in_place(params, indices):
156-
""" """
157-
params["spike_templates"] = params["spike_templates"][
158-
indices
159-
] # TODO: make an function for this. because we do this a lot
160-
params["spike_indexes"] = params["spike_indexes"][indices]
161-
params["spike_clusters"] = params["spike_clusters"][indices]
162-
params["temp_scaling_amplitudes"] = params["temp_scaling_amplitudes"][indices]
163-
params["pc_features"] = params["pc_features"][indices] # TODO: be conciststetn! change indees to indices
181+
########################################################################################################################
182+
# Get Template Data
183+
########################################################################################################################
164184

165185

166186
def get_unwhite_template_info(
@@ -173,8 +193,8 @@ def get_unwhite_template_info(
173193
Amplitude is calculated for each spike as the template amplitude
174194
multiplied by the `template_scaling_amplitudes`.
175195
176-
This function was ported from Nick Steinmetz's `spikes` repository
177-
MATLAB code, https://github.com/cortex-lab/spikes
196+
This function is based on code in Nick Steinmetz's `spikes` repository,
197+
https://github.com/cortex-lab/spikes
178198
179199
Parameters
180200
----------
@@ -213,7 +233,7 @@ def get_unwhite_template_info(
213233

214234
template_amplitudes_unscaled = np.max(template_amplitudes_per_channel, axis=1)
215235

216-
# Zero any small channel amplitudes
236+
# Zero any small channel amplitudes TODO: removed this.
217237
# threshold_values = 0.3 * template_amplitudes_unscaled TODO: remove this to be more general. Agree?
218238
# template_amplitudes_per_channel[template_amplitudes_per_channel < threshold_values[:, np.newaxis]] = 0
219239

@@ -253,9 +273,14 @@ def get_unwhite_template_info(
253273
)
254274

255275

256-
def compute_template_amplitudes_from_spikes():
257-
# Take the average of all spike amplitudes to get actual template amplitudes
258-
# (since tempScalingAmps are equal mean for all templates)
276+
def compute_template_amplitudes_from_spikes(templates, spike_templates, spike_amplitudes):
277+
"""
278+
Take the average of all spike amplitudes to get actual template amplitudes
279+
(since tempScalingAmps are equal mean for all templates)
280+
281+
This function is ported from Nick Steinmetz's `spikes` repository,
282+
https://github.com/cortex-lab/spikes
283+
"""
259284
num_indices = templates.shape[0]
260285
sum_per_index = np.zeros(num_indices, dtype=np.float64)
261286
np.add.at(sum_per_index, spike_templates, spike_amplitudes)
@@ -264,7 +289,12 @@ def compute_template_amplitudes_from_spikes():
264289
return template_amplitudes
265290

266291

267-
def _load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool = False) -> dict:
292+
########################################################################################################################
293+
# Load Parameters from KS Directory
294+
########################################################################################################################
295+
296+
297+
def load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool = False) -> dict:
268298
"""
269299
Loads the output of Kilosort into a `params` dict.
270300
@@ -300,7 +330,7 @@ def _load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool
300330

301331
params = read_python(sorter_output / "params.py")
302332

303-
spike_indexes = np.load(sorter_output / "spike_times.npy")
333+
spike_indices = np.load(sorter_output / "spike_times.npy")
304334
spike_templates = np.load(sorter_output / "spike_templates.npy")
305335

306336
if (clusters_path := sorter_output / "spike_clusters.csv").is_dir():
@@ -328,7 +358,7 @@ def _load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool
328358
noise_cluster_ids = cluster_ids[cluster_groups == 0]
329359
not_noise_clusters_by_spike = ~np.isin(spike_clusters.ravel(), noise_cluster_ids)
330360

331-
spike_indexes = spike_indexes[not_noise_clusters_by_spike]
361+
spike_indices = spike_indices[not_noise_clusters_by_spike]
332362
spike_templates = spike_templates[not_noise_clusters_by_spike]
333363
temp_scaling_amplitudes = temp_scaling_amplitudes[not_noise_clusters_by_spike]
334364

@@ -343,7 +373,7 @@ def _load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool
343373
cluster_groups = 3 * np.ones(cluster_ids.size)
344374

345375
new_params = {
346-
"spike_indexes": spike_indexes.squeeze(),
376+
"spike_indices": spike_indices.squeeze(),
347377
"spike_templates": spike_templates.squeeze(),
348378
"spike_clusters": spike_clusters.squeeze(),
349379
"pc_features": pc_features,

src/spikeinterface/working/plot_kilosort_drift_map.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
from pathlib import Path
2-
from spikeinterface.widgets.base import BaseWidget, to_attr
32
import matplotlib.axis
43
import scipy.signal
5-
from spikeinterface.core import read_python
4+
5+
# from spikeinterface.core import read_python
66
import numpy as np
77
import pandas as pd
88

99
import matplotlib.pyplot as plt
1010
from scipy import stats
1111
import load_kilosort_utils
1212

13+
from spikeinterface.widgets.base import BaseWidget, to_attr
14+
1315

1416
class KilosortDriftMapWidget(BaseWidget):
1517
"""
@@ -399,5 +401,24 @@ def _filter_large_amplitude_spikes(
399401
return spike_times, spike_amplitudes, spike_depths
400402

401403

402-
KilosortDriftMapWidget(r"D:\data\New folder\CA_528_1\imec0_ks2")
404+
KilosortDriftMapWidget(
405+
"/Users/joeziminski/data/bombcelll/sorter_output",
406+
only_include_large_amplitude_spikes=False,
407+
localised_spikes_only=True,
408+
)
403409
plt.show()
410+
411+
"""
412+
sorter_output: str | Path,
413+
only_include_large_amplitude_spikes: bool = True,
414+
decimate: None | int = None,
415+
add_histogram_plot: bool = False,
416+
add_histogram_peaks_and_boundaries: bool = True,
417+
add_drift_events: bool = True,
418+
weight_histogram_by_amplitude: bool = False,
419+
localised_spikes_only: bool = False,
420+
exclude_noise: bool = False,
421+
gain: float | None = None,
422+
large_amplitude_only_segment_size: float = 800.0,
423+
localised_spikes_channel_cutoff: int = 20,
424+
"""

0 commit comments

Comments
 (0)