Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
d1f97af
init generators commit
chufangao Jun 15, 2025
ee8c52c
base
Jul 16, 2025
00f10c2
Stab at implementation
Jul 16, 2025
b666f82
Misc. changes for testing
Jul 27, 2025
ec4f23d
Remove testing logs
Jul 27, 2025
4ce8e21
Clean up things a bit
Jul 27, 2025
b1584fd
Clean up hardcoded file path
Jul 27, 2025
d374603
Remove testing files from PR
Jul 27, 2025
4f456f9
Init model properly
Jul 27, 2025
56380f6
Update comments
Jul 27, 2025
5d4ede6
Add HALO generator with training and generation examples
jalengg Feb 4, 2026
d2b8da3
Remove non-HALO README changes
jalengg Feb 16, 2026
97050b8
Create HALO Colab notebook structure with headers
jalengg Feb 16, 2026
58ef738
Add setup and installation cells to HALO notebook
jalengg Feb 16, 2026
21f394f
Add configuration, data upload, training, generation, and results cel…
jalengg Feb 16, 2026
b1458fe
Add README documentation for HALO Colab notebook
jalengg Feb 16, 2026
702d65c
Fix installation cell to detect pip failures
jalengg Feb 16, 2026
261f819
Remove pandas<2 constraint for Python 3.12 compatibility
jalengg Feb 16, 2026
4acf2f2
Fix MIMIC-III file upload issues in Colab notebook
jalengg Feb 16, 2026
b8b4c96
Add missing __init__.py to halo_resources module
jalengg Feb 16, 2026
564cf0a
Add --no-cache-dir to pip install for latest code
jalengg Feb 16, 2026
8002123
Fix path concatenation bug in HALO_MIMIC3Dataset
jalengg Feb 16, 2026
2d4fbbd
Add MANIFEST.in to include YAML config files in package
jalengg Feb 16, 2026
6ce060d
Fix YAML config packaging: use package_data instead of MANIFEST.in
jalengg Feb 17, 2026
2418788
Add install timestamp to Colab notebook success message
jalengg Feb 17, 2026
f1ceb35
Add last-updated timestamp to Colab notebook header
jalengg Feb 17, 2026
663dbb8
Use human-readable timestamp format in notebook header
jalengg Feb 17, 2026
bc0f41b
Fix pkl file path concatenation bug in HALO_MIMIC3Dataset
jalengg Feb 17, 2026
200e693
added trailing slash
shiitavie Feb 18, 2026
76de88e
added trailing slash
shiitavie Feb 18, 2026
5ec2c42
format string error
shiitavie Feb 18, 2026
bb5de81
remove assertion (issue #23)
shiitavie Feb 19, 2026
4040422
fix: complete merge - add missing processor files from upstream/master
jalengg Feb 23, 2026
0c9e973
chore: merge upstream/master, resolve processor/model conflicts
jalengg Feb 23, 2026
a864781
feat: add halo_generation task function (HaloGenerationMIMIC3/4)
jalengg Feb 23, 2026
fe08005
refactor: make HALO inherit BaseModel with forward() and train_model()
jalengg Feb 25, 2026
6a814d1
test: add synthesize_dataset coverage for HALO
jalengg Feb 25, 2026
661ec69
fix: collate_fn for variable visit lengths and end-token position in …
jalengg Feb 25, 2026
5fb8ce0
feat: update halo training example to PyHealth 2.0 API
jalengg Feb 25, 2026
8892dda
feat: update halo generation example to PyHealth 2.0 API
jalengg Feb 25, 2026
7df241f
feat: remove HALO_MIMIC3Dataset (replaced by HaloGenerationMIMIC3 task)
jalengg Feb 25, 2026
299d272
docs: update HALO docstrings to Google/PyHealth style
jalengg Feb 25, 2026
4dcff7c
docs: fix synthesize_dataset Returns style and dataset type annotation
jalengg Feb 25, 2026
6142476
test: add HALO end-to-end integration tests
jalengg Feb 25, 2026
1b744eb
test: fix tearDown cleanup, env var path, and relative bootstrap paths
jalengg Feb 25, 2026
d5248de
test: guard integration test against sys.modules stub contamination f…
jalengg Feb 25, 2026
0d7775e
Update halo_mimic3_colab.ipynb to PyHealth 2.0 API and remove emojis
jalengg Mar 2, 2026
855f232
Update last updated date in halo_mimic3_colab.ipynb
jalengg Mar 2, 2026
c65ca42
Full dry-run audit of halo_mimic3_colab.ipynb against halo-pr-528 API
jalengg Mar 2, 2026
9715453
Fix numpy/scipy Colab incompatibility with kernel-restart pattern in …
jalengg Mar 2, 2026
b80f837
Replace numpy-downgrade hack with scipy>=1.14.0 upgrade in cell 2
jalengg Mar 2, 2026
5f66455
Wrap cardiology_detect import in try/except to fix scipy cascade failure
jalengg Mar 2, 2026
ad9051b
Wrap mne-dependent task imports in try/except to fix Colab scipy cascade
jalengg Mar 2, 2026
b1470ad
Guard mne/scipy-dependent imports to fix Colab numpy 2.x cascade failure
jalengg Mar 2, 2026
de00ece
Guard optional-dep imports to fix Colab numpy 2.x cascade failures
jalengg Mar 2, 2026
0f06422
Remove PyHealth 1.x args code_mapping and refresh_cache from MIMIC3Da…
jalengg Mar 2, 2026
71f5c71
Fix DataLoader shuffle=True incompatibility with IterableDataset; gua…
jalengg Mar 2, 2026
4cca0a4
Fix base_dataset.patients → unique_patient_ids (PyHealth 2.0 API)
jalengg Mar 2, 2026
f718886
Fix notebook: require 4 MIMIC-III files, remove nonexistent stat() call
jalengg Mar 2, 2026
d9408bc
Make icustays optional in MIMIC3Dataset; update notebook required files
jalengg Mar 2, 2026
c52aa0b
Remove icustays from MIMIC3Dataset defaults; add --force-reinstall to…
jalengg Mar 3, 2026
dada186
Wrap ChestXray14 and COVID19CXR dataset imports in try/except for PIL…
jalengg Mar 3, 2026
d1f49ac
Wrap CNN and VisionEmbeddingModel imports in try/except for PIL/torch…
jalengg Mar 3, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,13 @@ leaderboard/rtd_token.txt

# locally pre-trained models
pyhealth/medcode/pretrained_embeddings/kg_emb/examples/pretrained_model

# local testing files
halo_testing/
halo_testing_script.py
test_halo_model.slurm

data/physionet.org/

# VSCode settings
.vscode/
.vscode/
55 changes: 55 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# PyHealth Examples

This directory contains example scripts and notebooks for using PyHealth.

## HALO Synthetic Data Generation

### Google Colab Notebook (No Cluster Required)

**File**: `halo_mimic3_colab.ipynb`

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/sunlabuiuc/PyHealth/blob/master/examples/halo_mimic3_colab.ipynb)

Train HALO and generate synthetic MIMIC-III data directly in your browser using Google Colab.

**Requirements**:
- Google account (for Colab)
- MIMIC-III access from PhysioNet
- Files: ADMISSIONS.csv, DIAGNOSES_ICD.csv, PATIENTS.csv, patient_ids.txt

**Quick Start**:
1. Open `halo_mimic3_colab.ipynb` in Google Colab
2. Enable GPU (Runtime → Change runtime type → GPU)
3. Run cells in order
4. Upload your MIMIC-III files when prompted
5. Download synthetic data CSV

**Demo vs Production**:
- **Demo** (default): 5 epochs, 1K samples, ~30 min
- **Production**: 80 epochs, 10K samples, ~6-10 hours (change configuration)

**Features**:
- Google Drive integration for persistence
- Resume capability if session times out
- Automatic checkpoint saving
- CSV output format
- Data quality validation

### Cluster Training (SLURM)

**Files**:
- `slurm/train_halo_mimic3.slurm` - Training script
- `slurm/generate_halo_mimic3.slurm` - Generation script
- `halo_mimic3_training.py` - Python training code
- `generate_synthetic_mimic3_halo.py` - Python generation code

For users with access to GPU clusters. See individual script headers for usage.

**Example**:
```bash
# Train
sbatch slurm/train_halo_mimic3.slurm

# Generate
sbatch slurm/generate_halo_mimic3.slurm
```
150 changes: 150 additions & 0 deletions examples/generate_synthetic_mimic3_halo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
#!/usr/bin/env python3
"""
Example: Generate synthetic MIMIC-III patients using a trained HALO checkpoint.

Loads MIMIC3Dataset with the halo_generation_mimic3_fn task (identical to
training) so that the vocabulary is reconstructed, then loads the saved
checkpoint and calls model.synthesize_dataset(). Output is saved as JSON.

Usage:
python examples/generate_synthetic_mimic3_halo.py
python examples/generate_synthetic_mimic3_halo.py --save_dir ./my_save/ --num_samples 500
"""

import argparse
import json
import os

import torch

from pyhealth.datasets import MIMIC3Dataset
from pyhealth.models.generators.halo import HALO
from pyhealth.tasks import halo_generation_mimic3_fn


def parse_args():
parser = argparse.ArgumentParser(
description="Generate synthetic MIMIC-III patients with HALO"
)
parser.add_argument(
"--mimic3_root",
default="/path/to/mimic3",
help="Root directory of MIMIC-III data (default: /path/to/mimic3)",
)
parser.add_argument(
"--save_dir",
default="./save/",
help="Directory containing the trained halo_model checkpoint (default: ./save/)",
)
parser.add_argument(
"--num_samples",
type=int,
default=1000,
help="Number of synthetic patients to generate (default: 1000)",
)
parser.add_argument(
"--output",
default="synthetic_patients.json",
help="Output JSON file path (default: synthetic_patients.json)",
)
return parser.parse_args()


def main():
args = parse_args()

# ------------------------------------------------------------------
# STEP 1: Load MIMIC-III dataset
# The dataset must use the same tables and code_mapping as training
# so that the vocabulary is identical.
# ------------------------------------------------------------------
print("Loading MIMIC-III dataset...")
base_dataset = MIMIC3Dataset(
root=args.mimic3_root,
tables=["diagnoses_icd"], # If you trained with different tables=, update this to match.
code_mapping={},
dev=False,
refresh_cache=False,
)
print(f" Loaded {len(base_dataset.patients)} patients")

# ------------------------------------------------------------------
# STEP 2: Apply the HALO generation task
# set_task builds the vocabulary via NestedSequenceProcessor — must
# match the task used during training exactly.
# ------------------------------------------------------------------
print("Applying HALO generation task...")
sample_dataset = base_dataset.set_task(halo_generation_mimic3_fn)
print(f" {len(sample_dataset)} samples after task filtering")

# ------------------------------------------------------------------
# STEP 3: Instantiate HALO with the same hyperparameters as training
# The model constructor uses the dataset to determine vocab sizes;
# the weights are loaded from the checkpoint immediately after.
# ------------------------------------------------------------------
print("Initializing HALO model...")
model = HALO(
dataset=sample_dataset,
embed_dim=768,
n_heads=12,
n_layers=12,
n_ctx=48,
batch_size=48,
epochs=50, # unused during generation; must match training for checkpoint compatibility
pos_loss_weight=None,
lr=1e-4,
save_dir=args.save_dir,
)

# ------------------------------------------------------------------
# STEP 4: Load trained checkpoint
# The training loop saves to save_dir/halo_model with keys
# "model" (halo_model state dict) and "optimizer".
# ------------------------------------------------------------------
checkpoint_path = os.path.join(args.save_dir, "halo_model")
print(f"Loading checkpoint from {checkpoint_path} ...")
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(
f"Checkpoint not found at {checkpoint_path}. "
"Train the model first with examples/halo_mimic3_training.py."
)
checkpoint = torch.load(checkpoint_path, map_location="cpu")
model.halo_model.load_state_dict(checkpoint["model"])
print(" Checkpoint loaded successfully")

# ------------------------------------------------------------------
# STEP 5: Generate synthetic patients
# synthesize_dataset returns List[Dict] where each dict has:
# "patient_id": "synthetic_N"
# "visits": [[code, ...], ...]
# ------------------------------------------------------------------
print(f"Generating {args.num_samples} synthetic patients...")
synthetic_data = model.synthesize_dataset(
num_samples=args.num_samples,
random_sampling=True,
)

# ------------------------------------------------------------------
# STEP 6: Save output as JSON
# ------------------------------------------------------------------
print(f"Saving output to {args.output} ...")
with open(args.output, "w") as f:
json.dump(synthetic_data, f, indent=2)

# ------------------------------------------------------------------
# STEP 7: Print summary statistics
# ------------------------------------------------------------------
total_patients = len(synthetic_data)
total_visits = sum(len(p["visits"]) for p in synthetic_data)
avg_visits = total_visits / total_patients if total_patients > 0 else 0.0

print("\n--- Generation Summary ---")
print(f" Patients generated : {total_patients}")
print(f" Total visits : {total_visits}")
print(f" Avg visits/patient : {avg_visits:.2f}")
print(f" Output saved to : {args.output}")
print("Done.")


if __name__ == "__main__":
main()
Loading