Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,18 +1,30 @@
from pyhealth.datasets import split_by_visit, get_dataloader
from pyhealth.trainer import Trainer
from pyhealth.datasets import TUABDataset
from pyhealth.tasks import EEG_isAbnormal_fn
from pyhealth.tasks import EEGAbnormalTUAB
from pyhealth.models import SparcNet

# step 1: load signal data
dataset = TUABDataset(root="/srv/local/data/TUH/tuh_eeg_abnormal/v3.0.0/edf/",
dev=True,
refresh_cache=True,
)

print(dataset.stats())
# step 2: set task
TUAB_ds = dataset.set_task(EEG_isAbnormal_fn)
TUAB_ds.stat()
TUAB_ds = dataset.set_task(EEGAbnormalTUAB(
resample_rate=200,
bandpass_filter=(0.1, 75.0),
notch_filter=50.0,
))
print(f"Total task samples: {len(TUAB_ds)}")
print(f"Input schema: {TUAB_ds.input_schema}")
print(f"Output schema: {TUAB_ds.output_schema}")

# Inspect a sample
sample = TUAB_ds[0]
print(f"\nSample keys: {sample.keys()}")
print(f"Signal shape: {sample['signal'].shape}")
print(f"Label: {sample['label']}")

# split dataset
train_dataset, val_dataset, test_dataset = split_by_visit(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,32 @@
from pyhealth.datasets import split_by_visit, get_dataloader
from pyhealth.trainer import Trainer
from pyhealth.datasets import TUEVDataset
from pyhealth.tasks import EEG_events_fn
from pyhealth.tasks import EEGEventsTUEV
from pyhealth.models import SparcNet

# step 1: load signal data
dataset = TUEVDataset(root="/srv/local/data/TUH/tuh_eeg_events/v2.0.0/edf/",
dev=True,
refresh_cache=True,
)

print(dataset.stats())
# step 2: set task
TUEV_ds = dataset.set_task(EEG_events_fn)
TUEV_ds.stat()
TUEV_ds = dataset.set_task(EEGEventsTUEV(
resample_rate=200, # Resample rate
bandpass_filter=(0.1, 75.0), # Bandpass filter
notch_filter=50.0, # Notch filter
))

print(f"Total task samples: {len(TUEV_ds)}")
print(f"Input schema: {TUEV_ds.input_schema}")
print(f"Output schema: {TUEV_ds.output_schema}")

# Inspect a sample
sample = TUEV_ds[0]
print(f"\nSample keys: {sample.keys()}")
print(f"Signal shape: {sample['signal'].shape}")
print(f"Label: {sample['label']}")


# split dataset
train_dataset, val_dataset, test_dataset = split_by_visit(
Expand Down