diff --git a/.gitignore b/.gitignore index aaf66fc36..cbc16ee49 100644 --- a/.gitignore +++ b/.gitignore @@ -133,7 +133,13 @@ leaderboard/rtd_token.txt # locally pre-trained models pyhealth/medcode/pretrained_embeddings/kg_emb/examples/pretrained_model + +# local testing files +halo_testing/ +halo_testing_script.py +test_halo_model.slurm + data/physionet.org/ # VSCode settings -.vscode/ \ No newline at end of file +.vscode/ diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 000000000..bdca73271 --- /dev/null +++ b/examples/README.md @@ -0,0 +1,55 @@ +# PyHealth Examples + +This directory contains example scripts and notebooks for using PyHealth. + +## HALO Synthetic Data Generation + +### Google Colab Notebook (No Cluster Required) + +**File**: `halo_mimic3_colab.ipynb` + +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/sunlabuiuc/PyHealth/blob/master/examples/halo_mimic3_colab.ipynb) + +Train HALO and generate synthetic MIMIC-III data directly in your browser using Google Colab. + +**Requirements**: +- Google account (for Colab) +- MIMIC-III access from PhysioNet +- Files: ADMISSIONS.csv, DIAGNOSES_ICD.csv, PATIENTS.csv, patient_ids.txt + +**Quick Start**: +1. Open `halo_mimic3_colab.ipynb` in Google Colab +2. Enable GPU (Runtime → Change runtime type → GPU) +3. Run cells in order +4. Upload your MIMIC-III files when prompted +5. Download synthetic data CSV + +**Demo vs Production**: +- **Demo** (default): 5 epochs, 1K samples, ~30 min +- **Production**: 80 epochs, 10K samples, ~6-10 hours (change configuration) + +**Features**: +- Google Drive integration for persistence +- Resume capability if session times out +- Automatic checkpoint saving +- CSV output format +- Data quality validation + +### Cluster Training (SLURM) + +**Files**: +- `slurm/train_halo_mimic3.slurm` - Training script +- `slurm/generate_halo_mimic3.slurm` - Generation script +- `halo_mimic3_training.py` - Python training code +- `generate_synthetic_mimic3_halo.py` - Python generation code + +For users with access to GPU clusters. See individual script headers for usage. + +**Example**: +```bash +# Train +sbatch slurm/train_halo_mimic3.slurm + +# Generate +sbatch slurm/generate_halo_mimic3.slurm +``` diff --git a/examples/generate_synthetic_mimic3_halo.py b/examples/generate_synthetic_mimic3_halo.py new file mode 100644 index 000000000..d88f48110 --- /dev/null +++ b/examples/generate_synthetic_mimic3_halo.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python3 +""" +Example: Generate synthetic MIMIC-III patients using a trained HALO checkpoint. + +Loads MIMIC3Dataset with the halo_generation_mimic3_fn task (identical to +training) so that the vocabulary is reconstructed, then loads the saved +checkpoint and calls model.synthesize_dataset(). Output is saved as JSON. + +Usage: + python examples/generate_synthetic_mimic3_halo.py + python examples/generate_synthetic_mimic3_halo.py --save_dir ./my_save/ --num_samples 500 +""" + +import argparse +import json +import os + +import torch + +from pyhealth.datasets import MIMIC3Dataset +from pyhealth.models.generators.halo import HALO +from pyhealth.tasks import halo_generation_mimic3_fn + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Generate synthetic MIMIC-III patients with HALO" + ) + parser.add_argument( + "--mimic3_root", + default="/path/to/mimic3", + help="Root directory of MIMIC-III data (default: /path/to/mimic3)", + ) + parser.add_argument( + "--save_dir", + default="./save/", + help="Directory containing the trained halo_model checkpoint (default: ./save/)", + ) + parser.add_argument( + "--num_samples", + type=int, + default=1000, + help="Number of synthetic patients to generate (default: 1000)", + ) + parser.add_argument( + "--output", + default="synthetic_patients.json", + help="Output JSON file path (default: synthetic_patients.json)", + ) + return parser.parse_args() + + +def main(): + args = parse_args() + + # ------------------------------------------------------------------ + # STEP 1: Load MIMIC-III dataset + # The dataset must use the same tables and code_mapping as training + # so that the vocabulary is identical. + # ------------------------------------------------------------------ + print("Loading MIMIC-III dataset...") + base_dataset = MIMIC3Dataset( + root=args.mimic3_root, + tables=["diagnoses_icd"], # If you trained with different tables=, update this to match. + code_mapping={}, + dev=False, + refresh_cache=False, + ) + print(f" Loaded {len(base_dataset.patients)} patients") + + # ------------------------------------------------------------------ + # STEP 2: Apply the HALO generation task + # set_task builds the vocabulary via NestedSequenceProcessor — must + # match the task used during training exactly. + # ------------------------------------------------------------------ + print("Applying HALO generation task...") + sample_dataset = base_dataset.set_task(halo_generation_mimic3_fn) + print(f" {len(sample_dataset)} samples after task filtering") + + # ------------------------------------------------------------------ + # STEP 3: Instantiate HALO with the same hyperparameters as training + # The model constructor uses the dataset to determine vocab sizes; + # the weights are loaded from the checkpoint immediately after. + # ------------------------------------------------------------------ + print("Initializing HALO model...") + model = HALO( + dataset=sample_dataset, + embed_dim=768, + n_heads=12, + n_layers=12, + n_ctx=48, + batch_size=48, + epochs=50, # unused during generation; must match training for checkpoint compatibility + pos_loss_weight=None, + lr=1e-4, + save_dir=args.save_dir, + ) + + # ------------------------------------------------------------------ + # STEP 4: Load trained checkpoint + # The training loop saves to save_dir/halo_model with keys + # "model" (halo_model state dict) and "optimizer". + # ------------------------------------------------------------------ + checkpoint_path = os.path.join(args.save_dir, "halo_model") + print(f"Loading checkpoint from {checkpoint_path} ...") + if not os.path.exists(checkpoint_path): + raise FileNotFoundError( + f"Checkpoint not found at {checkpoint_path}. " + "Train the model first with examples/halo_mimic3_training.py." + ) + checkpoint = torch.load(checkpoint_path, map_location="cpu") + model.halo_model.load_state_dict(checkpoint["model"]) + print(" Checkpoint loaded successfully") + + # ------------------------------------------------------------------ + # STEP 5: Generate synthetic patients + # synthesize_dataset returns List[Dict] where each dict has: + # "patient_id": "synthetic_N" + # "visits": [[code, ...], ...] + # ------------------------------------------------------------------ + print(f"Generating {args.num_samples} synthetic patients...") + synthetic_data = model.synthesize_dataset( + num_samples=args.num_samples, + random_sampling=True, + ) + + # ------------------------------------------------------------------ + # STEP 6: Save output as JSON + # ------------------------------------------------------------------ + print(f"Saving output to {args.output} ...") + with open(args.output, "w") as f: + json.dump(synthetic_data, f, indent=2) + + # ------------------------------------------------------------------ + # STEP 7: Print summary statistics + # ------------------------------------------------------------------ + total_patients = len(synthetic_data) + total_visits = sum(len(p["visits"]) for p in synthetic_data) + avg_visits = total_visits / total_patients if total_patients > 0 else 0.0 + + print("\n--- Generation Summary ---") + print(f" Patients generated : {total_patients}") + print(f" Total visits : {total_visits}") + print(f" Avg visits/patient : {avg_visits:.2f}") + print(f" Output saved to : {args.output}") + print("Done.") + + +if __name__ == "__main__": + main() diff --git a/examples/halo_mimic3_colab.ipynb b/examples/halo_mimic3_colab.ipynb new file mode 100644 index 000000000..e1dcd1f57 --- /dev/null +++ b/examples/halo_mimic3_colab.ipynb @@ -0,0 +1,1553 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# HALO Synthetic Data Generation for MIMIC-III\n\n_Last updated: 2026-03-02_\n\nThis notebook trains the HALO (Hierarchical Autoregressive Language mOdel) on your MIMIC-III data and generates synthetic patients.\n\n## What You'll Need\n\n1. **MIMIC-III Access**: Download these files from PhysioNet:\n - `ADMISSIONS.csv`\n - `DIAGNOSES_ICD.csv`\n\n2. **Google Colab**: Free tier works, but GPU recommended (Runtime \u2192 Change runtime type \u2192 GPU)\n\n3. **Time**:\n - Demo (5 epochs, 1K samples): ~20-30 min on GPU\n - Production (80 epochs, 10K samples): ~6-10 hours on GPU\n\n## How It Works\n\n1. **Setup**: Install PyHealth and mount Google Drive\n2. **Upload Data**: Upload your MIMIC-III CSV files\n3. **Configure**: Set hyperparameters (epochs, batch size, etc.)\n4. **Train**: Train HALO model (checkpoints saved to Drive)\n5. **Generate**: Create synthetic patients using trained model\n6. **Download**: Get CSV file with synthetic data\n\n## Important Notes\n\n **Colab Timeout**: Free Colab sessions timeout after 12 hours. For production training (80 epochs), consider:\n- Colab Pro for longer sessions\n- Running on your own GPU cluster using `examples/slurm/train_halo_mimic3.slurm`\n\n **Demo vs Production**:\n- Demo defaults (5 epochs, 1K samples) let you test the pipeline quickly\n- Production settings (80 epochs, 10K samples) match the published HALO results\n\n## References\n\n- [HALO Paper](https://arxiv.org/abs/2406.16061)\n- [PyHealth Documentation](https://pyhealth.readthedocs.io/)\n- [MIMIC-III Access](https://physionet.org/content/mimiciii/)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "# 1. Setup & Installation" + ] + }, + { + "cell_type": "code", + "source": [ + "import subprocess\n", + "import sys\n", + "\n", + "FORK = 'jalengg'\n", + "BRANCH = 'halo-pr-528' # set to None to use main branch\n", + "\n", + "if BRANCH:\n", + " install_url = f\"git+https://github.com/{FORK}/PyHealth.git@{BRANCH}\"\n", + " print(f\"Installing PyHealth from branch '{BRANCH}'...\")\n", + "else:\n", + " install_url = f\"git+https://github.com/{FORK}/PyHealth.git\"\n", + " print(\"Installing PyHealth from main branch...\")\n", + "\n", + "result = subprocess.run(\n", + " [sys.executable, \"-m\", \"pip\", \"install\", \"-q\", \"--no-cache-dir\", \"--force-reinstall\", install_url],\n", + " capture_output=True, text=True\n", + ")\n", + "if result.returncode == 0:\n", + " print(\"PyHealth installed successfully!\")\n", + "else:\n", + " print(\"PyHealth installation failed!\")\n", + " print(result.stderr)\n", + " raise RuntimeError(\"PyHealth installation failed. Please check the error above.\")\n" + ], + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": "# Import required libraries\nimport os\nimport sys\nimport torch\nimport pandas as pd\nimport shutil\nfrom google.colab import drive, files\nfrom IPython.display import display, Markdown, HTML\n\nprint(\" All libraries imported successfully!\")\nprint(f\"PyTorch version: {torch.__version__}\")\nprint(f\"CUDA available: {torch.cuda.is_available()}\")\nif torch.cuda.is_available():\n print(f\"GPU: {torch.cuda.get_device_name(0)}\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": "# Mount Google Drive for persistent storage\nprint(\"Mounting Google Drive...\")\ndrive.mount('/content/drive')\nprint(\" Google Drive mounted at /content/drive\")\n\n# Create directory structure in Drive\nbase_dir = '/content/drive/MyDrive/HALO_Training'\ndata_dir = f'{base_dir}/data'\ncheckpoint_dir = f'{base_dir}/checkpoints'\noutput_dir = f'{base_dir}/output'\n\nfor dir_path in [base_dir, data_dir, checkpoint_dir, output_dir]:\n os.makedirs(dir_path, exist_ok=True)\n\nprint(f\"\\n Directory structure created:\")\nprint(f\" Base: {base_dir}\")\nprint(f\" Data: {data_dir}\")\nprint(f\" Checkpoints: {checkpoint_dir}\")\nprint(f\" Output: {output_dir}\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "# 2. Configuration" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Configure your training and generation parameters below.\n", + "\n", + "**For Quick Demo (recommended first time):**\n", + "- Leave defaults (5 epochs, 1K samples)\n", + "- Runs in ~20-30 minutes on GPU\n", + "\n", + "**For Production Quality:**\n", + "- Set `EPOCHS = 80`\n", + "- Set `N_SYNTHETIC_SAMPLES = 10000`\n", + "- Expect ~6-10 hours on GPU" + ] + }, + { + "cell_type": "code", + "metadata": {}, + "execution_count": null, + "outputs": [], + "source": "# ============================================\n# CONFIGURATION - Modify these parameters\n# ============================================\n\n# Training configuration\nEPOCHS = 5 # Demo: 5, Production: 80\nBATCH_SIZE = 32 # Demo: 32, Production: 48\nLEARNING_RATE = 0.0001 # Standard HALO learning rate\n\n# Generation configuration\nN_SYNTHETIC_SAMPLES = 1000 # Demo: 1000, Production: 10000\n\n# Advanced HALO model configuration (usually don't need to change)\nN_CTX = 48 # Context window\nN_EMBD = 768 # Embedding dimension\nN_LAYER = 12 # Number of transformer layers\nN_HEAD = 12 # Number of attention heads\n\n# Display configuration\nprint(\"=\" * 60)\nprint(\"HALO CONFIGURATION\")\nprint(\"=\" * 60)\nprint(f\"Training:\")\nprint(f\" Epochs: {EPOCHS}\")\nprint(f\" Batch size: {BATCH_SIZE}\")\nprint(f\" Learning rate: {LEARNING_RATE}\")\nprint(f\"\\nGeneration:\")\nprint(f\" Synthetic samples: {N_SYNTHETIC_SAMPLES}\")\nprint(f\"\\nModel architecture:\")\nprint(f\" Embedding dim: {N_EMBD}\")\nprint(f\" Layers: {N_LAYER}\")\nprint(f\" Attention heads: {N_HEAD}\")\nprint(\"=\" * 60)\n\n# Estimate runtime\nif torch.cuda.is_available():\n runtime_min = EPOCHS * 4 # Rough estimate: 4 min per epoch on GPU\n print(f\"\\nEstimated training time: ~{runtime_min}-{runtime_min*2} minutes on GPU\")\nelse:\n print(f\"\\n WARNING: No GPU detected! Training will be very slow.\")\n print(f\" Go to Runtime \u2192 Change runtime type \u2192 Select GPU\")" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "# 3. Data Upload" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Upload your MIMIC-III CSV files. You need these 3 files (ICUSTAYS.csv is optional):\n", + "\n", + "1. `PATIENTS.csv` - Patient demographics\n", + "2. `ADMISSIONS.csv` - Patient admission records\n", + "3. `DIAGNOSES_ICD.csv` - Diagnosis codes (ICD-9)\n", + "\n", + "**Optional**: `ICUSTAYS.csv` - ICU stay records (included automatically if present)\n", + "\n", + "**Note**: Files will be saved to Google Drive and persist across Colab sessions. If files already exist in Drive, you can skip uploading and reuse them." + ] + }, + { + "cell_type": "code", + "metadata": {}, + "execution_count": null, + "outputs": [], + "source": [ + "import os\n", + "\n", + "# Required MIMIC-III files (only 2 needed for HALO)\n", + "required_files = {\n", + " 'PATIENTS.csv': 'Patient demographics',\n", + " 'ADMISSIONS.csv': 'Patient admission records',\n", + " 'DIAGNOSES_ICD.csv': 'Diagnosis codes (ICD-9)'\n", + "}\n", + "\n", + "# Check if files already exist in Google Drive\n", + "existing_files = {f: os.path.exists(f'{data_dir}/{f}') for f in required_files}\n", + "missing_files = [f for f, exists in existing_files.items() if not exists]\n", + "\n", + "if not missing_files:\n", + " print(\" All required files already exist in Google Drive!\")\n", + " for filename in required_files:\n", + " file_path = f'{data_dir}/{filename}'\n", + " file_size_mb = os.path.getsize(file_path) / (1024 * 1024)\n", + " print(f\" {filename} ({file_size_mb:.1f} MB)\")\n", + " print(f\"\\nSkipping upload. To re-upload, delete files from: {data_dir}\")\n", + "else:\n", + " print(\"Upload the following MIMIC-III files:\")\n", + " for i, filename in enumerate(missing_files, 1):\n", + " print(f\" {i}. {filename} - {required_files[filename]}\")\n", + " print(\"\\nYou can upload files one at a time or all together.\\n\")\n", + " \n", + " uploaded = files.upload()\n", + " \n", + " # Normalize filenames - handle Colab's automatic renaming (e.g., \"ADMISSIONS (1).csv\")\n", + " print(\"\\nProcessing uploaded files...\")\n", + " for uploaded_name in uploaded.keys():\n", + " # Find which required file this matches\n", + " matched_file = None\n", + " for req_file in required_files:\n", + " # Extract base name without extension\n", + " req_base = req_file.replace('.csv', '')\n", + " # Check if uploaded name contains the required base name\n", + " if req_base in uploaded_name and uploaded_name.endswith('.csv'):\n", + " matched_file = req_file\n", + " break\n", + " \n", + " if matched_file:\n", + " src = f'/content/{uploaded_name}'\n", + " dst = f'{data_dir}/{matched_file}'\n", + " shutil.copy(src, dst)\n", + " file_size_mb = os.path.getsize(dst) / (1024 * 1024)\n", + " print(f\" {matched_file} ({file_size_mb:.1f} MB)\")\n", + " if matched_file in missing_files:\n", + " missing_files.remove(matched_file)\n", + " \n", + " # Check if we still have missing files\n", + " if missing_files:\n", + " print(f\"\\n Still missing: {missing_files}\")\n", + " print(\"Please run this cell again to upload the remaining files.\")\n", + " else:\n", + " print(\"\\n All files uploaded successfully!\")\n", + " print(f\" Location: {data_dir}\")" + ] + }, + { + "cell_type": "code", + "metadata": {}, + "execution_count": null, + "outputs": [], + "source": [ + "# Validate uploaded files\n", + "print(\"Validating data files...\")\n", + "\n", + "# Check PATIENTS.csv\n", + "patients = pd.read_csv(f'{data_dir}/PATIENTS.csv', nrows=5)\n", + "print(f\"\\n PATIENTS.csv: {len(patients.columns)} columns\")\n", + "print(f\" Sample columns: {', '.join(patients.columns[:3])}\")\n", + "\n", + "# Check ADMISSIONS.csv\n", + "admissions = pd.read_csv(f'{data_dir}/ADMISSIONS.csv', nrows=5)\n", + "print(f\"\\n ADMISSIONS.csv: {len(admissions.columns)} columns\")\n", + "print(f\" Sample columns: {', '.join(admissions.columns[:3])}\")\n", + "\n", + "\n", + "# Check DIAGNOSES_ICD.csv\n", + "diagnoses = pd.read_csv(f'{data_dir}/DIAGNOSES_ICD.csv', nrows=5)\n", + "print(f\"\\n DIAGNOSES_ICD.csv: {len(diagnoses.columns)} columns\")\n", + "print(f\" Sample columns: {', '.join(diagnoses.columns[:3])}\")\n", + "\n", + "print(\"\\n\" + \"=\" * 60)\n", + "print(\" DATA VALIDATION COMPLETE\")\n", + "print(\"=\" * 60)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "# 4. Training" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": "Train the HALO model on your MIMIC-III data.\n\n**What happens during training:**\n- Dataset is preprocessed and vocabularies are created\n- HALO model trains for the specified number of epochs\n- Best checkpoint is automatically saved when validation loss improves\n- Progress updates print every 1,000 iterations\n\n**Checkpoints saved to Drive:**\n- `halo_model` - Best model (lowest validation loss)\n\n**Resume capability**: If training is interrupted, checkpoints in Drive allow resuming." + }, + { + "cell_type": "code", + "metadata": {}, + "execution_count": null, + "outputs": [], + "source": [ + "from pyhealth.datasets import MIMIC3Dataset, split_by_patient\n", + "from pyhealth.tasks import halo_generation_mimic3_fn\n", + "from pyhealth.models.generators.halo import HALO\n", + "\n", + "# Check for existing checkpoint\n", + "existing_checkpoint = os.path.exists(f'{checkpoint_dir}/halo_model')\n", + "if existing_checkpoint:\n", + " print(\"=\" * 60)\n", + " print(\"EXISTING CHECKPOINT FOUND\")\n", + " print(\"=\" * 60)\n", + " print(f\"Found checkpoint at: {checkpoint_dir}/halo_model\")\n", + " print(\"\\nOptions:\")\n", + " print(\" 1. Skip training and go to generation (recommended if training completed)\")\n", + " print(\" 2. Continue running this cell to retrain (will overwrite checkpoint)\")\n", + " print(\"\\nTo skip training, jump to the 'Generation' section below.\")\n", + " print(\"=\" * 60)\n", + "\n", + "# Device configuration\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "print(f\"\\nUsing device: {device}\", flush=True)\n", + "if torch.cuda.is_available():\n", + " print(f\"GPU: {torch.cuda.get_device_name(0)}\", flush=True)\n", + " print(f\"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB\", flush=True)\n", + "else:\n", + " print(\"WARNING: Training on CPU will be very slow!\", flush=True)\n", + "\n", + "# Load MIMIC-III dataset via PyHealth 2.0\n", + "print(f\"\\n{'='*60}\", flush=True)\n", + "print(\"Loading and preprocessing MIMIC-III dataset...\", flush=True)\n", + "print(f\"{'='*60}\", flush=True)\n", + "print(f\"Data directory: {data_dir}\", flush=True)\n", + "\n", + "base_dataset = MIMIC3Dataset(\n", + " root=data_dir,\n", + " tables=[\"diagnoses_icd\", \"admissions\"],\n", + " dev=False,\n", + ")\n", + "print(f\"Loaded {len(base_dataset.unique_patient_ids)} patients\", flush=True)\n", + "\n", + "# Apply HALO generation task (builds vocabulary via NestedSequenceProcessor)\n", + "sample_dataset = base_dataset.set_task(halo_generation_mimic3_fn)\n", + "print(f\"Task samples after filtering: {len(sample_dataset)}\", flush=True)\n", + "\n", + "# Split into train / val\n", + "train_dataset, val_dataset, _ = split_by_patient(sample_dataset, [0.8, 0.1, 0.1])\n", + "print(f\"\\nTrain: {len(train_dataset)}, Val: {len(val_dataset)}\", flush=True)\n", + "\n", + "# Initialize HALO model (vocabulary derived automatically from dataset)\n", + "print(f\"\\n{'='*60}\", flush=True)\n", + "print(\"Initializing HALO model...\", flush=True)\n", + "print(f\"{'='*60}\", flush=True)\n", + "\n", + "model = HALO(\n", + " dataset=sample_dataset,\n", + " embed_dim=N_EMBD,\n", + " n_heads=N_HEAD,\n", + " n_layers=N_LAYER,\n", + " n_ctx=N_CTX,\n", + " batch_size=BATCH_SIZE,\n", + " epochs=EPOCHS,\n", + " lr=LEARNING_RATE,\n", + " save_dir=checkpoint_dir,\n", + ")\n", + "print(f\"Model initialized on {model.device}\", flush=True)\n", + "print(f\" Total vocabulary size: {model.config.total_vocab_size}\", flush=True)\n", + "print(f\" Embedding dim: {N_EMBD}, Layers: {N_LAYER}, Heads: {N_HEAD}\", flush=True)\n", + "\n", + "# Train model (best checkpoint saved automatically when val loss improves)\n", + "print(f\"\\n{'='*60}\", flush=True)\n", + "print(\"Training HALO model...\", flush=True)\n", + "print(f\"{'='*60}\", flush=True)\n", + "print(f\"Training for {EPOCHS} epochs\", flush=True)\n", + "print(f\"Checkpoints saved to: {checkpoint_dir}/halo_model\", flush=True)\n", + "print(f\"{'='*60}\\n\", flush=True)\n", + "\n", + "model.train_model(train_dataset, val_dataset)\n", + "\n", + "print(f\"\\n{'='*60}\", flush=True)\n", + "print(\"TRAINING COMPLETE!\", flush=True)\n", + "print(f\"{'='*60}\", flush=True)\n", + "print(f\"Best checkpoint saved to: {checkpoint_dir}/halo_model\", flush=True)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "# 5. Generation" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Generate synthetic patients using the trained HALO model.\n", + "\n", + "**What happens during generation:**\n", + "- Loads the best checkpoint from training\n", + "- Generates the specified number of synthetic patients\n", + "- Each patient has multiple visits with ICD-9 diagnosis codes\n", + "- Outputs CSV file with columns: SUBJECT_ID, VISIT_NUM, ICD9_CODE\n", + "\n", + "**Generation time**:\n", + "- 1,000 patients: ~10-15 minutes\n", + "- 10,000 patients: ~1-2 hours" + ] + }, + { + "cell_type": "code", + "metadata": {}, + "execution_count": null, + "outputs": [], + "source": [ + "from pyhealth.datasets import MIMIC3Dataset\n", + "from pyhealth.tasks import halo_generation_mimic3_fn\n", + "from pyhealth.models.generators.halo import HALO\n", + "\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "print(f\"Using device: {device}\", flush=True)\n", + "\n", + "# Reconstruct the same vocabulary used during training.\n", + "print(f\"\\n{'='*60}\", flush=True)\n", + "print(\"Reconstructing vocabulary from MIMIC-III dataset...\", flush=True)\n", + "print(f\"{'='*60}\", flush=True)\n", + "\n", + "base_dataset = MIMIC3Dataset(\n", + " root=data_dir,\n", + " tables=[\"diagnoses_icd\", \"admissions\"],\n", + " dev=False,\n", + ")\n", + "sample_dataset = base_dataset.set_task(halo_generation_mimic3_fn)\n", + "print(f\"Vocabulary reconstructed ({len(sample_dataset)} samples)\", flush=True)\n", + "\n", + "# Initialize model with same hyperparameters as training\n", + "print(f\"\\n{'='*60}\", flush=True)\n", + "print(\"Initializing HALO model...\", flush=True)\n", + "print(f\"{'='*60}\", flush=True)\n", + "\n", + "model = HALO(\n", + " dataset=sample_dataset,\n", + " embed_dim=N_EMBD,\n", + " n_heads=N_HEAD,\n", + " n_layers=N_LAYER,\n", + " n_ctx=N_CTX,\n", + " batch_size=BATCH_SIZE,\n", + " epochs=EPOCHS,\n", + " lr=LEARNING_RATE,\n", + " save_dir=checkpoint_dir,\n", + ")\n", + "\n", + "# Load trained checkpoint\n", + "checkpoint_path = os.path.join(checkpoint_dir, 'halo_model')\n", + "if not os.path.exists(checkpoint_path):\n", + " raise FileNotFoundError(\n", + " f\"No checkpoint found at {checkpoint_path}.\\n\"\n", + " f\"Please run the Training section first.\"\n", + " )\n", + "\n", + "print(f\"\\n{'='*60}\", flush=True)\n", + "print(f\"Loading checkpoint from {checkpoint_path}\", flush=True)\n", + "print(f\"{'='*60}\", flush=True)\n", + "\n", + "checkpoint = torch.load(checkpoint_path, map_location=model.device, weights_only=False)\n", + "model.halo_model.load_state_dict(checkpoint['model'])\n", + "print(\"Checkpoint loaded successfully\", flush=True)\n", + "\n", + "# Generate synthetic patients\n", + "print(f\"\\n{'='*60}\", flush=True)\n", + "print(f\"Generating {N_SYNTHETIC_SAMPLES} synthetic patients...\", flush=True)\n", + "print(f\"{'='*60}\", flush=True)\n", + "\n", + "if torch.cuda.is_available():\n", + " est_minutes = max(10, N_SYNTHETIC_SAMPLES // 100)\n", + " print(f\"Estimated time: ~{est_minutes} minutes on GPU\", flush=True)\n", + "print(f\"\", flush=True)\n", + "\n", + "synthetic_data = model.synthesize_dataset(\n", + " num_samples=N_SYNTHETIC_SAMPLES,\n", + " random_sampling=True,\n", + ")\n", + "\n", + "print(f\"\\n{'='*60}\", flush=True)\n", + "print(\"GENERATION COMPLETE!\", flush=True)\n", + "print(f\"{'='*60}\", flush=True)\n", + "\n", + "print(f\"\\nSynthetic data statistics:\", flush=True)\n", + "print(f\" Total patients: {len(synthetic_data)}\", flush=True)\n", + "\n", + "total_visits = sum(len(p['visits']) for p in synthetic_data)\n", + "avg_visits = total_visits / len(synthetic_data)\n", + "print(f\" Total visits: {total_visits}\", flush=True)\n", + "print(f\" Avg visits per patient: {avg_visits:.2f}\", flush=True)\n", + "\n", + "total_codes = sum(len(visit) for p in synthetic_data for visit in p['visits'])\n", + "avg_codes = total_codes / total_visits if total_visits > 0 else 0\n", + "print(f\" Avg codes per visit: {avg_codes:.2f}\", flush=True)\n" + ] + }, + { + "cell_type": "code", + "metadata": {}, + "execution_count": null, + "outputs": [], + "source": [ + "#", + " ", + "C", + "o", + "n", + "v", + "e", + "r", + "t", + " ", + "t", + "o", + " ", + "C", + "S", + "V", + " ", + "f", + "o", + "r", + "m", + "a", + "t", + "\n", + "#", + " ", + "s", + "y", + "n", + "t", + "h", + "e", + "s", + "i", + "z", + "e", + "_", + "d", + "a", + "t", + "a", + "s", + "e", + "t", + "(", + ")", + " ", + "r", + "e", + "t", + "u", + "r", + "n", + "s", + " ", + "s", + "t", + "r", + "i", + "n", + "g", + " ", + "c", + "o", + "d", + "e", + "s", + " ", + "d", + "i", + "r", + "e", + "c", + "t", + "l", + "y", + " ", + "\u2014", + " ", + "n", + "o", + " ", + "i", + "n", + "d", + "e", + "x", + " ", + "l", + "o", + "o", + "k", + "u", + "p", + " ", + "n", + "e", + "e", + "d", + "e", + "d", + "\n", + "p", + "r", + "i", + "n", + "t", + "(", + "f", + "\"", + "\\", + "n", + "{", + "'", + "=", + "'", + "*", + "6", + "0", + "}", + "\"", + ",", + " ", + "f", + "l", + "u", + "s", + "h", + "=", + "T", + "r", + "u", + "e", + ")", + "\n", + "p", + "r", + "i", + "n", + "t", + "(", + "\"", + "C", + "o", + "n", + "v", + "e", + "r", + "t", + "i", + "n", + "g", + " ", + "t", + "o", + " ", + "C", + "S", + "V", + " ", + "f", + "o", + "r", + "m", + "a", + "t", + ".", + ".", + ".", + "\"", + ",", + " ", + "f", + "l", + "u", + "s", + "h", + "=", + "T", + "r", + "u", + "e", + ")", + "\n", + "p", + "r", + "i", + "n", + "t", + "(", + "f", + "\"", + "{", + "'", + "=", + "'", + "*", + "6", + "0", + "}", + "\"", + ",", + " ", + "f", + "l", + "u", + "s", + "h", + "=", + "T", + "r", + "u", + "e", + ")", + "\n", + "\n", + "c", + "s", + "v", + "_", + "p", + "a", + "t", + "h", + " ", + "=", + " ", + "f", + "'", + "{", + "o", + "u", + "t", + "p", + "u", + "t", + "_", + "d", + "i", + "r", + "}", + "/", + "h", + "a", + "l", + "o", + "_", + "s", + "y", + "n", + "t", + "h", + "e", + "t", + "i", + "c", + "_", + "d", + "a", + "t", + "a", + ".", + "c", + "s", + "v", + "'", + "\n", + "\n", + "r", + "e", + "c", + "o", + "r", + "d", + "s", + " ", + "=", + " ", + "[", + "]", + "\n", + "f", + "o", + "r", + " ", + "p", + "a", + "t", + "i", + "e", + "n", + "t", + "_", + "i", + "d", + "x", + ",", + " ", + "p", + "a", + "t", + "i", + "e", + "n", + "t", + " ", + "i", + "n", + " ", + "e", + "n", + "u", + "m", + "e", + "r", + "a", + "t", + "e", + "(", + "s", + "y", + "n", + "t", + "h", + "e", + "t", + "i", + "c", + "_", + "d", + "a", + "t", + "a", + ")", + ":", + "\n", + " ", + " ", + " ", + " ", + "p", + "a", + "t", + "i", + "e", + "n", + "t", + "_", + "i", + "d", + " ", + "=", + " ", + "f", + "\"", + "S", + "Y", + "N", + "T", + "H", + "E", + "T", + "I", + "C", + "_", + "{", + "p", + "a", + "t", + "i", + "e", + "n", + "t", + "_", + "i", + "d", + "x", + "+", + "1", + ":", + "0", + "6", + "d", + "}", + "\"", + "\n", + " ", + " ", + " ", + " ", + "f", + "o", + "r", + " ", + "v", + "i", + "s", + "i", + "t", + "_", + "n", + "u", + "m", + ",", + " ", + "v", + "i", + "s", + "i", + "t", + " ", + "i", + "n", + " ", + "e", + "n", + "u", + "m", + "e", + "r", + "a", + "t", + "e", + "(", + "p", + "a", + "t", + "i", + "e", + "n", + "t", + "[", + "'", + "v", + "i", + "s", + "i", + "t", + "s", + "'", + "]", + ",", + " ", + "1", + ")", + ":", + "\n", + " ", + " ", + " ", + " ", + " ", + " ", + " ", + " ", + "f", + "o", + "r", + " ", + "i", + "c", + "d", + "9", + "_", + "c", + "o", + "d", + "e", + " ", + "i", + "n", + " ", + "v", + "i", + "s", + "i", + "t", + ":", + " ", + " ", + "#", + " ", + "c", + "o", + "d", + "e", + "s", + " ", + "a", + "r", + "e", + " ", + "a", + "l", + "r", + "e", + "a", + "d", + "y", + " ", + "s", + "t", + "r", + "i", + "n", + "g", + "s", + ",", + " ", + "e", + ".", + "g", + ".", + " ", + "\"", + "4", + "1", + "4", + "0", + "1", + "\"", + "\n", + " ", + " ", + " ", + " ", + " ", + " ", + " ", + " ", + " ", + " ", + " ", + " ", + "r", + "e", + "c", + "o", + "r", + "d", + "s", + ".", + "a", + "p", + "p", + "e", + "n", + "d", + "(", + "{", + "\n", + " ", + " ", + " ", + " ", + " ", + " ", + " ", + " ", + " ", + " ", + " ", + " ", + " ", + " ", + " ", + " ", + "'", + "S", + "U", + "B", + "J", + "E", + "C", + "T", + "_", + "I", + "D", + "'", + ":", + " ", + "p", + "a", + "t", + "i", + "e", + "n", + "t", + "_", + "i", + "d", + ",", + "\n", + " ", + " ", + " ", + " ", + " ", + " ", + " ", + " ", + " ", + " ", + " ", + " ", + " ", + " ", + " ", + " ", + "'", + "V", + "I", + "S", + "I", + "T", + "_", + "N", + "U", + "M", + "'", + ":", + " ", + "v", + "i", + "s", + "i", + "t", + "_", + "n", + "u", + "m", + ",", + "\n", + " ", + " ", + " ", + " ", + " ", + " ", + " ", + " ", + " ", + " ", + " ", + " ", + " ", + " ", + " ", + " ", + "'", + "I", + "C", + "D", + "9", + "_", + "C", + "O", + "D", + "E", + "'", + ":", + " ", + "i", + "c", + "d", + "9", + "_", + "c", + "o", + "d", + "e", + "\n", + " ", + " ", + " ", + " ", + " ", + " ", + " ", + " ", + " ", + " ", + " ", + " ", + "}", + ")", + "\n", + "\n", + "d", + "f", + " ", + "=", + " ", + "p", + "d", + ".", + "D", + "a", + "t", + "a", + "F", + "r", + "a", + "m", + "e", + "(", + "r", + "e", + "c", + "o", + "r", + "d", + "s", + ")", + "\n", + "d", + "f", + ".", + "t", + "o", + "_", + "c", + "s", + "v", + "(", + "c", + "s", + "v", + "_", + "p", + "a", + "t", + "h", + ",", + " ", + "i", + "n", + "d", + "e", + "x", + "=", + "F", + "a", + "l", + "s", + "e", + ")", + "\n", + "\n", + "p", + "r", + "i", + "n", + "t", + "(", + "f", + "\"", + "C", + "S", + "V", + " ", + "s", + "a", + "v", + "e", + "d", + " ", + "t", + "o", + ":", + " ", + "{", + "c", + "s", + "v", + "_", + "p", + "a", + "t", + "h", + "}", + "\"", + ",", + " ", + "f", + "l", + "u", + "s", + "h", + "=", + "T", + "r", + "u", + "e", + ")", + "\n", + "p", + "r", + "i", + "n", + "t", + "(", + "f", + "\"", + " ", + " ", + "T", + "o", + "t", + "a", + "l", + " ", + "r", + "e", + "c", + "o", + "r", + "d", + "s", + ":", + " ", + "{", + "l", + "e", + "n", + "(", + "d", + "f", + ")", + ":", + ",", + "}", + "\"", + ",", + " ", + "f", + "l", + "u", + "s", + "h", + "=", + "T", + "r", + "u", + "e", + ")", + "\n", + "p", + "r", + "i", + "n", + "t", + "(", + "f", + "\"", + " ", + " ", + "F", + "i", + "l", + "e", + " ", + "s", + "i", + "z", + "e", + ":", + " ", + "{", + "o", + "s", + ".", + "p", + "a", + "t", + "h", + ".", + "g", + "e", + "t", + "s", + "i", + "z", + "e", + "(", + "c", + "s", + "v", + "_", + "p", + "a", + "t", + "h", + ")", + " ", + "/", + " ", + "(", + "1", + "0", + "2", + "4", + "*", + "1", + "0", + "2", + "4", + ")", + ":", + ".", + "2", + "f", + "}", + " ", + "M", + "B", + "\"", + ",", + " ", + "f", + "l", + "u", + "s", + "h", + "=", + "T", + "r", + "u", + "e", + ")", + "\n", + "\n", + "p", + "r", + "i", + "n", + "t", + "(", + "f", + "\"", + "\\", + "n", + "F", + "i", + "r", + "s", + "t", + " ", + "1", + "0", + " ", + "r", + "o", + "w", + "s", + ":", + "\"", + ",", + " ", + "f", + "l", + "u", + "s", + "h", + "=", + "T", + "r", + "u", + "e", + ")", + "\n", + "d", + "i", + "s", + "p", + "l", + "a", + "y", + "(", + "d", + "f", + ".", + "h", + "e", + "a", + "d", + "(", + "1", + "0", + ")", + ")", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "# 6. Results & Download" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "View statistics and download your synthetic data.\n", + "\n", + "**CSV Format**:\n", + "- `SUBJECT_ID`: Synthetic patient identifier (SYNTHETIC_000001, etc.)\n", + "- `VISIT_NUM`: Visit sequence number (1, 2, 3, ...)\n", + "- `ICD9_CODE`: ICD-9 diagnosis code\n", + "\n", + "**Data Quality Checks**:\n", + "- Patient IDs should be unique and sequential\n", + "- Visit numbers should be sequential for each patient\n", + "- ICD-9 codes should match MIMIC-III code format" + ] + }, + { + "cell_type": "code", + "metadata": {}, + "execution_count": null, + "outputs": [], + "source": "# Validate generated data\nprint(\"=\" * 60)\nprint(\"DATA QUALITY CHECKS\")\nprint(\"=\" * 60)\n\n# Check 1: Patient IDs\nunique_patients = df['SUBJECT_ID'].nunique()\nprint(f\"\\n Unique patients: {unique_patients} out of {N_SYNTHETIC_SAMPLES} synthetic samples\")\n\n# Check 2: No empty values\nempty_subjects = df['SUBJECT_ID'].isna().sum()\nempty_visits = df['VISIT_NUM'].isna().sum()\nempty_codes = df['ICD9_CODE'].isna().sum()\n\nprint(f\" Empty values check:\")\nprint(f\" Subject IDs: {empty_subjects} (should be 0)\")\nprint(f\" Visit numbers: {empty_visits} (should be 0)\")\nprint(f\" ICD9 codes: {empty_codes} (should be 0)\")\n\nassert empty_subjects == 0 and empty_visits == 0 and empty_codes == 0, \"Found empty values!\"\n\n# Check 3: Visit number sequencing (sample first patient)\n# Each visit maps to many rows (one per code), so compare unique visit nums.\nvisit_nums = sorted(df[df['SUBJECT_ID'] == 'SYNTHETIC_000001']['VISIT_NUM'].unique().tolist())\nis_sequential = visit_nums == list(range(1, len(visit_nums) + 1))\nprint(f\"\\n Visit sequencing (first patient): {is_sequential}\")\n\n# Check 4: ICD9 code format (should be strings, common patterns)\nsample_codes = df['ICD9_CODE'].head(20).tolist()\nprint(f\"\\n Sample ICD9 codes: {', '.join(map(str, sample_codes[:10]))}\")\n\n# Check 5: Distribution statistics\ncodes_per_patient = df.groupby('SUBJECT_ID').size()\nprint(f\"\\n Codes per patient distribution:\")\nprint(f\" Min: {codes_per_patient.min()}\")\nprint(f\" Max: {codes_per_patient.max()}\")\nprint(f\" Mean: {codes_per_patient.mean():.2f}\")\nprint(f\" Median: {codes_per_patient.median():.2f}\")\n\nvisits_per_patient = df.groupby('SUBJECT_ID')['VISIT_NUM'].max()\nprint(f\"\\n Visits per patient distribution:\")\nprint(f\" Min: {visits_per_patient.min()}\")\nprint(f\" Max: {visits_per_patient.max()}\")\nprint(f\" Mean: {visits_per_patient.mean():.2f}\")\nprint(f\" Median: {visits_per_patient.median():.2f}\")\n\nprint(\"\\n\" + \"=\" * 60)\nprint(\" ALL QUALITY CHECKS PASSED\")\nprint(\"=\" * 60)" + }, + { + "cell_type": "code", + "metadata": {}, + "execution_count": null, + "outputs": [], + "source": [ + "# Download CSV file\n", + "print(\"=\" * 60)\n", + "print(\"DOWNLOAD SYNTHETIC DATA\")\n", + "print(\"=\" * 60)\n", + "\n", + "unique_patients = df['SUBJECT_ID'].nunique() # recompute in case cell 21 was skipped\n", + "print(f\"\\nYour synthetic data is ready:\")\n", + "print(f\" File: halo_synthetic_data.csv\")\n", + "print(f\" Patients: {unique_patients:,}\")\n", + "print(f\" Total records: {len(df):,}\")\n", + "print(f\" Size: {os.path.getsize(csv_path) / (1024*1024):.2f} MB\")\n", + "\n", + "print(f\"\\nDownloading file to your computer...\")\n", + "files.download(csv_path)\n", + "\n", + "print(f\"\\n Download started!\")\n", + "print(f\"\\nFile also saved in Google Drive:\")\n", + "print(f\" {csv_path}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "## Congratulations!\n", + "\n", + "You've successfully:\n", + "1. Trained a HALO model on your MIMIC-III data\n", + "2. Generated synthetic patients\n", + "3. Validated the synthetic data quality\n", + "4. Downloaded the CSV file\n", + "\n", + "## Next Steps\n", + "\n", + "**Use your synthetic data:**\n", + "- Train predictive models (readmission, mortality, etc.)\n", + "- Develop clinical decision support tools\n", + "- Share data without privacy concerns\n", + "\n", + "**Generate more samples:**\n", + "- Change `N_SYNTHETIC_SAMPLES` and re-run generation section\n", + "- No need to retrain - the checkpoint is saved!\n", + "\n", + "**Production training:**\n", + "- For publication-quality results, train with `EPOCHS = 80`\n", + "- Consider using Colab Pro or a dedicated GPU cluster\n", + "- See `examples/slurm/train_halo_mimic3.slurm` for cluster usage\n", + "\n", + "## Troubleshooting\n", + "\n", + "**Out of memory errors:**\n", + "- Reduce `BATCH_SIZE` (try 24 or 16)\n", + "- Close other browser tabs\n", + "- Restart runtime and try again\n", + "\n", + "**Training too slow:**\n", + "- Verify GPU is enabled (Runtime \u2192 Change runtime type)\n", + "- Reduce dataset size for testing\n", + "- Consider using Colab Pro for better GPUs\n", + "\n", + "**Questions or issues:**\n", + "- PyHealth docs: https://pyhealth.readthedocs.io/\n", + "- HALO paper: https://arxiv.org/abs/2406.16061\n", + "- GitHub issues: https://github.com/sunlabuiuc/PyHealth/issues" + ] + } + ], + "metadata": { + "colab": { + "provenance": [], + "gpuType": "T4", + "collapsed_sections": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU" + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/examples/halo_mimic3_training.py b/examples/halo_mimic3_training.py new file mode 100644 index 000000000..3bfcfdcc3 --- /dev/null +++ b/examples/halo_mimic3_training.py @@ -0,0 +1,66 @@ +""" +Example: Training HALO on MIMIC-III for synthetic EHR generation. + +This script demonstrates how to train the HALO model using PyHealth's +standard dataset and task patterns. HALO learns to generate synthetic +patient visit sequences via autoregressive transformer training. + +Usage: + python examples/halo_mimic3_training.py + +Replace the ``root`` path below with the local path to your MIMIC-III +data directory before running. +""" + +from pyhealth.datasets import MIMIC3Dataset, split_by_patient +from pyhealth.models.generators.halo import HALO +from pyhealth.tasks import halo_generation_mimic3_fn + +# Step 1: Load MIMIC-III dataset +print("Loading MIMIC-III dataset...") +base_dataset = MIMIC3Dataset( + root="/path/to/mimic3", + tables=["diagnoses_icd"], +) +base_dataset.stats() + +# Step 2: Set task for HALO generation +# halo_generation_mimic3_fn extracts diagnosis code sequences per patient. +# Each patient produces one sample with all their visits (admissions with +# at least one ICD-9 code). Patients with fewer than 2 qualifying visits +# are excluded. +print("Setting HALO generation task...") +sample_dataset = base_dataset.set_task(halo_generation_mimic3_fn) +print(f"Samples after task: {len(sample_dataset)}") + +# Step 3: Split dataset by patient (no patient appears in more than one split) +print("Splitting dataset...") +train_dataset, val_dataset, test_dataset = split_by_patient( + sample_dataset, [0.8, 0.1, 0.1] +) +print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}") + +# Step 4: Initialize HALO model +# The model derives vocabulary size automatically from the dataset's +# NestedSequenceProcessor. No manual vocabulary setup is needed. +print("Initializing HALO model...") +model = HALO( + dataset=sample_dataset, + embed_dim=768, + n_heads=12, + n_layers=12, + n_ctx=48, + batch_size=48, + epochs=50, + pos_loss_weight=None, + lr=1e-4, + save_dir="./save/", +) + +# Step 5: Train using HALO's custom training loop +# HALO does not use the PyHealth Trainer; it has its own loop that +# validates after every epoch and saves the best checkpoint to save_dir. +print("Starting training...") +model.train_model(train_dataset, val_dataset) + +print("Training complete. Best checkpoint saved to ./save/halo_model") diff --git a/examples/slurm/generate_halo_mimic3.slurm b/examples/slurm/generate_halo_mimic3.slurm new file mode 100644 index 000000000..9d563d708 --- /dev/null +++ b/examples/slurm/generate_halo_mimic3.slurm @@ -0,0 +1,36 @@ +#!/bin/bash +#SBATCH --job-name=halo_generate +#SBATCH --partition=gpu +#SBATCH --gres=gpu:1 +#SBATCH --cpus-per-task=4 +#SBATCH --mem=16G +#SBATCH --time=2:00:00 +#SBATCH --output=/scratch/%u/logs/halo_generate_%j.out + +# Canonical SLURM script for generating synthetic data with HALO +# Adjust paths, partition names, and resource allocations for your cluster + +# Navigate to working directory +cd "${SLURM_SUBMIT_DIR}" || exit 1 + +echo "SLURM_JOB_ID: ${SLURM_JOB_ID}" +echo "Starting HALO generation at: $(date)" +echo "========================================" + +# Activate your Python environment +# Example: conda activate pyhealth +# Example: source venv/bin/activate + +# Set Python path if needed +# export PYTHONPATH=/path/to/PyHealth:${PYTHONPATH} + +# Generation script +python examples/generate_synthetic_mimic3_halo.py \ + --checkpoint_dir /scratch/jalenj4/halo_results/ \ + --checkpoint_file halo_model_best \ + --output_pkl /scratch/jalenj4/halo_results/synthetic/halo_synthetic_10k.pkl \ + --output_csv /scratch/jalenj4/halo_results/synthetic/halo_synthetic_10k.csv \ + --n_samples 10000 + +echo "========================================" +echo "Generation completed at: $(date)" diff --git a/examples/slurm/train_halo_mimic3.slurm b/examples/slurm/train_halo_mimic3.slurm new file mode 100644 index 000000000..ba860e595 --- /dev/null +++ b/examples/slurm/train_halo_mimic3.slurm @@ -0,0 +1,38 @@ +#!/bin/bash +#SBATCH --job-name=halo_train +#SBATCH --partition=gpu +#SBATCH --gres=gpu:1 +#SBATCH --cpus-per-task=4 +#SBATCH --mem=32G +#SBATCH --time=12:00:00 +#SBATCH --output=/scratch/%u/logs/halo_train_%j.out + +# Canonical SLURM script for training HALO on MIMIC-III +# Adjust paths, partition names, and resource allocations for your cluster + +# Navigate to working directory +cd "${SLURM_SUBMIT_DIR}" || exit 1 + +echo "SLURM_JOB_ID: ${SLURM_JOB_ID}" +echo "Starting HALO training at: $(date)" +echo "========================================" + +# Activate your Python environment +# Example: conda activate pyhealth +# Example: source venv/bin/activate + +# Set Python path if needed +# export PYTHONPATH=/path/to/PyHealth:${PYTHONPATH} + +# Training script +python examples/halo_mimic3_training.py \ + --mimic3_dir /u/jalenj4/pehr_scratch/data_files_train/ \ + --output_dir /scratch/jalenj4/halo_results/ \ + --epochs 80 \ + --batch_size 48 \ + --learning_rate 0.0001 \ + --save_best \ + --save_final + +echo "========================================" +echo "Training completed at: $(date)" diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index effb47133..6933723a9 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -48,10 +48,16 @@ def __init__(self, *args, **kwargs): from .base_dataset import BaseDataset from .cardiology import CardiologyDataset -from .chestxray14 import ChestXray14Dataset +try: + from .chestxray14 import ChestXray14Dataset +except ImportError: + pass # PIL/torchvision unavailable from .clinvar import ClinVarDataset from .cosmic import COSMICDataset -from .covid19_cxr import COVID19CXRDataset +try: + from .covid19_cxr import COVID19CXRDataset +except ImportError: + pass # PIL/torchvision unavailable from .dreamt import DREAMTDataset from .ehrshot import EHRShotDataset from .eicu import eICUDataset @@ -63,7 +69,10 @@ def __init__(self, *args, **kwargs): from .omop import OMOPDataset from .sample_dataset import SampleBuilder, SampleDataset, create_sample_dataset from .shhs import SHHSDataset -from .sleepedf import SleepEDFDataset +try: + from .sleepedf import SleepEDFDataset +except ImportError: + pass # mne unavailable from .bmd_hs import BMDHSDataset from .support2 import Support2Dataset from .tcga_prad import TCGAPRADDataset @@ -76,8 +85,14 @@ def __init__(self, *args, **kwargs): split_by_visit, split_by_visit_conformal, ) -from .tuab import TUABDataset -from .tuev import TUEVDataset +try: + from .tuab import TUABDataset +except ImportError: + pass # mne unavailable; TUABDataset not registered +try: + from .tuev import TUEVDataset +except ImportError: + pass # mne unavailable; TUEVDataset not registered from .utils import ( collate_fn_dict, collate_fn_dict_with_padding, diff --git a/pyhealth/datasets/configs/hcup_ccs_2015_definitions_benchmark.yaml b/pyhealth/datasets/configs/hcup_ccs_2015_definitions_benchmark.yaml new file mode 100644 index 000000000..cf0caddfa --- /dev/null +++ b/pyhealth/datasets/configs/hcup_ccs_2015_definitions_benchmark.yaml @@ -0,0 +1,149 @@ +"Septicemia (except in labor)": + use_in_benchmark: True + type: "acute" + id: 2 + codes: [ "0031", "0202", "0223", "0362", "0380", "0381", "03810", "03811", "03812", "03819", "0382", "0383", "03840", "03841", "03842", "03843", "03844", "03849", "0388", "0389", "0545", "449", "77181", "7907", "99591", "99592" ] + +"Diabetes mellitus without complication": + use_in_benchmark: True + type: "chronic" + id: 49 + codes: [ "24900", "25000", "25001", "7902", "79021", "79022", "79029", "7915", "7916", "V4585", "V5391", "V6546" ] + +"Diabetes mellitus with complications": + use_in_benchmark: True + type: "chronic" + id: 50 + codes: [ "24901", "24910", "24911", "24920", "24921", "24930", "24931", "24940", "24941", "24950", "24951", "24960", "24961", "24970", "24971", "24980", "24981", "24990", "24991", "25002", "25003", "25010", "25011", "25012", "25013", "25020", "25021", "25022", "25023", "25030", "25031", "25032", "25033", "25040", "25041", "25042", "25043", "25050", "25051", "25052", "25053", "25060", "25061", "25062", "25063", "25070", "25071", "25072", "25073", "25080", "25081", "25082", "25083", "25090", "25091", "25092", "25093" ] + +"Disorders of lipid metabolism": + use_in_benchmark: True + type: "chronic" + id: 53 + codes: [ "2720", "2721", "2722", "2723", "2724" ] + +"Fluid and electrolyte disorders": + use_in_benchmark: True + type: "acute" + id: 55 + codes: [ "2760", "2761", "2762", "2763", "2764", "2765", "27650", "27651", "27652", "2766", "27669", "2767", "2768", "2769", "9951" ] + +"Essential hypertension": + use_in_benchmark: True + type: "chronic" + id: 98 + codes: [ "4011", "4019" ] + +"Hypertension with complications and secondary hypertension": + use_in_benchmark: True + type: "chronic" + id: 99 + codes: [ "4010", "40200", "40201", "40210", "40211", "40290", "40291", "4030", "40300", "40301", "4031", "40310", "40311", "4039", "40390", "40391", "4040", "40400", "40401", "40402", "40403", "4041", "40410", "40411", "40412", "40413", "4049", "40490", "40491", "40492", "40493", "40501", "40509", "40511", "40519", "40591", "40599", "4372" ] + +"Acute myocardial infarction": + use_in_benchmark: True + type: "acute" + id: 100 + codes: [ "4100", "41000", "41001", "41002", "4101", "41010", "41011", "41012", "4102", "41020", "41021", "41022", "4103", "41030", "41031", "41032", "4104", "41040", "41041", "41042", "4105", "41050", "41051", "41052", "4106", "41060", "41061", "41062", "4107", "41070", "41071", "41072", "4108", "41080", "41081", "41082", "4109", "41090", "41091", "41092" ] + +"Coronary atherosclerosis and other heart disease": + use_in_benchmark: True + type: "chronic" + id: 101 + codes: [ "4110", "4111", "4118", "41181", "41189", "412", "4130", "4131", "4139", "4140", "41400", "41401", "41406", "4142", "4143", "4144", "4148", "4149", "V4581", "V4582" ] + +"Conduction disorders": + use_in_benchmark: True + type: "chronic" + id: 105 + codes: [ "4260", "42610", "42611", "42612", "42613", "4262", "4263", "4264", "42650", "42651", "42652", "42653", "42654", "4266", "4267", "42681", "42682", "42689", "4269", "V450", "V4500", "V4501", "V4502", "V4509", "V533", "V5331", "V5332", "V5339" ] + +"Cardiac dysrhythmias": + use_in_benchmark: True + type: "chronic" + id: 106 + codes: [ "4270", "4271", "4272", "42731", "42732", "42760", "42761", "42769", "42781", "42789", "4279", "7850", "7851" ] + +"Congestive heart failure; nonhypertensive": + use_in_benchmark: True + type: "acute" + id: 108 + codes: [ "39891", "4280", "4281", "42820", "42821", "42822", "42823", "42830", "42831", "42832", "42833", "42840", "42841", "42842", "42843", "4289" ] + +"Acute cerebrovascular disease": + use_in_benchmark: True + type: "acute" + id: 109 + codes: [ "34660", "34661", "34662", "34663", "430", "431", "4320", "4321", "4329", "43301", "43311", "43321", "43331", "43381", "43391", "4340", "43400", "43401", "4341", "43410", "43411", "4349", "43490", "43491", "436" ] + +"Pneumonia (except that caused by tuberculosis or sexually transmitted disease)": + use_in_benchmark: True + type: "acute" + id: 122 + codes: [ "00322", "0203", "0204", "0205", "0212", "0221", "0310", "0391", "0521", "0551", "0730", "0830", "1124", "1140", "1144", "1145", "11505", "11515", "11595", "1304", "1363", "4800", "4801", "4802", "4803", "4808", "4809", "481", "4820", "4821", "4822", "4823", "48230", "48231", "48232", "48239", "4824", "48240", "48241", "48242", "48249", "4828", "48281", "48282", "48283", "48284", "48289", "4829", "483", "4830", "4831", "4838", "4841", "4843", "4845", "4846", "4847", "4848", "485", "486", "5130", "5171" ] + +"Chronic obstructive pulmonary disease and bronchiectasis": + use_in_benchmark: True + type: "chronic" + id: 127 + codes: [ "490", "4910", "4911", "4912", "49120", "49121", "49122", "4918", "4919", "4920", "4928", "494", "4940", "4941", "496" ] + +"Pleurisy; pneumothorax; pulmonary collapse": + use_in_benchmark: True + type: "acute" + id: 130 + codes: [ "5100", "5109", "5110", "5111", "5118", "51189", "5119", "5120", "5128", "51281", "51282", "51283", "51284", "51289", "5180", "5181", "5182" ] + +"Respiratory failure; insufficiency; arrest (adult)": + use_in_benchmark: True + type: "acute" + id: 131 + codes: [ "5173", "5185", "51851", "51852", "51853", "51881", "51882", "51883", "51884", "7991", "V461", "V4611", "V4612", "V4613", "V4614", "V462" ] + +"Other lower respiratory disease": + use_in_benchmark: True + type: "acute" + id: 133 + codes: [ "5131", "514", "515", "5160", "5161", "5162", "5163", "51630", "51631", "51632", "51633", "51634", "51635", "51636", "51637", "5164", "5165", "51661", "51662", "51663", "51664", "51669", "5168", "5169", "5172", "5178", "5183", "5184", "51889", "5194", "5198", "5199", "7825", "78600", "78601", "78602", "78603", "78604", "78605", "78606", "78607", "78609", "7862", "7863", "78630", "78631", "78639", "7864", "78652", "7866", "7867", "7868", "7869", "7931", "79311", "79319", "7942", "V126", "V1260", "V1261", "V1269", "V426" ] + +"Other upper respiratory disease": + use_in_benchmark: True + type: "acute" + id: 134 + codes: [ "470", "4710", "4711", "4718", "4719", "4720", "4721", "4722", "4760", "4761", "4770", "4772", "4778", "4779", "4780", "4781", "47811", "47819", "47820", "47821", "47822", "47824", "47825", "47826", "47829", "47830", "47831", "47832", "47833", "47834", "4784", "4785", "4786", "47870", "47871", "47874", "47875", "47879", "4788", "4789", "5191", "51911", "51919", "5192", "5193", "7841", "78440", "78441", "78442", "78443", "78444", "78449", "7847", "7848", "7849", "78499", "7861", "V414", "V440", "V550" ] + +"Other liver diseases": + use_in_benchmark: True + type: "acute" + id: 151 + codes: [ "570", "5715", "5716", "5718", "5719", "5720", "5721", "5722", "5723", "5724", "5728", "5730", "5734", "5735", "5738", "5739", "7824", "7891", "7895", "78959", "7904", "7905", "7948", "V427" ] + +"Gastrointestinal hemorrhage": + use_in_benchmark: True + type: "acute" + id: 153 + codes: [ "4560", "45620", "5307", "53082", "53100", "53101", "53120", "53121", "53140", "53141", "53160", "53161", "53200", "53201", "53220", "53221", "53240", "53241", "53260", "53261", "53300", "53301", "53320", "53321", "53340", "53341", "53360", "53361", "53400", "53401", "53420", "53421", "53440", "53441", "53460", "53461", "5693", "5780", "5781", "5789" ] + +"Acute and unspecified renal failure": + use_in_benchmark: True + type: "acute" + id: 157 + codes: [ "5845", "5846", "5847", "5848", "5849", "586" ] + +"Chronic kidney disease": + use_in_benchmark: True + type: "chronic" + id: 158 + codes: [ "585", "5851", "5852", "5853", "5854", "5855", "5856", "5859", "7925", "V420", "V451", "V4511", "V4512", "V560", "V561", "V562", "V5631", "V5632", "V568" ] + +"Complications of surgical procedures or medical care": + use_in_benchmark: True + type: "acute" + id: 238 + codes: [ "27661", "27783", "27788", "2853", "28741", "3490", "3491", "34931", "41511", "4294", "4582", "45821", "45829", "5121", "5122", "5187", "5190", "51900", "51901", "51902", "51909", "53086", "53087", "53640", "53641", "53642", "53649", "53901", "53909", "53981", "53989", "5642", "5643", "5644", "5696", "56962", "56971", "56979", "5793", "59681", "78062", "78063", "78066", "9093", "99524", "9954", "99586", "9970", "99700", "99701", "99702", "99709", "9971", "9972", "9973", "99731", "99732", "99739", "9974", "99741", "99749", "9975", "99760", "99761", "99762", "99769", "99771", "99772", "99779", "9979", "99791", "99799", "9980", "99800", "99801", "99802", "99809", "9981", "99811", "99812", "99813", "9982", "9983", "99830", "99831", "99832", "99833", "9984", "9985", "99851", "99859", "9986", "9987", "9988", "99881", "99882", "99883", "99889", "9989", "9990", "9991", "9992", "9993", "99934", "99939", "9994", "99941", "99942", "99949", "9995", "99951", "99952", "99959", "9996", "99960", "99961", "99962", "99963", "99969", "9997", "99970", "99971", "99972", "99973", "99974", "99975", "99976", "99977", "99978", "99979", "9998", "99980", "99981", "99982", "99983", "99984", "99985", "99988", "99989", "9999", "V1553", "V1580", "V1583", "V9001", "V9009" ] + +"Shock": + use_in_benchmark: True + type: "acute" + id: 249 + codes: [ "78550", "78551", "78552", "78559" ] \ No newline at end of file diff --git a/pyhealth/datasets/mimic3.py b/pyhealth/datasets/mimic3.py index 7e569d2f3..bdc00f904 100644 --- a/pyhealth/datasets/mimic3.py +++ b/pyhealth/datasets/mimic3.py @@ -53,7 +53,7 @@ def __init__( if config_path is None: logger.info("No config path provided, using default config") config_path = Path(__file__).parent / "configs" / "mimic3.yaml" - default_tables = ["patients", "admissions", "icustays"] + default_tables = ["patients", "admissions"] tables = default_tables + tables if "prescriptions" in tables: warnings.warn( diff --git a/pyhealth/datasets/sleepedf.py b/pyhealth/datasets/sleepedf.py index 71ef3a450..cc9aa6bac 100644 --- a/pyhealth/datasets/sleepedf.py +++ b/pyhealth/datasets/sleepedf.py @@ -6,7 +6,10 @@ import pandas as pd from pyhealth.datasets import BaseDataset -from pyhealth.tasks.sleep_staging_v2 import SleepStagingSleepEDF +try: + from pyhealth.tasks.sleep_staging_v2 import SleepStagingSleepEDF +except ImportError: + SleepStagingSleepEDF = None # mne unavailable logger = logging.getLogger(__name__) @@ -230,7 +233,7 @@ def normalize(value: object) -> str: return flattened @property - def default_task(self) -> SleepStagingSleepEDF: + def default_task(self) -> "SleepStagingSleepEDF": """Returns the default task for this dataset. Returns: diff --git a/pyhealth/datasets/tuab.py b/pyhealth/datasets/tuab.py index e2a3fc69c..adc59e396 100644 --- a/pyhealth/datasets/tuab.py +++ b/pyhealth/datasets/tuab.py @@ -5,7 +5,10 @@ from typing import Optional from .base_dataset import BaseDataset -from pyhealth.tasks import EEGAbnormalTUAB +try: + from pyhealth.tasks import EEGAbnormalTUAB +except ImportError: + EEGAbnormalTUAB = None logger = logging.getLogger(__name__) diff --git a/pyhealth/datasets/tuev.py b/pyhealth/datasets/tuev.py index 7e8dacf98..1e7f3aae2 100644 --- a/pyhealth/datasets/tuev.py +++ b/pyhealth/datasets/tuev.py @@ -5,7 +5,10 @@ from typing import Optional from .base_dataset import BaseDataset -from pyhealth.tasks import EEGEventsTUEV +try: + from pyhealth.tasks import EEGEventsTUEV +except ImportError: + EEGEventsTUEV = None logger = logging.getLogger(__name__) diff --git a/pyhealth/metrics/__init__.py b/pyhealth/metrics/__init__.py index da8da0f5b..8e834ea5a 100644 --- a/pyhealth/metrics/__init__.py +++ b/pyhealth/metrics/__init__.py @@ -1,4 +1,10 @@ -from .binary import binary_metrics_fn +try: + from .binary import binary_metrics_fn + from .multiclass import multiclass_metrics_fn + from .multilabel import multilabel_metrics_fn + from .regression import regression_metrics_fn +except ImportError: + pass # sklearn unavailable from .drug_recommendation import ddi_rate_score from .interpretability import ( ComprehensivenessMetric, @@ -7,12 +13,8 @@ SufficiencyMetric, evaluate_attribution, ) -from .multiclass import multiclass_metrics_fn -from .multilabel import multilabel_metrics_fn - # from .fairness import fairness_metrics_fn from .ranking import ranking_metrics_fn -from .regression import regression_metrics_fn __all__ = [ "binary_metrics_fn", diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index a13b18a51..3f89126a4 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -1,8 +1,14 @@ from .adacare import AdaCare, AdaCareLayer from .agent import Agent, AgentLayer from .base_model import BaseModel -from .biot import BIOT -from .cnn import CNN, CNNLayer +try: + from .biot import BIOT +except ImportError: + pass # einops unavailable +try: + from .cnn import CNN, CNNLayer +except ImportError: + pass # PIL/torchvision unavailable from .concare import ConCare, ConCareLayer from .contrawr import ContraWR, ResBlock2D from .deepr import Deepr, DeeprLayer @@ -11,35 +17,65 @@ from .jamba_ehr import JambaEHR, JambaLayer from .logistic_regression import LogisticRegression from .gan import GAN +from .generators.halo import HALO from .gnn import GAT, GCN -from .graph_torchvision_model import Graph_TorchvisionModel -from .grasp import GRASP, GRASPLayer -from .medlink import MedLink +try: + from .graph_torchvision_model import Graph_TorchvisionModel +except ImportError: + pass # torchvision unavailable +try: + from .grasp import GRASP, GRASPLayer +except ImportError: + pass # sklearn unavailable from .micron import MICRON, MICRONLayer from .mlp import MLP -from .molerec import MoleRec, MoleRecLayer +try: + from .molerec import MoleRec, MoleRecLayer +except ImportError: + pass # rdkit unavailable from .retain import RETAIN, RETAINLayer from .rnn import MultimodalRNN, RNN, RNNLayer -from .safedrug import SafeDrug, SafeDrugLayer +try: + from .safedrug import SafeDrug, SafeDrugLayer +except ImportError: + pass # rdkit unavailable from .sparcnet import DenseBlock, DenseLayer, SparcNet, TransitionLayer from .stagenet import StageNet, StageNetLayer from .stagenet_mha import StageAttentionNet, StageNetAttentionLayer from .tcn import TCN, TCNLayer -from .tfm_tokenizer import ( - TFMTokenizer, - TFM_VQVAE2_deep, - TFM_TOKEN_Classifier, - get_tfm_tokenizer_2x2x8, - get_tfm_token_classifier_64x4, - load_embedding_weights, -) -from .torchvision_model import TorchvisionModel +try: + from .tfm_tokenizer import ( + TFMTokenizer, + TFM_VQVAE2_deep, + TFM_TOKEN_Classifier, + get_tfm_tokenizer_2x2x8, + get_tfm_token_classifier_64x4, + load_embedding_weights, + ) +except ImportError: + pass # einops unavailable +try: + from .torchvision_model import TorchvisionModel +except ImportError: + pass # torchvision unavailable from .transformer import Transformer, TransformerLayer -from .transformers_model import TransformersModel +try: + from .transformers_model import TransformersModel +except ImportError: + pass # transformers unavailable from .ehrmamba import EHRMamba, MambaBlock from .vae import VAE -from .vision_embedding import VisionEmbeddingModel -from .text_embedding import TextEmbedding -from .sdoh import SdohClassifier +try: + from .vision_embedding import VisionEmbeddingModel +except ImportError: + pass # PIL/torchvision unavailable +try: + from .text_embedding import TextEmbedding +except ImportError: + pass # transformers unavailable +try: + from .sdoh import SdohClassifier +except ImportError: + pass # transformers/peft unavailable from .medlink import MedLink from .unified_embedding import UnifiedMultimodalEmbeddingModel, SinusoidalTimeEmbedding diff --git a/pyhealth/models/generators/__init__.py b/pyhealth/models/generators/__init__.py new file mode 100644 index 000000000..66f57e496 --- /dev/null +++ b/pyhealth/models/generators/__init__.py @@ -0,0 +1 @@ +from .halo import HALO diff --git a/pyhealth/models/generators/halo.py b/pyhealth/models/generators/halo.py new file mode 100644 index 000000000..06494a039 --- /dev/null +++ b/pyhealth/models/generators/halo.py @@ -0,0 +1,362 @@ +import os +import numpy as np +import torch +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import DataLoader +from tqdm import tqdm +from typing import Dict, List, Optional + +from pyhealth.models import BaseModel + +from pyhealth.models.generators.halo_resources.halo_model import HALOModel +from pyhealth.models.generators.halo_resources.halo_config import HALOConfig + + +def _halo_collate_fn(batch): + """Collate HALO samples, padding the visit dimension across the batch.""" + visits = pad_sequence( + [item["visits"] for item in batch], + batch_first=True, + padding_value=0, + ) + collated = {k: v for k, v in batch[0].items() if k != "visits"} + collated["visits"] = visits + return collated + + +class HALO(BaseModel): + """HALO: Heterogeneous Autoregressive Language mOdel for synthetic EHR generation. + + Trains a GPT-2-style transformer on patient visit sequences and generates + synthetic patients by autoregressive sampling. + + Args: + dataset (SampleDataset): A SampleDataset whose input_schema contains + ``{"visits": "nested_sequence"}`` and whose output_schema is empty. + embed_dim: Transformer embedding dimension. Default: 768. + n_heads: Number of attention heads. Default: 12. + n_layers: Number of transformer layers. Default: 12. + n_ctx: Maximum number of visits (context length). Default: 48. + batch_size: Training batch size. Default: 48. + epochs: Number of training epochs. Default: 50. + pos_loss_weight: Positive-class weight for BCE loss. None means no + weighting. Default: None. + lr: Learning rate for Adam optimizer. Default: 1e-4. + save_dir: Directory to save model checkpoints. Default: ``"./save/"``. + + Examples: + >>> from pyhealth.datasets.sample_dataset import InMemorySampleDataset + >>> samples = [ + ... {"patient_id": "p1", "visits": [["A", "B"], ["C"]]}, + ... {"patient_id": "p2", "visits": [["A"], ["B", "C"]]}, + ... ] + >>> dataset = InMemorySampleDataset( + ... samples=samples, + ... input_schema={"visits": "nested_sequence"}, + ... output_schema={}, + ... ) + >>> model = HALO(dataset, embed_dim=64, n_heads=2, n_layers=2, n_ctx=8) + >>> isinstance(model, HALO) + True + """ + + def __init__( + self, + dataset, + embed_dim: int = 768, + n_heads: int = 12, + n_layers: int = 12, + n_ctx: int = 48, + batch_size: int = 48, + epochs: int = 50, + pos_loss_weight: Optional[float] = None, + lr: float = 1e-4, + save_dir: str = "./save/", + ) -> None: + super(HALO, self).__init__(dataset) + + self.save_dir = save_dir + self._batch_size = batch_size + self._epochs = epochs + self._lr = lr + + # Derive vocab sizes from the dataset's NestedSequenceProcessor. + visits_processor = dataset.input_processors["visits"] + code_vocab_size = visits_processor.vocab_size() # includes and + label_vocab_size = 0 # generative task — no output labels + # +3 special tokens: start-of-sequence, end-of-sequence, pad-visit + total_vocab_size = code_vocab_size + label_vocab_size + 3 + + self.config = HALOConfig( + total_vocab_size=total_vocab_size, + code_vocab_size=code_vocab_size, + label_vocab_size=label_vocab_size, + special_vocab_size=3, + n_positions=n_ctx + 8, # position embedding table; needs a bit of slack + n_ctx=n_ctx, + n_embd=embed_dim, + n_layer=n_layers, + n_head=n_heads, + batch_size=batch_size, + epoch=epochs, + pos_loss_weight=pos_loss_weight, + lr=lr, + ) + + # Store processor reference for use in synthesize_dataset. + self.visits_processor = visits_processor + + # Register as an nn.Module sub-module so parameters() works correctly. + self.halo_model = HALOModel(self.config) + + # ------------------------------------------------------------------ + # Multi-hot encoding helper + # ------------------------------------------------------------------ + + def _encode_visits(self, visits: torch.Tensor): + """Convert a padded index tensor to HALO multi-hot format. + + The NestedSequenceProcessor returns indices; HALO's transformer expects + multi-hot vectors of shape (batch, n_ctx, total_vocab_size). + + Args: + visits: LongTensor of shape (batch, max_visits, max_codes_per_visit). + Index 0 is the pad token and is skipped. + + Returns: + batch_ehr: FloatTensor of shape ``(batch, n_ctx, total_vocab_size)``. + batch_mask: FloatTensor of shape ``(batch, n_ctx-1, 1)``, shifted to + align with the autoregressive prediction targets. + """ + cfg = self.config + batch_size = visits.shape[0] + + batch_ehr = torch.zeros( + batch_size, cfg.n_ctx, cfg.total_vocab_size, device=self.device + ) + batch_mask = torch.zeros(batch_size, cfg.n_ctx, 1, device=self.device) + + for i in range(batch_size): + n_visits = min(visits.shape[1], cfg.n_ctx - 2) + for j in range(n_visits): + for code_idx in visits[i, j]: + if code_idx > 0: # skip padding (index 0) + batch_ehr[i, j + 2, code_idx] = 1 # visits occupy positions 2+ + if visits[i, j].sum() > 0: + batch_mask[i, j + 2] = 1 + + # Special tokens (label_vocab_size == 0, so the 3 extras are contiguous): + batch_ehr[i, 0, cfg.code_vocab_size + cfg.label_vocab_size] = 1 # start + batch_ehr[i, n_visits + 2, cfg.code_vocab_size + cfg.label_vocab_size + 1] = 1 # end + batch_ehr[i, n_visits + 3:, cfg.code_vocab_size + cfg.label_vocab_size + 2] = 1 # pad + + batch_mask[:, 1] = 1 # label-position mask row + batch_mask = batch_mask[:, 1:, :] # shift mask to align with shifted labels/preds + + return batch_ehr, batch_mask + + # ------------------------------------------------------------------ + # forward — required by BaseModel (abstract in nn.Module) + # ------------------------------------------------------------------ + + def forward(self, visits: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: + """Forward pass. + + Accepts the padded index tensor produced by NestedSequenceProcessor, + converts it to HALO multi-hot format, and runs the transformer. + + Args: + visits: LongTensor of shape ``(batch, max_visits, max_codes_per_visit)``. + **kwargs: Additional keys from the batch dict are ignored. + + Returns: + loss: scalar BCE loss tensor. + predictions: code probability tensor of shape + ``(batch, n_ctx, total_vocab_size)``. + """ + visits = visits.to(self.device) + batch_ehr, batch_mask = self._encode_visits(visits) + + loss, predictions, _ = self.halo_model( + batch_ehr, + position_ids=None, + ehr_labels=batch_ehr, + ehr_masks=batch_mask, + pos_loss_weight=self.config.pos_loss_weight, + ) + return {"loss": loss, "predictions": predictions} + + # ------------------------------------------------------------------ + # Custom training loop + # ------------------------------------------------------------------ + + def train_model(self, train_dataset, val_dataset=None) -> None: + """Train the HALO model using a custom loop. + + Named ``train_model`` (not ``train``) to avoid shadowing ``nn.Module.train()``. + + Args: + train_dataset: SampleDataset for training. + val_dataset: Optional SampleDataset for validation. When provided, + validation loss is computed after every epoch and the best + checkpoint is saved to ``self.save_dir``. + """ + os.makedirs(self.save_dir, exist_ok=True) + optimizer = torch.optim.Adam(self.halo_model.parameters(), lr=self._lr) + + checkpoint_path = os.path.join(self.save_dir, "halo_model") + if os.path.exists(checkpoint_path): + checkpoint = torch.load(checkpoint_path, map_location=self.device) + self.halo_model.load_state_dict(checkpoint["model"]) + optimizer.load_state_dict(checkpoint["optimizer"]) + + train_loader = DataLoader( + train_dataset, + batch_size=self._batch_size, + shuffle=False, # IterableDataset (litdata.StreamingDataset) does not support shuffle + drop_last=False, + collate_fn=_halo_collate_fn, + ) + + global_loss = 1e10 + + for epoch in tqdm(range(self._epochs)): + self.halo_model.train() + for batch in train_loader: + visits = batch["visits"].to(self.device) + batch_ehr, batch_mask = self._encode_visits(visits) + + optimizer.zero_grad() + loss, _, _ = self.halo_model( + batch_ehr, + position_ids=None, + ehr_labels=batch_ehr, + ehr_masks=batch_mask, + pos_loss_weight=self.config.pos_loss_weight, + ) + loss.backward() + optimizer.step() + + if val_dataset is not None: + self.halo_model.eval() + val_loader = DataLoader( + val_dataset, + batch_size=self._batch_size, + shuffle=False, + drop_last=False, + collate_fn=_halo_collate_fn, + ) + val_losses = [] + with torch.no_grad(): + for val_batch in val_loader: + visits = val_batch["visits"].to(self.device) + batch_ehr, batch_mask = self._encode_visits(visits) + val_loss, _, _ = self.halo_model( + batch_ehr, + position_ids=None, + ehr_labels=batch_ehr, + ehr_masks=batch_mask, + pos_loss_weight=self.config.pos_loss_weight, + ) + val_losses.append(val_loss.item()) + + cur_val_loss = float(np.mean(val_losses)) + print(f"Epoch {epoch} Validation Loss: {cur_val_loss:.7f}") + if cur_val_loss < global_loss: + global_loss = cur_val_loss + state = { + "model": self.halo_model.state_dict(), + "optimizer": optimizer.state_dict(), + "epoch": epoch, + } + torch.save(state, checkpoint_path) + print("------------ Save best model ------------") + + # ------------------------------------------------------------------ + # Synthesis + # ------------------------------------------------------------------ + + def synthesize_dataset( + self, num_samples: int, random_sampling: bool = True + ) -> List[Dict]: + """Generate synthetic patients using the trained HALO model. + + Autoregressive sampling: feeds a start token and iteratively calls + ``halo_model.sample()`` until an end token is produced or ``n_ctx`` + steps are reached. + + Args: + num_samples: Number of synthetic patients to generate. + random_sampling: If True, samples via Bernoulli (stochastic). + If False, uses rounding (deterministic). Default: True. + + Returns: + list of dict: Synthetic patient records. Each dict has two keys: + ``"patient_id"`` (str): unique identifier, e.g. ``"synthetic_0"``. + ``"visits"`` (list of list of str): decoded code strings per visit. + """ + cfg = self.config + # Invert vocabulary: index → code string + index_to_code = {v: k for k, v in self.visits_processor.code_vocab.items()} + + end_token_idx = cfg.code_vocab_size + cfg.label_vocab_size + 1 + + # Build the start-token vector + stoken = torch.zeros(cfg.total_vocab_size, device=self.device, dtype=torch.float32) + stoken[cfg.code_vocab_size + cfg.label_vocab_size] = 1 # start token + + self.halo_model.eval() + synthetic_dataset = [] + sample_batch_size = min(num_samples, 256) + generated = 0 + + with torch.no_grad(): + while generated < num_samples: + bs = min(sample_batch_size, num_samples - generated) + # prev: (bs, 1, total_vocab_size) — starts with just the start token + prev = stoken.unsqueeze(0).unsqueeze(0).repeat(bs, 1, 1) + empty = torch.zeros( + bs, 1, cfg.total_vocab_size, device=self.device, dtype=torch.float32 + ) + + for _ in range(cfg.n_ctx - 1): + prev = self.halo_model.sample( + torch.cat((prev, empty), dim=1), random_sampling + ) + # Early stop when all sequences have produced an end token + has_end = prev[:, :, end_token_idx].sum(dim=1).bool() + if has_end.all(): + break + + batch_ehrs = prev.cpu().detach().numpy() + + for i in range(bs): + ehr = batch_ehrs[i] # (seq_len, total_vocab_size) + visits_out = [] + # Position 0 = start token; visits occupy positions 1+ + for j in range(1, len(ehr)): + visit_row = ehr[j] + indices = np.nonzero(visit_row)[0] + visit_codes = [] + hit_end = False + for idx in indices: + if idx < cfg.code_vocab_size: + code = index_to_code.get(idx) + if code not in (None, "", ""): + visit_codes.append(code) + elif idx == end_token_idx: + hit_end = True + if visit_codes: + visits_out.append(visit_codes) + if hit_end: + break + + synthetic_dataset.append( + { + "patient_id": f"synthetic_{generated + i}", + "visits": visits_out, + } + ) + generated += bs + + return synthetic_dataset diff --git a/pyhealth/models/generators/halo_resources/__init__.py b/pyhealth/models/generators/halo_resources/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pyhealth/models/generators/halo_resources/halo_config.py b/pyhealth/models/generators/halo_resources/halo_config.py new file mode 100644 index 000000000..2127aa3a0 --- /dev/null +++ b/pyhealth/models/generators/halo_resources/halo_config.py @@ -0,0 +1,42 @@ +''' + code by Brandon Theodorou + Original GPT-2 Paper and repository here: https://github.com/openai/gpt-2 + Original GPT-2 Pytorch Model: https://github.com/huggingface/pytorch-pretrained-BERT + GPT-2 Pytorch Model Derived From: https://github.com/graykode/gpt-2-Pytorch +''' +class HALOConfig(object): + def __init__( + self, + total_vocab_size=6984, + code_vocab_size=6841, + label_vocab_size=25, + special_vocab_size=3, + n_positions=56, + n_ctx=48, + n_embd=768, + n_layer=12, + n_head=12, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + batch_size=48, + sample_batch_size=256, + epoch=50, + pos_loss_weight=None, + lr=1e-4, + ): + self.total_vocab_size = total_vocab_size + self.code_vocab_size = code_vocab_size + self.label_vocab_size = label_vocab_size + self.special_vocab_size = special_vocab_size + self.n_positions = n_positions + self.n_ctx = n_ctx + self.n_embd = n_embd + self.n_layer = n_layer + self.n_head = n_head + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.batch_size = batch_size + self.sample_batch_size = sample_batch_size + self.epoch = epoch + self.pos_loss_weight = pos_loss_weight + self.lr = lr \ No newline at end of file diff --git a/pyhealth/models/generators/halo_resources/halo_model.py b/pyhealth/models/generators/halo_resources/halo_model.py new file mode 100644 index 000000000..a25a9fe04 --- /dev/null +++ b/pyhealth/models/generators/halo_resources/halo_model.py @@ -0,0 +1,237 @@ +''' + code by Brandon Theodorou + Original GPT-2 Paper and repository here: https://github.com/openai/gpt-2 + Original GPT-2 Pytorch Model: https://github.com/huggingface/pytorch-pretrained-BERT + GPT-2 Pytorch Model Derived From: https://github.com/graykode/gpt-2-Pytorch +''' +import copy +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +def gelu(x): + return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + +class LayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-12): + """Construct a layernorm module in the TF style (epsilon inside the square root).""" + super(LayerNorm, self).__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.bias = nn.Parameter(torch.zeros(hidden_size)) + self.variance_epsilon = eps + + def forward(self, x): + u = x.mean(-1, keepdim=True) + s = (x - u).pow(2).mean(-1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.variance_epsilon) + return self.weight * x + self.bias + +class Conv1D(nn.Module): + def __init__(self, nf, nx): + super(Conv1D, self).__init__() + self.nf = nf + w = torch.empty(nx, nf) + nn.init.normal_(w, std=0.02) + self.weight = nn.Parameter(w) + self.bias = nn.Parameter(torch.zeros(nf)) + + def forward(self, x): + size_out = x.size()[:-1] + (self.nf,) + x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) + x = x.view(*size_out) + return x + +class Attention(nn.Module): + def __init__(self, nx, n_ctx, config, scale=False): + super(Attention, self).__init__() + n_state = nx # in Attention: n_state=768 (nx=n_embd) + assert n_state % config.n_head == 0 + self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx)) + self.n_head = config.n_head + self.split_size = n_state + self.scale = scale + self.c_attn = Conv1D(n_state * 3, nx) + self.c_proj = Conv1D(n_state, nx) + + def _attn(self, q, k, v): + w = torch.matmul(q, k) + if self.scale: + w = w / math.sqrt(v.size(-1)) + nd, ns = w.size(-2), w.size(-1) + b = self.bias[:, :, ns-nd:ns, :ns] + w = w * b - 1e10 * (1 - b) + w = nn.Softmax(dim=-1)(w) + return torch.matmul(w, v) + + def merge_heads(self, x): + x = x.permute(0, 2, 1, 3).contiguous() + new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),) + return x.view(*new_x_shape) + + def split_heads(self, x, k=False): + new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head) + x = x.view(*new_x_shape) + if k: + return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length) + else: + return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) + + def forward(self, x, layer_past=None): + x = self.c_attn(x) + query, key, value = x.split(self.split_size, dim=2) + query = self.split_heads(query) + key = self.split_heads(key, k=True) + value = self.split_heads(value) + if layer_past is not None: + past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below + key = torch.cat((past_key, key), dim=-1) + value = torch.cat((past_value, value), dim=-2) + present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking + a = self._attn(query, key, value) + a = self.merge_heads(a) + a = self.c_proj(a) + return a, present + +class MLP(nn.Module): + def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd) + super(MLP, self).__init__() + nx = config.n_embd + self.c_fc = Conv1D(n_state, nx) + self.c_proj = Conv1D(nx, n_state) + self.act = gelu + + def forward(self, x): + h = self.act(self.c_fc(x)) + h2 = self.c_proj(h) + return h2 + +class Block(nn.Module): + def __init__(self, n_ctx, config, scale=False): + super(Block, self).__init__() + nx = config.n_embd + self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon) + self.attn = Attention(nx, n_ctx, config, scale) + self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon) + self.mlp = MLP(4 * nx, config) + + def forward(self, x, layer_past=None): + a, present = self.attn(self.ln_1(x), layer_past=layer_past) + x = x + a + m = self.mlp(self.ln_2(x)) + x = x + m + return x, present + +class CoarseTransformerModel(nn.Module): + def __init__(self, config): + super(CoarseTransformerModel, self).__init__() + self.n_layer = config.n_layer + self.n_embd = config.n_embd + self.n_vocab = config.total_vocab_size + + self.vis_embed_mat = nn.Linear(config.total_vocab_size, config.n_embd, bias=False) + self.pos_embed_mat = nn.Embedding(config.n_positions, config.n_embd) + block = Block(config.n_ctx, config, scale=True) + self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)]) + self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) + + def forward(self, input_visits, position_ids=None, past=None): + if past is None: + past_length = 0 + past = [None] * len(self.h) + else: + past_length = past[0][0].size(-2) + if position_ids is None: + position_ids = torch.arange(past_length, input_visits.size(1) + past_length, dtype=torch.long, + device=input_visits.device) + position_ids = position_ids.unsqueeze(0).expand(input_visits.size(0), input_visits.size(1)) + + inputs_embeds = self.vis_embed_mat(input_visits) + position_embeds = self.pos_embed_mat(position_ids) + hidden_states = inputs_embeds + position_embeds + for block, layer_past in zip(self.h, past): + hidden_states, _ = block(hidden_states, layer_past) + hidden_states = self.ln_f(hidden_states) + return hidden_states + +class AutoregressiveLinear(nn.Linear): + """ same as Linear except has a configurable mask on the weights """ + def __init__(self, in_features, out_features, bias=True): + super().__init__(in_features, out_features, bias) + self.register_buffer('mask', torch.tril(torch.ones(in_features, out_features)).int()) + + def forward(self, input): + return F.linear(input, self.mask * self.weight, self.bias) + +class FineAutoregressiveHead(nn.Module): + def __init__(self, config): + super(FineAutoregressiveHead, self).__init__() + self.auto1 = AutoregressiveLinear(config.n_embd + config.total_vocab_size, config.n_embd + config.total_vocab_size) + self.auto2 = AutoregressiveLinear(config.n_embd + config.total_vocab_size, config.n_embd + config.total_vocab_size) + self.n_embd = config.n_embd + self.tot_vocab = config.total_vocab_size + + def forward(self, history, input_visits): + history = history[:,:-1,:] + input_visits = input_visits[:,1:,:] + code_logits = self.auto2(torch.relu(self.auto1(torch.cat((history, input_visits), dim=2))))[:,:,self.n_embd-1:-1] + return code_logits + + def sample(self, history, input_visits): + history = history[:,:-1,:] + input_visits = input_visits[:,1:,:] + currVisit = torch.cat((history, input_visits), dim=2)[:,-1,:].unsqueeze(1) + code_logits = self.auto2(torch.relu(self.auto1(currVisit)))[:,:,self.n_embd-1:-1] + return code_logits + +class HALOModel(nn.Module): + def __init__(self, config): + super(HALOModel, self).__init__() + self.transformer = CoarseTransformerModel(config) + self.ehr_head = FineAutoregressiveHead(config) + + def forward(self, input_visits, position_ids=None, ehr_labels=None, ehr_masks=None, past=None, pos_loss_weight=None): + hidden_states = self.transformer(input_visits, position_ids, past) + code_logits = self.ehr_head(hidden_states, input_visits) + sig = nn.Sigmoid() + code_probs = sig(code_logits) + if ehr_labels is not None: + shift_labels = ehr_labels[..., 1:, :].contiguous() + loss_weights = None + if pos_loss_weight is not None: + loss_weights = torch.ones(code_probs.shape, device=code_probs.device) + loss_weights = loss_weights + (pos_loss_weight-1) * shift_labels + if ehr_masks is not None: + code_probs = code_probs * ehr_masks + shift_labels = shift_labels * ehr_masks + if pos_loss_weight is not None: + loss_weights = loss_weights * ehr_masks + + bce = nn.BCELoss(weight=loss_weights) + loss = bce(code_probs, shift_labels) + return loss, code_probs, shift_labels + + return code_probs + + def sample(self, input_visits, random=True): + sig = nn.Sigmoid() + hidden_states = self.transformer(input_visits) + i = 0 + while i < self.ehr_head.tot_vocab: + next_logits = self.ehr_head.sample(hidden_states, input_visits) + next_probs = sig(next_logits) + if random: + visit = torch.bernoulli(next_probs) + else: + visit = torch.round(next_probs) + + remaining_visit = visit[:,0,i:] + nonzero = torch.nonzero(remaining_visit, as_tuple=True)[1] + if nonzero.numel() == 0: + break + + first_nonzero = nonzero.min() + input_visits[:,-1,i + first_nonzero] = visit[:,0,i + first_nonzero] + i = i + first_nonzero + 1 + + return input_visits \ No newline at end of file diff --git a/pyhealth/models/generators/medgan.py b/pyhealth/models/generators/medgan.py new file mode 100644 index 000000000..e69de29bb diff --git a/pyhealth/models/generators/promptehr.py b/pyhealth/models/generators/promptehr.py new file mode 100644 index 000000000..e69de29bb diff --git a/pyhealth/processors/__init__.py b/pyhealth/processors/__init__.py index b48072270..c1c41f1da 100644 --- a/pyhealth/processors/__init__.py +++ b/pyhealth/processors/__init__.py @@ -18,7 +18,10 @@ def get_processor(name: str): # Import all processors so they register themselves -from .image_processor import ImageProcessor +try: + from .image_processor import ImageProcessor +except ImportError: + pass # PIL/torchvision unavailable from .label_processor import ( BinaryLabelProcessor, MultiClassLabelProcessor, @@ -44,7 +47,10 @@ def get_processor(name: str): from .tensor_processor import TensorProcessor from .text_processor import TextProcessor from .timeseries_processor import TimeseriesProcessor -from .time_image_processor import TimeImageProcessor +try: + from .time_image_processor import TimeImageProcessor +except ImportError: + pass # PIL/torchvision unavailable from .graph_processor import GraphProcessor from .audio_processor import AudioProcessor from .ignore_processor import IgnoreProcessor diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 2f4294a19..136039a8d 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -2,13 +2,16 @@ from .benchmark_ehrshot import BenchmarkEHRShot from .cancer_survival import CancerMutationBurden, CancerSurvivalPrediction from .bmd_hs_disease_classification import BMDHSDiseaseClassification -from .cardiology_detect import ( - cardiology_isAD_fn, - cardiology_isAR_fn, - cardiology_isBBBFB_fn, - cardiology_isCD_fn, - cardiology_isWA_fn, -) +try: + from .cardiology_detect import ( + cardiology_isAD_fn, + cardiology_isAR_fn, + cardiology_isBBBFB_fn, + cardiology_isCD_fn, + cardiology_isWA_fn, + ) +except ImportError: + pass # scipy unavailable; cardiology tasks not registered from .chestxray14_binary_classification import ChestXray14BinaryClassification from .chestxray14_multilabel_classification import ChestXray14MultilabelClassification from .covid19_cxr_classification import COVID19CXRClassification @@ -21,8 +24,20 @@ drug_recommendation_mimic4_fn, drug_recommendation_omop_fn, ) -from .EEG_abnormal import EEG_isAbnormal_fn -from .EEG_events import EEG_events_fn +from .halo_generation import ( + HaloGenerationMIMIC3, + HaloGenerationMIMIC4, + halo_generation_mimic3_fn, + halo_generation_mimic4_fn, +) +try: + from .EEG_abnormal import EEG_isAbnormal_fn +except ImportError: + pass # mne unavailable +try: + from .EEG_events import EEG_events_fn +except ImportError: + pass # mne unavailable from .in_hospital_mortality_mimic4 import InHospitalMortalityMIMIC4 from .length_of_stay_prediction import ( LengthOfStayPredictioneICU, @@ -53,16 +68,25 @@ ReadmissionPredictionMIMIC4, ReadmissionPredictionOMOP, ) -from .sleep_staging import ( - sleep_staging_isruc_fn, - sleep_staging_shhs_fn, - sleep_staging_sleepedf_fn, -) -from .sleep_staging_v2 import SleepStagingSleepEDF -from .temple_university_EEG_tasks import ( - EEGEventsTUEV, - EEGAbnormalTUAB -) +try: + from .sleep_staging import ( + sleep_staging_isruc_fn, + sleep_staging_shhs_fn, + sleep_staging_sleepedf_fn, + ) +except ImportError: + pass # mne unavailable +try: + from .sleep_staging_v2 import SleepStagingSleepEDF +except ImportError: + pass # mne unavailable +try: + from .temple_university_EEG_tasks import ( + EEGEventsTUEV, + EEGAbnormalTUAB + ) +except ImportError: + pass # mne unavailable from .variant_classification import ( MutationPathogenicityPrediction, VariantClassificationClinVar, diff --git a/pyhealth/tasks/halo_generation.py b/pyhealth/tasks/halo_generation.py new file mode 100644 index 000000000..2e0945853 --- /dev/null +++ b/pyhealth/tasks/halo_generation.py @@ -0,0 +1,107 @@ +"""Task function for HALO synthetic data generation.""" + +from typing import Any, Dict, List + +import polars as pl + +from pyhealth.tasks.base_task import BaseTask + + +class HaloGenerationMIMIC3(BaseTask): + """Task for HALO synthetic data generation using MIMIC-III dataset. + + HALO trains an autoregressive transformer to generate synthetic EHR data. + This task extracts diagnosis code sequences per patient, where each patient + produces one sample containing all their visits. + + Each sample contains all admissions for a patient, with the ICD-9 diagnosis + codes for each admission grouped into a nested list. Patients with fewer than + 2 admissions that contain diagnosis codes are excluded. + + Attributes: + task_name (str): "HaloGenerationMIMIC3" + input_schema (Dict[str, str]): {"visits": "nested_sequence"} + output_schema (Dict[str, str]): {} (generative task, no prediction target) + _icd_col (str): Polars column name for ICD codes in diagnoses_icd table. + + Examples: + >>> from pyhealth.datasets import MIMIC3Dataset + >>> from pyhealth.tasks import HaloGenerationMIMIC3 + >>> dataset = MIMIC3Dataset( + ... root="path/to/mimic3", + ... tables=["diagnoses_icd"], + ... ) + >>> task = HaloGenerationMIMIC3() + >>> sample_dataset = dataset.set_task(task) + >>> sample_dataset[0] # doctest: +ELLIPSIS + {'patient_id': ..., 'visits': [[...], ...]} + """ + + task_name: str = "HaloGenerationMIMIC3" + input_schema: Dict[str, str] = {"visits": "nested_sequence"} + output_schema: Dict[str, str] = {} + _icd_col: str = "diagnoses_icd/icd9_code" + + def __call__(self, patient: Any) -> List[Dict[str, Any]]: + """Process a patient for HALO generation task. + + Extracts diagnosis codes per admission, grouped by visit. Returns one + sample per patient containing all visits with diagnosis codes. Excludes + patients with fewer than 2 visits containing diagnosis codes. + + Args: + patient: Patient object (PyHealth 2.0 Polars-based API) + + Returns: + List of at most one sample dict: + { + "patient_id": str, + "visits": [[code1, code2, ...], [code3, ...], ...] + } + Returns empty list if patient has fewer than 2 admissions with + diagnosis codes. + """ + admissions = patient.get_events(event_type="admissions") + if len(admissions) < 2: + return [] + + visits = [] + for admission in admissions: + diagnoses_df = patient.get_events( + event_type="diagnoses_icd", + filters=[("hadm_id", "==", admission.hadm_id)], + return_df=True, + ) + codes = ( + diagnoses_df.select(pl.col(self._icd_col)) + .to_series() + .drop_nulls() + .to_list() + ) + if len(codes) > 0: + visits.append(codes) + + if len(visits) < 2: + return [] + + return [{"patient_id": patient.patient_id, "visits": visits}] + + +class HaloGenerationMIMIC4(HaloGenerationMIMIC3): + """Task for HALO synthetic data generation using MIMIC-IV dataset. + + Same logic as HaloGenerationMIMIC3. MIMIC-IV stores ICD codes under + ``diagnoses_icd/icd_code`` rather than ``diagnoses_icd/icd9_code``. + + Attributes: + task_name (str): "HaloGenerationMIMIC4" + _icd_col (str): "diagnoses_icd/icd_code" + """ + + task_name: str = "HaloGenerationMIMIC4" + _icd_col: str = "diagnoses_icd/icd_code" + + +# Convenience callable instances (same pattern as other PyHealth tasks) +halo_generation_mimic3_fn = HaloGenerationMIMIC3() +halo_generation_mimic4_fn = HaloGenerationMIMIC4() diff --git a/tests/core/test_halo_generation_task.py b/tests/core/test_halo_generation_task.py new file mode 100644 index 000000000..865816747 --- /dev/null +++ b/tests/core/test_halo_generation_task.py @@ -0,0 +1,375 @@ +"""Tests for HaloGenerationMIMIC3 task function.""" + +import unittest +from datetime import datetime +from pathlib import Path +from unittest.mock import MagicMock + +import polars as pl + +from pyhealth.data import Patient +from pyhealth.tasks.base_task import BaseTask +from pyhealth.tasks.halo_generation import ( + HaloGenerationMIMIC3, + HaloGenerationMIMIC4, + halo_generation_mimic3_fn, + halo_generation_mimic4_fn, +) + + +def _make_patient_df(admissions, diagnoses): + """Build a minimal Polars DataFrame for a Patient. + + Args: + admissions: list of dicts with keys: hadm_id, admittime, dischtime + diagnoses: list of dicts with keys: hadm_id, icd9_code + Use icd9_code=None to simulate a null/missing code. + + Returns: + pl.DataFrame suitable for Patient(patient_id, df). + """ + rows = [] + + for adm in admissions: + rows.append( + { + "event_type": "admissions", + "timestamp": adm["admittime"], + "admissions/hadm_id": str(adm["hadm_id"]), + "admissions/admittime": adm["admittime"], + "admissions/dischtime": adm.get("dischtime", adm["admittime"]), + "diagnoses_icd/hadm_id": None, + "diagnoses_icd/icd9_code": None, + } + ) + + for diag in diagnoses: + rows.append( + { + "event_type": "diagnoses_icd", + "timestamp": diag.get("timestamp", datetime(2100, 1, 1)), + "admissions/hadm_id": None, + "admissions/admittime": None, + "admissions/dischtime": None, + "diagnoses_icd/hadm_id": str(diag["hadm_id"]), + "diagnoses_icd/icd9_code": diag.get("icd9_code"), + } + ) + + schema = { + "event_type": pl.Utf8, + "timestamp": pl.Datetime, + "admissions/hadm_id": pl.Utf8, + "admissions/admittime": pl.Datetime, + "admissions/dischtime": pl.Datetime, + "diagnoses_icd/hadm_id": pl.Utf8, + "diagnoses_icd/icd9_code": pl.Utf8, + } + return pl.DataFrame(rows, schema=schema) + + +class TestHaloGenerationMIMIC3TaskSchema(unittest.TestCase): + """Tests for task class attributes (schema, naming).""" + + def test_task_name(self): + self.assertEqual(HaloGenerationMIMIC3.task_name, "HaloGenerationMIMIC3") + + def test_input_schema(self): + self.assertIn("visits", HaloGenerationMIMIC3.input_schema) + self.assertEqual(HaloGenerationMIMIC3.input_schema["visits"], "nested_sequence") + + def test_output_schema_is_empty(self): + self.assertEqual(HaloGenerationMIMIC3.output_schema, {}) + + def test_inherits_base_task(self): + task = HaloGenerationMIMIC3() + self.assertIsInstance(task, BaseTask) + + def test_convenience_instance_is_base_task(self): + self.assertIsInstance(halo_generation_mimic3_fn, BaseTask) + self.assertIsInstance(halo_generation_mimic3_fn, HaloGenerationMIMIC3) + + def test_convenience_instance_is_callable(self): + self.assertTrue(callable(halo_generation_mimic3_fn)) + + +class TestHaloGenerationMIMIC3Call(unittest.TestCase): + """Unit tests for HaloGenerationMIMIC3.__call__ using mock Patient objects.""" + + def setUp(self): + self.task = HaloGenerationMIMIC3() + + def _make_patient(self, patient_id, admissions, diagnoses): + df = _make_patient_df(admissions, diagnoses) + return Patient(patient_id=patient_id, data_source=df) + + def test_valid_patient_two_visits_returns_one_sample(self): + """Patient with 2 admissions each having diagnosis codes returns 1 sample.""" + patient = self._make_patient( + patient_id="P001", + admissions=[ + {"hadm_id": "100", "admittime": datetime(2100, 1, 1)}, + {"hadm_id": "200", "admittime": datetime(2101, 1, 1)}, + ], + diagnoses=[ + {"hadm_id": "100", "icd9_code": "D001", "timestamp": datetime(2100, 1, 2)}, + {"hadm_id": "100", "icd9_code": "D002", "timestamp": datetime(2100, 1, 2)}, + {"hadm_id": "200", "icd9_code": "D003", "timestamp": datetime(2101, 1, 2)}, + ], + ) + result = self.task(patient) + self.assertEqual(len(result), 1) + + def test_valid_patient_sample_structure(self): + """Sample must contain patient_id and visits keys.""" + patient = self._make_patient( + patient_id="P001", + admissions=[ + {"hadm_id": "100", "admittime": datetime(2100, 1, 1)}, + {"hadm_id": "200", "admittime": datetime(2101, 1, 1)}, + ], + diagnoses=[ + {"hadm_id": "100", "icd9_code": "D001", "timestamp": datetime(2100, 1, 2)}, + {"hadm_id": "200", "icd9_code": "D003", "timestamp": datetime(2101, 1, 2)}, + ], + ) + result = self.task(patient) + sample = result[0] + self.assertIn("patient_id", sample) + self.assertIn("visits", sample) + + def test_valid_patient_patient_id_correct(self): + """Sample patient_id must match the patient's ID.""" + patient = self._make_patient( + patient_id="P42", + admissions=[ + {"hadm_id": "100", "admittime": datetime(2100, 1, 1)}, + {"hadm_id": "200", "admittime": datetime(2101, 1, 1)}, + ], + diagnoses=[ + {"hadm_id": "100", "icd9_code": "D001", "timestamp": datetime(2100, 1, 2)}, + {"hadm_id": "200", "icd9_code": "D003", "timestamp": datetime(2101, 1, 2)}, + ], + ) + result = self.task(patient) + self.assertEqual(result[0]["patient_id"], "P42") + + def test_valid_patient_visits_are_nested_list(self): + """visits must be a list of lists of strings.""" + patient = self._make_patient( + patient_id="P001", + admissions=[ + {"hadm_id": "100", "admittime": datetime(2100, 1, 1)}, + {"hadm_id": "200", "admittime": datetime(2101, 1, 1)}, + ], + diagnoses=[ + {"hadm_id": "100", "icd9_code": "D001", "timestamp": datetime(2100, 1, 2)}, + {"hadm_id": "100", "icd9_code": "D002", "timestamp": datetime(2100, 1, 2)}, + {"hadm_id": "200", "icd9_code": "D003", "timestamp": datetime(2101, 1, 2)}, + ], + ) + result = self.task(patient) + visits = result[0]["visits"] + self.assertIsInstance(visits, list) + for visit in visits: + self.assertIsInstance(visit, list) + for code in visit: + self.assertIsInstance(code, str) + + def test_valid_patient_codes_are_correct(self): + """Codes in visits must match the ICD codes from the input data.""" + patient = self._make_patient( + patient_id="P001", + admissions=[ + {"hadm_id": "100", "admittime": datetime(2100, 1, 1)}, + {"hadm_id": "200", "admittime": datetime(2101, 1, 1)}, + ], + diagnoses=[ + {"hadm_id": "100", "icd9_code": "D001", "timestamp": datetime(2100, 1, 2)}, + {"hadm_id": "100", "icd9_code": "D002", "timestamp": datetime(2100, 1, 2)}, + {"hadm_id": "200", "icd9_code": "D003", "timestamp": datetime(2101, 1, 2)}, + ], + ) + result = self.task(patient) + visits = result[0]["visits"] + self.assertEqual(len(visits), 2) + self.assertIn("D001", visits[0]) + self.assertIn("D002", visits[0]) + self.assertIn("D003", visits[1]) + + def test_patient_with_one_admission_returns_empty(self): + """Patients with only 1 admission are excluded.""" + patient = self._make_patient( + patient_id="P001", + admissions=[ + {"hadm_id": "100", "admittime": datetime(2100, 1, 1)}, + ], + diagnoses=[ + {"hadm_id": "100", "icd9_code": "D001", "timestamp": datetime(2100, 1, 2)}, + ], + ) + result = self.task(patient) + self.assertEqual(result, []) + + def test_patient_with_zero_admissions_returns_empty(self): + """Patients with no admissions return empty list.""" + patient = self._make_patient( + patient_id="P001", + admissions=[], + diagnoses=[], + ) + result = self.task(patient) + self.assertEqual(result, []) + + def test_patient_only_one_visit_has_codes_returns_empty(self): + """Patient with 2 admissions but only 1 with diagnosis codes returns empty.""" + patient = self._make_patient( + patient_id="P001", + admissions=[ + {"hadm_id": "100", "admittime": datetime(2100, 1, 1)}, + {"hadm_id": "200", "admittime": datetime(2101, 1, 1)}, + ], + diagnoses=[ + # Only hadm_id 100 has codes; 200 has none + {"hadm_id": "100", "icd9_code": "D001", "timestamp": datetime(2100, 1, 2)}, + ], + ) + result = self.task(patient) + self.assertEqual(result, []) + + def test_patient_with_null_codes_skipped(self): + """Null diagnosis codes must be dropped; if only nulls in a visit, skip it.""" + patient = self._make_patient( + patient_id="P001", + admissions=[ + {"hadm_id": "100", "admittime": datetime(2100, 1, 1)}, + {"hadm_id": "200", "admittime": datetime(2101, 1, 1)}, + {"hadm_id": "300", "admittime": datetime(2102, 1, 1)}, + ], + diagnoses=[ + # hadm 100: only a null code → should be skipped + {"hadm_id": "100", "icd9_code": None, "timestamp": datetime(2100, 1, 2)}, + # hadm 200: valid code + {"hadm_id": "200", "icd9_code": "D003", "timestamp": datetime(2101, 1, 2)}, + # hadm 300: valid code + {"hadm_id": "300", "icd9_code": "D004", "timestamp": datetime(2102, 1, 2)}, + ], + ) + result = self.task(patient) + # hadm 100 skipped → only 2 visits remain → valid patient + self.assertEqual(len(result), 1) + visits = result[0]["visits"] + self.assertEqual(len(visits), 2) + self.assertIn("D003", visits[0]) + self.assertIn("D004", visits[1]) + + def test_three_valid_visits_all_included(self): + """All visits with codes are included in the nested list.""" + patient = self._make_patient( + patient_id="P001", + admissions=[ + {"hadm_id": "100", "admittime": datetime(2100, 1, 1)}, + {"hadm_id": "200", "admittime": datetime(2101, 1, 1)}, + {"hadm_id": "300", "admittime": datetime(2102, 1, 1)}, + ], + diagnoses=[ + {"hadm_id": "100", "icd9_code": "A", "timestamp": datetime(2100, 1, 2)}, + {"hadm_id": "200", "icd9_code": "B", "timestamp": datetime(2101, 1, 2)}, + {"hadm_id": "300", "icd9_code": "C", "timestamp": datetime(2102, 1, 2)}, + ], + ) + result = self.task(patient) + self.assertEqual(len(result), 1) + visits = result[0]["visits"] + self.assertEqual(len(visits), 3) + + +class TestHaloGenerationMIMIC4TaskSchema(unittest.TestCase): + """Tests for HaloGenerationMIMIC4 class attributes.""" + + def test_task_name(self): + self.assertEqual(HaloGenerationMIMIC4.task_name, "HaloGenerationMIMIC4") + + def test_inherits_mimic3(self): + task = HaloGenerationMIMIC4() + self.assertIsInstance(task, HaloGenerationMIMIC3) + + def test_convenience_instance(self): + self.assertIsInstance(halo_generation_mimic4_fn, HaloGenerationMIMIC4) + self.assertIsInstance(halo_generation_mimic4_fn, BaseTask) + + +class TestHaloGenerationMIMIC3Integration(unittest.TestCase): + """Integration tests using the MIMIC-III demo dataset.""" + + @classmethod + def setUpClass(cls): + try: + import inspect + + from pyhealth.datasets import MIMIC3Dataset + + if not inspect.isclass(MIMIC3Dataset): + raise ImportError( + "pyhealth.datasets.MIMIC3Dataset is not a real class " + "(stub injected by another test module); skipping integration tests." + ) + + demo_path = str( + Path(__file__).parent.parent.parent + / "test-resources" + / "core" + / "mimic3demo" + ) + cls.dataset = MIMIC3Dataset( + root=demo_path, + tables=["diagnoses_icd"], + ) + cls.task = HaloGenerationMIMIC3() + cls.sample_dataset = cls.dataset.set_task(cls.task) + cls.skip_integration = False + except (FileNotFoundError, OSError, ImportError, ValueError) as e: + cls.skip_integration = True + cls.skip_reason = str(e) + + def setUp(self): + if self.skip_integration: + self.skipTest(f"Integration test skipped: {self.skip_reason}") + + def test_set_task_returns_dataset(self): + self.assertIsNotNone(self.sample_dataset) + + def test_samples_generated(self): + """Should produce at least one sample from the demo dataset.""" + self.assertGreater(len(self.sample_dataset), 0) + + def test_sample_has_required_keys(self): + """Every sample must have patient_id and visits.""" + for sample in self.sample_dataset: + self.assertIn("patient_id", sample) + self.assertIn("visits", sample) + + def test_visits_are_nested_lists(self): + """visits must be a list of lists of strings.""" + for sample in self.sample_dataset: + visits = sample["visits"] + self.assertIsInstance(visits, list) + self.assertGreaterEqual(len(visits), 2) + for visit in visits: + self.assertIsInstance(visit, list) + self.assertGreater(len(visit), 0) + for code in visit: + self.assertIsInstance(code, str) + + def test_no_patients_with_single_visit(self): + """No sample should come from a patient with only one valid visit.""" + for sample in self.sample_dataset: + self.assertGreaterEqual( + len(sample["visits"]), + 2, + f"Patient {sample['patient_id']} has fewer than 2 visits", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/test_halo_model.py b/tests/core/test_halo_model.py new file mode 100644 index 000000000..5ae8c8005 --- /dev/null +++ b/tests/core/test_halo_model.py @@ -0,0 +1,461 @@ +"""Tests for HALO model inheriting BaseModel. + +TDD tests written before the implementation. All tests use a mock SampleDataset +built with NestedSequenceProcessor — no MIMIC-III data required. + +Uses importlib to load pyhealth.models.base_model and pyhealth.models.generators.halo +directly without triggering pyhealth.models.__init__.py (which requires optional +heavy dependencies like litdata, pyarrow, einops not present in the test venv). +""" + +import importlib.util +import os +import sys +import types +import unittest +from unittest.mock import MagicMock + + +def _bootstrap_imports(): + """Set up sys.modules so BaseModel and HALO can be imported cleanly. + + The test venv only has torch and polars installed. Heavy deps declared in + pyproject.toml (litdata, pyarrow, einops, ...) are not present. This + function stubs out exactly the modules needed to load base_model.py and + halo.py without hitting missing-import errors from unrelated models. + """ + # Import top-level pyhealth package (safe -- __init__.py has no heavy deps) + import pyhealth # noqa: F401 + + # Stub pyhealth.datasets so that base_model.py's + # "from ..datasets import SampleDataset" resolves cleanly. + if "pyhealth.datasets" not in sys.modules: + ds_stub = MagicMock() + class _FakeSampleDataset: # noqa: N801 + pass + ds_stub.SampleDataset = _FakeSampleDataset + sys.modules["pyhealth.datasets"] = ds_stub + + # Stub pyhealth.models package so we can control what gets loaded + # without running the real __init__.py (which requires einops, litdata, etc.) + if "pyhealth.models" not in sys.modules or isinstance( + sys.modules["pyhealth.models"], MagicMock + ): + models_stub = MagicMock() + sys.modules["pyhealth.models"] = models_stub + else: + models_stub = sys.modules["pyhealth.models"] + + # Load pyhealth.processors (clean — no heavy deps) + from pyhealth.processors import PROCESSOR_REGISTRY # noqa: F401 + + # Load base_model.py directly via importlib + def _load_file(mod_name, filepath): + spec = importlib.util.spec_from_file_location(mod_name, filepath) + mod = importlib.util.module_from_spec(spec) + sys.modules[mod_name] = mod + spec.loader.exec_module(mod) + return mod + + _PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) + base = os.path.join(_PROJECT_ROOT, "pyhealth", "models") + res = os.path.join(_PROJECT_ROOT, "pyhealth", "models", "generators", "halo_resources") + + # Load base_model + bm_mod = _load_file("pyhealth.models.base_model", f"{base}/base_model.py") + BaseModel = bm_mod.BaseModel + + # Make BaseModel accessible via the stub so halo.py can import it + models_stub.BaseModel = BaseModel + + # Load halo_resources sub-modules + _load_file( + "pyhealth.models.generators.halo_resources.halo_config", + f"{res}/halo_config.py", + ) + _load_file( + "pyhealth.models.generators.halo_resources.halo_model", + f"{res}/halo_model.py", + ) + + # Stub the generators package so relative imports in halo.py work + gen_stub = MagicMock() + sys.modules.setdefault("pyhealth.models.generators", gen_stub) + + # Load halo.py directly + halo_mod = _load_file( + "pyhealth.models.generators.halo", + f"{base}/generators/halo.py", + ) + + return BaseModel, halo_mod.HALO + + +BaseModel, HALO = _bootstrap_imports() + +import torch # noqa: E402 +import torch.nn as nn # noqa: E402 +from pyhealth.processors.nested_sequence_processor import NestedSequenceProcessor # noqa: E402 + + +# --------------------------------------------------------------------------- +# Shared fixtures +# --------------------------------------------------------------------------- + +def _make_mock_dataset(): + """Build a minimal mock SampleDataset with a fitted NestedSequenceProcessor.""" + samples = [ + {"patient_id": "p1", "visits": [["A", "B"], ["C", "D"], ["E"]]}, + {"patient_id": "p2", "visits": [["A", "C"], ["B", "E"], ["D", "F"]]}, + {"patient_id": "p3", "visits": [["G", "H"], ["I", "J"], ["A", "B"]]}, + ] + processor = NestedSequenceProcessor() + processor.fit(samples, "visits") + + dataset = MagicMock() + dataset.input_schema = {"visits": "nested_sequence"} + dataset.output_schema = {} + dataset.input_processors = {"visits": processor} + dataset.output_processors = {} + return dataset, processor + + +def _make_visits_tensor(processor, batch_size=2, n_visits=3): + """Build a padded visits tensor of shape (batch, n_visits, max_inner_len). + + Only writes to positions that exist given the provided n_visits and batch_size. + """ + max_inner = processor._max_inner_len + visits = torch.zeros(batch_size, n_visits, max_inner, dtype=torch.long) + + # Batch item 0 — fill up to min(3, n_visits) visits + if n_visits > 0: + visits[0, 0, 0] = processor.code_vocab.get("A", 0) + if max_inner > 1: + visits[0, 0, 1] = processor.code_vocab.get("B", 0) + if n_visits > 1: + visits[0, 1, 0] = processor.code_vocab.get("C", 0) + if n_visits > 2: + visits[0, 2, 0] = processor.code_vocab.get("E", 0) + + # Batch item 1 (only if batch_size > 1) + if batch_size > 1: + if n_visits > 0: + visits[1, 0, 0] = processor.code_vocab.get("A", 0) + if max_inner > 1: + visits[1, 0, 1] = processor.code_vocab.get("C", 0) + if n_visits > 1: + visits[1, 1, 0] = processor.code_vocab.get("B", 0) + + return visits + + +# --------------------------------------------------------------------------- +# Test cases +# --------------------------------------------------------------------------- + +class TestHALOInheritsBaseModel(unittest.TestCase): + """Structural and inheritance tests for the HALO class.""" + + def setUp(self): + self.dataset, self.processor = _make_mock_dataset() + # Use tiny dimensions to keep tests fast + self.model = HALO( + dataset=self.dataset, + embed_dim=16, + n_heads=2, + n_layers=2, + n_ctx=12, + ) + + def test_halo_inherits_basemodel(self): + """HALO must be a subclass of BaseModel.""" + self.assertIsInstance(self.model, BaseModel) + + def test_halo_is_nn_module(self): + """HALO must be a subclass of nn.Module.""" + self.assertIsInstance(self.model, nn.Module) + + def test_halo_feature_keys(self): + """feature_keys must be ['visits'] from input_schema.""" + self.assertEqual(self.model.feature_keys, ["visits"]) + + def test_halo_label_keys_empty(self): + """label_keys must be [] because output_schema is empty.""" + self.assertEqual(self.model.label_keys, []) + + def test_halo_has_device_property(self): + """model.device must return a torch.device instance.""" + self.assertIsInstance(self.model.device, torch.device) + + def test_halo_has_halo_model_submodule(self): + """self.halo_model must exist as a registered nn.Module submodule.""" + self.assertTrue(hasattr(self.model, "halo_model")) + self.assertIsInstance(self.model.halo_model, nn.Module) + + def test_halo_config_vocab_sizes(self): + """config.code_vocab_size must equal processor.vocab_size().""" + self.assertEqual( + self.model.config.code_vocab_size, + self.processor.vocab_size(), + ) + + def test_halo_config_label_vocab_size_zero(self): + """config.label_vocab_size must be 0 (no labels).""" + self.assertEqual(self.model.config.label_vocab_size, 0) + + def test_halo_config_total_vocab_size(self): + """total_vocab_size == code_vocab_size + label_vocab_size + 3.""" + expected = self.processor.vocab_size() + 0 + 3 + self.assertEqual(self.model.config.total_vocab_size, expected) + + def test_halo_halo_model_is_registered_submodule(self): + """halo_model must appear in named_modules() so it participates in state_dict.""" + module_names = [name for name, _ in self.model.named_modules()] + self.assertIn("halo_model", module_names) + + def test_halo_parameters_non_empty(self): + """Model must have trainable parameters (from halo_model).""" + params = list(self.model.parameters()) + self.assertGreater(len(params), 0) + + +class TestHALOForward(unittest.TestCase): + """Tests for HALO.forward() method.""" + + def setUp(self): + self.dataset, self.processor = _make_mock_dataset() + self.model = HALO( + dataset=self.dataset, + embed_dim=16, + n_heads=2, + n_layers=2, + n_ctx=12, + ) + self.visits_tensor = _make_visits_tensor(self.processor, batch_size=2, n_visits=3) + + def test_halo_forward_returns_dict(self): + """forward() must return a dict.""" + output = self.model(visits=self.visits_tensor) + self.assertIsInstance(output, dict) + + def test_halo_forward_returns_loss(self): + """forward() must have 'loss' key in output.""" + output = self.model(visits=self.visits_tensor) + self.assertIn("loss", output) + + def test_halo_forward_loss_is_tensor(self): + """loss must be a torch.Tensor.""" + output = self.model(visits=self.visits_tensor) + self.assertIsInstance(output["loss"], torch.Tensor) + + def test_halo_forward_loss_is_scalar(self): + """loss must be a scalar (0-dimensional tensor).""" + output = self.model(visits=self.visits_tensor) + self.assertEqual(output["loss"].dim(), 0) + + def test_halo_forward_returns_predictions(self): + """forward() must also return 'predictions' key.""" + output = self.model(visits=self.visits_tensor) + self.assertIn("predictions", output) + + def test_halo_forward_shape_batch2_visits3(self): + """forward() with (2, 3, max_codes) visits tensor -> scalar loss.""" + output = self.model(visits=self.visits_tensor) + self.assertEqual(output["loss"].shape, torch.Size([])) + + def test_halo_train_mode_forward(self): + """In training mode, forward() must return loss.""" + self.model.train() + output = self.model(visits=self.visits_tensor) + self.assertIn("loss", output) + self.assertIsInstance(output["loss"], torch.Tensor) + + def test_halo_eval_mode_forward(self): + """In eval mode with no_grad, forward() must still return loss.""" + self.model.eval() + with torch.no_grad(): + output = self.model(visits=self.visits_tensor) + self.assertIn("loss", output) + self.assertIsInstance(output["loss"], torch.Tensor) + + def test_halo_forward_loss_is_finite(self): + """loss must be a finite number (not NaN or Inf).""" + output = self.model(visits=self.visits_tensor) + self.assertTrue(torch.isfinite(output["loss"])) + + def test_halo_forward_batch_size_1(self): + """forward() must work with batch_size=1.""" + visits_1 = _make_visits_tensor(self.processor, batch_size=1, n_visits=2) + output = self.model(visits=visits_1) + self.assertIn("loss", output) + self.assertEqual(output["loss"].dim(), 0) + + def test_halo_forward_accepts_kwargs(self): + """forward() must accept extra kwargs without crashing (e.g., patient_id).""" + output = self.model(visits=self.visits_tensor, patient_id=["p1", "p2"]) + self.assertIn("loss", output) + + +class TestHALOInit(unittest.TestCase): + """Tests for HALO.__init__ parameter handling.""" + + def setUp(self): + self.dataset, self.processor = _make_mock_dataset() + + def test_save_dir_stored(self): + """save_dir parameter must be stored.""" + model = HALO( + dataset=self.dataset, + embed_dim=16, + n_heads=2, + n_layers=2, + n_ctx=12, + save_dir="/tmp/test_halo/", + ) + self.assertEqual(model.save_dir, "/tmp/test_halo/") + + def test_visits_processor_stored(self): + """visits_processor must be stored as a reference.""" + model = HALO( + dataset=self.dataset, + embed_dim=16, + n_heads=2, + n_layers=2, + n_ctx=12, + ) + self.assertTrue(hasattr(model, "visits_processor")) + + def test_config_n_ctx(self): + """config.n_ctx must reflect the n_ctx parameter.""" + model = HALO( + dataset=self.dataset, + embed_dim=16, + n_heads=2, + n_layers=2, + n_ctx=10, + ) + self.assertEqual(model.config.n_ctx, 10) + + def test_config_embed_dim(self): + """config.n_embd must reflect the embed_dim parameter.""" + model = HALO( + dataset=self.dataset, + embed_dim=32, + n_heads=2, + n_layers=2, + n_ctx=12, + ) + self.assertEqual(model.config.n_embd, 32) + + def test_config_epochs(self): + """config.epoch must reflect the epochs parameter.""" + model = HALO( + dataset=self.dataset, + embed_dim=16, + n_heads=2, + n_layers=2, + n_ctx=12, + epochs=10, + ) + self.assertEqual(model.config.epoch, 10) + + +class TestHALOSynthesizeDataset(unittest.TestCase): + """Tests for HALO.synthesize_dataset().""" + + def setUp(self): + self.dataset, self.processor = _make_mock_dataset() + self.model = HALO( + dataset=self.dataset, + embed_dim=32, + n_heads=2, + n_layers=1, + n_ctx=8, + ) + + def test_synthesize_dataset_returns_list_of_correct_length(self): + """synthesize_dataset(num_samples=2) must return a list of length 2.""" + result = self.model.synthesize_dataset(num_samples=2) + self.assertIsInstance(result, list) + self.assertEqual(len(result), 2) + + def test_synthesize_dataset_items_have_required_keys(self): + """Each item in the result must be a dict with 'patient_id' and 'visits'.""" + result = self.model.synthesize_dataset(num_samples=2) + for item in result: + self.assertIsInstance(item, dict) + self.assertIn("patient_id", item) + self.assertIn("visits", item) + + +class TestHALOTrainModelVariableVisits(unittest.TestCase): + """Tests that train_model() handles patients with different visit counts.""" + + def _make_variable_visit_dataset(self, processor): + """Build a dataset-like object whose items have different visit counts. + + Patient p1 has 2 visits, patient p2 has 3 visits. Without a custom + collate_fn the default torch.stack would raise RuntimeError because + the visit-dimension sizes differ. + """ + max_inner = processor._max_inner_len + + item_p1 = { + "patient_id": "p1", + "visits": torch.zeros(2, max_inner, dtype=torch.long), + } + item_p1["visits"][0, 0] = processor.code_vocab.get("A", 0) + item_p1["visits"][1, 0] = processor.code_vocab.get("B", 0) + + item_p2 = { + "patient_id": "p2", + "visits": torch.zeros(3, max_inner, dtype=torch.long), + } + item_p2["visits"][0, 0] = processor.code_vocab.get("A", 0) + item_p2["visits"][1, 0] = processor.code_vocab.get("C", 0) + item_p2["visits"][2, 0] = processor.code_vocab.get("E", 0) + + samples = [item_p1, item_p2] + + class _ListDataset: + """Minimal map-style dataset backed by a plain list.""" + def __init__(self, items): + self._items = items + self.input_schema = {"visits": "nested_sequence"} + self.output_schema = {} + self.input_processors = {"visits": processor} + self.output_processors = {} + + def __len__(self): + return len(self._items) + + def __getitem__(self, idx): + return self._items[idx] + + return _ListDataset(samples) + + def test_train_model_variable_visit_counts_no_error(self): + """train_model() with patients having different visit counts must not raise RuntimeError.""" + base_dataset, processor = _make_mock_dataset() + model = HALO( + dataset=base_dataset, + embed_dim=16, + n_heads=2, + n_layers=2, + n_ctx=12, + epochs=1, + batch_size=2, + ) + + train_dataset = self._make_variable_visit_dataset(processor) + + # Should complete without raising RuntimeError from collate + try: + model.train_model(train_dataset, val_dataset=None) + except RuntimeError as e: + self.fail(f"train_model() raised RuntimeError with variable visit counts: {e}") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/integration/test_halo_end_to_end.py b/tests/integration/test_halo_end_to_end.py new file mode 100644 index 000000000..754bba30c --- /dev/null +++ b/tests/integration/test_halo_end_to_end.py @@ -0,0 +1,401 @@ +"""End-to-end integration tests for the HALO synthetic EHR generation pipeline. + +Category A tests use InMemorySampleDataset with synthetic data — no external +data required and must always pass. + +Category B tests require actual MIMIC-III data and are skipped gracefully when +the data is unavailable. + +The bootstrap pattern (loading HALO and InMemorySampleDataset via importlib +while stubbing out heavy optional dependencies) mirrors the approach used in +tests/core/test_halo_model.py. +""" + +import importlib.util +import os +import sys +import tempfile +import unittest +from unittest.mock import MagicMock + + +# --------------------------------------------------------------------------- +# Bootstrap: load HALO, BaseModel, and InMemorySampleDataset without +# triggering pyhealth.models.__init__ (requires einops, litdata, etc.) or +# pyhealth.datasets.__init__ (requires litdata, pyarrow, pandas, dask, ...). +# --------------------------------------------------------------------------- + +def _bootstrap(): + """Load HALO, BaseModel, and InMemorySampleDataset via importlib. + + Returns: + (BaseModel, HALO, InMemorySampleDataset) + """ + import pyhealth # noqa: F401 — top-level __init__ has no heavy deps + + # Stub pyhealth.datasets so that base_model.py's + # "from ..datasets import SampleDataset" resolves cleanly. + if "pyhealth.datasets" not in sys.modules: + ds_stub = MagicMock() + class _FakeSampleDataset: # noqa: N801 + pass + ds_stub.SampleDataset = _FakeSampleDataset + sys.modules["pyhealth.datasets"] = ds_stub + + # Stub pyhealth.models so we can control loading without the real __init__. + if "pyhealth.models" not in sys.modules or isinstance( + sys.modules["pyhealth.models"], MagicMock + ): + models_stub = MagicMock() + sys.modules["pyhealth.models"] = models_stub + else: + models_stub = sys.modules["pyhealth.models"] + + # Processors are safe to import normally. + from pyhealth.processors import PROCESSOR_REGISTRY # noqa: F401 + + def _load_file(mod_name, filepath): + spec = importlib.util.spec_from_file_location(mod_name, filepath) + mod = importlib.util.module_from_spec(spec) + sys.modules[mod_name] = mod + spec.loader.exec_module(mod) + return mod + + root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + base = os.path.join(root, "pyhealth", "models") + res = os.path.join(base, "generators", "halo_resources") + + # Load base_model and expose via stub. + bm_mod = _load_file("pyhealth.models.base_model", os.path.join(base, "base_model.py")) + BaseModel = bm_mod.BaseModel + models_stub.BaseModel = BaseModel + + # Load HALO sub-dependencies. + _load_file( + "pyhealth.models.generators.halo_resources.halo_config", + os.path.join(res, "halo_config.py"), + ) + _load_file( + "pyhealth.models.generators.halo_resources.halo_model", + os.path.join(res, "halo_model.py"), + ) + + gen_stub = MagicMock() + sys.modules.setdefault("pyhealth.models.generators", gen_stub) + + halo_mod = _load_file( + "pyhealth.models.generators.halo", + os.path.join(base, "generators", "halo.py"), + ) + HALO = halo_mod.HALO + + # Stub litdata so sample_dataset.py can be loaded without the full package. + # sample_dataset.py imports litdata.StreamingDataset and + # litdata.utilities.train_test_split.deepcopy_dataset. + if "litdata" not in sys.modules: + litdata_pkg = MagicMock() + litdata_pkg.StreamingDataset = type( + "StreamingDataset", (), {"__init__": lambda self, *a, **kw: None} + ) + litdata_utilities = MagicMock() + litdata_utilities_train_test = MagicMock() + litdata_utilities_train_test.deepcopy_dataset = lambda x: x + litdata_utilities.train_test_split = litdata_utilities_train_test + litdata_pkg.utilities = litdata_utilities + sys.modules["litdata"] = litdata_pkg + sys.modules["litdata.utilities"] = litdata_utilities + sys.modules["litdata.utilities.train_test_split"] = litdata_utilities_train_test + + # Load sample_dataset.py directly (bypasses datasets/__init__.py). + ds_file_mod = _load_file( + "pyhealth.datasets.sample_dataset", + os.path.join(root, "pyhealth", "datasets", "sample_dataset.py"), + ) + InMemorySampleDataset = ds_file_mod.InMemorySampleDataset + + return BaseModel, HALO, InMemorySampleDataset + + +BaseModel, HALO, InMemorySampleDataset = _bootstrap() + +import torch # noqa: E402 +from torch.utils.data import DataLoader # noqa: E402 + + +# --------------------------------------------------------------------------- +# Shared helper +# --------------------------------------------------------------------------- + +_SMALL_SAMPLES = [ + {"patient_id": "p1", "visits": [["A", "B"], ["C"]]}, + {"patient_id": "p2", "visits": [["A"], ["B", "C"]]}, + {"patient_id": "p3", "visits": [["D", "E"], ["F", "G"], ["H"]]}, + {"patient_id": "p4", "visits": [["A", "B"], ["C", "D"]]}, + {"patient_id": "p5", "visits": [["B", "C"], ["D", "E"]]}, +] + +_SMALL_MODEL_KWARGS = dict( + embed_dim=32, + n_heads=2, + n_layers=1, + n_ctx=8, + batch_size=4, + epochs=1, +) + + +def _make_dataset(samples=None): + if samples is None: + samples = _SMALL_SAMPLES + return InMemorySampleDataset( + samples=samples, + input_schema={"visits": "nested_sequence"}, + output_schema={}, + ) + + +# --------------------------------------------------------------------------- +# Category A: In-Memory Integration Tests (must always pass) +# --------------------------------------------------------------------------- + + +class TestHALOFullPipelineForward(unittest.TestCase): + """Full pipeline: build dataset -> create HALO -> forward pass.""" + + def setUp(self): + self.dataset = _make_dataset() + self.tmpdir = tempfile.mkdtemp() + self.model = HALO(self.dataset, save_dir=self.tmpdir, **_SMALL_MODEL_KWARGS) + + def tearDown(self): + import shutil + shutil.rmtree(self.tmpdir, ignore_errors=True) + + def test_full_pipeline_forward(self): + """Forward pass on a batch from the dataset returns loss and predictions.""" + sample = self.dataset[0] + visits = sample["visits"].unsqueeze(0) # add batch dimension + output = self.model(visits=visits) + self.assertIsInstance(output, dict) + self.assertIn("loss", output) + self.assertIn("predictions", output) + self.assertIsInstance(output["loss"], torch.Tensor) + self.assertEqual(output["loss"].dim(), 0) + self.assertTrue(torch.isfinite(output["loss"])) + + +class TestHALOTrainModelOneEpoch(unittest.TestCase): + """train_model completes one epoch without error.""" + + def test_train_model_runs_one_epoch(self): + dataset = _make_dataset() + with tempfile.TemporaryDirectory() as tmpdir: + model = HALO(dataset, save_dir=tmpdir, **_SMALL_MODEL_KWARGS) + # Should complete without raising any exception. + try: + model.train_model(dataset, val_dataset=None) + except Exception as exc: # noqa: BLE001 + self.fail(f"train_model raised an unexpected exception: {exc}") + + +class TestHALOCheckpointSavedAfterTraining(unittest.TestCase): + """Checkpoint file exists after train_model with a validation dataset.""" + + def test_checkpoint_saved_after_training(self): + dataset = _make_dataset() + with tempfile.TemporaryDirectory() as tmpdir: + model = HALO(dataset, save_dir=tmpdir, **_SMALL_MODEL_KWARGS) + # Checkpoint is only written when val_dataset is provided and + # the validation loss improves. + model.train_model(dataset, val_dataset=dataset) + checkpoint_path = os.path.join(tmpdir, "halo_model") + self.assertTrue( + os.path.exists(checkpoint_path), + f"Expected checkpoint at {checkpoint_path}", + ) + + +class TestHALOSynthesizeReturnsCorrectCount(unittest.TestCase): + """synthesize_dataset(num_samples=5) returns exactly 5 dicts.""" + + def setUp(self): + self.dataset = _make_dataset() + self.tmpdir = tempfile.mkdtemp() + self.model = HALO(self.dataset, save_dir=self.tmpdir, **_SMALL_MODEL_KWARGS) + + def tearDown(self): + import shutil + shutil.rmtree(self.tmpdir, ignore_errors=True) + + def test_synthesize_returns_correct_count(self): + result = self.model.synthesize_dataset(num_samples=5) + self.assertIsInstance(result, list) + self.assertEqual(len(result), 5) + + +class TestHALOSynthesizeOutputStructure(unittest.TestCase): + """Each synthesized dict has patient_id and visits keys.""" + + def setUp(self): + self.dataset = _make_dataset() + self.tmpdir = tempfile.mkdtemp() + self.model = HALO(self.dataset, save_dir=self.tmpdir, **_SMALL_MODEL_KWARGS) + + def tearDown(self): + import shutil + shutil.rmtree(self.tmpdir, ignore_errors=True) + + def test_synthesize_output_structure(self): + result = self.model.synthesize_dataset(num_samples=3) + for i, item in enumerate(result): + self.assertIsInstance(item, dict, f"Item {i} is not a dict") + self.assertIn("patient_id", item, f"Item {i} missing 'patient_id'") + self.assertIn("visits", item, f"Item {i} missing 'visits'") + self.assertIsInstance(item["visits"], list, f"visits in item {i} is not a list") + for j, visit in enumerate(item["visits"]): + self.assertIsInstance( + visit, list, f"visit {j} in item {i} is not a list" + ) + + +class TestHALOVariableVisitCounts(unittest.TestCase): + """DataLoader does not crash when patients have different numbers of visits.""" + + def test_pipeline_with_variable_visit_counts(self): + # Patients with 2, 3, and 5 visits respectively. + samples = [ + {"patient_id": "p1", "visits": [["A", "B"], ["C"]]}, + {"patient_id": "p2", "visits": [["A"], ["B", "C"], ["D"]]}, + {"patient_id": "p3", "visits": [["E"], ["F"], ["G"], ["H"], ["I"]]}, + ] + dataset = _make_dataset(samples) + + with tempfile.TemporaryDirectory() as tmpdir: + model = HALO(dataset, save_dir=tmpdir, **_SMALL_MODEL_KWARGS) + # If _halo_collate_fn is missing or broken, DataLoader will raise + # RuntimeError due to mismatched visit-dimension sizes. + try: + model.train_model(dataset, val_dataset=None) + except RuntimeError as exc: + self.fail( + f"train_model raised RuntimeError with variable visit counts: {exc}" + ) + + +class TestHALOIsBaseModelInstance(unittest.TestCase): + """HALO model is an instance of BaseModel.""" + + def test_model_is_basemodel_instance(self): + dataset = _make_dataset() + with tempfile.TemporaryDirectory() as tmpdir: + model = HALO(dataset, save_dir=tmpdir, **_SMALL_MODEL_KWARGS) + self.assertIsInstance(model, BaseModel) + + +class TestHALOFeatureKeys(unittest.TestCase): + """model.feature_keys equals ['visits'].""" + + def test_feature_keys(self): + dataset = _make_dataset() + with tempfile.TemporaryDirectory() as tmpdir: + model = HALO(dataset, save_dir=tmpdir, **_SMALL_MODEL_KWARGS) + self.assertEqual(model.feature_keys, ["visits"]) + + +# --------------------------------------------------------------------------- +# Category B: MIMIC-III Integration Tests (skipped if data unavailable) +# --------------------------------------------------------------------------- + +_MIMIC3_PATH = os.environ.get( + "PYHEALTH_MIMIC3_PATH", + "/srv/local/data/physionet.org/files/mimiciii/1.4", +) + + +class TestHALOMIMIC3Integration(unittest.TestCase): + """End-to-end pipeline test with actual MIMIC-III data. + + Skipped automatically when MIMIC-III is not present on this machine. + """ + + @classmethod + def setUpClass(cls): + cls.skip_integration = False + cls.skip_reason = "" + try: + # Remove the bootstrap stub for pyhealth.datasets so we can attempt + # a real import (which will raise ImportError if litdata is absent). + _saved_stub = sys.modules.pop("pyhealth.datasets", None) + try: + import importlib as _il + _il.invalidate_caches() + from pyhealth.datasets import MIMIC3Dataset as _MIMIC3Dataset + from pyhealth.tasks.halo_generation import HaloGenerationMIMIC3 + except (ImportError, ModuleNotFoundError) as exc: + # Restore stub so the rest of the test session is unaffected. + if _saved_stub is not None: + sys.modules["pyhealth.datasets"] = _saved_stub + raise ImportError(str(exc)) from exc + # Restore whatever was there (real module or stub). + # If the import succeeded, sys.modules["pyhealth.datasets"] is now the + # real module — keep it. + + cls.dataset = _MIMIC3Dataset( + root=_MIMIC3_PATH, + tables=["diagnoses_icd"], + ) + task = HaloGenerationMIMIC3() + cls.sample_dataset = cls.dataset.set_task(task) + except (FileNotFoundError, OSError, ImportError, ValueError) as exc: + cls.skip_integration = True + cls.skip_reason = str(exc) + + def setUp(self): + if self.skip_integration: + self.skipTest(f"MIMIC-III integration test skipped: {self.skip_reason}") + + def test_mimic3_set_task_returns_nonempty_dataset(self): + """set_task produces at least one sample from MIMIC-III.""" + self.assertGreater(len(self.sample_dataset), 0) + + def test_mimic3_sample_keys(self): + """Every sample must contain patient_id and visits keys.""" + for sample in self.sample_dataset: + self.assertIn("patient_id", sample) + self.assertIn("visits", sample) + + def test_mimic3_visits_are_nested_lists_of_strings(self): + """visits must be a list of lists of strings with at least 2 visits.""" + for sample in self.sample_dataset: + visits = sample["visits"] + self.assertIsInstance(visits, list) + self.assertGreaterEqual(len(visits), 2) + for visit in visits: + self.assertIsInstance(visit, list) + self.assertGreater(len(visit), 0) + for code in visit: + self.assertIsInstance(code, str) + + def test_mimic3_full_pipeline_train_and_synthesize(self): + """Train one epoch on MIMIC-III data and synthesize a small batch.""" + with tempfile.TemporaryDirectory() as tmpdir: + model = HALO( + self.sample_dataset, + embed_dim=32, + n_heads=2, + n_layers=1, + n_ctx=8, + batch_size=16, + epochs=1, + save_dir=tmpdir, + ) + model.train_model(self.sample_dataset, val_dataset=None) + synthetic = model.synthesize_dataset(num_samples=10) + self.assertEqual(len(synthetic), 10) + for item in synthetic: + self.assertIn("patient_id", item) + self.assertIn("visits", item) + + +if __name__ == "__main__": + unittest.main()