7070
7171
7272def _coerce_categorical_source (cat_source : Any ) -> pd .Categorical :
73- """Return a pandas Categorical from known, concrete sources only."""
73+ """Return a pandas Categorical from known, concrete sources only.
74+
75+ Raises
76+ ------
77+ TypeError
78+ If *cat_source* is not a ``dd.Series``, ``pd.Series``,
79+ ``pd.Categorical``, or ``np.ndarray``.
80+ """
7481 if isinstance (cat_source , dd .Series ):
75- if pd . api . types . is_categorical_dtype (cat_source .dtype ) and getattr (cat_source .cat , "known" , True ) is False :
82+ if isinstance (cat_source .dtype , pd . CategoricalDtype ) and getattr (cat_source .cat , "known" , True ) is False :
7683 cat_source = cat_source .cat .as_known ()
7784 cat_source = cat_source .compute ()
7885
7986 if isinstance (cat_source , pd .Series ):
80- if pd . api . types . is_categorical_dtype (cat_source .dtype ):
87+ if isinstance (cat_source .dtype , pd . CategoricalDtype ):
8188 return cat_source .array
8289 return pd .Categorical (cat_source )
8390 if isinstance (cat_source , pd .Categorical ):
8491 return cat_source
92+ if isinstance (cat_source , np .ndarray ):
93+ return pd .Categorical (cat_source )
8594
86- return pd .Categorical (pd .Series (cat_source ))
95+ raise TypeError (
96+ f"Cannot coerce { type (cat_source ).__name__ } to pd.Categorical. "
97+ "Expected dd.Series, pd.Series, pd.Categorical, or np.ndarray."
98+ )
8799
88100
89101def _build_datashader_color_key (
@@ -106,6 +118,22 @@ def _build_datashader_color_key(
106118 return color_key
107119
108120
121+ def _filter_groups_transparent_na (
122+ groups : str | list [str ],
123+ color_source_vector : pd .Categorical ,
124+ color_vector : pd .Series | np .ndarray | list [str ],
125+ ) -> tuple [np .ndarray , pd .Categorical , np .ndarray ]:
126+ """Return a boolean mask and filtered color vectors for groups filtering.
127+
128+ Used when ``na_color=None`` (fully transparent) so that non-matching
129+ elements are removed entirely instead of rendered invisibly.
130+ """
131+ keep = color_source_vector .isin (groups )
132+ filtered_csv = color_source_vector [keep ]
133+ filtered_cv = np .asarray (color_vector )[keep ]
134+ return keep , filtered_csv , filtered_cv
135+
136+
109137def _split_colorbar_params (params : dict [str , object ] | None ) -> tuple [dict [str , object ], dict [str , object ], str | None ]:
110138 """Split colorbar params into layout hints, Matplotlib kwargs, and label override."""
111139 layout : dict [str , object ] = {}
@@ -212,16 +240,11 @@ def _render_shapes(
212240 # When groups are specified and na_color is fully transparent (na_color=None),
213241 # filter out non-matching elements instead of showing them as invisible geometry.
214242 if groups is not None and values_are_categorical and render_params .cmap_params .na_color .alpha == "00" :
215- csv_series = pd .Series (color_source_vector )
216- keep = csv_series .isin (groups ).values
243+ keep , color_source_vector , color_vector = _filter_groups_transparent_na (
244+ groups , color_source_vector , color_vector
245+ )
217246 shapes = shapes [keep ].reset_index (drop = True )
218247 sdata_filt [element ] = shapes
219- color_source_vector = pd .Categorical (csv_series [keep ].reset_index (drop = True ))
220- color_vector = (
221- np .asarray (color_vector )[keep ]
222- if not hasattr (color_vector , "reset_index" )
223- else (color_vector [keep ].reset_index (drop = True ))
224- )
225248
226249 # color_source_vector is None when the values aren't categorical
227250 if values_are_categorical and render_params .transfunc is not None :
@@ -352,7 +375,7 @@ def _render_shapes(
352375 color_by_categorical = col_for_color is not None and color_source_vector is not None
353376 if color_by_categorical :
354377 cat_series = transformed_element [col_for_color ]
355- if not pd . api . types . is_categorical_dtype (cat_series ):
378+ if not isinstance (cat_series . dtype , pd . CategoricalDtype ):
356379 cat_series = cat_series .astype ("category" )
357380 transformed_element [col_for_color ] = cat_series
358381
@@ -845,13 +868,8 @@ def _render_points(
845868 # When groups are specified and na_color is fully transparent (na_color=None),
846869 # filter out non-matching points instead of rendering invisible geometry.
847870 if groups is not None and color_source_vector is not None and render_params .cmap_params .na_color .alpha == "00" :
848- csv_series = pd .Series (color_source_vector )
849- keep = csv_series .isin (groups ).values
850- color_source_vector = pd .Categorical (csv_series [keep ].reset_index (drop = True ))
851- color_vector = (
852- np .asarray (color_vector )[keep ]
853- if not hasattr (color_vector , "reset_index" )
854- else (color_vector [keep ].reset_index (drop = True ))
871+ keep , color_source_vector , color_vector = _filter_groups_transparent_na (
872+ groups , color_source_vector , color_vector
855873 )
856874 # filter the materialized points, adata, and re-register in sdata_filt
857875 points = points [keep ].reset_index (drop = True )
@@ -931,11 +949,11 @@ def _render_points(
931949 color_dtype = transformed_element [col_for_color ].dtype if col_for_color is not None else None
932950 color_by_categorical = col_for_color is not None and (
933951 color_source_vector is not None
934- or pd . api . types . is_categorical_dtype (color_dtype )
952+ or isinstance (color_dtype , pd . CategoricalDtype )
935953 or pd .api .types .is_object_dtype (color_dtype )
936954 or pd .api .types .is_string_dtype (color_dtype )
937955 )
938- if color_by_categorical and not pd . api . types . is_categorical_dtype (color_dtype ):
956+ if color_by_categorical and not isinstance (color_dtype , pd . CategoricalDtype ):
939957 transformed_element [col_for_color ] = transformed_element [col_for_color ].astype ("category" )
940958
941959 aggregate_with_reduction = None
@@ -944,7 +962,7 @@ def _render_points(
944962 if color_by_categorical :
945963 # add nan as category so that nan points are shown in the nan color
946964 cat_series = transformed_element [col_for_color ]
947- if not pd . api . types . is_categorical_dtype (cat_series ):
965+ if not isinstance (cat_series . dtype , pd . CategoricalDtype ):
948966 cat_series = cat_series .astype ("category" )
949967 if hasattr (cat_series .cat , "as_known" ):
950968 cat_series = cat_series .cat .as_known ()
0 commit comments