Skip to content

Commit fc95972

Browse files
Sonja-Stockhaustimtreisclaude
authored
nan handling for continuous and categorical coloring (#427)
Co-authored-by: Tim Treis <tim.treis@stud.uni-heidelberg.de> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 059b51a commit fc95972

42 files changed

Lines changed: 465 additions & 94 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

src/spatialdata_plot/pl/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def render_shapes(
213213
`fill_alpha` will overwrite the value present in the cmap.
214214
groups : list[str] | str | None
215215
When using `color` and the key represents discrete labels, `groups` can be used to show only a subset of
216-
them. Other values are set to NA. If elment is None, broadcasting behaviour is attempted (use the same
216+
them. Other values are set to NA. If element is None, broadcasting behaviour is attempted (use the same
217217
values for all elements).
218218
palette : list[str] | str | None
219219
Palette for discrete annotations. List of valid color names that should be used for the categories. Must

src/spatialdata_plot/pl/render.py

Lines changed: 241 additions & 88 deletions
Large diffs are not rendered by default.

src/spatialdata_plot/pl/utils.py

Lines changed: 64 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -983,6 +983,40 @@ def _infer_color_data_kind(
983983
return "numeric", pd.to_numeric(series, errors="coerce")
984984

985985

986+
def _build_alignment_dtype_hint(
987+
sdata: sd.SpatialData | None,
988+
element: object,
989+
color_series: pd.Series,
990+
table_name: str | None,
991+
) -> str:
992+
"""Build a diagnostic hint string for dtype mismatches between element and table indices."""
993+
hints: list[str] = []
994+
color_index_dtype = getattr(color_series.index, "dtype", None)
995+
element_index_dtype = getattr(getattr(element, "index", None), "dtype", None) if element is not None else None
996+
997+
table_instance_dtype = None
998+
instance_key = None
999+
if table_name is not None and sdata is not None and table_name in sdata.tables:
1000+
table = sdata.tables[table_name]
1001+
try:
1002+
_, _, instance_key = get_table_keys(table)
1003+
except (KeyError, ValueError, TypeError, AttributeError):
1004+
instance_key = None
1005+
if instance_key is not None and hasattr(table, "obs") and instance_key in table.obs:
1006+
table_instance_dtype = table.obs[instance_key].dtype
1007+
1008+
if (
1009+
element_index_dtype is not None
1010+
and table_instance_dtype is not None
1011+
and element_index_dtype != table_instance_dtype
1012+
):
1013+
hints.append(f"element index dtype is {element_index_dtype}, '{instance_key}' dtype is {table_instance_dtype}")
1014+
if color_index_dtype is not None and element_index_dtype is not None and color_index_dtype != element_index_dtype:
1015+
hints.append(f"color index dtype is {color_index_dtype}, element index dtype is {element_index_dtype}")
1016+
1017+
return f" (hint: {'; '.join(hints)})" if hints else ""
1018+
1019+
9861020
def _set_color_source_vec(
9871021
sdata: sd.SpatialData,
9881022
element: SpatialElement | None,
@@ -1012,7 +1046,8 @@ def _set_color_source_vec(
10121046

10131047
if len(origins) > 1:
10141048
raise ValueError(
1015-
f"Color key '{value_to_plot}' for element '{element_name}' been found in multiple locations: {origins}."
1049+
f"Color key '{value_to_plot}' for element '{element_name}' was found in multiple locations: {origins}. "
1050+
"Please keep it in exactly one place (preferably on the points parquet for speed) to avoid ambiguity."
10161051
)
10171052

10181053
if len(origins) == 1 and value_to_plot is not None:
@@ -1035,6 +1070,17 @@ def _set_color_source_vec(
10351070
color_source_vector if isinstance(color_source_vector, pd.Series) else pd.Series(color_source_vector)
10361071
)
10371072

1073+
if color_series.isna().all():
1074+
element_label = _format_element_name(element_name)
1075+
location = f"table '{table_name}'" if table_name is not None else "the element"
1076+
dtype_hint = _build_alignment_dtype_hint(sdata, element, color_series, table_name)
1077+
raise ValueError(
1078+
f"Column '{value_to_plot}' for element '{element_label}' contains only missing values after aligning "
1079+
f"with {location}. This usually means the instance ids/indices could not be aligned or converted, so "
1080+
"colors cannot be determined. Please ensure the table annotates the element with matching instance ids."
1081+
f"{dtype_hint}"
1082+
)
1083+
10381084
kind, processed = _infer_color_data_kind(
10391085
series=color_series,
10401086
value_to_plot=value_to_plot,
@@ -1059,6 +1105,9 @@ def _set_color_source_vec(
10591105
return None, numeric_vector, False
10601106

10611107
assert isinstance(processed, pd.Categorical)
1108+
if not processed.ordered:
1109+
# ensure deterministic category order when the source is unordered (e.g., from a Python set)
1110+
processed = processed.reorder_categories(sorted(processed.categories))
10621111
color_source_vector = processed # convert, e.g., `pd.Series`
10631112

10641113
# Use the provided table_name parameter, fall back to only one present
@@ -1138,6 +1187,12 @@ def _set_color_source_vec(
11381187
# (e.g. two categories share a color). Wrapping back in pd.Categorical ensures
11391188
# downstream consumers always receive a Categorical for categorical data.
11401189
color_vector = pd.Categorical(color_source_vector.map(color_mapping, na_action="ignore"))
1190+
# nan handling: only add the NA category if needed, and store it as a hex string
1191+
na_color_hex = na_color.get_hex_with_alpha() if isinstance(na_color, Color) else str(na_color)
1192+
if color_vector.isna().any():
1193+
if na_color_hex not in color_vector.categories:
1194+
color_vector = color_vector.add_categories(na_color_hex)
1195+
color_vector[pd.isna(color_vector)] = na_color_hex
11411196

11421197
return color_source_vector, color_vector, True
11431198

@@ -1165,15 +1220,18 @@ def _map_color_seg(
11651220

11661221
if isinstance(color_vector.dtype, pd.CategoricalDtype):
11671222
# Case A: users wants to plot a categorical column
1168-
if np.any(color_source_vector.isna()):
1169-
cell_id[color_source_vector.isna()] = 0
11701223
val_im: ArrayLike = map_array(seg.copy(), cell_id, color_vector.codes + 1)
11711224
cols = colors.to_rgba_array(color_vector.categories)
11721225
elif pd.api.types.is_numeric_dtype(color_vector.dtype):
11731226
# Case B: user wants to plot a continous column
11741227
if isinstance(color_vector, pd.Series):
11751228
color_vector = color_vector.to_numpy()
1176-
cols = cmap_params.cmap(cmap_params.norm(color_vector))
1229+
# normalize only the not nan values, else the whole array would contain only nan values
1230+
normed_color_vector = color_vector.copy().astype(float)
1231+
normed_color_vector[~np.isnan(normed_color_vector)] = cmap_params.norm(
1232+
normed_color_vector[~np.isnan(normed_color_vector)]
1233+
)
1234+
cols = cmap_params.cmap(normed_color_vector)
11771235
val_im = map_array(seg.copy(), cell_id, cell_id)
11781236
else:
11791237
# Case C: User didn't specify any colors
@@ -2688,6 +2746,7 @@ def _validate_col_for_column_table(
26882746
elif table_name is not None:
26892747
tables = get_element_annotators(sdata, element_name)
26902748
if table_name not in tables:
2749+
logger.warning(f"Table '{table_name}' does not annotate element '{element_name}'.")
26912750
raise KeyError(f"Table '{table_name}' does not annotate element '{element_name}'.")
26922751
if col_for_color not in sdata[table_name].obs.columns and col_for_color not in sdata[table_name].var_names:
26932752
raise KeyError(
@@ -3084,7 +3143,7 @@ def _prepare_transformation(
30843143
def _datashader_map_aggregate_to_color(
30853144
agg: DataArray,
30863145
cmap: str | list[str] | ListedColormap,
3087-
color_key: None | list[str] = None,
3146+
color_key: list[str] | dict[str, str] | None = None,
30883147
min_alpha: float = 40,
30893148
span: None | list[float] = None,
30903149
clip: bool = True,
44.4 KB
43.6 KB
46.2 KB
151 Bytes
4.66 KB
1.72 KB
1.77 KB

0 commit comments

Comments
 (0)