@@ -110,17 +110,21 @@ def __init__(self,
110110
111111 def split (self , dataset_iterator : DatasetIteratorIF ) -> Tuple [List [DatasetIteratorIF ], List [List [DatasetIteratorIF ]]]:
112112 # create outer loop folds
113- targets = [sample [self .target_pos ] for sample in dataset_iterator ]
114- folds_indices = [fold [1 ] for fold in self .outer_splitter .split (X = np .zeros (len (targets )), y = targets )]
115- outer_folds = [DatasetIteratorView (dataset_iterator , fold_indices ) for fold_indices in folds_indices ]
113+ targets = np . array ( [sample [self .target_pos ] for sample in dataset_iterator ])
114+ outer_folds_indices = [fold [1 ] for fold in self .outer_splitter .split (X = np .zeros (len (targets )), y = targets )]
115+ outer_fold_iterators = [DatasetIteratorView (dataset_iterator , fold_indices ) for fold_indices in outer_folds_indices ]
116116 # create inner loop folds
117- inner_folds_list = [] # contains [inner folds of outer_fold_1, inner folds of outer_fold_2 ...]
118- for iterator in outer_folds :
119- targets = [sample [self .target_pos ] for sample in iterator ]
120- folds_indices = [fold [1 ] for fold in self .inner_splitter .split (X = np .zeros (len (targets )), y = targets )]
121- inner_folds = [DatasetIteratorView (iterator , fold_indices ) for fold_indices in folds_indices ]
122- inner_folds_list .append (inner_folds )
123- return outer_folds , inner_folds_list
117+ inner_folds_iterators_list = [] # contains [inner folds of outer_fold_1, inner folds of outer_fold_2 ...]
118+ for outer_fold_id in range (len (outer_fold_iterators )):
119+ # concat the indices of the splits which belong to the train splits
120+ train_split_ids = [i for i in range (len (outer_folds_indices )) if i != outer_fold_id ]
121+ outer_train_fold_indices = np .array ([indice for i in train_split_ids for indice in outer_folds_indices [i ]])
122+ inner_targets = targets [outer_train_fold_indices ]
123+ inner_folds_indices = [outer_train_fold_indices [inner_fold [1 ]]
124+ for inner_fold in self .inner_splitter .split (X = np .zeros (len (inner_targets )), y = inner_targets )]
125+ inner_folds = [DatasetIteratorView (dataset_iterator , fold_indices ) for fold_indices in inner_folds_indices ]
126+ inner_folds_iterators_list .append (inner_folds )
127+ return outer_fold_iterators , inner_folds_iterators_list
124128
125129 def get_indices (self , dataset_iterator : DatasetIteratorIF ) -> Tuple [List [List [int ]], List [List [int ]]]:
126130 outer_folds , inner_folds_list = self .split (dataset_iterator )
0 commit comments