Skip to content

Commit c31d2e9

Browse files
timtreisclaude
andcommitted
Add col_for_color to labels, enabling literal color values like color='red'
Labels now use the same color/col_for_color split as shapes and points, so `render_labels(color="red")` is correctly recognized as a literal color instead of being treated as a column name. Fixes #470 and #478. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 6be1b41 commit c31d2e9

5 files changed

Lines changed: 34 additions & 25 deletions

File tree

src/spatialdata_plot/pl/basic.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -749,6 +749,7 @@ def render_labels(
749749
sdata.plotting_tree[f"{n_steps + 1}_render_labels"] = LabelsRenderParams(
750750
element=element,
751751
color=param_values["color"],
752+
col_for_color=param_values["col_for_color"],
752753
groups=param_values["groups"],
753754
contour_px=param_values["contour_px"],
754755
cmap_params=cmap_params,
@@ -1130,14 +1131,13 @@ def _draw_colorbar(
11301131

11311132
if wanted_labels_on_this_cs:
11321133
table = params_copy.table_name
1133-
if table is not None:
1134-
assert isinstance(params_copy.color, str)
1135-
colors = sc.get.obs_df(sdata[table], [params_copy.color])
1136-
if isinstance(colors[params_copy.color].dtype, pd.CategoricalDtype):
1134+
if table is not None and params_copy.col_for_color is not None:
1135+
colors = sc.get.obs_df(sdata[table], [params_copy.col_for_color])
1136+
if isinstance(colors[params_copy.col_for_color].dtype, pd.CategoricalDtype):
11371137
_maybe_set_colors(
11381138
source=sdata[table],
11391139
target=sdata[table],
1140-
key=params_copy.color,
1140+
key=params_copy.col_for_color,
11411141
palette=params_copy.palette,
11421142
)
11431143

src/spatialdata_plot/pl/render.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1265,7 +1265,7 @@ def _render_labels(
12651265
table_name = render_params.table_name
12661266
table_layer = render_params.table_layer
12671267
palette = render_params.palette
1268-
color = render_params.color
1268+
col_for_color = render_params.col_for_color
12691269
groups = render_params.groups
12701270
scale = render_params.scale
12711271

@@ -1314,23 +1314,25 @@ def _render_labels(
13141314

13151315
_, trans_data = _prepare_transformation(label, coordinate_system, ax)
13161316

1317+
na_color = render_params.color if render_params.color else render_params.cmap_params.na_color
13171318
color_source_vector, color_vector, categorical = _set_color_source_vec(
13181319
sdata=sdata_filt,
13191320
element=label,
13201321
element_name=element,
1321-
value_to_plot=color,
1322+
value_to_plot=col_for_color,
13221323
groups=groups,
13231324
palette=palette,
1324-
na_color=render_params.cmap_params.na_color,
1325+
na_color=na_color,
13251326
cmap_params=render_params.cmap_params,
13261327
table_name=table_name,
13271328
table_layer=table_layer,
1329+
render_type="labels",
13281330
coordinate_system=coordinate_system,
13291331
)
13301332

13311333
# rasterize could have removed labels from label
13321334
# only problematic if color is specified
1333-
if rasterize and color is not None:
1335+
if rasterize and col_for_color is not None:
13341336
labels_in_rasterized_image = np.unique(label.values)
13351337
mask = np.isin(instance_id, labels_in_rasterized_image)
13361338
instance_id = instance_id[mask]
@@ -1408,15 +1410,15 @@ def _draw_labels(seg_erosionpx: int | None, seg_boundaries: bool, alpha: float)
14081410
colorbar_requested = _should_request_colorbar(
14091411
render_params.colorbar,
14101412
has_mappable=cax is not None,
1411-
is_continuous=color is not None and color_source_vector is None and not categorical,
1413+
is_continuous=col_for_color is not None and color_source_vector is None and not categorical,
14121414
)
14131415

14141416
_ = _decorate_axs(
14151417
ax=ax,
14161418
cax=cax,
14171419
fig_params=fig_params,
14181420
adata=table,
1419-
value_to_plot=color,
1421+
value_to_plot=col_for_color,
14201422
color_source_vector=color_source_vector,
14211423
color_vector=color_vector,
14221424
palette=palette,
@@ -1432,7 +1434,7 @@ def _draw_labels(seg_erosionpx: int | None, seg_boundaries: bool, alpha: float)
14321434
colorbar_requests=colorbar_requests,
14331435
colorbar_label=_resolve_colorbar_label(
14341436
render_params.colorbar_params,
1435-
color if isinstance(color, str) else None,
1437+
col_for_color if isinstance(col_for_color, str) else None,
14361438
),
14371439
scalebar_dx=scalebar_params.scalebar_dx,
14381440
scalebar_units=scalebar_params.scalebar_units,

src/spatialdata_plot/pl/render_params.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,8 @@ class LabelsRenderParams:
278278

279279
cmap_params: CmapParams
280280
element: str
281-
color: str | None = None
281+
color: Color | None = None
282+
col_for_color: str | None = None
282283
groups: str | list[str] | None = None
283284
contour_px: int | None = None
284285
outline: bool = False

src/spatialdata_plot/pl/utils.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -981,7 +981,7 @@ def _set_color_source_vec(
981981
alpha: float = 1.0,
982982
table_name: str | None = None,
983983
table_layer: str | None = None,
984-
render_type: Literal["points"] | None = None,
984+
render_type: Literal["points", "labels"] | None = None,
985985
coordinate_system: str | None = None,
986986
) -> tuple[ArrayLike | pd.Series | None, ArrayLike, bool]:
987987
if value_to_plot is None and element is not None:
@@ -1454,7 +1454,7 @@ def _get_categorical_color_mapping(
14541454
alpha: float = 1,
14551455
groups: list[str] | str | None = None,
14561456
palette: list[str] | str | None = None,
1457-
render_type: Literal["points"] | None = None,
1457+
render_type: Literal["points", "labels"] | None = None,
14581458
) -> Mapping[str, str]:
14591459
if not isinstance(color_source_vector, Categorical):
14601460
raise TypeError(f"Expected `categories` to be a `Categorical`, but got {type(color_source_vector).__name__}")
@@ -2145,15 +2145,15 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st
21452145
}:
21462146
if not isinstance(color, str | tuple | list):
21472147
raise TypeError("Parameter 'color' must be a string or a tuple/list of floats.")
2148-
if element_type in {"shapes", "points"}:
2148+
if element_type in {"shapes", "points", "labels"}:
21492149
if _is_color_like(color):
21502150
logger.info("Value for parameter 'color' appears to be a color, using it as such.")
21512151
param_dict["col_for_color"] = None
21522152
param_dict["color"] = Color(color)
21532153
if param_dict["color"].alpha_is_user_defined():
21542154
if element_type == "points" and param_dict.get("alpha") is None:
21552155
param_dict["alpha"] = param_dict["color"].get_alpha_as_float()
2156-
elif element_type == "shapes" and param_dict.get("fill_alpha") is None:
2156+
elif element_type in {"shapes", "labels"} and param_dict.get("fill_alpha") is None:
21572157
param_dict["fill_alpha"] = param_dict["color"].get_alpha_as_float()
21582158
else:
21592159
logger.info(
@@ -2165,7 +2165,7 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st
21652165
param_dict["color"] = None
21662166
else:
21672167
raise ValueError(f"{color} is not a valid RGB(A) array and therefore can't be used as 'color' value.")
2168-
elif "color" in param_dict and element_type != "labels":
2168+
elif "color" in param_dict and element_type != "images":
21692169
param_dict["col_for_color"] = None
21702170

21712171
outline_width = param_dict.get("outline_width")
@@ -2462,15 +2462,18 @@ def _validate_label_render_params(
24622462
element_params[el]["table_layer"] = param_dict["table_layer"]
24632463

24642464
element_params[el]["table_name"] = None
2465-
element_params[el]["color"] = None
2466-
color = param_dict["color"]
2467-
if color is not None:
2468-
color, table_name = _validate_col_for_column_table(sdata, el, color, param_dict["table_name"], labels=True)
2465+
element_params[el]["color"] = param_dict["color"] # literal Color or None
2466+
element_params[el]["col_for_color"] = None
2467+
if (col_for_color := param_dict["col_for_color"]) is not None:
2468+
col_for_color, table_name = _validate_col_for_column_table(
2469+
sdata, el, col_for_color, param_dict["table_name"], labels=True
2470+
)
24692471
element_params[el]["table_name"] = table_name
2470-
element_params[el]["color"] = color
2472+
element_params[el]["col_for_color"] = col_for_color
24712473

2472-
element_params[el]["palette"] = param_dict["palette"] if element_params[el]["table_name"] is not None else None
2473-
element_params[el]["groups"] = param_dict["groups"] if element_params[el]["table_name"] is not None else None
2474+
has_col = element_params[el]["col_for_color"] is not None
2475+
element_params[el]["palette"] = param_dict["palette"] if has_col else None
2476+
element_params[el]["groups"] = param_dict["groups"] if has_col else None
24742477
element_params[el]["colorbar"] = param_dict["colorbar"]
24752478
element_params[el]["colorbar_params"] = param_dict["colorbar_params"]
24762479

tests/pl/test_render_labels.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,9 @@ def test_plot_can_stack_render_labels(self, sdata_blobs: SpatialData):
8484
.pl.show()
8585
)
8686

87+
def test_plot_can_color_by_color_name(self, sdata_blobs: SpatialData):
88+
sdata_blobs.pl.render_labels("blobs_labels", color="red").pl.show()
89+
8790
def test_plot_can_color_labels_by_continuous_variable(self, sdata_blobs: SpatialData):
8891
sdata_blobs.pl.render_labels("blobs_labels", color="channel_0_sum").pl.show()
8992

0 commit comments

Comments
 (0)