Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/api/tasks/pyhealth.tasks.temple_university_EEG_tasks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ The tasks are:
- EEGEventsTUEV: EEG event classification task for the TUEV dataset.
- EEGAbnormalTUAB: Binary classification task for the TUAB dataset (abnormal vs normal).

Tasks Parameters:
- resample_rate: int, default=200 # Resample rate
- bandpass_filter: tuple, default=(0.1, 75.0) # Bandpass filter
- notch_filter: float, default=50.0 # Notch filter

.. autoclass:: pyhealth.tasks.temple_university_EEG_tasks.EEGEventsTUEV
:members:
:undoc-members:
Expand Down
84 changes: 39 additions & 45 deletions examples/eeg/tuh_eeg/tuab_abnormal_detection.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 2,
"id": "d1230c58",
"metadata": {},
"outputs": [
Expand All @@ -69,10 +69,6 @@
"Using cached metadata from /home/jp65/.cache/pyhealth/tuab\n",
"Initializing tuab dataset from /home/jp65/.cache/pyhealth/tuab (dev mode: True)\n",
"No cache_dir provided. Using default cache dir: /home/jp65/.cache/pyhealth/9e59c95b-42bb-596b-a667-dd694bc64ac2\n",
"Scanning table: train from /home/jp65/.cache/pyhealth/tuab/tuab-train-pyhealth.csv\n",
"Scanning table: eval from /home/jp65/.cache/pyhealth/tuab/tuab-eval-pyhealth.csv\n",
"Dev mode enabled: limiting to 1000 patients\n",
"Caching event dataframe to /home/jp65/.cache/pyhealth/9e59c95b-42bb-596b-a667-dd694bc64ac2/global_event_df.parquet...\n",
"Dataset: tuab\n",
"Dev mode: True\n",
"Number of patients: 1000\n",
Expand All @@ -99,7 +95,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 3,
"id": "66f68916",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -132,9 +128,9 @@
"name": "stderr",
"output_type": "stream",
"text": [
" 51%|█████ | 512/1000 [15:10<14:20, 1.76s/it]/home/jp65/miniconda3/envs/pyhealth/lib/python3.12/site-packages/joblib/externals/loky/process_executor.py:782: UserWarning: A worker stopped while some jobs were given to the executor. This can be caused by a too short worker timeout or by a memory leak.\n",
" 90%|████████▉ | 896/1000 [26:01<02:47, 1.61s/it]/home/jp65/miniconda3/envs/pyhealth/lib/python3.12/site-packages/joblib/externals/loky/process_executor.py:782: UserWarning: A worker stopped while some jobs were given to the executor. This can be caused by a too short worker timeout or by a memory leak.\n",
" warnings.warn(\n",
"100%|██████████| 1000/1000 [28:09<00:00, 1.69s/it]"
"100%|██████████| 1000/1000 [28:39<00:00, 1.72s/it]"
]
},
{
Expand All @@ -157,7 +153,7 @@
"output_type": "stream",
"text": [
"Label label vocab: {0: 0, 1: 1}\n",
"Processing samples and saving to /home/jp65/.cache/pyhealth/9e59c95b-42bb-596b-a667-dd694bc64ac2/tasks/EEG_abnormal_f8cedbe4-72a8-53c3-922d-4cc8730f4c2d/samples_160cf897-aed5-541d-8dff-f65c7676d862.ld...\n",
"Processing samples and saving to /home/jp65/.cache/pyhealth/9e59c95b-42bb-596b-a667-dd694bc64ac2/tasks/EEG_abnormal_d595851a-c8f6-5f1d-bdfb-b3d527c2deb0/samples_160cf897-aed5-541d-8dff-f65c7676d862.ld...\n",
"Applying processors on data with 1 workers...\n",
"Detected Jupyter notebook environment, setting num_workers to 1\n",
"Single worker mode, processing sequentially\n",
Expand All @@ -182,28 +178,15 @@
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 178760/178760 [05:07<00:00, 580.75it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Worker 0 finished processing samples.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
"100%|██████████| 178760/178760 [05:16<00:00, 564.46it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Cached processed samples to /home/jp65/.cache/pyhealth/9e59c95b-42bb-596b-a667-dd694bc64ac2/tasks/EEG_abnormal_f8cedbe4-72a8-53c3-922d-4cc8730f4c2d/samples_160cf897-aed5-541d-8dff-f65c7676d862.ld\n",
"Worker 0 finished processing samples.\n",
"Cached processed samples to /home/jp65/.cache/pyhealth/9e59c95b-42bb-596b-a667-dd694bc64ac2/tasks/EEG_abnormal_d595851a-c8f6-5f1d-bdfb-b3d527c2deb0/samples_160cf897-aed5-541d-8dff-f65c7676d862.ld\n",
"Total task samples: 178760\n",
"Input schema: {'signal': 'tensor'}\n",
"Output schema: {'label': 'binary'}\n",
Expand All @@ -212,10 +195,21 @@
"Signal shape: torch.Size([16, 2000])\n",
"Label: tensor([1.])\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"sample_dataset = dataset.set_task(EEGAbnormalTUAB())\n",
"sample_dataset = dataset.set_task(EEGAbnormalTUAB(\n",
" resample_rate=200, # Resample rate\n",
" bandpass_filter=(0.1, 75.0), # Bandpass filter\n",
" notch_filter=50.0, # Notch filter\n",
"))\n",
"\n",
"print(f\"Total task samples: {len(sample_dataset)}\")\n",
"print(f\"Input schema: {sample_dataset.input_schema}\")\n",
Expand All @@ -239,7 +233,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 4,
"id": "c01a076f",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -276,7 +270,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 5,
"id": "1d490449",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -322,7 +316,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 6,
"id": "7236ddc0",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -380,7 +374,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 7,
"id": "11d7f9c5",
"metadata": {},
"outputs": [
Expand All @@ -389,7 +383,7 @@
"output_type": "stream",
"text": [
"Model output shape: torch.Size([32, 1])\n",
"Sample output: tensor([-0.9347], device='cuda:0')\n"
"Sample output: tensor([0.1072], device='cuda:0')\n"
]
}
],
Expand Down Expand Up @@ -417,7 +411,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 8,
"id": "5521de25",
"metadata": {},
"outputs": [],
Expand All @@ -437,24 +431,24 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 9,
"id": "0c14a78d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/5, Loss: 0.6952\n",
"Validation Loss: 0.6931, Accuracy: 50.77%\n",
"Epoch 2/5, Loss: 0.6931\n",
"Validation Loss: 0.6932, Accuracy: 49.23%\n",
"Epoch 3/5, Loss: 0.6932\n",
"Validation Loss: 0.6931, Accuracy: 50.76%\n",
"Epoch 4/5, Loss: 0.6935\n",
"Validation Loss: 0.6930, Accuracy: 50.77%\n",
"Epoch 5/5, Loss: 0.7017\n",
"Validation Loss: 0.6933, Accuracy: 49.24%\n"
"Epoch 1/5, Loss: 0.6092\n",
"Validation Loss: 0.5229, Accuracy: 76.95%\n",
"Epoch 2/5, Loss: 0.4883\n",
"Validation Loss: 0.5344, Accuracy: 72.80%\n",
"Epoch 3/5, Loss: 0.4586\n",
"Validation Loss: 0.7649, Accuracy: 62.00%\n",
"Epoch 4/5, Loss: 0.4442\n",
"Validation Loss: 0.4706, Accuracy: 77.78%\n",
"Epoch 5/5, Loss: 0.4250\n",
"Validation Loss: 0.4473, Accuracy: 79.42%\n"
]
}
],
Expand Down Expand Up @@ -509,15 +503,15 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 10,
"id": "bbd0eb33",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test Loss: 0.6932, Accuracy: 49.61%\n"
"Test Loss: 0.4406, Accuracy: 80.17%\n"
]
}
],
Expand Down
Loading