diff --git a/combat/pycombat.py b/combat/pycombat.py index 8f01dda..07be229 100644 --- a/combat/pycombat.py +++ b/combat/pycombat.py @@ -420,6 +420,27 @@ def check_NAs(dat): return(NAs) +def remove_zero_variance_genes(dat, batches): + """Remove rows (genes) with zero variance in at least one batch + + Arguments: + dat {matrix} -- the data matrix + batches {array} -- batch indices + + Returns: + non_zero_variance_data {matrix} -- the data matrix without rows with zero variance in a batch + genes_to_remove {array} -- array of indices of genes to remove + """ + genes_to_remove = [] + for batch in batches: + batch_data = dat[:, batch] + zero_variance_index = batch_data.var(axis=1) == 0 + genes_to_remove.append(zero_variance_index) + genes_to_remove = np.array(genes_to_remove).any(axis=0) + non_zero_variance_data = dat[~genes_to_remove,:] + return non_zero_variance_data, genes_to_remove + + def calculate_mean_var(design, batches, ref, dat, NAs, ref_batch, n_batches, n_batch, n_array): """ calculates the Normalisation factors @@ -642,7 +663,7 @@ def pycombat(data, batch, mod=[], par_prior=True, prior_plots=False, mean_only=F list_samples = data.columns list_genes = data.index - dat = data.values + dat = data.copy().values check_mean_only(mean_only) @@ -651,6 +672,9 @@ def pycombat(data, batch, mod=[], par_prior=True, prior_plots=False, mean_only=F n_batch, batches, n_batches, n_array = treat_batches(batch) design = treat_covariates(batchmod, mod, ref, n_batch) NAs = check_NAs(dat) + dat, genes_to_remove = remove_zero_variance_genes(dat, batches) + genes_to_keep = list_genes[~genes_to_remove] + genes_to_remove = list_genes[genes_to_remove] if not(NAs): B_hat, grand_mean, var_pooled = calculate_mean_var( design, batches, ref, dat, NAs, ref_batch, n_batches, n_batch, n_array) @@ -662,9 +686,10 @@ def pycombat(data, batch, mod=[], par_prior=True, prior_plots=False, mean_only=F bayes_data = adjust_data(s_data, gamma_star, delta_star, batch_design, n_batches, var_pooled, stand_mean, n_array, ref_batch, ref, batches, dat) - bayes_data_df = pd.DataFrame(bayes_data, + reduced_bayes_data_df = pd.DataFrame(bayes_data, columns = list_samples, - index = list_genes) + index = genes_to_keep) + bayes_data_df = pd.concat([reduced_bayes_data_df, data.loc[genes_to_remove]], axis=0).loc[data.index] return(bayes_data_df) else: diff --git a/combat/test_unit.py b/combat/test_unit.py index b0a148c..c5cb1c8 100644 --- a/combat/test_unit.py +++ b/combat/test_unit.py @@ -40,6 +40,7 @@ from .pycombat import check_mean_only, define_batchmod, check_ref_batch, treat_batches, treat_covariates, check_NAs from .pycombat import calculate_mean_var, calculate_stand_mean from .pycombat import standardise_data, fit_model, adjust_data +from .pycombat import remove_zero_variance_genes from .pycombat import pycombat ########## @@ -140,6 +141,20 @@ def test_all_1(): assert all_1(np.array([1.5,0.5,1,1,1])) == False # This test to show the limit of the method we use +# tests for remove_zero_variance_genes function +def test_remove_zero_variance_genes(): + batches = [np.array([0,1,2,3]), np.array([4,5,6]), np.array([7,8])] + dat = np.array([[0,0,0,0,1,2,1,3,4], + [0,1,2,3,0,1,2,1,2], + [0,1,1,2,2,3,2,1,2], + [1,2,2,1,0,0,0,1,1]]) + reduced_dat, genes_to_remove = remove_zero_variance_genes(dat, batches) + print(genes_to_remove) + + assert reduced_dat.shape == (2, 9) + assert all([a == b for a, b in zip(genes_to_remove, np.array([ True, False, False, True]))]) + + # test for check_mean_only def test_check_mean_only(): check_mean_only(True)