Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions benchmark_utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,15 @@ def get_params_per_dataset(dataset_name, n_classes):
'max_epochs': 200,
'num_features': 40 * 69,
},
'bci_concat': {
'batch_size': 256,
'model': FBCSPNet(n_chans=22, n_classes=n_classes, input_window_samples=1125,),
'lr_scheduler': LRScheduler("CosineAnnealingLR", T_max=200 - 1),
'optimizer': AdamW,
'lr': 0.0625 * 0.01,
'max_epochs': 200,
'num_features': 40 * 69,
},
}

if dataset_name not in dataset_configs:
Expand Down
133 changes: 133 additions & 0 deletions datasets/bci_concat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
from benchopt import BaseDataset, safe_import_context

# Protect the import with `safe_import_context()`. This allows:
# - skipping import to speed up autocompletion in CLI.
# - getting requirements info when all dependencies are not installed.
with safe_import_context() as import_ctx:
from skada.utils import source_target_merge
import numpy as np
from braindecode.datasets import MOABBDataset
from braindecode.preprocessing import (
exponential_moving_standardize,
preprocess,
Preprocessor,
)
from numpy import multiply
from braindecode.preprocessing import create_windows_from_events


# All datasets must be named `Dataset` and inherit from `BaseDataset`
class Dataset(BaseDataset):

# Name to select the dataset in the CLI and to display the results.
name = "BCI_concat"

requirements = ['mne==1.6.1', 'braindecode==0.8.1',
'moabb==0.5', 'pyriemann==0.3']

parameters = {
'source_target': [("session_T", "session_E"), ("session_E", "session_T")],
}

def get_data(self):
# The return arguments of this function are passed as keyword arguments
# to `Objective.set_data`. This defines the benchmark's
# API to pass data. It is customizable for each benchmark.

source, target = self.source_target
subject_ids = [1, 2, 3, 4, 5, 6, 7, 8, 9]
X_source_list = []
y_source_list = []
X_target_list = []
y_target_list = []
for subject_id in subject_ids:
dataset = MOABBDataset(
dataset_name="BNCI2014001", subject_ids=[subject_id]
)
low_cut_hz = 4.0 # low cut frequency for filtering
high_cut_hz = 40.0 # high cut frequency for filtering
# Parameters for exponential moving standardization
factor_new = 1e-3
init_block_size = 1000
# Factor to convert from V to uV
factor = 1e6

preprocessors = [
Preprocessor("pick_types", eeg=True, meg=False, stim=False),
Preprocessor(lambda data: multiply(data, factor)),
Preprocessor(
"filter", l_freq=low_cut_hz, h_freq=high_cut_hz
),
Preprocessor(
exponential_moving_standardize,
factor_new=factor_new,
init_block_size=init_block_size,
),
]

# Transform the data
preprocess(dataset, preprocessors)

trial_start_offset_seconds = -0.5
# Extract sampling frequency, check that they are same in all datasets
sfreq = dataset.datasets[0].raw.info["sfreq"]

assert all([ds.raw.info["sfreq"] == sfreq for ds in dataset.datasets])
# Calculate the trial start offset in samples.
trial_start_offset_samples = int(trial_start_offset_seconds * sfreq)

window_size_samples = None
window_stride_samples = None

windows_dataset = create_windows_from_events(
dataset,
trial_start_offset_samples=trial_start_offset_samples,
trial_stop_offset_samples=0,
preload=False,
window_size_samples=window_size_samples,
window_stride_samples=window_stride_samples,
)

splitted = windows_dataset.split("session")
sessions = list(splitted.keys())

X = []
y = []
sess_source = source
n_runs = len(splitted[sess_source].datasets)
x = []
y = []
for run in range(n_runs):
x += [sample[0] for sample in splitted[sess_source].datasets[run]]
y += [sample[1] for sample in splitted[sess_source].datasets[run]]
X_source = np.array(x)
y_source = np.array(y)

sess_target = target
n_runs = len(splitted[sess_target].datasets)
x = []
y = []
for run in range(n_runs):
x += [sample[0] for sample in splitted[sess_target].datasets[run]]
y += [sample[1] for sample in splitted[sess_target].datasets[run]]
X_target = np.array(x)
y_target = np.array(y)

X_source_list.append(X_source)
y_source_list.append(y_source)
X_target_list.append(X_target)
y_target_list.append(y_target)

X_source_concat = np.concatenate(X_source_list)
y_source_concat = np.concatenate(y_source_list)
X_target_concat = np.concatenate(X_target_list)
y_target_concat = np.concatenate(y_target_list)

X, y, sample_domain = source_target_merge(
X_source_concat, X_target_concat, y_source_concat, y_target_concat)

return dict(
X=X,
y=y,
sample_domain=sample_domain
)
2 changes: 1 addition & 1 deletion objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class Objective(BaseObjective):

# Random state
random_state = 0
n_splits_data = 5
n_splits_data = 1
test_size_data = 0.2

# Set random states
Expand Down