Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions src/spatialdata_plot/pl/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,12 +970,16 @@ def show(
assert isinstance(params_copy.color, str)
colors = sc.get.obs_df(sdata[table], [params_copy.color])
if isinstance(colors[params_copy.color].dtype, pd.CategoricalDtype):
_maybe_set_colors(
source=sdata[table],
target=sdata[table],
key=params_copy.color,
palette=params_copy.palette,
)
# Avoid mutating `.uns` by generating new colors implicitly.
# Only copy colors if they already exist in `.uns`.
color_key = f"{params_copy.color}_colors"
if color_key in sdata[table].uns:
_maybe_set_colors(
source=sdata[table],
target=sdata[table],
key=params_copy.color,
palette=params_copy.palette,
)

rasterize = (params_copy.scale is None) or (
isinstance(params_copy.scale, str)
Expand Down
190 changes: 177 additions & 13 deletions src/spatialdata_plot/pl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,7 +760,6 @@ def _set_color_source_vec(
)[value_to_plot]

# numerical case, return early
# TODO temporary split until refactor is complete
if color_source_vector is not None and not isinstance(color_source_vector.dtype, pd.CategoricalDtype):
if (
not isinstance(element, GeoDataFrame)
Expand All @@ -777,18 +776,50 @@ def _set_color_source_vec(

color_source_vector = pd.Categorical(color_source_vector) # convert, e.g., `pd.Series`

# TODO check why table_name is not passed here.
color_mapping = _get_categorical_color_mapping(
adata=sdata["table"],
cluster_key=value_to_plot,
color_source_vector=color_source_vector,
cmap_params=cmap_params,
alpha=alpha,
groups=groups,
palette=palette,
na_color=na_color,
render_type=render_type,
)
# Use the provided table_name parameter, fall back to only one present
if table_name is not None:
table_to_use = table_name
else:
table_to_use = list(sdata.tables.keys())[0]
logger.warning(f"No table name provided, using '{table_to_use}' as fallback for color mapping.")

# Check if custom colors exist in the table's .uns slot
if value_to_plot is not None and _has_colors_in_uns(sdata, table_name, value_to_plot):
# Extract colors directly from the table's .uns slot
color_mapping = _extract_colors_from_table_uns(
sdata=sdata,
table_name=table_name,
col_to_colorby=value_to_plot,
color_source_vector=color_source_vector,
na_color=na_color,
)
if color_mapping is None:
logger.warning(f"Failed to extract colors for '{value_to_plot}', falling back to default mapping.")
# Fall back to the existing method if extraction fails
color_mapping = _get_categorical_color_mapping(
adata=sdata[table_to_use],
cluster_key=value_to_plot,
color_source_vector=color_source_vector,
cmap_params=cmap_params,
alpha=alpha,
groups=groups,
palette=palette,
na_color=na_color,
render_type=render_type,
)
else:
# Use the existing color mapping method
color_mapping = _get_categorical_color_mapping(
adata=sdata[table_to_use],
cluster_key=value_to_plot,
color_source_vector=color_source_vector,
cmap_params=cmap_params,
alpha=alpha,
groups=groups,
palette=palette,
na_color=na_color,
render_type=render_type,
)

color_source_vector = color_source_vector.set_categories(color_mapping.keys())
if color_mapping is None:
Expand Down Expand Up @@ -897,6 +928,139 @@ def _generate_base_categorial_color_mapping(
return _get_default_categorial_color_mapping(color_source_vector=color_source_vector, cmap_params=cmap_params)


def _has_colors_in_uns(
sdata: sd.SpatialData,
table_name: str | None,
col_to_colorby: str,
) -> bool:
"""
Check if <column_name>_colors exists in the specified table's .uns slot.

Parameters
----------
sdata
SpatialData object containing tables
table_name
Name of the table to check. If None, uses the first available table.
col_to_colorby
Name of the categorical column (e.g., "celltype")

Returns
-------
True if <col_to_colorby>_colors exists in the table's .uns, False otherwise
"""
# Determine which table to use
if table_name is not None:
if table_name not in sdata.tables:
return False
table_to_use = table_name
else:
if len(sdata.tables.keys()) == 0:
return False
table_to_use = list(sdata.tables.keys())[0]

adata = sdata.tables[table_to_use]
color_key = f"{col_to_colorby}_colors"
return color_key in adata.uns


def _extract_colors_from_table_uns(
sdata: sd.SpatialData,
table_name: str | None,
col_to_colorby: str,
color_source_vector: ArrayLike | pd.Series[CategoricalDtype],
na_color: ColorLike,
) -> Mapping[str, str] | None:
"""
Extract categorical colors from the <column_name>_colors pattern in adata.uns.

This function looks for colors stored in the format <col_to_colorby>_colors in the
specified table's .uns slot and creates a mapping from categories to colors.

Parameters
----------
sdata
SpatialData object containing tables
table_name
Name of the table to look in. If None, uses the first available table.
col_to_colorby
Name of the categorical column (e.g., "celltype")
color_source_vector
Categorical vector containing the categories to map
na_color
Color to use for NaN/missing values

Returns
-------
Mapping from category names to hex colors, or None if colors not found
"""
# Determine which table to use
if table_name is not None:
if table_name not in sdata.tables:
logger.warning(f"Table '{table_name}' not found in sdata. Available tables: {list(sdata.tables.keys())}")
return None
table_to_use = table_name
else:
if len(sdata.tables) == 0:
logger.warning("No tables found in sdata.")
return None
table_to_use = list(sdata.tables.keys())[0]
logger.info(f"No table name provided, using '{table_to_use}' for color extraction.")

adata = sdata.tables[table_to_use]
color_key = f"{col_to_colorby}_colors"

# Check if the color pattern exists
if color_key not in adata.uns:
logger.debug(f"Color key '{color_key}' not found in table '{table_to_use}' uns.")
return None

# Extract colors and categories
stored_colors = adata.uns[color_key]
categories = color_source_vector.categories.tolist()

# Validate na_color format
if "#" not in str(na_color):
logger.warning("Expected `na_color` to be a hex color, converting...")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, do we really expect that the user passes a hex?

na_color = to_hex(to_rgba(na_color)[:3])

# Strip alpha channel from na_color if present
if len(str(na_color)) == 9: # #rrggbbaa format
na_color = str(na_color)[:7] # Keep only #rrggbb

# Convert stored colors to hex format (without alpha channel)
try:
hex_colors = []
for color in stored_colors:
rgba = to_rgba(color)[:3] # Take only RGB, drop alpha
hex_color = to_hex(rgba)
# Ensure we strip alpha channel if present
if len(hex_color) == 9: # #rrggbbaa format
hex_color = hex_color[:7] # Keep only #rrggbb
hex_colors.append(hex_color)
except (TypeError, ValueError) as e:
logger.warning(f"Error converting colors to hex format: {e}")
return None

# Create the mapping
color_mapping = {}

# Map categories to colors
for i, category in enumerate(categories):
if i < len(hex_colors):
color_mapping[category] = hex_colors[i]
else:
# Not enough colors provided, use na_color for extra categories
logger.warning(f"Not enough colors provided for category '{category}', using na_color.")
color_mapping[category] = na_color

# Add NaN category
color_mapping["NaN"] = na_color

logger.info(f"Successfully extracted {len(hex_colors)} colors from '{color_key}' in table '{table_to_use}'.")
return color_mapping


def _modify_categorical_color_mapping(
mapping: Mapping[str, str],
groups: list[str] | str | None = None,
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
24 changes: 24 additions & 0 deletions tests/pl/test_render_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,30 @@ def test_plot_can_handle_dropping_small_labels_after_rasterize_categorical(self,

sdata_blobs.pl.render_labels("blobs_labels_large", color="category", table_name="table").pl.show()

def test_plot_respects_custom_colors_from_uns(self, sdata_blobs: SpatialData):
labels_name = "blobs_labels"
instances = get_element_instances(sdata_blobs[labels_name])
n_obs = len(instances)
adata = AnnData(
get_standard_RNG().normal(size=(n_obs, 10)),
obs=pd.DataFrame(get_standard_RNG().normal(size=(n_obs, 3)), columns=["a", "b", "c"]),
)
adata.obs["instance_id"] = instances.values
adata.obs["category"] = get_standard_RNG().choice(["a", "b", "c"], size=adata.n_obs)
adata.obs["category"][:3] = ["a", "b", "c"]
adata.obs["region"] = labels_name
table = TableModel.parse(
adata=adata,
region_key="region",
instance_key="instance_id",
region=labels_name,
)
sdata_blobs["other_table"] = table
sdata_blobs["other_table"].obs["category"] = sdata_blobs["other_table"].obs["category"].astype("category")
sdata_blobs["other_table"].uns["category_colors"] = ["#800080", "#008000", "#FFFF00"] # purple, green ,yellow

sdata_blobs.pl.render_labels("blobs_labels", color="category").pl.show()


def test_warns_when_table_does_not_annotate_element(sdata_blobs: SpatialData):
# Work on an independent copy since we mutate tables
Expand Down
26 changes: 25 additions & 1 deletion tests/pl/test_render_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from anndata import AnnData
from matplotlib.colors import Normalize
from shapely.geometry import MultiPolygon, Point, Polygon
from spatialdata import SpatialData, deepcopy
from spatialdata import SpatialData, deepcopy, get_element_instances
from spatialdata.models import ShapesModel, TableModel
from spatialdata.transformations import Affine, Identity, MapAxis, Scale, Sequence, Translation
from spatialdata.transformations._utils import _set_transformations
Expand Down Expand Up @@ -562,6 +562,30 @@ def test_plot_can_annotate_shapes_with_table_layer(self, sdata_blobs: SpatialDat

sdata_blobs.pl.render_shapes("blobs_circles", color="feature0", table_layer="normalized").pl.show()

def test_plot_respects_custom_colors_from_uns(self, sdata_blobs: SpatialData):
labels_name = "blobs_labels"
instances = get_element_instances(sdata_blobs[labels_name])
n_obs = len(instances)
adata = AnnData(
get_standard_RNG().normal(size=(n_obs, 10)),
obs=pd.DataFrame(get_standard_RNG().normal(size=(n_obs, 3)), columns=["a", "b", "c"]),
)
adata.obs["instance_id"] = instances.values
adata.obs["category"] = get_standard_RNG().choice(["a", "b", "c"], size=adata.n_obs)
adata.obs["category"][:3] = ["a", "b", "c"]
adata.obs["region"] = labels_name
table = TableModel.parse(
adata=adata,
region_key="region",
instance_key="instance_id",
region=labels_name,
)
sdata_blobs["other_table"] = table
sdata_blobs["other_table"].obs["category"] = sdata_blobs["other_table"].obs["category"].astype("category")
sdata_blobs["other_table"].uns["category_colors"] = ["#800080", "#008000", "#FFFF00"] # purple, green ,yellow

sdata_blobs.pl.render_labels("blobs_labels", color="category").pl.show()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct me if I'm wrong but shouldn't this all be about shapes and not plotting the labels again?



def test_warns_when_table_does_not_annotate_element(sdata_blobs: SpatialData):
# Work on an independent copy since we mutate tables
Expand Down