From 6633571c70ef2462c1af22b70de29bcc0de48151 Mon Sep 17 00:00:00 2001 From: Jathurshan0330 Date: Tue, 10 Feb 2026 15:10:41 -0600 Subject: [PATCH 1/4] Fix to TUH tasks - flexibility in preprocessing parameters --- pyhealth/tasks/temple_university_EEG_tasks.py | 58 +++++++++++++------ 1 file changed, 41 insertions(+), 17 deletions(-) diff --git a/pyhealth/tasks/temple_university_EEG_tasks.py b/pyhealth/tasks/temple_university_EEG_tasks.py index 3e8e20a4a..f2669d070 100644 --- a/pyhealth/tasks/temple_university_EEG_tasks.py +++ b/pyhealth/tasks/temple_university_EEG_tasks.py @@ -37,19 +37,28 @@ class EEGEventsTUEV(BaseTask): input_schema: Dict[str, str] = {"signal": "tensor"} output_schema: Dict[str, str] = {"label": "multiclass"} - def __init__(self) -> None: + def __init__(self, + resample_rate: float = 200, + bandpass_filter: Tuple[float, float] = (0.1, 75.0), + notch_filter: float = 50.0, + ) -> None: super().__init__() + self.resample_rate = resample_rate + self.bandpass_filter = bandpass_filter + self.notch_filter = notch_filter + @staticmethod def BuildEvents( - signals: np.ndarray, times: np.ndarray, EventData: np.ndarray + signals: np.ndarray, times: np.ndarray, EventData: np.ndarray, + resample_rate: float = 200, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: # Ensure 2D in case a .rec has only one row EventData = np.atleast_2d(EventData) numEvents, _ = EventData.shape - fs = 256.0 + fs = resample_rate numChan, _ = signals.shape features = np.zeros([numEvents, numChan, int(fs) * 5]) @@ -104,12 +113,16 @@ def convert_signals(signals: np.ndarray, Rawdata: mne.io.BaseRaw) -> np.ndarray: return new_signals @staticmethod - def readEDF(fileName: str) -> Tuple[np.ndarray, np.ndarray, np.ndarray, mne.io.BaseRaw]: + def readEDF(fileName: str, + resample_rate: float = 200, + bandpass_filter: Tuple[float, float] = (0.1, 75.0), + notch_filter: float = 50.0, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, mne.io.BaseRaw]: Rawdata = mne.io.read_raw_edf(fileName, preload=True, verbose="error") - Rawdata.filter(l_freq=0.1, h_freq=75.0, verbose="error") - Rawdata.notch_filter(50.0, verbose="error") - Rawdata.resample(256, n_jobs=5, verbose="error") + Rawdata.filter(l_freq=bandpass_filter[0], h_freq=bandpass_filter[1], verbose="error") + Rawdata.notch_filter(notch_filter, verbose="error") + Rawdata.resample(resample_rate, n_jobs=5, verbose="error") _, times = Rawdata[:] signals = Rawdata.get_data(units="uV") @@ -134,9 +147,9 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: for event in events: edf_path = event.signal_file - signals, times, rec, raw = self.readEDF(edf_path) + signals, times, rec, raw = self.readEDF(edf_path, self.resample_rate, self.bandpass_filter, self.notch_filter) signals = self.convert_signals(signals, raw) - feats, offending_channels, labels = self.BuildEvents(signals, times, rec) + feats, offending_channels, labels = self.BuildEvents(signals, times, rec, self.resample_rate) for idx, (signal, offending_channel, label) in enumerate( zip(feats, offending_channels, labels) @@ -185,16 +198,27 @@ class EEGAbnormalTUAB(BaseTask): input_schema: Dict[str, str] = {"signal": "tensor"} output_schema: Dict[str, str] = {"label": "binary"} - def __init__(self) -> None: + def __init__(self, + resample_rate: float = 200, + bandpass_filter: Tuple[float, float] = (0.1, 75.0), + notch_filter: float = 50.0, + ) -> None: super().__init__() - + self.resample_rate = resample_rate + self.bandpass_filter = bandpass_filter + self.notch_filter = notch_filter + @staticmethod - def read_and_process_edf(fileName: str) -> Tuple[np.ndarray, List[str]]: + def read_and_process_edf(fileName: str, + resample_rate: float = 200, + bandpass_filter: Tuple[float, float] = (0.1, 75.0), + notch_filter: float = 50.0, + ) -> Tuple[np.ndarray, List[str]]: Rawdata = mne.io.read_raw_edf(fileName, preload=True, verbose="error") - Rawdata.filter(l_freq=0.1, h_freq=75.0, verbose="error") - Rawdata.notch_filter(50.0, verbose="error") - Rawdata.resample(200, n_jobs=5, verbose="error") + Rawdata.filter(l_freq=bandpass_filter[0], h_freq=bandpass_filter[1], verbose="error") + Rawdata.notch_filter(notch_filter, verbose="error") + Rawdata.resample(resample_rate, n_jobs=5, verbose="error") raw_data = Rawdata.get_data(units="uV") ch_name = Rawdata.ch_names @@ -283,7 +307,7 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: events = patient.get_events() samples: List[Dict[str, Any]] = [] - fs = 200 + fs = self.resample_rate for event in events: edf_path = event.signal_file @@ -292,7 +316,7 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: label = 0 elif label == 'abnormal': label = 1 - raw_data, ch_name = self.read_and_process_edf(edf_path) + raw_data, ch_name = self.read_and_process_edf(edf_path, self.resample_rate, self.bandpass_filter, self.notch_filter) bipolar_data = self.convert_to_bipolar(raw_data, ch_name) num_samples = int(bipolar_data.shape[1] // (fs * 10)) From 534ce4cc1a7f088b7f4bb5eb6ec5f2135da17d04 Mon Sep 17 00:00:00 2001 From: Jathurshan0330 Date: Tue, 10 Feb 2026 15:12:50 -0600 Subject: [PATCH 2/4] Fix: tuev_test for default preprocessing parameters, resampling rate to 200 --- tests/core/test_tuev.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/core/test_tuev.py b/tests/core/test_tuev.py index 20dd11ce3..7e5dec99f 100644 --- a/tests/core/test_tuev.py +++ b/tests/core/test_tuev.py @@ -132,7 +132,7 @@ def __init__(self, ch_names): np.testing.assert_allclose(out[0], expected0) def test_BuildEvents_single_row_eventdata_and_window_length(self): - fs = 256 + fs = 200 num_chan = 16 num_points = 2000 signals = np.random.randn(num_chan, num_points) @@ -157,7 +157,7 @@ def test_call_returns_one_sample_per_event_and_adjusts_label(self): events=[_DummyEvent(signal_file=os.path.join("C:\\", "dummy.edf"))], ) - feats = np.zeros((2, 16, 256 * 5), dtype=float) + feats = np.zeros((2, 16, 200 * 5), dtype=float) offending = np.array([[3], [7]]) labels = np.array([[1], [6]]) # will become 0 and 5 in output @@ -173,7 +173,7 @@ def test_call_returns_one_sample_per_event_and_adjusts_label(self): self.assertEqual(len(samples), 2) self.assertEqual(samples[0]["patient_id"], "patient-0") self.assertIn("signal", samples[0]) - self.assertEqual(samples[0]["signal"].shape, (16, 256 * 5)) + self.assertEqual(samples[0]["signal"].shape, (16, 200 * 5)) self.assertEqual(samples[0]["offending_channel"], 3) self.assertEqual(samples[0]["label"], 0) self.assertEqual(samples[1]["offending_channel"], 7) From 40d904db5eef14f78f1fb429994db65841289933 Mon Sep 17 00:00:00 2001 From: Jathurshan0330 Date: Tue, 10 Feb 2026 16:18:56 -0600 Subject: [PATCH 3/4] Fix: updated examples for TUH tasks --- .../eeg/tuh_eeg/tuab_abnormal_detection.ipynb | 84 ++- .../tuev_eeg_event_classification.ipynb | 582 ++++++++++++++++++ 2 files changed, 621 insertions(+), 45 deletions(-) create mode 100644 examples/eeg/tuh_eeg/tuev_eeg_event_classification.ipynb diff --git a/examples/eeg/tuh_eeg/tuab_abnormal_detection.ipynb b/examples/eeg/tuh_eeg/tuab_abnormal_detection.ipynb index d2f57ace2..0406c265b 100644 --- a/examples/eeg/tuh_eeg/tuab_abnormal_detection.ipynb +++ b/examples/eeg/tuh_eeg/tuab_abnormal_detection.ipynb @@ -56,7 +56,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 2, "id": "d1230c58", "metadata": {}, "outputs": [ @@ -69,10 +69,6 @@ "Using cached metadata from /home/jp65/.cache/pyhealth/tuab\n", "Initializing tuab dataset from /home/jp65/.cache/pyhealth/tuab (dev mode: True)\n", "No cache_dir provided. Using default cache dir: /home/jp65/.cache/pyhealth/9e59c95b-42bb-596b-a667-dd694bc64ac2\n", - "Scanning table: train from /home/jp65/.cache/pyhealth/tuab/tuab-train-pyhealth.csv\n", - "Scanning table: eval from /home/jp65/.cache/pyhealth/tuab/tuab-eval-pyhealth.csv\n", - "Dev mode enabled: limiting to 1000 patients\n", - "Caching event dataframe to /home/jp65/.cache/pyhealth/9e59c95b-42bb-596b-a667-dd694bc64ac2/global_event_df.parquet...\n", "Dataset: tuab\n", "Dev mode: True\n", "Number of patients: 1000\n", @@ -99,7 +95,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 3, "id": "66f68916", "metadata": {}, "outputs": [ @@ -132,9 +128,9 @@ "name": "stderr", "output_type": "stream", "text": [ - " 51%|█████ | 512/1000 [15:10<14:20, 1.76s/it]/home/jp65/miniconda3/envs/pyhealth/lib/python3.12/site-packages/joblib/externals/loky/process_executor.py:782: UserWarning: A worker stopped while some jobs were given to the executor. This can be caused by a too short worker timeout or by a memory leak.\n", + " 90%|████████▉ | 896/1000 [26:01<02:47, 1.61s/it]/home/jp65/miniconda3/envs/pyhealth/lib/python3.12/site-packages/joblib/externals/loky/process_executor.py:782: UserWarning: A worker stopped while some jobs were given to the executor. This can be caused by a too short worker timeout or by a memory leak.\n", " warnings.warn(\n", - "100%|██████████| 1000/1000 [28:09<00:00, 1.69s/it]" + "100%|██████████| 1000/1000 [28:39<00:00, 1.72s/it]" ] }, { @@ -157,7 +153,7 @@ "output_type": "stream", "text": [ "Label label vocab: {0: 0, 1: 1}\n", - "Processing samples and saving to /home/jp65/.cache/pyhealth/9e59c95b-42bb-596b-a667-dd694bc64ac2/tasks/EEG_abnormal_f8cedbe4-72a8-53c3-922d-4cc8730f4c2d/samples_160cf897-aed5-541d-8dff-f65c7676d862.ld...\n", + "Processing samples and saving to /home/jp65/.cache/pyhealth/9e59c95b-42bb-596b-a667-dd694bc64ac2/tasks/EEG_abnormal_d595851a-c8f6-5f1d-bdfb-b3d527c2deb0/samples_160cf897-aed5-541d-8dff-f65c7676d862.ld...\n", "Applying processors on data with 1 workers...\n", "Detected Jupyter notebook environment, setting num_workers to 1\n", "Single worker mode, processing sequentially\n", @@ -182,28 +178,15 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 178760/178760 [05:07<00:00, 580.75it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Worker 0 finished processing samples.\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" + "100%|██████████| 178760/178760 [05:16<00:00, 564.46it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Cached processed samples to /home/jp65/.cache/pyhealth/9e59c95b-42bb-596b-a667-dd694bc64ac2/tasks/EEG_abnormal_f8cedbe4-72a8-53c3-922d-4cc8730f4c2d/samples_160cf897-aed5-541d-8dff-f65c7676d862.ld\n", + "Worker 0 finished processing samples.\n", + "Cached processed samples to /home/jp65/.cache/pyhealth/9e59c95b-42bb-596b-a667-dd694bc64ac2/tasks/EEG_abnormal_d595851a-c8f6-5f1d-bdfb-b3d527c2deb0/samples_160cf897-aed5-541d-8dff-f65c7676d862.ld\n", "Total task samples: 178760\n", "Input schema: {'signal': 'tensor'}\n", "Output schema: {'label': 'binary'}\n", @@ -212,10 +195,21 @@ "Signal shape: torch.Size([16, 2000])\n", "Label: tensor([1.])\n" ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] } ], "source": [ - "sample_dataset = dataset.set_task(EEGAbnormalTUAB())\n", + "sample_dataset = dataset.set_task(EEGAbnormalTUAB(\n", + " resample_rate=200, # Resample rate\n", + " bandpass_filter=(0.1, 75.0), # Bandpass filter\n", + " notch_filter=50.0, # Notch filter\n", + "))\n", "\n", "print(f\"Total task samples: {len(sample_dataset)}\")\n", "print(f\"Input schema: {sample_dataset.input_schema}\")\n", @@ -239,7 +233,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 4, "id": "c01a076f", "metadata": {}, "outputs": [ @@ -276,7 +270,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 5, "id": "1d490449", "metadata": {}, "outputs": [ @@ -322,7 +316,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 6, "id": "7236ddc0", "metadata": {}, "outputs": [ @@ -380,7 +374,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 7, "id": "11d7f9c5", "metadata": {}, "outputs": [ @@ -389,7 +383,7 @@ "output_type": "stream", "text": [ "Model output shape: torch.Size([32, 1])\n", - "Sample output: tensor([-0.9347], device='cuda:0')\n" + "Sample output: tensor([0.1072], device='cuda:0')\n" ] } ], @@ -417,7 +411,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 8, "id": "5521de25", "metadata": {}, "outputs": [], @@ -437,7 +431,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 9, "id": "0c14a78d", "metadata": {}, "outputs": [ @@ -445,16 +439,16 @@ "name": "stdout", "output_type": "stream", "text": [ - "Epoch 1/5, Loss: 0.6952\n", - "Validation Loss: 0.6931, Accuracy: 50.77%\n", - "Epoch 2/5, Loss: 0.6931\n", - "Validation Loss: 0.6932, Accuracy: 49.23%\n", - "Epoch 3/5, Loss: 0.6932\n", - "Validation Loss: 0.6931, Accuracy: 50.76%\n", - "Epoch 4/5, Loss: 0.6935\n", - "Validation Loss: 0.6930, Accuracy: 50.77%\n", - "Epoch 5/5, Loss: 0.7017\n", - "Validation Loss: 0.6933, Accuracy: 49.24%\n" + "Epoch 1/5, Loss: 0.6092\n", + "Validation Loss: 0.5229, Accuracy: 76.95%\n", + "Epoch 2/5, Loss: 0.4883\n", + "Validation Loss: 0.5344, Accuracy: 72.80%\n", + "Epoch 3/5, Loss: 0.4586\n", + "Validation Loss: 0.7649, Accuracy: 62.00%\n", + "Epoch 4/5, Loss: 0.4442\n", + "Validation Loss: 0.4706, Accuracy: 77.78%\n", + "Epoch 5/5, Loss: 0.4250\n", + "Validation Loss: 0.4473, Accuracy: 79.42%\n" ] } ], @@ -509,7 +503,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 10, "id": "bbd0eb33", "metadata": {}, "outputs": [ @@ -517,7 +511,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Test Loss: 0.6932, Accuracy: 49.61%\n" + "Test Loss: 0.4406, Accuracy: 80.17%\n" ] } ], diff --git a/examples/eeg/tuh_eeg/tuev_eeg_event_classification.ipynb b/examples/eeg/tuh_eeg/tuev_eeg_event_classification.ipynb new file mode 100644 index 000000000..95d7082a6 --- /dev/null +++ b/examples/eeg/tuh_eeg/tuev_eeg_event_classification.ipynb @@ -0,0 +1,582 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "a2b5eb60", + "metadata": {}, + "source": [ + "## 1. Environment Setup\n", + "Seed the random generators, import core dependencies, and detect the training device." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "f5284e16", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running on device: cuda\n" + ] + } + ], + "source": [ + "import random\n", + "\n", + "import numpy as np\n", + "import torch\n", + "\n", + "from pyhealth.datasets import TUEVDataset\n", + "from pyhealth.tasks import EEGEventsTUEV\n", + "from pyhealth.datasets.splitter import split_by_sample\n", + "from pyhealth.datasets.utils import get_dataloader\n", + "\n", + "SEED = 42\n", + "random.seed(SEED)\n", + "np.random.seed(SEED)\n", + "torch.manual_seed(SEED)\n", + "if torch.cuda.is_available():\n", + " torch.cuda.manual_seed_all(SEED)\n", + "\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "print(f\"Running on device: {device}\")" + ] + }, + { + "cell_type": "markdown", + "id": "1c999e55", + "metadata": {}, + "source": [ + "## 2. Load TUEV Dataset\n", + "Point to the TUEV dataset root and load the dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "d1230c58", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "No config path provided, using default config\n", + "Using both train and eval subsets\n", + "Wrote train metadata to cache: /home/jp65/.cache/pyhealth/tuev/tuev-train-pyhealth.csv\n", + "Wrote eval metadata to cache: /home/jp65/.cache/pyhealth/tuev/tuev-eval-pyhealth.csv\n", + "Using cached metadata from /home/jp65/.cache/pyhealth/tuev\n", + "Initializing tuev dataset from /home/jp65/.cache/pyhealth/tuev (dev mode: True)\n", + "No cache_dir provided. Using default cache dir: /home/jp65/.cache/pyhealth/fe851030-67cc-5fe9-8de6-9cac645af5a7\n", + "Scanning table: train from /home/jp65/.cache/pyhealth/tuev/tuev-train-pyhealth.csv\n", + "Scanning table: eval from /home/jp65/.cache/pyhealth/tuev/tuev-eval-pyhealth.csv\n", + "Dev mode enabled: limiting to 1000 patients\n", + "Caching event dataframe to /home/jp65/.cache/pyhealth/fe851030-67cc-5fe9-8de6-9cac645af5a7/global_event_df.parquet...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/jp65/miniconda3/envs/pyhealth/lib/python3.12/site-packages/dask/dataframe/core.py:382: UserWarning: Insufficient elements for `head`. 1000 elements requested, only 189 elements available. Try passing larger `npartitions` to `head`.\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Dataset: tuev\n", + "Dev mode: True\n", + "Number of patients: 189\n", + "Number of events: 259\n" + ] + } + ], + "source": [ + "dataset = TUEVDataset(\n", + " root='/srv/local/data/TUH/tuh_eeg_events/v2.0.0/edf', # Update this path\n", + " dev=True\n", + ")\n", + "dataset.stats()" + ] + }, + { + "cell_type": "markdown", + "id": "ff3f040f", + "metadata": {}, + "source": [ + "## 3. Prepare PyHealth Dataset\n", + "Set the task for the dataset and convert raw samples into PyHealth format for abnormal EEG classification." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "66f68916", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Setting task EEG_events for tuev base dataset...\n", + "Applying task transformations on data with 1 workers...\n", + "Detected Jupyter notebook environment, setting num_workers to 1\n", + "Single worker mode, processing sequentially\n", + "Worker 0 started processing 189 patients. (Polars threads: 128)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 0/189 [00:00 1276, pool: 638\n", + " # conv2: 638 -> 634, pool: 317\n", + " self.fc1 = nn.Linear(15808, 128)\n", + " self.fc2 = nn.Linear(128, num_classes)\n", + " self.relu = nn.ReLU()\n", + "\n", + " def forward(self, signal):\n", + " x = self.relu(self.conv1(signal))\n", + " x = self.pool(x)\n", + " x = self.relu(self.conv2(x))\n", + " x = self.pool(x)\n", + " x = x.view(x.size(0), -1)\n", + " x = self.relu(self.fc1(x))\n", + " x = self.fc2(x)\n", + " return x\n", + "\n", + "model = SimpleEEGClassifier(num_classes=6).to(device)\n", + "\n", + "total_params = sum(p.numel() for p in model.parameters())\n", + "trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n", + "\n", + "print(f\"Total parameters: {total_params:,}\")\n", + "print(f\"Trainable parameters: {trainable_params:,}\")" + ] + }, + { + "cell_type": "markdown", + "id": "912ec100", + "metadata": {}, + "source": [ + "## 7. Test Forward Pass\n", + "Verify the model can process a batch and compute outputs." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "11d7f9c5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model output shape: torch.Size([32, 6])\n", + "Sample output: tensor([ 0.4550, -0.1300, 0.4457, -0.0651, 0.2269, -0.5862], device='cuda:0')\n" + ] + } + ], + "source": [ + "# Move batch to device\n", + "test_batch = {key: value.to(device) if hasattr(value, 'to') else value \n", + " for key, value in first_batch.items()}\n", + "\n", + "# Forward pass\n", + "with torch.no_grad():\n", + " output = model(test_batch['signal'])\n", + "\n", + "print(\"Model output shape:\", output.shape)\n", + "print(\"Sample output:\", output[0])" + ] + }, + { + "cell_type": "markdown", + "id": "b0818f3b", + "metadata": {}, + "source": [ + "## 8. Configure Loss and Optimizer\n", + "Define the loss function and optimizer for training." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "5521de25", + "metadata": {}, + "outputs": [], + "source": [ + "criterion = nn.CrossEntropyLoss()\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)" + ] + }, + { + "cell_type": "markdown", + "id": "9bd05c88", + "metadata": {}, + "source": [ + "## 9. Train the Model\n", + "Launch the training loop to learn from the EEG data." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "0c14a78d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/5, Loss: 0.6861\n", + "Validation Loss: 0.2473, Accuracy: 93.95%\n", + "Epoch 2/5, Loss: 0.2535\n", + "Validation Loss: 0.1431, Accuracy: 96.33%\n", + "Epoch 3/5, Loss: 0.4147\n", + "Validation Loss: 0.3619, Accuracy: 88.27%\n", + "Epoch 4/5, Loss: 0.1992\n", + "Validation Loss: 0.1431, Accuracy: 96.97%\n", + "Epoch 5/5, Loss: 0.1228\n", + "Validation Loss: 0.1890, Accuracy: 96.25%\n" + ] + } + ], + "source": [ + "num_epochs = 5\n", + "\n", + "for epoch in range(num_epochs):\n", + " model.train()\n", + " running_loss = 0.0\n", + " for batch in train_loader:\n", + " signals = batch['signal'].to(device)\n", + " labels = batch['label'].to(device)\n", + " \n", + " optimizer.zero_grad()\n", + " outputs = model(signals)\n", + " loss = criterion(outputs, labels)\n", + " loss.backward()\n", + " optimizer.step()\n", + " \n", + " running_loss += loss.item()\n", + " \n", + " print(f\"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader):.4f}\")\n", + " \n", + " # Validation\n", + " if val_loader:\n", + " model.eval()\n", + " val_loss = 0.0\n", + " correct = 0\n", + " total = 0\n", + " with torch.no_grad():\n", + " for batch in val_loader:\n", + " signals = batch['signal'].to(device)\n", + " labels = batch['label'].to(device)\n", + " outputs = model(signals)\n", + " loss = criterion(outputs, labels)\n", + " val_loss += loss.item()\n", + " predicted = torch.argmax(outputs, dim=1)\n", + " total += labels.size(0)\n", + " correct += (predicted == labels).sum().item()\n", + " print(f\"Validation Loss: {val_loss/len(val_loader):.4f}, Accuracy: {100 * correct / total:.2f}%\")" + ] + }, + { + "cell_type": "markdown", + "id": "837ec6ed", + "metadata": {}, + "source": [ + "## 10. Evaluate on Test Set\n", + "Evaluate the trained model on the test set." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "bbd0eb33", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Test Loss: 0.1885, Accuracy: 96.11%\n" + ] + } + ], + "source": [ + "model.eval()\n", + "test_loss = 0.0\n", + "correct = 0\n", + "total = 0\n", + "with torch.no_grad():\n", + " for batch in test_loader:\n", + " signals = batch['signal'].to(device)\n", + " labels = batch['label'].to(device)\n", + " outputs = model(signals)\n", + " loss = criterion(outputs, labels)\n", + " test_loss += loss.item()\n", + " predicted = torch.argmax(outputs, dim=1)\n", + " total += labels.size(0)\n", + " correct += (predicted == labels).sum().item()\n", + "\n", + "print(f\"Test Loss: {test_loss/len(test_loader):.4f}, Accuracy: {100 * correct / total:.2f}%\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "294438fb", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "pyhealth", + "language": "python", + "name": "pyhealth" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From baae593d74dde9a9f1ee95d4bbf1f73a1f468785 Mon Sep 17 00:00:00 2001 From: Jathurshan0330 Date: Tue, 10 Feb 2026 16:21:46 -0600 Subject: [PATCH 4/4] Updated docs for temple university EEG tasks --- .../api/tasks/pyhealth.tasks.temple_university_EEG_tasks.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/api/tasks/pyhealth.tasks.temple_university_EEG_tasks.rst b/docs/api/tasks/pyhealth.tasks.temple_university_EEG_tasks.rst index 6403f0cc0..106c76673 100644 --- a/docs/api/tasks/pyhealth.tasks.temple_university_EEG_tasks.rst +++ b/docs/api/tasks/pyhealth.tasks.temple_university_EEG_tasks.rst @@ -7,6 +7,11 @@ The tasks are: - EEGEventsTUEV: EEG event classification task for the TUEV dataset. - EEGAbnormalTUAB: Binary classification task for the TUAB dataset (abnormal vs normal). +Tasks Parameters: +- resample_rate: int, default=200 # Resample rate +- bandpass_filter: tuple, default=(0.1, 75.0) # Bandpass filter +- notch_filter: float, default=50.0 # Notch filter + .. autoclass:: pyhealth.tasks.temple_university_EEG_tasks.EEGEventsTUEV :members: :undoc-members: