Skip to content

Commit f45e98f

Browse files
timtreisclaude
andcommitted
Refactor: extract groups filtering helper, fix deprecated dtype checks
- Extract _filter_groups_transparent_na() to deduplicate shapes/points filtering - Replace pd.api.types.is_categorical_dtype with isinstance(dtype, pd.CategoricalDtype) - Add TypeError for unsupported types in _coerce_categorical_source - Simplify np.ndarray branch and use Categorical.isin() directly Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent d6b2441 commit f45e98f

File tree

1 file changed

+41
-23
lines changed

1 file changed

+41
-23
lines changed

src/spatialdata_plot/pl/render.py

Lines changed: 41 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -70,20 +70,32 @@
7070

7171

7272
def _coerce_categorical_source(cat_source: Any) -> pd.Categorical:
73-
"""Return a pandas Categorical from known, concrete sources only."""
73+
"""Return a pandas Categorical from known, concrete sources only.
74+
75+
Raises
76+
------
77+
TypeError
78+
If *cat_source* is not a ``dd.Series``, ``pd.Series``,
79+
``pd.Categorical``, or ``np.ndarray``.
80+
"""
7481
if isinstance(cat_source, dd.Series):
75-
if pd.api.types.is_categorical_dtype(cat_source.dtype) and getattr(cat_source.cat, "known", True) is False:
82+
if isinstance(cat_source.dtype, pd.CategoricalDtype) and getattr(cat_source.cat, "known", True) is False:
7683
cat_source = cat_source.cat.as_known()
7784
cat_source = cat_source.compute()
7885

7986
if isinstance(cat_source, pd.Series):
80-
if pd.api.types.is_categorical_dtype(cat_source.dtype):
87+
if isinstance(cat_source.dtype, pd.CategoricalDtype):
8188
return cat_source.array
8289
return pd.Categorical(cat_source)
8390
if isinstance(cat_source, pd.Categorical):
8491
return cat_source
92+
if isinstance(cat_source, np.ndarray):
93+
return pd.Categorical(cat_source)
8594

86-
return pd.Categorical(pd.Series(cat_source))
95+
raise TypeError(
96+
f"Cannot coerce {type(cat_source).__name__} to pd.Categorical. "
97+
"Expected dd.Series, pd.Series, pd.Categorical, or np.ndarray."
98+
)
8799

88100

89101
def _build_datashader_color_key(
@@ -106,6 +118,22 @@ def _build_datashader_color_key(
106118
return color_key
107119

108120

121+
def _filter_groups_transparent_na(
122+
groups: str | list[str],
123+
color_source_vector: pd.Categorical,
124+
color_vector: pd.Series | np.ndarray | list[str],
125+
) -> tuple[np.ndarray, pd.Categorical, np.ndarray]:
126+
"""Return a boolean mask and filtered color vectors for groups filtering.
127+
128+
Used when ``na_color=None`` (fully transparent) so that non-matching
129+
elements are removed entirely instead of rendered invisibly.
130+
"""
131+
keep = color_source_vector.isin(groups)
132+
filtered_csv = color_source_vector[keep]
133+
filtered_cv = np.asarray(color_vector)[keep]
134+
return keep, filtered_csv, filtered_cv
135+
136+
109137
def _split_colorbar_params(params: dict[str, object] | None) -> tuple[dict[str, object], dict[str, object], str | None]:
110138
"""Split colorbar params into layout hints, Matplotlib kwargs, and label override."""
111139
layout: dict[str, object] = {}
@@ -212,16 +240,11 @@ def _render_shapes(
212240
# When groups are specified and na_color is fully transparent (na_color=None),
213241
# filter out non-matching elements instead of showing them as invisible geometry.
214242
if groups is not None and values_are_categorical and render_params.cmap_params.na_color.alpha == "00":
215-
csv_series = pd.Series(color_source_vector)
216-
keep = csv_series.isin(groups).values
243+
keep, color_source_vector, color_vector = _filter_groups_transparent_na(
244+
groups, color_source_vector, color_vector
245+
)
217246
shapes = shapes[keep].reset_index(drop=True)
218247
sdata_filt[element] = shapes
219-
color_source_vector = pd.Categorical(csv_series[keep].reset_index(drop=True))
220-
color_vector = (
221-
np.asarray(color_vector)[keep]
222-
if not hasattr(color_vector, "reset_index")
223-
else (color_vector[keep].reset_index(drop=True))
224-
)
225248

226249
# color_source_vector is None when the values aren't categorical
227250
if values_are_categorical and render_params.transfunc is not None:
@@ -352,7 +375,7 @@ def _render_shapes(
352375
color_by_categorical = col_for_color is not None and color_source_vector is not None
353376
if color_by_categorical:
354377
cat_series = transformed_element[col_for_color]
355-
if not pd.api.types.is_categorical_dtype(cat_series):
378+
if not isinstance(cat_series.dtype, pd.CategoricalDtype):
356379
cat_series = cat_series.astype("category")
357380
transformed_element[col_for_color] = cat_series
358381

@@ -845,13 +868,8 @@ def _render_points(
845868
# When groups are specified and na_color is fully transparent (na_color=None),
846869
# filter out non-matching points instead of rendering invisible geometry.
847870
if groups is not None and color_source_vector is not None and render_params.cmap_params.na_color.alpha == "00":
848-
csv_series = pd.Series(color_source_vector)
849-
keep = csv_series.isin(groups).values
850-
color_source_vector = pd.Categorical(csv_series[keep].reset_index(drop=True))
851-
color_vector = (
852-
np.asarray(color_vector)[keep]
853-
if not hasattr(color_vector, "reset_index")
854-
else (color_vector[keep].reset_index(drop=True))
871+
keep, color_source_vector, color_vector = _filter_groups_transparent_na(
872+
groups, color_source_vector, color_vector
855873
)
856874
# filter the materialized points, adata, and re-register in sdata_filt
857875
points = points[keep].reset_index(drop=True)
@@ -931,11 +949,11 @@ def _render_points(
931949
color_dtype = transformed_element[col_for_color].dtype if col_for_color is not None else None
932950
color_by_categorical = col_for_color is not None and (
933951
color_source_vector is not None
934-
or pd.api.types.is_categorical_dtype(color_dtype)
952+
or isinstance(color_dtype, pd.CategoricalDtype)
935953
or pd.api.types.is_object_dtype(color_dtype)
936954
or pd.api.types.is_string_dtype(color_dtype)
937955
)
938-
if color_by_categorical and not pd.api.types.is_categorical_dtype(color_dtype):
956+
if color_by_categorical and not isinstance(color_dtype, pd.CategoricalDtype):
939957
transformed_element[col_for_color] = transformed_element[col_for_color].astype("category")
940958

941959
aggregate_with_reduction = None
@@ -944,7 +962,7 @@ def _render_points(
944962
if color_by_categorical:
945963
# add nan as category so that nan points are shown in the nan color
946964
cat_series = transformed_element[col_for_color]
947-
if not pd.api.types.is_categorical_dtype(cat_series):
965+
if not isinstance(cat_series.dtype, pd.CategoricalDtype):
948966
cat_series = cat_series.astype("category")
949967
if hasattr(cat_series.cat, "as_known"):
950968
cat_series = cat_series.cat.as_known()

0 commit comments

Comments
 (0)