Skip to content

Commit c45c082

Browse files
author
Sonja Stockhaus
committed
incorporate review feedback
1 parent 2cecbf1 commit c45c082

File tree

7 files changed

+73
-111
lines changed

7 files changed

+73
-111
lines changed

src/spatialdata_plot/pl/render.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
_ax_show_and_transform,
3838
_create_image_from_datashader_result,
3939
_datashader_aggregate_with_function,
40-
_datashader_shade,
40+
_datashader_map_aggregate_to_color,
4141
_datshader_get_how_kw_for_spread,
4242
_decorate_axs,
4343
_get_collection_shape,
@@ -259,11 +259,12 @@ def _render_shapes(
259259
if isinstance(ds_cmap, str) and ds_cmap[0] == "#":
260260
ds_cmap = ds_cmap[:-2]
261261

262-
ds_result = _datashader_shade(
262+
ds_result = _datashader_map_aggregate_to_color(
263263
agg,
264264
cmap=ds_cmap,
265265
color_key=color_key,
266-
min_alpha=np.min([254, render_params.fill_alpha * 255]),
266+
# min_alpha=np.min([254, render_params.fill_alpha * 255]),
267+
min_alpha=render_params.fill_alpha * 255,
267268
)
268269
elif aggregate_with_reduction is not None: # to shut up mypy
269270
ds_cmap = render_params.cmap_params.cmap
@@ -274,10 +275,11 @@ def _render_shapes(
274275
ds_cmap = matplotlib.colors.to_hex(render_params.cmap_params.cmap(0.0), keep_alpha=False)
275276
aggregate_with_reduction = (aggregate_with_reduction[0], aggregate_with_reduction[0] + 1)
276277

277-
ds_result = _datashader_shade(
278+
ds_result = _datashader_map_aggregate_to_color(
278279
agg,
279280
cmap=ds_cmap,
280-
min_alpha=np.min([254, render_params.fill_alpha * 255]),
281+
# min_alpha=np.min([254, render_params.fill_alpha * 255]),
282+
min_alpha=render_params.fill_alpha * 255,
281283
span=ds_span,
282284
clip=norm.clip,
283285
)
@@ -295,7 +297,8 @@ def _render_shapes(
295297
ds_outlines = ds.tf.shade(
296298
agg_outlines,
297299
cmap=outline_color,
298-
min_alpha=np.min([254, render_params.outline_alpha * 255]),
300+
# min_alpha=np.min([254, render_params.outline_alpha * 255]),
301+
min_alpha=render_params.outline_alpha * 255,
299302
how="linear",
300303
)
301304

@@ -621,11 +624,12 @@ def _render_points(
621624
color_vector = np.asarray([x[:-2] for x in color_vector])
622625

623626
if color_by_categorical or col_for_color is None:
624-
ds_result = _datashader_shade(
627+
ds_result = _datashader_map_aggregate_to_color(
625628
ds.tf.spread(agg, px=px),
626629
cmap=color_vector[0],
627630
color_key=color_key,
628-
min_alpha=np.min([254, render_params.alpha * 255]),
631+
# min_alpha=np.min([254, render_params.alpha * 255]),
632+
min_alpha=render_params.alpha * 255,
629633
)
630634
else:
631635
spread_how = _datshader_get_how_kw_for_spread(render_params.ds_reduction)
@@ -640,12 +644,13 @@ def _render_points(
640644
ds_cmap = matplotlib.colors.to_hex(render_params.cmap_params.cmap(0.0), keep_alpha=False)
641645
aggregate_with_reduction = (aggregate_with_reduction[0], aggregate_with_reduction[0] + 1)
642646

643-
ds_result = _datashader_shade(
647+
ds_result = _datashader_map_aggregate_to_color(
644648
agg,
645649
cmap=ds_cmap,
646650
span=ds_span,
647651
clip=norm.clip,
648-
min_alpha=np.min([254, render_params.alpha * 255]),
652+
# min_alpha=np.min([254, render_params.alpha * 255]),
653+
min_alpha=render_params.alpha * 255,
649654
)
650655

651656
rgba_image, trans_data = _create_image_from_datashader_result(ds_result, factor, ax)
@@ -665,12 +670,6 @@ def _render_points(
665670
if (norm.vmin is not None or norm.vmax is not None) and norm.vmin == norm.vmax:
666671
vmin = norm.vmin - 0.5
667672
vmax = norm.vmin + 0.5
668-
# if norm.clip:
669-
# vmin = norm.vmin
670-
# vmax = norm.vmin + 1
671-
# else:
672-
# vmin = norm.vmin - 0.5
673-
# vmax = norm.vmin + 0.5
674673
cax = ScalarMappable(
675674
norm=matplotlib.colors.Normalize(vmin=vmin, vmax=vmax),
676675
cmap=render_params.cmap_params.cmap,
@@ -804,6 +803,7 @@ def _render_images(
804803
cmap._init()
805804
cmap._lut[:, -1] = render_params.alpha
806805

806+
# norm needs to be passed directly to ax.imshow(). If we normalize before, that method would always clip.
807807
_ax_show_and_transform(
808808
layer, trans_data, ax, cmap=cmap, zorder=render_params.zorder, norm=render_params.cmap_params.norm
809809
)

src/spatialdata_plot/pl/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2271,7 +2271,7 @@ def _get_transformation_matrix_for_datashader(
22712271
return _get_datashader_trans_matrix_of_single_element(trans)
22722272

22732273

2274-
def _datashader_shade(
2274+
def _datashader_map_aggregate_to_color(
22752275
agg: DataArray,
22762276
cmap: str | list[str] | ListedColormap,
22772277
color_key: None | list[str] = None,

tests/conftest.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from functools import wraps
44
from pathlib import Path
55

6+
import matplotlib
67
import matplotlib.pyplot as plt
78
import numpy as np
89
import pandas as pd
@@ -149,6 +150,23 @@ def test_sdata_multiple_images_diverging_dims():
149150
return sdata
150151

151152

153+
@pytest.fixture
154+
def sdata_blobs_shapes_annotated() -> SpatialData:
155+
"""Get blobs sdata with continuous annotation of polygons."""
156+
blob = blobs()
157+
blob["table"].obs["region"] = "blobs_polygons"
158+
blob["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons"
159+
blob.shapes["blobs_polygons"]["value"] = [1, 2, 3, 4, 5]
160+
return blob
161+
162+
163+
def _viridis_with_under_over() -> matplotlib.colors.ListedColormap:
164+
cmap = matplotlib.colormaps["viridis"]
165+
cmap.set_under("black")
166+
cmap.set_over("grey")
167+
return cmap
168+
169+
152170
# Code below taken from spatialdata main repo
153171

154172

tests/pl/test_render_images.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from spatialdata import SpatialData
88

99
import spatialdata_plot # noqa: F401
10-
from tests.conftest import DPI, PlotTester, PlotTesterMeta
10+
from tests.conftest import DPI, PlotTester, PlotTesterMeta, _viridis_with_under_over
1111

1212
RNG = np.random.default_rng(seed=42)
1313
sc.pl.set_rcParams_defaults()
@@ -68,17 +68,15 @@ def test_plot_can_render_two_channels_str_from_multiscale_image(self, sdata_blob
6868

6969
def test_plot_can_pass_normalize_clip_True(self, sdata_blobs: SpatialData):
7070
norm = Normalize(vmin=0.1, vmax=0.5, clip=True)
71-
cmap = matplotlib.colormaps["viridis"]
72-
cmap.set_under("black")
73-
cmap.set_over("grey")
74-
sdata_blobs.pl.render_images(element="blobs_image", channel=0, norm=norm, cmap=cmap).pl.show()
71+
sdata_blobs.pl.render_images(
72+
element="blobs_image", channel=0, norm=norm, cmap=_viridis_with_under_over()
73+
).pl.show()
7574

7675
def test_plot_can_pass_normalize_clip_False(self, sdata_blobs: SpatialData):
7776
norm = Normalize(vmin=0.1, vmax=0.5, clip=False)
78-
cmap = matplotlib.colormaps["viridis"]
79-
cmap.set_under("black")
80-
cmap.set_over("grey")
81-
sdata_blobs.pl.render_images(element="blobs_image", channel=0, norm=norm, cmap=cmap).pl.show()
77+
sdata_blobs.pl.render_images(
78+
element="blobs_image", channel=0, norm=norm, cmap=_viridis_with_under_over()
79+
).pl.show()
8280

8381
def test_plot_can_pass_color_to_single_channel(self, sdata_blobs: SpatialData):
8482
sdata_blobs.pl.render_images(element="blobs_image", channel=1, palette="red").pl.show()

tests/pl/test_render_labels.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from spatialdata.models import TableModel
1313

1414
import spatialdata_plot # noqa: F401
15-
from tests.conftest import DPI, PlotTester, PlotTesterMeta
15+
from tests.conftest import DPI, PlotTester, PlotTesterMeta, _viridis_with_under_over
1616

1717
RNG = np.random.default_rng(seed=42)
1818
sc.pl.set_rcParams_defaults()
@@ -236,19 +236,16 @@ def _make_tablemodel_with_categorical_labels(self, sdata_blobs, labels_name: str
236236
sdata_blobs["other_table"].obs["category"] = sdata_blobs["other_table"].obs["category"].astype("category")
237237

238238
def test_plot_can_color_with_norm_and_clipping(self, sdata_blobs: SpatialData):
239-
cmap = matplotlib.colormaps["viridis"]
240-
cmap.set_under("black")
241-
cmap.set_over("grey")
242239
sdata_blobs.pl.render_labels(
243-
"blobs_labels", color="channel_0_sum", norm=Normalize(400, 1000, clip=True), cmap=cmap
240+
"blobs_labels", color="channel_0_sum", norm=Normalize(400, 1000, clip=True), cmap=_viridis_with_under_over()
244241
).pl.show()
245242

246243
def test_plot_can_color_with_norm_no_clipping(self, sdata_blobs: SpatialData):
247-
cmap = matplotlib.colormaps["viridis"]
248-
cmap.set_under("black")
249-
cmap.set_over("grey")
250244
sdata_blobs.pl.render_labels(
251-
"blobs_labels", color="channel_0_sum", norm=Normalize(400, 1000, clip=False), cmap=cmap
245+
"blobs_labels",
246+
color="channel_0_sum",
247+
norm=Normalize(400, 1000, clip=False),
248+
cmap=_viridis_with_under_over(),
252249
).pl.show()
253250

254251
def test_plot_can_annotate_labels_with_table_layer(self, sdata_blobs: SpatialData):

tests/pl/test_render_points.py

Lines changed: 9 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from spatialdata.transformations._utils import _set_transformations
1515

1616
import spatialdata_plot # noqa: F401
17-
from tests.conftest import DPI, PlotTester, PlotTesterMeta
17+
from tests.conftest import DPI, PlotTester, PlotTesterMeta, _viridis_with_under_over
1818

1919
RNG = np.random.default_rng(seed=42)
2020
sc.pl.set_rcParams_defaults()
@@ -227,67 +227,51 @@ def test_plot_datashader_can_transform_points(self, sdata_blobs: SpatialData):
227227
sdata_blobs.pl.render_points("blobs_points", method="datashader", color="black", size=5).pl.show()
228228

229229
def test_plot_can_use_norm_with_clip(self, sdata_blobs: SpatialData):
230-
cmap = matplotlib.colormaps["viridis"]
231-
cmap.set_under("black")
232-
cmap.set_over("grey")
233-
sdata_blobs.pl.render_points(color="instance_id", size=40, norm=Normalize(3, 7, clip=True), cmap=cmap).pl.show()
230+
sdata_blobs.pl.render_points(
231+
color="instance_id", size=40, norm=Normalize(3, 7, clip=True), cmap=_viridis_with_under_over()
232+
).pl.show()
234233

235234
def test_plot_can_use_norm_without_clip(self, sdata_blobs: SpatialData):
236-
cmap = matplotlib.colormaps["viridis"]
237-
cmap.set_under("black")
238-
cmap.set_over("grey")
239235
sdata_blobs.pl.render_points(
240-
color="instance_id", size=40, norm=Normalize(3, 7, clip=False), cmap=cmap
236+
color="instance_id", size=40, norm=Normalize(3, 7, clip=False), cmap=_viridis_with_under_over()
241237
).pl.show()
242238

243239
def test_plot_datashader_can_use_norm_with_clip(self, sdata_blobs: SpatialData):
244-
cmap = matplotlib.colormaps["viridis"]
245-
cmap.set_under("black")
246-
cmap.set_over("grey")
247240
sdata_blobs.pl.render_points(
248241
color="instance_id",
249242
size=40,
250243
norm=Normalize(3, 7, clip=True),
251-
cmap=cmap,
244+
cmap=_viridis_with_under_over(),
252245
method="datashader",
253246
datashader_reduction="max",
254247
).pl.show()
255248

256249
def test_plot_datashader_can_use_norm_without_clip(self, sdata_blobs: SpatialData):
257-
cmap = matplotlib.colormaps["viridis"]
258-
cmap.set_under("black")
259-
cmap.set_over("grey")
260250
sdata_blobs.pl.render_points(
261251
color="instance_id",
262252
size=40,
263253
norm=Normalize(3, 7, clip=False),
264-
cmap=cmap,
254+
cmap=_viridis_with_under_over(),
265255
method="datashader",
266256
datashader_reduction="max",
267257
).pl.show()
268258

269259
def test_plot_datashader_norm_vmin_eq_vmax_with_clip(self, sdata_blobs: SpatialData):
270-
cmap = matplotlib.colormaps["viridis"]
271-
cmap.set_under("black")
272-
cmap.set_over("grey")
273260
sdata_blobs.pl.render_points(
274261
color="instance_id",
275262
size=40,
276263
norm=Normalize(5, 5, clip=True),
277-
cmap=cmap,
264+
cmap=_viridis_with_under_over(),
278265
method="datashader",
279266
datashader_reduction="max",
280267
).pl.show()
281268

282269
def test_plot_datashader_norm_vmin_eq_vmax_without_clip(self, sdata_blobs: SpatialData):
283-
cmap = matplotlib.colormaps["viridis"]
284-
cmap.set_under("black")
285-
cmap.set_over("grey")
286270
sdata_blobs.pl.render_points(
287271
color="instance_id",
288272
size=40,
289273
norm=Normalize(5, 5, clip=False),
290-
cmap=cmap,
274+
cmap=_viridis_with_under_over(),
291275
method="datashader",
292276
datashader_reduction="max",
293277
).pl.show()

tests/pl/test_render_shapes.py

Lines changed: 16 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from spatialdata.transformations._utils import _set_transformations
1717

1818
import spatialdata_plot # noqa: F401
19-
from tests.conftest import DPI, PlotTester, PlotTesterMeta
19+
from tests.conftest import DPI, PlotTester, PlotTesterMeta, _viridis_with_under_over
2020

2121
RNG = np.random.default_rng(seed=42)
2222
sc.pl.set_rcParams_defaults()
@@ -456,82 +456,47 @@ def test_plot_can_do_non_matching_table(self, sdata_blobs: SpatialData):
456456

457457
sdata_blobs.pl.render_shapes("blobs_circles", color="instance_id").pl.show()
458458

459-
def test_plot_can_color_with_norm_no_clipping(self, sdata_blobs: SpatialData):
460-
blob = deepcopy(sdata_blobs)
461-
blob["table"].obs["region"] = "blobs_polygons"
462-
blob["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons"
463-
blob.shapes["blobs_polygons"]["value"] = [1, 2, 3, 4, 5]
464-
cmap = matplotlib.colormaps["viridis"]
465-
cmap.set_under("black")
466-
cmap.set_over("grey")
467-
blob.pl.render_shapes(
468-
element="blobs_polygons", color="value", norm=Normalize(2, 4, clip=False), cmap=cmap
459+
def test_plot_can_color_with_norm_no_clipping(self, sdata_blobs_shapes_annotated: SpatialData):
460+
sdata_blobs_shapes_annotated.pl.render_shapes(
461+
element="blobs_polygons", color="value", norm=Normalize(2, 4, clip=False), cmap=_viridis_with_under_over()
469462
).pl.show()
470463

471-
def test_plot_datashader_can_color_with_norm_and_clipping(self, sdata_blobs: SpatialData):
472-
blob = deepcopy(sdata_blobs)
473-
blob["table"].obs["region"] = "blobs_polygons"
474-
blob["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons"
475-
blob.shapes["blobs_polygons"]["value"] = [1, 2, 3, 4, 5]
476-
cmap = matplotlib.colormaps["viridis"]
477-
cmap.set_under("black")
478-
cmap.set_over("grey")
479-
blob.pl.render_shapes(
464+
def test_plot_datashader_can_color_with_norm_and_clipping(self, sdata_blobs_shapes_annotated: SpatialData):
465+
sdata_blobs_shapes_annotated.pl.render_shapes(
480466
element="blobs_polygons",
481467
color="value",
482468
norm=Normalize(2, 4, clip=True),
483-
cmap=cmap,
469+
cmap=_viridis_with_under_over(),
484470
method="datashader",
485471
datashader_reduction="max",
486472
).pl.show()
487473

488-
def test_plot_datashader_can_color_with_norm_no_clipping(self, sdata_blobs: SpatialData):
489-
blob = deepcopy(sdata_blobs)
490-
blob["table"].obs["region"] = "blobs_polygons"
491-
blob["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons"
492-
blob.shapes["blobs_polygons"]["value"] = [1, 2, 3, 4, 5]
493-
cmap = matplotlib.colormaps["viridis"]
494-
cmap.set_under("black")
495-
cmap.set_over("grey")
496-
blob.pl.render_shapes(
474+
def test_plot_datashader_can_color_with_norm_no_clipping(self, sdata_blobs_shapes_annotated: SpatialData):
475+
sdata_blobs_shapes_annotated.pl.render_shapes(
497476
element="blobs_polygons",
498477
color="value",
499478
norm=Normalize(2, 4, clip=False),
500-
cmap=cmap,
479+
cmap=_viridis_with_under_over(),
501480
method="datashader",
502481
datashader_reduction="max",
503482
).pl.show()
504483

505-
def test_plot_datashader_norm_vmin_eq_vmax_without_clip(self, sdata_blobs: SpatialData):
506-
blob = deepcopy(sdata_blobs)
507-
blob["table"].obs["region"] = "blobs_polygons"
508-
blob["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons"
509-
blob.shapes["blobs_polygons"]["value"] = [1, 2, 3, 4, 5]
510-
cmap = matplotlib.colormaps["viridis"]
511-
cmap.set_under("black")
512-
cmap.set_over("grey")
513-
blob.pl.render_shapes(
484+
def test_plot_datashader_norm_vmin_eq_vmax_without_clip(self, sdata_blobs_shapes_annotated: SpatialData):
485+
sdata_blobs_shapes_annotated.pl.render_shapes(
514486
element="blobs_polygons",
515487
color="value",
516488
norm=Normalize(3, 3, clip=False),
517-
cmap=cmap,
489+
cmap=_viridis_with_under_over(),
518490
method="datashader",
519491
datashader_reduction="max",
520492
).pl.show()
521493

522-
def test_plot_datashader_norm_vmin_eq_vmax_with_clip(self, sdata_blobs: SpatialData):
523-
blob = deepcopy(sdata_blobs)
524-
blob["table"].obs["region"] = "blobs_polygons"
525-
blob["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons"
526-
blob.shapes["blobs_polygons"]["value"] = [1, 2, 3, 4, 5]
527-
cmap = matplotlib.colormaps["viridis"]
528-
cmap.set_under("black")
529-
cmap.set_over("grey")
530-
blob.pl.render_shapes(
494+
def test_plot_datashader_norm_vmin_eq_vmax_with_clip(self, sdata_blobs_shapes_annotated: SpatialData):
495+
sdata_blobs_shapes_annotated.pl.render_shapes(
531496
element="blobs_polygons",
532497
color="value",
533498
norm=Normalize(3, 3, clip=True),
534-
cmap=cmap,
499+
cmap=_viridis_with_under_over(),
535500
method="datashader",
536501
datashader_reduction="max",
537502
).pl.show()

0 commit comments

Comments
 (0)