diff --git a/src/dbspace/readout/ClinVect.py b/src/dbspace/readout/ClinVect.py index d6714c1..bce07d1 100644 --- a/src/dbspace/readout/ClinVect.py +++ b/src/dbspace/readout/ClinVect.py @@ -244,7 +244,7 @@ def Stim_Change_Table(self): ) # find the phase corresponding to the stim change bump_phases = np.array( - [np.array(dbo.all_phases)[0:][idxs] for idxs in diff_matrix] + [np.array(Phase_List('all'))[0:][idxs] for idxs in diff_matrix] ) full_table = [ diff --git a/src/dbspace/readout/OBands.py b/src/dbspace/readout/OBands.py index 78ce3ff..e540956 100644 --- a/src/dbspace/readout/OBands.py +++ b/src/dbspace/readout/OBands.py @@ -11,26 +11,17 @@ import sys sys.path.append('/home/virati/Dropbox/projects/Research/MDD-DBS/Ephys/DBSpace/') -import DBSpace as dbo -from DBSpace import nestdict - -from DBSpace import unity,displog -import scipy.stats as stats - -import pdb -import numpy as np - -import matplotlib.pyplot as plt -import seaborn as sns from collections import defaultdict -#sns.set() +import dbspace as dbo +import matplotlib.pyplot as plt +import numpy as np +import scipy.stats as stats +from dbspace.utils.functions import unity +from dbspace.utils.structures import nestdict -sns.set_context('talk') -sns.set(font_scale=4) -sns.set_style('white') -class naive_readout: +class naive_biomarker: def __init__(self,feat_frame,ClinFrame): self.feat_frame = feat_frame @@ -174,7 +165,7 @@ def feat_extract(self,do_corrections=False): rr.update({'FeatVect':feat_dict}) -#Standard two state/categorical analyses HERE + #Standard two state/categorical analyses HERE def mean_psds(self,patients='all',weeks=['C01','C23']): if patients == 'all': patients = dbo.all_pts @@ -476,4 +467,4 @@ def scatter_state(self,weeks='all',pt='all',feat='Alpha',circ='',plot=True,plot_ plt.suptitle(feat + ' over weeks; ' + str(pt)) - return feats,outstats, weeks_osc_distr \ No newline at end of file + return feats,outstats, weeks_osc_distr diff --git a/src/dbspace/readout/controllers.py b/src/dbspace/readout/controllers.py new file mode 100644 index 0000000..668e906 --- /dev/null +++ b/src/dbspace/readout/controllers.py @@ -0,0 +1,220 @@ +import numpy as np +from sklearn import metrics +from sklearn.metrics import ( + auc, + average_precision_score, + mean_absolute_error, + mean_squared_error, + precision_recall_curve, + roc_auc_score, + roc_curve, +) +from sklearn.utils import shuffle +from dbspace.utils.structures import nestdict +import matplotlib as plt +from scipy import interp +import random + + +class controller_analysis: + def __init__(self, readout, **kwargs): + self.readout_model = readout + # get our binarized disease states + self.binarized_type = kwargs["bin_type"] + + def gen_binarized_state(self, **kwargs): + # redo our testing set + if kwargs["approach"] == "threshold": + binarized = kwargs["input_c"] > 0.5 + elif kwargs["approach"] == "stim_changes": + query_array = kwargs["input_ptph"] + binarized = [ + self.readout_model.CFrame.query_stim_change(pt, ph) + for pt, ph in query_array + ] + else: + raise Exception + + return binarized + + def pr_classif(self, binarized, predicted): + + precision, recall, thresholds = precision_recall_curve(binarized, predicted) + + # plt.figure() + # plt.step(recall,precision) + return precision, recall + + def pr_oracle(self, binarized, level=0.5): + oracle = np.array(np.copy(binarized)).astype(np.float) + oracle += np.random.normal(0, level, size=oracle.shape) + + precision, recall, thresholds = precision_recall_curve(binarized, oracle) + return precision, recall + + def pr_classif_2pred(self, binarized, predicted, empirical): + empirical = np.array(empirical).squeeze() + precision, recall, thresholds = precision_recall_curve( + binarized, empirical - predicted + ) + return precision, recall + + def bin_classif(self, binarized, predicted): + fpr, tpr, thresholds = metrics.roc_curve(binarized, predicted) + roc_curve = (fpr, tpr, thresholds) + auc = roc_auc_score(binarized, predicted) + + return auc, roc_curve + + def controller_simulations(self): + """ + Controller Types: + "Readout": The main DR-SCC + "Empirical + Readout": The nHDRS with the readout + "Empirical + inv_readout": The nHDRS with the inverse of the readout + "Oracle": The best case scenario along with some noise + "Null": Pure chance + "Empirical": The nHDRS + + """ + controllers = nestdict() + controllers = nestdict() + + for ii in range(100): + test_subset_y, test_subset_c, test_subset_pt, test_subset_ph = zip( + *random.sample( + list( + zip( + self.readout_model.test_set_y, + self.readout_model.test_set_c, + self.readout_model.test_set_pt, + self.readout_model.test_set_ph, + ) + ), + np.ceil(0.8 * len(self.readout_model.test_set_y)).astype(np.int), + ) + ) + predicted_c = self.readout_model.decode_model.predict(test_subset_y) + + # test_subset_pt = shuffle(test_subset_pt);print('PR_Classif: Shuffling Data') + binarized_c = self.gen_binarized_state( + approach="stim_changes", + input_ptph=list(zip(test_subset_pt, test_subset_ph)), + ) + # shuffle? + # binarized_c = shuffle(binarized_c);print('PR_Classif: Shuffling binarization') + coinflip = np.random.choice( + [0, 1], size=(len(test_subset_pt),), p=[0.5, 0.5] + ) + + controllers["readout"].append(self.pr_classif(binarized_c, predicted_c)) + controllers["inv_readout"].append(self.pr_classif(binarized_c, 1/predicted_c)) + controllers["empirical+readout"].append( + self.pr_classif_2pred(binarized_c, predicted_c, test_subset_c) + ) + controllers["empirical+inv_readout"].append( + self.pr_classif_2pred(binarized_c, 1/predicted_c, test_subset_c) + ) + controllers["oracle"].append(self.pr_oracle(binarized_c, level=0.5)) + controllers["empirical"].append(self.pr_classif(binarized_c, test_subset_c)) + controllers["null"].append(self.pr_classif(binarized_c, coinflip)) + + self.controllers = controllers + + def controller_sim_metrics(self): + # organize results + controllers = self.controllers + aucs = nestdict() + pr_curves = nestdict() + + plot_controllers = ["readout","empirical","null","oracle"] + for kk in plot_controllers: + for ii in range(100): + aucs[kk].append( + metrics.auc(controllers[kk][ii][1], controllers[kk][ii][0]) + ) + pr_curves[kk].append((controllers[kk][ii][0], controllers[kk][ii][1])) + + self.plot_controller_runs(aucs[kk], pr_curves[kk], title=kk) + + def classif_runs( + self, + ): + aucs = [] + roc_curves = [] + + null_aucs = [] + null_roc_curves = [] + + for ii in range(100): + test_subset_y, test_subset_c, test_subset_pt, test_subset_ph = zip( + *random.sample( + list( + zip( + self.readout_model.test_set_y, + self.readout_model.test_set_c, + self.readout_model.test_set_pt, + self.readout_model.test_set_ph, + ) + ), + np.ceil(0.8 * len(self.readout_model.test_set_y)).astype(np.int), + ) + ) + # THIS IS WHERE WE NEED TO SHUFFLE TO TEST THE READOU + # test_subset_y, test_subset_c, test_subset_pt, test_subset_ph = shuffle(test_subset_y, test_subset_c, test_subset_pt, test_subset_ph) + predicted_c = self.readout_model.decode_model.predict(test_subset_y) + + binarized_c = self.gen_binarized_state( + approach="threshold", input_c=np.array(test_subset_c) + ) + auc, roc_curve = self.bin_classif(binarized_c, predicted_c) + aucs.append(auc) + roc_curves.append(roc_curve) + + coinflip = np.random.choice( + [0, 1], size=(len(test_subset_pt),), p=[0.5, 0.5] + ) + + n_auc, n_roc = self.bin_classif(binarized_c, coinflip) + null_aucs.append(n_auc) + null_roc_curves.append(n_roc) + + self.plot_controller_runs(aucs, roc_curves) + + def plot_controller_runs(self, aucs, roc_curves, **kwargs): + plt.figure() + plt.hist(aucs) + plt.vlines(np.mean(aucs), -1, 10, linewidth=10) + plt.xlim((0.0, 1.0)) + plt.title(kwargs["title"]) + + fig, ax = plt.subplots() + mean_fpr = np.linspace(0, 1, 100) + interp_tpr = [] + for aa in roc_curves: + interp_tpr_individ = interp(mean_fpr, aa[0], aa[1]) + interp_tpr_individ[0] = 0 + interp_tpr.append(interp_tpr_individ) + + mean_tpr = np.mean(interp_tpr, axis=0) + std_tpr = np.std(interp_tpr, axis=0) + + tprs_upper = np.minimum(mean_tpr + std_tpr, 1) + tprs_lower = np.maximum(mean_tpr - std_tpr, 0) + + ax.plot(mean_fpr, mean_tpr) + ax.fill_between(mean_fpr, tprs_lower, tprs_upper, alpha=0.2) + ax.plot(mean_fpr, mean_fpr, linestyle="dotted") + plt.plot([0, 1], [0, 1], linestyle="dotted") + if "title" in kwargs: + plt.title(kwargs["title"]) + + def plot_controller_simulations(self, plot_controller_list): + """ + Generate a compound plot of all simulations desired + """ + fig, ax = plt.subplots() + if set(plot_controller_list) != set(self.controllers.keys()): + raise ValueError("There's a mismatch in the controllers you want...") + for controller in plot_controller_list: + pass diff --git a/src/dbspace/readout/decoder.py b/src/dbspace/readout/decoder.py index 93ba6d0..5b7b550 100755 --- a/src/dbspace/readout/decoder.py +++ b/src/dbspace/readout/decoder.py @@ -6,55 +6,30 @@ @author: virati NEW classes for readout training, testing, and validation """ +import copy +import itertools import random +import dbspace as dbo +import dbspace.signal.oscillations as dbo import matplotlib.cm as cm import matplotlib.pylab as pl import matplotlib.pyplot as plt import numpy as np import scipy.stats as stats -from scipy import interp -from sklearn import metrics -from sklearn.linear_model import ElasticNet, RidgeCV, LassoCV -from dbspace.signal.oscillations import poly_subtr -from sklearn.metrics import ( - auc, - average_precision_score, - mean_absolute_error, - mean_squared_error, - precision_recall_curve, - roc_auc_score, - roc_curve, -) -from sklearn.model_selection import train_test_split -from sklearn.utils import shuffle - -np.random.seed(seed=2011) -random.seed(2011) - -import dbspace as dbo +from dbspace.readout.ClinVect import Phase_List +from dbspace.signal.oscillations import DEFAULT_FEAT_ORDER, poly_subtr from dbspace.utils.structures import nestdict from sklearn import linear_model - -default_params = {"CrossValid": 10} - +from sklearn.linear_model import ElasticNet, LassoCV, RidgeCV +from sklearn.model_selection import train_test_split +from dbspace.utils.functions import zero_mean import seaborn as sns -sns.set_context("paper") - -sns.set(font_scale=4) -sns.set_style("white") - -import copy -import itertools - -from dbspace.signal.oscillations import DEFAULT_FEAT_ORDER -from dbspace.readout.ClinVect import Phase_List -import dbspace.signal.oscillations as dbo - -def zero_mean(inp): - return inp - np.mean(inp) +default_params = {"CrossValid": 10} +np.random.seed(seed=2011) +random.seed(2011) #%% class base_decoder: @@ -947,7 +922,6 @@ def plot_test_stats(self): plt.title("Slope") """PLOTTING--------------------------------------------------------""" - """Plot the decoding CV coefficients""" def plot_decode_CV(self): @@ -971,197 +945,6 @@ def plot_decode_CV(self): plt.xlim((-1, len(self.do_feats) * 2)) -class controller_analysis: - def __init__(self, readout, **kwargs): - self.readout_model = readout - # get our binarized disease states - self.binarized_type = kwargs["bin_type"] - - def gen_binarized_state(self, **kwargs): - # redo our testing set - if kwargs["approach"] == "threshold": - binarized = kwargs["input_c"] > 0.5 - elif kwargs["approach"] == "stim_changes": - query_array = kwargs["input_ptph"] - binarized = [ - self.readout_model.CFrame.query_stim_change(pt, ph) - for pt, ph in query_array - ] - else: - raise Exception - - return binarized - - def pr_classif(self, binarized, predicted): - - precision, recall, thresholds = precision_recall_curve(binarized, predicted) - - # plt.figure() - # plt.step(recall,precision) - return precision, recall - - def pr_oracle(self, binarized, level=0.5): - oracle = np.array(np.copy(binarized)).astype(np.float) - oracle += np.random.normal(0, level, size=oracle.shape) - - precision, recall, thresholds = precision_recall_curve(binarized, oracle) - return precision, recall - - def pr_classif_2pred(self, binarized, predicted, empirical): - empirical = np.array(empirical).squeeze() - precision, recall, thresholds = precision_recall_curve( - binarized, empirical - predicted - ) - return precision, recall - - def bin_classif(self, binarized, predicted): - fpr, tpr, thresholds = metrics.roc_curve(binarized, predicted) - roc_curve = (fpr, tpr, thresholds) - auc = roc_auc_score(binarized, predicted) - - return auc, roc_curve - - def controller_runs(self): - controller_types = [ - "readout", - "empirical+readout", - "oracle", - "null", - "empirical", - ] - controllers = {key: [] for key in controller_types} - aucs = {key: [] for key in controller_types} - pr_curves = {key: [] for key in controller_types} - - for ii in range(100): - test_subset_y, test_subset_c, test_subset_pt, test_subset_ph = zip( - *random.sample( - list( - zip( - self.readout_model.test_set_y, - self.readout_model.test_set_c, - self.readout_model.test_set_pt, - self.readout_model.test_set_ph, - ) - ), - np.ceil(0.8 * len(self.readout_model.test_set_y)).astype(np.int), - ) - ) - predicted_c = self.readout_model.decode_model.predict(test_subset_y) - - # test_subset_pt = shuffle(test_subset_pt);print('PR_Classif: Shuffling Data') - binarized_c = self.gen_binarized_state( - approach="stim_changes", - input_ptph=list(zip(test_subset_pt, test_subset_ph)), - ) - # shuffle? - # binarized_c = shuffle(binarized_c);print('PR_Classif: Shuffling binarization') - coinflip = np.random.choice( - [0, 1], size=(len(test_subset_pt),), p=[0.5, 0.5] - ) - - controllers["readout"].append(self.pr_classif(binarized_c, predicted_c)) - controllers["empirical+readout"].append( - self.pr_classif_2pred(binarized_c, predicted_c, test_subset_c) - ) - controllers["oracle"].append(self.pr_oracle(binarized_c, level=0.5)) - controllers["empirical"].append(self.pr_classif(binarized_c, test_subset_c)) - controllers["null"].append(self.pr_classif(binarized_c, coinflip)) - - # organize results - for kk in controller_types: - for ii in range(100): - aucs[kk].append( - metrics.auc(controllers[kk][ii][1], controllers[kk][ii][0]) - ) - pr_curves[kk].append((controllers[kk][ii][0], controllers[kk][ii][1])) - - self.plot_classif_runs(aucs[kk], pr_curves[kk], title=kk) - - def classif_runs( - self, - ): - aucs = [] - roc_curves = [] - - null_aucs = [] - null_roc_curves = [] - - for ii in range(100): - test_subset_y, test_subset_c, test_subset_pt, test_subset_ph = zip( - *random.sample( - list( - zip( - self.readout_model.test_set_y, - self.readout_model.test_set_c, - self.readout_model.test_set_pt, - self.readout_model.test_set_ph, - ) - ), - np.ceil(0.8 * len(self.readout_model.test_set_y)).astype(np.int), - ) - ) - # THIS IS WHERE WE NEED TO SHUFFLE TO TEST THE READOU - # test_subset_y, test_subset_c, test_subset_pt, test_subset_ph = shuffle(test_subset_y, test_subset_c, test_subset_pt, test_subset_ph) - predicted_c = self.readout_model.decode_model.predict(test_subset_y) - - binarized_c = self.gen_binarized_state( - approach="threshold", input_c=np.array(test_subset_c) - ) - auc, roc_curve = self.bin_classif(binarized_c, predicted_c) - aucs.append(auc) - roc_curves.append(roc_curve) - - coinflip = np.random.choice( - [0, 1], size=(len(test_subset_pt),), p=[0.5, 0.5] - ) - - n_auc, n_roc = self.bin_classif(binarized_c, coinflip) - null_aucs.append(n_auc) - null_roc_curves.append(n_roc) - - self.plot_classif_runs(aucs, roc_curves) - # self.plot_classif_runs(null_aucs,null_roc_curves) # if you want a sanity check with a coinflip null - - """Here we'll do a 2-d density plot for error rates using both DR-SCC and nHDRS""" - - def density_plot(self): - pass - - def plot_classif_runs(self, aucs, roc_curves, **kwargs): - plt.figure() - plt.hist(aucs) - plt.vlines(np.mean(aucs), -1, 10, linewidth=10) - plt.xlim((0.0, 1.0)) - plt.title(kwargs["title"]) - - fig, ax = plt.subplots() - mean_fpr = np.linspace(0, 1, 100) - interp_tpr = [] - for aa in roc_curves: - interp_tpr_individ = interp(mean_fpr, aa[0], aa[1]) - interp_tpr_individ[0] = 0 - interp_tpr.append(interp_tpr_individ) - - mean_tpr = np.mean(interp_tpr, axis=0) - std_tpr = np.std(interp_tpr, axis=0) - - tprs_upper = np.minimum(mean_tpr + std_tpr, 1) - tprs_lower = np.maximum(mean_tpr - std_tpr, 0) - - ax.plot(mean_fpr, mean_tpr) - ax.fill_between(mean_fpr, tprs_lower, tprs_upper, alpha=0.2) - ax.plot(mean_fpr, mean_fpr, linestyle="dotted") - plt.plot([0, 1], [0, 1], linestyle="dotted") - if "title" in kwargs: - plt.title(kwargs["title"]) - - # for aa in roc_curves: - # plt.plot(aa[0],aa[1],alpha=0.2) - - def roc_auc(self): - pass - #%% class feat_check(base_decoder): diff --git a/src/dbspace/readout/BR_DataFrame.py b/src/dbspace/utils/frames/BR_DataFrame.py similarity index 100% rename from src/dbspace/readout/BR_DataFrame.py rename to src/dbspace/utils/frames/BR_DataFrame.py diff --git a/src/dbspace/utils/functions.py b/src/dbspace/utils/functions.py index 6ba60f5..e025fc2 100644 --- a/src/dbspace/utils/functions.py +++ b/src/dbspace/utils/functions.py @@ -1,3 +1,5 @@ +import numpy as np + def unity(x): return x @@ -8,3 +10,8 @@ def quadratic(x, c=[0, 0]): # Used in this module for finding the nearest datetime def nearest(items, pivot): return min(items, key=lambda x: abs(x - pivot)) + + +def zero_mean(inp): + return inp - np.mean(inp) +