A method for compound ranking predictions from gene expression signatures using ensemble neural network models.
This package provides tools for virtual drug screening using transcriptional signatures. It includes:
- DrugReflector: Ensemble neural network models for compound ranking predictions
- Signature Refinement: Tools to refine transcriptional signatures using experimental data
- V-score Computation: Fast vectorized v-score calculations for differential expression analysis
- Data Utilities: Functions for preprocessing and handling gene expression data
- Python 3.8+
- PyTorch
- NumPy
- Pandas
- SciPy
- AnnData
- Scanpy (for signature refinement)
- Clone this repository:
git clone git@github.com:Cellarity/drugreflector.git
cd drugreflectorOr install via pip:
pip install drugreflector- Install dependencies:
pip install torch numpy pandas scipy anndata scanpy- Download model checkpoints from Zenodo:
- DOI: 10.5281/zenodo.16912444
- Download all checkpoint files and place them in the
checkpoints/directory
This can be done from the command line using the zenodo-get package:
pip3 install zenodo-get # if not already installed
zenodo_get --output-dir checkpoints 16912444
- (optional, for reproducing published results): Download transition signatures. These are scanpy AnnData objects containing v-scores for cell stat transitions in normal hematopoiesis, and between healthy and diseased B-ALL and breast cancer tissue. They can be input to DrugReflector to prioritize compounds, as done in the publication cited below.
- DOI: 10.5281/zenodo.16921905
The trained model checkpoints are available on Zenodo at DOI 10.5281/zenodo.16912444.
After downloading, your directory structure should look like:
drugreflector/
├── checkpoints/
│ ├── model_fold_0.pt
│ ├── model_fold_1.pt
│ └── model_fold_2.pt
├── drugreflector/
│ ├── drug_reflector.py
│ └── ...
├── utils.py
└── ...
import pandas as pd
import scanpy as sc
import drugreflector as dr
# All major classes available at top level
# dr.DrugReflector, dr.SignatureRefinement, dr.compute_vscores_adata, etc.
# Step 1: Load PBMC 3k dataset with cell type annotations
adata = sc.datasets.pbmc3k() # Unfiltered dataset with more genes
annots = sc.datasets.pbmc3k_processed().obs # Cell type annotations
# Merge annotations
adata.obs = pd.merge(adata.obs, annots, how='left', left_index=True, right_index=True)
# Step 2: Compute v-scores between two monocyte populations
vscores = dr.compute_vscores_adata(
adata,
group_col='louvain',
group1_value='CD14+ Monocytes', # Classical monocytes
group2_value='FCGR3A+ Monocytes' # Non-classical monocytes
)
print(f"V-score comparison: {vscores.name}")
print(f"Computed v-scores for {len(vscores)} genes")
print(f"Top upregulated genes in FCGR3A+ vs CD14+ monocytes:")
print(vscores.nlargest(10))
# Step 3: Initialize DrugReflector with model checkpoints
model_paths = [
'checkpoints/model_fold_0.pt',
'checkpoints/model_fold_1.pt',
'checkpoints/model_fold_2.pt'
]
model = dr.DrugReflector(checkpoint_paths=model_paths)
# Step 4: Make predictions using v-scores
# DrugReflector will automatically preprocess gene names to HGNC format
predictions = model.predict(vscores, n_top=50)
print(f"Prediction results shape: {predictions.shape}")
print(f"Columns: {predictions.columns.names}")
# Access different metrics
print("\nTop 10 predicted compounds by rank:")
rank_col = ('rank', vscores.name) # Uses informative name: 'louvain:CD14+ Monocytes->FCGR3A+ Monocytes'
print(predictions[rank_col].nsmallest(10))
print("\nTop 10 compounds by probability:")
prob_col = ('prob', vscores.name)
print(predictions[prob_col].nlargest(10))
print(f"\nAvailable columns: {list(predictions.columns)}")DrugReflector accepts v-score data in three formats:
# 1. Pandas Series (single v-score vector)
# The series name will be used as the transition identifier in outputs
vscore_series = pd.Series([1.2, -0.8, 0.5, ...], index=['GENE1', 'GENE2', 'GENE3', ...],
name='treatment:control->drug')
predictions = model.predict(vscore_series)
# Columns will be: ('rank', 'treatment:control->drug'), ('logit', 'treatment:control->drug'), etc.
# 2. Pandas DataFrame (multiple transitions/signatures)
vscores_df = pd.DataFrame({
'GENE1': [1.2, 0.8],
'GENE2': [-0.8, 1.1],
'GENE3': [0.5, -0.3]
}, index=['treatment_A', 'treatment_B'])
predictions = model.predict(vscores_df)
# 3. AnnData (v-scores in .X)
vscores_adata = AnnData(
X=vscores_df.values,
var=pd.DataFrame(index=vscores_df.columns),
obs=pd.DataFrame(index=vscores_df.index)
)
predictions = model.predict(vscores_adata)CRITICAL: DrugReflector requires gene names in HGNC (HUGO Gene Nomenclature Committee) format for accurate predictions.
DrugReflector automatically preprocesses gene names to be HGNC-compatible:
# Example of automatic preprocessing
import pandas as pd
from drugreflector import DrugReflector
# Input with mixed gene name formats
mixed_genes = ['tp53', 'EGFR', 'ENSG00000141510.11', 'CDKN1A_at', 'il6.v2']
vscores = pd.Series([1.2, -0.8, 0.5, 2.1, -1.1], index=mixed_genes)
model = DrugReflector(checkpoint_paths=model_paths)
predictions = model.transform(vscores)
# Output shows preprocessing:
# Preprocessing gene names to HGNC format...
# Preprocessed 4/5 gene names for HGNC compatibility
# Examples of changes:
# tp53 -> TP53
# ENSG00000141510.11 -> ENSG00000141510
# CDKN1A_at -> CDKN1A
# il6.v2 -> IL6- Uppercase conversion:
tp53→TP53 - Remove Ensembl versions:
ENSG00000141510.11→ENSG00000141510 - Remove Affymetrix suffixes:
CDKN1A_at→CDKN1A - Remove version numbers:
IL6.v2→IL6 - Clean non-standard characters: Keep only
A-Z,0-9, and-
- Preferred: Use official HGNC symbols (
TP53,EGFR,CDKN1A) - Acceptable: Mixed case, common prefixes/suffixes (automatically cleaned)
- Check coverage: Ensure your genes overlap with the 978 landmark genes used by the model
Refine transcriptional signatures using paired transcriptional + phenotypic data:
import drugreflector as dr
import pandas as pd
# Starting signature (pandas Series with gene names as index)
starting_signature = pd.Series([1.2, -0.8, 0.5, ...],
index=['GENE1', 'GENE2', 'GENE3', ...])
# Initialize signature refinement (available at top level)
refiner = dr.SignatureRefinement(starting_signature)
# Load experimental data (AnnData with compound treatments)
# adata should have:
# - Gene expression data in .X or layers
# - Compound IDs in .obs (e.g., 'compound_id' column)
# - Sample IDs in .obs (e.g., 'sample_id' column)
refiner.load_counts_data(
adata,
compound_id_obs_col='compound_id',
sample_id_obs_cols=['sample_id'],
layer='raw_counts' # or None to use .X
)
# Load phenotypic readouts
readouts = pd.Series([0.8, -1.2, 0.3, ...],
index=['compound_A', 'compound_B', 'compound_C', ...])
refiner.load_phenotypic_readouts(readouts)
# Compute learned signatures using correlation analysis
refiner.compute_learned_signatures(corr_method='pearson')
# Generate refined signatures (interpolation between starting and learned)
refiner.compute_refined_signatures(
learning_rate=0.5, # 0.5 = equal weight to starting and learned
scale_learned_sig=True # Scale learned signature to match starting signature std
)
# Access results
refined_signatures = refiner.refined_signatures # AnnData object
learned_signatures = refiner.learned_signatures # AnnData object# For multiple experimental conditions, specify signature_id_obs_cols
refiner.load_counts_data(
adata,
compound_id_obs_col='compound_id',
sample_id_obs_cols=['sample_id'],
signature_id_obs_cols=['treatment_type', 'timepoint'], # Creates separate signatures
layer='raw_counts'
)
# This will create one learned/refined signature for each unique combination
# of values in signature_id_obs_colsFast vectorized v-score calculations for differential expression analysis:
from drugreflector import compute_vscores_adata, compute_vscore_two_groups
# Compute v-scores between two cell populations
vscores = compute_vscores_adata(
adata,
group_col='cell_type', # Column identifying groups
group1_value='control', # Reference group
group2_value='treatment', # Comparison group
layer=None # Use .X, or specify layer name
)
# vscores is a pandas Series with gene names as index and informative name
print(f"V-score comparison: {vscores.name}") # e.g., "cell_type:control->treatment"
print(f"Top upregulated genes:")
print(vscores.nlargest(10))
print(f"Top downregulated genes:")
print(vscores.nsmallest(10))
# For two arrays directly
group1_values = [1.2, 0.8, 1.5, 0.9] # Reference/control
group2_values = [2.1, 1.9, 2.3, 2.0] # Treatment/comparison
vscore = compute_vscore_two_groups(group1_values, group2_values)from drugreflector import load_h5ad_file, pseudobulk_adata
# Load H5AD file with preprocessing
adata = load_h5ad_file('data.h5ad')
# Pseudobulk single-cell data
pseudobulked = pseudobulk_adata(
adata,
sample_id_obs_cols=['donor_id', 'condition'], # Columns defining samples
method='sum' # or 'mean'
)from drugreflector import compute_vscores
# Use v-scores in existing workflow
transitions = {
'group_col': 'cell_type',
'group1_value': 'control',
'group2_value': 'treatment'
}
vscores_adata = compute_vscores(adata, transitions=transitions)
# Returns AnnData object with v-scores as .X# Example command line usage (if predict.py exists)
python drugreflector/predict.py input.h5ad \
--model1 checkpoints/model_fold_0.pt \
--model2 checkpoints/model_fold_1.pt \
--model3 checkpoints/model_fold_2.pt \
--output results.csv- checkpoint_paths: List of paths to model checkpoint files (.pt)
- device: PyTorch device ('cuda', 'cpu', or 'auto')
predict(data, n_top=None): Get compound predictions with ranks, scores, probabilitiesget_top_compounds(data, n_top=10): Get top N compounds as separate DataFramespredict_top_compounds(data, n_top=50): Alias for get_top_compoundscheck_gene_coverage(gene_names): Check how many genes are recognized by the model
- starting_signature: pandas Series or AnnData with initial signature
load_counts_data(adata, compound_id_obs_col, layer=None, ...): Load raw count dataload_normalized_data(adata, compound_id_obs_col, layer=None, ...): Load normalized dataload_phenotypic_readouts(readouts, readout_col=None, ...): Load phenotypic datacompute_learned_signatures(corr_method='pearson'): Compute signatures from datacompute_refined_signatures(learning_rate=0.5, scale_learned_sig=True): Generate refined signatures
compute_vscore_two_groups(group0, group1): V-score between two arrayscompute_vscores_adata(adata, group_col, group0_value, group1_value, layer=None): V-scores from AnnDatapseudobulk_adata(adata, sample_id_obs_cols, method='sum'): Pseudobulk expression dataload_h5ad_file(filepath): Load and preprocess H5AD filescreate_synthetic_gene_expression(n_obs, n_vars, ...): Generate synthetic data for testing
- Format: AnnData objects (.h5ad files)
- Genes: Must include the 978 landmark genes used by the model
- Gene Symbols: CRITICAL - Gene names must be in HGNC (HUGO Gene Nomenclature Committee) format
- Examples:
TP53,EGFR,CDKN1A,IL6 - DrugReflector automatically converts gene names to HGNC-compatible format (uppercase, removes prefixes/suffixes)
- Supported input formats: Ensembl IDs, Affymetrix probe IDs, mixed case symbols
- Examples:
- Samples: Expression profiles for compounds/treatments of interest
- Preprocessing: Log-transformed, normalized expression values
- Source: Zenodo DOI 10.5281/zenodo.16912445
- Format: PyTorch .pt files
- Count: 3 model files (ensemble of 3-fold cross-validation)
- Expression Data: Raw counts or normalized expression in AnnData format
- Metadata: Compound IDs, sample IDs, and experimental conditions in
.obs - Phenotypic Readouts: Numeric values (e.g., viability, efficacy scores) as pandas Series
If you use this package, please cite:
[To be added upon publication]
Model checkpoints: DOI 10.5281/zenodo.16912445
Copyright 2025, Cellarity Inc
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
-
Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
-
Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
-
Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
-
Commercial use of this software is strictly prohibited without prior written permission. For commercial licensing, please contact: CellarityPublications@cellarity.com
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, I NDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
If you see a warning about scikit-learn version mismatch:
InconsistentVersionWarning: Trying to unpickle estimator LabelEncoder from version 1.2.2 when using version 1.5.1
This occurs because the model checkpoints were trained with scikit-learn 1.2.2. The warning is generally harmless and does not affect functionality, but indicates a version difference between training and inference environments.
If you get unexpected results or low prediction scores:
- Check gene name format: Ensure genes are in HGNC format or compatible
- Verify gene coverage: Check how many of your genes overlap with the 978 landmark genes
- Review preprocessing output: DrugReflector shows which genes were modified during preprocessing
# Check gene coverage using built-in function
coverage = model.check_gene_coverage(your_data.var_names)
print(f"Gene coverage: {coverage['total_found']}/{coverage['total_input']} ({coverage['coverage_percent']:.1f}%)")
# See which genes were not found
if coverage['missing_genes']:
print(f"Missing genes: {coverage['missing_genes'][:10]}") # Show first 10
# See preprocessing changes
for gene_info in coverage['gene_mapping'][:5]: # Show first 5
if gene_info['original'] != gene_info['processed']:
print(f" {gene_info['original']} -> {gene_info['processed']} ({'found' if gene_info['found'] else 'not found'})")For issues and questions, please use the GitHub issue tracker.