diff --git a/examples/ehr_generation/ehr_generation_mimic3_transformer.py b/examples/ehr_generation/ehr_generation_mimic3_transformer.py new file mode 100644 index 000000000..dc599551d --- /dev/null +++ b/examples/ehr_generation/ehr_generation_mimic3_transformer.py @@ -0,0 +1,168 @@ +""" +EHR Generation with a GPT-2 style Transformer on MIMIC-III (PyHealth) +====================================================================== + +This example applies the :class:`~pyhealth.models.generators.EHRGPTBaseline` +model to MIMIC-III data and generates synthetic patient EHR sequences. + +The pipeline: + +1. Load MIMIC-III via **PyHealth** and apply the ``EHRGenerationMIMIC3`` task + to obtain per-patient nested visit sequences. +2. Serialise the nested sequences into plain text using ``VISIT_DELIM`` + separators (e.g. ``"250.00 401.9 VISIT_DELIM 272.0 428.0"``). +3. Train a word-level GPT-2 model via :meth:`EHRGPTBaseline.fit`. +4. Sample synthetic sequences via :meth:`EHRGPTBaseline.generate` and + save the resulting ``(SUBJECT_ID, HADM_ID, ICD9_CODE)`` DataFrame. + +References +---------- +- *Accelerating Reproducible Research in Synthetic EHR Generation* (CHIL 2026) + +Usage +----- +.. code-block:: bash + + # Full vocabulary (~6,955 ICD-9 codes) – recommended + python ehr_generation_mimic3_transformer.py \\ + --mimic3_root /path/to/mimic-iii-clinical-database-1.4 \\ + --output_dir ./synthetic_output + + # Optional: replicate the legacy 3-digit vocabulary + python ehr_generation_mimic3_transformer.py \\ + --mimic3_root /path/to/mimic-iii \\ + --truncate_icd \\ + --output_dir ./synthetic_output_3digit +""" + +import argparse +import os + +import torch + +from pyhealth.datasets import MIMIC3Dataset, split_by_patient +from pyhealth.models.generators import EHRGPTBaseline, samples_to_sequences +from pyhealth.tasks import EHRGenerationMIMIC3 + + +def main(args: argparse.Namespace) -> None: + os.makedirs(args.output_dir, exist_ok=True) + print(f"Using device: {'cuda' if torch.cuda.is_available() else 'cpu'}") + + # ------------------------------------------------------------------ + # STEP 1: Load MIMIC-III via PyHealth + # ------------------------------------------------------------------ + print("\nSTEP 1: Loading MIMIC-III dataset ...") + base_dataset = MIMIC3Dataset( + root=args.mimic3_root, + tables=["diagnoses_icd"], + ) + base_dataset.stats() + + # ------------------------------------------------------------------ + # STEP 2: Apply EHRGenerationMIMIC3 task + # ------------------------------------------------------------------ + print("\nSTEP 2: Applying EHRGenerationMIMIC3 task ...") + task = EHRGenerationMIMIC3( + min_visits=args.min_visits, + truncate_icd=args.truncate_icd, + ) + sample_dataset = base_dataset.set_task(task) + print(f" Total patients: {len(sample_dataset)}") + + train_dataset, _, _ = split_by_patient(sample_dataset, [0.8, 0.1, 0.1]) + print(f" Training patients: {len(train_dataset)}") + + # ------------------------------------------------------------------ + # STEP 3: Serialise to text sequences + # ------------------------------------------------------------------ + print("\nSTEP 3: Serialising patient sequences ...") + train_samples = list(train_dataset) + text_data = samples_to_sequences(train_samples) + max_len = max(len(seq.split()) for seq in text_data) + print(f" Max sequence length (tokens): {max_len}") + + # ------------------------------------------------------------------ + # STEP 4 - 6: Build tokeniser, initialise GPT-2, train + # ------------------------------------------------------------------ + print("\nSTEP 4-6: Building tokeniser and training GPT-2 ...") + model = EHRGPTBaseline( + n_embd=args.n_embd, + n_layer=args.n_layer, + n_head=args.n_head, + max_seq_len=args.max_seq_len, + ) + model.fit( + sequences=text_data, + output_dir=args.output_dir, + epochs=args.epochs, + batch_size=args.batch_size, + ) + n_params = sum(p.numel() for p in model.model.parameters()) / 1e6 + print(f" Vocabulary size : {len(model.tokenizer)}") + print(f" Model parameters: {n_params:.1f}M") + + # ------------------------------------------------------------------ + # STEP 7: Generate synthetic EHRs + # ------------------------------------------------------------------ + print(f"\nSTEP 7: Generating {args.num_synthetic} synthetic patients ...") + synthetic_df = model.generate( + n_patients=args.num_synthetic, + batch_size=args.gen_batch_size, + ) + print( + f" Generated {synthetic_df['SUBJECT_ID'].nunique()} patients, " + f"{synthetic_df.shape[0]} (patient, visit, code) rows" + ) + + out_csv = os.path.join(args.output_dir, "synthetic_ehr.csv") + synthetic_df.to_csv(out_csv, index=False) + print(f" Synthetic data saved to: {out_csv}") + + +# -- CLI entry point ----------------------------------------------------------- + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Train a GPT-2 transformer for synthetic EHR generation (MIMIC-III).", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--mimic3_root", + type=str, + required=True, + help="Path to the MIMIC-III root directory (raw CSV/CSV.GZ files).", + ) + parser.add_argument( + "--output_dir", + type=str, + default="./ehr_generation_output", + help="Directory to save checkpoints and synthetic data.", + ) + parser.add_argument( + "--min_visits", + type=int, + default=1, + help="Minimum valid admissions a patient must have.", + ) + parser.add_argument( + "--truncate_icd", + action="store_true", + default=False, + help="Truncate ICD-9 codes to 3-digit prefixes.", + ) + parser.add_argument("--n_embd", type=int, default=512, help="Embedding dimension.") + parser.add_argument("--n_layer", type=int, default=8, help="Number of transformer layers.") + parser.add_argument("--n_head", type=int, default=8, help="Number of attention heads.") + parser.add_argument( + "--max_seq_len", type=int, default=512, help="Maximum token sequence length." + ) + parser.add_argument("--epochs", type=int, default=50, help="Training epochs.") + parser.add_argument("--batch_size", type=int, default=64, help="Training batch size.") + parser.add_argument( + "--num_synthetic", type=int, default=10000, help="Synthetic patients to generate." + ) + parser.add_argument( + "--gen_batch_size", type=int, default=512, help="Generation batch size." + ) + main(parser.parse_args()) diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index a13b18a51..b29cd3b36 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -43,3 +43,10 @@ from .sdoh import SdohClassifier from .medlink import MedLink from .unified_embedding import UnifiedMultimodalEmbeddingModel, SinusoidalTimeEmbedding +from .generators import ( + EHRGPTBaseline, + EHRTextDataset, + build_tokenizer, + samples_to_sequences, + sequences_to_dataframe, +) diff --git a/pyhealth/models/generators/__init__.py b/pyhealth/models/generators/__init__.py new file mode 100644 index 000000000..b2e5c1ca5 --- /dev/null +++ b/pyhealth/models/generators/__init__.py @@ -0,0 +1,19 @@ +"""Generative models for synthetic EHR generation.""" + +from .gpt_baseline import ( + VISIT_DELIM, + EHRGPTBaseline, + EHRTextDataset, + build_tokenizer, + samples_to_sequences, + sequences_to_dataframe, +) + +__all__ = [ + "VISIT_DELIM", + "EHRGPTBaseline", + "EHRTextDataset", + "build_tokenizer", + "samples_to_sequences", + "sequences_to_dataframe", +] diff --git a/pyhealth/models/generators/gpt_baseline.py b/pyhealth/models/generators/gpt_baseline.py new file mode 100644 index 000000000..4daf91d69 --- /dev/null +++ b/pyhealth/models/generators/gpt_baseline.py @@ -0,0 +1,506 @@ +""" +GPT-2 Baseline for Synthetic EHR Generation +============================================ + +This module provides a self-contained GPT-2 decoder-only language model for +generating synthetic longitudinal EHR sequences composed of ICD-9 diagnosis +codes. + +Design +------ +Patient histories are first serialised as space-separated code sequences where +consecutive visits are separated by the special ``VISIT_DELIM`` token: + + ``"250.00 401.9 VISIT_DELIM 272.0 428.0 VISIT_DELIM 250.00"`` + +A word-level HuggingFace tokeniser is then trained on this corpus, and a +GPT-2 causal language model is fine-tuned on the resulting token IDs. At +inference time, sequences are sampled autoregressively and deserialised back +to a long-form ``(SUBJECT_ID, HADM_ID, ICD9_CODE)`` DataFrame. + +References +---------- +- *Accelerating Reproducible Research in Synthetic EHR Generation* (CHIL 2026) + +Typical usage +------------- +.. code-block:: python + + from pyhealth.models.generators import EHRGPTBaseline + from pyhealth.tasks.ehr_generation import samples_to_sequences + + model = EHRGPTBaseline(n_embd=256, n_layer=4, n_head=4) + model.fit(text_sequences, output_dir="./checkpoints", epochs=20) + synthetic_df = model.generate(n_patients=1000) +""" + +import os +from typing import Optional + +import pandas as pd +import torch +import torch.nn as nn +from tokenizers import Tokenizer, models, pre_tokenizers, processors, trainers +from torch.utils.data import Dataset +from tqdm import trange +from transformers import ( + DataCollatorForLanguageModeling, + GPT2Config, + GPT2LMHeadModel, + PreTrainedTokenizerFast, + Trainer, + TrainingArguments, +) + +__all__ = [ + "VISIT_DELIM", + "samples_to_sequences", + "sequences_to_dataframe", + "build_tokenizer", + "EHRTextDataset", + "EHRGPTBaseline", +] + +# ── Constants ────────────────────────────────────────────────────────────────── + +VISIT_DELIM = "VISIT_DELIM" + + +# ── Sequence helpers ─────────────────────────────────────────────────────────── + + +def samples_to_sequences(samples: list) -> list[str]: + """Convert ``EHRGenerationMIMIC3`` samples to VISIT_DELIM-delimited text. + + Each sample's ``conditions`` field is a ``List[List[str]]`` (visits × codes). + Adjacent visits are joined by ``VISIT_DELIM`` so the full patient history + becomes a single space-separated string. + + Args: + samples: List of dicts with at least a ``"conditions"`` key. + + Returns: + One string per patient, e.g. + ``"250.00 401.9 VISIT_DELIM 272.0 428.0 VISIT_DELIM 250.00"``. + + Examples: + >>> samples = [{"conditions": [["250.00", "401.9"], ["272.0"]]}] + >>> samples_to_sequences(samples) + ['250.00 401.9 VISIT_DELIM 272.0'] + """ + sequences: list[str] = [] + for sample in samples: + visit_texts = [" ".join(visit_codes) for visit_codes in sample["conditions"]] + sequences.append(f" {VISIT_DELIM} ".join(visit_texts)) + return sequences + + +def sequences_to_dataframe(sequences: list[str]) -> pd.DataFrame: + """Deserialise generated text sequences to long-form EHR rows. + + Assigns synthetic sequential identifiers; original MIMIC-III IDs are not + preserved (generation is unconditional). + + Args: + sequences: Generated text sequences from :meth:`EHRGPTBaseline.generate`. + + Returns: + A ``pd.DataFrame`` with columns ``SUBJECT_ID``, ``HADM_ID``, + ``ICD9_CODE``. + + Examples: + >>> sequences_to_dataframe(["250.00 VISIT_DELIM 401.9"]) + SUBJECT_ID HADM_ID ICD9_CODE + 0 0 0 250.00 + 1 0 1 401.9 + """ + rows: list[dict] = [] + for subj_idx, seq in enumerate(sequences): + for hadm_idx, visit_str in enumerate(seq.strip().split(VISIT_DELIM)): + for code in visit_str.strip().split(): + if code: + rows.append( + { + "SUBJECT_ID": subj_idx, + "HADM_ID": hadm_idx, + "ICD9_CODE": code, + } + ) + return pd.DataFrame(rows) + + +# ── Tokeniser ────────────────────────────────────────────────────────────────── + + +def build_tokenizer(text_data: list[str]) -> PreTrainedTokenizerFast: + """Build and train a word-level tokeniser on an EHR text corpus. + + Uses the HuggingFace ``tokenizers`` library. Special tokens: + + * ``[UNK]`` – unknown token + * ``[PAD]`` – padding + * ``[BOS]`` – beginning-of-sequence + * ``[EOS]`` – end-of-sequence + + ``VISIT_DELIM`` is treated as a regular vocabulary word so the model + learns its visit-boundary semantics. + + Args: + text_data: List of space-separated code sequences (one per patient). + + Returns: + A ``PreTrainedTokenizerFast`` wrapping the trained word-level model. + + Note: + The ``Whitespace`` pre-tokeniser splits on punctuation, so ICD-9 codes + such as ``"250.00"`` are stored as the sub-tokens ``["250", ".", "00"]``. + This is intentional: it drastically reduces the vocabulary size while + preserving code structure. + """ + tokenizer_obj = Tokenizer(models.WordLevel(unk_token="[UNK]")) + tokenizer_obj.pre_tokenizer = pre_tokenizers.Whitespace() + + special_tokens = ["[UNK]", "[PAD]", "[BOS]", "[EOS]"] + word_trainer = trainers.WordLevelTrainer(special_tokens=special_tokens) + tokenizer_obj.train_from_iterator(text_data, trainer=word_trainer) + + tokenizer_obj.post_processor = processors.TemplateProcessing( + single="[BOS] $A [EOS]", + special_tokens=[ + ("[BOS]", tokenizer_obj.token_to_id("[BOS]")), + ("[EOS]", tokenizer_obj.token_to_id("[EOS]")), + ], + ) + + return PreTrainedTokenizerFast( + tokenizer_object=tokenizer_obj, + unk_token="[UNK]", + pad_token="[PAD]", + bos_token="[BOS]", + eos_token="[EOS]", + ) + + +# ── PyTorch Dataset ──────────────────────────────────────────────────────────── + + +class EHRTextDataset(Dataset): + """Tokenises EHR text sequences for causal language-model training. + + Each sequence is tokenised, truncated/padded to ``max_length``, and stored + as a fixed-length ``LongTensor``. The ``labels`` field mirrors + ``input_ids`` so the HuggingFace ``Trainer`` can compute the standard + next-token prediction loss. + + Args: + sequences: Plain-text patient sequences (one string per patient). + tokenizer: A trained :class:`~transformers.PreTrainedTokenizerFast`. + max_length: Token budget; longer sequences are right-truncated. + + Examples: + >>> from pyhealth.models.generators import build_tokenizer, EHRTextDataset + >>> tok = build_tokenizer(["250.00 VISIT_DELIM 401.9"]) + >>> ds = EHRTextDataset(["250.00 VISIT_DELIM 401.9"], tok, max_length=16) + >>> len(ds) + 1 + >>> ds[0]["input_ids"].shape + torch.Size([16]) + """ + + def __init__( + self, + sequences: list[str], + tokenizer: PreTrainedTokenizerFast, + max_length: int = 512, + ) -> None: + self.input_ids: list[torch.Tensor] = [] + for txt in sequences: + enc = tokenizer( + txt, + truncation=True, + max_length=max_length, + padding="max_length", + ) + self.input_ids.append(torch.tensor(enc["input_ids"])) + + def __len__(self) -> int: + return len(self.input_ids) + + def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: + ids = self.input_ids[idx] + return {"input_ids": ids, "labels": ids} + + +# ── Main model class ─────────────────────────────────────────────────────────── + + +class EHRGPTBaseline(nn.Module): + """GPT-2 decoder-only language model for synthetic EHR generation. + + Wraps a HuggingFace ``GPT2LMHeadModel`` and exposes a high-level API + (:meth:`fit`, :meth:`generate`) that matches the training pipeline + described in *Accelerating Reproducible Research in Synthetic EHR + Generation* (CHIL 2026). + + Architecture + ------------ + * Word-level ICD-9 tokeniser (``VISIT_DELIM`` as vocabulary entry) + * GPT-2 transformer decoder with configurable depth and width + * Autoregressive next-token prediction objective + + Args: + n_embd: Embedding and hidden dimension. Default: 512. + n_layer: Number of transformer decoder layers. Default: 8. + n_head: Number of self-attention heads. Default: 8. + max_seq_len: Maximum token sequence length. Default: 512. + + Attributes: + tokenizer: The fitted :class:`~transformers.PreTrainedTokenizerFast` + (``None`` until :meth:`fit` is called). + model: The underlying :class:`~transformers.GPT2LMHeadModel` + (``None`` until :meth:`fit` is called). + + Examples: + .. code-block:: python + + from pyhealth.models.generators import EHRGPTBaseline, samples_to_sequences + + gpt = EHRGPTBaseline(n_embd=256, n_layer=4, n_head=4) + gpt.fit(text_sequences, output_dir="./ckpt", epochs=10, batch_size=32) + df = gpt.generate(n_patients=500) + """ + + def __init__( + self, + n_embd: int = 512, + n_layer: int = 8, + n_head: int = 8, + max_seq_len: int = 512, + ) -> None: + super().__init__() + self.n_embd = n_embd + self.n_layer = n_layer + self.n_head = n_head + self.max_seq_len = max_seq_len + + # Populated by fit() + self.tokenizer: Optional[PreTrainedTokenizerFast] = None + self.model: Optional[GPT2LMHeadModel] = None + + # ------------------------------------------------------------------ + # Forward + # ------------------------------------------------------------------ + + def forward(self, input_ids: torch.Tensor, **kwargs) -> dict: + """Run a forward pass through the GPT-2 model. + + Delegates directly to :class:`~transformers.GPT2LMHeadModel`. + + Args: + input_ids: Token ID tensor of shape ``(batch, seq_len)``. + **kwargs: Additional arguments forwarded to GPT2LMHeadModel. + + Returns: + The ``CausalLMOutputWithCrossAttentions`` dict-like object from + HuggingFace (contains ``logits``, ``loss`` when ``labels`` are + supplied, etc.). + + Raises: + RuntimeError: If called before :meth:`fit`. + """ + if self.model is None: + raise RuntimeError("Call fit() before forward().") + return self.model(input_ids, **kwargs) + + # ------------------------------------------------------------------ + # fit + # ------------------------------------------------------------------ + + def fit( + self, + sequences: list[str], + output_dir: str = "./ehr_gpt_output", + epochs: int = 50, + batch_size: int = 64, + learning_rate: float = 1e-4, + warmup_steps: int = 100, + ) -> "EHRGPTBaseline": + """Build the tokeniser, initialise GPT-2, and train on ``sequences``. + + This method is idempotent: calling it again re-initialises the + tokeniser and model from scratch. + + Args: + sequences: List of VISIT_DELIM-delimited patient text sequences + produced by :func:`samples_to_sequences`. + output_dir: Directory for HuggingFace ``Trainer`` checkpoints. + epochs: Training epochs. + batch_size: Per-device training batch size. + learning_rate: Peak learning rate (cosine schedule). + warmup_steps: Linear warm-up steps. + + Returns: + ``self`` (fluent API). + """ + os.makedirs(output_dir, exist_ok=True) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # ── tokeniser ───────────────────────────────────────────────────────── + self.tokenizer = build_tokenizer(sequences) + + # ── dataset ─────────────────────────────────────────────────────────── + train_ds = EHRTextDataset(sequences, self.tokenizer, max_length=self.max_seq_len) + + # ── model ───────────────────────────────────────────────────────────── + config = GPT2Config( + vocab_size=len(self.tokenizer), + n_positions=self.max_seq_len, + n_ctx=self.max_seq_len, + n_embd=self.n_embd, + n_layer=self.n_layer, + n_head=self.n_head, + bos_token_id=self.tokenizer.bos_token_id, + eos_token_id=self.tokenizer.eos_token_id, + ) + self.model = GPT2LMHeadModel(config).to(device) + + # ── training ────────────────────────────────────────────────────────── + data_collator = DataCollatorForLanguageModeling( + tokenizer=self.tokenizer, mlm=False + ) + + training_args = TrainingArguments( + output_dir=os.path.join(output_dir, "checkpoints"), + overwrite_output_dir=True, + num_train_epochs=epochs, + per_device_train_batch_size=batch_size, + logging_steps=50, + learning_rate=learning_rate, + lr_scheduler_type="cosine", + warmup_steps=warmup_steps, + use_cpu=not torch.cuda.is_available(), + save_strategy="epoch", + ) + + hf_trainer = Trainer( + model=self.model, + args=training_args, + data_collator=data_collator, + train_dataset=train_ds, + ) + hf_trainer.train() + + # Persist model and tokeniser side-by-side + model_dir = os.path.join(output_dir, "gpt_ehr_model") + hf_trainer.save_model(model_dir) + self.tokenizer.save_pretrained(model_dir) + + return self + + # ------------------------------------------------------------------ + # generate + # ------------------------------------------------------------------ + + def generate( + self, + n_patients: int = 1000, + batch_size: int = 512, + top_k: int = 50, + top_p: float = 0.95, + ) -> pd.DataFrame: + """Sample synthetic EHR sequences and return a long-form DataFrame. + + Args: + n_patients: Number of synthetic patients to generate. + batch_size: Generation batch size (GPU memory permitting). + top_k: Top-k sampling parameter. + top_p: Nucleus sampling probability threshold. + + Returns: + A ``pd.DataFrame`` with columns ``SUBJECT_ID``, ``HADM_ID``, + ``ICD9_CODE``. + + Raises: + RuntimeError: If called before :meth:`fit`. + """ + if self.model is None or self.tokenizer is None: + raise RuntimeError("Call fit() before generate().") + + device = next(self.model.parameters()).device + self.model.eval() + + all_dfs: list[pd.DataFrame] = [] + start_subj = 0 + + for batch_start in trange(0, n_patients, batch_size, desc="Generating"): + bsz = min(batch_size, n_patients - batch_start) + prompt = torch.tensor( + [[self.tokenizer.bos_token_id]] * bsz, device=device + ) + with torch.no_grad(): + generated = self.model.generate( + prompt, + max_new_tokens=self.max_seq_len, + do_sample=True, + top_k=top_k, + top_p=top_p, + pad_token_id=self.tokenizer.pad_token_id, + eos_token_id=self.tokenizer.eos_token_id, + ) + decoded = [ + self.tokenizer.decode(seq, skip_special_tokens=True) + for seq in generated + ] + batch_df = sequences_to_dataframe(decoded) + batch_df["SUBJECT_ID"] += start_subj + start_subj += bsz + all_dfs.append(batch_df) + + return pd.concat(all_dfs, ignore_index=True) + + # ------------------------------------------------------------------ + # Persistence helpers + # ------------------------------------------------------------------ + + def save(self, path: str) -> None: + """Save the GPT-2 model weights and tokeniser to ``path``. + + Args: + path: Directory to save into (created if absent). + """ + if self.model is None or self.tokenizer is None: + raise RuntimeError("Nothing to save – call fit() first.") + os.makedirs(path, exist_ok=True) + self.model.save_pretrained(path) + self.tokenizer.save_pretrained(path) + + @classmethod + def load(cls, path: str, **init_kwargs) -> "EHRGPTBaseline": + """Load a previously saved :class:`EHRGPTBaseline` from ``path``. + + Args: + path: Directory created by :meth:`save`. + **init_kwargs: Forwarded to ``__init__`` (overrides defaults for + ``n_embd``, ``n_layer``, ``n_head``, ``max_seq_len``). + + Returns: + A fully initialised :class:`EHRGPTBaseline` ready for + :meth:`generate`. + """ + instance = cls(**init_kwargs) + instance.tokenizer = PreTrainedTokenizerFast.from_pretrained(path) + instance.model = GPT2LMHeadModel.from_pretrained(path) + return instance + + # ------------------------------------------------------------------ + # repr + # ------------------------------------------------------------------ + + def __repr__(self) -> str: # pragma: no cover + fitted = self.model is not None + return ( + f"EHRGPTBaseline(" + f"n_embd={self.n_embd}, n_layer={self.n_layer}, " + f"n_head={self.n_head}, max_seq_len={self.max_seq_len}, " + f"fitted={fitted})" + ) diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 2f4294a19..23688453a 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -68,3 +68,4 @@ VariantClassificationClinVar, ) from .patient_linkage_mimic3 import PatientLinkageMIMIC3Task +from .ehr_generation import EHRGenerationMIMIC3 diff --git a/pyhealth/tasks/ehr_generation.py b/pyhealth/tasks/ehr_generation.py new file mode 100644 index 000000000..77a0408f4 --- /dev/null +++ b/pyhealth/tasks/ehr_generation.py @@ -0,0 +1,143 @@ +from typing import Any, Dict, List + +from pyhealth.tasks.base_task import BaseTask + + +class EHRGenerationMIMIC3(BaseTask): + """Task for training synthetic EHR generative models using MIMIC-III. + + Transforms longitudinal patient records into a visit-sequence representation + suitable for generative modeling. Each sample corresponds to one patient and + captures the complete temporal trajectory of ICD-9 diagnosis codes across + admissions. + + Two downstream representations can be derived from the output: + + * **Sequential** (PromptEHR, HALO, GPT): the nested ``conditions`` list + retains full visit boundaries and ordering. + * **Matrix / flattened** (MedGAN, CorGAN): flatten ``conditions`` into a + single list (binary presence) or count vector per patient. + + For standardised evaluation, every synthetic or real record should be + converted to a long-format schema of ``(subject_id, visit_id, code)`` + triplets as recommended by the paper *Accelerating Reproducible Research in + Synthetic EHR Generation*. + + Attributes: + task_name (str): Identifier for this task. + input_schema (Dict[str, str]): ``{"conditions": "nested_sequence"}`` – + tells PyHealth's processor to serialise the variable-length nested + visit list correctly (same convention as + ``DrugRecommendationMIMIC3``). + output_schema (Dict[str, str]): ``{}`` – no supervised label is + produced. + min_visits (int): Minimum number of valid visits a patient must have + to be included. Defaults to ``1``. + truncate_icd (bool): When ``True``, ICD-9 codes are truncated to the + first 3 characters (e.g. ``"250.40"`` → ``"250"``), reducing the + vocabulary from ~6,955 to 1,071 codes. The paper recommends + keeping ``False`` for full clinical fidelity. Defaults to + ``False``. + + Note: + A full end-to-end training example using a GPT-2 style decoder can be + found at ``examples/ehr_generation/ehr_generation_mimic3_transformer.py``. + + Examples: + >>> from pyhealth.datasets import MIMIC3Dataset + >>> from pyhealth.tasks import EHRGenerationMIMIC3 + >>> dataset = MIMIC3Dataset( + ... root="/path/to/mimic-iii/1.4", + ... tables=["diagnoses_icd"], + ... ) + >>> task = EHRGenerationMIMIC3() + >>> samples = dataset.set_task(task) + >>> # Each sample: {patient_id, conditions, num_visits} + >>> # conditions is a list of visits; each visit is a list of ICD-9 codes + >>> # e.g. [["250.00", "401.9"], ["272.0", "428.0"]] + """ + + task_name: str = "EHRGenerationMIMIC3" + input_schema: Dict[str, str] = {"conditions": "nested_sequence"} + output_schema: Dict[str, str] = {} + + def __init__( + self, + min_visits: int = 1, + truncate_icd: bool = False, + ) -> None: + """Initialise the task. + + Args: + min_visits (int): Patients with fewer than ``min_visits`` valid + admissions (i.e. admissions that contain at least one ICD-9 + code) are excluded. Defaults to ``1``. + truncate_icd (bool): Truncate ICD-9 codes to 3-digit prefixes. + Useful for reproducing prior work that caps the vocabulary at + 1,071 codes. Defaults to ``False`` (full 6,955-code vocabulary). + """ + self.min_visits = min_visits + self.truncate_icd = truncate_icd + + def __call__(self, patient: Any) -> List[Dict[str, Any]]: + """Process a single patient and return a list with one generation sample. + + Each returned sample represents the patient's full longitudinal record + as a nested list of ICD-9 diagnosis code sequences, one inner list per + hospital admission (visit). + + Admissions with no ICD-9 codes are silently skipped. Patients with + fewer valid visits than ``self.min_visits`` return an empty list. + + Args: + patient: A PyHealth ``Patient`` object providing a + ``get_events(event_type, filters)`` interface. + + Returns: + A list containing a single dict with: + + * ``patient_id`` (str): MIMIC-III ``subject_id``. + * ``conditions`` (List[List[str]]): Nested list of ICD-9 diagnosis + codes, grouped by admission. Outer index = visit order (chronological); + inner index = code index within that visit. + The number of visits can be derived as ``len(conditions)``. + """ + admissions = patient.get_events(event_type="admissions") + + visit_sequences: List[List[str]] = [] + + for admission in admissions: + diagnoses = patient.get_events( + event_type="diagnoses_icd", + filters=[("hadm_id", "==", admission.hadm_id)], + ) + + codes = [event.icd9_code for event in diagnoses if event.icd9_code] + + if self.truncate_icd: + codes = [code[:3] for code in codes] + + # Deduplicate while preserving order + seen: set = set() + unique_codes: List[str] = [] + for code in codes: + if code not in seen: + seen.add(code) + unique_codes.append(code) + codes = unique_codes + + if not codes: + continue + + visit_sequences.append(codes) + + if len(visit_sequences) < self.min_visits: + return [] + + return [ + { + "patient_id": patient.patient_id, + "conditions": visit_sequences, + } + ] + diff --git a/tests/core/test_mimic3_ehr_generation.py b/tests/core/test_mimic3_ehr_generation.py new file mode 100644 index 000000000..a5dd8b66d --- /dev/null +++ b/tests/core/test_mimic3_ehr_generation.py @@ -0,0 +1,337 @@ +"""Tests for EHRGenerationMIMIC3 task and the sequence helper utilities. + +All tests use a fully synthetic mock dataset so no real MIMIC-III files or +PyHealth's set_task() / litdata pipeline are required. +""" + +import unittest + +import pandas as pd + +from pyhealth.tasks import EHRGenerationMIMIC3 + +# ── visit-delimiter helpers (mirrored from the example script) ───────────────── + +VISIT_DELIM = "VISIT_DELIM" + + +def samples_to_sequences(samples: list) -> list: + """Nested visit list → VISIT_DELIM-delimited text string per patient.""" + sequences = [] + for sample in samples: + visit_texts = [" ".join(visit_codes) for visit_codes in sample["conditions"]] + sequences.append(f" {VISIT_DELIM} ".join(visit_texts)) + return sequences + + +def sequences_to_dataframe(sequences: list) -> pd.DataFrame: + """Text sequences → long-form (SUBJECT_ID, HADM_ID, ICD9_CODE) DataFrame.""" + rows = [] + for subj_idx, seq in enumerate(sequences): + for hadm_idx, visit_str in enumerate(seq.strip().split(VISIT_DELIM)): + for code in visit_str.strip().split(): + if code: + rows.append( + {"SUBJECT_ID": subj_idx, "HADM_ID": hadm_idx, "ICD9_CODE": code} + ) + return pd.DataFrame(rows) + + +# ── minimal mock objects that mimic PyHealth's Patient/Event interface ───────── + + +class _MockAdmission: + """Lightweight stand-in for a MIMIC-III admission event.""" + + def __init__(self, hadm_id: str) -> None: + self.hadm_id = hadm_id + + +class _MockDiagnosis: + """Lightweight stand-in for a diagnoses_icd event.""" + + def __init__(self, hadm_id: str, icd9_code: str) -> None: + self.hadm_id = hadm_id + self.icd9_code = icd9_code + + +class _MockPatient: + """Mimics BasePatient.get_events() for admissions and diagnoses_icd tables.""" + + def __init__(self, patient_id: str, visits: list) -> None: + """ + Args: + patient_id: Synthetic subject_id string. + visits: List of visits; each visit is a list of ICD-9 code strings. + Duplicates within a visit are intentional to test dedup logic. + """ + self.patient_id = patient_id + self._admissions = [ + _MockAdmission(hadm_id=str(100 + i)) for i in range(len(visits)) + ] + self._diagnoses = [] + for admission, codes in zip(self._admissions, visits): + for code in codes: + self._diagnoses.append(_MockDiagnosis(admission.hadm_id, code)) + + def get_events(self, event_type: str, filters=None): + if event_type == "admissions": + return list(self._admissions) + if event_type == "diagnoses_icd": + result = list(self._diagnoses) + if filters: + for field, op, value in filters: + if op == "==": + result = [e for e in result if getattr(e, field) == value] + return result + return [] + + +# ── synthetic patient corpus ─────────────────────────────────────────────────── + +_PATIENTS = { + # 3 visits, no duplicates + "P001": [ + ["250.00", "401.9", "278.00"], + ["250.00", "272.0"], + ["428.0", "401.9", "285.9"], + ], + # 2 visits with intentional within-visit duplicates + "P002": [ + ["410.01", "410.01", "412"], # 410.01 duplicated intentionally + ["414.01", "V45.81"], + ], + # 1 visit (used to test min_visits filtering) + "P003": [ + ["486", "518.81"], + ], + # 4 visits with long codes (used for truncate_icd tests) + "P004": [ + ["250.40", "250.00"], + ["401.10", "401.90"], + ["428.00"], + ["272.00", "272.10"], + ], + # patient with some empty visits (should be silently skipped) + "P005": [ + [], # empty → skipped + ["V15.82"], + [], # empty → skipped + ["401.9"], + ], +} + +ALL_PATIENTS = [_MockPatient(pid, visits) for pid, visits in _PATIENTS.items()] + + +# ── test class ───────────────────────────────────────────────────────────────── + + +class TestEHRGenerationMIMIC3Task(unittest.TestCase): + """Unit tests for EHRGenerationMIMIC3 using synthetic mock patients.""" + + def _run_task(self, task, patients=None): + """Helper: run task over a list of mock patients, flatten results.""" + if patients is None: + patients = ALL_PATIENTS + samples = [] + for p in patients: + samples.extend(task(p)) + return samples + + # ── schema / init ────────────────────────────────────────────────────────── + + def test_task_name(self): + self.assertEqual(EHRGenerationMIMIC3.task_name, "EHRGenerationMIMIC3") + + def test_input_schema(self): + # nested_sequence required so PyHealth's processor handles variable-length visits + self.assertEqual(EHRGenerationMIMIC3.input_schema, {"conditions": "nested_sequence"}) + + def test_output_schema(self): + self.assertEqual(EHRGenerationMIMIC3.output_schema, {}) + + def test_default_init(self): + task = EHRGenerationMIMIC3() + self.assertEqual(task.min_visits, 1) + self.assertFalse(task.truncate_icd) + + def test_custom_init(self): + task = EHRGenerationMIMIC3(min_visits=3, truncate_icd=True) + self.assertEqual(task.min_visits, 3) + self.assertTrue(task.truncate_icd) + + # ── per-patient __call__ output ──────────────────────────────────────────── + + def test_returns_one_sample_per_patient(self): + """Each qualifying patient produces exactly one sample dict.""" + task = EHRGenerationMIMIC3() + for patient in ALL_PATIENTS: + result = task(patient) + self.assertIn(len(result), (0, 1)) + + def test_sample_keys_present(self): + """Each sample must have patient_id and conditions keys.""" + task = EHRGenerationMIMIC3() + samples = self._run_task(task) + self.assertGreater(len(samples), 0) + for sample in samples: + self.assertIn("patient_id", sample) + self.assertIn("conditions", sample) + + def test_patient_id_matches(self): + """sample['patient_id'] must equal the originating patient id.""" + task = EHRGenerationMIMIC3() + for patient in ALL_PATIENTS: + for sample in task(patient): + self.assertEqual(sample["patient_id"], patient.patient_id) + + def test_conditions_is_nested_list_of_strings(self): + """conditions must be List[List[str]] with no empty inner lists.""" + task = EHRGenerationMIMIC3() + samples = self._run_task(task) + for sample in samples: + conds = sample["conditions"] + self.assertIsInstance(conds, list) + self.assertGreater(len(conds), 0) + for visit in conds: + self.assertIsInstance(visit, list) + self.assertGreater(len(visit), 0, "Empty visits must be dropped") + for code in visit: + self.assertIsInstance(code, str) + self.assertGreater(len(code), 0) + + def test_empty_visits_skipped(self): + """Admissions with no ICD-9 codes are silently skipped.""" + task = EHRGenerationMIMIC3() + p005 = _MockPatient("P005", _PATIENTS["P005"]) + result = task(p005) + self.assertEqual(len(result), 1) + # 4 admissions, 2 empty → 2 valid visits + self.assertEqual(len(result[0]["conditions"]), 2) + + def test_within_visit_deduplication(self): + """Duplicate ICD-9 codes within a single visit are removed.""" + task = EHRGenerationMIMIC3() + p002 = _MockPatient("P002", _PATIENTS["P002"]) + result = task(p002) + self.assertEqual(len(result), 1) + for visit in result[0]["conditions"]: + self.assertEqual(len(visit), len(set(visit)), + f"Duplicate codes in visit: {visit}") + + def test_visit_order_preserved(self): + """Visits appear in the same order they were supplied.""" + task = EHRGenerationMIMIC3() + p001 = _MockPatient("P001", _PATIENTS["P001"]) + result = task(p001) + self.assertIn("250.00", result[0]["conditions"][0]) + self.assertIn("428.0", result[0]["conditions"][2]) + + def test_conditions_length_matches_nonempty_visits(self): + """len(conditions) equals number of non-empty visits.""" + task = EHRGenerationMIMIC3() + self.assertEqual(len(task(_MockPatient("P001", _PATIENTS["P001"]))[0]["conditions"]), 3) + self.assertEqual(len(task(_MockPatient("P005", _PATIENTS["P005"]))[0]["conditions"]), 2) + + # ── min_visits filtering ─────────────────────────────────────────────────── + + def test_min_visits_1_includes_single_visit_patient(self): + task = EHRGenerationMIMIC3(min_visits=1) + self.assertEqual(len(task(_MockPatient("P003", _PATIENTS["P003"]))), 1) + + def test_min_visits_2_excludes_single_visit_patient(self): + task = EHRGenerationMIMIC3(min_visits=2) + self.assertEqual(len(task(_MockPatient("P003", _PATIENTS["P003"]))), 0) + + def test_min_visits_2_keeps_multi_visit_patient(self): + task = EHRGenerationMIMIC3(min_visits=2) + self.assertEqual(len(task(_MockPatient("P001", _PATIENTS["P001"]))), 1) + + def test_min_visits_too_high_returns_empty_for_all(self): + task = EHRGenerationMIMIC3(min_visits=10) + self.assertEqual(self._run_task(task), []) + + # ── truncate_icd ─────────────────────────────────────────────────────────── + + def test_truncate_icd_shortens_codes_to_3_chars(self): + """All codes must be ≤ 3 characters when truncate_icd=True.""" + task = EHRGenerationMIMIC3(truncate_icd=True) + for sample in self._run_task(task): + for visit in sample["conditions"]: + for code in visit: + self.assertLessEqual(len(code), 3, + f"Code '{code}' exceeds 3 chars") + + def test_truncate_icd_false_preserves_full_codes(self): + """Codes longer than 3 chars must survive when truncate_icd=False.""" + task = EHRGenerationMIMIC3(truncate_icd=False) + result = task(_MockPatient("P004", _PATIENTS["P004"])) + all_codes = [c for visit in result[0]["conditions"] for c in visit] + self.assertTrue(any(len(c) > 3 for c in all_codes), + "Expected full-length codes like '250.40'") + + def test_truncate_icd_dedup_after_merge(self): + """After truncation, merged codes are deduplicated within each visit.""" + # visit 0 of P004: "250.40" and "250.00" both → "250" (only one should survive) + task = EHRGenerationMIMIC3(truncate_icd=True) + result = task(_MockPatient("P004", _PATIENTS["P004"])) + visit_0 = result[0]["conditions"][0] + self.assertEqual(visit_0, ["250"]) + + # ── edge cases ───────────────────────────────────────────────────────────── + + def test_all_empty_visits_returns_empty(self): + task = EHRGenerationMIMIC3() + self.assertEqual(task(_MockPatient("PEMPTY", [[], [], []])), []) + + def test_no_visits_returns_empty(self): + task = EHRGenerationMIMIC3() + self.assertEqual(task(_MockPatient("PNONE", [])), []) + + # ── sequence helper: samples_to_sequences ───────────────────────────────── + + def test_samples_to_sequences_one_string_per_sample(self): + samples = self._run_task(EHRGenerationMIMIC3()) + seqs = samples_to_sequences(samples) + self.assertEqual(len(seqs), len(samples)) + for seq in seqs: + self.assertIsInstance(seq, str) + self.assertGreater(len(seq.strip()), 0) + + def test_samples_to_sequences_delimiter_present_for_multi_visit(self): + sample = EHRGenerationMIMIC3()(_MockPatient("P001", _PATIENTS["P001"]))[0] + self.assertIn(VISIT_DELIM, samples_to_sequences([sample])[0]) + + def test_samples_to_sequences_no_delimiter_for_single_visit(self): + sample = EHRGenerationMIMIC3()(_MockPatient("P003", _PATIENTS["P003"]))[0] + self.assertNotIn(VISIT_DELIM, samples_to_sequences([sample])[0]) + + # ── sequence helper: sequences_to_dataframe ─────────────────────────────── + + def test_sequences_to_dataframe_columns(self): + samples = self._run_task(EHRGenerationMIMIC3()) + df = sequences_to_dataframe(samples_to_sequences(samples)) + for col in ("SUBJECT_ID", "HADM_ID", "ICD9_CODE"): + self.assertIn(col, df.columns) + + def test_round_trip_all_codes_preserved(self): + """Every code in the original samples must appear in the recovered DataFrame.""" + samples = self._run_task(EHRGenerationMIMIC3()) + df = sequences_to_dataframe(samples_to_sequences(samples)) + original = {c for s in samples for visit in s["conditions"] for c in visit} + recovered = set(df["ICD9_CODE"].tolist()) + self.assertEqual(original, recovered) + + def test_round_trip_visit_count_per_patient(self): + """The DataFrame must reconstruct the correct visit count per patient.""" + samples = self._run_task(EHRGenerationMIMIC3()) + df = sequences_to_dataframe(samples_to_sequences(samples)) + for idx, sample in enumerate(samples): + syn_visits = df[df["SUBJECT_ID"] == idx]["HADM_ID"].nunique() + self.assertEqual(syn_visits, len(sample["conditions"]), + f"Visit count mismatch at sample index {idx}") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/test_transformer_ehr_helpers.py b/tests/core/test_transformer_ehr_helpers.py new file mode 100644 index 000000000..e93e11dc3 --- /dev/null +++ b/tests/core/test_transformer_ehr_helpers.py @@ -0,0 +1,272 @@ +""" +Tests for the utility functions and classes defined in +pyhealth/models/generators/gpt_baseline.py + +Covered: +* ``samples_to_sequences`` – nested visit lists -> text strings +* ``sequences_to_dataframe`` – text strings -> long-form DataFrame +* ``build_tokenizer`` – word-level HuggingFace tokenizer +* ``EHRTextDataset`` – PyTorch Dataset wrapping tokenized EHR sequences +""" + +import unittest + +import torch +from transformers import PreTrainedTokenizerFast + +from pyhealth.models.generators import ( + VISIT_DELIM, + EHRTextDataset, + build_tokenizer, + samples_to_sequences, + sequences_to_dataframe, +) + +# ── shared test data ─────────────────────────────────────────────────────────── + +_SINGLE_VISIT_SAMPLE = {"conditions": [["250.00", "401.9"]]} +_MULTI_VISIT_SAMPLE = {"conditions": [["250.00", "401.9"], ["272.0", "428.0"], ["250.00"]]} +_EMPTY_VISIT_SAMPLE = {"conditions": []} + +_CORPUS = [ + "250.00 401.9 VISIT_DELIM 272.0", + "428.0 VISIT_DELIM 250.00", + "401.9 272.0 428.0", +] + +_SEQUENCES = [ + "250.00 401.9 VISIT_DELIM 272.0", + "428.0", + "401.9 272.0 428.0 VISIT_DELIM 250.00 VISIT_DELIM 272.0", +] +_MAX_LENGTH = 16 + + +# ── 1. samples_to_sequences ──────────────────────────────────────────────────── + + +class TestSamplesToSequences(unittest.TestCase): + def test_returns_one_string_per_sample(self): + result = samples_to_sequences([_SINGLE_VISIT_SAMPLE, _MULTI_VISIT_SAMPLE]) + self.assertEqual(len(result), 2) + + def test_empty_input_returns_empty_list(self): + self.assertEqual(samples_to_sequences([]), []) + + def test_single_visit_no_delimiter(self): + result = samples_to_sequences([_SINGLE_VISIT_SAMPLE]) + self.assertNotIn(VISIT_DELIM, result[0]) + + def test_multi_visit_delimiter_count_matches(self): + # 3 visits -> 2 VISIT_DELIM occurrences + result = samples_to_sequences([_MULTI_VISIT_SAMPLE]) + self.assertEqual(result[0].count(VISIT_DELIM), 2) + + def test_codes_present_in_output(self): + result = samples_to_sequences([_MULTI_VISIT_SAMPLE]) + for visit in _MULTI_VISIT_SAMPLE["conditions"]: + for code in visit: + self.assertIn(code, result[0]) + + def test_single_visit_codes_space_separated(self): + result = samples_to_sequences([_SINGLE_VISIT_SAMPLE]) + self.assertEqual(result[0], "250.00 401.9") + + def test_multi_visit_format(self): + result = samples_to_sequences([_MULTI_VISIT_SAMPLE]) + expected = f"250.00 401.9 {VISIT_DELIM} 272.0 428.0 {VISIT_DELIM} 250.00" + self.assertEqual(result[0], expected) + + def test_empty_conditions_yields_empty_string(self): + result = samples_to_sequences([_EMPTY_VISIT_SAMPLE]) + self.assertEqual(result[0], "") + + def test_single_code_per_visit(self): + sample = {"conditions": [["A"], ["B"], ["C"]]} + result = samples_to_sequences([sample]) + self.assertEqual(result[0], f"A {VISIT_DELIM} B {VISIT_DELIM} C") + + def test_multiple_samples_independent(self): + result = samples_to_sequences([_SINGLE_VISIT_SAMPLE, _MULTI_VISIT_SAMPLE]) + self.assertNotEqual(result[0], result[1]) + + def test_output_is_list_of_strings(self): + result = samples_to_sequences([_SINGLE_VISIT_SAMPLE]) + self.assertIsInstance(result, list) + for s in result: + self.assertIsInstance(s, str) + + +# ── 2. sequences_to_dataframe ───────────────────────────────────────────────── + + +class TestSequencesToDataframe(unittest.TestCase): + _SEQ_SINGLE = "250.00 401.9" + _SEQ_MULTI = f"250.00 401.9 {VISIT_DELIM} 272.0 428.0" + + def test_required_columns_present(self): + df = sequences_to_dataframe([self._SEQ_SINGLE]) + self.assertEqual(set(df.columns), {"SUBJECT_ID", "HADM_ID", "ICD9_CODE"}) + + def test_empty_input_returns_empty_dataframe(self): + df = sequences_to_dataframe([]) + self.assertTrue(df.empty) + self.assertEqual(list(df.columns), []) + + def test_single_visit_produces_correct_codes(self): + df = sequences_to_dataframe([self._SEQ_SINGLE]) + self.assertEqual(set(df["ICD9_CODE"].tolist()), {"250.00", "401.9"}) + + def test_single_visit_single_hadm_id(self): + df = sequences_to_dataframe([self._SEQ_SINGLE]) + self.assertEqual(df["HADM_ID"].nunique(), 1) + self.assertEqual(df["HADM_ID"].iloc[0], 0) + + def test_multi_visit_hadm_ids(self): + df = sequences_to_dataframe([self._SEQ_MULTI]) + self.assertEqual(set(df["HADM_ID"].tolist()), {0, 1}) + + def test_subject_ids_sequential(self): + df = sequences_to_dataframe([self._SEQ_SINGLE, self._SEQ_SINGLE]) + self.assertEqual(set(df["SUBJECT_ID"].tolist()), {0, 1}) + + def test_multi_patient_subject_id_mapping(self): + df = sequences_to_dataframe([self._SEQ_MULTI, self._SEQ_SINGLE]) + self.assertEqual(df[df["SUBJECT_ID"] == 0]["HADM_ID"].nunique(), 2) + self.assertEqual(df[df["SUBJECT_ID"] == 1]["HADM_ID"].nunique(), 1) + + def test_row_count_matches_codes(self): + df = sequences_to_dataframe([self._SEQ_MULTI]) + self.assertEqual(len(df), 4) + + def test_whitespace_only_sequence_returns_empty(self): + df = sequences_to_dataframe([" "]) + self.assertTrue(df.empty) + + def test_round_trip_from_samples(self): + seqs = samples_to_sequences([_MULTI_VISIT_SAMPLE]) + df = sequences_to_dataframe(seqs) + all_codes = {c for visit in _MULTI_VISIT_SAMPLE["conditions"] for c in visit} + self.assertEqual(all_codes, set(df["ICD9_CODE"].tolist())) + + def test_round_trip_visit_count(self): + seqs = samples_to_sequences([_MULTI_VISIT_SAMPLE]) + df = sequences_to_dataframe(seqs) + n_visits = df.groupby("SUBJECT_ID")["HADM_ID"].nunique().iloc[0] + self.assertEqual(n_visits, len(_MULTI_VISIT_SAMPLE["conditions"])) + + +# ── 3. build_tokenizer ──────────────────────────────────────────────────────── + + +class TestBuildTokenizer(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.tokenizer = build_tokenizer(_CORPUS) + + def test_special_tokens_in_vocab(self): + vocab = self.tokenizer.get_vocab() + for tok in ("[UNK]", "[PAD]", "[BOS]", "[EOS]"): + self.assertIn(tok, vocab, f"{tok!r} missing from vocab") + + def test_visit_delim_in_vocab(self): + self.assertIn(VISIT_DELIM, self.tokenizer.get_vocab()) + + def test_medical_codes_in_vocab(self): + # Whitespace pre-tokenizer splits "250.00" -> ["250", ".", "00"] + vocab = self.tokenizer.get_vocab() + for sub in ["250", "00", "401", "9", "272", "0", "428", "."]: + self.assertIn(sub, vocab, f"sub-token {sub!r} missing from vocab") + + def test_vocab_size_at_least_corpus_tokens(self): + self.assertGreaterEqual(len(self.tokenizer), 10) + + def test_bos_eos_token_ids_set(self): + self.assertIsNotNone(self.tokenizer.bos_token_id) + self.assertIsNotNone(self.tokenizer.eos_token_id) + + def test_pad_token_id_set(self): + self.assertIsNotNone(self.tokenizer.pad_token_id) + + def test_encode_includes_bos_eos(self): + ids = self.tokenizer("250.00 401.9")["input_ids"] + self.assertEqual(ids[0], self.tokenizer.bos_token_id) + self.assertEqual(ids[-1], self.tokenizer.eos_token_id) + + def test_encode_decode_roundtrip(self): + text = "250.00 401.9 VISIT_DELIM 272.0" + ids = self.tokenizer(text, add_special_tokens=True)["input_ids"] + decoded = self.tokenizer.decode(ids, skip_special_tokens=True) + # Whitespace splits on '.', so check sub-tokens + for sub in ["250", "00", "401", "9", VISIT_DELIM, "272", "0"]: + self.assertIn(sub, decoded.split(), f"{sub!r} missing from decoded") + + def test_unknown_token_maps_to_unk_id(self): + enc = self.tokenizer("UNKNOWN_CODE_XYZ")["input_ids"] + inner = enc[1:-1] # strip BOS/EOS + self.assertIn(self.tokenizer.unk_token_id, inner) + + def test_returns_pretrained_tokenizer_fast(self): + self.assertIsInstance(self.tokenizer, PreTrainedTokenizerFast) + + +# ── 4. EHRTextDataset ───────────────────────────────────────────────────────── + + +class TestEHRTextDataset(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.tokenizer = build_tokenizer(_SEQUENCES) + cls.dataset = EHRTextDataset(_SEQUENCES, cls.tokenizer, max_length=_MAX_LENGTH) + + def test_len_matches_sequences(self): + self.assertEqual(len(self.dataset), len(_SEQUENCES)) + + def test_getitem_returns_dict(self): + self.assertIsInstance(self.dataset[0], dict) + + def test_getitem_has_input_ids_key(self): + self.assertIn("input_ids", self.dataset[0]) + + def test_getitem_has_labels_key(self): + self.assertIn("labels", self.dataset[0]) + + def test_input_ids_are_tensors(self): + self.assertIsInstance(self.dataset[0]["input_ids"], torch.Tensor) + + def test_labels_are_tensors(self): + self.assertIsInstance(self.dataset[0]["labels"], torch.Tensor) + + def test_input_ids_length_equals_max_length(self): + for i in range(len(self.dataset)): + self.assertEqual(self.dataset[i]["input_ids"].shape[0], _MAX_LENGTH) + + def test_labels_equal_input_ids(self): + item = self.dataset[0] + self.assertTrue(torch.equal(item["input_ids"], item["labels"])) + + def test_all_items_same_length(self): + lengths = {self.dataset[i]["input_ids"].shape[0] for i in range(len(self.dataset))} + self.assertEqual(len(lengths), 1) + + def test_empty_sequences_list(self): + ds = EHRTextDataset([], self.tokenizer, max_length=_MAX_LENGTH) + self.assertEqual(len(ds), 0) + + def test_single_sequence(self): + ds = EHRTextDataset(["250.00"], self.tokenizer, max_length=_MAX_LENGTH) + self.assertEqual(len(ds), 1) + self.assertEqual(ds[0]["input_ids"].shape[0], _MAX_LENGTH) + + def test_long_sequence_truncated(self): + long_seq = " ".join(["250.00"] * 100) + ds = EHRTextDataset([long_seq], self.tokenizer, max_length=_MAX_LENGTH) + self.assertEqual(ds[0]["input_ids"].shape[0], _MAX_LENGTH) + + def test_index_out_of_range_raises(self): + with self.assertRaises(IndexError): + _ = self.dataset[len(_SEQUENCES)] + + +if __name__ == "__main__": + unittest.main()