Skip to content

Commit f2bfa3f

Browse files
timtreisclaude
andauthored
Fix categorical color mapping across coordinate systems (#547)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 4fe68de commit f2bfa3f

File tree

1 file changed

+32
-6
lines changed

1 file changed

+32
-6
lines changed

src/spatialdata_plot/pl/utils.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1227,10 +1227,24 @@ def _generate_base_categorial_color_mapping(
12271227
cmap_params: CmapParams | None = None,
12281228
) -> Mapping[str, str]:
12291229
if adata is not None and cluster_key in adata.uns and f"{cluster_key}_colors" in adata.uns:
1230-
colors = adata.uns[f"{cluster_key}_colors"]
1231-
categories = color_source_vector.categories.tolist() + ["NaN"]
1230+
all_colors = adata.uns[f"{cluster_key}_colors"]
1231+
1232+
# When plotting per-coordinate-system, the color_source_vector may carry
1233+
# categories from other coordinate systems that aren't present in the
1234+
# current subset. Drop them so that categories and colors stay aligned.
1235+
color_source_vector = color_source_vector.remove_unused_categories()
1236+
1237+
# The stored colors in .uns correspond 1-to-1 to the *full* set of
1238+
# categories in adata.obs[cluster_key]. Subset to the categories that
1239+
# are still present after removing unused ones.
1240+
if cluster_key in adata.obs and hasattr(adata.obs[cluster_key], "cat"):
1241+
all_cats = adata.obs[cluster_key].cat.categories.tolist()
1242+
keep_idx = [i for i, c in enumerate(all_cats) if c in color_source_vector.categories]
1243+
colors = [to_hex(to_rgba(all_colors[i])[:3]) for i in keep_idx]
1244+
else:
1245+
colors = [to_hex(to_rgba(c)[:3]) for c in all_colors]
12321246

1233-
colors = [to_hex(to_rgba(color)[:3]) for color in colors]
1247+
categories = color_source_vector.categories.tolist() + ["NaN"]
12341248

12351249
if len(categories) > len(colors):
12361250
return dict(zip(categories, colors + [na_color.get_hex_with_alpha()], strict=True))
@@ -1345,6 +1359,9 @@ def _extract_colors_from_table_uns(
13451359

13461360
# Extract colors and categories
13471361
stored_colors = adata.uns[color_key]
1362+
# Drop categories not present in the current subset (e.g. when plotting
1363+
# per-coordinate-system) so that positional color lookups stay aligned.
1364+
color_source_vector = color_source_vector.remove_unused_categories()
13481365
categories = color_source_vector.categories.tolist()
13491366

13501367
# Validate na_color format and convert to hex string
@@ -1392,9 +1409,18 @@ def _to_hex_no_alpha(color_value: Any) -> str | None:
13921409
logger.warning(f"Unsupported color storage for '{color_key}'. Expected sequence or mapping.")
13931410
return None
13941411

1395-
for i, category in enumerate(categories):
1396-
if i < len(hex_colors) and hex_colors[i] is not None:
1397-
hex_color = hex_colors[i]
1412+
# Map by the category's position in the *full* table, not in the
1413+
# (possibly subset) color_source_vector, so colors stay consistent
1414+
# across coordinate systems.
1415+
all_cats = (
1416+
adata.obs[col_to_colorby].cat.categories.tolist()
1417+
if col_to_colorby in adata.obs and hasattr(adata.obs[col_to_colorby], "cat")
1418+
else categories
1419+
)
1420+
for category in categories:
1421+
idx = all_cats.index(category) if category in all_cats else None
1422+
if idx is not None and idx < len(hex_colors) and hex_colors[idx] is not None:
1423+
hex_color = hex_colors[idx]
13981424
assert hex_color is not None # type narrowing for mypy
13991425
color_mapping[category] = hex_color
14001426
else:

0 commit comments

Comments
 (0)