Skip to content

Commit 623e899

Browse files
author
Sonja Stockhaus
committed
fix datashader vmin==vmax behavior and add tests
1 parent f36ea13 commit 623e899

File tree

4 files changed

+90
-22
lines changed

4 files changed

+90
-22
lines changed

src/spatialdata_plot/pl/render.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -236,15 +236,14 @@ def _render_shapes(
236236
norm.vmax = np.max(agg) if norm.vmax is None else norm.vmax
237237
ds_span = [norm.vmin, norm.vmax]
238238
if norm.vmin == norm.vmax:
239+
# edge case, value vmin is rendered as the middle of the cmap
240+
ds_span = [0, 1]
239241
if norm.clip:
240-
# all data is mapped to 0
241-
agg = agg - agg
242+
agg = (agg - agg) + 0.5
242243
else:
243-
# values equal to norm.vmin are mapped to 0, the rest to -1 or 1
244244
agg = agg.where((agg >= norm.vmin) | (np.isnan(agg)), other=-1)
245-
agg = agg.where((agg <= norm.vmin) | (np.isnan(agg)), other=1)
246-
agg = agg.where((agg != norm.vmin) | (np.isnan(agg)), other=0)
247-
ds_span = [-1, 1]
245+
agg = agg.where((agg <= norm.vmin) | (np.isnan(agg)), other=2)
246+
agg = agg.where((agg != norm.vmin) | (np.isnan(agg)), other=0.5)
248247

249248
color_key = (
250249
[x[:-2] for x in color_vector.categories.values]
@@ -326,8 +325,8 @@ def _render_shapes(
326325
vmin = aggregate_with_reduction[0].values if norm.vmin is None else norm.vmin
327326
vmax = aggregate_with_reduction[1].values if norm.vmin is None else norm.vmax
328327
if (norm.vmin is not None or norm.vmax is not None) and norm.vmin == norm.vmax:
329-
vmin = norm.vmin
330-
vmax = norm.vmin + 1
328+
vmin = norm.vmin - 0.5
329+
vmax = norm.vmin + 0.5
331330
cax = ScalarMappable(
332331
norm=matplotlib.colors.Normalize(vmin=vmin, vmax=vmax),
333332
cmap=render_params.cmap_params.cmap,
@@ -596,15 +595,15 @@ def _render_points(
596595
norm.vmax = np.max(agg) if norm.vmax is None else norm.vmax
597596
ds_span = [norm.vmin, norm.vmax]
598597
if norm.vmin == norm.vmax:
598+
ds_span = [0, 1]
599599
if norm.clip:
600-
# all data is mapped to 0
601-
agg = agg - agg
600+
# all data is mapped to 0.5
601+
agg = (agg - agg) + 0.5
602602
else:
603-
# values equal to norm.vmin are mapped to 0, the rest to -1 or 1
603+
# values equal to norm.vmin are mapped to 0.5, the rest to -1 or 2
604604
agg = agg.where((agg >= norm.vmin) | (np.isnan(agg)), other=-1)
605-
agg = agg.where((agg <= norm.vmin) | (np.isnan(agg)), other=1)
606-
agg = agg.where((agg != norm.vmin) | (np.isnan(agg)), other=0)
607-
ds_span = [-1, 1]
605+
agg = agg.where((agg <= norm.vmin) | (np.isnan(agg)), other=2)
606+
agg = agg.where((agg != norm.vmin) | (np.isnan(agg)), other=0.5)
608607

609608
color_key = (
610609
list(color_vector.categories.values)
@@ -637,7 +636,7 @@ def _render_points(
637636
# in case all elements have the same value X: we render them using cmap(0.0),
638637
# using an artificial "span" of [X, X + 1] for the color bar
639638
# else: all elements would get alpha=0 and the color bar would have a weird range
640-
if aggregate_with_reduction[0] == aggregate_with_reduction[1]:
639+
if aggregate_with_reduction[0] == aggregate_with_reduction[1] and (ds_span is None or ds_span != [0, 1]):
641640
ds_cmap = matplotlib.colors.to_hex(render_params.cmap_params.cmap(0.0), keep_alpha=False)
642641
aggregate_with_reduction = (aggregate_with_reduction[0], aggregate_with_reduction[0] + 1)
643642

@@ -646,6 +645,7 @@ def _render_points(
646645
cmap=ds_cmap,
647646
span=ds_span,
648647
clip=norm.clip,
648+
min_alpha=np.min([254, render_params.alpha * 255]),
649649
)
650650

651651
rgba_image, trans_data = _create_image_from_datashader_result(ds_result, factor, ax)
@@ -663,12 +663,14 @@ def _render_points(
663663
vmin = aggregate_with_reduction[0].values if norm.vmin is None else norm.vmin
664664
vmax = aggregate_with_reduction[1].values if norm.vmax is None else norm.vmax
665665
if (norm.vmin is not None or norm.vmax is not None) and norm.vmin == norm.vmax:
666-
if norm.clip:
667-
vmin = norm.vmin
668-
vmax = norm.vmin + 1
669-
else:
670-
vmin = norm.vmin - 0.5
671-
vmax = norm.vmin + 0.5
666+
vmin = norm.vmin - 0.5
667+
vmax = norm.vmin + 0.5
668+
# if norm.clip:
669+
# vmin = norm.vmin
670+
# vmax = norm.vmin + 1
671+
# else:
672+
# vmin = norm.vmin - 0.5
673+
# vmax = norm.vmin + 0.5
672674
cax = ScalarMappable(
673675
norm=matplotlib.colors.Normalize(vmin=vmin, vmax=vmax),
674676
cmap=render_params.cmap_params.cmap,

src/spatialdata_plot/pl/utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2279,8 +2279,14 @@ def _datashader_shade(
22792279
span: None | list[float] = None,
22802280
clip: bool = True,
22812281
) -> ds.tf.Image | np.ndarray[Any, np.dtype[np.uint8]]:
2282-
"""If necessary (norm.clip=False), split shading in 3 parts and in the end, stack results."""
2282+
"""ds.tf.shade() part, ensuring correct clipping behavior.
2283+
2284+
If necessary (norm.clip=False), split shading in 3 parts and in the end, stack results.
2285+
This ensures the correct clipping behavior, because else datashader would always automatically clip.
2286+
"""
22832287
if not clip and isinstance(cmap, Colormap) and span is not None:
2288+
# in case we use datashader together with a Normalize object where clip=False
2289+
# why we need this is documented in https://github.com/scverse/spatialdata-plot/issues/372
22842290
agg_in = agg.where((agg >= span[0]) & (agg <= span[1]))
22852291
img_in = ds.tf.shade(
22862292
agg_in,

tests/pl/test_render_points.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,32 @@ def test_plot_datashader_can_use_norm_without_clip(self, sdata_blobs: SpatialDat
266266
datashader_reduction="max",
267267
).pl.show()
268268

269+
def test_plot_datashader_norm_vmin_eq_vmax_with_clip(self, sdata_blobs: SpatialData):
270+
cmap = matplotlib.colormaps["viridis"]
271+
cmap.set_under("black")
272+
cmap.set_over("grey")
273+
sdata_blobs.pl.render_points(
274+
color="instance_id",
275+
size=40,
276+
norm=Normalize(5, 5, clip=True),
277+
cmap=cmap,
278+
method="datashader",
279+
datashader_reduction="max",
280+
).pl.show()
281+
282+
def test_plot_datashader_norm_vmin_eq_vmax_without_clip(self, sdata_blobs: SpatialData):
283+
cmap = matplotlib.colormaps["viridis"]
284+
cmap.set_under("black")
285+
cmap.set_over("grey")
286+
sdata_blobs.pl.render_points(
287+
color="instance_id",
288+
size=40,
289+
norm=Normalize(5, 5, clip=False),
290+
cmap=cmap,
291+
method="datashader",
292+
datashader_reduction="max",
293+
).pl.show()
294+
269295
def test_plot_can_annotate_points_with_table_obs(self, sdata_blobs: SpatialData):
270296
nrows, ncols = 200, 3
271297
feature_matrix = RNG.random((nrows, ncols))

tests/pl/test_render_shapes.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,40 @@ def test_plot_datashader_can_color_with_norm_no_clipping(self, sdata_blobs: Spat
502502
datashader_reduction="max",
503503
).pl.show()
504504

505+
def test_plot_datashader_norm_vmin_eq_vmax_without_clip(self, sdata_blobs: SpatialData):
506+
blob = deepcopy(sdata_blobs)
507+
blob["table"].obs["region"] = "blobs_polygons"
508+
blob["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons"
509+
blob.shapes["blobs_polygons"]["value"] = [1, 2, 3, 4, 5]
510+
cmap = matplotlib.colormaps["viridis"]
511+
cmap.set_under("black")
512+
cmap.set_over("grey")
513+
blob.pl.render_shapes(
514+
element="blobs_polygons",
515+
color="value",
516+
norm=Normalize(3, 3, clip=False),
517+
cmap=cmap,
518+
method="datashader",
519+
datashader_reduction="max",
520+
).pl.show()
521+
522+
def test_plot_datashader_norm_vmin_eq_vmax_with_clip(self, sdata_blobs: SpatialData):
523+
blob = deepcopy(sdata_blobs)
524+
blob["table"].obs["region"] = "blobs_polygons"
525+
blob["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons"
526+
blob.shapes["blobs_polygons"]["value"] = [1, 2, 3, 4, 5]
527+
cmap = matplotlib.colormaps["viridis"]
528+
cmap.set_under("black")
529+
cmap.set_over("grey")
530+
blob.pl.render_shapes(
531+
element="blobs_polygons",
532+
color="value",
533+
norm=Normalize(3, 3, clip=True),
534+
cmap=cmap,
535+
method="datashader",
536+
datashader_reduction="max",
537+
).pl.show()
538+
505539
def test_plot_can_annotate_shapes_with_table_layer(self, sdata_blobs: SpatialData):
506540
nrows, ncols = 5, 3
507541
feature_matrix = RNG.random((nrows, ncols))

0 commit comments

Comments
 (0)