From 824f7a8fb9d9c7cb4a7a742185f0383ef50aa7b6 Mon Sep 17 00:00:00 2001 From: tgnassou Date: Wed, 25 Sep 2024 13:29:05 +0200 Subject: [PATCH 1/2] add bci concat --- benchmark_utils/utils.py | 9 +++ datasets/bci_concat.py | 133 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 142 insertions(+) create mode 100644 datasets/bci_concat.py diff --git a/benchmark_utils/utils.py b/benchmark_utils/utils.py index 111f977..fa9739b 100644 --- a/benchmark_utils/utils.py +++ b/benchmark_utils/utils.py @@ -168,6 +168,15 @@ def get_params_per_dataset(dataset_name, n_classes): 'max_epochs': 200, 'num_features': 40 * 69, }, + 'bci_concat': { + 'batch_size': 256, + 'model': FBCSPNet(n_chans=22, n_classes=n_classes, input_window_samples=1125,), + 'lr_scheduler': LRScheduler("CosineAnnealingLR", T_max=200 - 1), + 'optimizer': AdamW, + 'lr': 0.0625 * 0.01, + 'max_epochs': 200, + 'num_features': 40 * 69, + }, } if dataset_name not in dataset_configs: diff --git a/datasets/bci_concat.py b/datasets/bci_concat.py new file mode 100644 index 0000000..4514b1c --- /dev/null +++ b/datasets/bci_concat.py @@ -0,0 +1,133 @@ +from benchopt import BaseDataset, safe_import_context + +# Protect the import with `safe_import_context()`. This allows: +# - skipping import to speed up autocompletion in CLI. +# - getting requirements info when all dependencies are not installed. +with safe_import_context() as import_ctx: + from skada.utils import source_target_merge + import numpy as np + from braindecode.datasets import MOABBDataset + from braindecode.preprocessing import ( + exponential_moving_standardize, + preprocess, + Preprocessor, + ) + from numpy import multiply + from braindecode.preprocessing import create_windows_from_events + + +# All datasets must be named `Dataset` and inherit from `BaseDataset` +class Dataset(BaseDataset): + + # Name to select the dataset in the CLI and to display the results. + name = "BCI_concat" + + requirements = ['mne==1.6.1', 'braindecode==0.8.1', + 'moabb==0.5', 'pyriemann==0.3'] + + parameters = { + 'source_target': [("session_T", "session_E"), ("session_E", "session_T")], + } + + def get_data(self): + # The return arguments of this function are passed as keyword arguments + # to `Objective.set_data`. This defines the benchmark's + # API to pass data. It is customizable for each benchmark. + + source, target = self.source_target + subject_ids = [1, 2, 3, 4, 5, 6, 7, 8, 9] + X_source_list = [] + y_source_list = [] + X_target_list = [] + y_target_list = [] + for subject_id in subject_ids: + dataset = MOABBDataset( + dataset_name="BNCI2014001", subject_ids=[subject_id] + ) + low_cut_hz = 4.0 # low cut frequency for filtering + high_cut_hz = 40.0 # high cut frequency for filtering + # Parameters for exponential moving standardization + factor_new = 1e-3 + init_block_size = 1000 + # Factor to convert from V to uV + factor = 1e6 + + preprocessors = [ + Preprocessor("pick_types", eeg=True, meg=False, stim=False), + Preprocessor(lambda data: multiply(data, factor)), + Preprocessor( + "filter", l_freq=low_cut_hz, h_freq=high_cut_hz + ), + Preprocessor( + exponential_moving_standardize, + factor_new=factor_new, + init_block_size=init_block_size, + ), + ] + + # Transform the data + preprocess(dataset, preprocessors) + + trial_start_offset_seconds = -0.5 + # Extract sampling frequency, check that they are same in all datasets + sfreq = dataset.datasets[0].raw.info["sfreq"] + + assert all([ds.raw.info["sfreq"] == sfreq for ds in dataset.datasets]) + # Calculate the trial start offset in samples. + trial_start_offset_samples = int(trial_start_offset_seconds * sfreq) + + window_size_samples = None + window_stride_samples = None + + windows_dataset = create_windows_from_events( + dataset, + trial_start_offset_samples=trial_start_offset_samples, + trial_stop_offset_samples=0, + preload=False, + window_size_samples=window_size_samples, + window_stride_samples=window_stride_samples, + ) + + splitted = windows_dataset.split("session") + sessions = list(splitted.keys()) + + X = [] + y = [] + sess_source = source + n_runs = len(splitted[sess_source].datasets) + x = [] + y = [] + for run in range(n_runs): + x += [sample[0] for sample in splitted[sess_source].datasets[run]] + y += [sample[1] for sample in splitted[sess_source].datasets[run]] + X_source = np.array(x) + y_source = np.array(y) + + sess_target = target + n_runs = len(splitted[sess_target].datasets) + x = [] + y = [] + for run in range(n_runs): + x += [sample[0] for sample in splitted[sess_target].datasets[run]] + y += [sample[1] for sample in splitted[sess_target].datasets[run]] + X_target = np.array(x) + y_target = np.array(y) + + X_source_list.append(X_source) + y_source_list.append(y_source) + X_target_list.append(X_target) + y_target_list.append(y_target) + + X_source_concat = np.concatenate(X_source_list) + y_source_concat = np.concatenate(y_source_list) + X_target_concat = np.concatenate(X_target_list) + y_target_concat = np.concatenate(y_target_list) + + X, y, sample_domain = source_target_merge( + X_source_concat, X_target_concat, y_source_concat, y_target_concat) + + return dict( + X=X, + y=y, + sample_domain=sample_domain + ) From 7d7867561019b792498ffcd5d4cab73d14c3f987 Mon Sep 17 00:00:00 2001 From: tgnassou Date: Thu, 26 Sep 2024 08:26:46 +0200 Subject: [PATCH 2/2] n_split=1 --- objective.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/objective.py b/objective.py index 798f27b..8aa2bf3 100644 --- a/objective.py +++ b/objective.py @@ -48,7 +48,7 @@ class Objective(BaseObjective): # Random state random_state = 0 - n_splits_data = 5 + n_splits_data = 1 test_size_data = 0.2 # Set random states