Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 22 additions & 20 deletions doubletdetection/doubletdetection.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""Doublet detection in single-cell RNA-seq data."""

import collections
from collections.abc import Callable
import io
import warnings
from contextlib import redirect_stdout

import anndata
import numpy as np
from numpy.typing import NDArray
import phenograph
import scanpy as sc
import scipy.sparse as sp_sparse
Expand Down Expand Up @@ -83,20 +85,20 @@ class BoostClassifier:

def __init__(
self,
boost_rate=0.25,
n_components=30,
n_top_var_genes=10000,
replace=False,
clustering_algorithm="phenograph",
clustering_kwargs=None,
n_iters=10,
normalizer=None,
pseudocount=0.1,
random_state=0,
verbose=False,
standard_scaling=False,
n_jobs=1,
):
boost_rate: float = 0.25,
n_components: int = 30,
n_top_var_genes: int = 10000,
replace: bool = False,
clustering_algorithm: str = "phenograph",
clustering_kwargs: dict | None = None,
n_iters: int = 10,
normalizer: Callable | None = None,
pseudocount: float = 0.1,
random_state: int = 0,
verbose: bool = False,
standard_scaling: bool = False,
n_jobs: int = 1,
) -> None:
self.boost_rate = boost_rate
self.replace = replace
self.clustering_algorithm = clustering_algorithm
Expand Down Expand Up @@ -145,7 +147,7 @@ def __init__(
n_components, n_top_var_genes
)

def fit(self, raw_counts):
def fit(self, raw_counts: NDArray | sp_sparse.csr_matrix) -> "BoostClassifier":
"""Fits the classifier on raw_counts.

Args:
Expand Down Expand Up @@ -226,7 +228,7 @@ def fit(self, raw_counts):

return self

def predict(self, p_thresh=1e-7, voter_thresh=0.9):
def predict(self, p_thresh: float = 1e-7, voter_thresh: float = 0.9) -> NDArray:
"""Produce doublet calls from fitted classifier

Args:
Expand Down Expand Up @@ -266,7 +268,7 @@ def predict(self, p_thresh=1e-7, voter_thresh=0.9):

return self.labels_

def doublet_score(self):
def doublet_score(self) -> NDArray:
"""Produce doublet scores

The doublet score is the average negative log p-value of doublet enrichment
Expand All @@ -284,7 +286,7 @@ def doublet_score(self):

return -avg_log_p

def _one_fit(self):
def _one_fit(self) -> tuple[NDArray, NDArray]:
if self.verbose:
print("\nCreating synthetic doublets...")
self._createDoublets()
Expand Down Expand Up @@ -395,7 +397,7 @@ def _one_fit(self):

return scores, log_p_values

def _createDoublets(self):
def _createDoublets(self) -> None:
"""Create synthetic doublets.

Sets .parents_
Expand All @@ -414,7 +416,7 @@ def _createDoublets(self):
self._raw_synthetics = synthetic
self.parents_ = parents

def _set_clustering_kwargs(self):
def _set_clustering_kwargs(self) -> None:
"""Sets .clustering_kwargs"""
if self.clustering_algorithm == "phenograph":
if "prune" not in self.clustering_kwargs:
Expand Down
32 changes: 20 additions & 12 deletions doubletdetection/plot.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import os
import warnings
from typing import Any

import matplotlib
import numpy as np
from numpy.typing import NDArray
from matplotlib.figure import Figure

try:
os.environ["DISPLAY"]
Expand All @@ -11,7 +14,7 @@
import matplotlib.pyplot as plt


def normalize_counts(raw_counts, pseudocount=0.1):
def normalize_counts(raw_counts: NDArray, pseudocount: float = 0.1) -> NDArray:
"""Normalize count array. Default normalizer used by BoostClassifier.

Args:
Expand All @@ -22,7 +25,6 @@ def normalize_counts(raw_counts, pseudocount=0.1):
ndarray: Normalized data.
"""
# Sum across cells

cell_sums = np.sum(raw_counts, axis=1)

# Mutiply by median and divide each cell by cell sum
Expand All @@ -34,7 +36,13 @@ def normalize_counts(raw_counts, pseudocount=0.1):
return normed


def convergence(clf, show=False, save=None, p_thresh=1e-7, voter_thresh=0.9):
def convergence(
clf: Any,
show: bool = False,
save: str | None = None,
p_thresh: float = 1e-7,
voter_thresh: float = 0.9,
) -> Figure:
"""Produce a plot showing number of cells called doublet per iter

Args:
Expand Down Expand Up @@ -81,15 +89,15 @@ def convergence(clf, show=False, save=None, p_thresh=1e-7, voter_thresh=0.9):


def threshold(
clf,
show=False,
save=None,
log10=True,
log_p_grid=None,
voter_grid=None,
v_step=2,
p_step=5,
):
clf: Any,
show: bool = False,
save: str | None = None,
log10: bool = True,
log_p_grid: NDArray | None = None,
voter_grid: NDArray | None = None,
v_step: int = 2,
p_step: int = 5,
) -> Figure:
"""Produce a plot showing number of cells called doublet across
various thresholds

Expand Down