Skip to content

Commit 295fe40

Browse files
timtreisclaude
andcommitted
Deduplicate shared logic between _render_shapes and _render_points
Extract 8 helper functions from the near-identical datashader rendering paths in _render_shapes() and _render_points(): - _apply_datashader_norm: norm vmin/vmax edge-case handling - _build_datashader_colorbar_mappable: ScalarMappable construction - _datashader_aggregate: categorical/continuous/no-color aggregation - _datashader_shade_continuous: continuous color mapping + spread + NaN - _datashader_shade_categorical: categorical/no-color color mapping - _render_datashader_result: RGBA image rendering + NaN overlay - _make_palette: ListedColormap construction - _decorate_render: legend/colorbar/scalebar decoration Also refactor the show() dispatch loop in basic.py from 4 if/elif branches to a table-driven pattern. No public API changes. No behavioral changes. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 2893343 commit 295fe40

File tree

2 files changed

+434
-415
lines changed

2 files changed

+434
-415
lines changed

src/spatialdata_plot/pl/basic.py

Lines changed: 68 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -1063,125 +1063,88 @@ def _draw_colorbar(
10631063
assert isinstance(ax, Axes)
10641064
axis_colorbar_requests: list[ColorbarSpec] | None = [] if legend_params.colorbar else None
10651065

1066-
wants_images = False
1067-
wants_labels = False
1068-
wants_points = False
1069-
wants_shapes = False
1066+
# Built per-CS because has_images/labels/points/shapes vary by coordinate system
1067+
_CMD_DISPATCH = {
1068+
"render_images": ("images", _render_images, has_images),
1069+
"render_shapes": ("shapes", _render_shapes, has_shapes),
1070+
"render_points": ("points", _render_points, has_points),
1071+
"render_labels": ("labels", _render_labels, has_labels),
1072+
}
1073+
wants_element_type: dict[str, bool] = dict.fromkeys(("images", "labels", "points", "shapes"), False)
10701074
wanted_elements: list[str] = []
10711075

10721076
for cmd, params in render_cmds:
1077+
dispatch = _CMD_DISPATCH.get(cmd)
1078+
if dispatch is None:
1079+
continue
1080+
element_type_str, render_fn, has_type = dispatch
1081+
element_type = cast(Literal["images", "labels", "points", "shapes"], element_type_str)
1082+
if not has_type:
1083+
continue
1084+
10731085
# We create a copy here as the wanted elements can change from one cs to another.
10741086
params_copy = deepcopy(params)
1075-
if cmd == "render_images" and has_images:
1076-
wanted_elements, wanted_images_on_this_cs, wants_images = _get_wanted_render_elements(
1077-
sdata, wanted_elements, params_copy, cs, "images"
1078-
)
1087+
wanted_elements, wanted_on_cs, wants = _get_wanted_render_elements(
1088+
sdata, wanted_elements, params_copy, cs, element_type
1089+
)
1090+
wants_element_type[element_type] = wants
10791091

1080-
if wanted_images_on_this_cs:
1081-
rasterize = (params_copy.scale is None) or (
1082-
isinstance(params_copy.scale, str)
1083-
and params_copy.scale != "full"
1084-
and (dpi is not None or figsize is not None)
1085-
)
1086-
_render_images(
1087-
sdata=sdata,
1088-
render_params=params_copy,
1089-
coordinate_system=cs,
1090-
ax=ax,
1091-
fig_params=fig_params,
1092-
scalebar_params=scalebar_params,
1093-
legend_params=legend_params,
1094-
colorbar_requests=axis_colorbar_requests,
1095-
rasterize=rasterize,
1096-
)
1097-
1098-
elif cmd == "render_shapes" and has_shapes:
1099-
wanted_elements, wanted_shapes_on_this_cs, wants_shapes = _get_wanted_render_elements(
1100-
sdata, wanted_elements, params_copy, cs, "shapes"
1101-
)
1092+
if not wanted_on_cs:
1093+
continue
11021094

1103-
if wanted_shapes_on_this_cs:
1104-
_render_shapes(
1105-
sdata=sdata,
1106-
render_params=params_copy,
1107-
coordinate_system=cs,
1108-
ax=ax,
1109-
fig_params=fig_params,
1110-
scalebar_params=scalebar_params,
1111-
legend_params=legend_params,
1112-
colorbar_requests=axis_colorbar_requests,
1113-
)
1114-
1115-
elif cmd == "render_points" and has_points:
1116-
wanted_elements, wanted_points_on_this_cs, wants_points = _get_wanted_render_elements(
1117-
sdata, wanted_elements, params_copy, cs, "points"
1118-
)
1095+
# Pre-render hooks for specific element types
1096+
extra_kwargs: dict[str, Any] = {}
11191097

1120-
if wanted_points_on_this_cs:
1121-
_render_points(
1122-
sdata=sdata,
1123-
render_params=params_copy,
1124-
coordinate_system=cs,
1125-
ax=ax,
1126-
fig_params=fig_params,
1127-
scalebar_params=scalebar_params,
1128-
legend_params=legend_params,
1129-
colorbar_requests=axis_colorbar_requests,
1130-
)
1131-
1132-
elif cmd == "render_labels" and has_labels:
1133-
wanted_elements, wanted_labels_on_this_cs, wants_labels = _get_wanted_render_elements(
1134-
sdata, wanted_elements, params_copy, cs, "labels"
1098+
if cmd in ("render_images", "render_labels"):
1099+
extra_kwargs["rasterize"] = (params_copy.scale is None) or (
1100+
isinstance(params_copy.scale, str)
1101+
and params_copy.scale != "full"
1102+
and (dpi is not None or figsize is not None)
11351103
)
11361104

1137-
if wanted_labels_on_this_cs:
1138-
table = params_copy.table_name
1139-
if table is not None and params_copy.col_for_color is not None:
1140-
colors = sc.get.obs_df(sdata[table], [params_copy.col_for_color])
1141-
if isinstance(colors[params_copy.col_for_color].dtype, pd.CategoricalDtype):
1142-
_maybe_set_colors(
1143-
source=sdata[table],
1144-
target=sdata[table],
1145-
key=params_copy.col_for_color,
1146-
palette=params_copy.palette,
1147-
)
1148-
1149-
rasterize = (params_copy.scale is None) or (
1150-
isinstance(params_copy.scale, str)
1151-
and params_copy.scale != "full"
1152-
and (dpi is not None or figsize is not None)
1153-
)
1154-
_render_labels(
1155-
sdata=sdata,
1156-
render_params=params_copy,
1157-
coordinate_system=cs,
1158-
ax=ax,
1159-
fig_params=fig_params,
1160-
scalebar_params=scalebar_params,
1161-
legend_params=legend_params,
1162-
colorbar_requests=axis_colorbar_requests,
1163-
rasterize=rasterize,
1164-
)
1165-
1166-
if title is None:
1167-
t = cs
1168-
elif len(title) == 1:
1169-
t = title[0]
1170-
else:
1171-
try:
1172-
t = title[i]
1173-
except IndexError as e:
1174-
raise IndexError("The number of titles must match the number of coordinate systems.") from e
1175-
ax.set_title(t)
1176-
ax.set_aspect("equal")
1105+
if cmd == "render_labels":
1106+
table = params_copy.table_name
1107+
if table is not None and params_copy.col_for_color is not None:
1108+
colors = sc.get.obs_df(sdata[table], [params_copy.col_for_color])
1109+
if isinstance(colors[params_copy.col_for_color].dtype, pd.CategoricalDtype):
1110+
_maybe_set_colors(
1111+
source=sdata[table],
1112+
target=sdata[table],
1113+
key=params_copy.col_for_color,
1114+
palette=params_copy.palette,
1115+
)
1116+
1117+
render_fn( # type: ignore[operator]
1118+
sdata=sdata,
1119+
render_params=params_copy,
1120+
coordinate_system=cs,
1121+
ax=ax,
1122+
fig_params=fig_params,
1123+
scalebar_params=scalebar_params,
1124+
legend_params=legend_params,
1125+
colorbar_requests=axis_colorbar_requests,
1126+
**extra_kwargs,
1127+
)
1128+
1129+
if title is None:
1130+
t = cs
1131+
elif len(title) == 1:
1132+
t = title[0]
1133+
else:
1134+
try:
1135+
t = title[i]
1136+
except IndexError as e:
1137+
raise IndexError("The number of titles must match the number of coordinate systems.") from e
1138+
ax.set_title(t)
1139+
ax.set_aspect("equal")
11771140

11781141
extent = get_extent(
11791142
sdata,
11801143
coordinate_system=cs,
1181-
has_images=has_images and wants_images,
1182-
has_labels=has_labels and wants_labels,
1183-
has_points=has_points and wants_points,
1184-
has_shapes=has_shapes and wants_shapes,
1144+
has_images=has_images and wants_element_type["images"],
1145+
has_labels=has_labels and wants_element_type["labels"],
1146+
has_points=has_points and wants_element_type["points"],
1147+
has_shapes=has_shapes and wants_element_type["shapes"],
11851148
elements=wanted_elements,
11861149
)
11871150
cs_x_min, cs_x_max = extent["x"]

0 commit comments

Comments
 (0)