Skip to content

[medgan-pr-prep] T4: Update training example (medgan_mimic3_training.py) #40

@jalengg

Description

@jalengg

Goal

Create/rewrite examples/medgan_mimic3_training.py as a clean PyHealth 2.0 example using the new API.

Template

from pyhealth.datasets import MIMIC3Dataset, split_by_patient
from pyhealth.tasks import medgan_generation_mimic3_fn
from pyhealth.models import MedGAN

root = "/srv/local/data/physionet.org/files/mimiciii/1.4"

dataset = MIMIC3Dataset(root=root, tables=["diagnoses_icd", "admissions"])
sample_dataset = dataset.set_task(medgan_generation_mimic3_fn)
print(f"Total samples: {len(sample_dataset)}")

train, val, test = split_by_patient(sample_dataset, [0.8, 0.1, 0.1])

model = MedGAN(
    dataset=sample_dataset,
    latent_dim=128,
    hidden_dim=128,
    autoencoder_hidden_dim=128,
    discriminator_hidden_dim=256,
    minibatch_averaging=True,
    save_dir="./medgan_save/",
)

model.train_model(
    train_dataset=train,
    val_dataset=val,
    ae_epochs=100,
    gan_epochs=100,
    batch_size=32,
)

synthetic = model.synthesize_dataset(num_samples=1000)
print(f"Generated {len(synthetic)} patients")
print("Sample:", synthetic[0])

Acceptance criteria

  • File is examples/medgan_mimic3_training.py
  • Uses only PyHealth 2.0 API (MIMIC3Dataset, set_task, MedGAN, train_model, synthesize_dataset)
  • No raw CSV/numpy imports
  • MIMIC-III path is a variable, not hardcoded in function calls

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions