diff --git a/examples/eeg/seizure_detection/EEG_isAbnormal_SparcNet.py b/examples/eeg/SpaRCNet/EEG_abnormal_classification_SparcNet.py similarity index 72% rename from examples/eeg/seizure_detection/EEG_isAbnormal_SparcNet.py rename to examples/eeg/SpaRCNet/EEG_abnormal_classification_SparcNet.py index 0c801f94e..f279f3522 100644 --- a/examples/eeg/seizure_detection/EEG_isAbnormal_SparcNet.py +++ b/examples/eeg/SpaRCNet/EEG_abnormal_classification_SparcNet.py @@ -1,7 +1,7 @@ from pyhealth.datasets import split_by_visit, get_dataloader from pyhealth.trainer import Trainer from pyhealth.datasets import TUABDataset -from pyhealth.tasks import EEG_isAbnormal_fn +from pyhealth.tasks import EEGAbnormalTUAB from pyhealth.models import SparcNet # step 1: load signal data @@ -9,10 +9,22 @@ dev=True, refresh_cache=True, ) - +print(dataset.stats()) # step 2: set task -TUAB_ds = dataset.set_task(EEG_isAbnormal_fn) -TUAB_ds.stat() +TUAB_ds = dataset.set_task(EEGAbnormalTUAB( + resample_rate=200, + bandpass_filter=(0.1, 75.0), + notch_filter=50.0, +)) +print(f"Total task samples: {len(TUAB_ds)}") +print(f"Input schema: {TUAB_ds.input_schema}") +print(f"Output schema: {TUAB_ds.output_schema}") + +# Inspect a sample +sample = TUAB_ds[0] +print(f"\nSample keys: {sample.keys()}") +print(f"Signal shape: {sample['signal'].shape}") +print(f"Label: {sample['label']}") # split dataset train_dataset, val_dataset, test_dataset = split_by_visit( diff --git a/examples/eeg/seizure_detection/EEG_events_SparcNet.py b/examples/eeg/SpaRCNet/EEG_events_classification_SparcNet.py similarity index 69% rename from examples/eeg/seizure_detection/EEG_events_SparcNet.py rename to examples/eeg/SpaRCNet/EEG_events_classification_SparcNet.py index 6e1e79bd8..d38c274a4 100644 --- a/examples/eeg/seizure_detection/EEG_events_SparcNet.py +++ b/examples/eeg/SpaRCNet/EEG_events_classification_SparcNet.py @@ -1,7 +1,7 @@ from pyhealth.datasets import split_by_visit, get_dataloader from pyhealth.trainer import Trainer from pyhealth.datasets import TUEVDataset -from pyhealth.tasks import EEG_events_fn +from pyhealth.tasks import EEGEventsTUEV from pyhealth.models import SparcNet # step 1: load signal data @@ -9,10 +9,24 @@ dev=True, refresh_cache=True, ) - +print(dataset.stats()) # step 2: set task -TUEV_ds = dataset.set_task(EEG_events_fn) -TUEV_ds.stat() +TUEV_ds = dataset.set_task(EEGEventsTUEV( + resample_rate=200, # Resample rate + bandpass_filter=(0.1, 75.0), # Bandpass filter + notch_filter=50.0, # Notch filter +)) + +print(f"Total task samples: {len(TUEV_ds)}") +print(f"Input schema: {TUEV_ds.input_schema}") +print(f"Output schema: {TUEV_ds.output_schema}") + +# Inspect a sample +sample = TUEV_ds[0] +print(f"\nSample keys: {sample.keys()}") +print(f"Signal shape: {sample['signal'].shape}") +print(f"Label: {sample['label']}") + # split dataset train_dataset, val_dataset, test_dataset = split_by_visit(