Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/spikeinterface/benchmark/benchmark_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,11 @@ def plot_performances_vs_snr(self, **kwargs):

return plot_performances_vs_snr(self, **kwargs)

def plot_performances_vs_firing_rate(self, **kwargs):
from .benchmark_plot_tools import plot_performances_vs_firing_rate

return plot_performances_vs_firing_rate(self, **kwargs)

def plot_performances_comparison(self, *args, **kwargs):
from .benchmark_plot_tools import plot_performances_comparison

Expand Down
205 changes: 158 additions & 47 deletions src/spikeinterface/benchmark/benchmark_plot_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,8 @@ def plot_agreement_matrix(study, ordered=True, case_keys=None, axs=None):

num_axes = len(case_keys)
if axs is None:
fig, axs = plt.subplots(ncols=num_axes, squeeze=True)
fig, axs = plt.subplots(ncols=num_axes, squeeze=False)
axs = axs[0, :]
else:
assert len(axs) == num_axes, "axs should have the same number of axes as case_keys"
fig = axs[0].get_figure()
Expand All @@ -399,12 +400,13 @@ def plot_agreement_matrix(study, ordered=True, case_keys=None, axs=None):
return fig


def plot_performances_vs_snr(
def _plot_performances_vs_metric(
study,
metric_name,
case_keys=None,
figsize=None,
performance_names=("accuracy", "recall", "precision"),
snr_dataset_reference=None,
metric_dataset_reference=None,
levels_to_group_by=None,
orientation="vertical",
show_legend=True,
Expand All @@ -414,43 +416,6 @@ def plot_performances_vs_snr(
num_bin_average=20,
axs=None,
):
"""
Plots performance metrics against signal-to-noise ratio (SNR) for different cases in a study.

Parameters
----------
study : object
The study object containing the cases and results.
case_keys : list | None, default: None
List of case keys to include in the plot. If None, all cases in the study are included.
figsize : tuple | None, default: None
Size of the figure.
performance_names : tuple, default: ("accuracy", "recall", "precision")
Names of the performance metrics to plot. Default is ("accuracy", "recall", "precision").
snr_dataset_reference : str | None, default: None
Reference dataset key to use for SNR. If None, the SNR of each dataset is used.
levels_to_group_by : list | None, default: None
Levels to group by when mapping case keys.
orientation : "vertical" | "horizontal", default: "vertical"
The orientation of the plot.
show_legend : bool, default True
Show legend or not
show_sigmoid_fit : bool, default True
Show sigmoid that fit the performances.
show_average_by_bin : bool, default False
Instead of the sigmoid an average by bins can be plotted.
scatter_size : int, default 4
scatter size
num_bin_average : int, default 2
Num bin for average
axs : matplotlib.axes.Axes | None, default: None
The axs to use for plotting. Should be the same size as len(performance_names).

Returns
-------
fig : matplotlib.figure.Figure
The resulting figure containing the plots.
"""
import matplotlib.pyplot as plt

if case_keys is None:
Expand Down Expand Up @@ -499,15 +464,15 @@ def plot_performances_vs_snr(
all_xs = []
all_ys = []
for sub_key in key_list:
if snr_dataset_reference is None:
if metric_dataset_reference is None:
# use the SNR of each dataset
analyzer = study.get_sorting_analyzer(sub_key)
else:
# use the same SNR from a reference dataset
analyzer = study.get_sorting_analyzer(dataset_key=snr_dataset_reference)
analyzer = study.get_sorting_analyzer(dataset_key=metric_dataset_reference)

quality_metrics = analyzer.get_extension("quality_metrics").get_data()
x = quality_metrics["snr"].to_numpy(dtype="float64")
x = quality_metrics[metric_name].to_numpy(dtype="float64")
y = (
study.get_result(sub_key)["gt_comparison"]
.get_performance()[performance_name]
Expand Down Expand Up @@ -564,6 +529,148 @@ def plot_performances_vs_snr(
return fig


def plot_performances_vs_snr(
study,
case_keys=None,
figsize=None,
performance_names=("accuracy", "recall", "precision"),
metric_dataset_reference=None,
levels_to_group_by=None,
orientation="vertical",
show_legend=True,
with_sigmoid_fit=False,
show_average_by_bin=True,
scatter_size=4,
num_bin_average=20,
axs=None,
):
"""
Plots performance metrics against signal-to-noise ratio (SNR) for different cases in a study.

Parameters
----------
study : object
The study object containing the cases and results.
case_keys : list | None, default: None
List of case keys to include in the plot. If None, all cases in the study are included.
figsize : tuple | None, default: None
Size of the figure.
performance_names : tuple, default: ("accuracy", "recall", "precision")
Names of the performance metrics to plot. Default is ("accuracy", "recall", "precision").
metric_dataset_reference : str | None, default: None
Reference dataset metric key to use. If None, the SNR of each dataset is used.
levels_to_group_by : list | None, default: None
Levels to group by when mapping case keys.
orientation : "vertical" | "horizontal", default: "vertical"
The orientation of the plot.
show_legend : bool, default True
Show legend or not
show_sigmoid_fit : bool, default True
Show sigmoid that fit the performances.
show_average_by_bin : bool, default False
Instead of the sigmoid an average by bins can be plotted.
scatter_size : int, default 4
scatter size
num_bin_average : int, default 2
Num bin for average
axs : matplotlib.axes.Axes | None, default: None
The axs to use for plotting. Should be the same size as len(performance_names).

Returns
-------
fig : matplotlib.figure.Figure
The resulting figure containing the plots.
"""

return _plot_performances_vs_metric(
study,
"snr",
case_keys=case_keys,
figsize=figsize,
performance_names=performance_names,
metric_dataset_reference=metric_dataset_reference,
levels_to_group_by=levels_to_group_by,
orientation=orientation,
show_legend=show_legend,
with_sigmoid_fit=with_sigmoid_fit,
show_average_by_bin=show_average_by_bin,
scatter_size=scatter_size,
num_bin_average=num_bin_average,
axs=axs,
)


def plot_performances_vs_firing_rate(
study,
case_keys=None,
figsize=None,
performance_names=("accuracy", "recall", "precision"),
metric_dataset_reference=None,
levels_to_group_by=None,
orientation="vertical",
show_legend=True,
with_sigmoid_fit=False,
show_average_by_bin=True,
scatter_size=4,
num_bin_average=20,
axs=None,
):
"""
Plots performance metrics against firing rate for different cases in a study.

Parameters
----------
study : object
The study object containing the cases and results.
case_keys : list | None, default: None
List of case keys to include in the plot. If None, all cases in the study are included.
figsize : tuple | None, default: None
Size of the figure.
performance_names : tuple, default: ("accuracy", "recall", "precision")
Names of the performance metrics to plot. Default is ("accuracy", "recall", "precision").
metric_dataset_reference : str | None, default: None
Reference dataset metric key to use. If None, the SNR of each dataset is used.
levels_to_group_by : list | None, default: None
Levels to group by when mapping case keys.
orientation : "vertical" | "horizontal", default: "vertical"
The orientation of the plot.
show_legend : bool, default True
Show legend or not
show_sigmoid_fit : bool, default True
Show sigmoid that fit the performances.
show_average_by_bin : bool, default False
Instead of the sigmoid an average by bins can be plotted.
scatter_size : int, default 4
scatter size
num_bin_average : int, default 2
Num bin for average
axs : matplotlib.axes.Axes | None, default: None
The axs to use for plotting. Should be the same size as len(performance_names).

Returns
-------
fig : matplotlib.figure.Figure
The resulting figure containing the plots.
"""

return _plot_performances_vs_metric(
study,
"firing_rate",
case_keys=case_keys,
figsize=figsize,
performance_names=performance_names,
metric_dataset_reference=metric_dataset_reference,
levels_to_group_by=levels_to_group_by,
orientation=orientation,
show_legend=show_legend,
with_sigmoid_fit=with_sigmoid_fit,
show_average_by_bin=show_average_by_bin,
scatter_size=scatter_size,
num_bin_average=num_bin_average,
axs=axs,
)


def plot_performances_ordered(
study,
case_keys=None,
Expand Down Expand Up @@ -614,9 +721,11 @@ def plot_performances_ordered(

if axs is None:
if orientation == "vertical":
fig, axs = plt.subplots(nrows=num_axes, figsize=figsize, squeeze=True)
fig, axs = plt.subplots(nrows=num_axes, figsize=figsize, squeeze=False)
axs = axs[:, 0]
elif orientation == "horizontal":
fig, axs = plt.subplots(ncols=num_axes, figsize=figsize, squeeze=True)
fig, axs = plt.subplots(ncols=num_axes, figsize=figsize, squeeze=False)
axs = axs[0, :]
else:
raise ValueError("orientation must be 'vertical' or 'horizontal'")
else:
Expand Down Expand Up @@ -850,7 +959,8 @@ def plot_performances_vs_depth_and_snr(
case_keys, labels = study.get_grouped_keys_mapping(levels_to_group_by=levels_to_group_by, case_keys=case_keys)

if axs is None:
fig, axs = plt.subplots(ncols=len(case_keys), figsize=figsize, squeeze=True)
fig, axs = plt.subplots(ncols=len(case_keys), figsize=figsize, squeeze=False)
axs = axs[0, :]
else:
assert len(axs) == len(case_keys), "axs should have the same number of axes as case_keys"
fig = axs[0].get_figure()
Expand Down Expand Up @@ -927,7 +1037,8 @@ def plot_performance_losses(
import matplotlib.pyplot as plt

if axs is None:
fig, axs = plt.subplots(nrows=len(performance_names), figsize=figsize, squeeze=True)
fig, axs = plt.subplots(nrows=len(performance_names), figsize=figsize, squeeze=False)
axs = axs[:, 0]
else:
assert len(axs) == len(performance_names), "axs should have the same number of axes as performance_names"
fig = axs[0].get_figure()
Expand Down
5 changes: 5 additions & 0 deletions src/spikeinterface/benchmark/benchmark_sorter.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ def plot_performances_vs_snr(self, **kwargs):

return plot_performances_vs_snr(self, **kwargs)

def plot_performances_vs_firing_rate(self, **kwargs):
from .benchmark_plot_tools import plot_performances_vs_firing_rate

return plot_performances_vs_firing_rate(self, **kwargs)

def plot_performances_ordered(self, **kwargs):
from .benchmark_plot_tools import plot_performances_ordered

Expand Down