Skip to content

Commit e770345

Browse files
Sonja-StockhausSonja Stockhaustimtreis
authored
fix image rendering (clipping warning) (#471)
Co-authored-by: Sonja Stockhaus <[email protected]> Co-authored-by: Tim Treis <[email protected]> Co-authored-by: Tim Treis <[email protected]>
1 parent a3aa9e4 commit e770345

File tree

3 files changed

+98
-40
lines changed

3 files changed

+98
-40
lines changed

src/spatialdata_plot/pl/render.py

Lines changed: 59 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ def _render_shapes(
338338
cax = None
339339
if aggregate_with_reduction is not None:
340340
vmin = aggregate_with_reduction[0].values if norm.vmin is None else norm.vmin
341-
vmax = aggregate_with_reduction[1].values if norm.vmin is None else norm.vmax
341+
vmax = aggregate_with_reduction[1].values if norm.vmax is None else norm.vmax
342342
if (norm.vmin is not None or norm.vmax is not None) and norm.vmin == norm.vmax:
343343
assert norm.vmin is not None
344344
assert norm.vmax is not None
@@ -850,20 +850,22 @@ def _render_images(
850850
# 2) Image has any number of channels but 1
851851
else:
852852
layers = {}
853-
for ch_index, c in enumerate(channels):
854-
layers[c] = img.sel(c=c).copy(deep=True).squeeze()
855-
856-
if not isinstance(render_params.cmap_params, list):
857-
if render_params.cmap_params.norm is not None:
858-
layers[c] = render_params.cmap_params.norm(layers[c])
853+
for ch_idx, ch in enumerate(channels):
854+
layers[ch] = img.sel(c=ch).copy(deep=True).squeeze()
855+
if isinstance(render_params.cmap_params, list):
856+
ch_norm = render_params.cmap_params[ch_idx].norm
857+
ch_cmap_is_default = render_params.cmap_params[ch_idx].cmap_is_default
859858
else:
860-
if render_params.cmap_params[ch_index].norm is not None:
861-
layers[c] = render_params.cmap_params[ch_index].norm(layers[c])
859+
ch_norm = render_params.cmap_params.norm
860+
ch_cmap_is_default = render_params.cmap_params.cmap_is_default
861+
862+
if not ch_cmap_is_default and ch_norm is not None:
863+
layers[ch_idx] = ch_norm(layers[ch_idx])
862864

863865
# 2A) Image has 3 channels, no palette info, and no/only one cmap was given
864866
if palette is None and n_channels == 3 and not isinstance(render_params.cmap_params, list):
865867
if render_params.cmap_params.cmap_is_default: # -> use RGB
866-
stacked = np.stack([layers[c] for c in channels], axis=-1)
868+
stacked = np.stack([layers[ch] for ch in layers], axis=-1)
867869
else: # -> use given cmap for each channel
868870
channel_cmaps = [render_params.cmap_params.cmap] * n_channels
869871
stacked = (
@@ -896,12 +898,54 @@ def _render_images(
896898
# overwrite if n_channels == 2 for intuitive result
897899
if n_channels == 2:
898900
seed_colors = ["#ff0000ff", "#00ff00ff"]
899-
else:
901+
channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors]
902+
colored = np.stack(
903+
[channel_cmaps[ch_ind](layers[ch]) for ch_ind, ch in enumerate(channels)],
904+
0,
905+
).sum(0)
906+
colored = colored[:, :, :3]
907+
elif n_channels == 3:
900908
seed_colors = _get_colors_for_categorical_obs(list(range(n_channels)))
909+
channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors]
910+
colored = np.stack(
911+
[channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)],
912+
0,
913+
).sum(0)
914+
colored = colored[:, :, :3]
915+
else:
916+
if isinstance(render_params.cmap_params, list):
917+
cmap_is_default = render_params.cmap_params[0].cmap_is_default
918+
else:
919+
cmap_is_default = render_params.cmap_params.cmap_is_default
901920

902-
channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors]
903-
colored = np.stack([channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)], 0).sum(0)
904-
colored = colored[:, :, :3]
921+
if cmap_is_default:
922+
seed_colors = _get_colors_for_categorical_obs(list(range(n_channels)))
923+
else:
924+
# Sample n_channels colors evenly from the colormap
925+
if isinstance(render_params.cmap_params, list):
926+
seed_colors = [
927+
render_params.cmap_params[i].cmap(i / (n_channels - 1)) for i in range(n_channels)
928+
]
929+
else:
930+
seed_colors = [render_params.cmap_params.cmap(i / (n_channels - 1)) for i in range(n_channels)]
931+
channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors]
932+
933+
# Stack (n_channels, height, width) → (height*width, n_channels)
934+
H, W = next(iter(layers.values())).shape
935+
comp_rgb = np.zeros((H, W, 3), dtype=float)
936+
937+
# For each channel: map to RGBA, apply constant alpha, then add
938+
for ch_idx, ch in enumerate(channels):
939+
layer_arr = layers[ch]
940+
rgba = channel_cmaps[ch_idx](layer_arr)
941+
rgba[..., 3] = render_params.alpha
942+
comp_rgb += rgba[..., :3] * rgba[..., 3][..., None]
943+
944+
colored = np.clip(comp_rgb, 0, 1)
945+
logger.info(
946+
f"Your image has {n_channels} channels. Sampling categorical colors and using "
947+
f"multichannel strategy 'stack' to render."
948+
) # TODO: update when pca is added as strategy
905949

906950
_ax_show_and_transform(
907951
colored,
@@ -947,6 +991,7 @@ def _render_images(
947991
zorder=render_params.zorder,
948992
)
949993

994+
# 2D) Image has n channels, no palette but cmap info
950995
elif palette is not None and got_multiple_cmaps:
951996
raise ValueError("If 'palette' is provided, 'cmap' must be None.")
952997

src/spatialdata_plot/pl/utils.py

Lines changed: 39 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2008,7 +2008,7 @@ def _validate_col_for_column_table(
20082008
table_name = next(iter(tables))
20092009
if len(tables) > 1:
20102010
warnings.warn(
2011-
f"Multiple tables contain color column, using {table_name}",
2011+
f"Multiple tables contain column '{col_for_color}', using table '{table_name}'.",
20122012
UserWarning,
20132013
stacklevel=2,
20142014
)
@@ -2044,44 +2044,57 @@ def _validate_image_render_params(
20442044
element_params[el] = {}
20452045
spatial_element = param_dict["sdata"][el]
20462046

2047+
# robustly get channel names from image or multiscale image
20472048
spatial_element_ch = (
2048-
spatial_element.c if isinstance(spatial_element, DataArray) else spatial_element["scale0"].c
2049+
spatial_element.c.values if isinstance(spatial_element, DataArray) else spatial_element["scale0"].c.values
20492050
)
2050-
20512051
channel = param_dict["channel"]
2052-
channel_list: list[str] | list[int] | None
2053-
if isinstance(channel, list):
2054-
type_ = type(channel[0])
2055-
assert all(isinstance(ch, type_) for ch in channel), "All channels must be of the same type."
2056-
# mypy complains that channel_list can be also of type list[str | int]
2057-
channel_list = [channel] if isinstance(channel, int | str) else channel # type: ignore[assignment]
2058-
2059-
if channel_list is not None and (
2060-
(isinstance(channel_list[0], int) and max([abs(ch) for ch in channel_list]) <= len(spatial_element_ch)) # type: ignore[arg-type]
2061-
or all(ch in spatial_element_ch for ch in channel_list)
2062-
):
2063-
element_params[el]["channel"] = channel_list
2052+
if channel is not None:
2053+
# Normalize channel to always be a list of str or a list of int
2054+
if isinstance(channel, str):
2055+
channel = [channel]
2056+
2057+
if isinstance(channel, int):
2058+
channel = [channel]
2059+
2060+
# If channel is a list, ensure all elements are the same type
2061+
if not (isinstance(channel, list) and channel and all(isinstance(c, type(channel[0])) for c in channel)):
2062+
raise TypeError("Each item in 'channel' list must be of the same type, either string or integer.")
2063+
2064+
invalid = [c for c in channel if c not in spatial_element_ch]
2065+
if invalid:
2066+
raise ValueError(
2067+
f"Invalid channel(s): {', '.join(str(c) for c in invalid)}. Valid choices are: {spatial_element_ch}"
2068+
)
2069+
element_params[el]["channel"] = channel
20642070
else:
20652071
element_params[el]["channel"] = None
20662072

20672073
element_params[el]["alpha"] = param_dict["alpha"]
20682074

2069-
if isinstance(palette := param_dict["palette"], list):
2075+
palette = param_dict["palette"]
2076+
assert isinstance(palette, list | type(None)) # if present, was converted to list, just to make sure
2077+
2078+
if isinstance(palette, list):
2079+
# case A: single palette for all channels
20702080
if len(palette) == 1:
2071-
palette_length = len(channel_list) if channel_list is not None else len(spatial_element_ch)
2081+
palette_length = len(channel) if channel is not None else len(spatial_element_ch)
20722082
palette = palette * palette_length
2073-
if (channel_list is not None and len(palette) != len(channel_list)) and len(palette) != len(
2074-
spatial_element_ch
2075-
):
2076-
palette = None
2083+
# case B: one palette per channel (either given or derived from channel length)
2084+
channels_to_use = spatial_element_ch if element_params[el]["channel"] is None else channel
2085+
if channels_to_use is not None and len(palette) != len(channels_to_use):
2086+
raise ValueError(
2087+
f"Palette length ({len(palette)}) does not match channel length "
2088+
f"({', '.join(str(c) for c in channels_to_use)})."
2089+
)
20772090
element_params[el]["palette"] = palette
20782091
element_params[el]["na_color"] = param_dict["na_color"]
20792092

20802093
if (cmap := param_dict["cmap"]) is not None:
20812094
if len(cmap) == 1:
2082-
cmap_length = len(channel_list) if channel_list is not None else len(spatial_element_ch)
2095+
cmap_length = len(channel) if channel is not None else len(spatial_element_ch)
20832096
cmap = cmap * cmap_length
2084-
if (channel_list is not None and len(cmap) != len(channel_list)) or len(cmap) != len(spatial_element_ch):
2097+
if (channel is not None and len(cmap) != len(channel)) or len(cmap) != len(spatial_element_ch):
20852098
cmap = None
20862099
element_params[el]["cmap"] = cmap
20872100
element_params[el]["norm"] = param_dict["norm"]
@@ -2099,7 +2112,7 @@ def _validate_image_render_params(
20992112
def _get_wanted_render_elements(
21002113
sdata: SpatialData,
21012114
sdata_wanted_elements: list[str],
2102-
params: (ImageRenderParams | LabelsRenderParams | PointsRenderParams | ShapesRenderParams),
2115+
params: ImageRenderParams | LabelsRenderParams | PointsRenderParams | ShapesRenderParams,
21032116
cs: str,
21042117
element_type: Literal["images", "labels", "points", "shapes"],
21052118
) -> tuple[list[str], list[str], bool]:
@@ -2256,7 +2269,7 @@ def _create_image_from_datashader_result(
22562269

22572270

22582271
def _datashader_aggregate_with_function(
2259-
reduction: (Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None),
2272+
reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None,
22602273
cvs: Canvas,
22612274
spatial_element: GeoDataFrame | dask.dataframe.core.DataFrame,
22622275
col_for_color: str | None,
@@ -2320,7 +2333,7 @@ def _datashader_aggregate_with_function(
23202333

23212334

23222335
def _datshader_get_how_kw_for_spread(
2323-
reduction: (Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None),
2336+
reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None,
23242337
) -> str:
23252338
# Get the best input for the how argument of ds.tf.spread(), needed for numerical values
23262339
reduction = reduction or "sum"
-11.2 KB
Loading

0 commit comments

Comments
 (0)