@@ -526,6 +526,8 @@ def _prepare_cmap_norm(
526
526
527
527
cmap = copy (cmap )
528
528
529
+ assert isinstance (cmap , Colormap ), f"Invalid type of `cmap`: { type (cmap )} , expected `Colormap`."
530
+
529
531
if norm is None :
530
532
norm = Normalize (vmin = None , vmax = None , clip = False )
531
533
@@ -2045,30 +2047,41 @@ def _validate_image_render_params(
2045
2047
spatial_element_ch = (
2046
2048
spatial_element .c if isinstance (spatial_element , DataArray ) else spatial_element ["scale0" ].c
2047
2049
)
2048
- if (channel := param_dict ["channel" ]) is not None and (
2049
- (isinstance (channel [0 ], int ) and max ([abs (ch ) for ch in channel ]) <= len (spatial_element_ch ))
2050
- or all (ch in spatial_element_ch for ch in channel )
2050
+
2051
+ channel = param_dict ["channel" ]
2052
+ channel_list : list [str ] | list [int ] | None
2053
+ if isinstance (channel , list ):
2054
+ type_ = type (channel [0 ])
2055
+ assert all (isinstance (ch , type_ ) for ch in channel ), "All channels must be of the same type."
2056
+ # mypy complains that channel_list can be also of type list[str | int]
2057
+ channel_list = [channel ] if isinstance (channel , int | str ) else channel # type: ignore[assignment]
2058
+
2059
+ if channel_list is not None and (
2060
+ (isinstance (channel_list [0 ], int ) and max ([abs (ch ) for ch in channel_list ]) <= len (spatial_element_ch )) # type: ignore[arg-type]
2061
+ or all (ch in spatial_element_ch for ch in channel_list )
2051
2062
):
2052
- element_params [el ]["channel" ] = channel
2063
+ element_params [el ]["channel" ] = channel_list
2053
2064
else :
2054
2065
element_params [el ]["channel" ] = None
2055
2066
2056
2067
element_params [el ]["alpha" ] = param_dict ["alpha" ]
2057
2068
2058
2069
if isinstance (palette := param_dict ["palette" ], list ):
2059
2070
if len (palette ) == 1 :
2060
- palette_length = len (channel ) if channel is not None else len (spatial_element_ch )
2071
+ palette_length = len (channel_list ) if channel_list is not None else len (spatial_element_ch )
2061
2072
palette = palette * palette_length
2062
- if (channel is not None and len (palette ) != len (channel )) and len (palette ) != len (spatial_element_ch ):
2073
+ if (channel_list is not None and len (palette ) != len (channel_list )) and len (palette ) != len (
2074
+ spatial_element_ch
2075
+ ):
2063
2076
palette = None
2064
2077
element_params [el ]["palette" ] = palette
2065
2078
element_params [el ]["na_color" ] = param_dict ["na_color" ]
2066
2079
2067
2080
if (cmap := param_dict ["cmap" ]) is not None :
2068
2081
if len (cmap ) == 1 :
2069
- cmap_length = len (channel ) if channel is not None else len (spatial_element_ch )
2082
+ cmap_length = len (channel_list ) if channel_list is not None else len (spatial_element_ch )
2070
2083
cmap = cmap * cmap_length
2071
- if (channel is not None and len (cmap ) != len (channel )) or len (cmap ) != len (spatial_element_ch ):
2084
+ if (channel_list is not None and len (cmap ) != len (channel_list )) or len (cmap ) != len (spatial_element_ch ):
2072
2085
cmap = None
2073
2086
element_params [el ]["cmap" ] = cmap
2074
2087
element_params [el ]["norm" ] = param_dict ["norm" ]
@@ -2364,7 +2377,9 @@ def _get_datashader_trans_matrix_of_single_element(
2364
2377
# no flipping needed
2365
2378
return tm
2366
2379
# for a Translation, we need the transposed transformation matrix
2367
- return tm .T
2380
+ tm_T = tm .T
2381
+ assert isinstance (tm_T , np .ndarray )
2382
+ return tm_T
2368
2383
2369
2384
2370
2385
def _get_transformation_matrix_for_datashader (
0 commit comments