Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 51 additions & 17 deletions atomsci/ddm/pipeline/splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import logging
import os
import sys
import copy
import deepchem as dc
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -82,8 +83,7 @@ def select_dset_by_id_list(dataset, id_list):
"""
#TODO: Need to test
id_df = pd.DataFrame({'indices' : np.arange(len(dataset.ids), dtype=np.int32)}, index=[str(e) for e in dataset.ids])
sel_df = pd.DataFrame(index=[str(e) for e in id_list])
match_df = id_df.join(sel_df, how='inner')
match_df = id_df.loc[id_df.index.isin(id_list)]
subset = dataset.select(match_df.indices.values)
return subset

Expand All @@ -102,7 +102,7 @@ def select_attrs_by_dset_ids(dataset, attr_df):

"""
#TODO: Need to test
id_df = pd.DataFrame(index=[str(e) for e in dataset.ids])
id_df = pd.DataFrame(index=[str(e) for e in set(dataset.ids)])
subattr_df = id_df.join(attr_df, how='inner')
return subattr_df

Expand All @@ -125,7 +125,7 @@ def select_attrs_by_dset_smiles(dataset, attr_df, smiles_col):

"""
id_df = pd.DataFrame(index=[str(e) for e in dataset.ids])
subattr_df = id_df.merge(attr_df, how='inner', left_index=True, right_on=smiles_col)
subattr_df = id_df.merge(attr_df.drop_duplicates(subset=smiles_col), how='inner', left_index=True, right_on=smiles_col)
return subattr_df

# ****************************************************************************************
Expand Down Expand Up @@ -339,12 +339,21 @@ def split_dataset(self, dataset, attr_df, smiles_col):
Exception if there are duplicate ids or smiles strings in the dataset or the attr_df

"""
dataset_dup = False
if check_if_dupe_smiles_dataset(dataset, attr_df, smiles_col):
raise Exception("Duplicate ids or smiles in the dataset")
print("Duplicate ids or smiles in the dataset, deduplicate first")
dataset_dup = True
dataset_ori = copy.deepcopy(dataset)
id_df = pd.DataFrame({'indices' : np.arange(len(dataset.ids), dtype=np.int32), "compound_id": [str(e) for e in dataset.ids]})
sel_df = id_df.drop_duplicates(subset="compound_id")
dataset = dataset.select(sel_df.indices.values)

if self.needs_smiles():
# Some DeepChem splitters require compound IDs in dataset to be SMILES strings. Swap in the
# SMILES strings now; we'll reverse this later.
dataset = DiskDataset.from_numpy(dataset.X, dataset.y, ids=attr_df[smiles_col].values, verbose=False)
dataset = DiskDataset.from_numpy(dataset.X, dataset.y, ids=attr_df.drop_duplicates(subset=smiles_col)[smiles_col].values, verbose=False)
if dataset_dup:
dataset_ori = DiskDataset.from_numpy(dataset_ori.X, dataset_ori.y, ids=attr_df[smiles_col].values, verbose=False)

# Under k-fold CV, the training/validation splits are determined by num_folds; only the test set fraction
# is directly specified through command line parameters. If we use Butina splitting, we can't control
Expand All @@ -365,14 +374,18 @@ def split_dataset(self, dataset, attr_df, smiles_col):
# TODO: Add special handling for AVE splitter
train_cv, test = self.splitter.train_test_split(dataset, seed=np.random.seed(123), frac_train=train_frac)
train_cv_pairs = self.splitter.k_fold_split(train_cv, self.num_folds)

train_valid_dsets = []
train_valid_attr = []

if self.needs_smiles():
# Now that DeepChem splitters have done their work, replace the SMILES strings in the split
# dataset objects with actual compound IDs.
for train, valid in train_cv_pairs:
# assign the subsets to the original dataset if duplicated compounds exist
if dataset_dup:
train = select_dset_by_id_list(dataset_ori, train.ids)
valid = select_dset_by_id_list(dataset_ori, valid.ids)
train_attr = select_attrs_by_dset_smiles(train, attr_df, smiles_col)
train = DiskDataset.from_numpy(train.X, train.y, ids=train_attr.index.values, verbose=False)

Expand All @@ -382,15 +395,22 @@ def split_dataset(self, dataset, attr_df, smiles_col):
train_valid_dsets.append((train, valid))
train_valid_attr.append((train_attr, valid_attr))

if dataset_dup:
test = select_dset_by_id_list(dataset_ori, test.ids)
test_attr = select_attrs_by_dset_smiles(test, attr_df, smiles_col)
test = DiskDataset.from_numpy(test.X, test.y, ids=test_attr.index.values, verbose=False)
else:
# Otherwise just subset the ID-to-SMILES maps.
for train, valid in train_cv_pairs:
if dataset_dup:
train = select_dset_by_id_list(dataset_ori, train.ids)
valid = select_dset_by_id_list(dataset_ori, valid.ids)
train_attr = select_attrs_by_dset_ids(train, attr_df)
valid_attr = select_attrs_by_dset_ids(valid, attr_df)
train_valid_attr.append((train_attr, valid_attr))
train_valid_dsets = train_cv_pairs
if dataset_dup:
test = select_dset_by_id_list(dataset_ori, test.ids)
test_attr = select_attrs_by_dset_ids(test, attr_df)

return train_valid_dsets, test, train_valid_attr, test_attr
Expand Down Expand Up @@ -432,10 +452,10 @@ def __init__(self, params):
def get_split_prefix(self, parent=''):
"""Returns a string identifying the split strategy (TVT or k-fold) and the splitting method
(index, scaffold, etc.) for use in filenames, dataset keys, etc.

Args:
parent (str): Default to empty string. Sets the parent directory for the output string

Returns:
(str): A string that identifies the split strategy and the splitting method.
Appends a parent directory in front of the fold description
Expand All @@ -450,14 +470,14 @@ def get_split_prefix(self, parent=''):
def split_dataset(self, dataset, attr_df, smiles_col):
#smiles_col is a hack for now until deepchem fixes their scaffold and butina splitters
"""Splits dataset into training, testing and validation sets.

For ave_min, random, scaffold, index splits
self.params.split_valid_frac & self.params.split_test_frac should be defined and
train_frac = 1.0 - self.params.split_valid_frac - self.params.split_test_frac

For butina split, test size is not user defined, and depends on available clusters that qualify for placement in the test set
train_frac = 1.0 - self.params.split_valid_frac

For temporal split, test size is also not user defined, and depends on number of compounds with dates after cutoff date.
train_frac = 1.0 - self.params.split_valid_frac
Args:
Expand All @@ -466,7 +486,7 @@ def split_dataset(self, dataset, attr_df, smiles_col):
attr_df (Pandas DataFrame): dataframe containing SMILES strings indexed by compound IDs,

smiles_col (string): name of SMILES column (hack for now until deepchem fixes scaffold and butina splitters)

Returns:
[(train, valid)], test, [(train_attr, valid_attr)], test_attr:
train (deepchem Dataset): training dataset.
Expand All @@ -480,19 +500,28 @@ def split_dataset(self, dataset, attr_df, smiles_col):
valid_attr (Pandas DataFrame): dataframe of SMILES strings indexed by compound IDs for validation set.

test_attr (Pandas DataFrame): dataframe of SMILES strings indexed by compound IDs for test set.

Raises:
Exception if there are duplicate ids or smiles strings in the dataset or the attr_df

"""
log.warning("Splitting data by %s" % self.params.splitter)

dataset_dup = False
if check_if_dupe_smiles_dataset(dataset, attr_df, smiles_col):
print("Duplicate ids or smiles in the dataset, deduplicate first")
dataset_dup = True
dataset_ori = copy.deepcopy(dataset)
id_df = pd.DataFrame({'indices' : np.arange(len(dataset.ids), dtype=np.int32), "compound_id": [str(e) for e in dataset.ids]})
sel_df = id_df.drop_duplicates(subset="compound_id")
dataset = dataset.select(sel_df.indices.values)

if self.needs_smiles():
if check_if_dupe_smiles_dataset(dataset, attr_df, smiles_col):
raise Exception("Duplicate ids or smiles in the dataset")
# Some DeepChem splitters require compound IDs in dataset to be SMILES strings. Swap in the
# SMILES strings now; we'll reverse this later.
dataset = DiskDataset.from_numpy(dataset.X, dataset.y, ids=attr_df[smiles_col].values, verbose=False)
dataset = DiskDataset.from_numpy(dataset.X, dataset.y, ids=attr_df.drop_duplicates(subset=smiles_col)[smiles_col].values, verbose=False)
if dataset_dup:
dataset_ori = DiskDataset.from_numpy(dataset_ori.X, dataset_ori.y, ids=attr_df[smiles_col].values, verbose=False)

if self.split == 'butina':
#train_valid, test = self.splitter.train_test_split(dataset, cutoff=self.params.butina_cutoff)
Expand Down Expand Up @@ -529,6 +558,11 @@ def split_dataset(self, dataset, attr_df, smiles_col):
seed=np.random.seed(123))

# Extract the ID-to_SMILES maps from attr_df for each subset.
# assign the subsets to the original dataset if duplicated compounds exist
if dataset_dup:
train = select_dset_by_id_list(dataset_ori, train.ids)
valid = select_dset_by_id_list(dataset_ori, valid.ids)
test = select_dset_by_id_list(dataset_ori, test.ids)
if self.needs_smiles():
# Now that DeepChem splitters have done their work, replace the SMILES strings in the split
# dataset objects with actual compound IDs.
Expand Down