diff --git a/docs/api.md b/docs/api.md index 59c40c2..3713549 100644 --- a/docs/api.md +++ b/docs/api.md @@ -69,6 +69,7 @@ scib_metrics.ilisi_knn(...) utils.convert_knn_graph_to_idx utils.check_square utils.diffusion_nn + utils.anderson_ksamp ``` ### Nearest neighbors diff --git a/docs/references.bib b/docs/references.bib index 310e658..60bac48 100644 --- a/docs/references.bib +++ b/docs/references.bib @@ -36,3 +36,13 @@ @article{buttner2018 pages = {43--49}, publisher = {Springer Science and Business Media {LLC}} } + +@article{lutge2021cellmixs, + title={CellMixS: quantifying and visualizing batch effects in single-cell RNA-seq data}, + author={L{\"u}tge, Almut and Zyprych-Walczak, Joanna and Kunzmann, Urszula Brykczynska and Crowell, Helena L and Calini, Daniela and Malhotra, Dheeraj and Soneson, Charlotte and Robinson, Mark D}, + journal={Life science alliance}, + volume={4}, + number={6}, + year={2021}, + publisher={Life Science Alliance} +} diff --git a/pyproject.toml b/pyproject.toml index 4afe229..a3f119b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ dependencies = [ "matplotlib", "plottable", "tqdm", + "numba", ] [project.optional-dependencies] diff --git a/src/scib_metrics/__init__.py b/src/scib_metrics/__init__.py index 58bcace..29391f5 100644 --- a/src/scib_metrics/__init__.py +++ b/src/scib_metrics/__init__.py @@ -2,6 +2,7 @@ from importlib.metadata import version from . import nearest_neighbors, utils +from ._cms import cell_mixing_score from ._graph_connectivity import graph_connectivity from ._isolated_labels import isolated_labels from ._kbet import kbet, kbet_per_label @@ -26,6 +27,7 @@ "kbet", "kbet_per_label", "graph_connectivity", + "cell_mixing_score", ] __version__ = version("scib-metrics") diff --git a/src/scib_metrics/_cms.py b/src/scib_metrics/_cms.py new file mode 100644 index 0000000..6d97423 --- /dev/null +++ b/src/scib_metrics/_cms.py @@ -0,0 +1,74 @@ +import warnings +from functools import partial + +import numpy as np +import pandas as pd +from scipy.sparse import csr_matrix +from scipy.stats import anderson_ksamp + +from scib_metrics.utils import convert_knn_graph_to_idx + + +def _cms_one_cell( + knn_dists: np.ndarray, knn_cats: np.ndarray, n_categories: int, cell_min: int = 4, unbalanced: bool = False +): + # filter categories with too few cells (cell_min) + cat_values, cat_counts = np.unique(knn_cats, return_counts=True) + cats_to_use = np.where(cat_counts >= cell_min)[0] + cat_values = cat_values[cats_to_use] + mask = np.isin(knn_cats, cat_values) + knn_cats = knn_cats[mask] + knn_dists = knn_dists[mask] + + # do not perform AD test if only one group with enough cells is in knn. + if len(cats_to_use) <= 1: + p = np.nan if unbalanced else 0.0 + else: + # filter cells with the same representation + if np.any(knn_dists == 0): + warnings.warn("Distances equal to 0 - cells with identical representations detected. NaN assigned!") + p = np.nan + else: + # perform AD test with remaining cell + res = anderson_ksamp([knn_dists[knn_cats == cat] for cat in cat_values]) + p = res.significance_level + + return p + + +def cell_mixing_score(X: csr_matrix, batches: np.ndarray, cell_min: int = 10, unbalanced: bool = False) -> np.ndarray: + """Compute the cell-specific mixing score (cms) :cite:p:`lutge2021cellmixs`. + + Parameters + ---------- + X + Array of shape (n_cells, n_cells) with non-zero values + representing distances to exactly each cell's k nearest neighbors. + labels + Array of shape (n_cells,) representing cell type label values + for each cell. + cell_min + Minimum number of cells from each group to be included into the Anderson-Darling test. + unbalanced + If True neighborhoods with only one batch present will be set to NaN. This way they are not included into + any summaries or smoothing. + + Returns + ------- + cms + Array of shape (n_cells,) with the cms score for each cell. + """ + categorical_type_batches = pd.Categorical(batches) + batches = np.asarray(categorical_type_batches.codes) + n_categories = len(categorical_type_batches.categories) + knn_dists, knn_idx = convert_knn_graph_to_idx(X) + knn_cats = np.asarray(batches[knn_idx]) + knn_dists = np.asarray(knn_dists) + + cms_fn = partial(_cms_one_cell, n_categories=n_categories, cell_min=cell_min, unbalanced=unbalanced) + vectorized_fn = np.vectorize(cms_fn, signature="(n),(n)->()") + ps = vectorized_fn(knn_dists, knn_cats) + + # TODO: add smoothing + + return np.array(ps) diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 87bc467..e612ca7 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -77,6 +77,12 @@ def test_ilisi_clisi_knn(): scib_metrics.clisi_knn(X, labels, perplexity=10) +def test_cms(): + X, _, batches = dummy_x_labels_batch(x_is_neighbors_graph=True) + score = scib_metrics.cell_mixing_score(X, batches) + assert len(score) == X.shape[0] + + def test_nmi_ari_cluster_labels_kmeans(): X, labels = dummy_x_labels() out = scib_metrics.nmi_ari_cluster_labels_kmeans(X, labels)