-
Notifications
You must be signed in to change notification settings - Fork 0
Description
Overview
Create examples/corgan_mimic3_colab.ipynb — a Google Colab notebook for CorGAN synthetic EHR generation on MIMIC-III data, modeled after examples/halo_mimic3_colab.ipynb from the HALO PR. The notebook must serve three audiences simultaneously: practitioners who want a working pipeline, researchers who need to verify scientific claims, and CI systems that need to run a smoke-test without MIMIC-III data.
Reference notebook: examples/halo_mimic3_colab.ipynb (branch halo-pr-528)
Branch: corgan-pr-integration
Why CorGAN Needs More Than a Cloned HALO Notebook
CorGAN and HALO differ fundamentally in what needs to be demonstrated:
| Dimension | HALO | CorGAN |
|---|---|---|
| Data structure | Sequential visits (temporal) | Flat bag-of-codes (multi-hot binary vector per patient) |
| Architecture | Transformer language model | CNN Autoencoder + WGAN (3 interacting components) |
| Key innovation | Hierarchical autoregressive generation | Correlation capture via 1D convolutions |
| Training paradigm | Supervised (validation loss) | Adversarial (3 separate loss curves) |
| Convergence signal | Decreasing validation loss | Wasserstein distance stabilization |
| What must be shown | Output format, code/visit stats | Correlation preservation, adversarial dynamics |
CorGAN's entire scientific contribution is that CNN autoencoders capture inter-code correlations better than linear models. This must be empirically demonstrated in the notebook, not just asserted. The HALO notebook only validates format and counts; a CorGAN notebook must validate statistical fidelity.
Required Notebook Sections
Section 0: Preamble (Markdown)
Adapted from HALO's excellent preamble pattern.
- What you'll need: MIMIC-III access (or demo mode without it), GPU recommended, Colab with runtime set to GPU
- What you'll get: A trained CorGAN model + 10,000 synthetic MIMIC-III patients
- How long it takes: Demo (5 epochs, 20–30 min), Production (50 epochs, 2–4 hrs on T4)
- What makes CorGAN different: One short paragraph explaining the CNN autoencoder + WGAN architecture and why correlation capture matters clinically
- Reference: Baowaly et al., "Synthesizing Electronic Health Records Using Improved Generative Adversarial Networks", JAMIA 2019
Important: Unlike HALO's notebook which omits an architecture explanation, CorGAN's adversarial training is non-obvious to practitioners. The preamble should briefly explain what the generator, discriminator, and autoencoder each do and why adversarial loss produces better correlations than a plain autoencoder.
Section 1: Setup & Installation
# Install PyHealth from corgan branch
FORK = 'jalengg'
BRANCH = 'corgan-pr-integration'
install_url = f"git+https://github.com/{FORK}/PyHealth.git@{BRANCH}"
!pip install {install_url} --quiet --no-cache-dir
# Detect environment (Colab vs local)
try:
import google.colab
IN_COLAB = True
except ImportError:
IN_COLAB = False
# GPU detection
if torch.cuda.is_available():
print(f"GPU: {torch.cuda.get_device_name(0)}")
else:
print("WARNING: No GPU. Training will be very slow.")Required: Environment detection cell — if not in Colab, use ./ paths instead of /content/drive/MyDrive/. This makes the notebook usable on SLURM clusters and laptops, not just Colab. (The HALO notebook fails outside Colab due to hardcoded /content/drive/ paths — this is a known gap to fix.)
Section 2: Configuration
Centralized config block at the top — modeled after HALO but with CorGAN-specific parameters:
# ============================================================
# CONFIGURATION — Modify these parameters
# ============================================================
# --- Preset: Change this to switch Demo/Production ---
PRESET = "demo" # "demo" or "production"
# Training parameters
if PRESET == "demo":
EPOCHS = 5 # Quick smoke test (~20-30 min on T4)
BATCH_SIZE = 64
N_SYNTHETIC_SAMPLES = 1000
N_EPOCHS_PRETRAIN = 1 # Autoencoder pre-training epochs
elif PRESET == "production":
EPOCHS = 50 # Full training (~2-4 hrs on T4)
BATCH_SIZE = 512
N_SYNTHETIC_SAMPLES = 10000
N_EPOCHS_PRETRAIN = 3
# Model architecture
LATENT_DIM = 128 # Generator + decoder latent dimension
HIDDEN_DIM = 128 # Generator hidden dimension
AUTOENCODER_TYPE = "cnn" # "cnn", "cnn8layer", or "linear"
# WGAN parameters
N_ITER_D = 5 # Discriminator updates per generator update
CLAMP_LOWER = -0.01 # WGAN weight clipping
CLAMP_UPPER = 0.01
LR = 0.001 # Learning rate for all optimizers
# Reproducibility
SEED = 42
# Paths
BASE_DIR = '/content/drive/MyDrive/CorGAN_Training' if IN_COLAB else './corgan_training'Required: Comments must explain why each WGAN parameter exists (not just what it is). For example: "N_ITER_D = 5: Discriminator is updated more often than generator to prevent generator from over-optimizing a weak critic." This is absent from the HALO notebook and is essential for CorGAN because these parameters directly affect training stability.
Section 3: Data Upload
Two code paths:
Path A (with MIMIC-III):
# Upload DIAGNOSES_ICD.csv.gz and ADMISSIONS.csv.gz
# Validate CSV structure (columns, row count, sample display)Path B (demo/CI mode — no MIMIC-III):
if not mimic3_available:
print("MIMIC-III not found — using synthetic demo data")
from pyhealth.datasets.sample_dataset import InMemorySampleDataset
# Create InMemorySampleDataset with realistic-looking fake ICD-9 codes
# Sufficient to demonstrate the full pipelineRationale: The HALO notebook has no fallback — if MIMIC-III files aren't present, the notebook simply fails. For a CorGAN notebook, CI smoke tests and first-time users (who haven't yet obtained MIMIC-III) should be able to run the full pipeline on synthetic stand-in data. The InMemorySampleDataset pattern already exists in the test suite and should be reused here.
Section 4: Training
from pyhealth.models.generators.corgan import CorGAN
from pyhealth.tasks.corgan_generation import corgan_generation_mimic3_fn
# Apply task
sample_dataset = dataset.set_task(corgan_generation_mimic3_fn)
train_dataset, val_dataset = split_by_patient(sample_dataset, [0.85, 0.15])
# Initialize model
model = CorGAN(
dataset=train_dataset,
latent_dim=LATENT_DIM,
hidden_dim=HIDDEN_DIM,
batch_size=BATCH_SIZE,
epochs=EPOCHS,
n_epochs_pretrain=N_EPOCHS_PRETRAIN,
n_iter_D=N_ITER_D,
clamp_lower=CLAMP_LOWER,
clamp_upper=CLAMP_UPPER,
lr=LR,
autoencoder_type=AUTOENCODER_TYPE,
save_dir=checkpoint_dir,
)
# Train
model.train_model(train_dataset)Additionally required in this section:
-
Random seed cell — before any training:
import random, numpy as np torch.manual_seed(SEED) np.random.seed(SEED) random.seed(SEED) if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED) torch.backends.cudnn.deterministic = True
-
Hyperparameter config save — after model init, before training:
import json from datetime import datetime config_record = {k: v for k, v in globals().items() if k.isupper()} config_record['timestamp'] = datetime.now().isoformat() with open(f'{checkpoint_dir}/config.json', 'w') as f: json.dump(config_record, f, indent=2)
-
Loss curve plot — after training completes. This is the most important addition that HALO's notebook lacks entirely. CorGAN has three distinct losses that must be plotted together:
- Autoencoder reconstruction loss (pre-training phase)
- Discriminator Wasserstein loss
- Generator adversarial loss
Plot these as separate subplots (or overlaid with secondary axis). For demo mode (5 epochs, few batches), even a sparse curve is informative.
Markdown explanation in this section should cover the two-phase training: (1) autoencoder pre-trains to learn a latent representation of ICD-9 code vectors, then (2) the WGAN trains the generator to produce latent codes that the decoder maps back to realistic binary code vectors.
Section 5: Generation
synthetic = model.synthesize_dataset(num_samples=N_SYNTHETIC_SAMPLES)
# Display sample output
import pandas as pd
sample_df = pd.DataFrame([
{
"patient_id": p["patient_id"],
"n_codes": len(p["visits"]),
"sample_codes": ", ".join(p["visits"][:5]) + ("..." if len(p["visits"]) > 5 else "")
}
for p in synthetic[:10]
])
display(sample_df)Save in multiple formats:
synthetic_patients.json— full records for direct PyHealth reusesynthetic_patients.csv— flatSUBJECT_ID, ICD9_CODErows for external toolscorgan_final.pt— model checkpoint
One important contrast from HALO: In HALO's output, each row is a (SUBJECT_ID, VISIT_NUM, ICD9_CODE) triple because HALO generates visit sequences. In CorGAN, the output is a flat bag-of-codes — no VISIT_NUM column. The CSV should be SUBJECT_ID, ICD9_CODE only. Explain this in a markdown cell: CorGAN aggregates all a patient's diagnoses across admissions into a single flat set.
Section 6: Results & Evaluation
This is where CorGAN's notebook diverges most significantly from HALO's. HALO's results section only shows row counts and null-value checks. For CorGAN, the core scientific contribution must be validated.
Required cells:
6a. Vocabulary Coverage
all_generated_codes = set(code for p in synthetic for code in p["visits"])
vocab_size = dataset.input_processors["visits"].size()
coverage = len(all_generated_codes) / vocab_size * 100
print(f"Vocabulary coverage: {coverage:.1f}%")Low coverage indicates mode collapse. High coverage indicates diversity. This is a key CorGAN health metric.
6b. Code Count Statistics & Distribution
- Codes per patient: mean, std, min, max, median
- Histogram comparison — real training data vs synthetic (overlaid, normalized to density)
- This directly answers "does synthetic data have realistic patient complexity?"
6c. Code Frequency Comparison (Top 50 Codes)
# Real training data: count how often each code appears
# Synthetic data: count how often each code appears
# Plot: Side-by-side bar chart of top 20 codes in each
# Metric: Pearson correlation between real and synthetic code frequenciesThis is the central validation for CorGAN's claim of capturing code distributions. A high correlation (>0.9) indicates the model learned the marginal code distribution. This should be clearly labeled with its value and interpretation.
6d. All-Zeros Detection
empty_patients = sum(1 for p in synthetic if len(p["visits"]) == 0)
if empty_patients > 0:
print(f"WARNING: {empty_patients} patients with no codes (all-zeros generation)")
print("Consider adjusting binarization threshold or retraining with more epochs.")The HALO notebook checks for null values; CorGAN's equivalent is checking for all-zeros binary vectors (the documented edge case in synthesize_dataset's docstring).
6e. Quality Report
Save all metrics to quality_report.json:
{
"total_synthetic_patients": 10000,
"mean_codes_per_patient": 12.3,
"std_codes_per_patient": 8.1,
"min_codes": 0,
"max_codes": 45,
"unique_codes_generated": 4821,
"vocabulary_size": 6955,
"vocabulary_coverage_percent": 69.4,
"empty_patients_count": 3,
"code_frequency_pearson_r": 0.87,
"seed": 42,
"timestamp": "2025-01-31T14:23:00"
}This JSON file is the key CI artifact — a CI job can parse this file and fail if vocabulary_coverage_percent < 30 or code_frequency_pearson_r < 0.5.
Section 7: Download & Next Steps
- Download synthetic CSV, quality report JSON, and checkpoint
- "Congratulations" section (adopt HALO's encouraging tone)
- Troubleshooting: Out-of-memory (reduce batch_size), mode collapse (check empty_patients_count, reduce LR, increase N_ITER_D), slow training (use smaller vocab with
tables=["diagnoses_icd"]only) - References: CorGAN paper, PyHealth documentation, MIMIC-III PhysioNet link
Technical Requirements
File Location
examples/corgan_mimic3_colab.ipynb
Dependencies
No new dependencies. Uses only packages already in PyHealth's requirements.txt (torch, numpy, pandas, matplotlib). No seaborn or scipy required — use matplotlib directly.
Environment Compatibility
The notebook must work in both Colab and non-Colab environments (at minimum: local execution, SLURM cluster). Use IN_COLAB flag for drive mounting and files.download() calls.
Seeds
All stochastic operations must use SEED = 42 by default. This includes PyTorch, NumPy, and Python's random module.
Kernel Restart Safety
Section 1 (Setup) must be idempotent — re-running after a kernel restart should not fail. Checkpoint loading must be conditional on checkpoint existence.
Acceptance Criteria
- Notebook runs end-to-end in Colab on MIMIC-III without manual intervention after uploading CSVs
- Notebook runs end-to-end without MIMIC-III (demo/CI mode using
InMemorySampleDataset) - Section 2 Configuration is the only place hyperparameters need to be changed
- Three loss curves (autoencoder, discriminator, generator) are plotted after training
- Code frequency comparison (real vs synthetic, top 20) is shown in Section 6
-
quality_report.jsonis saved with all required metrics -
synthetic_patients.jsonandsynthetic_patients.csvare saved -
config.jsonwith all hyperparameters is saved alongside checkpoint - All-zeros patient detection is present
- Random seed is set before any stochastic operation
- Notebook works without Colab-specific imports outside of Drive mounting
- Preamble explains CorGAN architecture in plain language (not just "what it does" but "why it's different from HALO")
- Troubleshooting section covers mode collapse and OOM errors specifically
Out of Scope
- Correlation heatmaps (co-occurrence matrices) — too slow to compute in Colab for full vocabulary; could be a follow-up
- Conditional generation (generate patients with specific codes) — future work
- Comparison with HALO on same data — out of scope for this issue
- Downloading from Google Drive programmatically — use Colab's
files.download()UI