@@ -567,9 +567,6 @@ def _get_subplots(num_images: int, ncols: int = 4, width: int = 4, height: int =
567
567
Union[plt.Figure, plt.Axes]
568
568
Matplotlib figure and axes object.
569
569
"""
570
- # if num_images <= 1:
571
- # raise ValueError("Number of images must be greater than 1.")
572
-
573
570
if num_images < ncols :
574
571
nrows = 1
575
572
ncols = num_images
@@ -733,8 +730,6 @@ def _set_color_source_vec(
733
730
color = np .full (len (element ), na_color )
734
731
return color , color , False
735
732
736
- # model = get_model(sdata[element_name])
737
-
738
733
# Figure out where to get the color from
739
734
origins = _locate_value (value_key = value_to_plot , sdata = sdata , element_name = element_name , table_name = table_name )
740
735
@@ -778,16 +773,13 @@ def _set_color_source_vec(
778
773
palette = palette ,
779
774
na_color = na_color ,
780
775
)
776
+
781
777
color_source_vector = color_source_vector .set_categories (color_mapping .keys ())
782
778
if color_mapping is None :
783
779
raise ValueError ("Unable to create color palette." )
784
780
785
781
# do not rename categories, as colors need not be unique
786
782
color_vector = color_source_vector .map (color_mapping )
787
- if color_vector .isna ().any ():
788
- if (na_cat_color := to_hex (na_color )) not in color_vector .categories :
789
- color_vector = color_vector .add_categories ([na_cat_color ])
790
- color_vector = color_vector .fillna (to_hex (na_color ))
791
783
792
784
return color_source_vector , color_vector , True
793
785
@@ -808,44 +800,43 @@ def _map_color_seg(
808
800
seg_boundaries : bool = False ,
809
801
) -> ArrayLike :
810
802
cell_id = np .array (cell_id )
811
- if color_vector is not None and isinstance (color_vector .dtype , pd .CategoricalDtype ):
812
- # users wants to plot a categorical column
803
+
804
+ if pd .api .types .is_categorical_dtype (color_vector .dtype ):
805
+ # Case A: users wants to plot a categorical column
813
806
if np .any (color_source_vector .isna ()):
814
807
cell_id [color_source_vector .isna ()] = 0
815
- val_im : ArrayLike = map_array (seg , cell_id , color_vector .codes + 1 )
808
+ val_im : ArrayLike = map_array (seg . copy () , cell_id , color_vector .codes + 1 )
816
809
cols = colors .to_rgba_array (color_vector .categories )
817
-
818
810
elif pd .api .types .is_numeric_dtype (color_vector .dtype ):
819
- # user wants to plot a continous column
811
+ # Case B: user wants to plot a continous column
820
812
if isinstance (color_vector , pd .Series ):
821
813
color_vector = color_vector .to_numpy ()
822
- val_im = map_array (seg , cell_id , color_vector )
823
814
cols = cmap_params .cmap (cmap_params .norm (color_vector ))
824
-
815
+ val_im = map_array ( seg . copy (), cell_id , cell_id )
825
816
else :
826
- val_im = map_array (seg .copy (), cell_id , cell_id ) # replace with same seg id to remove missing segs
827
-
828
- if val_im .shape [0 ] == 1 :
829
- val_im = np .squeeze (val_im , axis = 0 )
830
- if "#" in str (color_vector [0 ]):
831
- # we have hex colors
832
- assert all (_is_color_like (c ) for c in color_vector ), "Not all values are color-like."
833
- cols = colors .to_rgba_array (color_vector )
817
+ # Case C: User didn't specify any colors
818
+ if color_source_vector is not None and (
819
+ set (color_vector ) == set (color_source_vector )
820
+ and len (set (color_vector )) == 1
821
+ and set (color_vector ) == {na_color }
822
+ and not na_color_modified_by_user
823
+ ):
824
+ val_im = map_array (seg .copy (), cell_id , cell_id )
825
+ RNG = default_rng (42 )
826
+ cols = RNG .random ((len (color_vector ), 3 ))
834
827
else :
835
- cols = cmap_params .cmap (cmap_params .norm (color_vector ))
828
+ # Case D: User didn't specify a column to color by, but modified the na_color
829
+ val_im = map_array (seg .copy (), cell_id , cell_id )
830
+ if "#" in str (color_vector [0 ]):
831
+ # we have hex colors
832
+ assert all (_is_color_like (c ) for c in color_vector ), "Not all values are color-like."
833
+ cols = colors .to_rgba_array (color_vector )
834
+ else :
835
+ cols = cmap_params .cmap (cmap_params .norm (color_vector ))
836
836
837
837
if seg_erosionpx is not None :
838
838
val_im [val_im == erosion (val_im , square (seg_erosionpx ))] = 0
839
839
840
- if color_source_vector is not None and (
841
- set (color_vector ) == set (color_source_vector )
842
- and len (set (color_vector )) == 1
843
- and set (color_vector ) == {na_color }
844
- and not na_color_modified_by_user
845
- ):
846
- RNG = default_rng (42 )
847
- cols = RNG .random ((len (cols ), 3 ))
848
-
849
840
seg_im : ArrayLike = label2rgb (
850
841
label = val_im ,
851
842
colors = cols ,
@@ -948,7 +939,7 @@ def _get_categorical_color_mapping(
948
939
else :
949
940
base_mapping = _generate_base_categorial_color_mapping (adata , cluster_key , color_source_vector , na_color )
950
941
951
- return _modify_categorical_color_mapping (base_mapping , groups , palette )
942
+ return _modify_categorical_color_mapping (mapping = base_mapping , groups = groups , palette = palette )
952
943
953
944
954
945
def _maybe_set_colors (
@@ -1587,19 +1578,14 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st
1587
1578
1588
1579
palette = param_dict ["palette" ]
1589
1580
1590
- if (groups := param_dict .get ("groups" )) is not None and palette is None :
1591
- warnings .warn (
1592
- "Groups is specified but palette is not. Setting palette to default 'lightgray'" , UserWarning , stacklevel = 2
1593
- )
1594
- param_dict ["palette" ] = ["lightgray" for _ in range (len (groups ))]
1595
-
1596
1581
if isinstance ((palette := param_dict ["palette" ]), list ):
1597
1582
if not all (isinstance (p , str ) for p in palette ):
1598
1583
raise ValueError ("If specified, parameter 'palette' must contain only strings." )
1599
1584
elif isinstance (palette , (str , type (None ))) and "palette" in param_dict :
1600
1585
param_dict ["palette" ] = [palette ] if palette is not None else None
1601
1586
1602
1587
if element_type in ["shapes" , "points" , "labels" ] and (palette := param_dict .get ("palette" )) is not None :
1588
+ groups = param_dict .get ("groups" )
1603
1589
if groups is None :
1604
1590
raise ValueError ("When specifying 'palette', 'groups' must also be specified." )
1605
1591
if len (groups ) != len (palette ):
0 commit comments