From 4859bd1537e54d2d98dc5c5768d65df5f32a3e26 Mon Sep 17 00:00:00 2001 From: adamgayoso Date: Sat, 11 Jan 2025 15:59:48 +0000 Subject: [PATCH] add python typing --- doubletdetection/doubletdetection.py | 42 +++++++++++++++------------- doubletdetection/plot.py | 32 +++++++++++++-------- 2 files changed, 42 insertions(+), 32 deletions(-) diff --git a/doubletdetection/doubletdetection.py b/doubletdetection/doubletdetection.py index f0793f0..841c0db 100644 --- a/doubletdetection/doubletdetection.py +++ b/doubletdetection/doubletdetection.py @@ -1,12 +1,14 @@ """Doublet detection in single-cell RNA-seq data.""" import collections +from collections.abc import Callable import io import warnings from contextlib import redirect_stdout import anndata import numpy as np +from numpy.typing import NDArray import phenograph import scanpy as sc import scipy.sparse as sp_sparse @@ -83,20 +85,20 @@ class BoostClassifier: def __init__( self, - boost_rate=0.25, - n_components=30, - n_top_var_genes=10000, - replace=False, - clustering_algorithm="phenograph", - clustering_kwargs=None, - n_iters=10, - normalizer=None, - pseudocount=0.1, - random_state=0, - verbose=False, - standard_scaling=False, - n_jobs=1, - ): + boost_rate: float = 0.25, + n_components: int = 30, + n_top_var_genes: int = 10000, + replace: bool = False, + clustering_algorithm: str = "phenograph", + clustering_kwargs: dict | None = None, + n_iters: int = 10, + normalizer: Callable | None = None, + pseudocount: float = 0.1, + random_state: int = 0, + verbose: bool = False, + standard_scaling: bool = False, + n_jobs: int = 1, + ) -> None: self.boost_rate = boost_rate self.replace = replace self.clustering_algorithm = clustering_algorithm @@ -145,7 +147,7 @@ def __init__( n_components, n_top_var_genes ) - def fit(self, raw_counts): + def fit(self, raw_counts: NDArray | sp_sparse.csr_matrix) -> "BoostClassifier": """Fits the classifier on raw_counts. Args: @@ -226,7 +228,7 @@ def fit(self, raw_counts): return self - def predict(self, p_thresh=1e-7, voter_thresh=0.9): + def predict(self, p_thresh: float = 1e-7, voter_thresh: float = 0.9) -> NDArray: """Produce doublet calls from fitted classifier Args: @@ -266,7 +268,7 @@ def predict(self, p_thresh=1e-7, voter_thresh=0.9): return self.labels_ - def doublet_score(self): + def doublet_score(self) -> NDArray: """Produce doublet scores The doublet score is the average negative log p-value of doublet enrichment @@ -284,7 +286,7 @@ def doublet_score(self): return -avg_log_p - def _one_fit(self): + def _one_fit(self) -> tuple[NDArray, NDArray]: if self.verbose: print("\nCreating synthetic doublets...") self._createDoublets() @@ -395,7 +397,7 @@ def _one_fit(self): return scores, log_p_values - def _createDoublets(self): + def _createDoublets(self) -> None: """Create synthetic doublets. Sets .parents_ @@ -414,7 +416,7 @@ def _createDoublets(self): self._raw_synthetics = synthetic self.parents_ = parents - def _set_clustering_kwargs(self): + def _set_clustering_kwargs(self) -> None: """Sets .clustering_kwargs""" if self.clustering_algorithm == "phenograph": if "prune" not in self.clustering_kwargs: diff --git a/doubletdetection/plot.py b/doubletdetection/plot.py index db39356..e470c1a 100644 --- a/doubletdetection/plot.py +++ b/doubletdetection/plot.py @@ -1,8 +1,11 @@ import os import warnings +from typing import Any import matplotlib import numpy as np +from numpy.typing import NDArray +from matplotlib.figure import Figure try: os.environ["DISPLAY"] @@ -11,7 +14,7 @@ import matplotlib.pyplot as plt -def normalize_counts(raw_counts, pseudocount=0.1): +def normalize_counts(raw_counts: NDArray, pseudocount: float = 0.1) -> NDArray: """Normalize count array. Default normalizer used by BoostClassifier. Args: @@ -22,7 +25,6 @@ def normalize_counts(raw_counts, pseudocount=0.1): ndarray: Normalized data. """ # Sum across cells - cell_sums = np.sum(raw_counts, axis=1) # Mutiply by median and divide each cell by cell sum @@ -34,7 +36,13 @@ def normalize_counts(raw_counts, pseudocount=0.1): return normed -def convergence(clf, show=False, save=None, p_thresh=1e-7, voter_thresh=0.9): +def convergence( + clf: Any, + show: bool = False, + save: str | None = None, + p_thresh: float = 1e-7, + voter_thresh: float = 0.9, +) -> Figure: """Produce a plot showing number of cells called doublet per iter Args: @@ -81,15 +89,15 @@ def convergence(clf, show=False, save=None, p_thresh=1e-7, voter_thresh=0.9): def threshold( - clf, - show=False, - save=None, - log10=True, - log_p_grid=None, - voter_grid=None, - v_step=2, - p_step=5, -): + clf: Any, + show: bool = False, + save: str | None = None, + log10: bool = True, + log_p_grid: NDArray | None = None, + voter_grid: NDArray | None = None, + v_step: int = 2, + p_step: int = 5, +) -> Figure: """Produce a plot showing number of cells called doublet across various thresholds