diff --git a/.github/scripts/determine_testing_environment.py b/.github/scripts/determine_testing_environment.py index 5836616d44..e097336c0e 100644 --- a/.github/scripts/determine_testing_environment.py +++ b/.github/scripts/determine_testing_environment.py @@ -20,7 +20,7 @@ plexon2_changed = False preprocessing_changed = False postprocessing_changed = False -qualitymetrics_changed = False +metrics_changed = False sorters_changed = False sorters_external_changed = False sorters_internal_changed = False @@ -58,8 +58,8 @@ preprocessing_changed = True elif "postprocessing" in changed_file.parts: postprocessing_changed = True - elif "qualitymetrics" in changed_file.parts: - qualitymetrics_changed = True + elif "metrics" in changed_file.parts: + metrics_changed = True elif "comparison" in changed_file.parts: comparison_changed = True elif "curation" in changed_file.parts: @@ -89,12 +89,12 @@ run_extractor_tests = run_everything or extractors_changed or plexon2_changed run_preprocessing_tests = run_everything or preprocessing_changed run_postprocessing_tests = run_everything or postprocessing_changed -run_qualitymetrics_tests = run_everything or qualitymetrics_changed +run_metrics_tests = run_everything or metrics_changed run_curation_tests = run_everything or curation_changed run_sortingcomponents_tests = run_everything or sortingcomponents_changed run_comparison_test = run_everything or run_generation_tests or comparison_changed -run_widgets_test = run_everything or run_qualitymetrics_tests or run_preprocessing_tests or widgets_changed +run_widgets_test = run_everything or run_metrics_tests or run_preprocessing_tests or widgets_changed run_exporters_test = run_everything or run_widgets_test or exporters_changed run_sorters_test = run_everything or sorters_changed @@ -109,7 +109,7 @@ "RUN_EXTRACTORS_TESTS": run_extractor_tests, "RUN_PREPROCESSING_TESTS": run_preprocessing_tests, "RUN_POSTPROCESSING_TESTS": run_postprocessing_tests, - "RUN_QUALITYMETRICS_TESTS": run_qualitymetrics_tests, + "RUN_METRICS_TESTS": run_metrics_tests, "RUN_CURATION_TESTS": run_curation_tests, "RUN_SORTINGCOMPONENTS_TESTS": run_sortingcomponents_tests, "RUN_GENERATION_TESTS": run_generation_tests, diff --git a/.github/scripts/import_test.py b/.github/scripts/import_test.py index 9e9fb7666b..0b166fab30 100644 --- a/.github/scripts/import_test.py +++ b/.github/scripts/import_test.py @@ -5,7 +5,7 @@ "import spikeinterface", "import spikeinterface.core", "import spikeinterface.extractors", - "import spikeinterface.qualitymetrics", + "import spikeinterface.metrics", "import spikeinterface.preprocessing", "import spikeinterface.comparison", "import spikeinterface.postprocessing", diff --git a/.github/workflows/all-tests.yml b/.github/workflows/all-tests.yml index 9481bb74f2..9e48799e4d 100644 --- a/.github/workflows/all-tests.yml +++ b/.github/workflows/all-tests.yml @@ -58,7 +58,7 @@ jobs: echo "RUN_EXTRACTORS_TESTS=${RUN_EXTRACTORS_TESTS}" echo "RUN_PREPROCESSING_TESTS=${RUN_PREPROCESSING_TESTS}" echo "RUN_POSTPROCESSING_TESTS=${RUN_POSTPROCESSING_TESTS}" - echo "RUN_QUALITYMETRICS_TESTS=${RUN_QUALITYMETRICS_TESTS}" + echo "RUN_METRICS_TESTS=${RUN_METRICS_TESTS}" echo "RUN_CURATION_TESTS=${RUN_CURATION_TESTS}" echo "RUN_SORTINGCOMPONENTS_TESTS=${RUN_SORTINGCOMPONENTS_TESTS}" echo "RUN_GENERATION_TESTS=${RUN_GENERATION_TESTS}" @@ -164,13 +164,13 @@ jobs: pip list ./.github/run_tests.sh postprocessing --no-virtual-env - - name: Test quality metrics + - name: Test metrics shell: bash - if: env.RUN_QUALITYMETRICS_TESTS == 'true' + if: env.RUN_METRICS_TESTS == 'true' run: | - pip install -e .[qualitymetrics] + pip install -e .[metrics] pip list - ./.github/run_tests.sh qualitymetrics --no-virtual-env + ./.github/run_tests.sh metrics --no-virtual-env - name: Test comparison shell: bash diff --git a/doc/api.rst b/doc/api.rst index 2aa09767a9..a4997bcd5f 100755 --- a/doc/api.rst +++ b/doc/api.rst @@ -226,19 +226,31 @@ spikeinterface.postprocessing .. autofunction:: compute_correlograms .. autofunction:: compute_acgs_3d .. autofunction:: compute_isi_histograms - .. autofunction:: get_template_metric_names .. autofunction:: align_sorting -spikeinterface.qualitymetrics ------------------------------ +spikeinterface.metrics +---------------------- -.. automodule:: spikeinterface.qualitymetrics +.. automodule:: spikeinterface.metrics.quality .. autofunction:: compute_quality_metrics .. autofunction:: get_quality_metric_list .. autofunction:: get_quality_pca_metric_list - .. autofunction:: get_default_qm_params + .. autofunction:: get_default_quality_metrics_params + +.. automodule:: spikeinterface.metrics.template + + .. autofunction:: compute_template_metrics + .. autofunction:: get_template_metric_list + .. autofunction:: get_default_template_metrics_params + .. autofunction:: get_single_channel_template_metric_names + .. autofunction:: get_multi_channel_template_metric_names + +.. automodule:: spikeinterface.metrics.spiketrain + + .. autofunction:: get_spiketrain_metric_list + .. autofunction:: get_default_spiketrain_metrics_params spikeinterface.sorters @@ -419,7 +431,7 @@ Drift ~~~~~ .. automodule:: spikeinterface.generation - :no-index: + :noindex: .. autofunction:: generate_drifting_recording .. autofunction:: generate_displacement_vector @@ -434,7 +446,7 @@ Hybrid ~~~~~~ .. automodule:: spikeinterface.generation - :no-index: + :noindex: .. autofunction:: generate_hybrid_recording .. autofunction:: estimate_templates_from_recording @@ -451,7 +463,6 @@ Noise ~~~~~ .. automodule:: spikeinterface.generation - :no-index: .. autofunction:: generate_noise @@ -508,9 +519,6 @@ spikeinterface.benchmark .. automodule:: spikeinterface.benchmark.benchmark_peak_localization .. autoclass:: PeakLocalizationStudy - -.. automodule:: spikeinterface.benchmark.benchmark_peak_localization - .. autoclass:: UnitLocalizationStudy .. automodule:: spikeinterface.benchmark.benchmark_motion_estimation diff --git a/doc/conf.py b/doc/conf.py index 4f9b60a926..b4ff6e97fe 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -35,8 +35,8 @@ '../examples/tutorials/core/analyzer_some_units', '../examples/tutorials/core/analyzer.zarr', '../examples/tutorials/curation/my_folder', - '../examples/tutorials/qualitymetrics/curated_sorting', - '../examples/tutorials/qualitymetrics/clean_analyzer.zarr', + '../examples/tutorials/metrics/curated_sorting', + '../examples/tutorials/metrics/clean_analyzer.zarr', '../examples/tutorials/widgets/waveforms_mearec', ] @@ -129,7 +129,7 @@ '../examples/tutorials/core', '../examples/tutorials/extractors', '../examples/tutorials/curation', - '../examples/tutorials/qualitymetrics', + '../examples/tutorials/metrics', '../examples/tutorials/comparison', '../examples/tutorials/widgets', '../examples/tutorials/forhowto', diff --git a/doc/get_started/import.rst b/doc/get_started/import.rst index 30e841bdd8..d42e21fc4d 100644 --- a/doc/get_started/import.rst +++ b/doc/get_started/import.rst @@ -26,10 +26,10 @@ to import the :code:`core` module followed by: import spikeinterface.extractors as se import spikeinterface.preprocessing as spre import spikeinterface.sorters as ss - import spikinterface.postprocessing as spost - import spikeinterface.qualitymetrics as sqm + import spikeinterface.postprocessing as spost + import spikeinterface.metrics as sm import spikeinterface.exporters as sexp - import spikeinterface.comparsion as scmp + import spikeinterface.comparison as scmp import spikeinterface.curation as scur import spikeinterface.sortingcomponents as sc import spikeinterface.widgets as sw diff --git a/doc/get_started/quickstart.rst b/doc/get_started/quickstart.rst index 1d532c9387..7b2b72a782 100644 --- a/doc/get_started/quickstart.rst +++ b/doc/get_started/quickstart.rst @@ -42,7 +42,7 @@ We need to import one by one different submodules separately import spikeinterface.preprocessing as spre import spikeinterface.sorters as ss import spikeinterface.postprocessing as spost - import spikeinterface.qualitymetrics as sqm + import spikeinterface.metrics.quality as sqm import spikeinterface.comparison as sc import spikeinterface.exporters as sexp import spikeinterface.curation as scur @@ -627,7 +627,7 @@ compute quality metrics (some quality metrics require certain extensions .. code:: ipython3 - qm_params = sqm.get_default_qm_params() + qm_params = sqm.get_default_quality_metrics_params() pprint(qm_params) diff --git a/doc/how_to/analyze_neuropixels.rst b/doc/how_to/analyze_neuropixels.rst index 602105bc1e..9ebaed7d46 100644 --- a/doc/how_to/analyze_neuropixels.rst +++ b/doc/how_to/analyze_neuropixels.rst @@ -699,15 +699,14 @@ make a copy of the analyzer and all computed extensions. Quality metrics --------------- -We have a single function ``compute_quality_metrics(SortingAnalyzer)`` -that returns a ``pandas.Dataframe`` with the desired metrics. +The ``analyzer.compute("quality_metrics").get_data()`` returns a ``pandas.Dataframe`` with the desired metrics. Note that this function is also an extension and so can be saved. And so -this is equivalent to do : +this is equivalent to do: ``metrics = analyzer.compute("quality_metrics").get_data()`` Please visit the `metrics -documentation `__ +documentation `__ for more information and a list of all supported metrics. Some metrics are based on PCA (like @@ -721,20 +720,12 @@ PCA for their computation. This can be achieved with: metric_names=['firing_rate', 'presence_ratio', 'snr', 'isi_violation', 'amplitude_cutoff'] - # metrics = analyzer.compute("quality_metrics").get_data() - # equivalent to - metrics = si.compute_quality_metrics(analyzer, metric_names=metric_names) + metrics_ext = analyzer.compute("quality_metrics", metric_names=metric_names) + metrics = metrics_ext.get_data() metrics -.. parsed-literal:: - - /home/samuel.garcia/Documents/SpikeInterface/spikeinterface/src/spikeinterface/qualitymetrics/misc_metrics.py:846: UserWarning: Some units have too few spikes : amplitude_cutoff is set to NaN - warnings.warn(f"Some units have too few spikes : amplitude_cutoff is set to NaN") - - - .. raw:: html diff --git a/doc/images/overview.png b/doc/images/overview.png index e367c4b6e4..1ddca00381 100644 Binary files a/doc/images/overview.png and b/doc/images/overview.png differ diff --git a/doc/modules/index.rst b/doc/modules/index.rst index 189bf56196..473f2eeb4d 100644 --- a/doc/modules/index.rst +++ b/doc/modules/index.rst @@ -9,7 +9,7 @@ Modules documentation preprocessing sorters postprocessing - qualitymetrics + metrics comparison exporters widgets diff --git a/doc/modules/metrics.rst b/doc/modules/metrics.rst new file mode 100644 index 0000000000..069fa340f3 --- /dev/null +++ b/doc/modules/metrics.rst @@ -0,0 +1,20 @@ +Metrics module +============== + + +The :py:mod:`~spikeinterface.metrics` module includes functions to compute various metrics related to spike sorting. + +Currently, it contains the following submodules: + +- **template metrics**: Computes commonly used waveform/template metrics. +- **quality metrics**: Computes a variety of quality metrics to assess the goodness of spike sorting outputs. +- **spiketrain metrics**: Computes metrics based on spike train statistics and correlogram shapes. + + +.. toctree:: + :caption: Metrics submodules + :maxdepth: 1 + + metrics/template_metrics + metrics/quality_metrics + metrics/spiketrain_metrics diff --git a/doc/modules/qualitymetrics.rst b/doc/modules/metrics/quality_metrics.rst similarity index 88% rename from doc/modules/qualitymetrics.rst rename to doc/modules/metrics/quality_metrics.rst index 7625d4db01..fd5a5ca0e4 100644 --- a/doc/modules/qualitymetrics.rst +++ b/doc/modules/metrics/quality_metrics.rst @@ -2,7 +2,7 @@ Quality Metrics module ====================== Quality metrics allows one to quantitatively assess the *goodness* of a spike sorting output. -The :py:mod:`~spikeinterface.qualitymetrics` module includes functions to compute a large variety of available metrics. +The :py:mod:`~spikeinterface.metrics.quality` module includes functions to compute a large variety of available metrics. All of the metrics currently implemented in spikeInterface are *per unit* (pairwise metrics do appear in the literature). Each metric aims to identify some quality of the unit. @@ -12,7 +12,7 @@ Completeness metrics (or 'false negative'/'type II' metrics) aim to identify whe Examples include: presence ratio, amplitude cutoff, NN-miss rate. Drift metrics aim to identify changes in waveforms which occur when spike sorters fail to successfully track neurons in the case of electrode drift. -The quality metrics are saved as an extension of a :doc:`SortingAnalyzer `. Some metrics can only be computed if certain extensions have been computed first. For example the drift metrics can only be computed the spike locations extension has been computed. By default, as many metrics as possible are computed. Which ones are computed depends on which other extensions have +The quality metrics are saved as an extension of a :doc:`SortingAnalyzer <../postprocessing>`. Some metrics can only be computed if certain extensions have been computed first. For example the drift metrics can only be computed the spike locations extension has been computed. By default, as many metrics as possible are computed. Which ones are computed depends on which other extensions have been computed. In detail, the default metrics are (click on each metric to find out more about them!): @@ -69,7 +69,7 @@ You can compute the default metrics using the following code snippet: Some metrics are very slow to compute when the number of units it large. So by default, the following metrics are not computed: -- The ``nn_noise_overlap`` from :doc:`qualitymetrics/nearest_neighbor` +- The ``nn_advanced`` from :doc:`qualitymetrics/nearest_neighbor` Some metrics make use of :ref:`principal component analysis ` (PCA) to reduce the dimensionality of computations. Various approaches to computing the principal components are possible, and choice should be carefully considered in relation to the recording equipment used. @@ -94,7 +94,7 @@ To save the result in your analyzer, you can use the ``compute`` method: } ) -Note that if you request a specific metric using ``metric_names`` and you do not have the required extension computed, this will error. +Note that if you request a specific metric using ``metric_names`` and you do not have the required extension computed, the metric will be skipped. For more information about quality metrics, check out this excellent `documentation `_ diff --git a/doc/modules/qualitymetrics/amplitude_cutoff.rst b/doc/modules/metrics/qualitymetrics/amplitude_cutoff.rst similarity index 96% rename from doc/modules/qualitymetrics/amplitude_cutoff.rst rename to doc/modules/metrics/qualitymetrics/amplitude_cutoff.rst index ef2749cd8b..155b1b6e2a 100644 --- a/doc/modules/qualitymetrics/amplitude_cutoff.rst +++ b/doc/modules/metrics/qualitymetrics/amplitude_cutoff.rst @@ -22,7 +22,7 @@ Example code .. code-block:: python - import spikeinterface.qualitymetrics as sqm + import spikeinterface.metrics.quality as sqm # Combine sorting and recording into a sorting_analyzer # It is also recommended to run sorting_analyzer.compute(input="spike_amplitudes") @@ -34,7 +34,7 @@ Example code Reference --------- -.. autofunction:: spikeinterface.qualitymetrics.misc_metrics.compute_amplitude_cutoffs +.. autofunction:: spikeinterface.metrics.quality.misc_metrics.compute_amplitude_cutoffs Links to original implementations diff --git a/doc/modules/qualitymetrics/amplitude_cv.rst b/doc/modules/metrics/qualitymetrics/amplitude_cv.rst similarity index 92% rename from doc/modules/qualitymetrics/amplitude_cv.rst rename to doc/modules/metrics/qualitymetrics/amplitude_cv.rst index 2ad51aab2a..675dcf9237 100644 --- a/doc/modules/qualitymetrics/amplitude_cv.rst +++ b/doc/modules/metrics/qualitymetrics/amplitude_cv.rst @@ -32,7 +32,7 @@ Example code .. code-block:: python - import spikeinterface.qualitymetrics as sqm + import spikeinterface.metrics.quality as sqm # Combine a sorting and recording into a sorting_analyzer # It is required to run sorting_analyzer.compute(input="spike_amplitudes") or @@ -46,7 +46,7 @@ Example code References ---------- -.. autofunction:: spikeinterface.qualitymetrics.misc_metrics.compute_amplitude_cv_metrics +.. autofunction:: spikeinterface.metrics.quality.misc_metrics.compute_amplitude_cv_metrics Literature diff --git a/doc/modules/qualitymetrics/amplitude_median.rst b/doc/modules/metrics/qualitymetrics/amplitude_median.rst similarity index 88% rename from doc/modules/qualitymetrics/amplitude_median.rst rename to doc/modules/metrics/qualitymetrics/amplitude_median.rst index 1e4eec2e40..10990014f6 100644 --- a/doc/modules/qualitymetrics/amplitude_median.rst +++ b/doc/modules/metrics/qualitymetrics/amplitude_median.rst @@ -20,7 +20,7 @@ Example code .. code-block:: python - import spikeinterface.qualitymetrics as sqm + import spikeinterface.metrics.quality as sqm # It is also recommended to run sorting_analyzer.compute(input="spike_amplitudes") # in order to use amplitude values from all spikes. @@ -31,7 +31,7 @@ Example code Reference --------- -.. autofunction:: spikeinterface.qualitymetrics.misc_metrics.compute_amplitude_medians +.. autofunction:: spikeinterface.metrics.quality.misc_metrics.compute_amplitude_medians Links to original implementations diff --git a/doc/modules/qualitymetrics/amplitudes.png b/doc/modules/metrics/qualitymetrics/amplitudes.png similarity index 100% rename from doc/modules/qualitymetrics/amplitudes.png rename to doc/modules/metrics/qualitymetrics/amplitudes.png diff --git a/doc/modules/qualitymetrics/contamination.png b/doc/modules/metrics/qualitymetrics/contamination.png similarity index 100% rename from doc/modules/qualitymetrics/contamination.png rename to doc/modules/metrics/qualitymetrics/contamination.png diff --git a/doc/modules/qualitymetrics/d_prime.rst b/doc/modules/metrics/qualitymetrics/d_prime.rst similarity index 91% rename from doc/modules/qualitymetrics/d_prime.rst rename to doc/modules/metrics/qualitymetrics/d_prime.rst index 9b540be743..cc591c1629 100644 --- a/doc/modules/qualitymetrics/d_prime.rst +++ b/doc/modules/metrics/qualitymetrics/d_prime.rst @@ -32,7 +32,7 @@ Example code .. code-block:: python - import spikeinterface.qualitymetrics as sqm + import spikeinterface.metrics.quality as sqm d_prime = sqm.lda_metrics(all_pcs=all_pcs, all_labels=all_labels, this_unit_id=0) @@ -40,7 +40,7 @@ Example code Reference --------- -.. autofunction:: spikeinterface.qualitymetrics.pca_metrics.lda_metrics +.. autofunction:: spikeinterface.metrics.quality.pca_metrics.lda_metrics Literature diff --git a/doc/modules/qualitymetrics/drift.rst b/doc/modules/metrics/qualitymetrics/drift.rst similarity index 95% rename from doc/modules/qualitymetrics/drift.rst rename to doc/modules/metrics/qualitymetrics/drift.rst index 8f95f74695..82144176b7 100644 --- a/doc/modules/qualitymetrics/drift.rst +++ b/doc/modules/metrics/qualitymetrics/drift.rst @@ -40,7 +40,7 @@ Example code .. code-block:: python - import spikeinterface.qualitymetrics as sqm + import spikeinterface.metrics.quality as sqm # Combine sorting and recording into sorting_analyzer # It is required to run sorting_analyzer.compute(input="spike_locations") first @@ -53,7 +53,7 @@ Example code Reference --------- -.. autofunction:: spikeinterface.qualitymetrics.misc_metrics.compute_drift_metrics +.. autofunction:: spikeinterface.metrics.quality.misc_metrics.compute_drift_metrics Links to original implementations --------------------------------- diff --git a/doc/modules/qualitymetrics/example_cutoff.png b/doc/modules/metrics/qualitymetrics/example_cutoff.png similarity index 100% rename from doc/modules/qualitymetrics/example_cutoff.png rename to doc/modules/metrics/qualitymetrics/example_cutoff.png diff --git a/doc/modules/qualitymetrics/firing_range.rst b/doc/modules/metrics/qualitymetrics/firing_range.rst similarity index 87% rename from doc/modules/qualitymetrics/firing_range.rst rename to doc/modules/metrics/qualitymetrics/firing_range.rst index d059f4eac6..9ddc03b57f 100644 --- a/doc/modules/qualitymetrics/firing_range.rst +++ b/doc/modules/metrics/qualitymetrics/firing_range.rst @@ -21,7 +21,7 @@ Example code .. code-block:: python - import spikeinterface.qualitymetrics as sqm + import spikeinterface.metrics.quality as sqm # Combine a sorting and recording into a sorting_analyzer firing_range = sqm.compute_firing_ranges(sorting_analyzer=sorting_analyzer) @@ -31,7 +31,7 @@ Example code References ---------- -.. autofunction:: spikeinterface.qualitymetrics.misc_metrics.compute_firing_ranges +.. autofunction:: spikeinterface.metrics.quality.misc_metrics.compute_firing_ranges Literature diff --git a/doc/modules/qualitymetrics/firing_rate.rst b/doc/modules/metrics/qualitymetrics/firing_rate.rst similarity index 88% rename from doc/modules/qualitymetrics/firing_rate.rst rename to doc/modules/metrics/qualitymetrics/firing_rate.rst index 953901dd38..55efeda4d1 100644 --- a/doc/modules/qualitymetrics/firing_rate.rst +++ b/doc/modules/metrics/qualitymetrics/firing_rate.rst @@ -37,17 +37,17 @@ With SpikeInterface: .. code-block:: python - import spikeinterface.qualitymetrics as sqm + from spikeinterface.metrics.spiketrain import compute_firing_rates # Combine a sorting and recording into a sorting_analyzer - firing_rate = sqm.compute_firing_rates(sorting_analyzer) + firing_rate = compute_firing_rates(sorting_analyzer) # firing_rate is a dict containing the unit IDs as keys, # and their firing rates across segments as values (in Hz). References ---------- -.. autofunction:: spikeinterface.qualitymetrics.misc_metrics.compute_firing_rates +.. autofunction:: spikeinterface.metrics.spiketrain.compute_firing_rates Links to original implementations diff --git a/doc/modules/qualitymetrics/isi_violations.rst b/doc/modules/metrics/qualitymetrics/isi_violations.rst similarity index 96% rename from doc/modules/qualitymetrics/isi_violations.rst rename to doc/modules/metrics/qualitymetrics/isi_violations.rst index 4527cdffe9..2a52612650 100644 --- a/doc/modules/qualitymetrics/isi_violations.rst +++ b/doc/modules/metrics/qualitymetrics/isi_violations.rst @@ -81,7 +81,7 @@ With SpikeInterface: .. code-block:: python - import spikeinterface.qualitymetrics as sqm + import spikeinterface.metrics.quality as sqm # Combine sorting and recording into sorting_analyzer @@ -93,12 +93,12 @@ References UMS implementation (:code:`isi_violation`) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. autofunction:: spikeinterface.qualitymetrics.misc_metrics.compute_isi_violations +.. autofunction:: spikeinterface.metrics.quality.misc_metrics.compute_isi_violations LLobet implementation (:code:`rp_violation`) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. autofunction:: spikeinterface.qualitymetrics.misc_metrics.compute_refrac_period_violations +.. autofunction:: spikeinterface.metrics.quality.misc_metrics.compute_refrac_period_violations Examples with plots diff --git a/doc/modules/qualitymetrics/l_ratio.rst b/doc/modules/metrics/qualitymetrics/mahalanobis.rst similarity index 50% rename from doc/modules/qualitymetrics/l_ratio.rst rename to doc/modules/metrics/qualitymetrics/mahalanobis.rst index ae31ab40a4..44c3bb856e 100644 --- a/doc/modules/qualitymetrics/l_ratio.rst +++ b/doc/modules/metrics/qualitymetrics/mahalanobis.rst @@ -1,9 +1,33 @@ -L-ratio (:code:`l_ratio`) -========================= +Mahalanobis metrics (:code:`isolation_distance`, :code:`l_ratio`) +================================================================= + +Mahalanobis metrics are quality metrics based on the Mahalanobis distance between spikes and cluster centres in the PCA space. + +They include: +- Isolation distance (:code:`isolation_distance`) +- L-ratio (:code:`l_ratio`) Calculation ----------- +Isolation distance +~~~~~~~~~~~~~~~~~~ + +- :math:`C` : cluster of interest. +- :math:`N_s` : number of spikes within cluster :math:`C`. +- :math:`N_n` : number of spikes outside of cluster :math:`C`. +- :math:`N_{min}` : minimum of :math:`N_s` and :math:`N_n`. +- :math:`\mu_C`, :math:`\Sigma_C` : mean vector and covariance matrix for spikes within :math:`C` (where each spike within :math:`C` is represented by a vector of principal components (PCs)). +- :math:`D_{i,C}^2` : for every spike :math:`i` (represented by vector :math:`x_i`) outside of cluster :math:`C`, the Mahalanobis distance (as below) between :math:`\mu_c` and :math:`x_i` is calculated. These distances are ordered from smallest to largest. The :math:`N_{min}`'th entry in this list is the isolation distance. + +.. math:: + D_{i,C}^2 = (x_i - \mu_C)^T \Sigma_C^{-1} (x_i - \mu_C) + +Geometrically, the isolation distance for cluster :math:`C` is the radius of the circle which contains :math:`N_{min}` spikes from cluster :math:`C` and :math:`N_{min}` spikes outside of the cluster :math:`C`. + +L-ratio +~~~~~~~ + This example assumes use of a tetrode. L-ratio uses 4 principal components (PCs) for each tetrode channel (the first being energy, the square root of the sum of squares of each sample in the waveform, followed by the first 3 PCs of the energy normalised waveform). @@ -31,7 +55,10 @@ This yields L-ratio, which can be expressed as: Expectation and use ------------------- -Since this metric identifies unit separation, a high value indicates a highly contaminated unit (type I error) +Isolation distance can be interpreted as a measure of distance from the cluster to the nearest other cluster. +A well isolated unit should have a large isolation distance. + +L-ratio quantifies unit separation, so a high value indicates a highly contaminated unit (type I error) ([Schmitzer-Torbert]_ et al.). [Jackson]_ et al. suggests that this measure is also correlated with type II errors (although more strongly with type I errors). @@ -43,19 +70,20 @@ Example code .. code-block:: python - import spikeinterface.qualitymetrics as sqm + from spikeinterface.metrics.quality.pca_metrics import mahalanobis_metrics - _, l_ratio = sqm.isolation_distance(all_pcs=all_pcs, all_labels=all_labels, this_unit_id=0) + isolation_distance, l_ratio = mahalanobis_metrics(all_pcs=all_pcs, all_labels=all_labels, this_unit_id=0) References ---------- -.. autofunction:: spikeinterface.qualitymetrics.pca_metrics.mahalanobis_metrics +.. autofunction:: spikeinterface.metrics.quality.pca_metrics.mahalanobis_metrics :noindex: Literature ---------- -Introduced by [Schmitzer-Torbert]_ et al.. +Isolation distance introduced by [Harris]_. +L-ratio introduced by [Schmitzer-Torbert]_ et al.. Early discussion and comparison with isolation distance by [Jackson]_ et al.. diff --git a/doc/modules/qualitymetrics/nearest_neighbor.rst b/doc/modules/metrics/qualitymetrics/nearest_neighbor.rst similarity index 82% rename from doc/modules/qualitymetrics/nearest_neighbor.rst rename to doc/modules/metrics/qualitymetrics/nearest_neighbor.rst index bbd8f6628a..d5b59c0481 100644 --- a/doc/modules/qualitymetrics/nearest_neighbor.rst +++ b/doc/modules/metrics/qualitymetrics/nearest_neighbor.rst @@ -9,10 +9,9 @@ There are several implementations of nearest neighbor metrics which can be used When calling the :code:`compute_quality_metrics()` function, the following options are available to calculate NN metrics: - The :code:`nearest_neighbor` option will return :code:`nn_hit_rate` and :code:`nn_miss_rate` (based on [Siegle]_ inspired by [Chung]_). -- The :code:`nn_isolation` option will return the nearest neighbor isolation metric (adapted from [Chung]_). -- The :code:`nn_noise_overlap` option will return the nearest neighbor isolation metric (adapted from [Chung]_). +- The :code:`nn_advanced` option will return the advanced nearest neighbor metrics (:code:`nn_isolation` and :code:`nn_noise_overlap`) (adapted from [Chung]_). -All options involve non-parametric calculations in PCA space. +All options involve non-parametric calculations in PCA space. Note that the :code:`nn_advanced` metrics can be very slow to compute, so they are not computed by default. :code:`nearest_neighbor` ------------------------ @@ -38,8 +37,11 @@ NN-hit rate gives an estimate of contamination (an uncontaminated unit should ha NN-miss rate gives an estimate of completeness. A more complete unit should have a low NN-miss rate. +:code:`nn_advanced` +------------------- + :code:`nn_isolation` --------------------- +^^^^^^^^^^^^^^^^^^^^ The overall logic of this approach is to choose a cluster for which the isolation is to be computed, and compute the pairwise isolation score between the chosen cluster and every other cluster. The isolation score is then the minimum of the pairwise scores (the worst case). @@ -58,7 +60,7 @@ The pairwise isolation between clusters A and B is then: Note that nn_isolation is affected by the size of the clusters, so setting the :code:`max_spikes_for_nn` may aid downstream comparison of scores. :code:`nn_noise_overlap` ------------------------- +^^^^^^^^^^^^^^^^^^^^^^^^ A noise cluster is generated by randomly sampling voltage snippets from the recording. Following a similar procedure to that of the nn_isolation method, compute isolation between the cluster of interest and the generated noise cluster. @@ -71,11 +73,11 @@ This metric gives an indication of the contamination present in the unit cluster References ---------- -.. autofunction:: spikeinterface.qualitymetrics.pca_metrics.nearest_neighbors_metrics +.. autofunction:: spikeinterface.metrics.quality.pca_metrics.nearest_neighbors_metrics -.. autofunction:: spikeinterface.qualitymetrics.pca_metrics.nearest_neighbors_isolation +.. autofunction:: spikeinterface.metrics.quality.pca_metrics.nearest_neighbors_isolation -.. autofunction:: spikeinterface.qualitymetrics.pca_metrics.nearest_neighbors_noise_overlap +.. autofunction:: spikeinterface.metrics.quality.pca_metrics.nearest_neighbors_noise_overlap Literature diff --git a/doc/modules/qualitymetrics/noise_cutoff.rst b/doc/modules/metrics/qualitymetrics/noise_cutoff.rst similarity index 98% rename from doc/modules/qualitymetrics/noise_cutoff.rst rename to doc/modules/metrics/qualitymetrics/noise_cutoff.rst index 10384dd637..6a1c9900f1 100644 --- a/doc/modules/qualitymetrics/noise_cutoff.rst +++ b/doc/modules/metrics/qualitymetrics/noise_cutoff.rst @@ -101,7 +101,7 @@ Example code Reference --------- -.. autofunction:: spikeinterface.qualitymetrics.misc_metrics.compute_noise_cutoffs +.. autofunction:: spikeinterface.metrics.quality.misc_metrics.compute_noise_cutoffs Examples with plots ------------------- diff --git a/doc/modules/qualitymetrics/presence_ratio.rst b/doc/modules/metrics/qualitymetrics/presence_ratio.rst similarity index 89% rename from doc/modules/qualitymetrics/presence_ratio.rst rename to doc/modules/metrics/qualitymetrics/presence_ratio.rst index e925c6e325..bf252fdb44 100644 --- a/doc/modules/qualitymetrics/presence_ratio.rst +++ b/doc/modules/metrics/qualitymetrics/presence_ratio.rst @@ -23,7 +23,7 @@ Example code .. code-block:: python - import spikeinterface.qualitymetrics as sqm + import spikeinterface.metrics.quality as sqm # Combine sorting and recording into a sorting_analyzer @@ -40,7 +40,7 @@ Links to original implementations References ---------- -.. autofunction:: spikeinterface.qualitymetrics.misc_metrics.compute_presence_ratios +.. autofunction:: spikeinterface.metrics.quality.misc_metrics.compute_presence_ratios Literature ---------- diff --git a/doc/modules/qualitymetrics/sd_ratio.rst b/doc/modules/metrics/qualitymetrics/sd_ratio.rst similarity index 92% rename from doc/modules/qualitymetrics/sd_ratio.rst rename to doc/modules/metrics/qualitymetrics/sd_ratio.rst index 260a2ec38e..14f1b32d23 100644 --- a/doc/modules/qualitymetrics/sd_ratio.rst +++ b/doc/modules/metrics/qualitymetrics/sd_ratio.rst @@ -26,7 +26,7 @@ Example code .. code-block:: python - import spikeinterface.qualitymetrics as sqm + import spikeinterface.metrics.quality as sqm # In this case we need to combine our sorting and recording into a sorting_analyzer sd_ratio = sqm.compute_sd_ratio(sorting_analyzer=sorting_analyzer censored_period_ms=4.0) @@ -35,7 +35,7 @@ Example code References ---------- -.. autofunction:: spikeinterface.qualitymetrics.misc_metrics.compute_sd_ratio +.. autofunction:: spikeinterface.metrics.quality.misc_metrics.compute_sd_ratio Literature ---------- diff --git a/doc/modules/qualitymetrics/silhouette_score.rst b/doc/modules/metrics/qualitymetrics/silhouette_score.rst similarity index 92% rename from doc/modules/qualitymetrics/silhouette_score.rst rename to doc/modules/metrics/qualitymetrics/silhouette_score.rst index 7da01e0476..f179356cd3 100644 --- a/doc/modules/qualitymetrics/silhouette_score.rst +++ b/doc/modules/metrics/qualitymetrics/silhouette_score.rst @@ -55,7 +55,7 @@ Example code .. code-block:: python - import spikeinterface.qualitymetrics as sqm + import spikeinterface.metrics.quality as sqm simple_sil_score = sqm.simplified_silhouette_score(all_pcs=all_pcs, all_labels=all_labels, this_unit_id=0) @@ -63,9 +63,9 @@ Example code References ---------- -.. autofunction:: spikeinterface.qualitymetrics.pca_metrics.simplified_silhouette_score +.. autofunction:: spikeinterface.metrics.quality.pca_metrics.simplified_silhouette_score -.. autofunction:: spikeinterface.qualitymetrics.pca_metrics.silhouette_score +.. autofunction:: spikeinterface.metrics.quality.pca_metrics.silhouette_score Literature diff --git a/doc/modules/qualitymetrics/sliding_rp_violations.rst b/doc/modules/metrics/qualitymetrics/sliding_rp_violations.rst similarity index 92% rename from doc/modules/qualitymetrics/sliding_rp_violations.rst rename to doc/modules/metrics/qualitymetrics/sliding_rp_violations.rst index 1913062cd9..eaa1831a47 100644 --- a/doc/modules/qualitymetrics/sliding_rp_violations.rst +++ b/doc/modules/metrics/qualitymetrics/sliding_rp_violations.rst @@ -27,7 +27,7 @@ With SpikeInterface: .. code-block:: python - import spikeinterface.qualitymetrics as sqm + import spikeinterface.metrics.quality as sqm # Combine sorting and recording into a sorting_analyzer @@ -36,7 +36,7 @@ With SpikeInterface: References ---------- -.. autofunction:: spikeinterface.qualitymetrics.misc_metrics.compute_sliding_rp_violations +.. autofunction:: spikeinterface.metrics.quality.misc_metrics.compute_sliding_rp_violations Links to original implementations diff --git a/doc/modules/qualitymetrics/snr.rst b/doc/modules/metrics/qualitymetrics/snr.rst similarity index 93% rename from doc/modules/qualitymetrics/snr.rst rename to doc/modules/metrics/qualitymetrics/snr.rst index e640ec026f..ff669e447e 100644 --- a/doc/modules/qualitymetrics/snr.rst +++ b/doc/modules/metrics/qualitymetrics/snr.rst @@ -41,7 +41,7 @@ With SpikeInterface: .. code-block:: python - import spikeinterface.qualitymetrics as sqm + import spikeinterface.metrics.quality as sqm # Combining sorting and recording into a sorting_analzyer SNRs = sqm.compute_snrs(sorting_analzyer=sorting_analzyer) @@ -56,7 +56,7 @@ Links to original implementations References ---------- -.. autofunction:: spikeinterface.qualitymetrics.misc_metrics.compute_snrs +.. autofunction:: spikeinterface.metrics.quality.misc_metrics.compute_snrs Literature ---------- diff --git a/doc/modules/qualitymetrics/synchrony.rst b/doc/modules/metrics/qualitymetrics/synchrony.rst similarity index 93% rename from doc/modules/qualitymetrics/synchrony.rst rename to doc/modules/metrics/qualitymetrics/synchrony.rst index 696dacbd3c..7b40449ee8 100644 --- a/doc/modules/qualitymetrics/synchrony.rst +++ b/doc/modules/metrics/qualitymetrics/synchrony.rst @@ -27,7 +27,7 @@ Example code .. code-block:: python - import spikeinterface.qualitymetrics as sqm + import spikeinterface.metrics.quality as sqm # Combine a sorting and recording into a sorting_analyzer synchrony = sqm.compute_synchrony_metrics(sorting_analyzer=sorting_analyzer) # synchrony is a tuple of dicts with the synchrony metrics for each unit @@ -41,7 +41,7 @@ The SpikeInterface implementation is a partial port of the low-level complexity References ---------- -.. automodule:: spikeinterface.qualitymetrics.misc_metrics +.. automodule:: spikeinterface.metrics.quality.misc_metrics .. autofunction:: compute_synchrony_metrics diff --git a/doc/modules/metrics/spiketrain_metrics.rst b/doc/modules/metrics/spiketrain_metrics.rst new file mode 100644 index 0000000000..867af567d7 --- /dev/null +++ b/doc/modules/metrics/spiketrain_metrics.rst @@ -0,0 +1,10 @@ +Spike Train Metrics +=================== + +The :py:mod:`~spikeinterface.metrics.spiketrain_metrics` module includes functions to compute metrics based on spike train statistics and correlogram shapes. +Currently, the following metrics are implemented: + +- "num_spikes": number of spikes in the spike train. +- "firing_rate": firing rate of the spike train (spikes per second). + +# TODO: Add more metrics such as ISI distribution, CV, etc. diff --git a/doc/modules/metrics/template_metrics.rst b/doc/modules/metrics/template_metrics.rst new file mode 100644 index 0000000000..6c2acbc21b --- /dev/null +++ b/doc/modules/metrics/template_metrics.rst @@ -0,0 +1,43 @@ +Template Metrics +================ + +This extension computes commonly used waveform/template metrics. +By default, the following metrics are computed: + +* "peak_to_valley": duration in :math:`s` between negative and positive peaks +* "halfwidth": duration in :math:`s` at 50% of the amplitude +* "peak_to_trough_ratio": ratio between negative and positive peaks +* "recovery_slope": speed to recover from the negative peak to 0 +* "repolarization_slope": speed to repolarize from the positive peak to 0 +* "num_positive_peaks": the number of positive peaks +* "num_negative_peaks": the number of negative peaks + +The units of :code:`recovery_slope` and :code:`repolarization_slope` depend on the +input. Voltages are based on the units of the template. By default this is :math:`\mu V` +but can be the raw output from the recording device (this depends on the +:code:`return_in_uV` parameter, read more here: :ref:`modules/core:SortingAnalyzer`). +Distances are in :math:`\mu m` and times are in seconds. So, for example, if the +templates are in units of :math:`\mu V` then: :code:`repolarization_slope` is in +:math:`mV / s`; :code:`peak_to_trough_ratio` is in :math:`\mu m` and the +:code:`halfwidth` is in :math:`s`. + +Optionally, the following multi-channel metrics can be computed by setting: +:code:`include_multi_channel_metrics=True` + +* "velocity_above": the velocity in :math:`\mu m/s` above the max channel of the template +* "velocity_below": the velocity in :math:`\mu m/s` below the max channel of the template +* "exp_decay": the exponential decay in :math:`\mu m` of the template amplitude over distance +* "spread": the spread in :math:`\mu m` of the template amplitude over distance + +.. figure:: ../../images/1d_waveform_features.png + + Visualization of template metrics. Image from `ecephys_spike_sorting `_ + from the Allen Institute. + + +.. code-block:: python + + tm = sorting_analyzer.compute(input="template_metrics", include_multi_channel_metrics=True) + + +For more information, see :py:func:`~spikeinterface.postprocessing.compute_template_metrics` diff --git a/doc/modules/postprocessing.rst b/doc/modules/postprocessing.rst index 41fbe18865..5442b4728c 100644 --- a/doc/modules/postprocessing.rst +++ b/doc/modules/postprocessing.rst @@ -313,51 +313,6 @@ based on individual waveforms, it calculates at the unit level using templates. For more information, see :py:func:`~spikeinterface.postprocessing.compute_unit_locations` -template_metrics -^^^^^^^^^^^^^^^^ - -This extension computes commonly used waveform/template metrics. -By default, the following metrics are computed: - -* "peak_to_valley": duration in :math:`s` between negative and positive peaks -* "halfwidth": duration in :math:`s` at 50% of the amplitude -* "peak_to_trough_ratio": ratio between negative and positive peaks -* "recovery_slope": speed to recover from the negative peak to 0 -* "repolarization_slope": speed to repolarize from the positive peak to 0 -* "num_positive_peaks": the number of positive peaks -* "num_negative_peaks": the number of negative peaks - -The units of :code:`recovery_slope` and :code:`repolarization_slope` depend on the -input. Voltages are based on the units of the template. By default this is :math:`\mu V` -but can be the raw output from the recording device (this depends on the -:code:`return_in_uV` parameter, read more here: :ref:`modules/core:SortingAnalyzer`). -Distances are in :math:`\mu m` and times are in seconds. So, for example, if the -templates are in units of :math:`\mu V` then: :code:`repolarization_slope` is in -:math:`mV / s`; :code:`peak_to_trough_ratio` is in :math:`\mu m` and the -:code:`halfwidth` is in :math:`s`. - -Optionally, the following multi-channel metrics can be computed by setting: -:code:`include_multi_channel_metrics=True` - -* "velocity_above": the velocity in :math:`\mu m/s` above the max channel of the template -* "velocity_below": the velocity in :math:`\mu m/s` below the max channel of the template -* "exp_decay": the exponential decay in :math:`\mu m` of the template amplitude over distance -* "spread": the spread in :math:`\mu m` of the template amplitude over distance - -.. figure:: ../images/1d_waveform_features.png - - Visualization of template metrics. Image from `ecephys_spike_sorting `_ - from the Allen Institute. - - -.. code-block:: python - - tm = sorting_analyzer.compute(input="template_metrics", include_multi_channel_metrics=True) - - -For more information, see :py:func:`~spikeinterface.postprocessing.compute_template_metrics` - - correlograms ^^^^^^^^^^^^ diff --git a/doc/modules/qualitymetrics/isolation_distance.rst b/doc/modules/qualitymetrics/isolation_distance.rst deleted file mode 100644 index 6ba0d0b1ec..0000000000 --- a/doc/modules/qualitymetrics/isolation_distance.rst +++ /dev/null @@ -1,45 +0,0 @@ -Isolation distance (:code:`isolation_distance`) -=============================================== - -Calculation ------------ - -- :math:`C` : cluster of interest. -- :math:`N_s` : number of spikes within cluster :math:`C`. -- :math:`N_n` : number of spikes outside of cluster :math:`C`. -- :math:`N_{min}` : minimum of :math:`N_s` and :math:`N_n`. -- :math:`\mu_C`, :math:`\Sigma_C` : mean vector and covariance matrix for spikes within :math:`C` (where each spike within :math:`C` is represented by a vector of principal components (PCs)). -- :math:`D_{i,C}^2` : for every spike :math:`i` (represented by vector :math:`x_i`) outside of cluster :math:`C`, the Mahalanobis distance (as below) between :math:`\mu_c` and :math:`x_i` is calculated. These distances are ordered from smallest to largest. The :math:`N_{min}`'th entry in this list is the isolation distance. - -.. math:: - D_{i,C}^2 = (x_i - \mu_C)^T \Sigma_C^{-1} (x_i - \mu_C) - -Geometrically, the isolation distance for cluster :math:`C` is the radius of the circle which contains :math:`N_{min}` spikes from cluster :math:`C` and :math:`N_{min}` spikes outside of the cluster :math:`C`. - - -Expectation and use -------------------- - -Isolation distance can be interpreted as a measure of distance from the cluster to the nearest other cluster. -A well isolated unit should have a large isolation distance. - -Example code ------------- - -.. code-block:: python - - import spikeinterface.qualitymetrics as sqm - - iso_distance, _ = sqm.isolation_distance(all_pcs=all_pcs, all_labels=all_labels, this_unit_id=0) - - -References ----------- - -.. autofunction:: spikeinterface.qualitymetrics.pca_metrics.mahalanobis_metrics - - -Literature ----------- - -Introduced by [Harris]_. diff --git a/doc/overview.rst b/doc/overview.rst index f8626347a5..f91a16e215 100644 --- a/doc/overview.rst +++ b/doc/overview.rst @@ -27,7 +27,7 @@ SpikeInterface consists of several sub-packages which encapsulate all steps in a - :py:mod:`spikeinterface.preprocessing` - :py:mod:`spikeinterface.sorters` - :py:mod:`spikeinterface.postprocessing` -- :py:mod:`spikeinterface.qualitymetrics` +- :py:mod:`spikeinterface.metrics` - :py:mod:`spikeinterface.widgets` - :py:mod:`spikeinterface.exporters` - :py:mod:`spikeinterface.comparison` diff --git a/doc/references.rst b/doc/references.rst index 1179afa509..383a53b158 100644 --- a/doc/references.rst +++ b/doc/references.rst @@ -57,13 +57,12 @@ methods: - :code:`acgs_3d` [Beau]_ - :code:`unit_locations` or :code:`spike_locations` with :code:`monopolar_triangulation` based on work from [Boussard]_ - :code:`unit_locations` or :code:`spike_locations` with :code:`grid_convolution` based on work from [Pachitariu]_ - - :code:`template_metrics` [Jia]_ -Qualitymetrics Module ---------------------- -If you use the :code:`qualitymetrics` module, i.e. you use the :code:`analyzer.compute()` -or :code:`compute_quality_metrics()` methods, please include the citations for the :code:`metric_names` that were particularly +Metrics Module +-------------- +If you use the :code:`metrics.quality` module, i.e. you use the :code:`analyzer.compute("quality_metrics")` +method, please include the citations for the :code:`metric_names` that were particularly important for your research: - :code:`amplitude_cutoff` [Hill]_ @@ -75,16 +74,16 @@ important for your research: - :code:`sd_ratio` [Pouzat]_ - :code:`snr` [Lemon]_ [Jackson]_ - :code:`synchrony` [GrĂ¼n]_ - -If you use the :code:`qualitymetrics.pca_metrics` module, i.e. you use the -:code:`compute_pc_metrics()` method, please include the citations for the :code:`metric_names` that were particularly -important for your research: - - :code:`d_prime` [Hill]_ - :code:`isolation_distance` or :code:`l_ratio` [Schmitzer-Torbert]_ - :code:`nearest_neighbor` or :code:`nn_isolation` or :code:`nn_noise_overlap` [Chung]_ [Siegle]_ - :code:`silhouette` [Rousseeuw]_ [Hruschka]_ +If you use the :code:`metrics.template` module, i.e. you use the :code:`analyzer.compute("template_metrics")` method, +please following citations: + +- [Jia]_ + Curation Module --------------- diff --git a/doc/tutorials_custom_index.rst b/doc/tutorials_custom_index.rst index f146f4a4d0..65b9c84543 100755 --- a/doc/tutorials_custom_index.rst +++ b/doc/tutorials_custom_index.rst @@ -118,23 +118,23 @@ The :py:mod:`spikeinterface.extractors` module is designed to load and save reco Quality metrics tutorial ------------------------ -The :code:`spikeinterface.qualitymetrics` module allows users to compute various quality metrics to assess the goodness of a spike sorting output. +The :code:`spikeinterface.metrics.quality` module allows users to compute various quality metrics to assess the goodness of a spike sorting output. .. grid:: 1 2 2 3 :gutter: 2 .. grid-item-card:: Quality Metrics :link-type: ref - :link: sphx_glr_tutorials_qualitymetrics_plot_3_quality_metrics.py - :img-top: /tutorials/qualitymetrics/images/thumb/sphx_glr_plot_3_quality_metrics_thumb.png + :link: sphx_glr_tutorials_metrics_plot_3_quality_metrics.py + :img-top: /tutorials/metrics/images/thumb/sphx_glr_plot_3_quality_metrics_thumb.png :img-alt: Quality Metrics :class-card: gallery-card :text-align: center .. grid-item-card:: Curation Tutorial :link-type: ref - :link: sphx_glr_tutorials_qualitymetrics_plot_4_curation.py - :img-top: /tutorials/qualitymetrics/images/thumb/sphx_glr_plot_4_curation_thumb.png + :link: sphx_glr_tutorials_metrics_plot_4_curation.py + :img-top: /tutorials/metrics/images/thumb/sphx_glr_plot_4_curation_thumb.png :img-alt: Curation Tutorial :class-card: gallery-card :text-align: center diff --git a/examples/get_started/quickstart.py b/examples/get_started/quickstart.py index 3bbcb371fa..9eff361925 100644 --- a/examples/get_started/quickstart.py +++ b/examples/get_started/quickstart.py @@ -43,7 +43,7 @@ # - `preprocessing` : preprocessing # - `sorters` : Python wrappers of spike sorters # - `postprocessing` : postprocessing -# - `qualitymetrics` : quality metrics on units found by sorters +# - `metrics` : quality, template, and spiketrain metrics on units found by sorters # - `curation` : automatic curation of spike sorting output # - `comparison` : comparison of spike sorting outputs # - `widgets` : visualization @@ -53,7 +53,7 @@ import spikeinterface.preprocessing as spre import spikeinterface.sorters as ss import spikeinterface.postprocessing as spost -import spikeinterface.qualitymetrics as sqm +import spikeinterface.metrics as sm import spikeinterface.comparison as sc import spikeinterface.exporters as sexp import spikeinterface.curation as scur @@ -61,7 +61,7 @@ # Alternatively, we can import all submodules at once with `import spikeinterface.full as si` which # internally imports core+extractors+preprocessing+sorters+postprocessing+ -# qualitymetrics+comparison+widgets+exporters. In this case all aliases in the following tutorial +# metrics+comparison+widgets+exporters. In this case all aliases in the following tutorial # would be `si`. # This is useful for notebooks, but it is a heavier import because internally many more dependencies @@ -329,7 +329,7 @@ # Once we have computed all of the postprocessing information, we can compute quality # metrics (some quality metrics require certain extensions - e.g., drift metrics require `spike_locations`): -qm_params = sqm.get_default_qm_params() +qm_params = sm.get_default_quality_metrics_params() pprint(qm_params) # Since the recording is very short, let's change some parameters to accommodate the duration: diff --git a/examples/how_to/analyze_neuropixels.py b/examples/how_to/analyze_neuropixels.py index 11aefe786e..1acf8c55e9 100644 --- a/examples/how_to/analyze_neuropixels.py +++ b/examples/how_to/analyze_neuropixels.py @@ -290,13 +290,10 @@ # ## Quality metrics # -# We have a single function `compute_quality_metrics(SortingAnalyzer)` that returns a `pandas.Dataframe` with the desired metrics. +# The `analyzer.compute("quality_metrics").get_data()` returns a `pandas.Dataframe` with the desired metrics. # -# Note that this function is also an extension and so can be saved. And so this is equivalent to do : -# `metrics = analyzer.compute("quality_metrics").get_data()` # -# -# Please visit the [metrics documentation](https://spikeinterface.readthedocs.io/en/latest/modules/qualitymetrics.html) for more information and a list of all supported metrics. +# Please visit the [metrics documentation](https://spikeinterface.readthedocs.io/en/latest/modules/metrics/qualitymetrics.html) for more information and a list of all supported metrics. # # Some metrics are based on PCA (like `'isolation_distance', 'l_ratio', 'd_prime'`) and require to estimate PCA for their computation. This can be achieved with: # @@ -308,9 +305,8 @@ metric_names=['firing_rate', 'presence_ratio', 'snr', 'isi_violation', 'amplitude_cutoff'] -# metrics = analyzer.compute("quality_metrics").get_data() -# equivalent to -metrics = si.compute_quality_metrics(analyzer, metric_names=metric_names) +metrics_ext = analyzer.compute("quality_metrics") +metrics = metrics_ext.get_data() metrics # - diff --git a/examples/tutorials/qualitymetrics/README.rst b/examples/tutorials/metrics/README.rst similarity index 100% rename from examples/tutorials/qualitymetrics/README.rst rename to examples/tutorials/metrics/README.rst diff --git a/examples/tutorials/qualitymetrics/plot_3_quality_metrics.py b/examples/tutorials/metrics/plot_3_quality_metrics.py similarity index 77% rename from examples/tutorials/qualitymetrics/plot_3_quality_metrics.py rename to examples/tutorials/metrics/plot_3_quality_metrics.py index fe71368845..96f0fa090e 100644 --- a/examples/tutorials/qualitymetrics/plot_3_quality_metrics.py +++ b/examples/tutorials/metrics/plot_3_quality_metrics.py @@ -8,11 +8,10 @@ """ import spikeinterface.core as si -from spikeinterface.qualitymetrics import ( +from spikeinterface.metrics import ( compute_snrs, - compute_firing_rates, + compute_presence_ratios, compute_isi_violations, - compute_quality_metrics, ) ############################################################################## @@ -48,8 +47,8 @@ # metrics in a compact and easy way. To compute a single metric, one can simply run one of the # quality metric functions as shown below. Each function has a variety of adjustable parameters that can be tuned. -firing_rates = compute_firing_rates(analyzer) -print(firing_rates) +presence_ratios = compute_presence_ratios(analyzer) +print(presence_ratios) isi_violation_ratio, isi_violations_count = compute_isi_violations(analyzer) print(isi_violation_ratio) snrs = compute_snrs(analyzer) @@ -57,23 +56,32 @@ ############################################################################## -# To compute more than one metric at once, we can use the :code:`compute_quality_metrics` function and indicate -# which metrics we want to compute. This will return a pandas dataframe: - -metrics = compute_quality_metrics(analyzer, metric_names=["firing_rate", "snr", "amplitude_cutoff"]) +# To compute more than one metric at once, we can use the :code:`SortingAnalyzer.compute("quality_metrics")` +# function and indicate which metrics we want to compute. Then we can retrieve the results using the :code:`get_data()` +# method as a ``pandas.DataFrame``. + +metrics_ext = analyzer.compute( + "quality_metrics", + metric_names=["presence_ratio", "snr", "amplitude_cutoff"], + metric_params={ + "presence_ratio": {"bin_duration_s": 2.0}, + } +) +metrics = metrics_ext.get_data() print(metrics) ############################################################################## -# Some metrics are based on the principal component scores, so the exwtension +# Some metrics are based on the principal component scores, so the extension # must be computed before. For instance: analyzer.compute("principal_components", n_components=3, mode="by_channel_global", whiten=True) -metrics = compute_quality_metrics( - analyzer, +metrics_ext = analyzer.compute( + "quality_metrics", metric_names=[ - "isolation_distance", + "mahalanobis_metrics", "d_prime", ], ) +metrics = metrics_ext.get_data() print(metrics) diff --git a/examples/tutorials/qualitymetrics/plot_4_curation.py b/examples/tutorials/metrics/plot_4_curation.py similarity index 89% rename from examples/tutorials/qualitymetrics/plot_4_curation.py rename to examples/tutorials/metrics/plot_4_curation.py index 328ebf8f2b..b556adc4c5 100644 --- a/examples/tutorials/qualitymetrics/plot_4_curation.py +++ b/examples/tutorials/metrics/plot_4_curation.py @@ -12,8 +12,6 @@ import spikeinterface.core as si -from spikeinterface.qualitymetrics import compute_quality_metrics - ############################################################################## # Let's generate a simulated dataset, and imagine that the ground-truth @@ -41,7 +39,8 @@ ############################################################################## # Then we compute some quality metrics: -metrics = compute_quality_metrics(analyzer, metric_names=["snr", "isi_violation", "nearest_neighbor"]) +metrics_ext = analyzer.compute("quality_metrics", metric_names=["snr", "isi_violation", "nearest_neighbor"]) +metrics = metrics_ext.get_data() print(metrics) ############################################################################## @@ -51,7 +50,7 @@ # # Then create a list of unit ids that we want to keep -keep_mask = (metrics["snr"] > 7.5) & (metrics["isi_violations_ratio"] < 0.2) & (metrics["nn_hit_rate"] > 0.90) +keep_mask = (metrics["snr"] > 7.5) & (metrics["isi_violations_ratio"] < 0.2) & (metrics["nn_hit_rate"] > 0.80) print(keep_mask) keep_unit_ids = keep_mask[keep_mask].index.values @@ -65,7 +64,7 @@ print(curated_sorting) -curated_sorting.save(folder="curated_sorting") +curated_sorting.save(folder="curated_sorting", overwrite=True) ############################################################################## # We can also save the analyzer with only theses units diff --git a/pyproject.toml b/pyproject.toml index 788b1aeaf4..3728aa22cd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -113,7 +113,7 @@ widgets = [ "sortingview>=0.12.0", ] -qualitymetrics = [ +metrics = [ "scikit-learn", "scipy", "pandas", @@ -232,7 +232,7 @@ markers = [ "extractors", "preprocessing", "postprocessing", - "qualitymetrics", + "mertrics", "sorters", "sorters_external", "sorters_internal", diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index fea3f3618e..2e6c225a91 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -11,12 +11,14 @@ import warnings import numpy as np +from collections import namedtuple -from .sortinganalyzer import AnalyzerExtension, register_result_extension +from .sortinganalyzer import SortingAnalyzer, AnalyzerExtension, register_result_extension from .waveform_tools import extract_waveforms_to_single_buffer, estimate_templates_with_accumulator from .recording_tools import get_noise_levels from .template import Templates from .sorting_tools import random_spikes_selection +from .job_tools import fix_job_kwargs, split_job_kwargs class ComputeRandomSpikes(AnalyzerExtension): @@ -806,3 +808,472 @@ def _handle_backward_compatibility_on_load(self): register_result_extension(ComputeNoiseLevels) compute_noise_levels = ComputeNoiseLevels.function_factory() + + +class BaseMetric: + """ + Base class for metric-based extension + """ + + metric_name = None # to be defined in subclass + metric_params = {} # to be defined in subclass + metric_columns = {} # column names and their dtypes of the dataframe + needs_recording = False # to be defined in subclass + needs_tmp_data = False # to be defined in subclass + needs_job_kwargs = False + depend_on = [] # to be defined in subclass + + # the metric function must have the signature: + # def metric_function(sorting_analyzer, unit_ids, **metric_params) + # or if needs_tmp_data=True + # def metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params) + # or if needs_job_kwargs=True + # def metric_function(sorting_analyzer, unit_ids, tmp_data, job_kwargs, **metric_params) + # and must return a dict ({unit_id: values}) or namedtuple with fields matching metric_columns keys + metric_function = None # to be defined in subclass + + @classmethod + def compute(cls, sorting_analyzer, unit_ids, metric_params, tmp_data, job_kwargs): + """Compute the metric. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + The input sorting analyzer + unit_ids : list + List of unit ids to compute the metric for + metric_params : dict + Parameters to override the default metric parameters + tmp_data : dict + Temporary data to pass to the metric function + job_kwargs : dict + Job keyword arguments to control parallelization + + Returns + ------- + results: namedtuple + The results of the metric function + """ + args = (sorting_analyzer, unit_ids) + if cls.needs_tmp_data: + args += (tmp_data,) + if cls.needs_job_kwargs: + args += (job_kwargs,) + + results = cls.metric_function(*args, **metric_params) + + # if namedtuple, check that columns are correct + if isinstance(results, tuple) and hasattr(results, "_fields"): + assert set(results._fields) == set(list(cls.metric_columns.keys())), ( + f"Metric {cls.metric_name} returned columns {results._fields} " + f"but expected columns are {cls.metric_columns.keys()}" + ) + return results + + +class BaseMetricExtension(AnalyzerExtension): + """ + AnalyzerExtension that computes a metric and store the results in a dataframe. + + This depends on one or more extensions (see `depend_on` attribute of the `BaseMetric` subclass). + + Returns + ------- + metric_dataframe : pd.DataFrame + The computed metric dataframe. + """ + + extension_name = None # to be defined in subclass + metric_class = None # to be defined in subclass + need_recording = False + use_nodepipeline = False + need_job_kwargs = True + need_backward_compatibility_on_load = False + metric_list: list[BaseMetric] = None # list of BaseMetric + + @classmethod + def get_default_metric_params(cls): + """Get the default metric parameters. + + Returns + ------- + default_metric_params : dict + Dictionary of default metric parameters for each metric. + """ + default_metric_params = {m.metric_name: m.metric_params for m in cls.metric_list} + return default_metric_params + + @classmethod + def get_metric_columns(cls, metric_names=None): + """Get the default metric columns. + + Parameters + ---------- + metric_names : list[str] | None + List of metric names to get columns for. If None, all metrics are considered. + + Returns + ------- + default_metric_columns : dict + Dictionary of default metric columns and their dtypes for each metric. + """ + default_metric_columns = [] + for m in cls.metric_list: + if metric_names is not None and m.metric_name not in metric_names: + continue + default_metric_columns.extend(m.metric_columns) + return default_metric_columns + + def _cast_metrics(self, metrics_df): + metric_dtypes = {} + for m in self.metric_list: + metric_dtypes.update(m.metric_columns) + + for col in metrics_df.columns: + if col in metric_dtypes: + try: + metrics_df[col] = metrics_df[col].astype(metric_dtypes[col]) + except Exception as e: + print(f"Error casting column {col}: {e}") + return metrics_df + + def _set_params( + self, + metric_names: list[str] | None = None, + metric_params: dict | None = None, + delete_existing_metrics: bool = False, + metrics_to_compute: list[str] | None = None, + **other_params, + ): + """ + Sets parameters for metric computation. + + Parameters + ---------- + metric_names : list[str] | None + List of metric names to compute. If None, all available metrics are computed. + metric_params : dict | None + Dictionary of metric parameters to override default parameters for specific metrics. + If None, default parameters for all metrics are used. + delete_existing_metrics : bool, default: False + If True, existing metrics in the extension will be deleted before computing new ones. + other_params : dict + Additional parameters for metric computation. + + Returns + ------- + params : dict + Dictionary of parameters for metric computation. + + Raises + ------ + ValueError + If any of the metric names are not in the available metrics. + """ + # check metric names + if metric_names is None: + metric_names = [m.metric_name for m in self.metric_list] + else: + for metric_name in metric_names: + if metric_name not in [m.metric_name for m in self.metric_list]: + raise ValueError( + f"Metric {metric_name} not in available metrics {[m.metric_name for m in self.metric_list]}" + ) + # check dependencies + metrics_to_remove = [] + for metric_name in metric_names: + metric = [m for m in self.metric_list if m.metric_name == metric_name][0] + depend_on = metric.depend_on + for dep in depend_on: + if "|" in dep: + # at least one of the dependencies must be present + dep_options = dep.split("|") + if not any([self.sorting_analyzer.has_extension(d) for d in dep_options]): + metrics_to_remove.append(metric_name) + else: + if not self.sorting_analyzer.has_extension(dep): + metrics_to_remove.append(metric_name) + if metric.needs_recording and not self.sorting_analyzer.has_recording(): + warnings.warn( + f"Metric {metric_name} requires a recording. " + f"Since the SortingAnalyzer has no recording, the metric will not be computed." + ) + metrics_to_remove.append(metric_name) + + metrics_to_remove = list(set(metrics_to_remove)) + if len(metrics_to_remove) > 0: + warnings.warn( + f"The following metrics will not be computed due to missing dependencies: {metrics_to_remove}" + ) + + for metric_name in metrics_to_remove: + metric_names.remove(metric_name) + + default_metric_params = {m.metric_name: m.metric_params for m in self.metric_list} + if metric_params is None: + metric_params = default_metric_params + else: + for metric, params in metric_params.items(): + default_metric_params[metric].update(params) + metric_params = default_metric_params + + # TODO: check behavior here!!! + if metrics_to_compute is None: + metrics_to_compute = metric_names + extension = self.sorting_analyzer.get_extension(self.extension_name) + if delete_existing_metrics is False and extension is not None: + existing_metric_names = extension.params["metric_names"] + existing_metric_names_propagated = [ + metric_name for metric_name in existing_metric_names if metric_name not in metrics_to_compute + ] + metric_names = metrics_to_compute + existing_metric_names_propagated + + params = dict( + metric_names=metric_names, + metrics_to_compute=metrics_to_compute, + delete_existing_metrics=delete_existing_metrics, + metric_params=metric_params, + **other_params, + ) + return params + + def _prepare_data(self, sorting_analyzer, unit_ids=None): + """Optional function to prepare shared data for metric computation.""" + # useful function to compute data that is shared across metrics (e.g., PCA) + return {} + + def _compute_metrics( + self, + sorting_analyzer: SortingAnalyzer, + unit_ids: list[int | str] | None = None, + metric_names: list[str] | None = None, + **job_kwargs, + ): + """ + Compute metrics. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + The SortingAnalyzer object. + unit_ids : list[int | str] | None, default: None + List of unit ids to compute metrics for. If None, all units are used. + metric_names : list[str] | None, default: None + List of metric names to compute. If None, all metrics in params["metric_names"] + are used. + + Returns + ------- + metrics : pd.DataFrame + DataFrame containing the computed metrics for each unit. + """ + import pandas as pd + + if unit_ids is None: + unit_ids = sorting_analyzer.unit_ids + tmp_data = self._prepare_data(sorting_analyzer=sorting_analyzer, unit_ids=unit_ids) + if metric_names is None: + metric_names = self.params["metric_names"] + + column_names_dtypes = {} + for metric_name in metric_names: + metric = [m for m in self.metric_list if m.metric_name == metric_name][0] + column_names_dtypes.update(metric.metric_columns) + + metrics = pd.DataFrame(index=unit_ids, columns=list(column_names_dtypes.keys())) + + for metric_name in metric_names: + metric = [m for m in self.metric_list if m.metric_name == metric_name][0] + column_names = list(metric.metric_columns.keys()) + try: + metric_params = self.params["metric_params"].get(metric_name, {}) + res = metric.compute( + sorting_analyzer, + unit_ids=unit_ids, + metric_params=metric_params, + tmp_data=tmp_data, + job_kwargs=job_kwargs, + ) + except Exception as e: + warnings.warn(f"Error computing metric {metric_name}: {e}") + if len(column_names) == 1: + res = {unit_id: np.nan for unit_id in unit_ids} + else: + res = namedtuple("MetricResult", column_names)(*([np.nan] * len(column_names))) + + # res is a namedtuple with several dictionary entries (one per column) + if isinstance(res, dict): + column_name = column_names[0] + metrics.loc[unit_ids, column_name] = pd.Series(res) + else: + for i, col in enumerate(res._fields): + metrics.loc[unit_ids, col] = pd.Series(res[i]) + + metrics = self._cast_metrics(metrics) + + return metrics + + def _run(self, **job_kwargs): + + metrics_to_compute = self.params["metrics_to_compute"] + delete_existing_metrics = self.params["delete_existing_metrics"] + + _, job_kwargs = split_job_kwargs(job_kwargs) + job_kwargs = fix_job_kwargs(job_kwargs) + + # compute the metrics which have been specified by the user + computed_metrics = self._compute_metrics( + sorting_analyzer=self.sorting_analyzer, unit_ids=None, metric_names=metrics_to_compute, **job_kwargs + ) + + existing_metrics = [] + + # Check if we need to propagate any old metrics. If so, we'll do that. + # Otherwise, we'll avoid attempting to load an empty metrics. + if set(self.params["metrics_to_compute"]) != set(self.params["metric_names"]): + + extension = self.sorting_analyzer.get_extension(self.extension_name) + if delete_existing_metrics is False and extension is not None and extension.data.get("metrics") is not None: + existing_metrics = extension.params["metric_names"] + + existing_metrics = [] + # here we get in the loaded via the dict only (to avoid full loading from disk after params reset) + extension = self.sorting_analyzer.extensions.get(self.extension_name, None) + if delete_existing_metrics is False and extension is not None and extension.data.get("metrics") is not None: + existing_metrics = extension.params["metric_names"] + + # append the metrics which were previously computed + for metric_name in set(existing_metrics).difference(metrics_to_compute): + metric = [m for m in self.metric_list if m.metric_name == metric_name][0] + # some metrics names produce data columns with other names. This deals with that. + for column_name in metric.metric_columns: + computed_metrics[column_name] = extension.data["metrics"][column_name] + + self.data["metrics"] = computed_metrics + + def _get_data(self): + # convert to correct dtype + return self.data["metrics"] + + def set_data(self, ext_data_name, data): + import pandas as pd + + if ext_data_name != "metrics": + return + if not isinstance(data, pd.DataFrame): + return + metrics = self._cast_metrics(data) + self.data[ext_data_name] = metrics + + def _select_extension_data(self, unit_ids: list[int | str]): + """ + Select data for a subset of unit ids. + + Parameters + ---------- + unit_ids : list[int | str] + List of unit ids to select data for. + + Returns + ------- + dict + Dictionary containing the selected metrics DataFrame. + """ + new_metrics = self.data["metrics"].loc[np.array(unit_ids)] + return dict(metrics=new_metrics) + + def _merge_extension_data( + self, + merge_unit_groups: list[list[int | str]], + new_unit_ids: list[int | str], + new_sorting_analyzer: SortingAnalyzer, + keep_mask: np.ndarray | None = None, + verbose: bool = False, + **job_kwargs, + ): + """ + Merge extension data from the old metrics DataFrame into the new one. + + Parameters + ---------- + merge_unit_groups : list[list[int | str]] + List of lists of unit ids to merge. + new_unit_ids : list[int | str] + List of new unit ids after merging. + new_sorting_analyzer : SortingAnalyzer + The new SortingAnalyzer object after merging. + keep_mask : np.ndarray | None, default: None + Mask to keep certain spikes (not used here). + verbose : bool, default: False + Whether to print verbose output. + job_kwargs : dict + Additional job keyword arguments. + + Returns + ------- + dict + Dictionary containing the merged metrics DataFrame. + """ + import pandas as pd + + available_metric_names = [m.metric_name for m in self.metric_list] + metric_names = [m for m in self.params["metric_names"] if m in available_metric_names] + old_metrics = self.data["metrics"] + + all_unit_ids = new_sorting_analyzer.unit_ids + not_new_ids = all_unit_ids[~np.isin(all_unit_ids, new_unit_ids)] + + metrics = pd.DataFrame(index=all_unit_ids, columns=old_metrics.columns) + + metrics.loc[not_new_ids, :] = old_metrics.loc[not_new_ids, :] + metrics.loc[new_unit_ids, :] = self._compute_metrics( + sorting_analyzer=new_sorting_analyzer, unit_ids=new_unit_ids, metric_names=metric_names, **job_kwargs + ) + metrics = self._cast_metrics(metrics) + + new_data = dict(metrics=metrics) + return new_data + + def _split_extension_data( + self, + split_units: dict[int | str, list[list[int]]], + new_unit_ids: list[list[int | str]], + new_sorting_analyzer: SortingAnalyzer, + verbose: bool = False, + **job_kwargs, + ): + """ + Split extension data from the old metrics DataFrame into the new one. + + Parameters + ---------- + split_units : dict[int | str, list[list[int]]] + List of unit ids to split. + new_unit_ids : list[list[int | str]] + List of lists of new unit ids after splitting. + new_sorting_analyzer : SortingAnalyzer + The new SortingAnalyzer object after splitting. + verbose : bool, default: False + Whether to print verbose output. + """ + import pandas as pd + from itertools import chain + + available_metric_names = [m.metric_name for m in self.metric_list] + metric_names = [m for m in self.params["metric_names"] if m in available_metric_names] + old_metrics = self.data["metrics"] + + all_unit_ids = new_sorting_analyzer.unit_ids + new_unit_ids_f = list(chain(*new_unit_ids)) + not_new_ids = all_unit_ids[~np.isin(all_unit_ids, new_unit_ids_f)] + + metrics = pd.DataFrame(index=all_unit_ids, columns=old_metrics.columns) + + metrics.loc[not_new_ids, :] = old_metrics.loc[not_new_ids, :] + metrics.loc[new_unit_ids_f, :] = self._compute_metrics( + sorting_analyzer=new_sorting_analyzer, unit_ids=new_unit_ids_f, metric_names=metric_names, **job_kwargs + ) + metrics = self._cast_metrics(metrics) + + new_data = dict(metrics=metrics) + return new_data diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 58bc48f72b..f4d78fa41a 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1975,6 +1975,30 @@ def get_default_extension_params(self, extension_name: str) -> dict: """ return get_default_analyzer_extension_params(extension_name) + def get_metrics_extension_data(self): + """ + Get all metrics data into a single dataframe. + + Returns + ------- + metrics_df : pandas.DataFrame + A concatenated dataframe with all available metrics. + """ + import pandas as pd + from spikeinterface.core.analyzer_extension_core import BaseMetricExtension + + all_metrics_data = [] + for extension_name, ext in self.extensions.items(): + if isinstance(ext, BaseMetricExtension): + metric_data = ext.get_data() + all_metrics_data.append(metric_data) + + if len(all_metrics_data) > 0: + metrics_df = pd.concat(all_metrics_data, axis=1) + else: + metrics_df = pd.DataFrame(index=self.unit_ids) + return metrics_df + def _sort_extensions_by_dependency(extensions): """ @@ -2425,8 +2449,7 @@ def load_data(self): ext_data = pickle.load(f) else: continue - self.data[ext_data_name] = ext_data - + self.set_data(ext_data_name, ext_data) elif self.format == "zarr": extension_group = self._get_zarr_extension_group(mode="r") for ext_data_name in extension_group.keys(): @@ -2447,7 +2470,7 @@ def load_data(self): else: # this load in memmory ext_data = np.array(ext_data_) - self.data[ext_data_name] = ext_data + self.set_data(ext_data_name, ext_data) if len(self.data) == 0: warnings.warn(f"Found no data for {self.extension_name}, extension should be re-computed.") @@ -2745,6 +2768,9 @@ def get_data(self, *args, **kwargs): assert len(self.data) > 0, "Extension has been run but no data found." return self._get_data(*args, **kwargs) + def set_data(self, ext_data_name, ext_data): + self.data[ext_data_name] = ext_data + # this is a hardcoded list to to improve error message and auto_import mechanism # this is important because extension are registered when the submodule is imported @@ -2762,9 +2788,10 @@ def get_data(self, *args, **kwargs): "principal_components": "spikeinterface.postprocessing", "spike_amplitudes": "spikeinterface.postprocessing", "spike_locations": "spikeinterface.postprocessing", - "template_metrics": "spikeinterface.postprocessing", "template_similarity": "spikeinterface.postprocessing", "unit_locations": "spikeinterface.postprocessing", - # from quality metrics - "quality_metrics": "spikeinterface.qualitymetrics", + # from metrics + "quality_metrics": "spikeinterface.metrics", + "template_metrics": "spikeinterface.metrics", + "spiketrain_metrics": "spikeinterface.metrics", } diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 9cce5e3f23..1acc041057 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -14,7 +14,8 @@ HAVE_NUMBA = False from spikeinterface.core import SortingAnalyzer -from spikeinterface.qualitymetrics import compute_refrac_period_violations, compute_firing_rates +from spikeinterface.metrics.quality.misc_metrics import compute_refrac_period_violations +from spikeinterface.metrics.spiketrain.metrics import compute_firing_rates from .mergeunitssorting import MergeUnitsSorting from .curation_tools import resolve_merging_graph diff --git a/src/spikeinterface/curation/model_based_curation.py b/src/spikeinterface/curation/model_based_curation.py index 98f8b4073f..e779e13182 100644 --- a/src/spikeinterface/curation/model_based_curation.py +++ b/src/spikeinterface/curation/model_based_curation.py @@ -174,7 +174,9 @@ def _check_params_for_classification(self, enforce_metric_params=False, model_in if model_metric_params is None or metric not in model_metric_params: inconsistent_metrics.append(metric) else: - if metric_params[metric] != model_metric_params[metric]: + if metric not in metric_params: + inconsistent_metrics.append(metric) + elif metric_params[metric] != model_metric_params[metric]: warning_message = f"{extension_name} params for {metric} do not match those used to train the model. Parameters can be found in the 'model_info.json' file." if enforce_metric_params is True: raise Exception(warning_message) diff --git a/src/spikeinterface/curation/tests/test_model_based_curation.py b/src/spikeinterface/curation/tests/test_model_based_curation.py index e2452c1d54..6156062a7c 100644 --- a/src/spikeinterface/curation/tests/test_model_based_curation.py +++ b/src/spikeinterface/curation/tests/test_model_based_curation.py @@ -136,7 +136,8 @@ def test_model_based_classification_predict_labels(sorting_analyzer_for_curation assert np.all(predictions_labelled == ["good", "noise", "good", "noise", "good"]) -def test_exception_raised_when_metricparams_not_equal(sorting_analyzer_for_curation): +@pytest.mark.skip(reason="We need to retrain the model to reflect any changes in metric computation") +def test_exception_raised_when_metric_params_not_equal(sorting_analyzer_for_curation): """We track whether the metric parameters used to compute the metrics used to train a model are the same as the parameters used to compute the metrics in the sorting analyzer which is being curated. If they are different, an error or warning will @@ -161,7 +162,11 @@ def test_exception_raised_when_metricparams_not_equal(sorting_analyzer_for_curat model_based_classification._check_params_for_classification(enforce_metric_params=False, model_info=model_info) # Now test the positive case. Recompute using the default parameters - sorting_analyzer_for_curation.compute("quality_metrics", metric_names=["num_spikes", "snr"], metric_params={}) + sorting_analyzer_for_curation.compute( + "quality_metrics", + metric_names=["num_spikes", "snr"], + metric_params={"snr": {"peak_sign": "neg", "peak_mode": "extremum"}}, + ) sorting_analyzer_for_curation.compute("template_metrics", metric_names=["half_width"]) model, model_info = load_model(model_folder=model_folder, trusted=["numpy.dtype"]) diff --git a/src/spikeinterface/curation/train_manual_curation.py b/src/spikeinterface/curation/train_manual_curation.py index fee222047a..4c9da1a430 100644 --- a/src/spikeinterface/curation/train_manual_curation.py +++ b/src/spikeinterface/curation/train_manual_curation.py @@ -4,15 +4,16 @@ import json import spikeinterface from spikeinterface.core.job_tools import fix_job_kwargs -from spikeinterface.qualitymetrics import ( + +# TODO fix with new metrics +from spikeinterface.metrics import ( + ComputeQualityMetrics, + ComputeTemplateMetrics, get_quality_metric_list, get_quality_pca_metric_list, - qm_compute_name_to_column_names, + get_template_metric_list, ) -from spikeinterface.postprocessing import get_template_metric_names -from spikeinterface.postprocessing.template_metrics import tm_compute_name_to_column_names from pathlib import Path -from copy import deepcopy def get_default_classifier_search_spaces(): @@ -236,7 +237,7 @@ def __init__( def get_default_metrics_list(self): """Returns the default list of metrics.""" - return get_quality_metric_list() + get_quality_pca_metric_list() + get_template_metric_names() + return get_quality_metric_list() + get_quality_pca_metric_list() + get_template_metric_list() def load_and_preprocess_analyzers(self, analyzers, enforce_metric_params): """ @@ -325,10 +326,8 @@ def load_and_preprocess_csv(self, paths): def get_metric_params_csv(self): - from itertools import chain - - qm_metric_names = list(chain.from_iterable(qm_compute_name_to_column_names.values())) - tm_metric_names = list(chain.from_iterable(tm_compute_name_to_column_names.values())) + qm_metric_names = ComputeQualityMetrics.get_metric_columns() + tm_metric_names = ComputeTemplateMetrics.get_metric_columns() quality_metric_names = [] template_metric_names = [] diff --git a/src/spikeinterface/exporters/tests/common.py b/src/spikeinterface/exporters/tests/common.py index 2b191bc1e3..89867a917f 100644 --- a/src/spikeinterface/exporters/tests/common.py +++ b/src/spikeinterface/exporters/tests/common.py @@ -45,6 +45,7 @@ def make_sorting_analyzer(sparse=True, with_group=False): sorting_analyzer.compute("noise_levels") sorting_analyzer.compute("principal_components") sorting_analyzer.compute("template_similarity") + sorting_analyzer.compute("spike_amplitudes") sorting_analyzer.compute( "quality_metrics", metric_names=["snr", "amplitude_median", "isi_violation", "amplitude_cutoff"] ) diff --git a/src/spikeinterface/full.py b/src/spikeinterface/full.py index b9410bc021..4c52f37e9f 100644 --- a/src/spikeinterface/full.py +++ b/src/spikeinterface/full.py @@ -19,7 +19,7 @@ from .sorters import * from .preprocessing import * from .postprocessing import * -from .qualitymetrics import * +from .metrics import * from .curation import * from .comparison import * from .widgets import * diff --git a/src/spikeinterface/metrics/__init__.py b/src/spikeinterface/metrics/__init__.py new file mode 100644 index 0000000000..472de809fa --- /dev/null +++ b/src/spikeinterface/metrics/__init__.py @@ -0,0 +1,3 @@ +from .template import * +from .quality import * +from .spiketrain import * diff --git a/src/spikeinterface/metrics/conftest.py b/src/spikeinterface/metrics/conftest.py new file mode 100644 index 0000000000..8d32c103fa --- /dev/null +++ b/src/spikeinterface/metrics/conftest.py @@ -0,0 +1,8 @@ +import pytest + +from spikeinterface.postprocessing.tests.conftest import _small_sorting_analyzer + + +@pytest.fixture(scope="module") +def small_sorting_analyzer(): + return _small_sorting_analyzer() diff --git a/src/spikeinterface/metrics/quality/__init__.py b/src/spikeinterface/metrics/quality/__init__.py new file mode 100644 index 0000000000..1edcd9221f --- /dev/null +++ b/src/spikeinterface/metrics/quality/__init__.py @@ -0,0 +1,23 @@ +from .quality_metrics import ( + get_quality_metric_list, + get_quality_pca_metric_list, + get_default_quality_metrics_params, + get_default_qm_params, + ComputeQualityMetrics, + compute_quality_metrics, +) + +from .misc_metrics import ( + compute_snrs, + compute_isi_violations, + compute_amplitude_cutoffs, + compute_presence_ratios, + compute_drift_metrics, + compute_amplitude_cv_metrics, + compute_amplitude_medians, + compute_noise_cutoffs, + compute_firing_ranges, + compute_sliding_rp_violations, + compute_sd_ratio, + compute_synchrony_metrics, +) diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py similarity index 88% rename from src/spikeinterface/qualitymetrics/misc_metrics.py rename to src/spikeinterface/metrics/quality/misc_metrics.py index e7b9dee2c7..c4d8941ccc 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -9,7 +9,6 @@ from __future__ import annotations -from .utils import _has_required_extensions from collections import namedtuple import math import warnings @@ -17,6 +16,7 @@ import numpy as np +from spikeinterface.core.analyzer_extension_core import BaseMetric from spikeinterface.core.job_tools import fix_job_kwargs, split_job_kwargs from spikeinterface.postprocessing import correlogram_for_one_segment from spikeinterface.core import SortingAnalyzer, get_noise_levels @@ -26,6 +26,8 @@ get_dense_templates_array, ) +from ..spiketrain.metrics import NumSpikes, FiringRate + numba_spec = importlib.util.find_spec("numba") if numba_spec is not None: HAVE_NUMBA = True @@ -33,233 +35,7 @@ HAVE_NUMBA = False -_default_params = dict() - - -def compute_noise_cutoffs(sorting_analyzer, high_quantile=0.25, low_quantile=0.1, n_bins=100, unit_ids=None): - """ - A metric to determine if a unit's amplitude distribution is cut off as it approaches zero, without assuming a Gaussian distribution. - - Based on the histogram of the (transformed) amplitude: - - 1. This method compares counts in the lower-amplitude bins to counts in the top 'high_quantile' of the amplitude range. - It computes the mean and std of an upper quantile of the distribution, and calculates how many standard deviations away - from that mean the lower-quantile bins lie. - - 2. The method also compares the counts in the lower-amplitude bins to the count in the highest bin and return their ratio. - - Parameters - ---------- - sorting_analyzer : SortingAnalyzer - A SortingAnalyzer object. - high_quantile : float, default: 0.25 - Quantile of the amplitude range above which values are treated as "high" (e.g. 0.25 = top 25%), the reference region. - low_quantile : int, default: 0.1 - Quantile of the amplitude range below which values are treated as "low" (e.g. 0.1 = lower 10%), the test region. - n_bins: int, default: 100 - The number of bins to use to compute the amplitude histogram. - unit_ids : list or None - List of unit ids to compute the amplitude cutoffs. If None, all units are used. - - Returns - ------- - noise_cutoff_dict : dict of floats - Estimated metrics based on the amplitude distribution, for each unit ID. - - References - ---------- - Inspired by metric described in [IBL2024]_ - - """ - res = namedtuple("cutoff_metrics", ["noise_cutoff", "noise_ratio"]) - if unit_ids is None: - unit_ids = sorting_analyzer.unit_ids - - noise_cutoff_dict = {} - noise_ratio_dict = {} - - _has_required_extensions(sorting_analyzer, metric_name="noise_cutoff") - - amplitude_extension = sorting_analyzer.get_extension("spike_amplitudes") - peak_sign = amplitude_extension.params["peak_sign"] - if peak_sign == "both": - raise TypeError( - '`peak_sign` should either be "pos" or "neg". You can set `peak_sign` as an argument when you compute spike_amplitudes.' - ) - - amplitudes_by_units = _get_amplitudes_by_units(sorting_analyzer, unit_ids, peak_sign) - - for unit_id in unit_ids: - amplitudes = amplitudes_by_units[unit_id] - - # We assume the noise (zero values) is on the lower tail of the amplitude distribution. - # But if peak_sign == 'neg', the noise will be on the higher tail, so we flip the distribution. - if peak_sign == "neg": - amplitudes = -amplitudes - - cutoff, ratio = _noise_cutoff(amplitudes, high_quantile=high_quantile, low_quantile=low_quantile, n_bins=n_bins) - noise_cutoff_dict[unit_id] = cutoff - noise_ratio_dict[unit_id] = ratio - - return res(noise_cutoff_dict, noise_ratio_dict) - - -_default_params["noise_cutoff"] = dict(high_quantile=0.25, low_quantile=0.1, n_bins=100) - - -def _noise_cutoff(amps, high_quantile=0.25, low_quantile=0.1, n_bins=100): - """ - A metric to determine if a unit's amplitude distribution is cut off as it approaches zero, without assuming a Gaussian distribution. - - Based on the histogram of the (transformed) amplitude: - - 1. This method compares counts in the lower-amplitude bins to counts in the higher_amplitude bins. - It computes the mean and std of an upper quantile of the distribution, and calculates how many standard deviations away - from that mean the lower-quantile bins lie. - - 2. The method also compares the counts in the lower-amplitude bins to the count in the highest bin and return their ratio. - - Parameters - ---------- - amps : array-like - Spike amplitudes. - high_quantile : float, default: 0.25 - Quantile of the amplitude range above which values are treated as "high" (e.g. 0.25 = top 25%), the reference region. - low_quantile : int, default: 0.1 - Quantile of the amplitude range below which values are treated as "low" (e.g. 0.1 = lower 10%), the test region. - n_bins: int, default: 100 - The number of bins to use to compute the amplitude histogram. - - Returns - ------- - cutoff : float - (mean(lower_bins_count) - mean(high_bins_count)) / std(high_bins_count) - ratio: float - mean(lower_bins_count) / highest_bin_count - - """ - n_per_bin, bin_edges = np.histogram(amps, bins=n_bins) - - maximum_bin_height = np.max(n_per_bin) - - low_quantile_value = np.quantile(amps, q=low_quantile) - - # the indices for low-amplitude bins - low_indices = np.where(bin_edges[1:] <= low_quantile_value)[0] - - high_quantile_value = np.quantile(amps, q=1 - high_quantile) - - # the indices for high-amplitude bins - high_indices = np.where(bin_edges[:-1] >= high_quantile_value)[0] - - if len(low_indices) == 0: - warnings.warn( - "No bin is selected to test cutoff. Please increase low_quantile. Setting noise cutoff and ratio to NaN" - ) - return np.nan, np.nan - - # compute ratio between low-amplitude bins and the largest bin - low_counts = n_per_bin[low_indices] - mean_low_counts = np.mean(low_counts) - ratio = mean_low_counts / maximum_bin_height - - if len(high_indices) == 0: - warnings.warn( - "No bin is selected as the reference region. Please increase high_quantile. Setting noise cutoff to NaN" - ) - return np.nan, ratio - - if len(high_indices) == 1: - warnings.warn( - "Only one bin is selected as the reference region, and thus the standard deviation cannot be computed. " - "Please increase high_quantile. Setting noise cutoff to NaN" - ) - return np.nan, ratio - - # compute cutoff from low-amplitude and high-amplitude bins - high_counts = n_per_bin[high_indices] - mean_high_counts = np.mean(high_counts) - std_high_counts = np.std(high_counts) - if std_high_counts == 0: - warnings.warn( - "All the high-amplitude bins have the same size. Please consider changing n_bins. " - "Setting noise cutoff to NaN" - ) - return np.nan, ratio - - cutoff = (mean_low_counts - mean_high_counts) / std_high_counts - return cutoff, ratio - - -def compute_num_spikes(sorting_analyzer, unit_ids=None, **kwargs): - """ - Compute the number of spike across segments. - - Parameters - ---------- - sorting_analyzer : SortingAnalyzer - A SortingAnalyzer object. - unit_ids : list or None - The list of unit ids to compute the number of spikes. If None, all units are used. - - Returns - ------- - num_spikes : dict - The number of spikes, across all segments, for each unit ID. - """ - - sorting = sorting_analyzer.sorting - if unit_ids is None: - unit_ids = sorting.unit_ids - num_segs = sorting.get_num_segments() - - num_spikes = {} - for unit_id in unit_ids: - n = 0 - for segment_index in range(num_segs): - st = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) - n += st.size - num_spikes[unit_id] = n - - return num_spikes - - -_default_params["num_spikes"] = {} - - -def compute_firing_rates(sorting_analyzer, unit_ids=None): - """ - Compute the firing rate across segments. - - Parameters - ---------- - sorting_analyzer : SortingAnalyzer - A SortingAnalyzer object. - unit_ids : list or None - The list of unit ids to compute the firing rate. If None, all units are used. - - Returns - ------- - firing_rates : dict of floats - The firing rate, across all segments, for each unit ID. - """ - - sorting = sorting_analyzer.sorting - if unit_ids is None: - unit_ids = sorting.unit_ids - total_duration = sorting_analyzer.get_total_duration() - - firing_rates = {} - num_spikes = compute_num_spikes(sorting_analyzer) - for unit_id in unit_ids: - firing_rates[unit_id] = num_spikes[unit_id] / total_duration - return firing_rates - - -_default_params["firing_rate"] = {} - - -def compute_presence_ratios(sorting_analyzer, bin_duration_s=60.0, mean_fr_ratio_thresh=0.0, unit_ids=None): +def compute_presence_ratios(sorting_analyzer, unit_ids=None, bin_duration_s=60.0, mean_fr_ratio_thresh=0.0): """ Calculate the presence ratio, the fraction of time the unit is firing above a certain threshold. @@ -267,14 +43,14 @@ def compute_presence_ratios(sorting_analyzer, bin_duration_s=60.0, mean_fr_ratio ---------- sorting_analyzer : SortingAnalyzer A SortingAnalyzer object. + unit_ids : list or None + The list of unit ids to compute the presence ratio. If None, all units are used. bin_duration_s : float, default: 60 The duration of each bin in seconds. If the duration is less than this value, presence_ratio is set to NaN. mean_fr_ratio_thresh : float, default: 0 The unit is considered active in a bin if its firing rate during that bin. is strictly above `mean_fr_ratio_thresh` times its mean firing rate throughout the recording. - unit_ids : list or None - The list of unit ids to compute the presence ratio. If None, all units are used. Returns ------- @@ -334,17 +110,18 @@ def compute_presence_ratios(sorting_analyzer, bin_duration_s=60.0, mean_fr_ratio return presence_ratios -_default_params["presence_ratio"] = dict( - bin_duration_s=60, - mean_fr_ratio_thresh=0.0, -) +class PresenceRatio(BaseMetric): + metric_name = "presence_ratio" + metric_function = compute_presence_ratios + metric_params = {"bin_duration_s": 60, "mean_fr_ratio_thresh": 0.0} + metric_columns = {"presence_ratio": float} def compute_snrs( sorting_analyzer, + unit_ids=None, peak_sign: str = "neg", peak_mode: str = "extremum", - unit_ids=None, ): """ Compute signal to noise ratio. @@ -353,26 +130,25 @@ def compute_snrs( ---------- sorting_analyzer : SortingAnalyzer A SortingAnalyzer object. + unit_ids : list or None + The list of unit ids to compute the SNR. If None, all units are used. peak_sign : "neg" | "pos" | "both", default: "neg" The sign of the template to compute best channels. peak_mode : "extremum" | "at_index" | "peak_to_peak", default: "extremum" How to compute the amplitude. Extremum takes the maxima/minima At_index takes the value at t=sorting_analyzer.nbefore. - unit_ids : list or None - The list of unit ids to compute the SNR. If None, all units are used. Returns ------- snrs : dict Computed signal to noise ratio for each unit. """ + check_has_required_extensions("snr", sorting_analyzer) if unit_ids is None: unit_ids = sorting_analyzer.unit_ids - _has_required_extensions(sorting_analyzer, metric_name="snr") - noise_levels = sorting_analyzer.get_extension("noise_levels").get_data() assert peak_sign in ("neg", "pos", "both") @@ -396,10 +172,15 @@ def compute_snrs( return snrs -_default_params["snr"] = dict(peak_sign="neg", peak_mode="extremum") +class SNR(BaseMetric): + metric_name = "snr" + metric_function = compute_snrs + metric_params = {"peak_sign": "neg", "peak_mode": "extremum"} + metric_columns = {"snr": float} + depend_on = ["noise_levels", "templates"] -def compute_isi_violations(sorting_analyzer, isi_threshold_ms=1.5, min_isi_ms=0, unit_ids=None): +def compute_isi_violations(sorting_analyzer, unit_ids=None, isi_threshold_ms=1.5, min_isi_ms=0): """ Calculate Inter-Spike Interval (ISI) violations. @@ -412,6 +193,8 @@ def compute_isi_violations(sorting_analyzer, isi_threshold_ms=1.5, min_isi_ms=0, ---------- sorting_analyzer : SortingAnalyzer The SortingAnalyzer object. + unit_ids : list or None + List of unit ids to compute the ISI violations. If None, all units are used. isi_threshold_ms : float, default: 1.5 Threshold for classifying adjacent spikes as an ISI violation, in ms. This is the biophysical refractory period. @@ -419,8 +202,6 @@ def compute_isi_violations(sorting_analyzer, isi_threshold_ms=1.5, min_isi_ms=0, Minimum possible inter-spike interval, in ms. This is the artificial refractory period enforced. by the data acquisition system or post-processing algorithms. - unit_ids : list or None - List of unit ids to compute the ISI violations. If None, all units are used. Returns ------- @@ -485,11 +266,15 @@ def compute_isi_violations(sorting_analyzer, isi_threshold_ms=1.5, min_isi_ms=0, return res(isi_violations_ratio, isi_violations_count) -_default_params["isi_violation"] = dict(isi_threshold_ms=1.5, min_isi_ms=0) +class ISIViolation(BaseMetric): + metric_name = "isi_violation" + metric_function = compute_isi_violations + metric_params = {"isi_threshold_ms": 1.5, "min_isi_ms": 0} + metric_columns = {"isi_violations_ratio": float, "isi_violations_count": int} def compute_refrac_period_violations( - sorting_analyzer, refractory_period_ms: float = 1.0, censored_period_ms: float = 0.0, unit_ids=None + sorting_analyzer, unit_ids=None, refractory_period_ms: float = 1.0, censored_period_ms: float = 0.0 ): """ Calculate the number of refractory period violations. @@ -502,13 +287,13 @@ def compute_refrac_period_violations( ---------- sorting_analyzer : SortingAnalyzer The SortingAnalyzer object. + unit_ids : list or None + List of unit ids to compute the refractory period violations. If None, all units are used. refractory_period_ms : float, default: 1.0 The period (in ms) where no 2 good spikes can occur. censored_period_ms : float, default: 0.0 The period (in ms) where no 2 spikes can occur (because they are not detected, or because they were removed by another mean). - unit_ids : list or None - List of unit ids to compute the refractory period violations. If None, all units are used. Returns ------- @@ -531,6 +316,8 @@ def compute_refrac_period_violations( ---------- Based on metrics described in [Llobet]_ """ + from spikeinterface.metrics.spiketrain.metrics import compute_num_spikes + res = namedtuple("rp_violations", ["rp_contamination", "rp_violations"]) if not HAVE_NUMBA: @@ -579,18 +366,22 @@ def compute_refrac_period_violations( return res(rp_contamination, nb_violations) -_default_params["rp_violation"] = dict(refractory_period_ms=1.0, censored_period_ms=0.0) +class RPViolation(BaseMetric): + metric_name = "rp_violation" + metric_function = compute_refrac_period_violations + metric_params = {"refractory_period_ms": 1.0, "censored_period_ms": 0.0} + metric_columns = {"rp_contamination": float, "rp_violations": int} def compute_sliding_rp_violations( sorting_analyzer, + unit_ids=None, min_spikes=0, bin_size_ms=0.25, window_size_s=1, exclude_ref_period_below_ms=0.5, max_ref_period_ms=10, contamination_values=None, - unit_ids=None, ): """ Compute sliding refractory period violations, a metric developed by IBL which computes @@ -601,6 +392,8 @@ def compute_sliding_rp_violations( ---------- sorting_analyzer : SortingAnalyzer A SortingAnalyzer object. + unit_ids : list or None + List of unit ids to compute the sliding RP violations. If None, all units are used. min_spikes : int, default: 0 Contamination is set to np.nan if the unit has less than this many spikes across all segments. @@ -614,8 +407,6 @@ def compute_sliding_rp_violations( Maximum refractory period to test in ms. contamination_values : 1d array or None, default: None The contamination values to test, If None, it is set to np.arange(0.5, 35, 0.5). - unit_ids : list or None - List of unit ids to compute the sliding RP violations. If None, all units are used. Returns ------- @@ -668,33 +459,38 @@ def compute_sliding_rp_violations( return contamination -_default_params["sliding_rp_violation"] = dict( - min_spikes=0, - bin_size_ms=0.25, - window_size_s=1, - exclude_ref_period_below_ms=0.5, - max_ref_period_ms=10, - contamination_values=None, -) +class SlidingRPViolation(BaseMetric): + metric_name = "sliding_rp_violation" + metric_function = compute_sliding_rp_violations + metric_params = { + "min_spikes": 0, + "bin_size_ms": 0.25, + "window_size_s": 1, + "exclude_ref_period_below_ms": 0.5, + "max_ref_period_ms": 10, + "contamination_values": None, + } + metric_columns = {"sliding_rp_violation": float} -def _get_synchrony_counts(spikes, synchrony_sizes, all_unit_ids): +def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=None): """ - Compute synchrony counts, the number of simultaneous spikes with sizes `synchrony_sizes`. + Compute synchrony metrics. Synchrony metrics represent the rate of occurrences of + spikes at the exact same sample index, with synchrony sizes 2, 4 and 8. Parameters ---------- - spikes : np.array - Structured numpy array with fields ("sample_index", "unit_index", "segment_index"). - all_unit_ids : list or None, default: None - List of unit ids to compute the synchrony metrics. Expecting all units. - synchrony_sizes : None or np.array, default: None - The synchrony sizes to compute. Should be pre-sorted. + sorting_analyzer : SortingAnalyzer + A SortingAnalyzer object. + unit_ids : list or None, default: None + List of unit ids to compute the synchrony metrics. If None, all units are used. + synchrony_sizes: None, default: None + Deprecated argument. Please use private `_get_synchrony_counts` if you need finer control over number of synchronous spikes. Returns ------- - synchrony_counts : np.ndarray - The synchrony counts for the synchrony sizes. + sync_spike_{X} : dict + The synchrony metric for synchrony size X. References ---------- @@ -702,55 +498,9 @@ def _get_synchrony_counts(spikes, synchrony_sizes, all_unit_ids): This code was adapted from `Elephant - Electrophysiology Analysis Toolkit `_ """ - synchrony_counts = np.zeros((np.size(synchrony_sizes), len(all_unit_ids)), dtype=np.int64) - - # compute the occurrence of each sample_index. Count >2 means there's synchrony - _, unique_spike_index, counts = np.unique(spikes["sample_index"], return_index=True, return_counts=True) - - sync_indices = unique_spike_index[counts >= 2] - sync_counts = counts[counts >= 2] - - for i, sync_index in enumerate(sync_indices): - - num_of_syncs = sync_counts[i] - units_with_sync = [spikes[sync_index + a][1] for a in range(0, num_of_syncs)] - - # Counts inclusively. E.g. if there are 3 simultaneous spikes, these are also added - # to the 2 simultaneous spike bins. - how_many_bins_to_add_to = np.size(synchrony_sizes[synchrony_sizes <= num_of_syncs]) - synchrony_counts[:how_many_bins_to_add_to, units_with_sync] += 1 - - return synchrony_counts - - -def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=None): - """ - Compute synchrony metrics. Synchrony metrics represent the rate of occurrences of - spikes at the exact same sample index, with synchrony sizes 2, 4 and 8. - - Parameters - ---------- - sorting_analyzer : SortingAnalyzer - A SortingAnalyzer object. - unit_ids : list or None, default: None - List of unit ids to compute the synchrony metrics. If None, all units are used. - synchrony_sizes: None, default: None - Deprecated argument. Please use private `_get_synchrony_counts` if you need finer control over number of synchronous spikes. - - Returns - ------- - sync_spike_{X} : dict - The synchrony metric for synchrony size X. - - References - ---------- - Based on concepts described in [GrĂ¼n]_ - This code was adapted from `Elephant - Electrophysiology Analysis Toolkit `_ - """ - - if synchrony_sizes is not None: - warning_message = "Custom `synchrony_sizes` is deprecated; the `synchrony_metrics` will be computed using `synchrony_sizes = [2,4,8]`" - warnings.warn(warning_message, DeprecationWarning, stacklevel=2) + if synchrony_sizes is not None: + warning_message = "Custom `synchrony_sizes` is deprecated; the `synchrony_metrics` will be computed using `synchrony_sizes = [2,4,8]`" + warnings.warn(warning_message, DeprecationWarning, stacklevel=2) synchrony_sizes = np.array([2, 4, 8]) @@ -782,10 +532,13 @@ def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=N return res(**synchrony_metrics_dict) -_default_params["synchrony"] = dict() +class Synchrony(BaseMetric): + metric_name = "synchrony" + metric_function = compute_synchrony_metrics + metric_columns = {"sync_spike_2": float, "sync_spike_4": float, "sync_spike_8": float} -def compute_firing_ranges(sorting_analyzer, bin_size_s=5, percentiles=(5, 95), unit_ids=None): +def compute_firing_ranges(sorting_analyzer, unit_ids=None, bin_size_s=5, percentiles=(5, 95)): """ Calculate firing range, the range between the 5th and 95th percentiles of the firing rates distribution computed in non-overlapping time bins. @@ -793,13 +546,13 @@ def compute_firing_ranges(sorting_analyzer, bin_size_s=5, percentiles=(5, 95), u Parameters ---------- sorting_analyzer : SortingAnalyzer - A SortingAnalyzer object + A SortingAnalyzer object. + unit_ids : list or None + List of unit ids to compute the firing range. If None, all units are used. bin_size_s : float, default: 5 The size of the bin in seconds. percentiles : tuple, default: (5, 95) The percentiles to compute. - unit_ids : list or None - List of unit ids to compute the firing range. If None, all units are used. Returns ------- @@ -847,16 +600,20 @@ def compute_firing_ranges(sorting_analyzer, bin_size_s=5, percentiles=(5, 95), u return firing_ranges -_default_params["firing_range"] = dict(bin_size_s=5, percentiles=(5, 95)) +class FiringRange(BaseMetric): + metric_name = "firing_range" + metric_function = compute_firing_ranges + metric_params = {"bin_size_s": 5, "percentiles": (5, 95)} + metric_columns = {"firing_range": float} def compute_amplitude_cv_metrics( sorting_analyzer, + unit_ids=None, average_num_spikes_per_bin=50, percentiles=(5, 95), min_num_bins=10, amplitude_extension="spike_amplitudes", - unit_ids=None, ): """ Calculate coefficient of variation of spike amplitudes within defined temporal bins. @@ -867,6 +624,8 @@ def compute_amplitude_cv_metrics( ---------- sorting_analyzer : SortingAnalyzer A SortingAnalyzer object. + unit_ids : list or None + List of unit ids to compute the amplitude spread. If None, all units are used. average_num_spikes_per_bin : int, default: 50 The average number of spikes per bin. This is used to estimate a temporal bin size using the firing rate of each unit. For example, if a unit has a firing rate of 10 Hz, amd the average number of spikes per bin is @@ -878,8 +637,6 @@ def compute_amplitude_cv_metrics( the median and range are set to NaN. amplitude_extension : str, default: "spike_amplitudes" The name of the extension to load the amplitudes from. "spike_amplitudes" or "amplitude_scalings". - unit_ids : list or None - List of unit ids to compute the amplitude spread. If None, all units are used. Returns ------- @@ -892,6 +649,7 @@ def compute_amplitude_cv_metrics( ----- Designed by Simon Musall and Alessio Buccino. """ + check_has_required_extensions("amplitude_cv", sorting_analyzer) res = namedtuple("amplitude_cv", ["amplitude_cv_median", "amplitude_cv_range"]) assert amplitude_extension in ( "spike_amplitudes", @@ -904,8 +662,6 @@ def compute_amplitude_cv_metrics( if unit_ids is None: unit_ids = sorting.unit_ids - _has_required_extensions(sorting_analyzer, metric_name="amplitude_cv") - amps = sorting_analyzer.get_extension(amplitude_extension).get_data() # precompute segment slice @@ -952,41 +708,25 @@ def compute_amplitude_cv_metrics( return res(amplitude_cv_medians, amplitude_cv_ranges) -_default_params["amplitude_cv"] = dict( - average_num_spikes_per_bin=50, percentiles=(5, 95), min_num_bins=10, amplitude_extension="spike_amplitudes" -) - - -def _get_amplitudes_by_units(sorting_analyzer, unit_ids, peak_sign): - # used by compute_amplitude_cutoffs and compute_amplitude_medians - - if (spike_amplitudes_extension := sorting_analyzer.get_extension("spike_amplitudes")) is not None: - return spike_amplitudes_extension.get_data(outputs="by_unit", concatenated=True) - - elif sorting_analyzer.has_extension("waveforms"): - amplitudes_by_units = {} - waveforms_ext = sorting_analyzer.get_extension("waveforms") - before = waveforms_ext.nbefore - extremum_channels_ids = get_template_extremum_channel(sorting_analyzer, peak_sign=peak_sign) - for unit_id in unit_ids: - waveforms = waveforms_ext.get_waveforms_one_unit(unit_id, force_dense=False) - chan_id = extremum_channels_ids[unit_id] - if sorting_analyzer.is_sparse(): - chan_ind = np.where(sorting_analyzer.sparsity.unit_id_to_channel_ids[unit_id] == chan_id)[0] - else: - chan_ind = sorting_analyzer.channel_ids_to_indices([chan_id])[0] - amplitudes_by_units[unit_id] = waveforms[:, before, chan_ind] - - return amplitudes_by_units +class AmplitudeCV(BaseMetric): + metric_name = "amplitude_cv" + metric_function = compute_amplitude_cv_metrics + metric_params = { + "average_num_spikes_per_bin": 50, + "percentiles": (5, 95), + "min_num_bins": 10, + "amplitude_extension": "spike_amplitudes", + } + metric_columns = {"amplitude_cv_median": float, "amplitude_cv_range": float} + depend_on = ["spike_amplitudes|amplitude_scalings"] def compute_amplitude_cutoffs( sorting_analyzer, - peak_sign="neg", + unit_ids=None, num_histogram_bins=500, histogram_smoothing_value=3, amplitudes_bins_min_ratio=5, - unit_ids=None, ): """ Calculate approximate fraction of spikes missing from a distribution of amplitudes. @@ -995,8 +735,8 @@ def compute_amplitude_cutoffs( ---------- sorting_analyzer : SortingAnalyzer A SortingAnalyzer object. - peak_sign : "neg" | "pos" | "both", default: "neg" - The sign of the peaks. + unit_ids : list or None + List of unit ids to compute the amplitude cutoffs. If None, all units are used. num_histogram_bins : int, default: 100 The number of bins to use to compute the amplitude histogram. histogram_smoothing_value : int, default: 3 @@ -1005,8 +745,6 @@ def compute_amplitude_cutoffs( The minimum ratio between number of amplitudes for a unit and the number of bins. If the ratio is less than this threshold, the amplitude_cutoff for the unit is set to NaN. - unit_ids : list or None - List of unit ids to compute the amplitude cutoffs. If None, all units are used. Returns ------- @@ -1017,9 +755,7 @@ def compute_amplitude_cutoffs( Notes ----- This approach assumes the amplitude histogram is symmetric (not valid in the presence of drift). - If available, amplitudes are extracted from the "spike_amplitude" extension (recommended). - If the "spike_amplitude" extension is not available, the amplitudes are extracted from the SortingAnalyzer, - which usually has waveforms for a small subset of spikes (500 by default). + If available, amplitudes are extracted from the "spike_amplitude" or "amplitude_scalings" extensions. References ---------- @@ -1029,22 +765,22 @@ def compute_amplitude_cutoffs( https://github.com/AllenInstitute/ecephys_spike_sorting/tree/master/ecephys_spike_sorting/modules/quality_metrics """ + check_has_required_extensions("amplitude_cutoff", sorting_analyzer) if unit_ids is None: unit_ids = sorting_analyzer.unit_ids all_fraction_missing = {} - _has_required_extensions(sorting_analyzer, metric_name="amplitude_cutoff") - invert_amplitudes = False - if ( - sorting_analyzer.has_extension("spike_amplitudes") - and sorting_analyzer.get_extension("spike_amplitudes").params["peak_sign"] == "pos" - ): - invert_amplitudes = True - elif sorting_analyzer.has_extension("waveforms") and peak_sign == "pos": + if sorting_analyzer.has_extension("spike_amplitudes"): + extension = sorting_analyzer.get_extension("spike_amplitudes") + all_amplitudes = extension.get_data() + invert_amplitudes = np.median(all_amplitudes) > 0 + elif sorting_analyzer.has_extension("amplitude_scalings"): + # amplitude scalings are positive, we need to invert them invert_amplitudes = True + extension = sorting_analyzer.get_extension("amplitude_scalings") - amplitudes_by_units = _get_amplitudes_by_units(sorting_analyzer, unit_ids, peak_sign) + amplitudes_by_units = extension.get_data(outputs="by_unit", concatenated=True) for unit_id in unit_ids: amplitudes = amplitudes_by_units[unit_id] @@ -1061,12 +797,19 @@ def compute_amplitude_cutoffs( return all_fraction_missing -_default_params["amplitude_cutoff"] = dict( - peak_sign="neg", num_histogram_bins=100, histogram_smoothing_value=3, amplitudes_bins_min_ratio=5 -) +class AmplitudeCutoff(BaseMetric): + metric_name = "amplitude_cutoff" + metric_function = compute_amplitude_cutoffs + metric_params = { + "num_histogram_bins": 100, + "histogram_smoothing_value": 3, + "amplitudes_bins_min_ratio": 5, + } + metric_columns = {"amplitude_cutoff": float} + depend_on = ["spike_amplitudes|amplitude_scalings"] -def compute_amplitude_medians(sorting_analyzer, peak_sign="neg", unit_ids=None): +def compute_amplitude_medians(sorting_analyzer, unit_ids=None): """ Compute median of the amplitude distributions (in absolute value). @@ -1074,8 +817,6 @@ def compute_amplitude_medians(sorting_analyzer, peak_sign="neg", unit_ids=None): ---------- sorting_analyzer : SortingAnalyzer A SortingAnalyzer object. - peak_sign : "neg" | "pos" | "both", default: "neg" - The sign of the peaks. unit_ids : list or None List of unit ids to compute the amplitude medians. If None, all units are used. @@ -1090,32 +831,109 @@ def compute_amplitude_medians(sorting_analyzer, peak_sign="neg", unit_ids=None): This code is ported from: https://github.com/int-brain-lab/ibllib/blob/master/brainbox/metrics/single_units.py """ - sorting = sorting_analyzer.sorting + check_has_required_extensions("amplitude_median", sorting_analyzer) if unit_ids is None: unit_ids = sorting_analyzer.unit_ids - _has_required_extensions(sorting_analyzer, metric_name="amplitude_median") - all_amplitude_medians = {} - amplitudes_by_units = _get_amplitudes_by_units(sorting_analyzer, unit_ids, peak_sign) + amplitude_extension = sorting_analyzer.get_extension("spike_amplitudes") + amplitudes_by_units = amplitude_extension.get_data(outputs="by_unit", concatenated=True) for unit_id in unit_ids: all_amplitude_medians[unit_id] = np.median(amplitudes_by_units[unit_id]) return all_amplitude_medians -_default_params["amplitude_median"] = dict(peak_sign="neg") +class AmplitudeMedian(BaseMetric): + metric_name = "amplitude_median" + metric_function = compute_amplitude_medians + metric_columns = {"amplitude_median": float} + depend_on = ["spike_amplitudes"] + + +def compute_noise_cutoffs(sorting_analyzer, unit_ids=None, high_quantile=0.25, low_quantile=0.1, n_bins=100): + """ + A metric to determine if a unit's amplitude distribution is cut off as it approaches zero, without assuming a Gaussian distribution. + + Based on the histogram of the (transformed) amplitude: + + 1. This method compares counts in the lower-amplitude bins to counts in the top 'high_quantile' of the amplitude range. + It computes the mean and std of an upper quantile of the distribution, and calculates how many standard deviations away + from that mean the lower-quantile bins lie. + + 2. The method also compares the counts in the lower-amplitude bins to the count in the highest bin and return their ratio. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + A SortingAnalyzer object. + unit_ids : list or None + List of unit ids to compute the amplitude cutoffs. If None, all units are used. + high_quantile : float, default: 0.25 + Quantile of the amplitude range above which values are treated as "high" (e.g. 0.25 = top 25%), the reference region. + low_quantile : int, default: 0.1 + Quantile of the amplitude range below which values are treated as "low" (e.g. 0.1 = lower 10%), the test region. + n_bins: int, default: 100 + The number of bins to use to compute the amplitude histogram. + + Returns + ------- + noise_cutoff_dict : dict of floats + Estimated metrics based on the amplitude distribution, for each unit ID. + + References + ---------- + Inspired by metric described in [IBL2024]_ + + """ + check_has_required_extensions("noise_cutoff", sorting_analyzer) + res = namedtuple("cutoff_metrics", ["noise_cutoff", "noise_ratio"]) + if unit_ids is None: + unit_ids = sorting_analyzer.unit_ids + + noise_cutoff_dict = {} + noise_ratio_dict = {} + + if sorting_analyzer.has_extension("spike_amplitudes"): + extension = sorting_analyzer.get_extension("spike_amplitudes") + all_amplitudes = extension.get_data() + invert_amplitudes = np.median(all_amplitudes) > 0 + elif sorting_analyzer.has_extension("amplitude_scalings"): + # amplitude scalings are positive, we need to invert them + invert_amplitudes = True + extension = sorting_analyzer.get_extension("amplitude_scalings") + + amplitudes_by_units = extension.get_data(outputs="by_unit", concatenated=True) + + for unit_id in unit_ids: + amplitudes = amplitudes_by_units[unit_id] + if invert_amplitudes: + amplitudes = -amplitudes + + cutoff, ratio = _noise_cutoff(amplitudes, high_quantile=high_quantile, low_quantile=low_quantile, n_bins=n_bins) + noise_cutoff_dict[unit_id] = cutoff + noise_ratio_dict[unit_id] = ratio + + return res(noise_cutoff_dict, noise_ratio_dict) + + +class NoiseCutoff(BaseMetric): + metric_name = "noise_cutoff" + metric_function = compute_noise_cutoffs + metric_params = {"high_quantile": 0.25, "low_quantile": 0.1, "n_bins": 100} + metric_columns = {"noise_cutoff": float, "noise_ratio": float} + depend_on = ["spike_amplitudes|amplitude_scalings"] def compute_drift_metrics( sorting_analyzer, + unit_ids=None, interval_s=60, min_spikes_per_interval=100, direction="y", min_fraction_valid_intervals=0.5, min_num_bins=2, return_positions=False, - unit_ids=None, ): """ Compute drifts metrics using estimated spike locations. @@ -1135,6 +953,8 @@ def compute_drift_metrics( ---------- sorting_analyzer : SortingAnalyzer A SortingAnalyzer object. + unit_ids : list or None, default: None + List of unit ids to compute the drift metrics. If None, all units are used. interval_s : int, default: 60 Interval length is seconds for computing spike depth. min_spikes_per_interval : int, default: 100 @@ -1150,8 +970,6 @@ def compute_drift_metrics( less bins, the metric values are set to NaN. return_positions : bool, default: False If True, median positions are returned (for debugging). - unit_ids : list or None, default: None - List of unit ids to compute the drift metrics. If None, all units are used. Returns ------- @@ -1169,13 +987,12 @@ def compute_drift_metrics( For multi-segment object, segments are concatenated before the computation. This means that if there are large displacements in between segments, the resulting metric values will be very high. """ + check_has_required_extensions("drift", sorting_analyzer) res = namedtuple("drift_metrics", ["drift_ptp", "drift_std", "drift_mad"]) sorting = sorting_analyzer.sorting if unit_ids is None: unit_ids = sorting.unit_ids - _has_required_extensions(sorting_analyzer, metric_name="drift") - spike_locations_ext = sorting_analyzer.get_extension("spike_locations") spike_locations = spike_locations_ext.get_data() # spike_locations_by_unit = spike_locations_ext.get_data(outputs="by_unit") @@ -1272,76 +1089,267 @@ def compute_drift_metrics( return outs -_default_params["drift"] = dict(interval_s=60, min_spikes_per_interval=100, direction="y", min_num_bins=2) +class Drift(BaseMetric): + metric_name = "drift" + metric_function = compute_drift_metrics + metric_params = { + "interval_s": 60, + "min_spikes_per_interval": 100, + "direction": "y", + "min_num_bins": 2, + } + metric_columns = {"drift_ptp": float, "drift_std": float, "drift_mad": float} + depend_on = ["spike_locations"] -### LOW-LEVEL FUNCTIONS ### -def presence_ratio(spike_train, total_length, bin_edges=None, num_bin_edges=None, bin_n_spikes_thres=0): +def compute_sd_ratio( + sorting_analyzer: SortingAnalyzer, + unit_ids=None, + censored_period_ms: float = 4.0, + correct_for_drift: bool = True, + correct_for_template_itself: bool = True, + **kwargs, +): """ - Calculate the presence ratio for a single unit. + Computes the SD (Standard Deviation) of each unit's spike amplitudes, and compare it to the SD of noise. + In this case, noise refers to the global voltage trace on the same channel as the best channel of the unit. + (ideally (not implemented yet), the noise would be computed outside of spikes from the unit itself). + + TODO: Take jitter into account. Parameters ---------- - spike_train : np.ndarray - Spike times for this unit, in samples. - total_length : int - Total length of the recording in samples. - bin_edges : np.array, optional - Pre-computed bin edges (mutually exclusive with num_bin_edges). - num_bin_edges : int, optional - The number of bins edges to use to compute the presence ratio. - (mutually exclusive with bin_edges). - bin_n_spikes_thres : int, default: 0 - Minimum number of spikes within a bin to consider the unit active. + sorting_analyzer : SortingAnalyzer + A SortingAnalyzer object. + unit_ids : list or None, default: None + The list of unit ids to compute this metric. If None, all units are used. + censored_period_ms : float, default: 4.0 + The censored period in milliseconds. This is to remove any potential bursts that could affect the SD. + correct_for_drift : bool, default: True + If True, will subtract the amplitudes sequentiially to significantly reduce the impact of drift. + correct_for_template_itself : bool, default: True + If true, will take into account that the template itself impacts the standard deviation of the noise, + and will make a rough estimation of what that impact is (and remove it). + **kwargs : dict, default: {} + Keyword arguments for computing spike amplitudes and extremum channel. Returns ------- - presence_ratio : float - The presence ratio for one unit. - + num_spikes : dict + The number of spikes, across all segments, for each unit ID. """ - assert bin_edges is not None or num_bin_edges is not None, "Use either bin_edges or num_bin_edges" - assert bin_n_spikes_thres >= 0 - if bin_edges is not None: - bins = bin_edges - num_bin_edges = len(bin_edges) - else: - bins = num_bin_edges - h, _ = np.histogram(spike_train, bins=bins) - return np.sum(h > bin_n_spikes_thres) / (num_bin_edges - 1) + from spikeinterface.curation.curation_tools import find_duplicated_spikes + check_has_required_extensions("sd_ratio", sorting_analyzer) + kwargs, job_kwargs = split_job_kwargs(kwargs) + job_kwargs = fix_job_kwargs(job_kwargs) -def isi_violations(spike_trains, total_duration_s, isi_threshold_s=0.0015, min_isi_s=0): - """ - Calculate Inter-Spike Interval (ISI) violations. + sorting = sorting_analyzer.sorting - See compute_isi_violations for additional documentation + censored_period = int(round(censored_period_ms * 1e-3 * sorting_analyzer.sampling_frequency)) + if unit_ids is None: + unit_ids = sorting_analyzer.unit_ids - Parameters - ---------- - spike_trains : list of np.ndarrays - The spike times for each recording segment for one unit, in seconds. - total_duration_s : float - The total duration of the recording (in seconds). - isi_threshold_s : float, default: 0.0015 - Threshold for classifying adjacent spikes as an ISI violation, in seconds. - This is the biophysical refractory period. - min_isi_s : float, default: 0 - Minimum possible inter-spike interval, in seconds. - This is the artificial refractory period enforced - by the data acquisition system or post-processing algorithms. + if not sorting_analyzer.has_recording(): + warnings.warn( + "The `sd_ratio` metric cannot work with a recordless SortingAnalyzer object" + "SD ratio metric will be set to NaN" + ) + return {unit_id: np.nan for unit_id in unit_ids} - Returns - ------- - isi_violations_ratio : float - The isi violation ratio described in [1]. - isi_violations_rate : float - Rate of contaminating spikes as a fraction of overall rate. - Higher values indicate more contamination. - isi_violation_count : int - Number of violations. - """ + spike_amplitudes = sorting_analyzer.get_extension("spike_amplitudes").get_data() + + if not HAVE_NUMBA: + warnings.warn( + "'sd_ratio' metric computation requires numba. Install it with >>> pip install numba. " + "SD ratio metric will be set to NaN" + ) + return {unit_id: np.nan for unit_id in unit_ids} + job_kwargs["progress_bar"] = False + noise_levels = get_noise_levels( + sorting_analyzer.recording, return_in_uV=sorting_analyzer.return_in_uV, method="std", **job_kwargs + ) + best_channels = get_template_extremum_channel(sorting_analyzer, outputs="index", **kwargs) + n_spikes = sorting.count_num_spikes_per_unit() + + if correct_for_template_itself: + tamplates_array = get_dense_templates_array(sorting_analyzer, return_in_uV=sorting_analyzer.return_in_uV) + + spikes = sorting.to_spike_vector() + sd_ratio = {} + for unit_id in unit_ids: + unit_index = sorting_analyzer.sorting.id_to_index(unit_id) + + spk_amp = [] + + for segment_index in range(sorting_analyzer.get_num_segments()): + + spike_mask = (spikes["unit_index"] == unit_index) & (spikes["segment_index"] == segment_index) + spike_train = spikes[spike_mask]["sample_index"].astype(np.int64, copy=False) + amplitudes = spike_amplitudes[spike_mask] + + censored_indices = find_duplicated_spikes( + spike_train, + censored_period, + method="keep_first_iterative", + ) + + spk_amp.append(np.delete(amplitudes, censored_indices)) + + spk_amp = np.concatenate([spk_amp[i] for i in range(len(spk_amp))]) + + if len(spk_amp) == 0: + sd_ratio[unit_id] = np.nan + elif len(spk_amp) == 1: + sd_ratio[unit_id] = 0.0 + else: + if correct_for_drift: + unit_std = np.std(np.diff(spk_amp)) / np.sqrt(2) + else: + unit_std = np.std(spk_amp) + + best_channel = best_channels[unit_id] + std_noise = noise_levels[best_channel] + + if correct_for_template_itself: + # template = sorting_analyzer.get_template(unit_id, force_dense=True)[:, best_channel] + + template = tamplates_array[unit_index, :, :][:, best_channel] + nsamples = template.shape[0] + + # Computing the variance of a trace that is all 0 and n_spikes non-overlapping template. + # TODO: Take into account that templates for different segments might differ. + p = nsamples * n_spikes[unit_id] / sorting_analyzer.get_total_samples() + total_variance = p * np.mean(template**2) - p**2 * np.mean(template) ** 2 + + std_noise = np.sqrt(std_noise**2 - total_variance) + + sd_ratio[unit_id] = unit_std / std_noise + + return sd_ratio + + +class SDRatio(BaseMetric): + metric_name = "sd_ratio" + metric_function = compute_sd_ratio + metric_params = { + "censored_period_ms": 4.0, + "correct_for_drift": True, + "correct_for_template_itself": True, + } + metric_columns = {"sd_ratio": float} + needs_recording = True + depend_on = ["templates", "spike_amplitudes"] + + +# Group metrics into categories +misc_metrics_list = [ + NumSpikes, + FiringRate, + PresenceRatio, + SNR, + ISIViolation, + RPViolation, + SlidingRPViolation, + Synchrony, + FiringRange, + AmplitudeCV, + AmplitudeCutoff, + NoiseCutoff, + AmplitudeMedian, + Drift, + SDRatio, +] + + +def check_has_required_extensions(metric_name, sorting_analyzer): + metric = [m for m in misc_metrics_list if m.metric_name == metric_name][0] + dependencies = metric.depend_on + has_required_extensions = True + for dep in dependencies: + if "|" in dep: + # at least one of the extensions is required + ext_names = dep.split("|") + if not any([sorting_analyzer.has_extension(ext_name) for ext_name in ext_names]): + has_required_extensions = False + else: + if not sorting_analyzer.has_extension(dep): + has_required_extensions = False + if not has_required_extensions: + raise ValueError( + f"The metric '{metric_name}' requires the following extensions: {dependencies}. " + f"Please make sure your SortingAnalyzer has the required extensions." + ) + + +### LOW-LEVEL FUNCTIONS ### +def presence_ratio(spike_train, total_length, bin_edges=None, num_bin_edges=None, bin_n_spikes_thres=0): + """ + Calculate the presence ratio for a single unit. + + Parameters + ---------- + spike_train : np.ndarray + Spike times for this unit, in samples. + total_length : int + Total length of the recording in samples. + bin_edges : np.array, optional + Pre-computed bin edges (mutually exclusive with num_bin_edges). + num_bin_edges : int, optional + The number of bins edges to use to compute the presence ratio. + (mutually exclusive with bin_edges). + bin_n_spikes_thres : int, default: 0 + Minimum number of spikes within a bin to consider the unit active. + + Returns + ------- + presence_ratio : float + The presence ratio for one unit. + + """ + assert bin_edges is not None or num_bin_edges is not None, "Use either bin_edges or num_bin_edges" + assert bin_n_spikes_thres >= 0 + if bin_edges is not None: + bins = bin_edges + num_bin_edges = len(bin_edges) + else: + bins = num_bin_edges + h, _ = np.histogram(spike_train, bins=bins) + + return np.sum(h > bin_n_spikes_thres) / (num_bin_edges - 1) + + +def isi_violations(spike_trains, total_duration_s, isi_threshold_s=0.0015, min_isi_s=0): + """ + Calculate Inter-Spike Interval (ISI) violations. + + See compute_isi_violations for additional documentation + + Parameters + ---------- + spike_trains : list of np.ndarrays + The spike times for each recording segment for one unit, in seconds. + total_duration_s : float + The total duration of the recording (in seconds). + isi_threshold_s : float, default: 0.0015 + Threshold for classifying adjacent spikes as an ISI violation, in seconds. + This is the biophysical refractory period. + min_isi_s : float, default: 0 + Minimum possible inter-spike interval, in seconds. + This is the artificial refractory period enforced + by the data acquisition system or post-processing algorithms. + + Returns + ------- + isi_violations_ratio : float + The isi violation ratio described in [1]. + isi_violations_rate : float + Rate of contaminating spikes as a fraction of overall rate. + Higher values indicate more contamination. + isi_violation_count : int + Number of violations. + """ num_violations = 0 num_spikes = 0 @@ -1526,173 +1534,190 @@ def _compute_violations(obs_viol, firing_rate, spike_count, ref_period_dur, cont return confidence_score -if HAVE_NUMBA: - import numba - - @numba.jit(nopython=True, nogil=True, cache=False) - def _compute_nb_violations_numba(spike_train, t_r): - n_v = 0 - N = len(spike_train) - - for i in range(N): - for j in range(i + 1, N): - diff = spike_train[j] - spike_train[i] - - if diff > t_r: - break - - # if diff < t_c: - # continue - - n_v += 1 - - return n_v - - @numba.jit( - nopython=True, - nogil=True, - cache=False, - parallel=True, - ) - def _compute_rp_violations_numba(nb_rp_violations, spike_trains, spike_clusters, t_c, t_r): - n_units = len(nb_rp_violations) - - for i in numba.prange(n_units): - spike_train = spike_trains[spike_clusters == i] - n_v = _compute_nb_violations_numba(spike_train, t_r) - nb_rp_violations[i] += n_v +def _noise_cutoff(amps, high_quantile=0.25, low_quantile=0.1, n_bins=100): + """ + A metric to determine if a unit's amplitude distribution is cut off as it approaches zero, without assuming a Gaussian distribution. + Based on the histogram of the (transformed) amplitude: -def compute_sd_ratio( - sorting_analyzer: SortingAnalyzer, - censored_period_ms: float = 4.0, - correct_for_drift: bool = True, - correct_for_template_itself: bool = True, - unit_ids=None, - **kwargs, -): - """ - Computes the SD (Standard Deviation) of each unit's spike amplitudes, and compare it to the SD of noise. - In this case, noise refers to the global voltage trace on the same channel as the best channel of the unit. - (ideally (not implemented yet), the noise would be computed outside of spikes from the unit itself). + 1. This method compares counts in the lower-amplitude bins to counts in the higher_amplitude bins. + It computes the mean and std of an upper quantile of the distribution, and calculates how many standard deviations away + from that mean the lower-quantile bins lie. - TODO: Take jitter into account. + 2. The method also compares the counts in the lower-amplitude bins to the count in the highest bin and return their ratio. Parameters ---------- - sorting_analyzer : SortingAnalyzer - A SortingAnalyzer object. - censored_period_ms : float, default: 4.0 - The censored period in milliseconds. This is to remove any potential bursts that could affect the SD. - correct_for_drift : bool, default: True - If True, will subtract the amplitudes sequentiially to significantly reduce the impact of drift. - correct_for_template_itself : bool, default: True - If true, will take into account that the template itself impacts the standard deviation of the noise, - and will make a rough estimation of what that impact is (and remove it). - unit_ids : list or None, default: None - The list of unit ids to compute this metric. If None, all units are used. - **kwargs : dict, default: {} - Keyword arguments for computing spike amplitudes and extremum channel. + amps : array-like + Spike amplitudes. + high_quantile : float, default: 0.25 + Quantile of the amplitude range above which values are treated as "high" (e.g. 0.25 = top 25%), the reference region. + low_quantile : int, default: 0.1 + Quantile of the amplitude range below which values are treated as "low" (e.g. 0.1 = lower 10%), the test region. + n_bins: int, default: 100 + The number of bins to use to compute the amplitude histogram. Returns ------- - num_spikes : dict - The number of spikes, across all segments, for each unit ID. + cutoff : float + (mean(lower_bins_count) - mean(high_bins_count)) / std(high_bins_count) + ratio: float + mean(lower_bins_count) / highest_bin_count + """ + n_per_bin, bin_edges = np.histogram(amps, bins=n_bins) - from spikeinterface.curation.curation_tools import find_duplicated_spikes + maximum_bin_height = np.max(n_per_bin) - kwargs, job_kwargs = split_job_kwargs(kwargs) - job_kwargs = fix_job_kwargs(job_kwargs) + low_quantile_value = np.quantile(amps, q=low_quantile) - sorting = sorting_analyzer.sorting + # the indices for low-amplitude bins + low_indices = np.where(bin_edges[1:] <= low_quantile_value)[0] - censored_period = int(round(censored_period_ms * 1e-3 * sorting_analyzer.sampling_frequency)) - if unit_ids is None: - unit_ids = sorting_analyzer.unit_ids + high_quantile_value = np.quantile(amps, q=1 - high_quantile) - if not sorting_analyzer.has_recording(): + # the indices for high-amplitude bins + high_indices = np.where(bin_edges[:-1] >= high_quantile_value)[0] + + if len(low_indices) == 0: warnings.warn( - "The `sd_ratio` metric cannot work with a recordless SortingAnalyzer object" - "SD ratio metric will be set to NaN" + "No bin is selected to test cutoff. Please increase low_quantile. Setting noise cutoff and ratio to NaN" ) - return {unit_id: np.nan for unit_id in unit_ids} + return np.nan, np.nan - _has_required_extensions(sorting_analyzer, metric_name="sd_ratio") + # compute ratio between low-amplitude bins and the largest bin + low_counts = n_per_bin[low_indices] + mean_low_counts = np.mean(low_counts) + ratio = mean_low_counts / maximum_bin_height - spike_amplitudes = sorting_analyzer.get_extension("spike_amplitudes").get_data() + if len(high_indices) == 0: + warnings.warn( + "No bin is selected as the reference region. Please increase high_quantile. Setting noise cutoff to NaN" + ) + return np.nan, ratio - if not HAVE_NUMBA: + if len(high_indices) == 1: warnings.warn( - "'sd_ratio' metric computation requires numba. Install it with >>> pip install numba. " - "SD ratio metric will be set to NaN" + "Only one bin is selected as the reference region, and thus the standard deviation cannot be computed. " + "Please increase high_quantile. Setting noise cutoff to NaN" ) - return {unit_id: np.nan for unit_id in unit_ids} - noise_levels = get_noise_levels( - sorting_analyzer.recording, return_in_uV=sorting_analyzer.return_in_uV, method="std", **job_kwargs - ) - best_channels = get_template_extremum_channel(sorting_analyzer, outputs="index", **kwargs) - n_spikes = sorting.count_num_spikes_per_unit() + return np.nan, ratio - if correct_for_template_itself: - tamplates_array = get_dense_templates_array(sorting_analyzer, return_in_uV=sorting_analyzer.return_in_uV) + # compute cutoff from low-amplitude and high-amplitude bins + high_counts = n_per_bin[high_indices] + mean_high_counts = np.mean(high_counts) + std_high_counts = np.std(high_counts) + if std_high_counts == 0: + warnings.warn( + "All the high-amplitude bins have the same size. Please consider changing n_bins. " + "Setting noise cutoff to NaN" + ) + return np.nan, ratio - spikes = sorting.to_spike_vector() - sd_ratio = {} - for unit_id in unit_ids: - unit_index = sorting_analyzer.sorting.id_to_index(unit_id) + cutoff = (mean_low_counts - mean_high_counts) / std_high_counts + return cutoff, ratio - spk_amp = [] - for segment_index in range(sorting_analyzer.get_num_segments()): +def _get_synchrony_counts(spikes, synchrony_sizes, all_unit_ids): + """ + Compute synchrony counts, the number of simultaneous spikes with sizes `synchrony_sizes`. - spike_mask = (spikes["unit_index"] == unit_index) & (spikes["segment_index"] == segment_index) - spike_train = spikes[spike_mask]["sample_index"].astype(np.int64, copy=False) - amplitudes = spike_amplitudes[spike_mask] + Parameters + ---------- + spikes : np.array + Structured numpy array with fields ("sample_index", "unit_index", "segment_index"). + all_unit_ids : list or None, default: None + List of unit ids to compute the synchrony metrics. Expecting all units. + synchrony_sizes : None or np.array, default: None + The synchrony sizes to compute. Should be pre-sorted. - censored_indices = find_duplicated_spikes( - spike_train, - censored_period, - method="keep_first_iterative", - ) + Returns + ------- + synchrony_counts : np.ndarray + The synchrony counts for the synchrony sizes. - spk_amp.append(np.delete(amplitudes, censored_indices)) + References + ---------- + Based on concepts described in [GrĂ¼n]_ + This code was adapted from `Elephant - Electrophysiology Analysis Toolkit `_ + """ - spk_amp = np.concatenate([spk_amp[i] for i in range(len(spk_amp))]) + synchrony_counts = np.zeros((np.size(synchrony_sizes), len(all_unit_ids)), dtype=np.int64) - if len(spk_amp) == 0: - sd_ratio[unit_id] = np.nan - elif len(spk_amp) == 1: - sd_ratio[unit_id] = 0.0 - else: - if correct_for_drift: - unit_std = np.std(np.diff(spk_amp)) / np.sqrt(2) + # compute the occurrence of each sample_index. Count >2 means there's synchrony + _, unique_spike_index, counts = np.unique(spikes["sample_index"], return_index=True, return_counts=True) + + sync_indices = unique_spike_index[counts >= 2] + sync_counts = counts[counts >= 2] + + for i, sync_index in enumerate(sync_indices): + + num_of_syncs = sync_counts[i] + units_with_sync = [spikes[sync_index + a][1] for a in range(0, num_of_syncs)] + + # Counts inclusively. E.g. if there are 3 simultaneous spikes, these are also added + # to the 2 simultaneous spike bins. + how_many_bins_to_add_to = np.size(synchrony_sizes[synchrony_sizes <= num_of_syncs]) + synchrony_counts[:how_many_bins_to_add_to, units_with_sync] += 1 + + return synchrony_counts + + +def _get_amplitudes_by_units(sorting_analyzer, unit_ids, peak_sign): + # used by compute_amplitude_cutoffs and compute_amplitude_medians + + if (spike_amplitudes_extension := sorting_analyzer.get_extension("spike_amplitudes")) is not None: + return spike_amplitudes_extension.get_data(outputs="by_unit", concatenated=True) + + elif sorting_analyzer.has_extension("waveforms"): + amplitudes_by_units = {} + waveforms_ext = sorting_analyzer.get_extension("waveforms") + before = waveforms_ext.nbefore + extremum_channels_ids = get_template_extremum_channel(sorting_analyzer, peak_sign=peak_sign) + for unit_id in unit_ids: + waveforms = waveforms_ext.get_waveforms_one_unit(unit_id, force_dense=False) + chan_id = extremum_channels_ids[unit_id] + if sorting_analyzer.is_sparse(): + chan_ind = np.where(sorting_analyzer.sparsity.unit_id_to_channel_ids[unit_id] == chan_id)[0] else: - unit_std = np.std(spk_amp) + chan_ind = sorting_analyzer.channel_ids_to_indices([chan_id])[0] + amplitudes_by_units[unit_id] = waveforms[:, before, chan_ind] - best_channel = best_channels[unit_id] - std_noise = noise_levels[best_channel] + return amplitudes_by_units - if correct_for_template_itself: - # template = sorting_analyzer.get_template(unit_id, force_dense=True)[:, best_channel] - template = tamplates_array[unit_index, :, :][:, best_channel] - nsamples = template.shape[0] +if HAVE_NUMBA: + import numba - # Computing the variance of a trace that is all 0 and n_spikes non-overlapping template. - # TODO: Take into account that templates for different segments might differ. - p = nsamples * n_spikes[unit_id] / sorting_analyzer.get_total_samples() - total_variance = p * np.mean(template**2) - p**2 * np.mean(template) ** 2 + @numba.jit(nopython=True, nogil=True, cache=False) + def _compute_nb_violations_numba(spike_train, t_r): + n_v = 0 + N = len(spike_train) - std_noise = np.sqrt(std_noise**2 - total_variance) + for i in range(N): + for j in range(i + 1, N): + diff = spike_train[j] - spike_train[i] - sd_ratio[unit_id] = unit_std / std_noise + if diff > t_r: + break - return sd_ratio + # if diff < t_c: + # continue + n_v += 1 -_default_params["sd_ratio"] = dict( - censored_period_ms=4.0, - correct_for_drift=True, - correct_for_template_itself=True, -) + return n_v + + @numba.jit( + nopython=True, + nogil=True, + cache=False, + parallel=True, + ) + def _compute_rp_violations_numba(nb_rp_violations, spike_trains, spike_clusters, t_c, t_r): + n_units = len(nb_rp_violations) + + for i in numba.prange(n_units): + spike_train = spike_trains[spike_clusters == i] + n_v = _compute_nb_violations_numba(spike_train, t_r) + nb_rp_violations[i] += n_v diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/metrics/quality/pca_metrics.py similarity index 75% rename from src/spikeinterface/qualitymetrics/pca_metrics.py rename to src/spikeinterface/metrics/quality/pca_metrics.py index f3c95f7fd7..71469d84b4 100644 --- a/src/spikeinterface/qualitymetrics/pca_metrics.py +++ b/src/spikeinterface/metrics/quality/pca_metrics.py @@ -3,236 +3,354 @@ from __future__ import annotations import warnings -from copy import deepcopy -import platform -from tqdm.auto import tqdm -from warnings import warn - +from collections import namedtuple +from pathlib import Path import numpy as np +# Parallel processing +import platform import multiprocessing as mp from concurrent.futures import ProcessPoolExecutor -from threadpoolctl import threadpool_limits -from .misc_metrics import compute_num_spikes, compute_firing_rates -from spikeinterface.core import get_random_data_chunks, compute_sparsity -from spikeinterface.core.template_tools import get_template_extremum_channel +from spikeinterface.core.analyzer_extension_core import BaseMetric +from spikeinterface.core import get_random_data_chunks, compute_sparsity, load +from spikeinterface.metrics.spiketrain.metrics import compute_num_spikes, compute_firing_rates -_possible_pc_metric_names = [ - "isolation_distance", - "l_ratio", - "d_prime", - "nearest_neighbor", - "nn_isolation", - "nn_noise_overlap", - "silhouette", -] +def _mahalanobis_metrics_function(sorting_analyzer, unit_ids, tmp_data, **metric_params): + mahalanobis_result = namedtuple("MahalanobisResult", ["isolation_distance", "l_ratio"]) -_default_params = dict( - nearest_neighbor=dict( - max_spikes=10000, - n_neighbors=5, - ), - nn_isolation=dict( - max_spikes=10000, min_spikes=10, min_fr=0.0, n_neighbors=4, n_components=10, radius_um=100, peak_sign="neg" - ), - nn_noise_overlap=dict( - max_spikes=10000, min_spikes=10, min_fr=0.0, n_neighbors=4, n_components=10, radius_um=100, peak_sign="neg" - ), - silhouette=dict(method=("simplified",)), - isolation_distance=dict(), - l_ratio=dict(), - d_prime=dict(), -) + # Use pre-computed PCA data + pca_data_per_unit = tmp_data["pca_data_per_unit"] + isolation_distance_dict = {} + l_ratio_dict = {} -def get_quality_pca_metric_list(): - """Get a list of the available PCA-based quality metrics.""" - return deepcopy(_possible_pc_metric_names) + for unit_id in unit_ids: + pcs_flat = pca_data_per_unit[unit_id]["pcs_flat"] + labels = pca_data_per_unit[unit_id]["labels"] + try: + isolation_distance, l_ratio = mahalanobis_metrics(pcs_flat, labels, unit_id) + except: + isolation_distance = np.nan + l_ratio = np.nan -def compute_pc_metrics( - sorting_analyzer, - metric_names=None, - metric_params=None, - qm_params=None, - unit_ids=None, - seed=None, - n_jobs=1, - progress_bar=False, - mp_context=None, - max_threads_per_worker=None, -) -> dict: - """ - Calculate principal component derived metrics. + isolation_distance_dict[unit_id] = isolation_distance + l_ratio_dict[unit_id] = l_ratio - Parameters - ---------- - sorting_analyzer : SortingAnalyzer - A SortingAnalyzer object. - metric_names : list of str, default: None - The list of PC metrics to compute. - If not provided, defaults to all PC metrics. - metric_params : dict or None - Dictionary with parameters for each PC metric function. - unit_ids : list of int or None - List of unit ids to compute metrics for. - seed : int, default: None - Random seed value. - n_jobs : int - Number of jobs to parallelize metric computations. - progress_bar : bool - If True, progress bar is shown. + return mahalanobis_result(isolation_distance=isolation_distance_dict, l_ratio=l_ratio_dict) - Returns - ------- - pc_metrics : dict - The computed PC metrics. - """ - if qm_params is not None and metric_params is None: - deprecation_msg = ( - "`qm_params` is deprecated and will be removed in version 0.104.0. Please use metric_params instead" - ) - warn(deprecation_msg, category=DeprecationWarning, stacklevel=2) - metric_params = qm_params +class MahalanobisMetrics(BaseMetric): + metric_name = "mahalanobis_metrics" + metric_function = _mahalanobis_metrics_function + metric_params = {} + metric_columns = {"isolation_distance": float, "l_ratio": float} + depend_on = ["principal_components"] + needs_tmp_data = True - pca_ext = sorting_analyzer.get_extension("principal_components") - assert pca_ext is not None, "calculate_pc_metrics() need extension 'principal_components'" - sorting = sorting_analyzer.sorting +def _d_prime_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params): + # Use pre-computed PCA data + pca_data_per_unit = tmp_data["pca_data_per_unit"] + + d_prime_dict = {} + + for unit_id in unit_ids: + if len(unit_ids) == 1: + d_prime_dict[unit_id] = np.nan + continue + + pcs_flat = pca_data_per_unit[unit_id]["pcs_flat"] + labels = pca_data_per_unit[unit_id]["labels"] + + try: + d_prime = lda_metrics(pcs_flat, labels, unit_id) + except: + d_prime = np.nan + + d_prime_dict[unit_id] = d_prime - if metric_names is None: - metric_names = _possible_pc_metric_names.copy() - if metric_params is None: - metric_params = _default_params + return d_prime_dict - extremum_channels = get_template_extremum_channel(sorting_analyzer) - if unit_ids is None: - unit_ids = sorting_analyzer.unit_ids - channel_ids = sorting_analyzer.channel_ids +class DPrimeMetrics(BaseMetric): + metric_name = "d_prime" + metric_function = _d_prime_metric_function + metric_params = {} + metric_columns = {"d_prime": float} + depend_on = ["principal_components"] + needs_tmp_data = True - # create output dict of dict pc_metrics['metric_name'][unit_id] - pc_metrics = {k: {} for k in metric_names} - if "nearest_neighbor" in metric_names: - pc_metrics.pop("nearest_neighbor") - pc_metrics["nn_hit_rate"] = {} - pc_metrics["nn_miss_rate"] = {} - if "nn_isolation" in metric_names: - pc_metrics["nn_unit_id"] = {} +def _nn_one_unit(args): + unit_id, pcs_flat, labels, metric_params = args - possible_nn_metrics = ["nn_isolation", "nn_noise_overlap"] + try: + nn_hit_rate, nn_miss_rate = nearest_neighbors_metrics(pcs_flat, labels, unit_id, **metric_params) + except: + nn_hit_rate = np.nan + nn_miss_rate = np.nan + + return unit_id, nn_hit_rate, nn_miss_rate + + +def _nearest_neighbor_metric_function(sorting_analyzer, unit_ids, tmp_data, job_kwargs, **metric_params): + nn_result = namedtuple("NearestNeighborResult", ["nn_hit_rate", "nn_miss_rate"]) + + # Use pre-computed PCA data + pca_data_per_unit = tmp_data["pca_data_per_unit"] - nn_metrics = list(set(metric_names).intersection(possible_nn_metrics)) - non_nn_metrics = list(set(metric_names).difference(possible_nn_metrics)) + # Extract job parameters + n_jobs = job_kwargs.get("n_jobs", 1) + mp_context = job_kwargs.get("mp_context", None) - # Compute nspikes and firing rate outside of main loop for speed - if nn_metrics: - n_spikes_all_units = compute_num_spikes(sorting_analyzer, unit_ids=unit_ids) - fr_all_units = compute_firing_rates(sorting_analyzer, unit_ids=unit_ids) + nn_hit_rate_dict = {} + nn_miss_rate_dict = {} + + if n_jobs == 1: + # Sequential processing + units_loop = unit_ids + + for unit_id in units_loop: + pcs_flat = pca_data_per_unit[unit_id]["pcs_flat"] + labels = pca_data_per_unit[unit_id]["labels"] + + try: + nn_hit_rate, nn_miss_rate = nearest_neighbors_metrics(pcs_flat, labels, unit_id, **metric_params) + except: + nn_hit_rate = np.nan + nn_miss_rate = np.nan + + nn_hit_rate_dict[unit_id] = nn_hit_rate + nn_miss_rate_dict[unit_id] = nn_miss_rate else: - n_spikes_all_units = None - fr_all_units = None + if mp_context is not None and platform.system() == "Windows": + assert mp_context != "fork", "'fork' mp_context not supported on Windows!" + elif mp_context == "fork" and platform.system() == "Darwin": + warnings.warn('As of Python 3.8 "fork" is no longer considered safe on macOS') - run_in_parallel = n_jobs > 1 + # Prepare arguments - only pass pickle-able data + args_list = [] + for unit_id in unit_ids: + pcs_flat = pca_data_per_unit[unit_id]["pcs_flat"] + labels = pca_data_per_unit[unit_id]["labels"] + args_list.append((unit_id, pcs_flat, labels, metric_params)) - # this get dense projection for selected unit_ids - dense_projections, spike_unit_indices = pca_ext.get_some_projections(channel_ids=None, unit_ids=unit_ids) - all_labels = sorting.unit_ids[spike_unit_indices] + with ProcessPoolExecutor( + max_workers=n_jobs, + mp_context=mp.get_context(mp_context) if mp_context else None, + ) as executor: + results = executor.map(_nn_one_unit, args_list) - items = [] - for unit_id in unit_ids: - if sorting_analyzer.is_sparse(): - neighbor_channel_ids = sorting_analyzer.sparsity.unit_id_to_channel_ids[unit_id] - neighbor_unit_ids = [ - other_unit for other_unit in unit_ids if extremum_channels[other_unit] in neighbor_channel_ids - ] - else: - neighbor_channel_ids = channel_ids - neighbor_unit_ids = unit_ids - neighbor_channel_indices = sorting_analyzer.channel_ids_to_indices(neighbor_channel_ids) + for unit_id, nn_hit_rate, nn_miss_rate in results: + nn_hit_rate_dict[unit_id] = nn_hit_rate + nn_miss_rate_dict[unit_id] = nn_miss_rate - labels = all_labels[np.isin(all_labels, neighbor_unit_ids)] - if pca_ext.params["mode"] == "concatenated": - pcs = dense_projections[np.isin(all_labels, neighbor_unit_ids)] - else: - pcs = dense_projections[np.isin(all_labels, neighbor_unit_ids)][:, :, neighbor_channel_indices] - pcs_flat = pcs.reshape(pcs.shape[0], -1) + return nn_result(nn_hit_rate=nn_hit_rate_dict, nn_miss_rate=nn_miss_rate_dict) + + +class NearestNeighborMetrics(BaseMetric): + metric_name = "nearest_neighbor" + metric_function = _nearest_neighbor_metric_function + metric_params = {"max_spikes": 10000, "n_neighbors": 5} + metric_columns = {"nn_hit_rate": float, "nn_miss_rate": float} + depend_on = ["principal_components"] + needs_tmp_data = True + needs_job_kwargs = True - func_args = (pcs_flat, labels, non_nn_metrics, unit_id, unit_ids, metric_params, max_threads_per_worker) - items.append(func_args) +def _nn_advanced_one_unit(args): + unit_id, sorting_analyzer_or_folder, n_spikes_all_units, fr_all_units, metric_params, seed = args + if isinstance(sorting_analyzer_or_folder, (str, Path)): + sorting_analyzer = load(sorting_analyzer_or_folder) + else: + sorting_analyzer = sorting_analyzer_or_folder + + nn_isolation_params = { + k: v + for k, v in metric_params.items() + if k + in [ + "max_spikes", + "min_spikes", + "min_fr", + "n_neighbors", + "n_components", + "radius_um", + "peak_sign", + "min_spatial_overlap", + ] + } + nn_noise_params = { + k: v + for k, v in metric_params.items() + if k in ["max_spikes", "min_spikes", "min_fr", "n_neighbors", "n_components", "radius_um", "peak_sign"] + } + + # NN Isolation + try: + nn_isolation, nn_unit_id = nearest_neighbors_isolation( + sorting_analyzer, + unit_id, + n_spikes_all_units=n_spikes_all_units, + fr_all_units=fr_all_units, + seed=seed, + **nn_isolation_params, + ) + except: + nn_isolation, nn_unit_id = np.nan, np.nan - if not run_in_parallel and non_nn_metrics: - units_loop = enumerate(unit_ids) + # NN Noise Overlap + try: + nn_noise_overlap = nearest_neighbors_noise_overlap( + sorting_analyzer, + unit_id, + n_spikes_all_units=n_spikes_all_units, + fr_all_units=fr_all_units, + seed=seed, + **nn_noise_params, + ) + except: + nn_noise_overlap = np.nan + + return unit_id, nn_isolation, nn_unit_id, nn_noise_overlap + + +def _nn_advanced_metric_function(sorting_analyzer, unit_ids, tmp_data, job_kwargs, **metric_params): + nn_advanced_result = namedtuple("NNAdvancedResult", ["nn_isolation", "nn_noise_overlap"]) + + # Use pre-computed data + n_spikes_all_units = tmp_data["n_spikes_all_units"] + fr_all_units = tmp_data["fr_all_units"] + + # Extract job parameters + n_jobs = job_kwargs.get("n_jobs", 1) + progress_bar = False + mp_context = job_kwargs.get("mp_context", None) + seed = metric_params.get("seed", None) + + if sorting_analyzer.format == "memory" and n_jobs > 1: + warnings.warn( + "Computing 'nn_advanced' metric in parallel with a SortingAnalyzer in memory is not supported. " + "Falling back to n_jobs=1." + ) + n_jobs = 1 + + nn_isolation_dict = {} + nn_unit_id_dict = {} + nn_noise_overlap_dict = {} + + if n_jobs == 1: + # Sequential processing + units_loop = unit_ids if progress_bar: - units_loop = tqdm(units_loop, desc="calculate pc_metrics", total=len(unit_ids)) + from tqdm.auto import tqdm + + units_loop = tqdm(units_loop, desc="Advanced NN metrics") - for i, unit_id in units_loop: - pca_metrics_unit = pca_metrics_one_unit(items[i]) - for metric_name, metric in pca_metrics_unit.items(): - pc_metrics[metric_name][unit_id] = metric - elif run_in_parallel and non_nn_metrics: + for unit_id in units_loop: + _, nn_isolation, nn_unit_id, nn_noise_overlap = _nn_advanced_one_unit( + (unit_id, sorting_analyzer, n_spikes_all_units, fr_all_units, metric_params, seed) + ) + nn_isolation_dict[unit_id] = nn_isolation + nn_noise_overlap_dict[unit_id] = nn_noise_overlap + else: if mp_context is not None and platform.system() == "Windows": assert mp_context != "fork", "'fork' mp_context not supported on Windows!" elif mp_context == "fork" and platform.system() == "Darwin": warnings.warn('As of Python 3.8 "fork" is no longer considered safe on macOS') + # Prepare arguments + # If we got here, we are sure the sorting_analyzer is saved on disk + args_list = [] + for unit_id in unit_ids: + args_list.append((unit_id, sorting_analyzer.folder, n_spikes_all_units, fr_all_units, metric_params, seed)) + with ProcessPoolExecutor( max_workers=n_jobs, - mp_context=mp.get_context(mp_context), + mp_context=mp.get_context(mp_context) if mp_context else None, ) as executor: - results = executor.map(pca_metrics_one_unit, items) + results = executor.map(_nn_advanced_one_unit, args_list) if progress_bar: - results = tqdm(results, total=len(unit_ids), desc="calculate_pc_metrics") + from tqdm.auto import tqdm - for ui, pca_metrics_unit in enumerate(results): - unit_id = unit_ids[ui] - for metric_name, metric in pca_metrics_unit.items(): - pc_metrics[metric_name][unit_id] = metric + results = tqdm(results, total=len(unit_ids), desc="Advanced NN metrics") - for metric_name in nn_metrics: - units_loop = enumerate(unit_ids) - if progress_bar: - units_loop = tqdm(units_loop, desc=f"calculate {metric_name} metric", total=len(unit_ids)) + for unit_id, nn_isolation, nn_unit_id, nn_noise_overlap in results: + nn_isolation_dict[unit_id] = nn_isolation + nn_unit_id_dict[unit_id] = nn_unit_id + nn_noise_overlap_dict[unit_id] = nn_noise_overlap - func = _nn_metric_name_to_func[metric_name] - metric_params = metric_params[metric_name] if metric_name in metric_params else {} + return nn_advanced_result(nn_isolation=nn_isolation_dict, nn_noise_overlap=nn_noise_overlap_dict) - for _, unit_id in units_loop: - try: - res = func( - sorting_analyzer, - unit_id, - seed=seed, - n_spikes_all_units=n_spikes_all_units, - fr_all_units=fr_all_units, - **metric_params, - ) - except: - if metric_name == "nn_isolation": - res = (np.nan, np.nan) - elif metric_name == "nn_noise_overlap": - res = np.nan - if metric_name == "nn_isolation": - nn_isolation, nn_unit_id = res - pc_metrics["nn_isolation"][unit_id] = nn_isolation - pc_metrics["nn_unit_id"][unit_id] = nn_unit_id - elif metric_name == "nn_noise_overlap": - pc_metrics["nn_noise_overlap"][unit_id] = res +class NearestNeighborAdvancedMetrics(BaseMetric): + metric_name = "nn_advanced" + metric_function = _nn_advanced_metric_function + metric_params = { + "max_spikes": 1000, + "min_spikes": 10, + "min_fr": 0.0, + "n_neighbors": 4, + "n_components": 10, + "radius_um": 100, + "peak_sign": "neg", + "min_spatial_overlap": 0.5, + "seed": None, + } + metric_columns = {"nn_isolation": float, "nn_noise_overlap": float} + depend_on = ["principal_components", "waveforms", "templates"] + needs_tmp_data = True + needs_job_kwargs = True - return pc_metrics +def _silhouette_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params): + # Use pre-computed PCA data + pca_data_per_unit = tmp_data["pca_data_per_unit"] -################################################################# -# Code from spikemetrics + silhouette_dict = {} + method = metric_params.get("method", "simplified") + + for unit_id in unit_ids: + pcs_flat = pca_data_per_unit[unit_id]["pcs_flat"] + labels = pca_data_per_unit[unit_id]["labels"] + + try: + if method == "simplified": + silhouette_value = simplified_silhouette_score(pcs_flat, labels, unit_id) + else: # method == "full" + silhouette_value = silhouette_score(pcs_flat, labels, unit_id) + except: + silhouette_value = np.nan + silhouette_dict[unit_id] = silhouette_value + return silhouette_dict + + +class SilhouetteMetrics(BaseMetric): + metric_name = "silhouette" + metric_function = _silhouette_metric_function + metric_params = {"method": "simplified"} + metric_columns = {"silhouette": float} + depend_on = ["principal_components"] + needs_tmp_data = True + + +pca_metrics_list = [ + MahalanobisMetrics, + DPrimeMetrics, + NearestNeighborMetrics, + SilhouetteMetrics, + NearestNeighborAdvancedMetrics, +] + + +################################################################# +# Code from spikemetrics def mahalanobis_metrics(all_pcs, all_labels, this_unit_id): """ Calculate isolation distance and L-ratio (metrics computed from Mahalanobis distance). @@ -969,75 +1087,3 @@ def _compute_isolation(pcs_target_unit, pcs_other_unit, n_neighbors: int): isolation = (target_nn_in_target + other_nn_in_other) / (n_spikes_target + n_spikes_other) / n_neighbors_adjusted return isolation - - -def pca_metrics_one_unit(args): - - (pcs_flat, labels, metric_names, unit_id, unit_ids, metric_params, max_threads_per_worker) = args - - if max_threads_per_worker is None: - return _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, metric_params) - else: - with threadpool_limits(limits=int(max_threads_per_worker)): - return _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, metric_params) - - -def _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, metric_params): - pc_metrics = {} - # metrics - if "isolation_distance" in metric_names or "l_ratio" in metric_names: - try: - isolation_distance, l_ratio = mahalanobis_metrics(pcs_flat, labels, unit_id) - except: - isolation_distance = np.nan - l_ratio = np.nan - - if "isolation_distance" in metric_names: - pc_metrics["isolation_distance"] = isolation_distance - if "l_ratio" in metric_names: - pc_metrics["l_ratio"] = l_ratio - - if "d_prime" in metric_names: - if len(unit_ids) == 1: - d_prime = np.nan - else: - try: - d_prime = lda_metrics(pcs_flat, labels, unit_id) - except: - d_prime = np.nan - - pc_metrics["d_prime"] = d_prime - - if "nearest_neighbor" in metric_names: - try: - nn_hit_rate, nn_miss_rate = nearest_neighbors_metrics( - pcs_flat, labels, unit_id, **metric_params["nearest_neighbor"] - ) - except: - nn_hit_rate = np.nan - nn_miss_rate = np.nan - pc_metrics["nn_hit_rate"] = nn_hit_rate - pc_metrics["nn_miss_rate"] = nn_miss_rate - - if "silhouette" in metric_names: - silhouette_method = metric_params["silhouette"]["method"] - if "simplified" in silhouette_method: - try: - unit_silhouette_score = simplified_silhouette_score(pcs_flat, labels, unit_id) - except: - unit_silhouette_score = np.nan - pc_metrics["silhouette"] = unit_silhouette_score - if "full" in silhouette_method: - try: - unit_silhouette_score = silhouette_score(pcs_flat, labels, unit_id) - except: - unit_silhouette_score = np.nan - pc_metrics["silhouette_full"] = unit_silhouette_score - - return pc_metrics - - -_nn_metric_name_to_func = { - "nn_isolation": nearest_neighbors_isolation, - "nn_noise_overlap": nearest_neighbors_noise_overlap, -} diff --git a/src/spikeinterface/metrics/quality/quality_metrics.py b/src/spikeinterface/metrics/quality/quality_metrics.py new file mode 100644 index 0000000000..239669173a --- /dev/null +++ b/src/spikeinterface/metrics/quality/quality_metrics.py @@ -0,0 +1,205 @@ +"""Classes and functions for computing multiple quality metrics.""" + +from __future__ import annotations + +import warnings +import numpy as np + +from spikeinterface.core.template_tools import get_template_extremum_channel +from spikeinterface.core.sortinganalyzer import register_result_extension +from spikeinterface.core.analyzer_extension_core import BaseMetricExtension + +from .misc_metrics import misc_metrics_list +from .pca_metrics import pca_metrics_list + + +class ComputeQualityMetrics(BaseMetricExtension): + """ + Compute quality metrics on a `sorting_analyzer`. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + A SortingAnalyzer object. + metric_names : list or None + List of quality metrics to compute. + metric_params : dict of dicts or None + Dictionary with parameters for quality metrics calculation. + Default parameters can be obtained with: `si.qualitymetrics.get_default_quality_metrics_params()` + skip_pc_metrics : bool, default: False + If True, PC metrics computation is skipped. + delete_existing_metrics : bool, default: False + If True, any quality metrics attached to the `sorting_analyzer` are deleted. If False, any metrics which were previously calculated but are not included in `metric_names` are kept. + + Returns + ------- + metrics: pandas.DataFrame + Data frame with the computed metrics. + + Notes + ----- + principal_components are loaded automatically if already computed. + """ + + extension_name = "quality_metrics" + depend_on = [] + need_recording = False + use_nodepipeline = False + need_job_kwargs = True + need_backward_compatibility_on_load = True + metric_list = misc_metrics_list + pca_metrics_list + + def _handle_backward_compatibility_on_load(self): + # For backwards compatibility - this renames qm_params as metric_params + if (qm_params := self.params.get("qm_params")) is not None: + self.params["metric_params"] = qm_params + del self.params["qm_params"] + + def _set_params( + self, + metric_names: list[str] | None = None, + metric_params: dict | None = None, + delete_existing_metrics: bool = False, + metrics_to_compute: list[str] | None = None, + # common extension kwargs + peak_sign=None, + seed=None, + skip_pc_metrics=False, + ): + if metric_names is None: + metric_names = [m.metric_name for m in self.metric_list] + # if PC is available, PC metrics are automatically added to the list + if "nn_advanced" in metric_names: + # remove nn_advanced because too slow + metric_names.remove("nn_advanced") + if skip_pc_metrics: + pc_metric_names = [m.metric_name for m in pca_metrics_list] + metric_names = [m for m in metric_names if m not in pc_metric_names] + + return super()._set_params( + metric_names=metric_names, + metric_params=metric_params, + delete_existing_metrics=delete_existing_metrics, + metrics_to_compute=metrics_to_compute, + peak_sign=peak_sign, + seed=seed, + skip_pc_metrics=skip_pc_metrics, + ) + + def _prepare_data(self, sorting_analyzer, unit_ids=None): + """Prepare shared data for quality metrics computation.""" + # Pre-compute shared PCA data + from spikeinterface.metrics.spiketrain.metrics import compute_num_spikes, compute_firing_rates + + tmp_data = {} + + # Check if any PCA metrics are requested + pca_metric_names = [m.metric_name for m in pca_metrics_list] + requested_pca_metrics = [m for m in self.params["metric_names"] if m in pca_metric_names] + + if not requested_pca_metrics: + return tmp_data + + # Check if PCA extension is available + pca_ext = sorting_analyzer.get_extension("principal_components") + if pca_ext is None: + return tmp_data + + if unit_ids is None: + unit_ids = sorting_analyzer.unit_ids + + # Get dense PCA projections for all requested units + dense_projections, spike_unit_indices = pca_ext.get_some_projections(channel_ids=None, unit_ids=unit_ids) + all_labels = sorting_analyzer.sorting.unit_ids[spike_unit_indices] + + # Get extremum channels for neighbor selection in sparse mode + extremum_channels = get_template_extremum_channel(sorting_analyzer) + + # Pre-compute spike counts and firing rates if advanced NN metrics are requested + advanced_nn_metrics = ["nn_advanced"] # Our grouped advanced NN metric + if any(m in advanced_nn_metrics for m in requested_pca_metrics): + tmp_data["n_spikes_all_units"] = compute_num_spikes(sorting_analyzer, unit_ids=unit_ids) + tmp_data["fr_all_units"] = compute_firing_rates(sorting_analyzer, unit_ids=unit_ids) + + # Pre-compute per-unit PCA data and neighbor information + pca_data_per_unit = {} + for unit_id in unit_ids: + # Determine neighbor units based on sparsity + if sorting_analyzer.is_sparse(): + neighbor_channel_ids = sorting_analyzer.sparsity.unit_id_to_channel_ids[unit_id] + neighbor_unit_ids = [ + other_unit for other_unit in unit_ids if extremum_channels[other_unit] in neighbor_channel_ids + ] + neighbor_channel_indices = sorting_analyzer.channel_ids_to_indices(neighbor_channel_ids) + else: + neighbor_channel_ids = sorting_analyzer.channel_ids + neighbor_unit_ids = unit_ids + neighbor_channel_indices = sorting_analyzer.channel_ids_to_indices(neighbor_channel_ids) + + # Filter projections to neighbor units + labels = all_labels[np.isin(all_labels, neighbor_unit_ids)] + if pca_ext.params["mode"] == "concatenated": + pcs = dense_projections[np.isin(all_labels, neighbor_unit_ids)] + else: + pcs = dense_projections[np.isin(all_labels, neighbor_unit_ids)][:, :, neighbor_channel_indices] + pcs_flat = pcs.reshape(pcs.shape[0], -1) + + pca_data_per_unit[unit_id] = { + "pcs_flat": pcs_flat, + "labels": labels, + "neighbor_unit_ids": neighbor_unit_ids, + "neighbor_channel_ids": neighbor_channel_ids, + "neighbor_channel_indices": neighbor_channel_indices, + } + + tmp_data["pca_data_per_unit"] = pca_data_per_unit + + return tmp_data + + +register_result_extension(ComputeQualityMetrics) +compute_quality_metrics = ComputeQualityMetrics.function_factory() + + +def get_quality_metric_list(): + """ + Return a list of the available quality metrics. + """ + + return [m.metric_name for m in ComputeQualityMetrics.metric_list] + + +def get_quality_pca_metric_list(): + """ + Return a list of the available quality PCA metrics. + """ + + return [m.metric_name for m in pca_metrics_list] + + +def get_default_quality_metrics_params(metric_names=None): + """ + Return default dictionary of quality metrics parameters. + + Returns + ------- + dict + Default qm parameters with metric name as key and parameter dictionary as values. + """ + default_params = ComputeQualityMetrics.get_default_metric_params() + if metric_names is None: + return default_params + else: + metric_names = list(set(metric_names) & set(default_params.keys())) + metric_params = {m: default_params[m] for m in metric_names} + return metric_params + + +def get_default_qm_params(metric_names=None): + warnings.warn( + "`get_default_qm_params` is deprecated and will be removed in a version 0.105.0. " + "Please use `get_default_quality_metrics_params` instead.", + DeprecationWarning, + stacklevel=2, + ) + return get_default_quality_metrics_params(metric_names=metric_names) diff --git a/src/spikeinterface/qualitymetrics/tests/conftest.py b/src/spikeinterface/metrics/quality/tests/conftest.py similarity index 100% rename from src/spikeinterface/qualitymetrics/tests/conftest.py rename to src/spikeinterface/metrics/quality/tests/conftest.py diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py similarity index 94% rename from src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py rename to src/spikeinterface/metrics/quality/tests/test_metrics_functions.py index 79f25ac772..023c6629ff 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py @@ -12,24 +12,24 @@ synthesize_random_firings, ) -from spikeinterface.qualitymetrics.utils import create_ground_truth_pc_distributions +from spikeinterface.metrics.quality.utils import create_ground_truth_pc_distributions -from spikeinterface.qualitymetrics.quality_metric_list import ( - _misc_metric_name_to_func, -) +# from spikeinterface.metrics.quality_metric_list import ( +# _misc_metric_name_to_func, +# ) -from spikeinterface.qualitymetrics import ( +from spikeinterface.metrics.quality import ( get_quality_metric_list, - mahalanobis_metrics, - lda_metrics, - nearest_neighbors_metrics, - silhouette_score, - simplified_silhouette_score, + get_quality_pca_metric_list, + compute_quality_metrics, +) +from spikeinterface.metrics.quality.misc_metrics import ( + misc_metrics_list, compute_amplitude_cutoffs, compute_presence_ratios, compute_isi_violations, - compute_firing_rates, - compute_num_spikes, + # compute_firing_rates, + # compute_num_spikes, compute_snrs, compute_refrac_period_violations, compute_sliding_rp_violations, @@ -39,11 +39,19 @@ compute_firing_ranges, compute_amplitude_cv_metrics, compute_sd_ratio, + _noise_cutoff, _get_synchrony_counts, - compute_quality_metrics, ) -from spikeinterface.qualitymetrics.misc_metrics import _noise_cutoff +from spikeinterface.metrics.quality.pca_metrics import ( + pca_metrics_list, + mahalanobis_metrics, + lda_metrics, + nearest_neighbors_metrics, + silhouette_score, + simplified_silhouette_score, +) + from spikeinterface.core.basesorting import minimum_spike_dtype @@ -220,17 +228,16 @@ def test_unit_structure_in_output(small_sorting_analyzer): "rp_violation": {"refractory_period_ms": 10.0, "censored_period_ms": 0.0}, } - for metric_name in get_quality_metric_list(): - + for metric in misc_metrics_list: + metric_name = metric.metric_name + metric_fun = metric.metric_function try: qm_param = qm_params[metric_name] except: qm_param = {} - result_all = _misc_metric_name_to_func[metric_name](sorting_analyzer=small_sorting_analyzer, **qm_param) - result_sub = _misc_metric_name_to_func[metric_name]( - sorting_analyzer=small_sorting_analyzer, unit_ids=["#4", "#9"], **qm_param - ) + result_all = metric_fun(sorting_analyzer=small_sorting_analyzer, **qm_param) + result_sub = metric_fun(sorting_analyzer=small_sorting_analyzer, unit_ids=["#4", "#9"], **qm_param) if isinstance(result_all, dict): assert list(result_all.keys()) == ["#3", "#9", "#4"] @@ -283,10 +290,10 @@ def test_unit_id_order_independence(small_sorting_analyzer): } quality_metrics_1 = compute_quality_metrics( - small_sorting_analyzer, metric_names=get_quality_metric_list(), metric_params=qm_params + small_sorting_analyzer, metric_names=get_quality_metric_list(), metric_params=qm_params, skip_pc_metrics=True ) quality_metrics_2 = compute_quality_metrics( - small_sorting_analyzer_2, metric_names=get_quality_metric_list(), metric_params=qm_params + small_sorting_analyzer_2, metric_names=get_quality_metric_list(), metric_params=qm_params, skip_pc_metrics=True ) for metric, metric_2_data in quality_metrics_2.items(): @@ -479,16 +486,16 @@ def test_simplified_silhouette_score_metrics(): assert sim_sil_score1 < sim_sil_score2 -def test_calculate_firing_rate_num_spikes(sorting_analyzer_simple): - sorting_analyzer = sorting_analyzer_simple - firing_rates = compute_firing_rates(sorting_analyzer) - num_spikes = compute_num_spikes(sorting_analyzer) +# def test_calculate_firing_rate_num_spikes(sorting_analyzer_simple): +# sorting_analyzer = sorting_analyzer_simple +# firing_rates = compute_firing_rates(sorting_analyzer) +# num_spikes = compute_num_spikes(sorting_analyzer) - # testing method accuracy with magic number is not a good pratcice, I remove this. - # firing_rates_gt = {0: 10.01, 1: 5.03, 2: 5.09} - # num_spikes_gt = {0: 1001, 1: 503, 2: 509} - # assert np.allclose(list(firing_rates_gt.values()), list(firing_rates.values()), rtol=0.05) - # np.testing.assert_array_equal(list(num_spikes_gt.values()), list(num_spikes.values())) +# testing method accuracy with magic number is not a good pratcice, I remove this. +# firing_rates_gt = {0: 10.01, 1: 5.03, 2: 5.09} +# num_spikes_gt = {0: 1001, 1: 503, 2: 509} +# assert np.allclose(list(firing_rates_gt.values()), list(firing_rates.values()), rtol=0.05) +# np.testing.assert_array_equal(list(num_spikes_gt.values()), list(num_spikes.values())) def test_calculate_firing_range(sorting_analyzer_simple): diff --git a/src/spikeinterface/metrics/quality/tests/test_pca_metrics.py b/src/spikeinterface/metrics/quality/tests/test_pca_metrics.py new file mode 100644 index 0000000000..8227ad5156 --- /dev/null +++ b/src/spikeinterface/metrics/quality/tests/test_pca_metrics.py @@ -0,0 +1,60 @@ +import pytest +import warnings +import numpy as np + +from spikeinterface.metrics import compute_quality_metrics, get_quality_pca_metric_list + + +def test_compute_pc_metrics_multi_processing(small_sorting_analyzer, tmp_path): + import pandas as pd + + sorting_analyzer = small_sorting_analyzer + metric_names = get_quality_pca_metric_list() + metric_params = dict(nn_advanced=dict(seed=2308)) + res1 = compute_quality_metrics( + sorting_analyzer, metric_names=metric_names, n_jobs=1, progress_bar=True, seed=1205, metric_params=metric_params + ) + + # this should raise a warning, since nn_advanced can be parallelized only if not in memory + with pytest.warns(UserWarning): + res2 = compute_quality_metrics( + sorting_analyzer, + metric_names=metric_names, + n_jobs=2, + progress_bar=True, + seed=1205, + metric_params=metric_params, + ) + + # now we cache the analyzer and there should be no warning + sorting_analyzer_saved = sorting_analyzer.save_as(folder=tmp_path / "analyzer", format="binary_folder") + # assert no warnings this time + with warnings.catch_warnings(): + warnings.filterwarnings("error", message="Falling back to n_jobs=1.") + res2 = compute_quality_metrics( + sorting_analyzer_saved, metric_names=metric_names, n_jobs=2, progress_bar=True, seed=1205 + ) + + for metric_name in res1.columns: + values1 = res1[metric_name].values + values2 = res2[metric_name].values + + if values1.dtype.kind == "f": + np.testing.assert_almost_equal(values1, values2, decimal=4) + # import matplotlib.pyplot as plt + # fig, axs = plt.subplots(nrows=2, share=True) + # ax =a xs[0] + # ax.plot(res1[metric_name].values) + # ax.plot(res2[metric_name].values) + # ax =a xs[1] + # ax.plot(res2[metric_name].values - res1[metric_name].values) + # plt.show() + else: + assert np.array_equal(values1, values2) + + +if __name__ == "__main__": + from spikeinterface.metrics.tests.conftest import make_small_analyzer + + small_sorting_analyzer = make_small_analyzer() + test_compute_pc_metrics_multi_processing(small_sorting_analyzer) diff --git a/src/spikeinterface/metrics/quality/tests/test_quality_metric_calculator.py b/src/spikeinterface/metrics/quality/tests/test_quality_metric_calculator.py new file mode 100644 index 0000000000..ec72fdc178 --- /dev/null +++ b/src/spikeinterface/metrics/quality/tests/test_quality_metric_calculator.py @@ -0,0 +1,181 @@ +import pytest +from pathlib import Path +import numpy as np + +from spikeinterface.core import ( + generate_ground_truth_recording, + create_sorting_analyzer, + NumpySorting, + aggregate_units, +) + +from spikeinterface.metrics.quality.misc_metrics import compute_snrs, compute_drift_metrics + + +from spikeinterface.metrics import ( + compute_quality_metrics, +) + +job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") + + +def test_warnings_errors_when_missing_deps(): + """ + If the user requests to compute a quality metric which depends on an extension + that has not been computed, this should error. If the user uses the default + quality metrics (i.e. they do not explicitly request the specific metrics), + this should report a warning about which metrics could not be computed. + We check this behavior in this test. + """ + + recording, sorting = generate_ground_truth_recording() + analyzer = create_sorting_analyzer(sorting=sorting, recording=recording) + + # user tries to use `compute_snrs` without templates. Should error + with pytest.raises(ValueError): + compute_snrs(analyzer) + + # user asks for drift metrics without spike_locations. Should error + with pytest.raises(ValueError): + compute_drift_metrics(analyzer) + + # user doesn't specify which metrics to compute. Should return a warning + # about which metrics have not been computed. + with pytest.warns(Warning): + analyzer.compute("quality_metrics") + + +def test_compute_quality_metrics(sorting_analyzer_simple): + sorting_analyzer = sorting_analyzer_simple + + # without PCs + metrics = compute_quality_metrics( + sorting_analyzer, + metric_names=["snr"], + metric_params=dict(isi_violation=dict(isi_threshold_ms=2)), + skip_pc_metrics=True, + seed=2205, + ) + # print(metrics) + + qm = sorting_analyzer.get_extension("quality_metrics") + assert qm.params["metric_params"]["isi_violation"]["isi_threshold_ms"] == 2 + assert "snr" in metrics.columns + assert "isolation_distance" not in metrics.columns + + # with PCs + sorting_analyzer.compute("principal_components") + metrics = compute_quality_metrics( + sorting_analyzer, + metric_names=None, + metric_params=dict(isi_violation=dict(isi_threshold_ms=2)), + skip_pc_metrics=False, + seed=2205, + ) + print(metrics.columns) + assert "isolation_distance" in metrics.columns + + +def test_merging_quality_metrics(sorting_analyzer_simple): + + sorting_analyzer = sorting_analyzer_simple + + metrics = compute_quality_metrics( + sorting_analyzer, + metric_names=None, + metric_params=dict(isi_violation=dict(isi_threshold_ms=2)), + skip_pc_metrics=False, + seed=2205, + ) + + # sorting_analyzer_simple has ten units + new_sorting_analyzer = sorting_analyzer.merge_units([[0, 1]]) + new_metrics = new_sorting_analyzer.get_extension("quality_metrics").get_data() + + # we should copy over the metrics after merge + for column in metrics.columns: + assert column in new_metrics.columns + # should copy dtype too + assert metrics[column].dtype == new_metrics[column].dtype + + # 10 units vs 9 units + assert len(metrics.index) > len(new_metrics.index) + + +def test_compute_quality_metrics_recordingless(sorting_analyzer_simple): + + sorting_analyzer = sorting_analyzer_simple + metrics = compute_quality_metrics( + sorting_analyzer, + metric_names=None, + metric_params=dict(isi_violation=dict(isi_threshold_ms=2)), + skip_pc_metrics=False, + seed=2205, + ) + + # make a copy and make it recordingless + sorting_analyzer_norec = sorting_analyzer.save_as(format="memory") + sorting_analyzer_norec.delete_extension("quality_metrics") + sorting_analyzer_norec._recording = None + assert not sorting_analyzer_norec.has_recording() + + metrics_norec = compute_quality_metrics( + sorting_analyzer_norec, + metric_names=None, + metric_params=dict(isi_violation=dict(isi_threshold_ms=2)), + skip_pc_metrics=False, + seed=2205, + ) + + for metric_name in metrics.columns: + if metric_name == "sd_ratio": + # this one need recording!!! + continue + assert np.allclose(metrics[metric_name].values, metrics_norec[metric_name].values, rtol=1e-02) + + +def test_empty_units(sorting_analyzer_simple): + from pandas import isnull + + sorting_analyzer = sorting_analyzer_simple + + empty_spike_train = np.array([], dtype="int64") + empty_sorting = NumpySorting.from_unit_dict( + {100: empty_spike_train, 200: empty_spike_train, 300: empty_spike_train}, + sampling_frequency=sorting_analyzer.sampling_frequency, + ) + sorting_empty = aggregate_units([sorting_analyzer.sorting, empty_sorting]) + assert len(sorting_empty.get_empty_unit_ids()) == 3 + + sorting_analyzer_empty = create_sorting_analyzer(sorting_empty, sorting_analyzer.recording, format="memory") + sorting_analyzer_empty.compute("random_spikes", max_spikes_per_unit=300, seed=2205) + sorting_analyzer_empty.compute("noise_levels") + sorting_analyzer_empty.compute("waveforms", **job_kwargs) + sorting_analyzer_empty.compute("templates") + sorting_analyzer_empty.compute("spike_amplitudes", **job_kwargs) + + metrics_empty = compute_quality_metrics( + sorting_analyzer_empty, + metric_names=None, + metric_params=dict(isi_violation=dict(isi_threshold_ms=2)), + skip_pc_metrics=True, + seed=2205, + ) + + # test that metrics are either NaN or zero for empty units + empty_unit_ids = sorting_empty.get_empty_unit_ids() + + for col in metrics_empty.columns: + all_nans = np.all(isnull(metrics_empty.loc[empty_unit_ids, col].values)) + all_zeros = np.all(metrics_empty.loc[empty_unit_ids, col].values == 0) + assert all_nans or all_zeros + + +if __name__ == "__main__": + + sorting_analyzer = get_sorting_analyzer() + print(sorting_analyzer) + + test_compute_quality_metrics(sorting_analyzer) + test_compute_quality_metrics_recordingless(sorting_analyzer) + test_empty_units(sorting_analyzer) diff --git a/src/spikeinterface/qualitymetrics/utils.py b/src/spikeinterface/metrics/quality/utils.py similarity index 61% rename from src/spikeinterface/qualitymetrics/utils.py rename to src/spikeinterface/metrics/quality/utils.py index 90faf1a602..844a7da7f5 100644 --- a/src/spikeinterface/qualitymetrics/utils.py +++ b/src/spikeinterface/metrics/quality/utils.py @@ -2,29 +2,6 @@ import numpy as np -from spikeinterface.qualitymetrics.quality_metric_list import metric_extension_dependencies - - -def _has_required_extensions(sorting_analyzer, metric_name): - - required_extensions = metric_extension_dependencies[metric_name] - - not_computed_required_extensions = [] - for ext in required_extensions: - if all(sorting_analyzer.has_extension(name) is False for name in ext.split("|")): - not_computed_required_extensions.append(ext) - - if len(not_computed_required_extensions) > 0: - warnings_string = f"The `{metric_name}` metric requires the {not_computed_required_extensions} extensions.\n" - warnings_string += "Use the sorting_analyzer.compute([" - for count, ext in enumerate(not_computed_required_extensions): - if count == len(not_computed_required_extensions) - 1: - warnings_string += f"'{ext}'" - else: - warnings_string += f"'{ext}', " - warnings_string += f"]) method to compute." - raise ValueError(warnings_string) - def create_ground_truth_pc_distributions(center_locations, total_points): """ diff --git a/src/spikeinterface/metrics/spiketrain/__init__.py b/src/spikeinterface/metrics/spiketrain/__init__.py new file mode 100644 index 0000000000..ffbc8e1625 --- /dev/null +++ b/src/spikeinterface/metrics/spiketrain/__init__.py @@ -0,0 +1,8 @@ +from .spiketrain_metrics import ( + ComputeSpikeTrainMetrics, + compute_spiketrain_metrics, + get_default_spiketrain_metrics_params, + get_spiketrain_metric_list, +) + +from .metrics import compute_firing_rates, compute_num_spikes diff --git a/src/spikeinterface/metrics/spiketrain/metrics.py b/src/spikeinterface/metrics/spiketrain/metrics.py new file mode 100644 index 0000000000..39e244bb67 --- /dev/null +++ b/src/spikeinterface/metrics/spiketrain/metrics.py @@ -0,0 +1,84 @@ +import numpy as np +from spikeinterface.core.analyzer_extension_core import BaseMetric + + +def compute_num_spikes(sorting_analyzer, unit_ids=None, **kwargs): + """ + Compute the number of spike across segments. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + A SortingAnalyzer object. + unit_ids : list or None + The list of unit ids to compute the number of spikes. If None, all units are used. + + Returns + ------- + num_spikes : dict + The number of spikes, across all segments, for each unit ID. + """ + + sorting = sorting_analyzer.sorting + if unit_ids is None: + unit_ids = sorting.unit_ids + num_segs = sorting.get_num_segments() + + num_spikes = {} + for unit_id in unit_ids: + n = 0 + for segment_index in range(num_segs): + st = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) + n += st.size + num_spikes[unit_id] = n + + return num_spikes + + +class NumSpikes(BaseMetric): + metric_name = "num_spikes" + metric_function = compute_num_spikes + metric_params = {} + metric_columns = {"num_spikes": int} + + +def compute_firing_rates(sorting_analyzer, unit_ids=None): + """ + Compute the firing rate across segments. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + A SortingAnalyzer object. + unit_ids : list or None + The list of unit ids to compute the firing rate. If None, all units are used. + + Returns + ------- + firing_rates : dict of floats + The firing rate, across all segments, for each unit ID. + """ + + sorting = sorting_analyzer.sorting + if unit_ids is None: + unit_ids = sorting.unit_ids + total_duration = sorting_analyzer.get_total_duration() + + firing_rates = {} + num_spikes = compute_num_spikes(sorting_analyzer) + for unit_id in unit_ids: + if num_spikes[unit_id] == 0: + firing_rates[unit_id] = np.nan + else: + firing_rates[unit_id] = num_spikes[unit_id] / total_duration + return firing_rates + + +class FiringRate(BaseMetric): + metric_name = "firing_rate" + metric_function = compute_firing_rates + metric_params = {} + metric_columns = {"firing_rate": float} + + +spiketrain_metrics = [NumSpikes, FiringRate] diff --git a/src/spikeinterface/metrics/spiketrain/spiketrain_metrics.py b/src/spikeinterface/metrics/spiketrain/spiketrain_metrics.py new file mode 100644 index 0000000000..41aee3d74f --- /dev/null +++ b/src/spikeinterface/metrics/spiketrain/spiketrain_metrics.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +from spikeinterface.core.sortinganalyzer import register_result_extension +from spikeinterface.core.analyzer_extension_core import BaseMetricExtension + +from .metrics import spiketrain_metrics + + +class ComputeSpikeTrainMetrics(BaseMetricExtension): + """ + Compute spike train metrics including: + * num_spikes + * firing_rate + * TODO: add ACG/ISI metrics + * TODO: add burst metrics + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + The SortingAnalyzer object + metric_names : list or None, default: None + List of metrics to compute (see si.metrics.get_spiketrain_metric_names()) + delete_existing_metrics : bool, default: False + If True, any template metrics attached to the `sorting_analyzer` are deleted. If False, any metrics which were previously calculated but are not included in `metric_names` are kept, provided the `metric_params` are unchanged. + metric_params : dict of dicts or None, default: None + Dictionary with parameters for template metrics calculation. + Default parameters can be obtained with: `si.metrics.get_default_spiketrain_metrics_params()` + + Returns + ------- + spiketrain_metrics : pd.DataFrame + Dataframe with the computed spike train metrics. + + Notes + ----- + If any multi-channel metric is in the metric_names or include_multi_channel_metrics is True, sparsity must be None, + so that one metric value will be computed per unit. + For multi-channel metrics, 3D channel locations are not supported. By default, the depth direction is "y". + """ + + extension_name = "spiketrain_metrics" + depend_on = [] + need_backward_compatibility_on_load = True + metric_list = spiketrain_metrics + + +register_result_extension(ComputeSpikeTrainMetrics) +compute_spiketrain_metrics = ComputeSpikeTrainMetrics.function_factory() + + +def get_spiketrain_metric_list(): + return [m.metric_name for m in spiketrain_metrics] + + +def get_default_spiketrain_metrics_params(metric_names=None): + default_params = ComputeSpikeTrainMetrics.get_default_metric_params() + if metric_names is None: + return default_params + else: + metric_names = list(set(metric_names) & set(default_params.keys())) + metric_params = {m: default_params[m] for m in metric_names} + return metric_params diff --git a/src/spikeinterface/metrics/template/__init__.py b/src/spikeinterface/metrics/template/__init__.py new file mode 100644 index 0000000000..15912c1bf6 --- /dev/null +++ b/src/spikeinterface/metrics/template/__init__.py @@ -0,0 +1,10 @@ +from .template_metrics import ( + ComputeTemplateMetrics, + compute_template_metrics, + get_template_metric_list, + get_template_metric_names, + get_single_channel_template_metric_names, + get_multi_channel_template_metric_names, + get_default_template_metrics_params, + get_default_tm_params, +) diff --git a/src/spikeinterface/metrics/template/metrics.py b/src/spikeinterface/metrics/template/metrics.py new file mode 100644 index 0000000000..a9ba2f2602 --- /dev/null +++ b/src/spikeinterface/metrics/template/metrics.py @@ -0,0 +1,727 @@ +from __future__ import annotations + +import numpy as np +from collections import namedtuple + +from spikeinterface.core.analyzer_extension_core import BaseMetric + + +def get_trough_and_peak_idx(template): + """ + Return the indices into the input template of the detected trough + (minimum of template) and peak (maximum of template, after trough). + Assumes negative trough and positive peak. + + Parameters + ---------- + template: numpy.ndarray + The 1D template waveform + + Returns + ------- + trough_idx: int + The index of the trough + peak_idx: int + The index of the peak + """ + assert template.ndim == 1 + trough_idx = np.argmin(template) + peak_idx = trough_idx + np.argmax(template[trough_idx:]) + return trough_idx, peak_idx + + +######################################################################################### +# Single-channel metrics +def get_peak_to_valley(template_single, sampling_frequency, trough_idx=None, peak_idx=None, **kwargs) -> float: + """ + Return the peak to valley duration in seconds of input waveforms. + + Parameters + ---------- + template_single: numpy.ndarray + The 1D template waveform + sampling_frequency : float + The sampling frequency of the template + trough_idx: int, default: None + The index of the trough + peak_idx: int, default: None + The index of the peak + + Returns + ------- + ptv: float + The peak to valley duration in seconds + """ + if trough_idx is None or peak_idx is None: + trough_idx, peak_idx = get_trough_and_peak_idx(template_single) + ptv = (peak_idx - trough_idx) / sampling_frequency + return ptv + + +def get_peak_trough_ratio(template_single, sampling_frequency=None, trough_idx=None, peak_idx=None, **kwargs) -> float: + """ + Return the peak to trough ratio of input waveforms. + + Parameters + ---------- + template_single: numpy.ndarray + The 1D template waveform + sampling_frequency : float + The sampling frequency of the template + trough_idx: int, default: None + The index of the trough + peak_idx: int, default: None + The index of the peak + + Returns + ------- + ptratio: float + The peak to trough ratio + """ + if trough_idx is None or peak_idx is None: + trough_idx, peak_idx = get_trough_and_peak_idx(template_single) + ptratio = template_single[peak_idx] / template_single[trough_idx] + return ptratio + + +def get_half_width(template_single, sampling_frequency, trough_idx=None, peak_idx=None, **kwargs) -> float: + """ + Return the half width of input waveforms in seconds. + + Parameters + ---------- + template_single: numpy.ndarray + The 1D template waveform + sampling_frequency : float + The sampling frequency of the template + trough_idx: int, default: None + The index of the trough + peak_idx: int, default: None + The index of the peak + + Returns + ------- + hw: float + The half width in seconds + """ + if trough_idx is None or peak_idx is None: + trough_idx, peak_idx = get_trough_and_peak_idx(template_single) + + if peak_idx == 0: + return np.nan + + trough_val = template_single[trough_idx] + # threshold is half of peak height (assuming baseline is 0) + threshold = 0.5 * trough_val + + (cpre_idx,) = np.where(template_single[:trough_idx] < threshold) + (cpost_idx,) = np.where(template_single[trough_idx:] < threshold) + + if len(cpre_idx) == 0 or len(cpost_idx) == 0: + hw = np.nan + + else: + # last occurence of template lower than thr, before peak + cross_pre_pk = cpre_idx[0] - 1 + # first occurence of template lower than peak, after peak + cross_post_pk = cpost_idx[-1] + 1 + trough_idx + + hw = (cross_post_pk - cross_pre_pk) / sampling_frequency + return hw + + +def get_repolarization_slope(template_single, sampling_frequency, trough_idx=None, **kwargs): + """ + Return slope of repolarization period between trough and baseline + + After reaching it's maximum polarization, the neuron potential will + recover. The repolarization slope is defined as the dV/dT of the action potential + between trough and baseline. The returned slope is in units of (unit of template) + per second. By default traces are scaled to units of uV, controlled + by `sorting_analyzer.return_in_uV`. In this case this function returns the slope + in uV/s. + + Parameters + ---------- + template_single: numpy.ndarray + The 1D template waveform + sampling_frequency : float + The sampling frequency of the template + trough_idx: int, default: None + The index of the trough + + Returns + ------- + slope: float + The repolarization slope + """ + if trough_idx is None: + trough_idx, _ = get_trough_and_peak_idx(template_single) + + times = np.arange(template_single.shape[0]) / sampling_frequency + + if trough_idx == 0: + return np.nan + + (rtrn_idx,) = np.nonzero(template_single[trough_idx:] >= 0) + if len(rtrn_idx) == 0: + return np.nan + # first time after trough, where template is at baseline + return_to_base_idx = rtrn_idx[0] + trough_idx + + if return_to_base_idx - trough_idx < 3: + return np.nan + + import scipy.stats + + res = scipy.stats.linregress(times[trough_idx:return_to_base_idx], template_single[trough_idx:return_to_base_idx]) + return res.slope + + +def get_recovery_slope(template_single, sampling_frequency, peak_idx=None, **kwargs): + """ + Return the recovery slope of input waveforms. After repolarization, + the neuron hyperpolarizes until it peaks. The recovery slope is the + slope of the action potential after the peak, returning to the baseline + in dV/dT. The returned slope is in units of (unit of template) + per second. By default traces are scaled to units of uV, controlled + by `sorting_analyzer.return_in_uV`. In this case this function returns the slope + in uV/s. The slope is computed within a user-defined window after the peak. + + Parameters + ---------- + template_single: numpy.ndarray + The 1D template waveform + sampling_frequency : float + The sampling frequency of the template + peak_idx: int, default: None + The index of the peak + **kwargs: Required kwargs: + - recovery_window_ms: the window in ms after the peak to compute the recovery_slope + + Returns + ------- + res.slope: float + The recovery slope + """ + import scipy.stats + + assert "recovery_window_ms" in kwargs, "recovery_window_ms must be given as kwarg" + recovery_window_ms = kwargs["recovery_window_ms"] + if peak_idx is None: + _, peak_idx = get_trough_and_peak_idx(template_single) + + times = np.arange(template_single.shape[0]) / sampling_frequency + + if peak_idx == 0: + return np.nan + max_idx = int(peak_idx + ((recovery_window_ms / 1000) * sampling_frequency)) + max_idx = np.min([max_idx, template_single.shape[0]]) + + res = scipy.stats.linregress(times[peak_idx:max_idx], template_single[peak_idx:max_idx]) + return res.slope + + +def get_number_of_peaks(template_single, sampling_frequency, **kwargs): + """ + Count the total number of peaks (positive + negative) in the template. + + Parameters + ---------- + template_single: numpy.ndarray + The 1D template waveform + sampling_frequency : float + The sampling frequency of the template + **kwargs: Required kwargs: + - peak_relative_threshold: the relative threshold to detect positive and negative peaks + - peak_width_ms: the width in samples to detect peaks + + Returns + ------- + number_of_peaks: int + the total number of peaks (positive + negative) + """ + from scipy.signal import find_peaks + + assert "peak_relative_threshold" in kwargs, "peak_relative_threshold must be given as kwarg" + assert "peak_width_ms" in kwargs, "peak_width_ms must be given as kwarg" + peak_relative_threshold = kwargs["peak_relative_threshold"] + peak_width_ms = kwargs["peak_width_ms"] + max_value = np.max(np.abs(template_single)) + peak_width_samples = int(peak_width_ms / 1000 * sampling_frequency) + + pos_peaks = find_peaks(template_single, height=peak_relative_threshold * max_value, width=peak_width_samples) + neg_peaks = find_peaks(-template_single, height=peak_relative_threshold * max_value, width=peak_width_samples) + num_positive = len(pos_peaks[0]) + num_negative = len(neg_peaks[0]) + return num_positive, num_negative + + +######################################################################################### +# Multi-channel metrics +def transform_column_range(template, channel_locations, column_range, depth_direction="y"): + """ + Transform template and channel locations based on column range. + """ + column_dim = 0 if depth_direction == "y" else 1 + if column_range is None: + template_column_range = template + channel_locations_column_range = channel_locations + else: + max_channel_x = channel_locations[np.argmax(np.ptp(template, axis=0)), 0] + column_mask = np.abs(channel_locations[:, column_dim] - max_channel_x) <= column_range + template_column_range = template[:, column_mask] + channel_locations_column_range = channel_locations[column_mask] + return template_column_range, channel_locations_column_range + + +def sort_template_and_locations(template, channel_locations, depth_direction="y"): + """ + Sort template and locations. + """ + depth_dim = 1 if depth_direction == "y" else 0 + sort_indices = np.argsort(channel_locations[:, depth_dim]) + return template[:, sort_indices], channel_locations[sort_indices, :] + + +def fit_velocity(peak_times, channel_dist): + """ + Fit velocity from peak times and channel distances using robust Theilsen estimator. + """ + # from scipy.stats import linregress + # slope, intercept, _, _, _ = linregress(peak_times, channel_dist) + + from sklearn.linear_model import TheilSenRegressor + + theil = TheilSenRegressor() + theil.fit(peak_times.reshape(-1, 1), channel_dist) + slope = theil.coef_[0] + intercept = theil.intercept_ + score = theil.score(peak_times.reshape(-1, 1), channel_dist) + return slope, intercept, score + + +def get_velocity_fits(template, channel_locations, sampling_frequency, **kwargs): + """ + Compute both velocity above and below the max channel of the template in units um/s. + + Parameters + ---------- + template: numpy.ndarray + The template waveform (num_samples, num_channels) + channel_locations: numpy.ndarray + The channel locations (num_channels, 2) + sampling_frequency : float + The sampling frequency of the template + **kwargs: Required kwargs: + - depth_direction: the direction to compute velocity above and below ("x", "y", or "z") + - min_channels: the minimum number of channels above or below to compute velocity + - min_r2: the minimum r2 to accept the velocity fit + - column_range: the range in um in the x-direction to consider channels for velocity + + Returns + ------- + velocity_above : float + The velocity above the max channel + velocity_below : float + The velocity below the max channel + """ + assert "depth_direction" in kwargs, "depth_direction must be given as kwarg" + assert "min_channels" in kwargs, "min_channels must be given as kwarg" + assert "min_r2" in kwargs, "min_r2 must be given as kwarg" + assert "column_range" in kwargs, "column_range must be given as kwarg" + + depth_direction = kwargs["depth_direction"] + min_channels_for_velocity = kwargs["min_channels"] + min_r2 = kwargs["min_r2"] + column_range = kwargs["column_range"] + + depth_dim = 1 if depth_direction == "y" else 0 + template, channel_locations = transform_column_range(template, channel_locations, column_range, depth_direction) + template, channel_locations = sort_template_and_locations(template, channel_locations, depth_direction) + + # find location of max channel + max_sample_idx, max_channel_idx = np.unravel_index(np.argmin(template), template.shape) + max_peak_time = max_sample_idx / sampling_frequency * 1000 + max_channel_location = channel_locations[max_channel_idx] + + # Compute velocity above + channels_above = channel_locations[:, depth_dim] >= max_channel_location[depth_dim] + if np.sum(channels_above) < min_channels_for_velocity: + velocity_above = np.nan + else: + template_above = template[:, channels_above] + channel_locations_above = channel_locations[channels_above] + peak_times_ms_above = np.argmin(template_above, 0) / sampling_frequency * 1000 - max_peak_time + distances_um_above = np.array([np.linalg.norm(cl - max_channel_location) for cl in channel_locations_above]) + velocity_above, _, score = fit_velocity(peak_times_ms_above, distances_um_above) + if score < min_r2: + velocity_above = np.nan + + # Compute velocity below + channels_below = channel_locations[:, depth_dim] <= max_channel_location[depth_dim] + if np.sum(channels_below) < min_channels_for_velocity: + velocity_below = np.nan + else: + template_below = template[:, channels_below] + channel_locations_below = channel_locations[channels_below] + peak_times_ms_below = np.argmin(template_below, 0) / sampling_frequency * 1000 - max_peak_time + distances_um_below = np.array([np.linalg.norm(cl - max_channel_location) for cl in channel_locations_below]) + velocity_below, _, score = fit_velocity(peak_times_ms_below, distances_um_below) + if score < min_r2: + velocity_below = np.nan + + return velocity_above, velocity_below + + +def get_exp_decay(template, channel_locations, sampling_frequency=None, **kwargs): + """ + Compute the exponential decay of the template amplitude over distance in units um/s. + + Parameters + ---------- + template: numpy.ndarray + The template waveform (num_samples, num_channels) + channel_locations: numpy.ndarray + The channel locations (num_channels, 2) + sampling_frequency : float + The sampling frequency of the template + **kwargs: Required kwargs: + - peak_function: the function to use to compute the peak amplitude for the exp decay ("ptp" or "min") + - min_r2: the minimum r2 to accept the exp decay fit + + Returns + ------- + exp_decay_value : float + The exponential decay of the template amplitude + """ + from scipy.optimize import curve_fit + from sklearn.metrics import r2_score + + def exp_decay(x, decay, amp0, offset): + return amp0 * np.exp(-decay * x) + offset + + assert "peak_function" in kwargs, "peak_function must be given as kwarg" + peak_function = kwargs["peak_function"] + assert "min_r2" in kwargs, "min_r2 must be given as kwarg" + min_r2 = kwargs["min_r2"] + # exp decay fit + if peak_function == "ptp": + fun = np.ptp + elif peak_function == "min": + fun = np.min + peak_amplitudes = np.abs(fun(template, axis=0)) + max_channel_location = channel_locations[np.argmax(peak_amplitudes)] + channel_distances = np.array([np.linalg.norm(cl - max_channel_location) for cl in channel_locations]) + distances_sort_indices = np.argsort(channel_distances) + + # longdouble is float128 when the platform supports it, otherwise it is float64 + channel_distances_sorted = channel_distances[distances_sort_indices].astype(np.longdouble) + peak_amplitudes_sorted = peak_amplitudes[distances_sort_indices].astype(np.longdouble) + + try: + amp0 = peak_amplitudes_sorted[0] + offset0 = np.min(peak_amplitudes_sorted) + + popt, _ = curve_fit( + exp_decay, + channel_distances_sorted, + peak_amplitudes_sorted, + bounds=([1e-5, amp0 - 0.5 * amp0, 0], [2, amp0 + 0.5 * amp0, 2 * offset0]), + p0=[1e-3, peak_amplitudes_sorted[0], offset0], + ) + r2 = r2_score(peak_amplitudes_sorted, exp_decay(channel_distances_sorted, *popt)) + exp_decay_value = popt[0] + + if r2 < min_r2: + exp_decay_value = np.nan + except: + exp_decay_value = np.nan + + return exp_decay_value + + +def get_spread(template, channel_locations, sampling_frequency, **kwargs) -> float: + """ + Compute the spread of the template amplitude over distance in units um/s. + + Parameters + ---------- + template: numpy.ndarray + The template waveform (num_samples, num_channels) + channel_locations: numpy.ndarray + The channel locations (num_channels, 2) + sampling_frequency : float + The sampling frequency of the template + **kwargs: Required kwargs: + - depth_direction: the direction to compute velocity above and below ("x", "y", or "z") + - spread_threshold: the threshold to compute the spread + - column_range: the range in um in the x-direction to consider channels for velocity + + Returns + ------- + spread : float + Spread of the template amplitude + """ + assert "depth_direction" in kwargs, "depth_direction must be given as kwarg" + depth_direction = kwargs["depth_direction"] + assert "spread_threshold" in kwargs, "spread_threshold must be given as kwarg" + spread_threshold = kwargs["spread_threshold"] + assert "spread_smooth_um" in kwargs, "spread_smooth_um must be given as kwarg" + spread_smooth_um = kwargs["spread_smooth_um"] + assert "column_range" in kwargs, "column_range must be given as kwarg" + column_range = kwargs["column_range"] + + depth_dim = 1 if depth_direction == "y" else 0 + template, channel_locations = transform_column_range(template, channel_locations, column_range) + template, channel_locations = sort_template_and_locations(template, channel_locations, depth_direction) + + MM = np.ptp(template, 0) + channel_depths = channel_locations[:, depth_dim] + + if spread_smooth_um is not None and spread_smooth_um > 0: + from scipy.ndimage import gaussian_filter1d + + spread_sigma = spread_smooth_um / np.median(np.diff(np.unique(channel_depths))) + MM = gaussian_filter1d(MM, spread_sigma) + + MM = MM / np.max(MM) + + channel_locations_above_threshold = channel_locations[MM > spread_threshold] + channel_depth_above_threshold = channel_locations_above_threshold[:, depth_dim] + + spread = np.ptp(channel_depth_above_threshold) + + return spread + + +def single_channel_metric(unit_function, sorting_analyzer, unit_ids, tmp_data, **metric_params): + result = {} + templates_single = tmp_data["templates_single"] + troughs = tmp_data.get("troughs", None) + peaks = tmp_data.get("peaks", None) + sampling_frequency = tmp_data["sampling_frequency"] + for unit_index, unit_id in enumerate(unit_ids): + template_single = templates_single[unit_index] + trough_idx = troughs[unit_id] if troughs is not None else None + peak_idx = peaks[unit_id] if peaks is not None else None + metric_params["trough_idx"] = trough_idx + metric_params["peak_idx"] = peak_idx + value = unit_function(template_single, sampling_frequency, **metric_params) + result[unit_id] = value + return result + + +class PeakToValley(BaseMetric): + metric_name = "peak_to_valley" + metric_params = {} + metric_columns = {"peak_to_valley": float} + needs_tmp_data = True + + @staticmethod + def _peak_to_valley_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params): + return single_channel_metric( + unit_function=get_peak_to_valley, + sorting_analyzer=sorting_analyzer, + unit_ids=unit_ids, + tmp_data=tmp_data, + **metric_params, + ) + + metric_function = _peak_to_valley_metric_function + + +class PeakToTroughRatio(BaseMetric): + metric_name = "peak_trough_ratio" + metric_params = {} + metric_columns = {"peak_trough_ratio": float} + needs_tmp_data = True + + @staticmethod + def _peak_to_trough_ratio_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params): + return single_channel_metric( + unit_function=get_peak_trough_ratio, + sorting_analyzer=sorting_analyzer, + unit_ids=unit_ids, + tmp_data=tmp_data, + **metric_params, + ) + + metric_function = _peak_to_trough_ratio_metric_function + + +class HalfWidth(BaseMetric): + metric_name = "half_width" + metric_params = {} + metric_columns = {"half_width": float} + needs_tmp_data = True + + @staticmethod + def _half_width_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params): + return single_channel_metric( + unit_function=get_half_width, + sorting_analyzer=sorting_analyzer, + unit_ids=unit_ids, + tmp_data=tmp_data, + **metric_params, + ) + + metric_function = _half_width_metric_function + + +class RepolarizationSlope(BaseMetric): + metric_name = "repolarization_slope" + metric_params = {} + metric_columns = {"repolarization_slope": float} + needs_tmp_data = True + + @staticmethod + def _repolarization_slope_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params): + return single_channel_metric( + unit_function=get_repolarization_slope, + sorting_analyzer=sorting_analyzer, + unit_ids=unit_ids, + tmp_data=tmp_data, + **metric_params, + ) + + metric_function = _repolarization_slope_metric_function + + +class RecoverySlope(BaseMetric): + metric_name = "recovery_slope" + metric_params = {"recovery_window_ms": 0.7} + metric_columns = {"recovery_slope": float} + needs_tmp_data = True + + @staticmethod + def _recovery_slope_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params): + return single_channel_metric( + unit_function=get_recovery_slope, + sorting_analyzer=sorting_analyzer, + unit_ids=unit_ids, + tmp_data=tmp_data, + **metric_params, + ) + + metric_function = _recovery_slope_metric_function + + +def _number_of_peaks_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params): + num_peaks_result = namedtuple("NumberOfPeaksResult", ["num_positive_peaks", "num_negative_peaks"]) + num_positive_peaks_dict = {} + num_negative_peaks_dict = {} + sampling_frequency = sorting_analyzer.sampling_frequency + templates_single = tmp_data["templates_single"] + for unit_index, unit_id in enumerate(unit_ids): + template_single = templates_single[unit_index] + num_positive, num_negative = get_number_of_peaks(template_single, sampling_frequency, **metric_params) + num_positive_peaks_dict[unit_id] = num_positive + num_negative_peaks_dict[unit_id] = num_negative + return num_peaks_result(num_positive_peaks=num_positive_peaks_dict, num_negative_peaks=num_negative_peaks_dict) + + +class NumberOfPeaks(BaseMetric): + metric_name = "number_of_peaks" + metric_function = _number_of_peaks_metric_function + metric_params = {"peak_relative_threshold": 0.2, "peak_width_ms": 0.1} + metric_columns = {"num_positive_peaks": int, "num_negative_peaks": int} + needs_tmp_data = True + + +single_channel_metrics = [ + PeakToValley, + PeakToTroughRatio, + HalfWidth, + RepolarizationSlope, + RecoverySlope, + NumberOfPeaks, +] + + +def _get_velocity_fits_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params): + velocity_above_result = namedtuple("Velocities", ["velocity_above", "velocity_below"]) + velocity_above_dict = {} + velocity_below_dict = {} + templates_multi = tmp_data["templates_multi"] + channel_locations_multi = tmp_data["channel_locations_multi"] + sampling_frequency = tmp_data["sampling_frequency"] + metric_params["depth_direction"] = tmp_data["depth_direction"] + for unit_index, unit_id in enumerate(unit_ids): + channel_locations = channel_locations_multi[unit_index] + template = templates_multi[unit_index] + vel_above, vel_below = get_velocity_fits(template, channel_locations, sampling_frequency, **metric_params) + velocity_above_dict[unit_id] = vel_above + velocity_below_dict[unit_id] = vel_below + return velocity_above_result(velocity_above=velocity_above_dict, velocity_below=velocity_below_dict) + + +class VelocityFits(BaseMetric): + metric_name = "velocity_fits" + metric_function = _get_velocity_fits_metric_function + metric_params = { + "min_channels": 3, + "min_r2": 0.2, + "column_range": None, + } + metric_columns = {"velocity_above": float, "velocity_below": float} + needs_tmp_data = True + + +def multi_channel_metric(unit_function, sorting_analyzer, unit_ids, tmp_data, **metric_params): + result = {} + templates_multi = tmp_data["templates_multi"] + channel_locations_multi = tmp_data["channel_locations_multi"] + sampling_frequency = tmp_data["sampling_frequency"] + metric_params["depth_direction"] = tmp_data["depth_direction"] + for unit_index, unit_id in enumerate(unit_ids): + channel_locations = channel_locations_multi[unit_index] + template = templates_multi[unit_index] + value = unit_function(template, channel_locations, sampling_frequency, **metric_params) + result[unit_id] = value + return result + + +class ExpDecay(BaseMetric): + metric_name = "exp_decay" + metric_params = {"peak_function": "ptp", "min_r2": 0.2} + metric_columns = {"exp_decay": float} + needs_tmp_data = True + + @staticmethod + def _exp_decay_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params): + return multi_channel_metric( + unit_function=get_exp_decay, + sorting_analyzer=sorting_analyzer, + unit_ids=unit_ids, + tmp_data=tmp_data, + **metric_params, + ) + + metric_function = _exp_decay_metric_function + + +class Spread(BaseMetric): + metric_name = "spread" + metric_params = {"spread_threshold": 0.5, "spread_smooth_um": 20, "column_range": None} + metric_columns = {"spread": float} + needs_tmp_data = True + + @staticmethod + def _spread_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params): + return multi_channel_metric( + unit_function=get_spread, + sorting_analyzer=sorting_analyzer, + unit_ids=unit_ids, + tmp_data=tmp_data, + **metric_params, + ) + + metric_function = _spread_metric_function + + +multi_channel_metrics = [ + VelocityFits, + ExpDecay, + Spread, +] diff --git a/src/spikeinterface/metrics/template/template_metrics.py b/src/spikeinterface/metrics/template/template_metrics.py new file mode 100644 index 0000000000..da6187f355 --- /dev/null +++ b/src/spikeinterface/metrics/template/template_metrics.py @@ -0,0 +1,260 @@ +""" +Functions based on +https://github.com/AllenInstitute/ecephys_spike_sorting/blob/master/ecephys_spike_sorting/modules/mean_waveforms/waveform_metrics.py +22/04/2020 +""" + +from __future__ import annotations + +import numpy as np +import warnings +from copy import deepcopy + +from spikeinterface.core.sortinganalyzer import register_result_extension +from spikeinterface.core.analyzer_extension_core import BaseMetricExtension +from spikeinterface.core.template_tools import get_template_extremum_channel, get_dense_templates_array + +from .metrics import get_trough_and_peak_idx, single_channel_metrics, multi_channel_metrics + + +MIN_SPARSE_CHANNELS_FOR_MULTI_CHANNEL_WARNING = 10 +MIN_CHANNELS_FOR_MULTI_CHANNEL_METRICS = 64 + + +def get_single_channel_template_metric_names(): + return [m.metric_name for m in single_channel_metrics] + + +def get_multi_channel_template_metric_names(): + return [m.metric_name for m in multi_channel_metrics] + + +def get_template_metric_list(): + return get_single_channel_template_metric_names() + get_multi_channel_template_metric_names() + + +def get_template_metric_names(): + warnings.warn( + "get_template_metric_names is deprecated and will be removed in a version 0.105.0. " + "Please use get_template_metric_list instead.", + DeprecationWarning, + stacklevel=2, + ) + return get_template_metric_list() + + +class ComputeTemplateMetrics(BaseMetricExtension): + """ + Compute template metrics including: + * peak_to_valley + * peak_trough_ratio + * halfwidth + * repolarization_slope + * recovery_slope + * num_positive_peaks + * num_negative_peaks + + Optionally, the following multi-channel metrics can be computed (when include_multi_channel_metrics=True): + * velocity_above + * velocity_below + * exp_decay + * spread + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + The SortingAnalyzer object + metric_names : list or None, default: None + List of metrics to compute (see si.metrics.get_template_metric_names()) + delete_existing_metrics : bool, default: False + If True, any template metrics attached to the `sorting_analyzer` are deleted. If False, any metrics which were previously calculated but are not included in `metric_names` are kept, provided the `metric_params` are unchanged. + metric_params : dict of dicts or None, default: None + Dictionary with parameters for template metrics calculation. + Default parameters can be obtained with: `si.metrics.template_metrics.get_default_template_metrics_params()` + peak_sign : {"neg", "pos"}, default: "neg" + Whether to use the positive ("pos") or negative ("neg") peaks to estimate extremum channels. + upsampling_factor : int, default: 10 + The upsampling factor to upsample the templates + include_multi_channel_metrics : bool, default: False + Whether to compute multi-channel metrics + + Returns + ------- + template_metrics : pd.DataFrame + Dataframe with the computed template metrics. + + Notes + ----- + If any multi-channel metric is in the metric_names or include_multi_channel_metrics is True, sparsity must be None, + so that one metric value will be computed per unit. + For multi-channel metrics, 3D channel locations are not supported. By default, the depth direction is "y". + """ + + extension_name = "template_metrics" + depend_on = ["templates"] + need_backward_compatibility_on_load = True + metric_list = single_channel_metrics + multi_channel_metrics + + def _handle_backward_compatibility_on_load(self): + # For backwards compatibility - this reformats metrics_kwargs as metric_params + if (metrics_kwargs := self.params.get("metrics_kwargs")) is not None: + + metric_params = {} + for metric_name in self.params["metric_names"]: + metric_params[metric_name] = deepcopy(metrics_kwargs) + self.params["metric_params"] = metric_params + + del self.params["metrics_kwargs"] + + def _set_params( + self, + metric_names: list[str] | None = None, + metric_params: dict | None = None, + delete_existing_metrics: bool = False, + metrics_to_compute: list[str] | None = None, + # common extension kwargs + peak_sign="neg", + upsampling_factor=10, + include_multi_channel_metrics=False, + depth_direction="y", + ): + # Auto-detect if multi-channel metrics should be included based on number of channels + num_channels = self.sorting_analyzer.get_num_channels() + if not include_multi_channel_metrics and num_channels >= MIN_CHANNELS_FOR_MULTI_CHANNEL_METRICS: + include_multi_channel_metrics = True + + # Validate channel locations if multi-channel metrics are to be computed + if include_multi_channel_metrics or ( + metric_names is not None and any([m in get_multi_channel_template_metric_names() for m in metric_names]) + ): + assert ( + self.sorting_analyzer.get_channel_locations().shape[1] == 2 + ), "If multi-channel metrics are computed, channel locations must be 2D." + + if metric_names is None: + metric_names = get_single_channel_template_metric_names() + if include_multi_channel_metrics: + metric_names += get_multi_channel_template_metric_names() + + return super()._set_params( + metric_names=metric_names, + metric_params=metric_params, + delete_existing_metrics=delete_existing_metrics, + metrics_to_compute=metrics_to_compute, + peak_sign=peak_sign, + upsampling_factor=upsampling_factor, + include_multi_channel_metrics=include_multi_channel_metrics, + depth_direction=depth_direction, + ) + + def _prepare_data(self, sorting_analyzer, unit_ids): + from scipy.signal import resample_poly + + # compute templates_single and templates_multi (if include_multi_channel_metrics is True) + tmp_data = {} + + if unit_ids is None: + unit_ids = sorting_analyzer.unit_ids + peak_sign = self.params["peak_sign"] + upsampling_factor = self.params["upsampling_factor"] + sampling_frequency = sorting_analyzer.sampling_frequency + if self.params["upsampling_factor"] > 1: + sampling_frequency_up = upsampling_factor * sampling_frequency + else: + sampling_frequency_up = sampling_frequency + tmp_data["sampling_frequency"] = sampling_frequency_up + + include_multi_channel_metrics = self.params.get("include_multi_channel_metrics") or any( + m in get_multi_channel_template_metric_names() for m in self.params["metrics_to_compute"] + ) + + extremum_channel_indices = get_template_extremum_channel(sorting_analyzer, peak_sign=peak_sign, outputs="index") + all_templates = get_dense_templates_array(sorting_analyzer, return_in_uV=True) + channel_locations = sorting_analyzer.get_channel_locations() + + templates_single = [] + troughs = {} + peaks = {} + templates_multi = [] + channel_locations_multi = [] + for unit_id in unit_ids: + unit_index = sorting_analyzer.sorting.id_to_index(unit_id) + template_all_chans = all_templates[unit_index] + template_single = template_all_chans[:, extremum_channel_indices[unit_id]] + + # compute single_channel metrics + if upsampling_factor > 1: + template_upsampled = resample_poly(template_single, up=upsampling_factor, down=1) + else: + template_upsampled = template_single + sampling_frequency_up = sampling_frequency + trough_idx, peak_idx = get_trough_and_peak_idx(template_upsampled) + + templates_single.append(template_upsampled) + troughs[unit_id] = trough_idx + peaks[unit_id] = peak_idx + + if include_multi_channel_metrics: + if sorting_analyzer.is_sparse(): + mask = sorting_analyzer.sparsity.mask[unit_index, :] + template_multi = template_all_chans[:, mask] + channel_location_multi = channel_locations[mask] + else: + template_multi = template_all_chans + channel_location_multi = channel_locations + if template_multi.shape[1] < MIN_SPARSE_CHANNELS_FOR_MULTI_CHANNEL_WARNING: + warnings.warn( + f"With less than {MIN_SPARSE_CHANNELS_FOR_MULTI_CHANNEL_WARNING} channels, " + "multi-channel metrics might not be reliable." + ) + + if upsampling_factor > 1: + template_multi_upsampled = resample_poly(template_multi, up=upsampling_factor, down=1, axis=0) + else: + template_multi_upsampled = template_multi + templates_multi.append(template_multi_upsampled) + channel_locations_multi.append(channel_location_multi) + + tmp_data["troughs"] = troughs + tmp_data["peaks"] = peaks + tmp_data["templates_single"] = np.array(templates_single) + + if include_multi_channel_metrics: + # templates_multi is a list of 2D arrays of shape (n_times, n_channels) + tmp_data["templates_multi"] = templates_multi + tmp_data["channel_locations_multi"] = channel_locations_multi + tmp_data["depth_direction"] = self.params["depth_direction"] + + return tmp_data + + +register_result_extension(ComputeTemplateMetrics) +compute_template_metrics = ComputeTemplateMetrics.function_factory() + + +def get_default_template_metrics_params(metric_names=None): + default_params = ComputeTemplateMetrics.get_default_metric_params() + if metric_names is None: + return default_params + else: + metric_names = list(set(metric_names) & set(default_params.keys())) + metric_params = {m: default_params[m] for m in metric_names} + return metric_params + + +def get_default_tm_params(metric_names=None): + """ + Return default dictionary of template metrics parameters. + + Returns + ------- + metric_params : dict + Dictionary with default parameters for template metrics. + """ + warnings.warn( + "get_default_tm_params is deprecated and will be removed in a version 0.105.0. " + "Please use get_default_template_metrics_params instead.", + DeprecationWarning, + stacklevel=2, + ) + return get_default_template_metrics_params(metric_names) diff --git a/src/spikeinterface/postprocessing/tests/test_template_metrics.py b/src/spikeinterface/metrics/template/tests/test_template_metrics.py similarity index 62% rename from src/spikeinterface/postprocessing/tests/test_template_metrics.py rename to src/spikeinterface/metrics/template/tests/test_template_metrics.py index f5f34635e7..d42fd12b4c 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_metrics.py +++ b/src/spikeinterface/metrics/template/tests/test_template_metrics.py @@ -1,11 +1,15 @@ -from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite -from spikeinterface.postprocessing import ComputeTemplateMetrics, compute_template_metrics import pytest -import csv -from spikeinterface.postprocessing.template_metrics import _single_channel_metric_name_to_func +from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite +from spikeinterface.metrics.template import ( + ComputeTemplateMetrics, + compute_template_metrics, + get_single_channel_template_metric_names, +) +from spikeinterface.metrics.template.metrics import single_channel_metrics, multi_channel_metrics + -template_metrics = list(_single_channel_metric_name_to_func.keys()) +template_metrics = get_single_channel_template_metric_names() def test_different_params_template_metrics(small_sorting_analyzer): @@ -16,39 +20,14 @@ def test_different_params_template_metrics(small_sorting_analyzer): compute_template_metrics( sorting_analyzer=small_sorting_analyzer, metric_names=["exp_decay", "spread", "half_width"], - metric_params={"exp_decay": {"recovery_window_ms": 0.8}, "spread": {"spread_smooth_um": 15}}, + metric_params={"exp_decay": {"peak_function": "min"}, "spread": {"spread_smooth_um": 15}}, ) tm_extension = small_sorting_analyzer.get_extension("template_metrics") tm_params = tm_extension.params["metric_params"] - assert tm_params["exp_decay"]["recovery_window_ms"] == 0.8 - assert tm_params["spread"]["recovery_window_ms"] == 0.7 - assert tm_params["half_width"]["recovery_window_ms"] == 0.7 - + assert tm_params["exp_decay"]["peak_function"] == "min" assert tm_params["spread"]["spread_smooth_um"] == 15 - assert tm_params["exp_decay"]["spread_smooth_um"] == 20 - assert tm_params["half_width"]["spread_smooth_um"] == 20 - - -def test_backwards_compat_params_template_metrics(small_sorting_analyzer): - """ - Computes template metrics using the metrics_kwargs keyword - """ - compute_template_metrics( - sorting_analyzer=small_sorting_analyzer, - metric_names=["exp_decay", "spread"], - metrics_kwargs={"recovery_window_ms": 0.8}, - ) - - tm_extension = small_sorting_analyzer.get_extension("template_metrics") - tm_params = tm_extension.params["metric_params"] - - assert tm_params["exp_decay"]["recovery_window_ms"] == 0.8 - assert tm_params["spread"]["recovery_window_ms"] == 0.8 - - assert tm_params["spread"]["spread_smooth_um"] == 20 - assert tm_params["exp_decay"]["spread_smooth_um"] == 20 def test_compute_new_template_metrics(small_sorting_analyzer): @@ -92,7 +71,12 @@ def test_compute_new_template_metrics(small_sorting_analyzer): # check that, when parameters are changed, the old metrics are deleted small_sorting_analyzer.compute( - {"template_metrics": {"metric_names": ["exp_decay"], "metric_params": {"recovery_window_ms": 0.6}}} + { + "template_metrics": { + "metric_names": ["exp_decay"], + "metric_params": {"recovery_slope": {"recovery_window_ms": 0.6}}, + } + } ) @@ -100,11 +84,13 @@ def test_metric_names_in_same_order(small_sorting_analyzer): """ Computes sepecified template metrics and checks order is propagated. """ - specified_metric_names = ["peak_trough_ratio", "num_negative_peaks", "half_width"] - small_sorting_analyzer.compute("template_metrics", metric_names=specified_metric_names) - tm_keys = small_sorting_analyzer.get_extension("template_metrics").get_data().keys() - for i in range(3): - assert specified_metric_names[i] == tm_keys[i] + specified_metric_names = ["peak_trough_ratio", "half_width", "peak_to_valley"] + small_sorting_analyzer.compute( + "template_metrics", metric_names=specified_metric_names, delete_existing_metrics=True + ) + tm_columns = small_sorting_analyzer.get_extension("template_metrics").get_data().columns + for specified_name, column in zip(specified_metric_names, tm_columns): + assert specified_name == column def test_save_template_metrics(small_sorting_analyzer, create_cache_folder): @@ -112,7 +98,11 @@ def test_save_template_metrics(small_sorting_analyzer, create_cache_folder): Computes template metrics in binary folder format. Then computes subsets of template metrics and checks if they are saved correctly. """ + import pandas as pd + column_names = [] + for m in single_channel_metrics: + column_names.extend(list(m.metric_columns.keys())) small_sorting_analyzer.compute("template_metrics") cache_folder = create_cache_folder @@ -121,29 +111,24 @@ def test_save_template_metrics(small_sorting_analyzer, create_cache_folder): folder_analyzer = small_sorting_analyzer.save_as(format="binary_folder", folder=output_folder) template_metrics_filename = output_folder / "extensions" / "template_metrics" / "metrics.csv" - with open(template_metrics_filename) as metrics_file: - saved_metrics = csv.reader(metrics_file) - metric_names = next(saved_metrics) + saved_metrics = pd.read_csv(template_metrics_filename) + metric_names = saved_metrics.columns - for metric_name in template_metrics: + for metric_name in column_names: assert metric_name in metric_names folder_analyzer.compute("template_metrics", metric_names=["half_width"], delete_existing_metrics=False) - with open(template_metrics_filename) as metrics_file: - saved_metrics = csv.reader(metrics_file) - metric_names = next(saved_metrics) - - for metric_name in template_metrics: + saved_metrics = pd.read_csv(template_metrics_filename) + metric_names = saved_metrics.columns + for metric_name in column_names: assert metric_name in metric_names folder_analyzer.compute("template_metrics", metric_names=["half_width"], delete_existing_metrics=True) - with open(template_metrics_filename) as metrics_file: - saved_metrics = csv.reader(metrics_file) - metric_names = next(saved_metrics) - - for metric_name in template_metrics: + saved_metrics = pd.read_csv(template_metrics_filename) + metric_names = saved_metrics.columns + for metric_name in column_names: if metric_name == "half_width": assert metric_name in metric_names else: diff --git a/src/spikeinterface/postprocessing/__init__.py b/src/spikeinterface/postprocessing/__init__.py index b1adbff281..dca9711ccd 100644 --- a/src/spikeinterface/postprocessing/__init__.py +++ b/src/spikeinterface/postprocessing/__init__.py @@ -1,9 +1,3 @@ -from .template_metrics import ( - ComputeTemplateMetrics, - compute_template_metrics, - get_template_metric_names, -) - from .template_similarity import ( ComputeTemplateSimilarity, compute_template_similarity, @@ -45,3 +39,8 @@ from .alignsorting import align_sorting, AlignSortingExtractor from .noise_level import compute_noise_levels, ComputeNoiseLevels + +from .template_metrics import ( + ComputeTemplateMetrics, + compute_template_metrics, +) diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 8583114d86..403e690b8c 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -1,1089 +1,26 @@ -""" -Functions based on -https://github.com/AllenInstitute/ecephys_spike_sorting/blob/master/ecephys_spike_sorting/modules/mean_waveforms/waveform_metrics.py -22/04/2020 -""" - -from __future__ import annotations - -import numpy as np import warnings -from itertools import chain -from copy import deepcopy - -from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension -from spikeinterface.core.template_tools import get_template_extremum_channel -from spikeinterface.core.template_tools import get_dense_templates_array - -# DEBUG = False - - -def get_single_channel_template_metric_names(): - return deepcopy(list(_single_channel_metric_name_to_func.keys())) - - -def get_multi_channel_template_metric_names(): - return deepcopy(list(_multi_channel_metric_name_to_func.keys())) - - -def get_template_metric_names(): - return get_single_channel_template_metric_names() + get_multi_channel_template_metric_names() - - -class ComputeTemplateMetrics(AnalyzerExtension): - """ - Compute template metrics including: - * peak_to_valley - * peak_trough_ratio - * halfwidth - * repolarization_slope - * recovery_slope - * num_positive_peaks - * num_negative_peaks - - Optionally, the following multi-channel metrics can be computed (when include_multi_channel_metrics=True): - * velocity_above - * velocity_below - * exp_decay - * spread - - Parameters - ---------- - sorting_analyzer : SortingAnalyzer - The SortingAnalyzer object - metric_names : list or None, default: None - List of metrics to compute (see si.postprocessing.get_template_metric_names()) - peak_sign : {"neg", "pos"}, default: "neg" - Whether to use the positive ("pos") or negative ("neg") peaks to estimate extremum channels. - upsampling_factor : int, default: 10 - The upsampling factor to upsample the templates - sparsity : ChannelSparsity or None, default: None - If None, template metrics are computed on the extremum channel only. - If sparsity is given, template metrics are computed on all sparse channels of each unit. - For more on generating a ChannelSparsity, see the `~spikeinterface.compute_sparsity()` function. - include_multi_channel_metrics : bool, default: False - Whether to compute multi-channel metrics - delete_existing_metrics : bool, default: False - If True, any template metrics attached to the `sorting_analyzer` are deleted. If False, any metrics which were previously calculated but are not included in `metric_names` are kept, provided the `metric_params` are unchanged. - metric_params : dict of dicts or None, default: None - Dictionary with parameters for template metrics calculation. - Default parameters can be obtained with: `si.postprocessing.template_metrics.get_default_tm_params()` - - Returns - ------- - template_metrics : pd.DataFrame - Dataframe with the computed template metrics. - If "sparsity" is None, the index is the unit_id. - If "sparsity" is given, the index is a multi-index (unit_id, channel_id) - - Notes - ----- - If any multi-channel metric is in the metric_names or include_multi_channel_metrics is True, sparsity must be None, - so that one metric value will be computed per unit. - For multi-channel metrics, 3D channel locations are not supported. By default, the depth direction is "y". - """ - - extension_name = "template_metrics" - depend_on = ["templates"] - need_recording = False - use_nodepipeline = False - need_job_kwargs = False - need_backward_compatibility_on_load = True - - min_channels_for_multi_channel_warning = 10 - - def _handle_backward_compatibility_on_load(self): - - # For backwards compatibility - this reformats metrics_kwargs as metric_params - if (metrics_kwargs := self.params.get("metrics_kwargs")) is not None: - - metric_params = {} - for metric_name in self.params["metric_names"]: - metric_params[metric_name] = deepcopy(metrics_kwargs) - self.params["metric_params"] = metric_params - - del self.params["metrics_kwargs"] - - def _set_params( - self, - metric_names=None, - peak_sign="neg", - upsampling_factor=10, - sparsity=None, - metric_params=None, - metrics_kwargs=None, - include_multi_channel_metrics=False, - delete_existing_metrics=False, - **other_kwargs, - ): - - # TODO alessio can you check this : this used to be in the function but now we have ComputeTemplateMetrics.function_factory() - if include_multi_channel_metrics or ( - metric_names is not None and any([m in get_multi_channel_template_metric_names() for m in metric_names]) - ): - assert sparsity is None, ( - "If multi-channel metrics are computed, sparsity must be None, " - "so that each unit will correspond to 1 row of the output dataframe." - ) - assert ( - self.sorting_analyzer.get_channel_locations().shape[1] == 2 - ), "If multi-channel metrics are computed, channel locations must be 2D." - - if metric_names is None: - metric_names = get_single_channel_template_metric_names() - if include_multi_channel_metrics: - metric_names += get_multi_channel_template_metric_names() - - if metrics_kwargs is not None and metric_params is None: - deprecation_msg = "`metrics_kwargs` is deprecated and will be removed in version 0.104.0. Please use `metric_params` instead" - deprecation_msg = "`metrics_kwargs` is deprecated and will be removed in version 0.104.0. Please use `metric_params` instead" - metric_params = {} - for metric_name in metric_names: - metric_params[metric_name] = deepcopy(metrics_kwargs) - metric_params_ = get_default_tm_params(metric_names) - for k in metric_params_: - if metric_params is not None and k in metric_params: - metric_params_[k].update(metric_params[k]) +from spikeinterface.metrics.template import ComputeTemplateMetrics as ComputeTemplateMetricsNew +from spikeinterface.metrics.template import compute_template_metrics as compute_template_metrics_new - metrics_to_compute = metric_names - tm_extension = self.sorting_analyzer.get_extension("template_metrics") - if delete_existing_metrics is False and tm_extension is not None: - existing_metric_names = tm_extension.params["metric_names"] - existing_metric_names_propagated = [ - metric_name for metric_name in existing_metric_names if metric_name not in metrics_to_compute - ] - metric_names = metrics_to_compute + existing_metric_names_propagated - - params = dict( - metric_names=metric_names, - sparsity=sparsity, - peak_sign=peak_sign, - upsampling_factor=int(upsampling_factor), - metric_params=metric_params_, - delete_existing_metrics=delete_existing_metrics, - metrics_to_compute=metrics_to_compute, - ) - - return params - - def _select_extension_data(self, unit_ids): - new_metrics = self.data["metrics"].loc[np.array(unit_ids)] - return dict(metrics=new_metrics) - - def _merge_extension_data( - self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs - ): - import pandas as pd - - metric_names = self.params["metric_names"] - old_metrics = self.data["metrics"] - - all_unit_ids = new_sorting_analyzer.unit_ids - not_new_ids = all_unit_ids[~np.isin(all_unit_ids, new_unit_ids)] - - metrics = pd.DataFrame(index=all_unit_ids, columns=old_metrics.columns) - - metrics.loc[not_new_ids, :] = old_metrics.loc[not_new_ids, :] - metrics.loc[new_unit_ids, :] = self._compute_metrics( - new_sorting_analyzer, new_unit_ids, verbose, metric_names, **job_kwargs - ) - - new_data = dict(metrics=metrics) - return new_data - - def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): - import pandas as pd - - metric_names = self.params["metric_names"] - old_metrics = self.data["metrics"] - - all_unit_ids = new_sorting_analyzer.unit_ids - new_unit_ids_f = list(chain(*new_unit_ids)) - not_new_ids = all_unit_ids[~np.isin(all_unit_ids, new_unit_ids_f)] - - metrics = pd.DataFrame(index=all_unit_ids, columns=old_metrics.columns) - - metrics.loc[not_new_ids, :] = old_metrics.loc[not_new_ids, :] - metrics.loc[new_unit_ids_f, :] = self._compute_metrics( - new_sorting_analyzer, new_unit_ids_f, verbose, metric_names, **job_kwargs +class ComputeTemplateMetrics(ComputeTemplateMetricsNew): + def __init__(self, *args, **kwargs): + warnings.warn( + "The module 'spikeinterface.postprocessing.template_metrics' is deprecated and will be removed in 0.105.0." + "Please use 'spikeinterface.metrics.template' instead.", + DeprecationWarning, + stacklevel=2, ) + super().__init__(*args, **kwargs) - new_data = dict(metrics=metrics) - return new_data - - def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metric_names=None, **job_kwargs): - """ - Compute template metrics. - """ - import pandas as pd - from scipy.signal import resample_poly - - sparsity = self.params["sparsity"] - peak_sign = self.params["peak_sign"] - upsampling_factor = self.params["upsampling_factor"] - if unit_ids is None: - unit_ids = sorting_analyzer.unit_ids - sampling_frequency = sorting_analyzer.sampling_frequency - - metrics_single_channel = [m for m in metric_names if m in get_single_channel_template_metric_names()] - metrics_multi_channel = [m for m in metric_names if m in get_multi_channel_template_metric_names()] - - if sparsity is None: - extremum_channels_ids = get_template_extremum_channel(sorting_analyzer, peak_sign=peak_sign, outputs="id") - - template_metrics = pd.DataFrame(index=unit_ids, columns=metric_names) - else: - extremum_channels_ids = sparsity.unit_id_to_channel_ids - index_unit_ids = [] - index_channel_ids = [] - for unit_id, sparse_channels in extremum_channels_ids.items(): - index_unit_ids += [unit_id] * len(sparse_channels) - index_channel_ids += list(sparse_channels) - multi_index = pd.MultiIndex.from_tuples( - list(zip(index_unit_ids, index_channel_ids)), names=["unit_id", "channel_id"] - ) - template_metrics = pd.DataFrame(index=multi_index, columns=metric_names) - - all_templates = get_dense_templates_array(sorting_analyzer, return_in_uV=True) - - channel_locations = sorting_analyzer.get_channel_locations() - - for unit_id in unit_ids: - unit_index = sorting_analyzer.sorting.id_to_index(unit_id) - template_all_chans = all_templates[unit_index] - chan_ids = np.array(extremum_channels_ids[unit_id]) - if chan_ids.ndim == 0: - chan_ids = [chan_ids] - chan_ind = sorting_analyzer.channel_ids_to_indices(chan_ids) - template = template_all_chans[:, chan_ind] - - # compute single_channel metrics - for i, template_single in enumerate(template.T): - if sparsity is None: - index = unit_id - else: - index = (unit_id, chan_ids[i]) - if upsampling_factor > 1: - assert isinstance(upsampling_factor, (int, np.integer)), "'upsample' must be an integer" - template_upsampled = resample_poly(template_single, up=upsampling_factor, down=1) - sampling_frequency_up = upsampling_factor * sampling_frequency - else: - template_upsampled = template_single - sampling_frequency_up = sampling_frequency - - trough_idx, peak_idx = get_trough_and_peak_idx(template_upsampled) - - for metric_name in metrics_single_channel: - func = _metric_name_to_func[metric_name] - try: - value = func( - template_upsampled, - sampling_frequency=sampling_frequency_up, - trough_idx=trough_idx, - peak_idx=peak_idx, - **self.params["metric_params"][metric_name], - ) - except Exception as e: - warnings.warn(f"Error computing metric {metric_name} for unit {unit_id}: {e}") - value = np.nan - template_metrics.at[index, metric_name] = value - - # compute metrics multi_channel - for metric_name in metrics_multi_channel: - # retrieve template (with sparsity if waveform extractor is sparse) - template = all_templates[unit_index, :, :] - if sorting_analyzer.is_sparse(): - mask = sorting_analyzer.sparsity.mask[unit_index, :] - template = template[:, mask] - - if template.shape[1] < self.min_channels_for_multi_channel_warning: - warnings.warn( - f"With less than {self.min_channels_for_multi_channel_warning} channels, " - "multi-channel metrics might not be reliable." - ) - if sorting_analyzer.is_sparse(): - channel_locations_sparse = channel_locations[sorting_analyzer.sparsity.mask[unit_index]] - else: - channel_locations_sparse = channel_locations - - if upsampling_factor > 1: - assert isinstance(upsampling_factor, (int, np.integer)), "'upsample' must be an integer" - template_upsampled = resample_poly(template, up=upsampling_factor, down=1, axis=0) - sampling_frequency_up = upsampling_factor * sampling_frequency - else: - template_upsampled = template - sampling_frequency_up = sampling_frequency - - func = _metric_name_to_func[metric_name] - try: - value = func( - template_upsampled, - channel_locations=channel_locations_sparse, - sampling_frequency=sampling_frequency_up, - **self.params["metric_params"][metric_name], - ) - except Exception as e: - warnings.warn(f"Error computing metric {metric_name} for unit {unit_id}: {e}") - value = np.nan - template_metrics.at[index, metric_name] = value - - # we use the convert_dtypes to convert the columns to the most appropriate dtype and avoid object columns - # (in case of NaN values) - template_metrics = template_metrics.convert_dtypes() - return template_metrics - - def _run(self, verbose=False): - - metrics_to_compute = self.params["metrics_to_compute"] - delete_existing_metrics = self.params["delete_existing_metrics"] - - # compute the metrics which have been specified by the user - computed_metrics = self._compute_metrics( - sorting_analyzer=self.sorting_analyzer, unit_ids=None, verbose=verbose, metric_names=metrics_to_compute - ) - - existing_metrics = [] - - # Check if we need to propagate any old metrics. If so, we'll do that. - # Otherwise, we'll avoid attempting to load an empty template_metrics. - if set(self.params["metrics_to_compute"]) != set(self.params["metric_names"]): - - tm_extension = self.sorting_analyzer.get_extension("template_metrics") - if ( - delete_existing_metrics is False - and tm_extension is not None - and tm_extension.data.get("metrics") is not None - ): - existing_metrics = tm_extension.params["metric_names"] - - existing_metrics = [] - # here we get in the loaded via the dict only (to avoid full loading from disk after params reset) - tm_extension = self.sorting_analyzer.extensions.get("template_metrics", None) - if ( - delete_existing_metrics is False - and tm_extension is not None - and tm_extension.data.get("metrics") is not None - ): - existing_metrics = tm_extension.params["metric_names"] - - # append the metrics which were previously computed - for metric_name in set(existing_metrics).difference(metrics_to_compute): - # some metrics names produce data columns with other names. This deals with that. - for column_name in tm_compute_name_to_column_names[metric_name]: - computed_metrics[column_name] = tm_extension.data["metrics"][column_name] - - self.data["metrics"] = computed_metrics - - def _get_data(self): - return self.data["metrics"] - - -register_result_extension(ComputeTemplateMetrics) -compute_template_metrics = ComputeTemplateMetrics.function_factory() - - -_default_function_kwargs = dict( - recovery_window_ms=0.7, - peak_relative_threshold=0.2, - peak_width_ms=0.1, - depth_direction="y", - min_channels_for_velocity=5, - min_r2_velocity=0.5, - exp_peak_function="ptp", - min_r2_exp_decay=0.5, - spread_threshold=0.2, - spread_smooth_um=20, - column_range=None, -) - - -def get_default_tm_params(metric_names): - if metric_names is None: - metric_names = get_template_metric_names() - - base_tm_params = _default_function_kwargs - - metric_params = {} - for metric_name in metric_names: - metric_params[metric_name] = deepcopy(base_tm_params) - - return metric_params - - -# a dict converting the name of the metric for computation to the output of that computation -tm_compute_name_to_column_names = { - "peak_to_valley": ["peak_to_valley"], - "peak_trough_ratio": ["peak_trough_ratio"], - "half_width": ["half_width"], - "repolarization_slope": ["repolarization_slope"], - "recovery_slope": ["recovery_slope"], - "num_positive_peaks": ["num_positive_peaks"], - "num_negative_peaks": ["num_negative_peaks"], - "velocity_above": ["velocity_above"], - "velocity_below": ["velocity_below"], - "exp_decay": ["exp_decay"], - "spread": ["spread"], -} - - -def get_trough_and_peak_idx(template): - """ - Return the indices into the input template of the detected trough - (minimum of template) and peak (maximum of template, after trough). - Assumes negative trough and positive peak. - - Parameters - ---------- - template: numpy.ndarray - The 1D template waveform - - Returns - ------- - trough_idx: int - The index of the trough - peak_idx: int - The index of the peak - """ - assert template.ndim == 1 - trough_idx = np.argmin(template) - peak_idx = trough_idx + np.argmax(template[trough_idx:]) - return trough_idx, peak_idx - - -######################################################################################### -# Single-channel metrics -def get_peak_to_valley(template_single, sampling_frequency, trough_idx=None, peak_idx=None, **kwargs) -> float: - """ - Return the peak to valley duration in seconds of input waveforms. - - Parameters - ---------- - template_single: numpy.ndarray - The 1D template waveform - sampling_frequency : float - The sampling frequency of the template - trough_idx: int, default: None - The index of the trough - peak_idx: int, default: None - The index of the peak - - Returns - ------- - ptv: float - The peak to valley duration in seconds - """ - if trough_idx is None or peak_idx is None: - trough_idx, peak_idx = get_trough_and_peak_idx(template_single) - ptv = (peak_idx - trough_idx) / sampling_frequency - return ptv - - -def get_peak_trough_ratio(template_single, sampling_frequency=None, trough_idx=None, peak_idx=None, **kwargs) -> float: - """ - Return the peak to trough ratio of input waveforms. - - Parameters - ---------- - template_single: numpy.ndarray - The 1D template waveform - sampling_frequency : float - The sampling frequency of the template - trough_idx: int, default: None - The index of the trough - peak_idx: int, default: None - The index of the peak - - Returns - ------- - ptratio: float - The peak to trough ratio - """ - if trough_idx is None or peak_idx is None: - trough_idx, peak_idx = get_trough_and_peak_idx(template_single) - ptratio = template_single[peak_idx] / template_single[trough_idx] - return ptratio - - -def get_half_width(template_single, sampling_frequency, trough_idx=None, peak_idx=None, **kwargs) -> float: - """ - Return the half width of input waveforms in seconds. - - Parameters - ---------- - template_single: numpy.ndarray - The 1D template waveform - sampling_frequency : float - The sampling frequency of the template - trough_idx: int, default: None - The index of the trough - peak_idx: int, default: None - The index of the peak - - Returns - ------- - hw: float - The half width in seconds - """ - if trough_idx is None or peak_idx is None: - trough_idx, peak_idx = get_trough_and_peak_idx(template_single) - - if peak_idx == 0: - return np.nan - - trough_val = template_single[trough_idx] - # threshold is half of peak height (assuming baseline is 0) - threshold = 0.5 * trough_val - - (cpre_idx,) = np.where(template_single[:trough_idx] < threshold) - (cpost_idx,) = np.where(template_single[trough_idx:] < threshold) - - if len(cpre_idx) == 0 or len(cpost_idx) == 0: - hw = np.nan - - else: - # last occurence of template lower than thr, before peak - cross_pre_pk = cpre_idx[0] - 1 - # first occurence of template lower than peak, after peak - cross_post_pk = cpost_idx[-1] + 1 + trough_idx - - hw = (cross_post_pk - cross_pre_pk) / sampling_frequency - return hw - - -def get_repolarization_slope(template_single, sampling_frequency, trough_idx=None, **kwargs): - """ - Return slope of repolarization period between trough and baseline - - After reaching it's maximum polarization, the neuron potential will - recover. The repolarization slope is defined as the dV/dT of the action potential - between trough and baseline. The returned slope is in units of (unit of template) - per second. By default traces are scaled to units of uV, controlled - by `sorting_analyzer.return_in_uV`. In this case this function returns the slope - in uV/s. - - Parameters - ---------- - template_single: numpy.ndarray - The 1D template waveform - sampling_frequency : float - The sampling frequency of the template - trough_idx: int, default: None - The index of the trough - - Returns - ------- - slope: float - The repolarization slope - """ - if trough_idx is None: - trough_idx, _ = get_trough_and_peak_idx(template_single) - - times = np.arange(template_single.shape[0]) / sampling_frequency - - if trough_idx == 0: - return np.nan - - (rtrn_idx,) = np.nonzero(template_single[trough_idx:] >= 0) - if len(rtrn_idx) == 0: - return np.nan - # first time after trough, where template is at baseline - return_to_base_idx = rtrn_idx[0] + trough_idx - - if return_to_base_idx - trough_idx < 3: - return np.nan - - import scipy.stats - - res = scipy.stats.linregress(times[trough_idx:return_to_base_idx], template_single[trough_idx:return_to_base_idx]) - return res.slope - - -def get_recovery_slope(template_single, sampling_frequency, peak_idx=None, **kwargs): - """ - Return the recovery slope of input waveforms. After repolarization, - the neuron hyperpolarizes until it peaks. The recovery slope is the - slope of the action potential after the peak, returning to the baseline - in dV/dT. The returned slope is in units of (unit of template) - per second. By default traces are scaled to units of uV, controlled - by `sorting_analyzer.return_in_uV`. In this case this function returns the slope - in uV/s. The slope is computed within a user-defined window after the peak. - - Parameters - ---------- - template_single: numpy.ndarray - The 1D template waveform - sampling_frequency : float - The sampling frequency of the template - peak_idx: int, default: None - The index of the peak - **kwargs: Required kwargs: - - recovery_window_ms: the window in ms after the peak to compute the recovery_slope - - Returns - ------- - res.slope: float - The recovery slope - """ - import scipy.stats - - assert "recovery_window_ms" in kwargs, "recovery_window_ms must be given as kwarg" - recovery_window_ms = kwargs["recovery_window_ms"] - if peak_idx is None: - _, peak_idx = get_trough_and_peak_idx(template_single) - - times = np.arange(template_single.shape[0]) / sampling_frequency - - if peak_idx == 0: - return np.nan - max_idx = int(peak_idx + ((recovery_window_ms / 1000) * sampling_frequency)) - max_idx = np.min([max_idx, template_single.shape[0]]) - - res = scipy.stats.linregress(times[peak_idx:max_idx], template_single[peak_idx:max_idx]) - return res.slope - - -def get_num_positive_peaks(template_single, sampling_frequency, **kwargs): - """ - Count the number of positive peaks in the template. - - Parameters - ---------- - template_single: numpy.ndarray - The 1D template waveform - sampling_frequency : float - The sampling frequency of the template - **kwargs: Required kwargs: - - peak_relative_threshold: the relative threshold to detect positive and negative peaks - - peak_width_ms: the width in samples to detect peaks - - Returns - ------- - number_positive_peaks: int - the number of positive peaks - """ - from scipy.signal import find_peaks - - assert "peak_relative_threshold" in kwargs, "peak_relative_threshold must be given as kwarg" - assert "peak_width_ms" in kwargs, "peak_width_ms must be given as kwarg" - peak_relative_threshold = kwargs["peak_relative_threshold"] - peak_width_ms = kwargs["peak_width_ms"] - max_value = np.max(np.abs(template_single)) - peak_width_samples = int(peak_width_ms / 1000 * sampling_frequency) - - pos_peaks = find_peaks(template_single, height=peak_relative_threshold * max_value, width=peak_width_samples) - - return len(pos_peaks[0]) - - -def get_num_negative_peaks(template_single, sampling_frequency, **kwargs): - """ - Count the number of negative peaks in the template. - - Parameters - ---------- - template_single: numpy.ndarray - The 1D template waveform - sampling_frequency : float - The sampling frequency of the template - **kwargs: Required kwargs: - - peak_relative_threshold: the relative threshold to detect positive and negative peaks - - peak_width_ms: the width in samples to detect peaks - - Returns - ------- - num_negative_peaks: int - the number of negative peaks - """ - from scipy.signal import find_peaks - - assert "peak_relative_threshold" in kwargs, "peak_relative_threshold must be given as kwarg" - assert "peak_width_ms" in kwargs, "peak_width_ms must be given as kwarg" - peak_relative_threshold = kwargs["peak_relative_threshold"] - peak_width_ms = kwargs["peak_width_ms"] - max_value = np.max(np.abs(template_single)) - peak_width_samples = int(peak_width_ms / 1000 * sampling_frequency) - - neg_peaks = find_peaks(-template_single, height=peak_relative_threshold * max_value, width=peak_width_samples) - - return len(neg_peaks[0]) - - -_single_channel_metric_name_to_func = { - "peak_to_valley": get_peak_to_valley, - "peak_trough_ratio": get_peak_trough_ratio, - "half_width": get_half_width, - "repolarization_slope": get_repolarization_slope, - "recovery_slope": get_recovery_slope, - "num_positive_peaks": get_num_positive_peaks, - "num_negative_peaks": get_num_negative_peaks, -} - - -######################################################################################### -# Multi-channel metrics - - -def transform_column_range(template, channel_locations, column_range, depth_direction="y"): - """ - Transform template and channel locations based on column range. - """ - column_dim = 0 if depth_direction == "y" else 1 - if column_range is None: - template_column_range = template - channel_locations_column_range = channel_locations - else: - max_channel_x = channel_locations[np.argmax(np.ptp(template, axis=0)), 0] - column_mask = np.abs(channel_locations[:, column_dim] - max_channel_x) <= column_range - template_column_range = template[:, column_mask] - channel_locations_column_range = channel_locations[column_mask] - return template_column_range, channel_locations_column_range - - -def sort_template_and_locations(template, channel_locations, depth_direction="y"): - """ - Sort template and locations. - """ - depth_dim = 1 if depth_direction == "y" else 0 - sort_indices = np.argsort(channel_locations[:, depth_dim]) - return template[:, sort_indices], channel_locations[sort_indices, :] - - -def fit_velocity(peak_times, channel_dist): - """ - Fit velocity from peak times and channel distances using robust Theilsen estimator. - """ - # from scipy.stats import linregress - # slope, intercept, _, _, _ = linregress(peak_times, channel_dist) - - from sklearn.linear_model import TheilSenRegressor - - theil = TheilSenRegressor() - theil.fit(peak_times.reshape(-1, 1), channel_dist) - slope = theil.coef_[0] - intercept = theil.intercept_ - score = theil.score(peak_times.reshape(-1, 1), channel_dist) - return slope, intercept, score - - -def get_velocity_above(template, channel_locations, sampling_frequency, **kwargs): - """ - Compute the velocity above the max channel of the template in units um/s. - - Parameters - ---------- - template: numpy.ndarray - The template waveform (num_samples, num_channels) - channel_locations: numpy.ndarray - The channel locations (num_channels, 2) - sampling_frequency : float - The sampling frequency of the template - **kwargs: Required kwargs: - - depth_direction: the direction to compute velocity above and below ("x", "y", or "z") - - min_channels_for_velocity: the minimum number of channels above or below to compute velocity - - min_r2_velocity: the minimum r2 to accept the velocity fit - - column_range: the range in um in the x-direction to consider channels for velocity - """ - assert "depth_direction" in kwargs, "depth_direction must be given as kwarg" - assert "min_channels_for_velocity" in kwargs, "min_channels_for_velocity must be given as kwarg" - assert "min_r2_velocity" in kwargs, "min_r2_velocity must be given as kwarg" - assert "column_range" in kwargs, "column_range must be given as kwarg" - - depth_direction = kwargs["depth_direction"] - min_channels_for_velocity = kwargs["min_channels_for_velocity"] - min_r2_velocity = kwargs["min_r2_velocity"] - column_range = kwargs["column_range"] - - depth_dim = 1 if depth_direction == "y" else 0 - template, channel_locations = transform_column_range(template, channel_locations, column_range, depth_direction) - template, channel_locations = sort_template_and_locations(template, channel_locations, depth_direction) - - # find location of max channel - max_sample_idx, max_channel_idx = np.unravel_index(np.argmin(template), template.shape) - max_peak_time = max_sample_idx / sampling_frequency * 1000 - max_channel_location = channel_locations[max_channel_idx] - - channels_above = channel_locations[:, depth_dim] >= max_channel_location[depth_dim] - - # if not enough channels return NaN - if np.sum(channels_above) < min_channels_for_velocity: - return np.nan - - template_above = template[:, channels_above] - channel_locations_above = channel_locations[channels_above] - - peak_times_ms_above = np.argmin(template_above, 0) / sampling_frequency * 1000 - max_peak_time - distances_um_above = np.array([np.linalg.norm(cl - max_channel_location) for cl in channel_locations_above]) - velocity_above, intercept, score = fit_velocity(peak_times_ms_above, distances_um_above) - - # if r2 score is to low return NaN - if score < min_r2_velocity: - return np.nan - - # if DEBUG: - # import matplotlib.pyplot as plt - - # fig, axs = plt.subplots(ncols=2, figsize=(10, 7)) - # offset = 1.2 * np.max(np.ptp(template, axis=0)) - # ts = np.arange(template.shape[0]) / sampling_frequency * 1000 - max_peak_time - # (channel_indices_above,) = np.nonzero(channels_above) - # for i, single_template in enumerate(template.T): - # color = "r" if i in channel_indices_above else "k" - # axs[0].plot(ts, single_template + i * offset, color=color) - # axs[0].axvline(0, color="g", ls="--") - # axs[1].plot(peak_times_ms_above, distances_um_above, "o") - # x = np.linspace(peak_times_ms_above.min(), peak_times_ms_above.max(), 20) - # axs[1].plot(x, intercept + x * velocity_above) - # axs[1].set_xlabel("Peak time (ms)") - # axs[1].set_ylabel("Distance from max channel (um)") - # fig.suptitle( - # f"Velocity above: {velocity_above:.2f} um/ms - score {score:.2f} - channels: {np.sum(channels_above)}" - # ) - # plt.show() - - return velocity_above - - -def get_velocity_below(template, channel_locations, sampling_frequency, **kwargs): - """ - Compute the velocity below the max channel of the template in units um/s. - - Parameters - ---------- - template: numpy.ndarray - The template waveform (num_samples, num_channels) - channel_locations: numpy.ndarray - The channel locations (num_channels, 2) - sampling_frequency : float - The sampling frequency of the template - **kwargs: Required kwargs: - - depth_direction: the direction to compute velocity above and below ("x", "y", or "z") - - min_channels_for_velocity: the minimum number of channels above or below to compute velocity - - min_r2_velocity: the minimum r2 to accept the velocity fit - - column_range: the range in um in the x-direction to consider channels for velocity - """ - assert "depth_direction" in kwargs, "depth_direction must be given as kwarg" - assert "min_channels_for_velocity" in kwargs, "min_channels_for_velocity must be given as kwarg" - assert "min_r2_velocity" in kwargs, "min_r2_velocity must be given as kwarg" - assert "column_range" in kwargs, "column_range must be given as kwarg" - - depth_direction = kwargs["depth_direction"] - min_channels_for_velocity = kwargs["min_channels_for_velocity"] - min_r2_velocity = kwargs["min_r2_velocity"] - column_range = kwargs["column_range"] - - depth_dim = 1 if depth_direction == "y" else 0 - template, channel_locations = transform_column_range(template, channel_locations, column_range) - template, channel_locations = sort_template_and_locations(template, channel_locations, depth_direction) - - # find location of max channel - max_sample_idx, max_channel_idx = np.unravel_index(np.argmin(template), template.shape) - max_peak_time = max_sample_idx / sampling_frequency * 1000 - max_channel_location = channel_locations[max_channel_idx] - - channels_below = channel_locations[:, depth_dim] <= max_channel_location[depth_dim] - - # if not enough channels return NaN - if np.sum(channels_below) < min_channels_for_velocity: - return np.nan - - template_below = template[:, channels_below] - channel_locations_below = channel_locations[channels_below] - - peak_times_ms_below = np.argmin(template_below, 0) / sampling_frequency * 1000 - max_peak_time - distances_um_below = np.array([np.linalg.norm(cl - max_channel_location) for cl in channel_locations_below]) - velocity_below, intercept, score = fit_velocity(peak_times_ms_below, distances_um_below) - - # if r2 score is to low return NaN - if score < min_r2_velocity: - return np.nan - - # if DEBUG: - # import matplotlib.pyplot as plt - - # fig, axs = plt.subplots(ncols=2, figsize=(10, 7)) - # offset = 1.2 * np.max(np.ptp(template, axis=0)) - # ts = np.arange(template.shape[0]) / sampling_frequency * 1000 - max_peak_time - # (channel_indices_below,) = np.nonzero(channels_below) - # for i, single_template in enumerate(template.T): - # color = "r" if i in channel_indices_below else "k" - # axs[0].plot(ts, single_template + i * offset, color=color) - # axs[0].axvline(0, color="g", ls="--") - # axs[1].plot(peak_times_ms_below, distances_um_below, "o") - # x = np.linspace(peak_times_ms_below.min(), peak_times_ms_below.max(), 20) - # axs[1].plot(x, intercept + x * velocity_below) - # axs[1].set_xlabel("Peak time (ms)") - # axs[1].set_ylabel("Distance from max channel (um)") - # fig.suptitle( - # f"Velocity below: {np.round(velocity_below, 3)} um/ms - score {score:.2f} - channels: {np.sum(channels_below)}" - # ) - # plt.show() - - return velocity_below - - -def get_exp_decay(template, channel_locations, sampling_frequency=None, **kwargs): - """ - Compute the exponential decay of the template amplitude over distance in units um/s. - - Parameters - ---------- - template: numpy.ndarray - The template waveform (num_samples, num_channels) - channel_locations: numpy.ndarray - The channel locations (num_channels, 2) - sampling_frequency : float - The sampling frequency of the template - **kwargs: Required kwargs: - - exp_peak_function: the function to use to compute the peak amplitude for the exp decay ("ptp" or "min") - - min_r2_exp_decay: the minimum r2 to accept the exp decay fit - - Returns - ------- - exp_decay_value : float - The exponential decay of the template amplitude - """ - from scipy.optimize import curve_fit - from sklearn.metrics import r2_score - - def exp_decay(x, decay, amp0, offset): - return amp0 * np.exp(-decay * x) + offset - - assert "exp_peak_function" in kwargs, "exp_peak_function must be given as kwarg" - exp_peak_function = kwargs["exp_peak_function"] - assert "min_r2_exp_decay" in kwargs, "min_r2_exp_decay must be given as kwarg" - min_r2_exp_decay = kwargs["min_r2_exp_decay"] - # exp decay fit - if exp_peak_function == "ptp": - fun = np.ptp - elif exp_peak_function == "min": - fun = np.min - peak_amplitudes = np.abs(fun(template, axis=0)) - max_channel_location = channel_locations[np.argmax(peak_amplitudes)] - channel_distances = np.array([np.linalg.norm(cl - max_channel_location) for cl in channel_locations]) - distances_sort_indices = np.argsort(channel_distances) - - # longdouble is float128 when the platform supports it, otherwise it is float64 - channel_distances_sorted = channel_distances[distances_sort_indices].astype(np.longdouble) - peak_amplitudes_sorted = peak_amplitudes[distances_sort_indices].astype(np.longdouble) - - try: - amp0 = peak_amplitudes_sorted[0] - offset0 = np.min(peak_amplitudes_sorted) - - popt, _ = curve_fit( - exp_decay, - channel_distances_sorted, - peak_amplitudes_sorted, - bounds=([1e-5, amp0 - 0.5 * amp0, 0], [2, amp0 + 0.5 * amp0, 2 * offset0]), - p0=[1e-3, peak_amplitudes_sorted[0], offset0], - ) - r2 = r2_score(peak_amplitudes_sorted, exp_decay(channel_distances_sorted, *popt)) - exp_decay_value = popt[0] - - if r2 < min_r2_exp_decay: - exp_decay_value = np.nan - - # if DEBUG: - # import matplotlib.pyplot as plt - - # fig, ax = plt.subplots() - # ax.plot(channel_distances_sorted, peak_amplitudes_sorted, "o") - # x = np.arange(channel_distances_sorted.min(), channel_distances_sorted.max()) - # ax.plot(x, exp_decay(x, *popt)) - # ax.set_xlabel("Distance from max channel (um)") - # ax.set_ylabel("Peak amplitude") - # ax.set_title( - # f"Exp decay: {np.round(exp_decay_value, 3)} - Amp: {np.round(popt[1], 3)} - Offset: {np.round(popt[2], 3)} - " - # f"R2: {np.round(r2, 4)}" - # ) - # fig.suptitle("Exp decay") - # plt.show() - except: - exp_decay_value = np.nan - - return exp_decay_value - - -def get_spread(template, channel_locations, sampling_frequency, **kwargs) -> float: - """ - Compute the spread of the template amplitude over distance in units um/s. - - Parameters - ---------- - template: numpy.ndarray - The template waveform (num_samples, num_channels) - channel_locations: numpy.ndarray - The channel locations (num_channels, 2) - sampling_frequency : float - The sampling frequency of the template - **kwargs: Required kwargs: - - depth_direction: the direction to compute velocity above and below ("x", "y", or "z") - - spread_threshold: the threshold to compute the spread - - column_range: the range in um in the x-direction to consider channels for velocity - - Returns - ------- - spread : float - Spread of the template amplitude - """ - assert "depth_direction" in kwargs, "depth_direction must be given as kwarg" - depth_direction = kwargs["depth_direction"] - assert "spread_threshold" in kwargs, "spread_threshold must be given as kwarg" - spread_threshold = kwargs["spread_threshold"] - assert "spread_smooth_um" in kwargs, "spread_smooth_um must be given as kwarg" - spread_smooth_um = kwargs["spread_smooth_um"] - assert "column_range" in kwargs, "column_range must be given as kwarg" - column_range = kwargs["column_range"] - - depth_dim = 1 if depth_direction == "y" else 0 - template, channel_locations = transform_column_range(template, channel_locations, column_range) - template, channel_locations = sort_template_and_locations(template, channel_locations, depth_direction) - - MM = np.ptp(template, 0) - channel_depths = channel_locations[:, depth_dim] - - if spread_smooth_um is not None and spread_smooth_um > 0: - from scipy.ndimage import gaussian_filter1d - - spread_sigma = spread_smooth_um / np.median(np.diff(np.unique(channel_depths))) - MM = gaussian_filter1d(MM, spread_sigma) - - MM = MM / np.max(MM) - - channel_locations_above_threshold = channel_locations[MM > spread_threshold] - channel_depth_above_threshold = channel_locations_above_threshold[:, depth_dim] - - spread = np.ptp(channel_depth_above_threshold) - - # if DEBUG: - # import matplotlib.pyplot as plt - - # fig, axs = plt.subplots(ncols=2, figsize=(10, 7)) - # axs[0].imshow( - # template.T, - # aspect="auto", - # origin="lower", - # extent=[0, template.shape[0] / sampling_frequency, channel_depths[0], channel_depths[-1]], - # ) - # axs[1].plot(channel_depths, MM, "o-") - # axs[1].axhline(spread_threshold, ls="--", color="r") - # axs[1].set_xlabel("Depth (um)") - # axs[1].set_ylabel("Amplitude") - # axs[1].set_title(f"Spread: {np.round(spread, 3)} um") - # fig.suptitle("Spread") - # plt.show() - - return spread - - -_multi_channel_metric_name_to_func = { - "velocity_above": get_velocity_above, - "velocity_below": get_velocity_below, - "exp_decay": get_exp_decay, - "spread": get_spread, -} -_metric_name_to_func = {**_single_channel_metric_name_to_func, **_multi_channel_metric_name_to_func} +def compute_template_metrics(*args, **kwargs): + warnings.warn( + "The module 'spikeinterface.postprocessing.template_metrics' is deprecated and will be removed in 0.105.0." + "Please use 'spikeinterface.metrics.template' instead.", + DeprecationWarning, + stacklevel=2, + ) + return compute_template_metrics_new(*args, **kwargs) diff --git a/src/spikeinterface/qualitymetrics/__init__.py b/src/spikeinterface/qualitymetrics/__init__.py index 754c82d8e3..58db27a719 100644 --- a/src/spikeinterface/qualitymetrics/__init__.py +++ b/src/spikeinterface/qualitymetrics/__init__.py @@ -1,9 +1,10 @@ -from .quality_metric_list import * -from .quality_metric_calculator import ( - compute_quality_metrics, - get_quality_metric_list, - ComputeQualityMetrics, - get_default_qm_params, +import warnings + +warnings.warn( + "The module 'spikeinterface.qualitymetrics' is deprecated and will be removed in 0.105.0." + "Please use 'spikeinterface.metrics.quality' instead.", + DeprecationWarning, + stacklevel=2, ) -from .pca_metrics import get_quality_pca_metric_list -from .misc_metrics import _get_synchrony_counts + +from spikeinterface.metrics.quality import * # noqa: F403 diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py deleted file mode 100644 index 5d338a990b..0000000000 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ /dev/null @@ -1,362 +0,0 @@ -"""Classes and functions for computing multiple quality metrics.""" - -from __future__ import annotations - -import warnings -from itertools import chain -from copy import deepcopy, copy - -import numpy as np -from warnings import warn - -from spikeinterface.core.job_tools import fix_job_kwargs -from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension - - -from .quality_metric_list import ( - compute_pc_metrics, - _misc_metric_name_to_func, - _possible_pc_metric_names, - qm_compute_name_to_column_names, - column_name_to_column_dtype, - metric_extension_dependencies, -) -from .misc_metrics import _default_params as misc_metrics_params -from .pca_metrics import _default_params as pca_metrics_params - - -class ComputeQualityMetrics(AnalyzerExtension): - """ - Compute quality metrics on a `sorting_analyzer`. - - Parameters - ---------- - sorting_analyzer : SortingAnalyzer - A SortingAnalyzer object. - metric_names : list or None - List of quality metrics to compute. - metric_params : dict of dicts or None - Dictionary with parameters for quality metrics calculation. - Default parameters can be obtained with: `si.qualitymetrics.get_default_qm_params()` - skip_pc_metrics : bool, default: False - If True, PC metrics computation is skipped. - delete_existing_metrics : bool, default: False - If True, any quality metrics attached to the `sorting_analyzer` are deleted. If False, any metrics which were previously calculated but are not included in `metric_names` are kept. - - Returns - ------- - metrics: pandas.DataFrame - Data frame with the computed metrics. - - Notes - ----- - principal_components are loaded automatically if already computed. - """ - - extension_name = "quality_metrics" - depend_on = [] - need_recording = False - use_nodepipeline = False - need_job_kwargs = True - need_backward_compatibility_on_load = True - - def _handle_backward_compatibility_on_load(self): - # For backwards compatibility - this renames qm_params as metric_params - if (qm_params := self.params.get("qm_params")) is not None: - self.params["metric_params"] = qm_params - del self.params["qm_params"] - - def _set_params( - self, - metric_names=None, - metric_params=None, - qm_params=None, - peak_sign=None, - seed=None, - skip_pc_metrics=False, - delete_existing_metrics=False, - metrics_to_compute=None, - ): - if qm_params is not None and metric_params is None: - deprecation_msg = ( - "`qm_params` is deprecated and will be removed in version 0.104.0 Please use metric_params instead" - ) - metric_params = qm_params - warn(deprecation_msg, category=DeprecationWarning, stacklevel=2) - - metric_names_is_none = False - if metric_names is None: - metric_names_is_none = True - metric_names = list(_misc_metric_name_to_func.keys()) - # if PC is available, PC metrics are automatically added to the list - if self.sorting_analyzer.has_extension("principal_components") and not skip_pc_metrics: - # by default 'nearest_neightbor' is removed because too slow - pc_metrics = _possible_pc_metric_names.copy() - pc_metrics.remove("nn_isolation") - pc_metrics.remove("nn_noise_overlap") - metric_names += pc_metrics - - metric_params_ = get_default_qm_params() - for k in metric_params_: - if metric_params is not None and k in metric_params: - metric_params_[k].update(metric_params[k]) - if "peak_sign" in metric_params_[k] and peak_sign is not None: - metric_params_[k]["peak_sign"] = peak_sign - - metrics_to_compute = metric_names - qm_extension = self.sorting_analyzer.get_extension("quality_metrics") - if delete_existing_metrics is False and qm_extension is not None: - - existing_metric_names = qm_extension.params["metric_names"] - existing_metric_names_propagated = [ - metric_name for metric_name in existing_metric_names if metric_name not in metrics_to_compute - ] - metric_names = metrics_to_compute + existing_metric_names_propagated - - ## Deal with dependencies - computable_metrics_to_compute = copy(metrics_to_compute) - if metric_names_is_none: - need_more_extensions = False - warning_text = "Some metrics you are trying to compute depend on other extensions:\n" - for metric in metrics_to_compute: - metric_dependencies = metric_extension_dependencies.get(metric) - if metric_dependencies is not None: - for extension_name in metric_dependencies: - if all( - self.sorting_analyzer.has_extension(name) is False for name in extension_name.split("|") - ): - need_more_extensions = True - if metric in computable_metrics_to_compute: - computable_metrics_to_compute.remove(metric) - warning_text += f" {metric} requires {metric_dependencies}\n" - warning_text += "To include these metrics, compute the required extensions using `sorting_analyzer.compute('extension_name')" - if need_more_extensions: - warnings.warn(warning_text) - - params = dict( - metric_names=metric_names, - peak_sign=peak_sign, - seed=seed, - metric_params=metric_params_, - skip_pc_metrics=skip_pc_metrics, - delete_existing_metrics=delete_existing_metrics, - metrics_to_compute=computable_metrics_to_compute, - ) - - return params - - def _select_extension_data(self, unit_ids): - new_metrics = self.data["metrics"].loc[np.array(unit_ids)] - new_data = dict(metrics=new_metrics) - return new_data - - def _merge_extension_data( - self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs - ): - import pandas as pd - - metric_names = self.params["metric_names"] - old_metrics = self.data["metrics"] - - all_unit_ids = new_sorting_analyzer.unit_ids - not_new_ids = all_unit_ids[~np.isin(all_unit_ids, new_unit_ids)] - - # this creates a new metrics dictionary, but the dtype for everything will be - # object. So we will need to fix this later after computing metrics - metrics = pd.DataFrame(index=all_unit_ids, columns=old_metrics.columns) - metrics.loc[not_new_ids, :] = old_metrics.loc[not_new_ids, :] - metrics.loc[new_unit_ids, :] = self._compute_metrics( - new_sorting_analyzer, new_unit_ids, verbose, metric_names, **job_kwargs - ) - - # we need to fix the dtypes after we compute everything because we have nans - # we can iterate through the columns and convert them back to the dtype - # of the original quality dataframe. - for column in old_metrics.columns: - metrics[column] = metrics[column].astype(old_metrics[column].dtype) - - new_data = dict(metrics=metrics) - return new_data - - def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): - import pandas as pd - - metric_names = self.params["metric_names"] - old_metrics = self.data["metrics"] - - all_unit_ids = new_sorting_analyzer.unit_ids - new_unit_ids_f = list(chain(*new_unit_ids)) - not_new_ids = all_unit_ids[~np.isin(all_unit_ids, new_unit_ids_f)] - - # this creates a new metrics dictionary, but the dtype for everything will be - # object. So we will need to fix this later after computing metrics - metrics = pd.DataFrame(index=all_unit_ids, columns=old_metrics.columns) - metrics.loc[not_new_ids, :] = old_metrics.loc[not_new_ids, :] - metrics.loc[new_unit_ids_f, :] = self._compute_metrics( - new_sorting_analyzer, new_unit_ids_f, verbose, metric_names, **job_kwargs - ) - - # we need to fix the dtypes after we compute everything because we have nans - # we can iterate through the columns and convert them back to the dtype - # of the original quality dataframe. - for column in old_metrics.columns: - metrics[column] = metrics[column].astype(old_metrics[column].dtype) - - new_data = dict(metrics=metrics) - return new_data - - def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metric_names=None, **job_kwargs): - """ - Compute quality metrics. - """ - import pandas as pd - - metric_params = self.params["metric_params"] - # sparsity = self.params["sparsity"] - seed = self.params["seed"] - - # update job_kwargs with global ones - job_kwargs = fix_job_kwargs(job_kwargs) - n_jobs = job_kwargs["n_jobs"] - progress_bar = job_kwargs["progress_bar"] - - if unit_ids is None: - sorting = sorting_analyzer.sorting - unit_ids = sorting.unit_ids - non_empty_unit_ids = sorting.get_non_empty_unit_ids() - empty_unit_ids = unit_ids[~np.isin(unit_ids, non_empty_unit_ids)] - if len(empty_unit_ids) > 0: - warnings.warn( - f"Units {empty_unit_ids} are empty. Quality metrics will be set to NaN " - f"for these units.\n To remove empty units, use `sorting.remove_empty_units()`." - ) - else: - non_empty_unit_ids = unit_ids - empty_unit_ids = [] - - metrics = pd.DataFrame(index=unit_ids) - - # simple metrics not based on PCs - for metric_name in metric_names: - # keep PC metrics for later - if metric_name in _possible_pc_metric_names: - continue - if verbose: - if metric_name not in _possible_pc_metric_names: - print(f"Computing {metric_name}") - - func = _misc_metric_name_to_func[metric_name] - - params = metric_params[metric_name] if metric_name in metric_params else {} - res = func(sorting_analyzer, unit_ids=non_empty_unit_ids, **params) - # QM with uninstall dependencies might return None - if res is not None: - if isinstance(res, dict): - # res is a dict convert to series - metrics.loc[non_empty_unit_ids, metric_name] = pd.Series(res) - else: - # res is a namedtuple with several dict - # so several columns - for i, col in enumerate(res._fields): - metrics.loc[non_empty_unit_ids, col] = pd.Series(res[i]) - - # metrics based on PCs - pc_metric_names = [k for k in metric_names if k in _possible_pc_metric_names] - if len(pc_metric_names) > 0 and not self.params["skip_pc_metrics"]: - if not sorting_analyzer.has_extension("principal_components"): - raise ValueError( - "To compute principal components base metrics, the principal components " - "extension must be computed first." - ) - pc_metrics = compute_pc_metrics( - sorting_analyzer, - unit_ids=non_empty_unit_ids, - metric_names=pc_metric_names, - # sparsity=sparsity, - progress_bar=progress_bar, - n_jobs=n_jobs, - metric_params=metric_params, - seed=seed, - ) - for col, values in pc_metrics.items(): - metrics.loc[non_empty_unit_ids, col] = pd.Series(values) - - # add NaN for empty units - if len(empty_unit_ids) > 0: - metrics.loc[empty_unit_ids] = np.nan - # num_spikes is an int and should be 0 - if "num_spikes" in metrics.columns: - metrics.loc[empty_unit_ids, ["num_spikes"]] = 0 - - # we use the convert_dtypes to convert the columns to the most appropriate dtype and avoid object columns - # (in case of NaN values) - metrics = metrics.convert_dtypes() - - # we do this because the convert_dtypes infers the wrong types sometimes. - # the actual types for columns can be found in column_name_to_column_dtype dictionary. - for column in metrics.columns: - if column in column_name_to_column_dtype: - metrics[column] = metrics[column].astype(column_name_to_column_dtype[column]) - - return metrics - - def _run(self, verbose=False, **job_kwargs): - - metrics_to_compute = self.params["metrics_to_compute"] - delete_existing_metrics = self.params["delete_existing_metrics"] - - computed_metrics = self._compute_metrics( - sorting_analyzer=self.sorting_analyzer, - unit_ids=None, - verbose=verbose, - metric_names=metrics_to_compute, - **job_kwargs, - ) - - existing_metrics = [] - # here we get in the loaded via the dict only (to avoid full loading from disk after params reset) - qm_extension = self.sorting_analyzer.extensions.get("quality_metrics", None) - if ( - delete_existing_metrics is False - and qm_extension is not None - and qm_extension.data.get("metrics") is not None - ): - existing_metrics = qm_extension.params["metric_names"] - - # append the metrics which were previously computed - for metric_name in set(existing_metrics).difference(metrics_to_compute): - # some metrics names produce data columns with other names. This deals with that. - for column_name in qm_compute_name_to_column_names[metric_name]: - computed_metrics[column_name] = qm_extension.data["metrics"][column_name] - - self.data["metrics"] = computed_metrics - - def _get_data(self): - return self.data["metrics"] - - -register_result_extension(ComputeQualityMetrics) -compute_quality_metrics = ComputeQualityMetrics.function_factory() - - -def get_quality_metric_list(): - """ - Return a list of the available quality metrics. - """ - - return deepcopy(list(_misc_metric_name_to_func.keys())) - - -def get_default_qm_params(): - """ - Return default dictionary of quality metrics parameters. - - Returns - ------- - dict - Default qm parameters with metric name as key and parameter dictionary as values. - """ - default_params = {} - default_params.update(misc_metrics_params) - default_params.update(pca_metrics_params) - return deepcopy(default_params) diff --git a/src/spikeinterface/qualitymetrics/quality_metric_list.py b/src/spikeinterface/qualitymetrics/quality_metric_list.py deleted file mode 100644 index 5e769ab8eb..0000000000 --- a/src/spikeinterface/qualitymetrics/quality_metric_list.py +++ /dev/null @@ -1,136 +0,0 @@ -"""Lists of quality metrics.""" - -from __future__ import annotations - -# a dict containing the extension dependencies for each metric -metric_extension_dependencies = { - "snr": ["noise_levels", "templates"], - "amplitude_cutoff": ["spike_amplitudes|waveforms", "templates"], - "amplitude_median": ["spike_amplitudes|waveforms", "templates"], - "amplitude_cv": ["spike_amplitudes|amplitude_scalings", "templates"], - "drift": ["spike_locations"], - "sd_ratio": ["templates", "spike_amplitudes"], - "noise_cutoff": ["spike_amplitudes"], -} - - -from .misc_metrics import ( - compute_num_spikes, - compute_firing_rates, - compute_presence_ratios, - compute_snrs, - compute_isi_violations, - compute_refrac_period_violations, - compute_sliding_rp_violations, - compute_amplitude_cutoffs, - compute_amplitude_medians, - compute_drift_metrics, - compute_synchrony_metrics, - compute_firing_ranges, - compute_amplitude_cv_metrics, - compute_sd_ratio, - compute_noise_cutoffs, -) - -from .pca_metrics import ( - compute_pc_metrics, - mahalanobis_metrics, - lda_metrics, - nearest_neighbors_metrics, - nearest_neighbors_isolation, - nearest_neighbors_noise_overlap, - silhouette_score, - simplified_silhouette_score, -) - -from .pca_metrics import _possible_pc_metric_names - - -# list of all available metrics and mapping to function -# this list MUST NOT contain pca metrics, which are handled separately -_misc_metric_name_to_func = { - "num_spikes": compute_num_spikes, - "firing_rate": compute_firing_rates, - "presence_ratio": compute_presence_ratios, - "snr": compute_snrs, - "isi_violation": compute_isi_violations, - "rp_violation": compute_refrac_period_violations, - "sliding_rp_violation": compute_sliding_rp_violations, - "amplitude_cutoff": compute_amplitude_cutoffs, - "amplitude_median": compute_amplitude_medians, - "amplitude_cv": compute_amplitude_cv_metrics, - "synchrony": compute_synchrony_metrics, - "firing_range": compute_firing_ranges, - "drift": compute_drift_metrics, - "sd_ratio": compute_sd_ratio, - "noise_cutoff": compute_noise_cutoffs, -} - - -# a dict converting the name of the metric for computation to the output of that computation -qm_compute_name_to_column_names = { - "num_spikes": ["num_spikes"], - "firing_rate": ["firing_rate"], - "presence_ratio": ["presence_ratio"], - "snr": ["snr"], - "isi_violation": ["isi_violations_ratio", "isi_violations_count"], - "rp_violation": ["rp_violations", "rp_contamination"], - "sliding_rp_violation": ["sliding_rp_violation"], - "amplitude_cutoff": ["amplitude_cutoff"], - "amplitude_median": ["amplitude_median"], - "amplitude_cv": ["amplitude_cv_median", "amplitude_cv_range"], - "synchrony": [ - "sync_spike_2", - "sync_spike_4", - "sync_spike_8", - ], - "firing_range": ["firing_range"], - "drift": ["drift_ptp", "drift_std", "drift_mad"], - "sd_ratio": ["sd_ratio"], - "isolation_distance": ["isolation_distance"], - "l_ratio": ["l_ratio"], - "d_prime": ["d_prime"], - "nearest_neighbor": ["nn_hit_rate", "nn_miss_rate"], - "nn_isolation": ["nn_isolation", "nn_unit_id"], - "nn_noise_overlap": ["nn_noise_overlap"], - "silhouette": ["silhouette"], - "silhouette_full": ["silhouette_full"], - "noise_cutoff": ["noise_cutoff", "noise_ratio"], -} - -# this dict allows us to ensure the appropriate dtype of metrics rather than allow Pandas to infer them -column_name_to_column_dtype = { - "num_spikes": int, - "firing_rate": float, - "presence_ratio": float, - "snr": float, - "isi_violations_ratio": float, - "isi_violations_count": float, - "rp_violations": float, - "rp_contamination": float, - "sliding_rp_violation": float, - "amplitude_cutoff": float, - "amplitude_median": float, - "amplitude_cv_median": float, - "amplitude_cv_range": float, - "sync_spike_2": float, - "sync_spike_4": float, - "sync_spike_8": float, - "firing_range": float, - "drift_ptp": float, - "drift_std": float, - "drift_mad": float, - "sd_ratio": float, - "isolation_distance": float, - "l_ratio": float, - "d_prime": float, - "nn_hit_rate": float, - "nn_miss_rate": float, - "nn_isolation": float, - "nn_unit_id": float, - "nn_noise_overlap": float, - "silhouette": float, - "silhouette_full": float, - "noise_cutoff": float, - "noise_ratio": float, -} diff --git a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py deleted file mode 100644 index 1491b9eac1..0000000000 --- a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py +++ /dev/null @@ -1,64 +0,0 @@ -import pytest -import numpy as np - -from spikeinterface.qualitymetrics import compute_pc_metrics, get_quality_pca_metric_list - - -def test_compute_pc_metrics(small_sorting_analyzer): - import pandas as pd - - sorting_analyzer = small_sorting_analyzer - res1 = compute_pc_metrics(sorting_analyzer, n_jobs=1, progress_bar=True, seed=1205) - res1 = pd.DataFrame(res1) - - res2 = compute_pc_metrics(sorting_analyzer, n_jobs=2, progress_bar=True, seed=1205) - res2 = pd.DataFrame(res2) - - for metric_name in res1.columns: - values1 = res1[metric_name].values - values2 = res1[metric_name].values - - if metric_name != "nn_unit_id": - assert not np.all(np.isnan(values1)) - assert not np.all(np.isnan(values2)) - - if values1.dtype.kind == "f": - np.testing.assert_almost_equal(values1, values2, decimal=4) - # import matplotlib.pyplot as plt - # fig, axs = plt.subplots(nrows=2, share=True) - # ax =a xs[0] - # ax.plot(res1[metric_name].values) - # ax.plot(res2[metric_name].values) - # ax =a xs[1] - # ax.plot(res2[metric_name].values - res1[metric_name].values) - # plt.show() - else: - assert np.array_equal(values1, values2) - - -def test_pca_metrics_multi_processing(small_sorting_analyzer): - sorting_analyzer = small_sorting_analyzer - - metric_names = get_quality_pca_metric_list() - metric_names.remove("nn_isolation") - metric_names.remove("nn_noise_overlap") - - print(f"Computing PCA metrics with 1 thread per process") - res1 = compute_pc_metrics( - sorting_analyzer, n_jobs=-1, metric_names=metric_names, max_threads_per_worker=1, progress_bar=True - ) - print(f"Computing PCA metrics with 2 thread per process") - res2 = compute_pc_metrics( - sorting_analyzer, n_jobs=-1, metric_names=metric_names, max_threads_per_worker=2, progress_bar=True - ) - print("Computing PCA metrics with spawn context") - res2 = compute_pc_metrics( - sorting_analyzer, n_jobs=-1, metric_names=metric_names, max_threads_per_worker=2, progress_bar=True - ) - - -if __name__ == "__main__": - from spikeinterface.qualitymetrics.tests.conftest import make_small_analyzer - - small_sorting_analyzer = make_small_analyzer() - test_calculate_pc_metrics(small_sorting_analyzer) diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py deleted file mode 100644 index 36f2e0785a..0000000000 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ /dev/null @@ -1,315 +0,0 @@ -import pytest -from pathlib import Path -import numpy as np - -from spikeinterface.core import ( - generate_ground_truth_recording, - create_sorting_analyzer, - NumpySorting, - aggregate_units, -) - -from spikeinterface.qualitymetrics import compute_snrs - - -from spikeinterface.qualitymetrics import ( - compute_quality_metrics, -) - -job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") - - -def test_warnings_errors_when_missing_deps(): - """ - If the user requests to compute a quality metric which depends on an extension - that has not been computed, this should error. If the user uses the default - quality metrics (i.e. they do not explicitly request the specific metrics), - this should report a warning about which metrics could not be computed. - We check this behavior in this test. - """ - - recording, sorting = generate_ground_truth_recording() - analyzer = create_sorting_analyzer(sorting=sorting, recording=recording) - - # user tries to use `compute_snrs` without templates. Should error - with pytest.raises(ValueError): - compute_snrs(analyzer) - - # user asks for drift metrics without spike_locations. Should error - with pytest.raises(ValueError): - analyzer.compute("quality_metrics", metric_names=["drift"]) - - # user doesn't specify which metrics to compute. Should return a warning - # about which metrics have not been computed. - with pytest.warns(Warning): - analyzer.compute("quality_metrics") - - -def test_compute_quality_metrics(sorting_analyzer_simple): - sorting_analyzer = sorting_analyzer_simple - - # without PCs - metrics = compute_quality_metrics( - sorting_analyzer, - metric_names=["snr"], - metric_params=dict(isi_violation=dict(isi_threshold_ms=2)), - skip_pc_metrics=True, - seed=2205, - ) - # print(metrics) - - qm = sorting_analyzer.get_extension("quality_metrics") - assert qm.params["metric_params"]["isi_violation"]["isi_threshold_ms"] == 2 - assert "snr" in metrics.columns - assert "isolation_distance" not in metrics.columns - - # with PCs - sorting_analyzer.compute("principal_components") - metrics = compute_quality_metrics( - sorting_analyzer, - metric_names=None, - metric_params=dict(isi_violation=dict(isi_threshold_ms=2)), - skip_pc_metrics=False, - seed=2205, - ) - print(metrics.columns) - assert "isolation_distance" in metrics.columns - - -def test_merging_quality_metrics(sorting_analyzer_simple): - - sorting_analyzer = sorting_analyzer_simple - - metrics = compute_quality_metrics( - sorting_analyzer, - metric_names=None, - qm_params=dict(isi_violation=dict(isi_threshold_ms=2)), - skip_pc_metrics=False, - seed=2205, - ) - - # sorting_analyzer_simple has ten units - new_sorting_analyzer = sorting_analyzer.merge_units([[0, 1]]) - - new_metrics = new_sorting_analyzer.get_extension("quality_metrics").get_data() - - # we should copy over the metrics after merge - for column in metrics.columns: - assert column in new_metrics.columns - # should copy dtype too - assert metrics[column].dtype == new_metrics[column].dtype - - # 10 units vs 9 units - assert len(metrics.index) > len(new_metrics.index) - - -def test_compute_quality_metrics_recordingless(sorting_analyzer_simple): - - sorting_analyzer = sorting_analyzer_simple - metrics = compute_quality_metrics( - sorting_analyzer, - metric_names=None, - metric_params=dict(isi_violation=dict(isi_threshold_ms=2)), - skip_pc_metrics=False, - seed=2205, - ) - - # make a copy and make it recordingless - sorting_analyzer_norec = sorting_analyzer.save_as(format="memory") - sorting_analyzer_norec.delete_extension("quality_metrics") - sorting_analyzer_norec._recording = None - assert not sorting_analyzer_norec.has_recording() - - metrics_norec = compute_quality_metrics( - sorting_analyzer_norec, - metric_names=None, - metric_params=dict(isi_violation=dict(isi_threshold_ms=2)), - skip_pc_metrics=False, - seed=2205, - ) - - for metric_name in metrics.columns: - if metric_name == "sd_ratio": - # this one need recording!!! - continue - assert np.allclose(metrics[metric_name].values, metrics_norec[metric_name].values, rtol=1e-02) - - -def test_empty_units(sorting_analyzer_simple): - sorting_analyzer = sorting_analyzer_simple - - empty_spike_train = np.array([], dtype="int64") - empty_sorting = NumpySorting.from_unit_dict( - {100: empty_spike_train, 200: empty_spike_train, 300: empty_spike_train}, - sampling_frequency=sorting_analyzer.sampling_frequency, - ) - sorting_empty = aggregate_units([sorting_analyzer.sorting, empty_sorting]) - assert len(sorting_empty.get_empty_unit_ids()) == 3 - - sorting_analyzer_empty = create_sorting_analyzer(sorting_empty, sorting_analyzer.recording, format="memory") - sorting_analyzer_empty.compute("random_spikes", max_spikes_per_unit=300, seed=2205) - sorting_analyzer_empty.compute("noise_levels") - sorting_analyzer_empty.compute("waveforms", **job_kwargs) - sorting_analyzer_empty.compute("templates") - sorting_analyzer_empty.compute("spike_amplitudes", **job_kwargs) - - metrics_empty = compute_quality_metrics( - sorting_analyzer_empty, - metric_names=None, - metric_params=dict(isi_violation=dict(isi_threshold_ms=2)), - skip_pc_metrics=True, - seed=2205, - ) - - # num_spikes are ints not nans so we confirm empty units are nans for everything except - # num_spikes which should be 0 - nan_containing_columns = [column for column in metrics_empty.columns if column != "num_spikes"] - for empty_unit_ids in sorting_empty.get_empty_unit_ids(): - from pandas import isnull - - assert np.all(isnull(metrics_empty.loc[empty_unit_ids, nan_containing_columns].values)) - if "num_spikes" in metrics_empty.columns: - assert sum(metrics_empty.loc[empty_unit_ids, ["num_spikes"]]) == 0 - - -# TODO @alessio all theses old test should be moved in test_metric_functions.py or test_pca_metrics() - -# def test_amplitude_cutoff(self): -# we = self.we_short -# _ = compute_spike_amplitudes(we, peak_sign="neg") - -# # If too few spikes, should raise a warning and set amplitude cutoffs to nans -# with pytest.warns(UserWarning) as w: -# metrics = self.extension_class.get_extension_function()( -# we, metric_names=["amplitude_cutoff"], peak_sign="neg" -# ) -# assert all(np.isnan(cutoff) for cutoff in metrics["amplitude_cutoff"].values) - -# # now we decrease the number of bins and check that amplitude cutoffs are correctly computed -# qm_params = dict(amplitude_cutoff=dict(num_histogram_bins=5)) -# with warnings.catch_warnings(): -# warnings.simplefilter("error") -# metrics = self.extension_class.get_extension_function()( -# we, metric_names=["amplitude_cutoff"], peak_sign="neg", qm_params=qm_params -# ) -# assert all(not np.isnan(cutoff) for cutoff in metrics["amplitude_cutoff"].values) - -# def test_presence_ratio(self): -# we = self.we_long - -# total_duration = we.get_total_duration() -# # If bin_duration_s is larger than total duration, should raise a warning and set presence ratios to nans -# qm_params = dict(presence_ratio=dict(bin_duration_s=total_duration + 1)) -# with pytest.warns(UserWarning) as w: -# metrics = self.extension_class.get_extension_function()( -# we, metric_names=["presence_ratio"], qm_params=qm_params -# ) -# assert all(np.isnan(ratio) for ratio in metrics["presence_ratio"].values) - -# # now we decrease the bin_duration_s and check that presence ratios are correctly computed -# qm_params = dict(presence_ratio=dict(bin_duration_s=total_duration // 10)) -# with warnings.catch_warnings(): -# warnings.simplefilter("error") -# metrics = self.extension_class.get_extension_function()( -# we, metric_names=["presence_ratio"], qm_params=qm_params -# ) -# assert all(not np.isnan(ratio) for ratio in metrics["presence_ratio"].values) - -# def test_drift_metrics(self): -# we = self.we_long # is also multi-segment - -# # if spike_locations is not an extension, raise a warning and set values to NaN -# with pytest.warns(UserWarning) as w: -# metrics = self.extension_class.get_extension_function()(we, metric_names=["drift"]) -# assert all(np.isnan(metric) for metric in metrics["drift_ptp"].values) -# assert all(np.isnan(metric) for metric in metrics["drift_std"].values) -# assert all(np.isnan(metric) for metric in metrics["drift_mad"].values) - -# # now we compute spike locations, but use an interval_s larger than half the total duration -# _ = compute_spike_locations(we) -# total_duration = we.get_total_duration() -# qm_params = dict(drift=dict(interval_s=total_duration // 2 + 1, min_spikes_per_interval=10, min_num_bins=2)) -# with pytest.warns(UserWarning) as w: -# metrics = self.extension_class.get_extension_function()(we, metric_names=["drift"], qm_params=qm_params) -# assert all(np.isnan(metric) for metric in metrics["drift_ptp"].values) -# assert all(np.isnan(metric) for metric in metrics["drift_std"].values) -# assert all(np.isnan(metric) for metric in metrics["drift_mad"].values) - -# # finally let's use an interval compatible with segment durations -# qm_params = dict(drift=dict(interval_s=total_duration // 10, min_spikes_per_interval=10)) -# with warnings.catch_warnings(): -# warnings.simplefilter("error") -# metrics = self.extension_class.get_extension_function()(we, metric_names=["drift"], qm_params=qm_params) -# # print(metrics) -# assert all(not np.isnan(metric) for metric in metrics["drift_ptp"].values) -# assert all(not np.isnan(metric) for metric in metrics["drift_std"].values) -# assert all(not np.isnan(metric) for metric in metrics["drift_mad"].values) - -# def test_peak_sign(self): -# we = self.we_long -# rec = we.recording -# sort = we.sorting - -# # invert recording -# rec_inv = scale(rec, gain=-1.0) - -# we_inv = extract_waveforms(rec_inv, sort, cache_folder / "toy_waveforms_inv", seed=0) - -# # compute amplitudes -# _ = compute_spike_amplitudes(we, peak_sign="neg") -# _ = compute_spike_amplitudes(we_inv, peak_sign="pos") - -# # without PC -# metrics = self.extension_class.get_extension_function()( -# we, metric_names=["snr", "amplitude_cutoff"], peak_sign="neg" -# ) -# metrics_inv = self.extension_class.get_extension_function()( -# we_inv, metric_names=["snr", "amplitude_cutoff"], peak_sign="pos" -# ) -# # print(metrics) -# # print(metrics_inv) -# # for SNR we allow a 5% tollerance because of waveform sub-sampling -# assert np.allclose(metrics["snr"].values, metrics_inv["snr"].values, rtol=0.05) -# # for amplitude_cutoff, since spike amplitudes are computed, values should be exactly the same -# assert np.allclose(metrics["amplitude_cutoff"].values, metrics_inv["amplitude_cutoff"].values, atol=1e-3) - -# def test_nn_metrics(self): -# we_dense = self.we1 -# we_sparse = self.we_sparse -# sparsity = self.sparsity1 -# # print(sparsity) - -# metric_names = ["nearest_neighbor", "nn_isolation", "nn_noise_overlap"] - -# # with external sparsity on dense waveforms -# _ = compute_principal_components(we_dense, n_components=5, mode="by_channel_local") -# metrics = self.extension_class.get_extension_function()( -# we_dense, metric_names=metric_names, sparsity=sparsity, seed=0 -# ) -# # print(metrics) - -# # with sparse waveforms -# _ = compute_principal_components(we_sparse, n_components=5, mode="by_channel_local") -# metrics = self.extension_class.get_extension_function()( -# we_sparse, metric_names=metric_names, sparsity=None, seed=0 -# ) -# # print(metrics) - -# # with 2 jobs -# # with sparse waveforms -# _ = compute_principal_components(we_sparse, n_components=5, mode="by_channel_local") -# metrics_par = self.extension_class.get_extension_function()( -# we_sparse, metric_names=metric_names, sparsity=None, seed=0, n_jobs=2 -# ) -# for metric_name in metrics.columns: -# # NaNs are skipped -# assert np.allclose(metrics[metric_name].dropna(), metrics_par[metric_name].dropna()) - -if __name__ == "__main__": - - sorting_analyzer = get_sorting_analyzer() - print(sorting_analyzer) - - test_compute_quality_metrics(sorting_analyzer) - test_compute_quality_metrics_recordingless(sorting_analyzer) - test_empty_units(sorting_analyzer) diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index d4d1f0df67..b3942b071a 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -20,17 +20,22 @@ from spikeinterface.core.recording_tools import get_noise_levels -def make_multi_method_doc(methods, ident=" "): +def make_multi_method_doc(methods, indent=" "): doc = "" doc += "method : " + ", ".join(f"'{method.name}'" for method in methods) + "\n" - doc += ident + " Method to use.\n" + doc += indent + "Method to use.\n" for method in methods: doc += "\n" - doc += ident + ident + f"arguments for method='{method.name}'" + doc += indent + indent + f"* arguments for method='{method.name}'\n" for line in method.params_doc.splitlines(): - doc += ident + ident + line + "\n" + # add '* ' before the start of the text of each line + if len(line.strip()) == 0: + continue + line = line.lstrip() + line = "* " + line + doc += indent + indent + indent + line + "\n" return doc