diff --git a/atomsci/ddm/pipeline/splitting.py b/atomsci/ddm/pipeline/splitting.py index cd7a3766..91acde8f 100644 --- a/atomsci/ddm/pipeline/splitting.py +++ b/atomsci/ddm/pipeline/splitting.py @@ -7,6 +7,7 @@ import logging import os import sys +import copy import deepchem as dc import numpy as np import pandas as pd @@ -82,8 +83,7 @@ def select_dset_by_id_list(dataset, id_list): """ #TODO: Need to test id_df = pd.DataFrame({'indices' : np.arange(len(dataset.ids), dtype=np.int32)}, index=[str(e) for e in dataset.ids]) - sel_df = pd.DataFrame(index=[str(e) for e in id_list]) - match_df = id_df.join(sel_df, how='inner') + match_df = id_df.loc[id_df.index.isin(id_list)] subset = dataset.select(match_df.indices.values) return subset @@ -102,7 +102,7 @@ def select_attrs_by_dset_ids(dataset, attr_df): """ #TODO: Need to test - id_df = pd.DataFrame(index=[str(e) for e in dataset.ids]) + id_df = pd.DataFrame(index=[str(e) for e in set(dataset.ids)]) subattr_df = id_df.join(attr_df, how='inner') return subattr_df @@ -125,7 +125,7 @@ def select_attrs_by_dset_smiles(dataset, attr_df, smiles_col): """ id_df = pd.DataFrame(index=[str(e) for e in dataset.ids]) - subattr_df = id_df.merge(attr_df, how='inner', left_index=True, right_on=smiles_col) + subattr_df = id_df.merge(attr_df.drop_duplicates(subset=smiles_col), how='inner', left_index=True, right_on=smiles_col) return subattr_df # **************************************************************************************** @@ -339,12 +339,21 @@ def split_dataset(self, dataset, attr_df, smiles_col): Exception if there are duplicate ids or smiles strings in the dataset or the attr_df """ + dataset_dup = False if check_if_dupe_smiles_dataset(dataset, attr_df, smiles_col): - raise Exception("Duplicate ids or smiles in the dataset") + print("Duplicate ids or smiles in the dataset, deduplicate first") + dataset_dup = True + dataset_ori = copy.deepcopy(dataset) + id_df = pd.DataFrame({'indices' : np.arange(len(dataset.ids), dtype=np.int32), "compound_id": [str(e) for e in dataset.ids]}) + sel_df = id_df.drop_duplicates(subset="compound_id") + dataset = dataset.select(sel_df.indices.values) + if self.needs_smiles(): # Some DeepChem splitters require compound IDs in dataset to be SMILES strings. Swap in the # SMILES strings now; we'll reverse this later. - dataset = DiskDataset.from_numpy(dataset.X, dataset.y, ids=attr_df[smiles_col].values, verbose=False) + dataset = DiskDataset.from_numpy(dataset.X, dataset.y, ids=attr_df.drop_duplicates(subset=smiles_col)[smiles_col].values, verbose=False) + if dataset_dup: + dataset_ori = DiskDataset.from_numpy(dataset_ori.X, dataset_ori.y, ids=attr_df[smiles_col].values, verbose=False) # Under k-fold CV, the training/validation splits are determined by num_folds; only the test set fraction # is directly specified through command line parameters. If we use Butina splitting, we can't control @@ -365,7 +374,7 @@ def split_dataset(self, dataset, attr_df, smiles_col): # TODO: Add special handling for AVE splitter train_cv, test = self.splitter.train_test_split(dataset, seed=np.random.seed(123), frac_train=train_frac) train_cv_pairs = self.splitter.k_fold_split(train_cv, self.num_folds) - + train_valid_dsets = [] train_valid_attr = [] @@ -373,6 +382,10 @@ def split_dataset(self, dataset, attr_df, smiles_col): # Now that DeepChem splitters have done their work, replace the SMILES strings in the split # dataset objects with actual compound IDs. for train, valid in train_cv_pairs: + # assign the subsets to the original dataset if duplicated compounds exist + if dataset_dup: + train = select_dset_by_id_list(dataset_ori, train.ids) + valid = select_dset_by_id_list(dataset_ori, valid.ids) train_attr = select_attrs_by_dset_smiles(train, attr_df, smiles_col) train = DiskDataset.from_numpy(train.X, train.y, ids=train_attr.index.values, verbose=False) @@ -382,15 +395,22 @@ def split_dataset(self, dataset, attr_df, smiles_col): train_valid_dsets.append((train, valid)) train_valid_attr.append((train_attr, valid_attr)) + if dataset_dup: + test = select_dset_by_id_list(dataset_ori, test.ids) test_attr = select_attrs_by_dset_smiles(test, attr_df, smiles_col) test = DiskDataset.from_numpy(test.X, test.y, ids=test_attr.index.values, verbose=False) else: # Otherwise just subset the ID-to-SMILES maps. for train, valid in train_cv_pairs: + if dataset_dup: + train = select_dset_by_id_list(dataset_ori, train.ids) + valid = select_dset_by_id_list(dataset_ori, valid.ids) train_attr = select_attrs_by_dset_ids(train, attr_df) valid_attr = select_attrs_by_dset_ids(valid, attr_df) train_valid_attr.append((train_attr, valid_attr)) train_valid_dsets = train_cv_pairs + if dataset_dup: + test = select_dset_by_id_list(dataset_ori, test.ids) test_attr = select_attrs_by_dset_ids(test, attr_df) return train_valid_dsets, test, train_valid_attr, test_attr @@ -432,10 +452,10 @@ def __init__(self, params): def get_split_prefix(self, parent=''): """Returns a string identifying the split strategy (TVT or k-fold) and the splitting method (index, scaffold, etc.) for use in filenames, dataset keys, etc. - + Args: parent (str): Default to empty string. Sets the parent directory for the output string - + Returns: (str): A string that identifies the split strategy and the splitting method. Appends a parent directory in front of the fold description @@ -450,14 +470,14 @@ def get_split_prefix(self, parent=''): def split_dataset(self, dataset, attr_df, smiles_col): #smiles_col is a hack for now until deepchem fixes their scaffold and butina splitters """Splits dataset into training, testing and validation sets. - + For ave_min, random, scaffold, index splits self.params.split_valid_frac & self.params.split_test_frac should be defined and train_frac = 1.0 - self.params.split_valid_frac - self.params.split_test_frac - + For butina split, test size is not user defined, and depends on available clusters that qualify for placement in the test set train_frac = 1.0 - self.params.split_valid_frac - + For temporal split, test size is also not user defined, and depends on number of compounds with dates after cutoff date. train_frac = 1.0 - self.params.split_valid_frac Args: @@ -466,7 +486,7 @@ def split_dataset(self, dataset, attr_df, smiles_col): attr_df (Pandas DataFrame): dataframe containing SMILES strings indexed by compound IDs, smiles_col (string): name of SMILES column (hack for now until deepchem fixes scaffold and butina splitters) - + Returns: [(train, valid)], test, [(train_attr, valid_attr)], test_attr: train (deepchem Dataset): training dataset. @@ -480,19 +500,28 @@ def split_dataset(self, dataset, attr_df, smiles_col): valid_attr (Pandas DataFrame): dataframe of SMILES strings indexed by compound IDs for validation set. test_attr (Pandas DataFrame): dataframe of SMILES strings indexed by compound IDs for test set. - + Raises: Exception if there are duplicate ids or smiles strings in the dataset or the attr_df """ log.warning("Splitting data by %s" % self.params.splitter) + dataset_dup = False + if check_if_dupe_smiles_dataset(dataset, attr_df, smiles_col): + print("Duplicate ids or smiles in the dataset, deduplicate first") + dataset_dup = True + dataset_ori = copy.deepcopy(dataset) + id_df = pd.DataFrame({'indices' : np.arange(len(dataset.ids), dtype=np.int32), "compound_id": [str(e) for e in dataset.ids]}) + sel_df = id_df.drop_duplicates(subset="compound_id") + dataset = dataset.select(sel_df.indices.values) + if self.needs_smiles(): - if check_if_dupe_smiles_dataset(dataset, attr_df, smiles_col): - raise Exception("Duplicate ids or smiles in the dataset") # Some DeepChem splitters require compound IDs in dataset to be SMILES strings. Swap in the # SMILES strings now; we'll reverse this later. - dataset = DiskDataset.from_numpy(dataset.X, dataset.y, ids=attr_df[smiles_col].values, verbose=False) + dataset = DiskDataset.from_numpy(dataset.X, dataset.y, ids=attr_df.drop_duplicates(subset=smiles_col)[smiles_col].values, verbose=False) + if dataset_dup: + dataset_ori = DiskDataset.from_numpy(dataset_ori.X, dataset_ori.y, ids=attr_df[smiles_col].values, verbose=False) if self.split == 'butina': #train_valid, test = self.splitter.train_test_split(dataset, cutoff=self.params.butina_cutoff) @@ -529,6 +558,11 @@ def split_dataset(self, dataset, attr_df, smiles_col): seed=np.random.seed(123)) # Extract the ID-to_SMILES maps from attr_df for each subset. + # assign the subsets to the original dataset if duplicated compounds exist + if dataset_dup: + train = select_dset_by_id_list(dataset_ori, train.ids) + valid = select_dset_by_id_list(dataset_ori, valid.ids) + test = select_dset_by_id_list(dataset_ori, test.ids) if self.needs_smiles(): # Now that DeepChem splitters have done their work, replace the SMILES strings in the split # dataset objects with actual compound IDs.