Skip to content

Commit e7cf27c

Browse files
committed
Merge branch 'nested_cv_implementation'
2 parents fa40436 + efa79b4 commit e7cf27c

File tree

2 files changed

+15
-11
lines changed

2 files changed

+15
-11
lines changed

src/data_stack/dataset/splitter.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -110,17 +110,21 @@ def __init__(self,
110110

111111
def split(self, dataset_iterator: DatasetIteratorIF) -> Tuple[List[DatasetIteratorIF], List[List[DatasetIteratorIF]]]:
112112
# create outer loop folds
113-
targets = [sample[self.target_pos] for sample in dataset_iterator]
114-
folds_indices = [fold[1] for fold in self.outer_splitter.split(X=np.zeros(len(targets)), y=targets)]
115-
outer_folds = [DatasetIteratorView(dataset_iterator, fold_indices) for fold_indices in folds_indices]
113+
targets = np.array([sample[self.target_pos] for sample in dataset_iterator])
114+
outer_folds_indices = [fold[1] for fold in self.outer_splitter.split(X=np.zeros(len(targets)), y=targets)]
115+
outer_fold_iterators = [DatasetIteratorView(dataset_iterator, fold_indices) for fold_indices in outer_folds_indices]
116116
# create inner loop folds
117-
inner_folds_list = [] # contains [inner folds of outer_fold_1, inner folds of outer_fold_2 ...]
118-
for iterator in outer_folds:
119-
targets = [sample[self.target_pos] for sample in iterator]
120-
folds_indices = [fold[1] for fold in self.inner_splitter.split(X=np.zeros(len(targets)), y=targets)]
121-
inner_folds = [DatasetIteratorView(iterator, fold_indices) for fold_indices in folds_indices]
122-
inner_folds_list.append(inner_folds)
123-
return outer_folds, inner_folds_list
117+
inner_folds_iterators_list = [] # contains [inner folds of outer_fold_1, inner folds of outer_fold_2 ...]
118+
for outer_fold_id in range(len(outer_fold_iterators)):
119+
# concat the indices of the splits which belong to the train splits
120+
train_split_ids = [i for i in range(len(outer_folds_indices)) if i != outer_fold_id]
121+
outer_train_fold_indices = np.array([indice for i in train_split_ids for indice in outer_folds_indices[i]])
122+
inner_targets = targets[outer_train_fold_indices]
123+
inner_folds_indices = [outer_train_fold_indices[inner_fold[1]]
124+
for inner_fold in self.inner_splitter.split(X=np.zeros(len(inner_targets)), y=inner_targets)]
125+
inner_folds = [DatasetIteratorView(dataset_iterator, fold_indices) for fold_indices in inner_folds_indices]
126+
inner_folds_iterators_list.append(inner_folds)
127+
return outer_fold_iterators, inner_folds_iterators_list
124128

125129
def get_indices(self, dataset_iterator: DatasetIteratorIF) -> Tuple[List[List[int]], List[List[int]]]:
126130
outer_folds, inner_folds_list = self.split(dataset_iterator)

unittests/dataset/test_splitter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def test_nested_cv_splitter(self, num_outer_loop_folds: int, num_inner_loop_fold
7474
if inner_stratification:
7575
for i in range(len(inner_folds)):
7676
class_counts = dict(collections.Counter([t for _, t in outer_folds[i]]))
77-
class_counts_per_fold = {target_class: int(count/num_inner_loop_folds) for target_class, count in class_counts.items()}
77+
class_counts_per_fold = {target_class: int(count*(num_outer_loop_folds-1)/num_inner_loop_folds) for target_class, count in class_counts.items()}
7878
for fold in inner_folds[i]:
7979
fold_class_counts = dict(collections.Counter([t for _, t in fold]))
8080
for key in list(class_counts_per_fold.keys()) + list(fold_class_counts.keys()):

0 commit comments

Comments
 (0)