forked from sunlabuiuc/PyHealth
-
Notifications
You must be signed in to change notification settings - Fork 0
Open
Description
Goal
Update all docstrings in pyhealth/models/generators/medgan.py to Google/PyHealth style, matching the pattern established by HALO and CorGAN.
Classes to document
MedGAN (main class)
class MedGAN(BaseModel):
"""GAN-based synthetic EHR generator using a two-phase training strategy.
Implements MedGAN (Choi et al., 2017) adapted for PyHealth 2.0.
Pretrains a linear autoencoder before adversarial training with
standard binary cross-entropy loss.
Args:
dataset (SampleDataset): PyHealth sample dataset with ``multi_hot``
schema for the ``"visits"`` key.
latent_dim (int): Noise dimension for the generator. Default: 128.
hidden_dim (int): Hidden dimension for generator residual layers. Default: 128.
autoencoder_hidden_dim (int): Encoder/decoder bottleneck size. Default: 128.
discriminator_hidden_dim (int): First hidden layer of discriminator. Default: 256.
minibatch_averaging (bool): Concatenate batch-mean features in discriminator. Default: True.
save_dir (str): Directory for checkpoint files. Default: ``"./save/"``.
Examples:
>>> from pyhealth.datasets.sample_dataset import InMemorySampleDataset
>>> samples = [{"patient_id": "p1", "visits": ["401.9", "250.00"]}]
>>> dataset = InMemorySampleDataset(samples, {"visits": "multi_hot"}, {})
>>> model = MedGAN(dataset, latent_dim=16, hidden_dim=16, autoencoder_hidden_dim=16)
>>> isinstance(model, MedGAN)
True
"""train_model()
Document both phases, all parameters, return type.
synthesize_dataset()
Returns:
list of dict: One record per synthetic patient. Each dict has:
``"patient_id"`` (str): Identifier, e.g. ``"synthetic_0"``.
``"visits"`` (list of str): Decoded ICD code strings.
Sub-modules: MedGANAutoencoder, MedGANGenerator, MedGANDiscriminator
Keep existing structure, add proper Args/Returns sections.
Acceptance criteria
- All public methods have Google-style docstrings
- Class-level docstring includes
Args:andExamples:sections synthesize_dataset()Returns:block uses named sub-fields (patient_id,visits)- No undocumented public methods
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels