Skip to content

Commit 6cef5df

Browse files
Coloring labels by a continuous variable fixed (#344)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 393315b commit 6cef5df

File tree

36 files changed

+105
-74
lines changed

36 files changed

+105
-74
lines changed

CHANGELOG.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,22 @@ and this project adheres to [Semantic Versioning][].
88
[keep a changelog]: https://keepachangelog.com/en/1.0.0/
99
[semantic versioning]: https://semver.org/spec/v2.0.0.html
1010

11+
## [0.2.6] - tbd
12+
13+
### Added
14+
15+
-
16+
17+
### Changed
18+
19+
- Lowered RMSE-threshold for plot-based tests from 45 to 15 (#344)
20+
- When subsetting to `groups`, `NA` isn't automatically added to legend (#344)
21+
22+
### Fixed
23+
24+
- Filtering with `groups` now preserves original cmap (#344)
25+
- Non-selected `groups` are now not shown in `na_color` (#344)
26+
1127
## [0.2.5] - 2024-08-23
1228

1329
### Added

src/spatialdata_plot/pl/render.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -855,7 +855,7 @@ def _draw_labels(seg_erosionpx: int | None, seg_boundaries: bool, alpha: float)
855855
legend_fontweight=legend_params.legend_fontweight,
856856
legend_loc=legend_params.legend_loc,
857857
legend_fontoutline=legend_params.legend_fontoutline,
858-
na_in_legend=legend_params.na_in_legend,
858+
na_in_legend=legend_params.na_in_legend if groups is None else len(groups) == len(set(color_vector)),
859859
colorbar=legend_params.colorbar,
860860
scalebar_dx=scalebar_params.scalebar_dx,
861861
scalebar_units=scalebar_params.scalebar_units,

src/spatialdata_plot/pl/utils.py

Lines changed: 27 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -567,9 +567,6 @@ def _get_subplots(num_images: int, ncols: int = 4, width: int = 4, height: int =
567567
Union[plt.Figure, plt.Axes]
568568
Matplotlib figure and axes object.
569569
"""
570-
# if num_images <= 1:
571-
# raise ValueError("Number of images must be greater than 1.")
572-
573570
if num_images < ncols:
574571
nrows = 1
575572
ncols = num_images
@@ -733,8 +730,6 @@ def _set_color_source_vec(
733730
color = np.full(len(element), na_color)
734731
return color, color, False
735732

736-
# model = get_model(sdata[element_name])
737-
738733
# Figure out where to get the color from
739734
origins = _locate_value(value_key=value_to_plot, sdata=sdata, element_name=element_name, table_name=table_name)
740735

@@ -778,16 +773,13 @@ def _set_color_source_vec(
778773
palette=palette,
779774
na_color=na_color,
780775
)
776+
781777
color_source_vector = color_source_vector.set_categories(color_mapping.keys())
782778
if color_mapping is None:
783779
raise ValueError("Unable to create color palette.")
784780

785781
# do not rename categories, as colors need not be unique
786782
color_vector = color_source_vector.map(color_mapping)
787-
if color_vector.isna().any():
788-
if (na_cat_color := to_hex(na_color)) not in color_vector.categories:
789-
color_vector = color_vector.add_categories([na_cat_color])
790-
color_vector = color_vector.fillna(to_hex(na_color))
791783

792784
return color_source_vector, color_vector, True
793785

@@ -808,44 +800,43 @@ def _map_color_seg(
808800
seg_boundaries: bool = False,
809801
) -> ArrayLike:
810802
cell_id = np.array(cell_id)
811-
if color_vector is not None and isinstance(color_vector.dtype, pd.CategoricalDtype):
812-
# users wants to plot a categorical column
803+
804+
if pd.api.types.is_categorical_dtype(color_vector.dtype):
805+
# Case A: users wants to plot a categorical column
813806
if np.any(color_source_vector.isna()):
814807
cell_id[color_source_vector.isna()] = 0
815-
val_im: ArrayLike = map_array(seg, cell_id, color_vector.codes + 1)
808+
val_im: ArrayLike = map_array(seg.copy(), cell_id, color_vector.codes + 1)
816809
cols = colors.to_rgba_array(color_vector.categories)
817-
818810
elif pd.api.types.is_numeric_dtype(color_vector.dtype):
819-
# user wants to plot a continous column
811+
# Case B: user wants to plot a continous column
820812
if isinstance(color_vector, pd.Series):
821813
color_vector = color_vector.to_numpy()
822-
val_im = map_array(seg, cell_id, color_vector)
823814
cols = cmap_params.cmap(cmap_params.norm(color_vector))
824-
815+
val_im = map_array(seg.copy(), cell_id, cell_id)
825816
else:
826-
val_im = map_array(seg.copy(), cell_id, cell_id) # replace with same seg id to remove missing segs
827-
828-
if val_im.shape[0] == 1:
829-
val_im = np.squeeze(val_im, axis=0)
830-
if "#" in str(color_vector[0]):
831-
# we have hex colors
832-
assert all(_is_color_like(c) for c in color_vector), "Not all values are color-like."
833-
cols = colors.to_rgba_array(color_vector)
817+
# Case C: User didn't specify any colors
818+
if color_source_vector is not None and (
819+
set(color_vector) == set(color_source_vector)
820+
and len(set(color_vector)) == 1
821+
and set(color_vector) == {na_color}
822+
and not na_color_modified_by_user
823+
):
824+
val_im = map_array(seg.copy(), cell_id, cell_id)
825+
RNG = default_rng(42)
826+
cols = RNG.random((len(color_vector), 3))
834827
else:
835-
cols = cmap_params.cmap(cmap_params.norm(color_vector))
828+
# Case D: User didn't specify a column to color by, but modified the na_color
829+
val_im = map_array(seg.copy(), cell_id, cell_id)
830+
if "#" in str(color_vector[0]):
831+
# we have hex colors
832+
assert all(_is_color_like(c) for c in color_vector), "Not all values are color-like."
833+
cols = colors.to_rgba_array(color_vector)
834+
else:
835+
cols = cmap_params.cmap(cmap_params.norm(color_vector))
836836

837837
if seg_erosionpx is not None:
838838
val_im[val_im == erosion(val_im, square(seg_erosionpx))] = 0
839839

840-
if color_source_vector is not None and (
841-
set(color_vector) == set(color_source_vector)
842-
and len(set(color_vector)) == 1
843-
and set(color_vector) == {na_color}
844-
and not na_color_modified_by_user
845-
):
846-
RNG = default_rng(42)
847-
cols = RNG.random((len(cols), 3))
848-
849840
seg_im: ArrayLike = label2rgb(
850841
label=val_im,
851842
colors=cols,
@@ -948,7 +939,7 @@ def _get_categorical_color_mapping(
948939
else:
949940
base_mapping = _generate_base_categorial_color_mapping(adata, cluster_key, color_source_vector, na_color)
950941

951-
return _modify_categorical_color_mapping(base_mapping, groups, palette)
942+
return _modify_categorical_color_mapping(mapping=base_mapping, groups=groups, palette=palette)
952943

953944

954945
def _maybe_set_colors(
@@ -1587,19 +1578,14 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st
15871578

15881579
palette = param_dict["palette"]
15891580

1590-
if (groups := param_dict.get("groups")) is not None and palette is None:
1591-
warnings.warn(
1592-
"Groups is specified but palette is not. Setting palette to default 'lightgray'", UserWarning, stacklevel=2
1593-
)
1594-
param_dict["palette"] = ["lightgray" for _ in range(len(groups))]
1595-
15961581
if isinstance((palette := param_dict["palette"]), list):
15971582
if not all(isinstance(p, str) for p in palette):
15981583
raise ValueError("If specified, parameter 'palette' must contain only strings.")
15991584
elif isinstance(palette, (str, type(None))) and "palette" in param_dict:
16001585
param_dict["palette"] = [palette] if palette is not None else None
16011586

16021587
if element_type in ["shapes", "points", "labels"] and (palette := param_dict.get("palette")) is not None:
1588+
groups = param_dict.get("groups")
16031589
if groups is None:
16041590
raise ValueError("When specifying 'palette', 'groups' must also be specified.")
16051591
if len(groups) != len(palette):
4.78 KB
Loading
23.3 KB
Loading
466 Bytes
Loading
-7.42 KB
Loading
15 KB
Loading
8 Bytes
Loading
278 Bytes
Loading

0 commit comments

Comments
 (0)