@@ -54,14 +54,16 @@ def _transform_feature_group(x):
5454
5555def filter_data (
5656 data : anndata .AnnData ,
57- max_fraction_nans : float | None = 0.25 ,
57+ max_fraction_not_finite : float | None = 0.25 ,
5858 min_variance : float | None = 0.1 ,
5959 by : str | Sequence | None = None ,
6060) -> anndata .AnnData :
61- """Filter cells using `max_fraction_nans` then filter features using `min_variance`
61+ """Filter cells using `max_fraction_not_finite` then filter features using
62+ `min_variance`
6263
6364 :param data: AnnData object
64- :param max_fraction_nans: Keep cells with <= `max_fraction_nans` missing values
65+ :param max_fraction_not_finite: Keep cells with <= `max_fraction_not_finite`
66+ missing or infinite values
6567 :param min_variance: Keep features with variance >= `min_variance`
6668 :param by: Column(s) in `data.obs` to stratify by when computing variance. If
6769 provided, the median variance is used for filtering.
@@ -70,10 +72,10 @@ def filter_data(
7072 xp = get_namespace (data .X )
7173 keep_cells = None
7274 keep_features = None
73- if max_fraction_nans is not None :
74- nan_counts_per_cell = xp .isnan (data .X ).sum (axis = 1 )
75- max_nans = int (data .shape [1 ] * max_fraction_nans )
76- keep_cells = nan_counts_per_cell <= max_nans
75+ if max_fraction_not_finite is not None :
76+ invalid_counts_per_cell = ( ~ xp .isfinite (data .X ) ).sum (axis = 1 )
77+ max_counts = int (data .shape [1 ] * max_fraction_not_finite )
78+ keep_cells = invalid_counts_per_cell <= max_counts
7779 if min_variance is not None :
7880 if by is not None :
7981 if isinstance (keep_cells , da .Array ):
@@ -102,7 +104,7 @@ def filter_data(
102104 if keep_cells is not None
103105 else xp .var (data .X , axis = 0 )
104106 )
105- keep_features = variance >= min_variance
107+ keep_features = ( variance >= min_variance ) & ( xp . isfinite ( variance ))
106108
107109 if isinstance (data .X , da .Array ):
108110 keep_features , keep_cells = dask .compute (keep_features , keep_cells )
0 commit comments