3737 _ax_show_and_transform ,
3838 _create_image_from_datashader_result ,
3939 _datashader_aggregate_with_function ,
40+ _datashader_map_aggregate_to_color ,
4041 _datshader_get_how_kw_for_spread ,
4142 _decorate_axs ,
4243 _get_collection_shape ,
@@ -229,18 +230,20 @@ def _render_shapes(
229230 line_width = render_params .outline_params .linewidth ,
230231 )
231232
233+ ds_span = None
232234 if norm .vmin is not None or norm .vmax is not None :
233235 norm .vmin = np .min (agg ) if norm .vmin is None else norm .vmin
234236 norm .vmax = np .max (agg ) if norm .vmax is None else norm .vmax
235- norm . clip = True # NOTE: mpl currently behaves like clip is always True
237+ ds_span = [ norm . vmin , norm . vmax ]
236238 if norm .vmin == norm .vmax :
237- # data is mapped to 0
238- agg = agg - agg
239- else :
240- agg = (agg - norm .vmin ) / (norm .vmax - norm .vmin )
239+ # edge case, value vmin is rendered as the middle of the cmap
240+ ds_span = [0 , 1 ]
241241 if norm .clip :
242- agg = np .maximum (agg , 0 )
243- agg = np .minimum (agg , 1 )
242+ agg = (agg - agg ) + 0.5
243+ else :
244+ agg = agg .where ((agg >= norm .vmin ) | (np .isnan (agg )), other = - 1 )
245+ agg = agg .where ((agg <= norm .vmin ) | (np .isnan (agg )), other = 2 )
246+ agg = agg .where ((agg != norm .vmin ) | (np .isnan (agg )), other = 0.5 )
244247
245248 color_key = (
246249 [x [:- 2 ] for x in color_vector .categories .values ]
@@ -256,13 +259,12 @@ def _render_shapes(
256259 if isinstance (ds_cmap , str ) and ds_cmap [0 ] == "#" :
257260 ds_cmap = ds_cmap [:- 2 ]
258261
259- ds_result = ds . tf . shade (
262+ ds_result = _datashader_map_aggregate_to_color (
260263 agg ,
261264 cmap = ds_cmap ,
262265 color_key = color_key ,
263266 min_alpha = np .min ([254 , render_params .fill_alpha * 255 ]),
264- how = "linear" ,
265- )
267+ ) # prevent min_alpha == 255, bc that led to fully colored test plots instead of just colored points/shapes
266268 elif aggregate_with_reduction is not None : # to shut up mypy
267269 ds_cmap = render_params .cmap_params .cmap
268270 # in case all elements have the same value X: we render them using cmap(0.0),
@@ -272,12 +274,13 @@ def _render_shapes(
272274 ds_cmap = matplotlib .colors .to_hex (render_params .cmap_params .cmap (0.0 ), keep_alpha = False )
273275 aggregate_with_reduction = (aggregate_with_reduction [0 ], aggregate_with_reduction [0 ] + 1 )
274276
275- ds_result = ds . tf . shade (
277+ ds_result = _datashader_map_aggregate_to_color (
276278 agg ,
277279 cmap = ds_cmap ,
278- how = "linear" ,
279280 min_alpha = np .min ([254 , render_params .fill_alpha * 255 ]),
280- )
281+ span = ds_span ,
282+ clip = norm .clip ,
283+ ) # prevent min_alpha == 255, bc that led to fully colored test plots instead of just colored points/shapes
281284
282285 # shade outlines if needed
283286 outline_color = render_params .outline_params .outline_color
@@ -294,7 +297,7 @@ def _render_shapes(
294297 cmap = outline_color ,
295298 min_alpha = np .min ([254 , render_params .outline_alpha * 255 ]),
296299 how = "linear" ,
297- )
300+ ) # prevent min_alpha == 255, bc that led to fully colored test plots instead of just colored points/shapes
298301
299302 rgba_image , trans_data = _create_image_from_datashader_result (ds_result , factor , ax )
300303 _cax = _ax_show_and_transform (
@@ -322,8 +325,10 @@ def _render_shapes(
322325 vmin = aggregate_with_reduction [0 ].values if norm .vmin is None else norm .vmin
323326 vmax = aggregate_with_reduction [1 ].values if norm .vmin is None else norm .vmax
324327 if (norm .vmin is not None or norm .vmax is not None ) and norm .vmin == norm .vmax :
325- vmin = norm .vmin
326- vmax = norm .vmin + 1
328+ # value (vmin=vmax) is placed in the middle of the colorbar so that we can distinguish it from over and
329+ # under values in case clip=True or clip=False with cmap(under)=cmap(0) & cmap(over)=cmap(1)
330+ vmin = norm .vmin - 0.5
331+ vmax = norm .vmin + 0.5
327332 cax = ScalarMappable (
328333 norm = matplotlib .colors .Normalize (vmin = vmin , vmax = vmax ),
329334 cmap = render_params .cmap_params .cmap ,
@@ -586,18 +591,21 @@ def _render_points(
586591 else :
587592 agg = cvs .points (transformed_element , "x" , "y" , agg = ds .count ())
588593
594+ ds_span = None
589595 if norm .vmin is not None or norm .vmax is not None :
590596 norm .vmin = np .min (agg ) if norm .vmin is None else norm .vmin
591597 norm .vmax = np .max (agg ) if norm .vmax is None else norm .vmax
592- norm . clip = True # NOTE: mpl currently behaves like clip is always True
598+ ds_span = [ norm . vmin , norm . vmax ]
593599 if norm .vmin == norm .vmax :
594- # data is mapped to 0
595- agg = agg - agg
596- else :
597- agg = (agg - norm .vmin ) / (norm .vmax - norm .vmin )
600+ ds_span = [0 , 1 ]
598601 if norm .clip :
599- agg = np .maximum (agg , 0 )
600- agg = np .minimum (agg , 1 )
602+ # all data is mapped to 0.5
603+ agg = (agg - agg ) + 0.5
604+ else :
605+ # values equal to norm.vmin are mapped to 0.5, the rest to -1 or 2
606+ agg = agg .where ((agg >= norm .vmin ) | (np .isnan (agg )), other = - 1 )
607+ agg = agg .where ((agg <= norm .vmin ) | (np .isnan (agg )), other = 2 )
608+ agg = agg .where ((agg != norm .vmin ) | (np .isnan (agg )), other = 0.5 )
601609
602610 color_key = (
603611 list (color_vector .categories .values )
@@ -615,13 +623,12 @@ def _render_points(
615623 color_vector = np .asarray ([x [:- 2 ] for x in color_vector ])
616624
617625 if color_by_categorical or col_for_color is None :
618- ds_result = ds . tf . shade (
626+ ds_result = _datashader_map_aggregate_to_color (
619627 ds .tf .spread (agg , px = px ),
620628 cmap = color_vector [0 ],
621629 color_key = color_key ,
622630 min_alpha = np .min ([254 , render_params .alpha * 255 ]),
623- how = "linear" ,
624- )
631+ ) # prevent min_alpha == 255, bc that led to fully colored test plots instead of just colored points/shapes
625632 else :
626633 spread_how = _datshader_get_how_kw_for_spread (render_params .ds_reduction )
627634 agg = ds .tf .spread (agg , px = px , how = spread_how )
@@ -631,15 +638,17 @@ def _render_points(
631638 # in case all elements have the same value X: we render them using cmap(0.0),
632639 # using an artificial "span" of [X, X + 1] for the color bar
633640 # else: all elements would get alpha=0 and the color bar would have a weird range
634- if aggregate_with_reduction [0 ] == aggregate_with_reduction [1 ]:
641+ if aggregate_with_reduction [0 ] == aggregate_with_reduction [1 ] and ( ds_span is None or ds_span != [ 0 , 1 ]) :
635642 ds_cmap = matplotlib .colors .to_hex (render_params .cmap_params .cmap (0.0 ), keep_alpha = False )
636643 aggregate_with_reduction = (aggregate_with_reduction [0 ], aggregate_with_reduction [0 ] + 1 )
637644
638- ds_result = ds . tf . shade (
645+ ds_result = _datashader_map_aggregate_to_color (
639646 agg ,
640647 cmap = ds_cmap ,
641- how = "linear" ,
642- )
648+ span = ds_span ,
649+ clip = norm .clip ,
650+ min_alpha = np .min ([254 , render_params .alpha * 255 ]),
651+ ) # prevent min_alpha == 255, bc that led to fully colored test plots instead of just colored points/shapes
643652
644653 rgba_image , trans_data = _create_image_from_datashader_result (ds_result , factor , ax )
645654 _ax_show_and_transform (
@@ -656,8 +665,10 @@ def _render_points(
656665 vmin = aggregate_with_reduction [0 ].values if norm .vmin is None else norm .vmin
657666 vmax = aggregate_with_reduction [1 ].values if norm .vmax is None else norm .vmax
658667 if (norm .vmin is not None or norm .vmax is not None ) and norm .vmin == norm .vmax :
659- vmin = norm .vmin
660- vmax = norm .vmin + 1
668+ # value (vmin=vmax) is placed in the middle of the colorbar so that we can distinguish it from over and
669+ # under values in case clip=True or clip=False with cmap(under)=cmap(0) & cmap(over)=cmap(1)
670+ vmin = norm .vmin - 0.5
671+ vmax = norm .vmin + 0.5
661672 cax = ScalarMappable (
662673 norm = matplotlib .colors .Normalize (vmin = vmin , vmax = vmax ),
663674 cmap = render_params .cmap_params .cmap ,
@@ -723,7 +734,6 @@ def _render_images(
723734 legend_params : LegendParams ,
724735 rasterize : bool ,
725736) -> None :
726-
727737 sdata_filt = sdata .filter_by_coordinate_system (
728738 coordinate_system = coordinate_system ,
729739 filter_tables = False ,
@@ -781,9 +791,6 @@ def _render_images(
781791 if n_channels == 1 and not isinstance (render_params .cmap_params , list ):
782792 layer = img .sel (c = channels [0 ]).squeeze () if isinstance (channels [0 ], str ) else img .isel (c = channels [0 ]).squeeze ()
783793
784- if render_params .cmap_params .norm : # type: ignore[attr-defined]
785- layer = render_params .cmap_params .norm (layer ) # type: ignore[attr-defined]
786-
787794 cmap = (
788795 _get_linear_colormap (palette , "k" )[0 ]
789796 if isinstance (palette , list ) and all (isinstance (p , str ) for p in palette )
@@ -794,7 +801,10 @@ def _render_images(
794801 cmap ._init ()
795802 cmap ._lut [:, - 1 ] = render_params .alpha
796803
797- _ax_show_and_transform (layer , trans_data , ax , cmap = cmap , zorder = render_params .zorder )
804+ # norm needs to be passed directly to ax.imshow(). If we normalize before, that method would always clip.
805+ _ax_show_and_transform (
806+ layer , trans_data , ax , cmap = cmap , zorder = render_params .zorder , norm = render_params .cmap_params .norm
807+ )
798808
799809 if legend_params .colorbar :
800810 sm = plt .cm .ScalarMappable (cmap = cmap , norm = render_params .cmap_params .norm )
0 commit comments