Skip to content

[corgan-notebook] Create CorGAN Colab notebook: end-to-end training, generation, and evaluation #35

@jalengg

Description

@jalengg

Overview

Create examples/corgan_mimic3_colab.ipynb — a Google Colab notebook for CorGAN synthetic EHR generation on MIMIC-III data, modeled after examples/halo_mimic3_colab.ipynb from the HALO PR. The notebook must serve three audiences simultaneously: practitioners who want a working pipeline, researchers who need to verify scientific claims, and CI systems that need to run a smoke-test without MIMIC-III data.

Reference notebook: examples/halo_mimic3_colab.ipynb (branch halo-pr-528)
Branch: corgan-pr-integration


Why CorGAN Needs More Than a Cloned HALO Notebook

CorGAN and HALO differ fundamentally in what needs to be demonstrated:

Dimension HALO CorGAN
Data structure Sequential visits (temporal) Flat bag-of-codes (multi-hot binary vector per patient)
Architecture Transformer language model CNN Autoencoder + WGAN (3 interacting components)
Key innovation Hierarchical autoregressive generation Correlation capture via 1D convolutions
Training paradigm Supervised (validation loss) Adversarial (3 separate loss curves)
Convergence signal Decreasing validation loss Wasserstein distance stabilization
What must be shown Output format, code/visit stats Correlation preservation, adversarial dynamics

CorGAN's entire scientific contribution is that CNN autoencoders capture inter-code correlations better than linear models. This must be empirically demonstrated in the notebook, not just asserted. The HALO notebook only validates format and counts; a CorGAN notebook must validate statistical fidelity.


Required Notebook Sections

Section 0: Preamble (Markdown)

Adapted from HALO's excellent preamble pattern.

  • What you'll need: MIMIC-III access (or demo mode without it), GPU recommended, Colab with runtime set to GPU
  • What you'll get: A trained CorGAN model + 10,000 synthetic MIMIC-III patients
  • How long it takes: Demo (5 epochs, 20–30 min), Production (50 epochs, 2–4 hrs on T4)
  • What makes CorGAN different: One short paragraph explaining the CNN autoencoder + WGAN architecture and why correlation capture matters clinically
  • Reference: Baowaly et al., "Synthesizing Electronic Health Records Using Improved Generative Adversarial Networks", JAMIA 2019

Important: Unlike HALO's notebook which omits an architecture explanation, CorGAN's adversarial training is non-obvious to practitioners. The preamble should briefly explain what the generator, discriminator, and autoencoder each do and why adversarial loss produces better correlations than a plain autoencoder.


Section 1: Setup & Installation

# Install PyHealth from corgan branch
FORK = 'jalengg'
BRANCH = 'corgan-pr-integration'
install_url = f"git+https://github.com/{FORK}/PyHealth.git@{BRANCH}"
!pip install {install_url} --quiet --no-cache-dir

# Detect environment (Colab vs local)
try:
    import google.colab
    IN_COLAB = True
except ImportError:
    IN_COLAB = False

# GPU detection
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
else:
    print("WARNING: No GPU. Training will be very slow.")

Required: Environment detection cell — if not in Colab, use ./ paths instead of /content/drive/MyDrive/. This makes the notebook usable on SLURM clusters and laptops, not just Colab. (The HALO notebook fails outside Colab due to hardcoded /content/drive/ paths — this is a known gap to fix.)


Section 2: Configuration

Centralized config block at the top — modeled after HALO but with CorGAN-specific parameters:

# ============================================================
# CONFIGURATION — Modify these parameters
# ============================================================

# --- Preset: Change this to switch Demo/Production ---
PRESET = "demo"  # "demo" or "production"

# Training parameters
if PRESET == "demo":
    EPOCHS = 5              # Quick smoke test (~20-30 min on T4)
    BATCH_SIZE = 64
    N_SYNTHETIC_SAMPLES = 1000
    N_EPOCHS_PRETRAIN = 1   # Autoencoder pre-training epochs
elif PRESET == "production":
    EPOCHS = 50             # Full training (~2-4 hrs on T4)
    BATCH_SIZE = 512
    N_SYNTHETIC_SAMPLES = 10000
    N_EPOCHS_PRETRAIN = 3

# Model architecture
LATENT_DIM = 128        # Generator + decoder latent dimension
HIDDEN_DIM = 128        # Generator hidden dimension
AUTOENCODER_TYPE = "cnn"  # "cnn", "cnn8layer", or "linear"

# WGAN parameters
N_ITER_D = 5            # Discriminator updates per generator update
CLAMP_LOWER = -0.01     # WGAN weight clipping
CLAMP_UPPER = 0.01
LR = 0.001              # Learning rate for all optimizers

# Reproducibility
SEED = 42

# Paths
BASE_DIR = '/content/drive/MyDrive/CorGAN_Training' if IN_COLAB else './corgan_training'

Required: Comments must explain why each WGAN parameter exists (not just what it is). For example: "N_ITER_D = 5: Discriminator is updated more often than generator to prevent generator from over-optimizing a weak critic." This is absent from the HALO notebook and is essential for CorGAN because these parameters directly affect training stability.


Section 3: Data Upload

Two code paths:

Path A (with MIMIC-III):

# Upload DIAGNOSES_ICD.csv.gz and ADMISSIONS.csv.gz
# Validate CSV structure (columns, row count, sample display)

Path B (demo/CI mode — no MIMIC-III):

if not mimic3_available:
    print("MIMIC-III not found — using synthetic demo data")
    from pyhealth.datasets.sample_dataset import InMemorySampleDataset
    # Create InMemorySampleDataset with realistic-looking fake ICD-9 codes
    # Sufficient to demonstrate the full pipeline

Rationale: The HALO notebook has no fallback — if MIMIC-III files aren't present, the notebook simply fails. For a CorGAN notebook, CI smoke tests and first-time users (who haven't yet obtained MIMIC-III) should be able to run the full pipeline on synthetic stand-in data. The InMemorySampleDataset pattern already exists in the test suite and should be reused here.


Section 4: Training

from pyhealth.models.generators.corgan import CorGAN
from pyhealth.tasks.corgan_generation import corgan_generation_mimic3_fn

# Apply task
sample_dataset = dataset.set_task(corgan_generation_mimic3_fn)
train_dataset, val_dataset = split_by_patient(sample_dataset, [0.85, 0.15])

# Initialize model
model = CorGAN(
    dataset=train_dataset,
    latent_dim=LATENT_DIM,
    hidden_dim=HIDDEN_DIM,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    n_epochs_pretrain=N_EPOCHS_PRETRAIN,
    n_iter_D=N_ITER_D,
    clamp_lower=CLAMP_LOWER,
    clamp_upper=CLAMP_UPPER,
    lr=LR,
    autoencoder_type=AUTOENCODER_TYPE,
    save_dir=checkpoint_dir,
)

# Train
model.train_model(train_dataset)

Additionally required in this section:

  1. Random seed cell — before any training:

    import random, numpy as np
    torch.manual_seed(SEED)
    np.random.seed(SEED)
    random.seed(SEED)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
  2. Hyperparameter config save — after model init, before training:

    import json
    from datetime import datetime
    config_record = {k: v for k, v in globals().items() if k.isupper()}
    config_record['timestamp'] = datetime.now().isoformat()
    with open(f'{checkpoint_dir}/config.json', 'w') as f:
        json.dump(config_record, f, indent=2)
  3. Loss curve plot — after training completes. This is the most important addition that HALO's notebook lacks entirely. CorGAN has three distinct losses that must be plotted together:

    • Autoencoder reconstruction loss (pre-training phase)
    • Discriminator Wasserstein loss
    • Generator adversarial loss

    Plot these as separate subplots (or overlaid with secondary axis). For demo mode (5 epochs, few batches), even a sparse curve is informative.

Markdown explanation in this section should cover the two-phase training: (1) autoencoder pre-trains to learn a latent representation of ICD-9 code vectors, then (2) the WGAN trains the generator to produce latent codes that the decoder maps back to realistic binary code vectors.


Section 5: Generation

synthetic = model.synthesize_dataset(num_samples=N_SYNTHETIC_SAMPLES)

# Display sample output
import pandas as pd
sample_df = pd.DataFrame([
    {
        "patient_id": p["patient_id"],
        "n_codes": len(p["visits"]),
        "sample_codes": ", ".join(p["visits"][:5]) + ("..." if len(p["visits"]) > 5 else "")
    }
    for p in synthetic[:10]
])
display(sample_df)

Save in multiple formats:

  • synthetic_patients.json — full records for direct PyHealth reuse
  • synthetic_patients.csv — flat SUBJECT_ID, ICD9_CODE rows for external tools
  • corgan_final.pt — model checkpoint

One important contrast from HALO: In HALO's output, each row is a (SUBJECT_ID, VISIT_NUM, ICD9_CODE) triple because HALO generates visit sequences. In CorGAN, the output is a flat bag-of-codes — no VISIT_NUM column. The CSV should be SUBJECT_ID, ICD9_CODE only. Explain this in a markdown cell: CorGAN aggregates all a patient's diagnoses across admissions into a single flat set.


Section 6: Results & Evaluation

This is where CorGAN's notebook diverges most significantly from HALO's. HALO's results section only shows row counts and null-value checks. For CorGAN, the core scientific contribution must be validated.

Required cells:

6a. Vocabulary Coverage

all_generated_codes = set(code for p in synthetic for code in p["visits"])
vocab_size = dataset.input_processors["visits"].size()
coverage = len(all_generated_codes) / vocab_size * 100
print(f"Vocabulary coverage: {coverage:.1f}%")

Low coverage indicates mode collapse. High coverage indicates diversity. This is a key CorGAN health metric.

6b. Code Count Statistics & Distribution

  • Codes per patient: mean, std, min, max, median
  • Histogram comparison — real training data vs synthetic (overlaid, normalized to density)
  • This directly answers "does synthetic data have realistic patient complexity?"

6c. Code Frequency Comparison (Top 50 Codes)

# Real training data: count how often each code appears
# Synthetic data: count how often each code appears
# Plot: Side-by-side bar chart of top 20 codes in each
# Metric: Pearson correlation between real and synthetic code frequencies

This is the central validation for CorGAN's claim of capturing code distributions. A high correlation (>0.9) indicates the model learned the marginal code distribution. This should be clearly labeled with its value and interpretation.

6d. All-Zeros Detection

empty_patients = sum(1 for p in synthetic if len(p["visits"]) == 0)
if empty_patients > 0:
    print(f"WARNING: {empty_patients} patients with no codes (all-zeros generation)")
    print("Consider adjusting binarization threshold or retraining with more epochs.")

The HALO notebook checks for null values; CorGAN's equivalent is checking for all-zeros binary vectors (the documented edge case in synthesize_dataset's docstring).

6e. Quality Report

Save all metrics to quality_report.json:

{
    "total_synthetic_patients": 10000,
    "mean_codes_per_patient": 12.3,
    "std_codes_per_patient": 8.1,
    "min_codes": 0,
    "max_codes": 45,
    "unique_codes_generated": 4821,
    "vocabulary_size": 6955,
    "vocabulary_coverage_percent": 69.4,
    "empty_patients_count": 3,
    "code_frequency_pearson_r": 0.87,
    "seed": 42,
    "timestamp": "2025-01-31T14:23:00"
}

This JSON file is the key CI artifact — a CI job can parse this file and fail if vocabulary_coverage_percent < 30 or code_frequency_pearson_r < 0.5.


Section 7: Download & Next Steps

  • Download synthetic CSV, quality report JSON, and checkpoint
  • "Congratulations" section (adopt HALO's encouraging tone)
  • Troubleshooting: Out-of-memory (reduce batch_size), mode collapse (check empty_patients_count, reduce LR, increase N_ITER_D), slow training (use smaller vocab with tables=["diagnoses_icd"] only)
  • References: CorGAN paper, PyHealth documentation, MIMIC-III PhysioNet link

Technical Requirements

File Location

examples/corgan_mimic3_colab.ipynb

Dependencies

No new dependencies. Uses only packages already in PyHealth's requirements.txt (torch, numpy, pandas, matplotlib). No seaborn or scipy required — use matplotlib directly.

Environment Compatibility

The notebook must work in both Colab and non-Colab environments (at minimum: local execution, SLURM cluster). Use IN_COLAB flag for drive mounting and files.download() calls.

Seeds

All stochastic operations must use SEED = 42 by default. This includes PyTorch, NumPy, and Python's random module.

Kernel Restart Safety

Section 1 (Setup) must be idempotent — re-running after a kernel restart should not fail. Checkpoint loading must be conditional on checkpoint existence.


Acceptance Criteria

  • Notebook runs end-to-end in Colab on MIMIC-III without manual intervention after uploading CSVs
  • Notebook runs end-to-end without MIMIC-III (demo/CI mode using InMemorySampleDataset)
  • Section 2 Configuration is the only place hyperparameters need to be changed
  • Three loss curves (autoencoder, discriminator, generator) are plotted after training
  • Code frequency comparison (real vs synthetic, top 20) is shown in Section 6
  • quality_report.json is saved with all required metrics
  • synthetic_patients.json and synthetic_patients.csv are saved
  • config.json with all hyperparameters is saved alongside checkpoint
  • All-zeros patient detection is present
  • Random seed is set before any stochastic operation
  • Notebook works without Colab-specific imports outside of Drive mounting
  • Preamble explains CorGAN architecture in plain language (not just "what it does" but "why it's different from HALO")
  • Troubleshooting section covers mode collapse and OOM errors specifically

Out of Scope

  • Correlation heatmaps (co-occurrence matrices) — too slow to compute in Colab for full vocabulary; could be a follow-up
  • Conditional generation (generate patients with specific codes) — future work
  • Comparison with HALO on same data — out of scope for this issue
  • Downloading from Google Drive programmatically — use Colab's files.download() UI

Metadata

Metadata

Assignees

Labels

No labels
No labels

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions