Skip to content

[medgan-pr-prep] T5: Update generation example (generate_synthetic_mimic3_medgan.py) #41

@jalengg

Description

@jalengg

Goal

Rewrite examples/generate_synthetic_mimic3_medgan.py to use the PyHealth 2.0 API (load a trained model, generate samples, save to CSV).

Template

"""Generate synthetic EHR data using a pre-trained MedGAN model."""
import csv
from pyhealth.datasets import MIMIC3Dataset
from pyhealth.tasks import medgan_generation_mimic3_fn
from pyhealth.models import MedGAN

root = "/srv/local/data/physionet.org/files/mimiciii/1.4"
checkpoint_path = "./medgan_save/medgan_checkpoint.pt"
output_path = "./medgan_synthetic_1000.csv"

dataset = MIMIC3Dataset(root=root, tables=["diagnoses_icd", "admissions"])
sample_dataset = dataset.set_task(medgan_generation_mimic3_fn)

model = MedGAN.load_model(checkpoint_path, dataset=sample_dataset)

synthetic = model.synthesize_dataset(num_samples=1000)

with open(output_path, "w", newline="") as f:
    writer = csv.writer(f)
    writer.writerow(["patient_id", "visits"])
    for record in synthetic:
        writer.writerow([record["patient_id"], "|".join(record["visits"])])

print(f"Saved {len(synthetic)} synthetic patients to {output_path}")

Acceptance criteria

  • File is examples/generate_synthetic_mimic3_medgan.py
  • Uses only PyHealth 2.0 API
  • Output is CSV with patient_id and visits columns
  • No numpy/binary matrix logic

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions