Skip to content

Commit d26829d

Browse files
author
Sonja Stockhaus
committed
use datashader span arg, fix image single channel with norm rendering
1 parent b91de80 commit d26829d

2 files changed

Lines changed: 31 additions & 12 deletions

File tree

src/spatialdata_plot/pl/render.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -549,18 +549,26 @@ def _render_points(
549549
else:
550550
agg = cvs.points(transformed_element, "x", "y", agg=ds.count())
551551

552+
ds_span = None
552553
if norm.vmin is not None or norm.vmax is not None:
553554
norm.vmin = np.min(agg) if norm.vmin is None else norm.vmin
554555
norm.vmax = np.max(agg) if norm.vmax is None else norm.vmax
555-
norm.clip = True # NOTE: mpl currently behaves like clip is always True
556+
ds_span = [norm.vmin, norm.vmax]
556557
if norm.vmin == norm.vmax:
557-
# data is mapped to 0
558-
agg = agg - agg
559-
else:
560-
agg = (agg - norm.vmin) / (norm.vmax - norm.vmin)
561558
if norm.clip:
562-
agg = np.maximum(agg, 0)
563-
agg = np.minimum(agg, 1)
559+
# all data is mapped to 0
560+
agg = agg - agg
561+
else:
562+
# values equal to norm.vmin are mapped to 0, the rest to -1 or 1
563+
agg = agg.where((agg >= norm.vmin) | (np.isnan(agg)), other=-1)
564+
agg = agg.where((agg <= norm.vmin) | (np.isnan(agg)), other=1)
565+
agg = agg.where((agg != norm.vmin) | (np.isnan(agg)), other=0)
566+
ds_span = [-1, 1]
567+
# else:
568+
# agg = (agg - norm.vmin) / (norm.vmax - norm.vmin)
569+
# if norm.clip:
570+
# agg = np.maximum(agg, 0)
571+
# agg = np.minimum(agg, 1)
564572

565573
color_key = (
566574
list(color_vector.categories.values)
@@ -602,6 +610,7 @@ def _render_points(
602610
agg,
603611
cmap=ds_cmap,
604612
how="linear",
613+
span=ds_span,
605614
)
606615

607616
rgba_image, trans_data = _create_image_from_datashader_result(ds_result, factor, ax)
@@ -619,8 +628,12 @@ def _render_points(
619628
vmin = aggregate_with_reduction[0].values if norm.vmin is None else norm.vmin
620629
vmax = aggregate_with_reduction[1].values if norm.vmax is None else norm.vmax
621630
if (norm.vmin is not None or norm.vmax is not None) and norm.vmin == norm.vmax:
622-
vmin = norm.vmin
623-
vmax = norm.vmin + 1
631+
if norm.clip:
632+
vmin = norm.vmin
633+
vmax = norm.vmin + 1
634+
else:
635+
vmin = norm.vmin - 0.5
636+
vmax = norm.vmin + 0.5
624637
cax = ScalarMappable(
625638
norm=matplotlib.colors.Normalize(vmin=vmin, vmax=vmax),
626639
cmap=render_params.cmap_params.cmap,
@@ -744,8 +757,9 @@ def _render_images(
744757
if n_channels == 1 and not isinstance(render_params.cmap_params, list):
745758
layer = img.sel(c=channels[0]).squeeze() if isinstance(channels[0], str) else img.isel(c=channels[0]).squeeze()
746759

747-
if render_params.cmap_params.norm: # type: ignore[attr-defined]
748-
layer = render_params.cmap_params.norm(layer) # type: ignore[attr-defined]
760+
# TODO: remove, pushed norm to imshow()
761+
# if render_params.cmap_params.norm: # type: ignore[attr-defined]
762+
# layer = render_params.cmap_params.norm(layer) # type: ignore[attr-defined]
749763

750764
cmap = (
751765
_get_linear_colormap(palette, "k")[0]
@@ -757,7 +771,9 @@ def _render_images(
757771
cmap._init()
758772
cmap._lut[:, -1] = render_params.alpha
759773

760-
_ax_show_and_transform(layer, trans_data, ax, cmap=cmap, zorder=render_params.zorder)
774+
_ax_show_and_transform(
775+
layer, trans_data, ax, cmap=cmap, zorder=render_params.zorder, norm=render_params.cmap_params.norm
776+
)
761777

762778
if legend_params.colorbar:
763779
sm = plt.cm.ScalarMappable(cmap=cmap, norm=render_params.cmap_params.norm)

src/spatialdata_plot/pl/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1982,6 +1982,7 @@ def _ax_show_and_transform(
19821982
cmap: ListedColormap | LinearSegmentedColormap | None = None,
19831983
zorder: int = 0,
19841984
extent: list[float] | None = None,
1985+
norm: Normalize | None = None,
19851986
) -> matplotlib.image.AxesImage:
19861987
# default extent in mpl:
19871988
image_extent = [-0.5, array.shape[1] - 0.5, array.shape[0] - 0.5, -0.5]
@@ -2004,6 +2005,7 @@ def _ax_show_and_transform(
20042005
alpha=alpha,
20052006
zorder=zorder,
20062007
extent=tuple(image_extent),
2008+
norm=norm,
20072009
)
20082010
im.set_transform(trans_data)
20092011
else:
@@ -2012,6 +2014,7 @@ def _ax_show_and_transform(
20122014
cmap=cmap,
20132015
zorder=zorder,
20142016
extent=tuple(image_extent),
2017+
norm=norm,
20152018
)
20162019
im.set_transform(trans_data)
20172020
return im

0 commit comments

Comments
 (0)