Skip to content

Commit d25fff2

Browse files
authored
Merge pull request #85 from Genentech/handle-inf
Handle infinite values in addition to nans
2 parents a99dac9 + d1fa308 commit d25fff2

2 files changed

Lines changed: 13 additions & 11 deletions

File tree

scallops/features/preprocessing.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,16 @@ def _transform_feature_group(x):
5454

5555
def filter_data(
5656
data: anndata.AnnData,
57-
max_fraction_nans: float | None = 0.25,
57+
max_fraction_not_finite: float | None = 0.25,
5858
min_variance: float | None = 0.1,
5959
by: str | Sequence | None = None,
6060
) -> anndata.AnnData:
61-
"""Filter cells using `max_fraction_nans` then filter features using `min_variance`
61+
"""Filter cells using `max_fraction_not_finite` then filter features using
62+
`min_variance`
6263
6364
:param data: AnnData object
64-
:param max_fraction_nans: Keep cells with <= `max_fraction_nans` missing values
65+
:param max_fraction_not_finite: Keep cells with <= `max_fraction_not_finite`
66+
missing or infinite values
6567
:param min_variance: Keep features with variance >= `min_variance`
6668
:param by: Column(s) in `data.obs` to stratify by when computing variance. If
6769
provided, the median variance is used for filtering.
@@ -70,10 +72,10 @@ def filter_data(
7072
xp = get_namespace(data.X)
7173
keep_cells = None
7274
keep_features = None
73-
if max_fraction_nans is not None:
74-
nan_counts_per_cell = xp.isnan(data.X).sum(axis=1)
75-
max_nans = int(data.shape[1] * max_fraction_nans)
76-
keep_cells = nan_counts_per_cell <= max_nans
75+
if max_fraction_not_finite is not None:
76+
invalid_counts_per_cell = (~xp.isfinite(data.X)).sum(axis=1)
77+
max_counts = int(data.shape[1] * max_fraction_not_finite)
78+
keep_cells = invalid_counts_per_cell <= max_counts
7779
if min_variance is not None:
7880
if by is not None:
7981
if isinstance(keep_cells, da.Array):
@@ -102,7 +104,7 @@ def filter_data(
102104
if keep_cells is not None
103105
else xp.var(data.X, axis=0)
104106
)
105-
keep_features = variance >= min_variance
107+
keep_features = (variance >= min_variance) & (xp.isfinite(variance))
106108

107109
if isinstance(data.X, da.Array):
108110
keep_features, keep_cells = dask.compute(keep_features, keep_cells)

scallops/tests/test_features_preprocessing.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,14 @@ def test_filter_data(use_dask, by):
2828
adata.X[1, 0] = 100
2929
adata.X[0, 0] = np.nan
3030
# np.var(adata.X, axis=0) array([nan, 5.], dtype=float32)
31-
test_nan_filter = filter_data(adata, max_fraction_nans=0, min_variance=None)
31+
test_nan_filter = filter_data(adata, max_fraction_not_finite=0, min_variance=None)
3232
assert test_nan_filter.shape == (3, 2)
3333
# np.var(adata.X, axis=0) # array([nan, 5.]
3434
# np.var(adata[adata.obs['well'] == 'well1'].X, axis=0) # array([nan, 4.])
3535
# np.var(adata[adata.obs['well'] == 'well2'].X, axis=0) # array([2209., 4.]
36-
d1 = filter_data(adata, max_fraction_nans=None, min_variance=0, by=by)
36+
d1 = filter_data(adata, max_fraction_not_finite=None, min_variance=0, by=by)
3737
# np.var(adata[1:].X, axis=0) array([2006.2222, 2.6666667]
38-
d2 = filter_data(adata, max_fraction_nans=0, min_variance=5, by=by)
38+
d2 = filter_data(adata, max_fraction_not_finite=0, min_variance=5, by=by)
3939

4040
assert d1.shape == (4, 1)
4141
assert d2.shape == (3, 1)

0 commit comments

Comments
 (0)