@@ -1227,10 +1227,24 @@ def _generate_base_categorial_color_mapping(
12271227 cmap_params : CmapParams | None = None ,
12281228) -> Mapping [str , str ]:
12291229 if adata is not None and cluster_key in adata .uns and f"{ cluster_key } _colors" in adata .uns :
1230- colors = adata .uns [f"{ cluster_key } _colors" ]
1231- categories = color_source_vector .categories .tolist () + ["NaN" ]
1230+ all_colors = adata .uns [f"{ cluster_key } _colors" ]
1231+
1232+ # When plotting per-coordinate-system, the color_source_vector may carry
1233+ # categories from other coordinate systems that aren't present in the
1234+ # current subset. Drop them so that categories and colors stay aligned.
1235+ color_source_vector = color_source_vector .remove_unused_categories ()
1236+
1237+ # The stored colors in .uns correspond 1-to-1 to the *full* set of
1238+ # categories in adata.obs[cluster_key]. Subset to the categories that
1239+ # are still present after removing unused ones.
1240+ if cluster_key in adata .obs and hasattr (adata .obs [cluster_key ], "cat" ):
1241+ all_cats = adata .obs [cluster_key ].cat .categories .tolist ()
1242+ keep_idx = [i for i , c in enumerate (all_cats ) if c in color_source_vector .categories ]
1243+ colors = [to_hex (to_rgba (all_colors [i ])[:3 ]) for i in keep_idx ]
1244+ else :
1245+ colors = [to_hex (to_rgba (c )[:3 ]) for c in all_colors ]
12321246
1233- colors = [ to_hex ( to_rgba ( color )[: 3 ]) for color in colors ]
1247+ categories = color_source_vector . categories . tolist () + [ "NaN" ]
12341248
12351249 if len (categories ) > len (colors ):
12361250 return dict (zip (categories , colors + [na_color .get_hex_with_alpha ()], strict = True ))
@@ -1345,6 +1359,9 @@ def _extract_colors_from_table_uns(
13451359
13461360 # Extract colors and categories
13471361 stored_colors = adata .uns [color_key ]
1362+ # Drop categories not present in the current subset (e.g. when plotting
1363+ # per-coordinate-system) so that positional color lookups stay aligned.
1364+ color_source_vector = color_source_vector .remove_unused_categories ()
13481365 categories = color_source_vector .categories .tolist ()
13491366
13501367 # Validate na_color format and convert to hex string
@@ -1392,9 +1409,18 @@ def _to_hex_no_alpha(color_value: Any) -> str | None:
13921409 logger .warning (f"Unsupported color storage for '{ color_key } '. Expected sequence or mapping." )
13931410 return None
13941411
1395- for i , category in enumerate (categories ):
1396- if i < len (hex_colors ) and hex_colors [i ] is not None :
1397- hex_color = hex_colors [i ]
1412+ # Map by the category's position in the *full* table, not in the
1413+ # (possibly subset) color_source_vector, so colors stay consistent
1414+ # across coordinate systems.
1415+ all_cats = (
1416+ adata .obs [col_to_colorby ].cat .categories .tolist ()
1417+ if col_to_colorby in adata .obs and hasattr (adata .obs [col_to_colorby ], "cat" )
1418+ else categories
1419+ )
1420+ for category in categories :
1421+ idx = all_cats .index (category ) if category in all_cats else None
1422+ if idx is not None and idx < len (hex_colors ) and hex_colors [idx ] is not None :
1423+ hex_color = hex_colors [idx ]
13981424 assert hex_color is not None # type narrowing for mypy
13991425 color_mapping [category ] = hex_color
14001426 else :
0 commit comments