Skip to content

Commit ebb9f68

Browse files
timtreisclaude
andcommitted
Filter non-matching labels when groups is set without explicit na_color
Consistent with shapes/points: when groups is specified, non-matching label IDs are zeroed out (rendered as background) by default. Setting na_color explicitly restores the old "show in grey" behavior. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent a2aa0e9 commit ebb9f68

File tree

5 files changed

+26
-5
lines changed

5 files changed

+26
-5
lines changed

src/spatialdata_plot/pl/basic.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,8 @@ def render_shapes(
213213
`fill_alpha` will overwrite the value present in the cmap.
214214
groups : list[str] | str | None
215215
When using `color` and the key represents discrete labels, `groups` can be used to show only a subset of
216-
them. By default, non-matching elements are filtered out entirely (shapes and points only; labels are
217-
raster-based and cannot be filtered). To show non-matching elements, set ``na_color`` explicitly.
216+
them. By default, non-matching elements are hidden. To show non-matching elements, set ``na_color``
217+
explicitly.
218218
If element is None, broadcasting behaviour is attempted (use the same values for all elements).
219219
palette : list[str] | str | None
220220
Palette for discrete annotations. List of valid color names that should be used for the categories. Must
@@ -399,8 +399,9 @@ def render_points(
399399
value is used instead.
400400
groups : list[str] | str | None
401401
When using `color` and the key represents discrete labels, `groups` can be used to show only a subset of
402-
them. Other values are set to NA. If `element` is `None`, broadcasting behaviour is attempted (use the same
403-
values for all elements).
402+
them. By default, non-matching points are filtered out entirely. To show non-matching points, set
403+
``na_color`` explicitly.
404+
If element is None, broadcasting behaviour is attempted (use the same values for all elements).
404405
palette : list[str] | str | None
405406
Palette for discrete annotations. List of valid color names that should be used for the categories. Must
406407
match the number of groups. If `element` is `None`, broadcasting behaviour is attempted (use the same values
@@ -672,7 +673,7 @@ def render_labels(
672673
table_name to be used for the element if you would like a specific table to be used.
673674
groups : list[str] | str | None
674675
When using `color` and the key represents discrete labels, `groups` can be used to show only a subset of
675-
them. Other values are set to NA. The list can contain multiple discrete labels to be visualized.
676+
them. By default, non-matching labels are hidden. To show non-matching labels, set ``na_color`` explicitly.
676677
palette : list[str] | str | None
677678
Palette for discrete annotations. List of valid color names that should be used for the categories. Must
678679
match the number of groups. The list can contain multiple palettes (one per group) to be visualized. If

src/spatialdata_plot/pl/render.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1531,6 +1531,26 @@ def _render_labels(
15311531
else:
15321532
assert color_source_vector is None
15331533

1534+
# When groups are specified, zero out non-matching label IDs so they render as background.
1535+
# Only show non-matching labels if the user explicitly sets na_color.
1536+
_na = render_params.cmap_params.na_color
1537+
if (
1538+
groups is not None
1539+
and categorical
1540+
and color_source_vector is not None
1541+
and (_na.default_color_set or _na.alpha == "00")
1542+
):
1543+
keep_vec = color_source_vector.isin(groups)
1544+
matching_ids = instance_id[keep_vec]
1545+
keep_mask = np.isin(label.values, matching_ids)
1546+
label = label.copy()
1547+
label.values[~keep_mask] = 0
1548+
instance_id = instance_id[keep_vec]
1549+
color_source_vector = color_source_vector[keep_vec]
1550+
color_vector = color_vector[keep_vec]
1551+
if isinstance(color_vector.dtype, pd.CategoricalDtype):
1552+
color_vector = color_vector.remove_unused_categories()
1553+
15341554
def _draw_labels(seg_erosionpx: int | None, seg_boundaries: bool, alpha: float) -> matplotlib.image.AxesImage:
15351555
labels = _map_color_seg(
15361556
seg=label.values,
Binary file not shown.
-24.3 KB
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)