@@ -2008,7 +2008,7 @@ def _validate_col_for_column_table(
2008
2008
table_name = next (iter (tables ))
2009
2009
if len (tables ) > 1 :
2010
2010
warnings .warn (
2011
- f"Multiple tables contain color column, using { table_name } " ,
2011
+ f"Multiple tables contain column ' { col_for_color } ' , using table ' { table_name } '. " ,
2012
2012
UserWarning ,
2013
2013
stacklevel = 2 ,
2014
2014
)
@@ -2044,44 +2044,57 @@ def _validate_image_render_params(
2044
2044
element_params [el ] = {}
2045
2045
spatial_element = param_dict ["sdata" ][el ]
2046
2046
2047
+ # robustly get channel names from image or multiscale image
2047
2048
spatial_element_ch = (
2048
- spatial_element .c if isinstance (spatial_element , DataArray ) else spatial_element ["scale0" ].c
2049
+ spatial_element .c . values if isinstance (spatial_element , DataArray ) else spatial_element ["scale0" ].c . values
2049
2050
)
2050
-
2051
2051
channel = param_dict ["channel" ]
2052
- channel_list : list [str ] | list [int ] | None
2053
- if isinstance (channel , list ):
2054
- type_ = type (channel [0 ])
2055
- assert all (isinstance (ch , type_ ) for ch in channel ), "All channels must be of the same type."
2056
- # mypy complains that channel_list can be also of type list[str | int]
2057
- channel_list = [channel ] if isinstance (channel , int | str ) else channel # type: ignore[assignment]
2058
-
2059
- if channel_list is not None and (
2060
- (isinstance (channel_list [0 ], int ) and max ([abs (ch ) for ch in channel_list ]) <= len (spatial_element_ch )) # type: ignore[arg-type]
2061
- or all (ch in spatial_element_ch for ch in channel_list )
2062
- ):
2063
- element_params [el ]["channel" ] = channel_list
2052
+ if channel is not None :
2053
+ # Normalize channel to always be a list of str or a list of int
2054
+ if isinstance (channel , str ):
2055
+ channel = [channel ]
2056
+
2057
+ if isinstance (channel , int ):
2058
+ channel = [channel ]
2059
+
2060
+ # If channel is a list, ensure all elements are the same type
2061
+ if not (isinstance (channel , list ) and channel and all (isinstance (c , type (channel [0 ])) for c in channel )):
2062
+ raise TypeError ("Each item in 'channel' list must be of the same type, either string or integer." )
2063
+
2064
+ invalid = [c for c in channel if c not in spatial_element_ch ]
2065
+ if invalid :
2066
+ raise ValueError (
2067
+ f"Invalid channel(s): { ', ' .join (str (c ) for c in invalid )} . Valid choices are: { spatial_element_ch } "
2068
+ )
2069
+ element_params [el ]["channel" ] = channel
2064
2070
else :
2065
2071
element_params [el ]["channel" ] = None
2066
2072
2067
2073
element_params [el ]["alpha" ] = param_dict ["alpha" ]
2068
2074
2069
- if isinstance (palette := param_dict ["palette" ], list ):
2075
+ palette = param_dict ["palette" ]
2076
+ assert isinstance (palette , list | type (None )) # if present, was converted to list, just to make sure
2077
+
2078
+ if isinstance (palette , list ):
2079
+ # case A: single palette for all channels
2070
2080
if len (palette ) == 1 :
2071
- palette_length = len (channel_list ) if channel_list is not None else len (spatial_element_ch )
2081
+ palette_length = len (channel ) if channel is not None else len (spatial_element_ch )
2072
2082
palette = palette * palette_length
2073
- if (channel_list is not None and len (palette ) != len (channel_list )) and len (palette ) != len (
2074
- spatial_element_ch
2075
- ):
2076
- palette = None
2083
+ # case B: one palette per channel (either given or derived from channel length)
2084
+ channels_to_use = spatial_element_ch if element_params [el ]["channel" ] is None else channel
2085
+ if channels_to_use is not None and len (palette ) != len (channels_to_use ):
2086
+ raise ValueError (
2087
+ f"Palette length ({ len (palette )} ) does not match channel length "
2088
+ f"({ ', ' .join (str (c ) for c in channels_to_use )} )."
2089
+ )
2077
2090
element_params [el ]["palette" ] = palette
2078
2091
element_params [el ]["na_color" ] = param_dict ["na_color" ]
2079
2092
2080
2093
if (cmap := param_dict ["cmap" ]) is not None :
2081
2094
if len (cmap ) == 1 :
2082
- cmap_length = len (channel_list ) if channel_list is not None else len (spatial_element_ch )
2095
+ cmap_length = len (channel ) if channel is not None else len (spatial_element_ch )
2083
2096
cmap = cmap * cmap_length
2084
- if (channel_list is not None and len (cmap ) != len (channel_list )) or len (cmap ) != len (spatial_element_ch ):
2097
+ if (channel is not None and len (cmap ) != len (channel )) or len (cmap ) != len (spatial_element_ch ):
2085
2098
cmap = None
2086
2099
element_params [el ]["cmap" ] = cmap
2087
2100
element_params [el ]["norm" ] = param_dict ["norm" ]
@@ -2099,7 +2112,7 @@ def _validate_image_render_params(
2099
2112
def _get_wanted_render_elements (
2100
2113
sdata : SpatialData ,
2101
2114
sdata_wanted_elements : list [str ],
2102
- params : ( ImageRenderParams | LabelsRenderParams | PointsRenderParams | ShapesRenderParams ) ,
2115
+ params : ImageRenderParams | LabelsRenderParams | PointsRenderParams | ShapesRenderParams ,
2103
2116
cs : str ,
2104
2117
element_type : Literal ["images" , "labels" , "points" , "shapes" ],
2105
2118
) -> tuple [list [str ], list [str ], bool ]:
@@ -2256,7 +2269,7 @@ def _create_image_from_datashader_result(
2256
2269
2257
2270
2258
2271
def _datashader_aggregate_with_function (
2259
- reduction : ( Literal ["sum" , "mean" , "any" , "count" , "std" , "var" , "max" , "min" ] | None ) ,
2272
+ reduction : Literal ["sum" , "mean" , "any" , "count" , "std" , "var" , "max" , "min" ] | None ,
2260
2273
cvs : Canvas ,
2261
2274
spatial_element : GeoDataFrame | dask .dataframe .core .DataFrame ,
2262
2275
col_for_color : str | None ,
@@ -2320,7 +2333,7 @@ def _datashader_aggregate_with_function(
2320
2333
2321
2334
2322
2335
def _datshader_get_how_kw_for_spread (
2323
- reduction : ( Literal ["sum" , "mean" , "any" , "count" , "std" , "var" , "max" , "min" ] | None ) ,
2336
+ reduction : Literal ["sum" , "mean" , "any" , "count" , "std" , "var" , "max" , "min" ] | None ,
2324
2337
) -> str :
2325
2338
# Get the best input for the how argument of ds.tf.spread(), needed for numerical values
2326
2339
reduction = reduction or "sum"
0 commit comments