Skip to content

Commit c709932

Browse files
committed
fix
1 parent ed66c85 commit c709932

File tree

2 files changed

+71
-4
lines changed

2 files changed

+71
-4
lines changed

src/spatialdata_plot/pl/render.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,10 @@ def _render_shapes(
204204

205205
# Handle circles encoded as points with radius
206206
if is_point.any():
207-
scale = shapes[is_point]["radius"] * render_params.scale
207+
radius_values = shapes[is_point]["radius"]
208+
# Convert to numeric, replacing non-numeric values with NaN
209+
radius_numeric = pd.to_numeric(radius_values, errors='coerce')
210+
scale = radius_numeric * render_params.scale
208211
shapes.loc[is_point, "geometry"] = _geometry[is_point].buffer(scale.to_numpy())
209212

210213
# apply transformations to the individual points
@@ -227,6 +230,20 @@ def _render_shapes(
227230

228231
# in case we are coloring by a column in table
229232
if col_for_color is not None and col_for_color not in transformed_element.columns:
233+
# Ensure color vector length matches the number of shapes
234+
if len(color_vector) != len(transformed_element):
235+
if len(color_vector) == 1:
236+
# If single color, broadcast to all shapes
237+
color_vector = [color_vector[0]] * len(transformed_element)
238+
else:
239+
# If lengths don't match, pad or truncate to match
240+
if len(color_vector) > len(transformed_element):
241+
color_vector = color_vector[:len(transformed_element)]
242+
else:
243+
# Pad with the last color or na_color
244+
na_color = render_params.cmap_params.na_color.get_hex_with_alpha()
245+
color_vector = list(color_vector) + [na_color] * (len(transformed_element) - len(color_vector))
246+
230247
transformed_element[col_for_color] = color_vector if color_source_vector is None else color_source_vector
231248
# Render shapes with datashader
232249
color_by_categorical = col_for_color is not None and color_source_vector is not None

src/spatialdata_plot/pl/utils.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,50 @@
9393
ColorLike = tuple[float, ...] | list[float] | str
9494

9595

96+
def _extract_scalar_value(value: Any, default: float = 0.0) -> float:
97+
"""
98+
Extract a scalar float value from various data types.
99+
100+
Handles pandas Series, arrays, lists, and other iterables by taking the first element.
101+
Converts non-numeric values to the default value.
102+
103+
Parameters
104+
----------
105+
value : Any
106+
The value to extract a scalar from
107+
default : float, default 0.0
108+
Default value to return if conversion fails
109+
110+
Returns
111+
-------
112+
float
113+
The extracted scalar value
114+
"""
115+
try:
116+
# Handle pandas Series or similar objects with iloc
117+
if hasattr(value, 'iloc'):
118+
if len(value) > 0:
119+
value = value.iloc[0]
120+
else:
121+
return default
122+
123+
# Handle other array-like objects
124+
elif hasattr(value, '__len__') and not isinstance(value, (str, bytes)):
125+
if len(value) > 0:
126+
value = value[0]
127+
else:
128+
return default
129+
130+
# Convert to float, handling NaN values
131+
if pd.isna(value):
132+
return default
133+
134+
return float(value)
135+
136+
except (TypeError, ValueError, IndexError):
137+
return default
138+
139+
96140
def _verify_plotting_tree(sdata: SpatialData) -> SpatialData:
97141
"""Verify that the plotting tree exists, and if not, create it."""
98142
if not hasattr(sdata, "plotting_tree"):
@@ -285,9 +329,10 @@ def _get_centroid_of_pathpatch(pathpatch: mpatches.PathPatch) -> tuple[float, fl
285329

286330

287331
def _scale_pathpatch_around_centroid(pathpatch: mpatches.PathPatch, scale_factor: float) -> None:
332+
scale_value = _extract_scalar_value(scale_factor, default=1.0)
288333
centroid = _get_centroid_of_pathpatch(pathpatch)
289334
vertices = pathpatch.get_path().vertices
290-
scaled_vertices = np.array([centroid + (vertex - centroid) * scale_factor for vertex in vertices])
335+
scaled_vertices = np.array([centroid + (vertex - centroid) * scale_value for vertex in vertices])
291336
pathpatch.get_path().vertices = scaled_vertices
292337

293338

@@ -421,7 +466,8 @@ def _assign_fill_and_outline_to_row(
421466
def _process_polygon(row: pd.Series, scale: float) -> dict[str, Any]:
422467
coords = np.array(row["geometry"].exterior.coords)
423468
centroid = np.mean(coords, axis=0)
424-
scaled = (centroid + (coords - centroid) * scale).tolist()
469+
scale_value = _extract_scalar_value(scale, default=1.0)
470+
scaled = (centroid + (coords - centroid) * scale_value).tolist()
425471
return {**row.to_dict(), "geometry": mpatches.Polygon(scaled, closed=True)}
426472

427473
def _process_multipolygon(row: pd.Series, scale: float) -> list[dict[str, Any]]:
@@ -432,9 +478,13 @@ def _process_multipolygon(row: pd.Series, scale: float) -> list[dict[str, Any]]:
432478
return [{**row_dict, "geometry": m} for m in mp]
433479

434480
def _process_point(row: pd.Series, scale: float) -> dict[str, Any]:
481+
radius_value = _extract_scalar_value(row["radius"], default=0.0)
482+
scale_value = _extract_scalar_value(scale, default=1.0)
483+
radius = radius_value * scale_value
484+
435485
return {
436486
**row.to_dict(),
437-
"geometry": mpatches.Circle((row["geometry"].x, row["geometry"].y), radius=row["radius"] * scale),
487+
"geometry": mpatches.Circle((row["geometry"].x, row["geometry"].y), radius=radius),
438488
}
439489

440490
def _create_patches(

0 commit comments

Comments
 (0)