Skip to content
Draft
47 changes: 47 additions & 0 deletions examples/generate_synthetic_mimic3_promptehr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""PromptEHR: Synthetic MIMIC-III Patient Generation.

Load a trained PromptEHR checkpoint and generate synthetic patients.

Reference:
Wang et al. "PromptEHR: Conditional Electronic Healthcare Records
Generation with Prompt Learning." EMNLP 2023.
https://arxiv.org/abs/2211.01761
"""

import json

from pyhealth.datasets import MIMIC3Dataset
from pyhealth.models import PromptEHR
from pyhealth.tasks import promptehr_generation_mimic3_fn

MIMIC3_ROOT = "/srv/local/data/physionet.org/files/mimiciii/1.4"
CHECKPOINT_PATH = "./save/promptehr/checkpoint.pt"
OUTPUT_PATH = "./save/promptehr/synthetic_patients.json"
NUM_SAMPLES = 10_000

# 1. Load dataset + apply task (needed for processor/vocab reconstruction)
dataset = MIMIC3Dataset(
root=MIMIC3_ROOT,
tables=["patients", "admissions", "diagnoses_icd"],
code_mapping={},
)
sample_dataset = dataset.set_task(promptehr_generation_mimic3_fn)

# 2. Load checkpoint
model = PromptEHR(dataset=sample_dataset)
model.load_model(CHECKPOINT_PATH)
print(f"Loaded checkpoint from {CHECKPOINT_PATH}")

# 3. Generate
print(f"Generating {NUM_SAMPLES} synthetic patients...")
synthetic = model.synthesize_dataset(num_samples=NUM_SAMPLES)
print(f"Generated {len(synthetic)} patients")

# 4. Save
with open(OUTPUT_PATH, "w") as f:
json.dump(synthetic, f, indent=2)
print(f"Saved to {OUTPUT_PATH}")

# Summary stats
avg_visits = sum(len(p["visits"]) for p in synthetic) / len(synthetic)
print(f"Average visits per patient: {avg_visits:.2f}")
252 changes: 252 additions & 0 deletions examples/promptehr_mimic3_colab.ipynb

Large diffs are not rendered by default.

47 changes: 47 additions & 0 deletions examples/promptehr_mimic3_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""PromptEHR: Training on MIMIC-III.

Train PromptEHR for synthetic EHR generation using PyHealth 2.0 API.

Reference:
Wang et al. "PromptEHR: Conditional Electronic Health Records Generation
with Prompt Learning." CHIL 2023.
"""

from pyhealth.datasets import MIMIC3Dataset, split_by_patient
from pyhealth.models import PromptEHR
from pyhealth.tasks import promptehr_generation_mimic3_fn

MIMIC3_ROOT = "/srv/local/data/physionet.org/files/mimiciii/1.4"

# 1. Load MIMIC-III
dataset = MIMIC3Dataset(
root=MIMIC3_ROOT,
tables=["patients", "admissions", "diagnoses_icd"],
code_mapping={},
)

# 2. Apply generation task
sample_dataset = dataset.set_task(promptehr_generation_mimic3_fn)
print(f"Patients: {len(sample_dataset)}")
sample_dataset.stat()

# 3. Split
train, val, test = split_by_patient(sample_dataset, [0.8, 0.1, 0.1])

# 4. Initialize model
model = PromptEHR(
dataset=sample_dataset,
n_num_features=1,
cat_cardinalities=[2],
d_hidden=128,
prompt_length=1,
epochs=20,
batch_size=16,
lr=1e-5,
warmup_steps=1000,
save_dir="./save/promptehr/",
)

# 5. Train
model.train_model(train, val)
print("Training complete. Checkpoint saved to ./save/promptehr/")
1 change: 1 addition & 0 deletions pyhealth/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .grasp import GRASP, GRASPLayer
from .medlink import MedLink
from .micron import MICRON, MICRONLayer
from .promptehr import PromptEHR
from .mlp import MLP
from .molerec import MoleRec, MoleRecLayer
from .retain import RETAIN, RETAINLayer
Expand Down
41 changes: 41 additions & 0 deletions pyhealth/models/promptehr/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""PromptEHR: Prompt-based BART model for synthetic EHR generation.

This module provides a demographic-conditioned sequence-to-sequence model
for generating realistic synthetic electronic health records.

Main components:
- PromptEHR: Main model class (inherits from BaseModel)
- ConditionalPromptEncoder: Demographic conditioning with reparameterization
- PromptBartEncoder: Modified BART encoder with prompt injection
- PromptBartDecoder: Modified BART decoder with prompt injection
- VisitStructureSampler: Utility for structure-constrained generation
- Generation functions: sample_demographics, parse_sequence_to_visits, etc.
"""

from .model import PromptEHR
from .conditional_prompt import ConditionalPromptEncoder
from .bart_encoder import PromptBartEncoder
from .bart_decoder import PromptBartDecoder
from .visit_sampler import VisitStructureSampler
from .generation import (
DemographicSampler,
sample_demographics,
decode_patient_demographics,
parse_sequence_to_visits,
generate_patient_sequence_conditional,
generate_patient_with_structure_constraints
)

__all__ = [
"PromptEHR",
"ConditionalPromptEncoder",
"PromptBartEncoder",
"PromptBartDecoder",
"VisitStructureSampler",
"DemographicSampler",
"sample_demographics",
"decode_patient_demographics",
"parse_sequence_to_visits",
"generate_patient_sequence_conditional",
"generate_patient_with_structure_constraints",
]
Loading