From 5ab0df7d32f54a06be95ef2cc3c6663896f8ae2e Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 14 Oct 2025 18:03:35 +0200 Subject: [PATCH 1/6] Add plot_performances_vs_firing_rate --- .../benchmark/benchmark_clustering.py | 6 + .../benchmark/benchmark_plot_tools.py | 194 ++++++++++++++---- 2 files changed, 157 insertions(+), 43 deletions(-) diff --git a/src/spikeinterface/benchmark/benchmark_clustering.py b/src/spikeinterface/benchmark/benchmark_clustering.py index 8a3c9e78d9..23e6c4d435 100644 --- a/src/spikeinterface/benchmark/benchmark_clustering.py +++ b/src/spikeinterface/benchmark/benchmark_clustering.py @@ -192,6 +192,12 @@ def plot_performances_vs_snr(self, **kwargs): return plot_performances_vs_snr(self, **kwargs) + def plot_performances_vs_firing_rate(self, **kwargs): + from .benchmark_plot_tools import plot_performances_vs_firing_rate + + return plot_performances_vs_firing_rate(self, **kwargs) + + def plot_performances_comparison(self, *args, **kwargs): from .benchmark_plot_tools import plot_performances_comparison diff --git a/src/spikeinterface/benchmark/benchmark_plot_tools.py b/src/spikeinterface/benchmark/benchmark_plot_tools.py index 4a21029e03..7e629ecdf5 100644 --- a/src/spikeinterface/benchmark/benchmark_plot_tools.py +++ b/src/spikeinterface/benchmark/benchmark_plot_tools.py @@ -398,13 +398,13 @@ def plot_agreement_matrix(study, ordered=True, case_keys=None, axs=None): return fig - -def plot_performances_vs_snr( +def _plot_performances_vs_metric( study, + metric_name, case_keys=None, figsize=None, performance_names=("accuracy", "recall", "precision"), - snr_dataset_reference=None, + metric_dataset_reference=None, levels_to_group_by=None, orientation="vertical", show_legend=True, @@ -414,43 +414,6 @@ def plot_performances_vs_snr( num_bin_average=20, axs=None, ): - """ - Plots performance metrics against signal-to-noise ratio (SNR) for different cases in a study. - - Parameters - ---------- - study : object - The study object containing the cases and results. - case_keys : list | None, default: None - List of case keys to include in the plot. If None, all cases in the study are included. - figsize : tuple | None, default: None - Size of the figure. - performance_names : tuple, default: ("accuracy", "recall", "precision") - Names of the performance metrics to plot. Default is ("accuracy", "recall", "precision"). - snr_dataset_reference : str | None, default: None - Reference dataset key to use for SNR. If None, the SNR of each dataset is used. - levels_to_group_by : list | None, default: None - Levels to group by when mapping case keys. - orientation : "vertical" | "horizontal", default: "vertical" - The orientation of the plot. - show_legend : bool, default True - Show legend or not - show_sigmoid_fit : bool, default True - Show sigmoid that fit the performances. - show_average_by_bin : bool, default False - Instead of the sigmoid an average by bins can be plotted. - scatter_size : int, default 4 - scatter size - num_bin_average : int, default 2 - Num bin for average - axs : matplotlib.axes.Axes | None, default: None - The axs to use for plotting. Should be the same size as len(performance_names). - - Returns - ------- - fig : matplotlib.figure.Figure - The resulting figure containing the plots. - """ import matplotlib.pyplot as plt if case_keys is None: @@ -499,15 +462,15 @@ def plot_performances_vs_snr( all_xs = [] all_ys = [] for sub_key in key_list: - if snr_dataset_reference is None: + if metric_dataset_reference is None: # use the SNR of each dataset analyzer = study.get_sorting_analyzer(sub_key) else: # use the same SNR from a reference dataset - analyzer = study.get_sorting_analyzer(dataset_key=snr_dataset_reference) + analyzer = study.get_sorting_analyzer(dataset_key=metric_dataset_reference) quality_metrics = analyzer.get_extension("quality_metrics").get_data() - x = quality_metrics["snr"].to_numpy(dtype="float64") + x = quality_metrics[metric_name].to_numpy(dtype="float64") y = ( study.get_result(sub_key)["gt_comparison"] .get_performance()[performance_name] @@ -564,6 +527,151 @@ def plot_performances_vs_snr( return fig +def plot_performances_vs_snr( + study, + case_keys=None, + figsize=None, + performance_names=("accuracy", "recall", "precision"), + metric_dataset_reference=None, + levels_to_group_by=None, + orientation="vertical", + show_legend=True, + with_sigmoid_fit=False, + show_average_by_bin=True, + scatter_size=4, + num_bin_average=20, + axs=None, +): + """ + Plots performance metrics against signal-to-noise ratio (SNR) for different cases in a study. + + Parameters + ---------- + study : object + The study object containing the cases and results. + case_keys : list | None, default: None + List of case keys to include in the plot. If None, all cases in the study are included. + figsize : tuple | None, default: None + Size of the figure. + performance_names : tuple, default: ("accuracy", "recall", "precision") + Names of the performance metrics to plot. Default is ("accuracy", "recall", "precision"). + metric_dataset_reference : str | None, default: None + Reference dataset key to use for SNR. If None, the SNR of each dataset is used. + levels_to_group_by : list | None, default: None + Levels to group by when mapping case keys. + orientation : "vertical" | "horizontal", default: "vertical" + The orientation of the plot. + show_legend : bool, default True + Show legend or not + show_sigmoid_fit : bool, default True + Show sigmoid that fit the performances. + show_average_by_bin : bool, default False + Instead of the sigmoid an average by bins can be plotted. + scatter_size : int, default 4 + scatter size + num_bin_average : int, default 2 + Num bin for average + axs : matplotlib.axes.Axes | None, default: None + The axs to use for plotting. Should be the same size as len(performance_names). + + Returns + ------- + fig : matplotlib.figure.Figure + The resulting figure containing the plots. + """ + + return _plot_performances_vs_metric( + study, + "snr", + case_keys=case_keys, + figsize=figsize, + performance_names=performance_names, + metric_dataset_reference=metric_dataset_reference, + levels_to_group_by=levels_to_group_by, + orientation=orientation, + show_legend=show_legend, + with_sigmoid_fit=with_sigmoid_fit, + show_average_by_bin=show_average_by_bin, + scatter_size=scatter_size, + num_bin_average=num_bin_average, + axs=axs, + ) + +def plot_performances_vs_firing_rate( + study, + case_keys=None, + figsize=None, + performance_names=("accuracy", "recall", "precision"), + metric_dataset_reference=None, + levels_to_group_by=None, + orientation="vertical", + show_legend=True, + with_sigmoid_fit=False, + show_average_by_bin=True, + scatter_size=4, + num_bin_average=20, + axs=None, +): + """ + Plots performance metrics against firing rate for different cases in a study. + + Parameters + ---------- + study : object + The study object containing the cases and results. + case_keys : list | None, default: None + List of case keys to include in the plot. If None, all cases in the study are included. + figsize : tuple | None, default: None + Size of the figure. + performance_names : tuple, default: ("accuracy", "recall", "precision") + Names of the performance metrics to plot. Default is ("accuracy", "recall", "precision"). + metric_dataset_reference : str | None, default: None + Reference dataset key to use for SNR. If None, the SNR of each dataset is used. + levels_to_group_by : list | None, default: None + Levels to group by when mapping case keys. + orientation : "vertical" | "horizontal", default: "vertical" + The orientation of the plot. + show_legend : bool, default True + Show legend or not + show_sigmoid_fit : bool, default True + Show sigmoid that fit the performances. + show_average_by_bin : bool, default False + Instead of the sigmoid an average by bins can be plotted. + scatter_size : int, default 4 + scatter size + num_bin_average : int, default 2 + Num bin for average + axs : matplotlib.axes.Axes | None, default: None + The axs to use for plotting. Should be the same size as len(performance_names). + + Returns + ------- + fig : matplotlib.figure.Figure + The resulting figure containing the plots. + """ + + return _plot_performances_vs_metric( + study, + "firing_rate", + case_keys=case_keys, + figsize=figsize, + performance_names=performance_names, + metric_dataset_reference=metric_dataset_reference, + levels_to_group_by=levels_to_group_by, + orientation=orientation, + show_legend=show_legend, + with_sigmoid_fit=with_sigmoid_fit, + show_average_by_bin=show_average_by_bin, + scatter_size=scatter_size, + num_bin_average=num_bin_average, + axs=axs, + ) + + + + + + def plot_performances_ordered( study, case_keys=None, From 8fcb1754e5e76c8c80849001674feacccf6c3cf8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 17 Oct 2025 10:01:25 +0000 Subject: [PATCH 2/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/benchmark/benchmark_clustering.py | 1 - src/spikeinterface/benchmark/benchmark_plot_tools.py | 6 ++---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/benchmark/benchmark_clustering.py b/src/spikeinterface/benchmark/benchmark_clustering.py index 23e6c4d435..3885fe073c 100644 --- a/src/spikeinterface/benchmark/benchmark_clustering.py +++ b/src/spikeinterface/benchmark/benchmark_clustering.py @@ -197,7 +197,6 @@ def plot_performances_vs_firing_rate(self, **kwargs): return plot_performances_vs_firing_rate(self, **kwargs) - def plot_performances_comparison(self, *args, **kwargs): from .benchmark_plot_tools import plot_performances_comparison diff --git a/src/spikeinterface/benchmark/benchmark_plot_tools.py b/src/spikeinterface/benchmark/benchmark_plot_tools.py index 7e629ecdf5..28b7575429 100644 --- a/src/spikeinterface/benchmark/benchmark_plot_tools.py +++ b/src/spikeinterface/benchmark/benchmark_plot_tools.py @@ -398,6 +398,7 @@ def plot_agreement_matrix(study, ordered=True, case_keys=None, axs=None): return fig + def _plot_performances_vs_metric( study, metric_name, @@ -597,6 +598,7 @@ def plot_performances_vs_snr( axs=axs, ) + def plot_performances_vs_firing_rate( study, case_keys=None, @@ -668,10 +670,6 @@ def plot_performances_vs_firing_rate( ) - - - - def plot_performances_ordered( study, case_keys=None, From 22c3df54d4d690fc0d4a20c364b1f0be0c1e9f71 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 24 Oct 2025 09:46:28 +0200 Subject: [PATCH 3/6] Better handling of subplots() squeeze in plot_benchmark --- .../benchmark/benchmark_plot_tools.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/benchmark/benchmark_plot_tools.py b/src/spikeinterface/benchmark/benchmark_plot_tools.py index 28b7575429..8086e8e17e 100644 --- a/src/spikeinterface/benchmark/benchmark_plot_tools.py +++ b/src/spikeinterface/benchmark/benchmark_plot_tools.py @@ -373,7 +373,8 @@ def plot_agreement_matrix(study, ordered=True, case_keys=None, axs=None): num_axes = len(case_keys) if axs is None: - fig, axs = plt.subplots(ncols=num_axes, squeeze=True) + fig, axs = plt.subplots(ncols=num_axes, squeeze=False) + axs = axs[0, :] else: assert len(axs) == num_axes, "axs should have the same number of axes as case_keys" fig = axs[0].get_figure() @@ -720,9 +721,11 @@ def plot_performances_ordered( if axs is None: if orientation == "vertical": - fig, axs = plt.subplots(nrows=num_axes, figsize=figsize, squeeze=True) + fig, axs = plt.subplots(nrows=num_axes, figsize=figsize, squeeze=False) + axs = axs[:, 0] elif orientation == "horizontal": - fig, axs = plt.subplots(ncols=num_axes, figsize=figsize, squeeze=True) + fig, axs = plt.subplots(ncols=num_axes, figsize=figsize, squeeze=False) + axs = axs[0, :] else: raise ValueError("orientation must be 'vertical' or 'horizontal'") else: @@ -956,7 +959,8 @@ def plot_performances_vs_depth_and_snr( case_keys, labels = study.get_grouped_keys_mapping(levels_to_group_by=levels_to_group_by, case_keys=case_keys) if axs is None: - fig, axs = plt.subplots(ncols=len(case_keys), figsize=figsize, squeeze=True) + fig, axs = plt.subplots(ncols=len(case_keys), figsize=figsize, squeeze=False) + axs = axs[0, :] else: assert len(axs) == len(case_keys), "axs should have the same number of axes as case_keys" fig = axs[0].get_figure() @@ -1033,7 +1037,8 @@ def plot_performance_losses( import matplotlib.pyplot as plt if axs is None: - fig, axs = plt.subplots(nrows=len(performance_names), figsize=figsize, squeeze=True) + fig, axs = plt.subplots(nrows=len(performance_names), figsize=figsize, squeeze=False) + axs = axs[:, 0] else: assert len(axs) == len(performance_names), "axs should have the same number of axes as performance_names" fig = axs[0].get_figure() From ecc7efca14dcbb2200d97efc3d5875d97f539a8f Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 24 Oct 2025 09:46:42 +0200 Subject: [PATCH 4/6] plot_performances_vs_firing_rate in sorter study --- src/spikeinterface/benchmark/benchmark_sorter.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/spikeinterface/benchmark/benchmark_sorter.py b/src/spikeinterface/benchmark/benchmark_sorter.py index 88bff28ecd..f068e6b5d1 100644 --- a/src/spikeinterface/benchmark/benchmark_sorter.py +++ b/src/spikeinterface/benchmark/benchmark_sorter.py @@ -88,6 +88,11 @@ def plot_performances_vs_snr(self, **kwargs): from .benchmark_plot_tools import plot_performances_vs_snr return plot_performances_vs_snr(self, **kwargs) + + def plot_performances_vs_firing_rate(self, **kwargs): + from .benchmark_plot_tools import plot_performances_vs_firing_rate + + return plot_performances_vs_firing_rate(self, **kwargs) def plot_performances_ordered(self, **kwargs): from .benchmark_plot_tools import plot_performances_ordered From 018d48d4cf94770156fdf7940a0281fcd7efbc44 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 24 Oct 2025 07:48:26 +0000 Subject: [PATCH 5/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/benchmark/benchmark_sorter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/benchmark/benchmark_sorter.py b/src/spikeinterface/benchmark/benchmark_sorter.py index f068e6b5d1..6284b1c82d 100644 --- a/src/spikeinterface/benchmark/benchmark_sorter.py +++ b/src/spikeinterface/benchmark/benchmark_sorter.py @@ -88,7 +88,7 @@ def plot_performances_vs_snr(self, **kwargs): from .benchmark_plot_tools import plot_performances_vs_snr return plot_performances_vs_snr(self, **kwargs) - + def plot_performances_vs_firing_rate(self, **kwargs): from .benchmark_plot_tools import plot_performances_vs_firing_rate From 28a532cf396b7b23af8456b3a60191196eadfac1 Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Fri, 24 Oct 2025 10:29:19 +0200 Subject: [PATCH 6/6] merci alessio Co-authored-by: Alessio Buccino --- src/spikeinterface/benchmark/benchmark_plot_tools.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/benchmark/benchmark_plot_tools.py b/src/spikeinterface/benchmark/benchmark_plot_tools.py index 8086e8e17e..b32cf3df45 100644 --- a/src/spikeinterface/benchmark/benchmark_plot_tools.py +++ b/src/spikeinterface/benchmark/benchmark_plot_tools.py @@ -558,7 +558,7 @@ def plot_performances_vs_snr( performance_names : tuple, default: ("accuracy", "recall", "precision") Names of the performance metrics to plot. Default is ("accuracy", "recall", "precision"). metric_dataset_reference : str | None, default: None - Reference dataset key to use for SNR. If None, the SNR of each dataset is used. + Reference dataset metric key to use. If None, the SNR of each dataset is used. levels_to_group_by : list | None, default: None Levels to group by when mapping case keys. orientation : "vertical" | "horizontal", default: "vertical" @@ -629,7 +629,7 @@ def plot_performances_vs_firing_rate( performance_names : tuple, default: ("accuracy", "recall", "precision") Names of the performance metrics to plot. Default is ("accuracy", "recall", "precision"). metric_dataset_reference : str | None, default: None - Reference dataset key to use for SNR. If None, the SNR of each dataset is used. + Reference dataset metric key to use. If None, the SNR of each dataset is used. levels_to_group_by : list | None, default: None Levels to group by when mapping case keys. orientation : "vertical" | "horizontal", default: "vertical"