diff --git a/examples/generate_synthetic_mimic3_promptehr.py b/examples/generate_synthetic_mimic3_promptehr.py
new file mode 100644
index 000000000..5eefb7ff7
--- /dev/null
+++ b/examples/generate_synthetic_mimic3_promptehr.py
@@ -0,0 +1,47 @@
+"""PromptEHR: Synthetic MIMIC-III Patient Generation.
+
+Load a trained PromptEHR checkpoint and generate synthetic patients.
+
+Reference:
+ Wang et al. "PromptEHR: Conditional Electronic Healthcare Records
+ Generation with Prompt Learning." EMNLP 2023.
+ https://arxiv.org/abs/2211.01761
+"""
+
+import json
+
+from pyhealth.datasets import MIMIC3Dataset
+from pyhealth.models import PromptEHR
+from pyhealth.tasks import promptehr_generation_mimic3_fn
+
+MIMIC3_ROOT = "/srv/local/data/physionet.org/files/mimiciii/1.4"
+CHECKPOINT_PATH = "./save/promptehr/checkpoint.pt"
+OUTPUT_PATH = "./save/promptehr/synthetic_patients.json"
+NUM_SAMPLES = 10_000
+
+# 1. Load dataset + apply task (needed for processor/vocab reconstruction)
+dataset = MIMIC3Dataset(
+ root=MIMIC3_ROOT,
+ tables=["patients", "admissions", "diagnoses_icd"],
+ code_mapping={},
+)
+sample_dataset = dataset.set_task(promptehr_generation_mimic3_fn)
+
+# 2. Load checkpoint
+model = PromptEHR(dataset=sample_dataset)
+model.load_model(CHECKPOINT_PATH)
+print(f"Loaded checkpoint from {CHECKPOINT_PATH}")
+
+# 3. Generate
+print(f"Generating {NUM_SAMPLES} synthetic patients...")
+synthetic = model.synthesize_dataset(num_samples=NUM_SAMPLES)
+print(f"Generated {len(synthetic)} patients")
+
+# 4. Save
+with open(OUTPUT_PATH, "w") as f:
+ json.dump(synthetic, f, indent=2)
+print(f"Saved to {OUTPUT_PATH}")
+
+# Summary stats
+avg_visits = sum(len(p["visits"]) for p in synthetic) / len(synthetic)
+print(f"Average visits per patient: {avg_visits:.2f}")
diff --git a/examples/promptehr_mimic3_colab.ipynb b/examples/promptehr_mimic3_colab.ipynb
new file mode 100644
index 000000000..925efc57a
--- /dev/null
+++ b/examples/promptehr_mimic3_colab.ipynb
@@ -0,0 +1,252 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 5,
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python",
+ "version": "3.10.0"
+ },
+ "colab": {
+ "provenance": [],
+ "gpuType": "T4"
+ },
+ "accelerator": "GPU"
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "preamble",
+ "metadata": {},
+ "source": "# PromptEHR: Demographic-Conditioned Synthetic EHR Generation\n\n_Last updated: 2026-03-01_\n\nTrain **PromptEHR** on your MIMIC-III data and generate synthetic patients whose demographic distributions mirror the real population.\n\n## What You'll Need\n\n1. **MIMIC-III Access** (or run in Demo Mode without it). Download 3 files from PhysioNet:\n - `PATIENTS.csv` \u2014 patient demographics (date of birth, gender)\n - `ADMISSIONS.csv` \u2014 hospital admission records (timestamps)\n - `DIAGNOSES_ICD.csv` \u2014 ICD-9 diagnosis codes\n\n2. **Google Colab** (or local environment): Free tier works; GPU recommended.\n\n> **Demo Mode**: No MIMIC-III? Set `PRESET = \"demo\"` and skip the file upload step. The notebook runs the full pipeline with synthetic stand-in data.\n\n## What You'll Get\n\n- A trained PromptEHR model conditioned on patient age and gender\n- Synthetic patients whose age/gender distributions mirror the MIMIC-III population\n- `synthetic_patients.csv` \u2014 flat `SUBJECT_ID, VISIT_NUM, ICD9_CODE` records\n- `synthetic_patients.json` \u2014 nested visit records for PyHealth downstream tasks\n- `quality_report.json` \u2014 statistics for automated evaluation and CI\n\n## How Long It Takes\n\n| Preset | Epochs | Time (T4 GPU) | Use case |\n|--------|--------|----------------|----------|\n| `\"demo\"` | 5 | ~30\u201345 min | First run, CI smoke test |\n| `\"production\"` | 20 | ~3\u20135 hrs | Publication-quality results |\n\n## What Makes PromptEHR Different from HALO\n\nUnlike HALO (which generates patients from a shared unconditional distribution), **PromptEHR conditions generation on patient demographics**. It uses a BART Seq2Seq Transformer with learned \"prompt\" vectors \u2014 one per demographic feature \u2014 prepended to the encoder input. During training, the model learns that older male patients tend to have different diagnosis patterns than young female patients. During generation, demographics are sampled from the real training distribution, so the synthetic cohort's age/gender profile automatically mirrors MIMIC-III.\n\nThis matters clinically: synthetic datasets used for fairness research or subgroup analysis must preserve demographic distributions. PromptEHR provides this guarantee by design.\n\n**Reference**: Wang et al., \"PromptEHR: Conditional Electronic Healthcare Records Generation with Prompt Learning.\" EMNLP 2023. https://arxiv.org/abs/2211.01761"
+ },
+ {
+ "cell_type": "markdown",
+ "id": "s1-header",
+ "metadata": {},
+ "source": "---\n# 1. Setup & Installation"
+ },
+ {
+ "cell_type": "code",
+ "id": "s1-setup",
+ "metadata": {},
+ "outputs": [],
+ "execution_count": null,
+ "source": "import subprocess\nimport sys\n\n# Install PyHealth from GitHub\nFORK = 'jalengg'\nBRANCH = 'promptehr-pr-integration'\ninstall_url = f\"git+https://github.com/{FORK}/PyHealth.git@{BRANCH}\"\nsubprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", install_url,\n \"--quiet\", \"--no-cache-dir\"])\nprint(f\"\u2713 PyHealth installed from {FORK}/{BRANCH}\")\n\n# Environment detection \u2014 MUST come before any google.colab import\ntry:\n import google.colab # noqa: F401\n IN_COLAB = True\nexcept ImportError:\n IN_COLAB = False\n\nimport os\nimport json\nimport glob\nimport random\nimport numpy as np\nimport torch\nimport pandas as pd\nimport matplotlib.pyplot as plt\nfrom datetime import datetime\nfrom collections import Counter\nfrom IPython.display import display\n\nprint(f\"Running in Colab: {IN_COLAB}\")\nprint(f\"PyTorch: {torch.__version__}\")\nif torch.cuda.is_available():\n print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n gb = torch.cuda.get_device_properties(0).total_memory / 1e9\n print(f\"GPU memory: {gb:.1f} GB\")\nelse:\n print(\"WARNING: No GPU detected. Training will be slow.\")\n print(\" \u2192 Runtime \u2192 Change runtime type \u2192 T4 GPU\")\n\n# Verify HuggingFace transformers version (>=4.48.3 required for use_cpu parameter)\nimport transformers\n_ver = tuple(int(x) for x in transformers.__version__.split(\".\")[:2])\nassert _ver >= (4, 48), (\n f\"transformers>=4.48.3 required (got {transformers.__version__}). \"\n \"Fix: pip install transformers --upgrade\"\n)\nprint(f\"transformers: {transformers.__version__} \u2713\")\nprint(\"\u2713 All setup complete\")"
+ },
+ {
+ "cell_type": "markdown",
+ "id": "s2-header",
+ "metadata": {},
+ "source": "---\n# 2. Configuration"
+ },
+ {
+ "cell_type": "markdown",
+ "id": "s2-desc",
+ "metadata": {},
+ "source": "Configure all parameters here. **This is the only cell you need to modify.**\n\n- **`PRESET = \"demo\"`** \u2014 5 epochs, 1 K synthetic patients, ~30\u201345 min on T4\n- **`PRESET = \"production\"`** \u2014 20 epochs, 10 K synthetic patients, ~3\u20135 hrs on T4"
+ },
+ {
+ "cell_type": "code",
+ "id": "s2-config",
+ "metadata": {},
+ "outputs": [],
+ "execution_count": null,
+ "source": "# ============================================================\n# CONFIGURATION \u2014 All modifiable parameters in one place\n# ============================================================\n\n# --- Preset ---\nPRESET = \"demo\" # \"demo\" or \"production\"\n\n# --- Training parameters ---\nif PRESET == \"demo\":\n EPOCHS = 5\n BATCH_SIZE = 16\n N_SYNTHETIC_SAMPLES = 1_000\n WARMUP_STEPS = 100\nelif PRESET == \"production\":\n EPOCHS = 20\n BATCH_SIZE = 16\n N_SYNTHETIC_SAMPLES = 10_000\n WARMUP_STEPS = 1_000\n\nLR = 1e-5 # Paper LR; low to avoid catastrophic forgetting of BART weights\nMAX_SEQ_LENGTH = 512 # Max tokens per patient (visits + special tokens)\n\n# --- Model architecture ---\nD_HIDDEN = 128 # Hidden dim for demographic prompt encoder\nPROMPT_LENGTH = 1 # Prompt vectors per demographic feature (1 is sufficient per paper)\n\n# --- BART backbone ---\n# \"facebook/bart-base\": pretrained BART (139 M params, 768 hidden dim).\n# PromptEHR fine-tunes these weights rather than training from scratch \u2014\n# the pretrained sequence modeling prior means even 20 epochs can produce good results.\nBART_CONFIG_NAME = \"facebook/bart-base\"\n\n# --- Generation parameters ---\nRANDOM_SAMPLING = True # True: nucleus sampling (diverse), False: greedy (deterministic)\nTEMPERATURE = 0.7 # Lower = more common codes. Higher = more rare/diverse codes.\nTOP_P = 0.95 # Nucleus sampling: sample from top 95% probability mass.\n\n# --- Reproducibility ---\nSEED = 42\n\n# --- Paths (all derived from BASE_DIR) ---\nBASE_DIR = '/content/drive/MyDrive/PromptEHR_Training' if IN_COLAB else './promptehr_training'\nDATA_DIR = f'{BASE_DIR}/data'\nCHECKPOINT_DIR = f'{BASE_DIR}/checkpoints'\nOUTPUT_DIR = f'{BASE_DIR}/output'\n\nfor d in [DATA_DIR, CHECKPOINT_DIR, OUTPUT_DIR]:\n os.makedirs(d, exist_ok=True)\n\nprint(f\"Preset: {PRESET}\")\nprint(f\"Epochs: {EPOCHS} | Batch size: {BATCH_SIZE} | LR: {LR}\")\nprint(f\"Synthetic: {N_SYNTHETIC_SAMPLES:,} patients\")\nprint(f\"Base directory: {BASE_DIR}\")\nprint(\"\u2713 Configuration complete\")"
+ },
+ {
+ "cell_type": "markdown",
+ "id": "s3-header",
+ "metadata": {},
+ "source": "---\n# 3. Data Upload"
+ },
+ {
+ "cell_type": "markdown",
+ "id": "s3-desc",
+ "metadata": {},
+ "source": "Upload your MIMIC-III CSV files. PromptEHR needs **3 files** (one more than HALO \u2014 `PATIENTS.csv` is required for demographic conditioning):\n\n1. `PATIENTS.csv` \u2014 date of birth and gender\n2. `ADMISSIONS.csv` \u2014 admission timestamps (used to compute age at first admission)\n3. `DIAGNOSES_ICD.csv` \u2014 ICD-9 diagnosis codes\n\nFiles persist across Colab sessions when saved to Google Drive.\n\n**No MIMIC-III?** The next cell automatically activates Demo Mode."
+ },
+ {
+ "cell_type": "code",
+ "id": "s3-upload",
+ "metadata": {},
+ "outputs": [],
+ "execution_count": null,
+ "source": "DEMO_MODE = False\n\n# Mount Drive (Colab only)\nif IN_COLAB:\n from google.colab import drive\n drive.mount('/content/drive')\n print(\"\u2713 Google Drive mounted\")\n\n# Check which files exist\nrequired_files = {\n 'PATIENTS.csv': 'Patient demographics (DOB, gender)',\n 'ADMISSIONS.csv': 'Admission records (timestamps)',\n 'DIAGNOSES_ICD.csv': 'ICD-9 diagnosis codes',\n}\nexisting = {f: os.path.exists(f'{DATA_DIR}/{f}') for f in required_files}\nmissing = [f for f, ok in existing.items() if not ok]\n\nprint(\"\\nMIMIC-III file status:\")\nfor fname, desc in required_files.items():\n mark = \"\u2713\" if existing[fname] else \"\u2717 MISSING\"\n print(f\" {mark} {fname} \u2014 {desc}\")\n\nif missing and IN_COLAB:\n print(f\"\\nUploading {len(missing)} missing file(s)...\")\n from google.colab import files as _colab_files\n uploaded = _colab_files.upload()\n for fname, data in uploaded.items():\n dest = f'{DATA_DIR}/{fname}'\n with open(dest, 'wb') as f:\n f.write(data)\n print(f\" Saved {fname} \u2192 {dest}\")\n missing = [f for f in required_files if not os.path.exists(f'{DATA_DIR}/{f}')]\n\nif missing:\n print(f\"\\nMIMIC-III files not available ({missing}).\")\n print(\"\u2192 Activating Demo Mode \u2014 full pipeline with synthetic stand-in data.\")\n DEMO_MODE = True\nelse:\n print(\"\\n\u2713 All 3 MIMIC-III files present. Running in MIMIC-III mode.\")"
+ },
+ {
+ "cell_type": "code",
+ "id": "s3-demo",
+ "metadata": {},
+ "outputs": [],
+ "execution_count": null,
+ "source": "if DEMO_MODE:\n print(\"Setting up Demo Mode data...\")\n from pyhealth.datasets.sample_dataset import InMemorySampleDataset\n\n # Synthetic stand-in: 200 patients, 2-6 visits, realistic ICD-9 codes.\n # Exercises the full pipeline without any real patient data.\n random.seed(SEED)\n icd9_pool = [\n \"428.0\", \"401.9\", \"250.00\", \"272.4\", \"410.71\", \"486\",\n \"585.3\", \"V58.61\", \"412\", \"414.01\", \"276.1\", \"285.9\",\n \"584.9\", \"305.1\", \"290.0\", \"427.31\", \"518.81\", \"496\",\n \"038.9\", \"599.0\",\n ]\n demo_samples = []\n for i in range(200):\n n_visits = random.randint(2, 6)\n visits = [random.sample(icd9_pool, random.randint(1, 5)) for _ in range(n_visits)]\n demo_samples.append({\n \"patient_id\": f\"DEMO_{i:04d}\",\n \"visits\": visits,\n \"age\": float(random.randint(18, 89)),\n \"gender\": random.randint(0, 1),\n })\n print(f\"\u2713 Demo dataset: {len(demo_samples)} patients, up to 6 visits each\")\n print(\" (Replace with real MIMIC-III data for publication-quality results)\")"
+ },
+ {
+ "cell_type": "code",
+ "id": "s3-validate",
+ "metadata": {},
+ "outputs": [],
+ "execution_count": null,
+ "source": "if not DEMO_MODE:\n print(\"Validating MIMIC-III files...\")\n _patients = pd.read_csv(f'{DATA_DIR}/PATIENTS.csv')\n assert 'SUBJECT_ID' in _patients.columns, \"PATIENTS.csv missing SUBJECT_ID\"\n assert 'GENDER' in _patients.columns, \"PATIENTS.csv missing GENDER\"\n assert 'DOB' in _patients.columns, \"PATIENTS.csv missing DOB\"\n print(f\"\u2713 PATIENTS.csv: {len(_patients):>8,} rows\")\n\n _admissions = pd.read_csv(f'{DATA_DIR}/ADMISSIONS.csv')\n assert 'SUBJECT_ID' in _admissions.columns, \"ADMISSIONS.csv missing SUBJECT_ID\"\n assert 'HADM_ID' in _admissions.columns, \"ADMISSIONS.csv missing HADM_ID\"\n print(f\"\u2713 ADMISSIONS.csv: {len(_admissions):>8,} rows\")\n\n _diagnoses = pd.read_csv(f'{DATA_DIR}/DIAGNOSES_ICD.csv')\n assert 'ICD9_CODE' in _diagnoses.columns, \"DIAGNOSES_ICD.csv missing ICD9_CODE\"\n print(f\"\u2713 DIAGNOSES_ICD.csv: {len(_diagnoses):>8,} rows\")\n\n del _patients, _admissions, _diagnoses # free memory\n print(\"\\n\u2713 All files validated successfully\")"
+ },
+ {
+ "cell_type": "markdown",
+ "id": "s4-header",
+ "metadata": {},
+ "source": "---\n# 4. Training"
+ },
+ {
+ "cell_type": "markdown",
+ "id": "s4-desc",
+ "metadata": {},
+ "source": "**What happens during training:**\n\n1. **Dataset loading**: PyHealth reads MIMIC-III and creates one sample per patient (nested visit sequences + demographics: age at first admission, gender).\n2. **Tokenization**: Each ICD-9 code is mapped to a unique BART token ID. Special tokens mark visit boundaries: `[VISIT_START]`, `[VISIT_END]`, `[SEQ_END]`.\n3. **Demographic prompts**: Age and gender are encoded into learned prompt vectors prepended to the BART encoder input \u2014 steering the model toward age/gender-appropriate diagnosis patterns.\n4. **Fine-tuning**: HuggingFace Trainer fine-tunes the BART Seq2Seq model to predict the next token conditioned on the demographic prompts.\n5. **Checkpoint**: Saved to `{CHECKPOINT_DIR}/checkpoint.pt` after training.\n\nThe `WARMUP_STEPS` ramp up the learning rate gradually during early training, preventing catastrophic forgetting of BART's pretrained sequence modeling capabilities."
+ },
+ {
+ "cell_type": "code",
+ "id": "s4-dataset",
+ "metadata": {},
+ "outputs": [],
+ "execution_count": null,
+ "source": "# Set all random seeds before any stochastic operation\ntorch.manual_seed(SEED)\nnp.random.seed(SEED)\nrandom.seed(SEED)\nif torch.cuda.is_available():\n torch.cuda.manual_seed_all(SEED)\n torch.backends.cudnn.deterministic = True\nprint(f\"\u2713 Random seed set to {SEED}\")\n\nfrom pyhealth.datasets import split_by_patient\nfrom pyhealth.models import PromptEHR\n\nif not DEMO_MODE:\n from pyhealth.datasets import MIMIC3Dataset\n from pyhealth.tasks import promptehr_generation_mimic3_fn\n\n print(\"\\nLoading MIMIC-III dataset (this may take a few minutes)...\")\n dataset = MIMIC3Dataset(\n root=DATA_DIR,\n tables=[\"patients\", \"admissions\", \"diagnoses_icd\"],\n code_mapping={},\n )\n print(f\"Loaded {len(dataset.patients):,} patients\")\n\n print(\"Applying PromptEHR generation task...\")\n sample_dataset = dataset.set_task(promptehr_generation_mimic3_fn)\n print(f\"Eligible patients (\u22652 visits with ICD-9 codes): {len(sample_dataset):,}\")\nelse:\n from pyhealth.datasets.sample_dataset import InMemorySampleDataset\n sample_dataset = InMemorySampleDataset(\n samples=demo_samples,\n input_schema={\"visits\": \"nested_sequence\"},\n output_schema={},\n )\n print(f\"Demo dataset ready: {len(sample_dataset)} patients\")\n\ntrain_dataset, val_dataset, _ = split_by_patient(sample_dataset, [0.8, 0.1, 0.1])\nprint(f\"\\nSplit: {len(train_dataset):,} train / {len(val_dataset):,} val patients\")"
+ },
+ {
+ "cell_type": "code",
+ "id": "s4-init",
+ "metadata": {},
+ "outputs": [],
+ "execution_count": null,
+ "source": "# Save config alongside checkpoint for reproducibility\n_config = {k: str(v) for k, v in globals().items()\n if k.isupper() and not k.startswith('_')\n and isinstance(v, (str, int, float, bool))}\n_config['timestamp'] = datetime.now().isoformat()\n_config_path = f'{CHECKPOINT_DIR}/config.json'\nwith open(_config_path, 'w') as f:\n json.dump(_config, f, indent=2)\nprint(f\"\u2713 Config saved to {_config_path}\")\n\n# Initialize model\nprint(\"\\nInitializing PromptEHR model...\")\nmodel = PromptEHR(\n dataset=train_dataset,\n n_num_features=1, # 1 continuous demographic feature: age\n cat_cardinalities=[2], # 1 categorical feature: gender (binary: 0=male, 1=female)\n d_hidden=D_HIDDEN,\n prompt_length=PROMPT_LENGTH,\n bart_config_name=BART_CONFIG_NAME,\n epochs=EPOCHS,\n batch_size=BATCH_SIZE,\n lr=LR,\n warmup_steps=WARMUP_STEPS,\n max_seq_length=MAX_SEQ_LENGTH,\n save_dir=CHECKPOINT_DIR,\n)\n\nn_special = 7 # PAD, BOS, EOS, UNK, VISIT_START, VISIT_END, SEQ_END\nn_codes = model._vocab.total_size - n_special\ntotal_params = sum(p.numel() for p in model.parameters())\nprint(f\"\u2713 PromptEHR initialized\")\nprint(f\" Vocabulary: {model._vocab.total_size} tokens \"\n f\"({n_codes} ICD-9 codes + {n_special} special tokens)\")\nprint(f\" Parameters: {total_params:,}\")"
+ },
+ {
+ "cell_type": "code",
+ "id": "s4-train",
+ "metadata": {},
+ "outputs": [],
+ "execution_count": null,
+ "source": "print(\"Starting training...\")\nprint(\"HuggingFace Trainer will print step-by-step progress below.\")\nprint(\"=\" * 60)\n\nmodel.train_model(train_dataset, val_dataset=val_dataset)\n\nprint(\"=\" * 60)\nprint(\"\u2713 Training complete!\")\nprint(f\" Checkpoint: {CHECKPOINT_DIR}/checkpoint.pt\")"
+ },
+ {
+ "cell_type": "code",
+ "id": "s4-loss",
+ "metadata": {},
+ "outputs": [],
+ "execution_count": null,
+ "source": "# Plot training loss from HuggingFace Trainer logs\n_state_files = glob.glob(f'{CHECKPOINT_DIR}/**/trainer_state.json', recursive=True)\n\nif _state_files:\n with open(_state_files[0]) as f:\n _log = json.load(f)['log_history']\n _steps = [e['step'] for e in _log if 'loss' in e]\n _losses = [e['loss'] for e in _log if 'loss' in e]\n\n if _steps:\n fig, ax = plt.subplots(figsize=(9, 4))\n ax.plot(_steps, _losses, 'b-o', linewidth=1.5, markersize=4, label='Training loss')\n ax.set_xlabel('Training step', fontsize=12)\n ax.set_ylabel('Cross-entropy loss', fontsize=12)\n ax.set_title('PromptEHR Training Loss', fontsize=14)\n ax.legend(); ax.grid(alpha=0.3)\n plt.tight_layout()\n _loss_plot = f'{OUTPUT_DIR}/training_loss.png'\n plt.savefig(_loss_plot, dpi=150); plt.show()\n print(f\"Initial loss: {_losses[0]:.4f} \u2192 Final loss: {_losses[-1]:.4f}\")\n print(f\"Plot saved to: {_loss_plot}\")\n else:\n print(\"No loss values recorded (too few steps for demo preset).\")\nelse:\n print(\"trainer_state.json not found \u2014 skipping loss curve.\")\n print(\"(Expected for very short demo runs.)\")"
+ },
+ {
+ "cell_type": "markdown",
+ "id": "s5-header",
+ "metadata": {},
+ "source": "---\n# 5. Generation"
+ },
+ {
+ "cell_type": "markdown",
+ "id": "s5-desc",
+ "metadata": {},
+ "source": "**How generation works:**\n\n1. **Demographic sampling**: For each synthetic patient, `synthesize_dataset` draws an `(age, gender)` pair from `model._demo_pool` \u2014 the real training population. This ensures the synthetic cohort's demographic profile mirrors MIMIC-III.\n2. **Prompt conditioning**: The sampled demographics are encoded into prompt vectors and prepended to the BART encoder input.\n3. **Autoregressive decoding**: BART generates tokens one at a time. Special tokens `[VISIT_START]` and `[VISIT_END]` structure the output into visits; `[SEQ_END]` ends the patient sequence.\n4. **Decoding**: Token IDs are mapped back to ICD-9 code strings.\n\n`RANDOM_SAMPLING = True` (default): nucleus sampling \u2014 diverse, realistic output. \n`RANDOM_SAMPLING = False`: greedy decoding \u2014 deterministic, may repeat common patterns."
+ },
+ {
+ "cell_type": "code",
+ "id": "s5-generate",
+ "metadata": {},
+ "outputs": [],
+ "execution_count": null,
+ "source": "print(f\"Generating {N_SYNTHETIC_SAMPLES:,} synthetic patients...\")\nprint(f\" Sampling: {'nucleus (random)' if RANDOM_SAMPLING else 'greedy'}\"\n + (f\", temperature={TEMPERATURE}, top_p={TOP_P}\" if RANDOM_SAMPLING else \"\"))\nprint(\"(This may take several minutes...)\")\n\nsynthetic = model.synthesize_dataset(\n num_samples=N_SYNTHETIC_SAMPLES,\n random_sampling=RANDOM_SAMPLING,\n)\n\nprint(f\"\\n\u2713 Generated {len(synthetic):,} synthetic patients\")\n\n# Preview\n_preview = []\nfor p in synthetic[:10]:\n _v0 = p[\"visits\"][0] if p[\"visits\"] else []\n _sample = \", \".join(_v0[:4]) + (\"...\" if len(_v0) > 4 else \"\")\n _preview.append({\n \"patient_id\": p[\"patient_id\"],\n \"n_visits\": len(p[\"visits\"]),\n \"total_codes\": sum(len(v) for v in p[\"visits\"]),\n \"first_visit_codes\": _sample or \"(empty)\",\n })\ndisplay(pd.DataFrame(_preview))"
+ },
+ {
+ "cell_type": "code",
+ "id": "s5-save",
+ "metadata": {},
+ "outputs": [],
+ "execution_count": null,
+ "source": "# Save as JSON (full nested records \u2014 directly loadable back into PyHealth)\njson_path = f'{OUTPUT_DIR}/synthetic_patients.json'\nwith open(json_path, 'w') as f:\n json.dump(synthetic, f, indent=2)\nprint(f\"\u2713 {len(synthetic):,} patients \u2192 {json_path}\")\n\n# Save as CSV (flat SUBJECT_ID, VISIT_NUM, ICD9_CODE \u2014 matches MIMIC-III output schema)\n_rows = []\nfor p in synthetic:\n for _vnum, _visit in enumerate(p[\"visits\"], 1):\n for _code in _visit:\n _rows.append({\"SUBJECT_ID\": p[\"patient_id\"],\n \"VISIT_NUM\": _vnum,\n \"ICD9_CODE\": _code})\ndf_synthetic = pd.DataFrame(_rows)\ncsv_path = f'{OUTPUT_DIR}/synthetic_patients.csv'\ndf_synthetic.to_csv(csv_path, index=False)\nprint(f\"\u2713 {len(df_synthetic):,} records \u2192 {csv_path}\")\nprint(f\" Columns: SUBJECT_ID, VISIT_NUM, ICD9_CODE\")\nprint(\"\\nSample rows:\")\ndisplay(df_synthetic.head(8))"
+ },
+ {
+ "cell_type": "markdown",
+ "id": "s6-header",
+ "metadata": {},
+ "source": "---\n# 6. Results & Evaluation"
+ },
+ {
+ "cell_type": "code",
+ "id": "s6-stats",
+ "metadata": {},
+ "outputs": [],
+ "execution_count": null,
+ "source": "print(\"=\" * 60)\nprint(\"SYNTHETIC DATASET STATISTICS\")\nprint(\"=\" * 60)\n\nn_visits = [len(p[\"visits\"]) for p in synthetic]\nn_codes = [sum(len(v) for v in p[\"visits\"]) for p in synthetic]\n\nprint(f\"\\nPatients: {len(synthetic):,}\")\nprint(f\"\\nVisits per patient:\")\nprint(f\" Mean \u00b1 SD : {np.mean(n_visits):.2f} \u00b1 {np.std(n_visits):.2f}\")\nprint(f\" Median : {np.median(n_visits):.0f}\")\nprint(f\" Range : [{min(n_visits)}, {max(n_visits)}]\")\nprint(f\"\\nDiagnosis codes per patient:\")\nprint(f\" Mean \u00b1 SD : {np.mean(n_codes):.2f} \u00b1 {np.std(n_codes):.2f}\")\nprint(f\" Median : {np.median(n_codes):.0f}\")\nprint(f\" Range : [{min(n_codes)}, {max(n_codes)}]\")\n\nfig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))\nax1.hist(n_visits, bins=20, color='steelblue', edgecolor='white', alpha=0.85)\nax1.set_xlabel('Visits per patient'); ax1.set_ylabel('Count')\nax1.set_title('Visit Count Distribution')\nax2.hist(n_codes, bins=30, color='coral', edgecolor='white', alpha=0.85)\nax2.set_xlabel('Codes per patient'); ax2.set_ylabel('Count')\nax2.set_title('Code Count Distribution')\nplt.tight_layout()\nplt.savefig(f'{OUTPUT_DIR}/count_distributions.png', dpi=150)\nplt.show()"
+ },
+ {
+ "cell_type": "code",
+ "id": "s6-coverage",
+ "metadata": {},
+ "outputs": [],
+ "execution_count": null,
+ "source": "all_synth_codes = set(c for p in synthetic for v in p[\"visits\"] for c in v)\nn_real_codes = len(model._vocab._bart_to_code) # ICD-9 codes in vocabulary\ncoverage = len(all_synth_codes) / n_real_codes * 100 if n_real_codes > 0 else 0.0\n\nprint(f\"Vocabulary size (ICD-9 codes): {n_real_codes:,}\")\nprint(f\"Unique codes in synthetic: {len(all_synth_codes):,}\")\nprint(f\"Vocabulary coverage: {coverage:.1f}%\")\n\nif coverage < 30:\n print(\"\\n\u26a0 Low coverage may indicate mode collapse.\")\n print(\" Consider: more EPOCHS, lower LR, or check _demo_pool is populated.\")\nelif coverage < 60:\n print(\"\\nModerate coverage \u2014 expected for demo preset.\")\n print(\"Production training typically achieves 60\u201380%.\")\nelse:\n print(f\"\\n\u2713 Good vocabulary coverage.\")"
+ },
+ {
+ "cell_type": "code",
+ "id": "s6-demographics",
+ "metadata": {},
+ "outputs": [],
+ "execution_count": null,
+ "source": "# model._demo_pool stores (age, gender) pairs from training data.\n# synthesize_dataset samples from this pool for each synthetic patient,\n# so the synthetic cohort's demographics automatically mirror the training population.\nif model._demo_pool:\n _ages = [a for a, g in model._demo_pool]\n _genders = [g for a, g in model._demo_pool]\n _n_male = sum(1 for g in _genders if g == 0)\n _n_female = sum(1 for g in _genders if g == 1)\n\n fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(13, 5))\n\n ax1.hist(_ages, bins=25, density=True, color='steelblue', edgecolor='white',\n alpha=0.8, label='Training population')\n ax1.axvline(np.mean(_ages), color='navy', linestyle='--', linewidth=1.5,\n label=f'Mean age: {np.mean(_ages):.1f}')\n ax1.set_xlabel('Age at first admission', fontsize=12)\n ax1.set_ylabel('Density', fontsize=12)\n ax1.set_title('Age Distribution\\n(Conditioning Source)', fontsize=13)\n ax1.legend(fontsize=10)\n\n _bars = ax2.bar(['Male', 'Female'], [_n_male, _n_female],\n color=['steelblue', 'coral'], edgecolor='white', alpha=0.85)\n for _bar, _val in zip(_bars, [_n_male, _n_female]):\n ax2.text(_bar.get_x() + _bar.get_width()/2, _bar.get_height() + 5,\n f'{_val:,}\\n({_val/len(_genders)*100:.1f}%)',\n ha='center', va='bottom', fontsize=11)\n ax2.set_ylabel('Patient count', fontsize=12)\n ax2.set_title('Gender Distribution\\n(Conditioning Source)', fontsize=13)\n\n plt.tight_layout()\n plt.savefig(f'{OUTPUT_DIR}/demographics_distribution.png', dpi=150)\n plt.show()\n\n print(f\"Demographics pool: {len(model._demo_pool):,} training patients\")\n print(f\" Age: mean={np.mean(_ages):.1f}, std={np.std(_ages):.1f}, \"\n f\"range=[{min(_ages):.0f}, {max(_ages):.0f}]\")\n print(f\" Male: {_n_male:,} ({_n_male/len(_genders)*100:.1f}%)\")\n print(f\" Female: {_n_female:,} ({_n_female/len(_genders)*100:.1f}%)\")\n print(\"\\n\u2713 Synthetic patients are generated with demographics sampled from this distribution.\")\nelse:\n print(\"_demo_pool is empty \u2014 model was not trained before calling synthesize_dataset.\")\n print(\"Run Section 4 first, or load a checkpoint that was saved after training.\")"
+ },
+ {
+ "cell_type": "code",
+ "id": "s6-freq",
+ "metadata": {},
+ "outputs": [],
+ "execution_count": null,
+ "source": "# Build real training code frequencies by decoding processor-encoded visit tensors.\n# NestedSequenceProcessor: index 0=pad, 1=unk, 2+=codes.\n# _PromptEHRVocab mapping: bart_id = processor_idx + 5 for codes (idx>=2).\n_vocab_map = model._vocab._bart_to_code # bart_token_id -> ICD-9 code string\n_real_counts = Counter()\n\nfor _sample in train_dataset:\n for _visit in _sample.get(\"visits\", []):\n for _tok in _visit:\n _idx = int(_tok.item()) if hasattr(_tok, 'item') else int(_tok)\n if _idx >= 2: # skip pad(0) and unk(1)\n _bart_id = _idx + 5\n _code = _vocab_map.get(_bart_id)\n if _code:\n _real_counts[_code] += 1\n\n_synth_counts = Counter(c for p in synthetic for v in p[\"visits\"] for c in v)\n\n_top_codes = [c for c, _ in _real_counts.most_common(20)]\n_real_freq = [_real_counts[c] for c in _top_codes]\n_synth_freq = [_synth_counts.get(c, 0) for c in _top_codes]\n\nfig, ax = plt.subplots(figsize=(15, 5))\n_x = range(len(_top_codes))\nax.bar([i - 0.2 for i in _x], _real_freq, 0.38, label='Real (training)', color='steelblue', alpha=0.85)\nax.bar([i + 0.2 for i in _x], _synth_freq, 0.38, label='Synthetic', color='coral', alpha=0.85)\nax.set_xticks(_x)\nax.set_xticklabels(_top_codes, rotation=45, ha='right', fontsize=9)\nax.set_ylabel('Frequency', fontsize=12)\nax.set_title('Top-20 ICD-9 Code Frequency: Real vs Synthetic', fontsize=14)\nax.legend(fontsize=11); ax.grid(axis='y', alpha=0.3)\nplt.tight_layout()\nplt.savefig(f'{OUTPUT_DIR}/code_frequency_comparison.png', dpi=150)\nplt.show()\n\n# Pearson r (manual computation \u2014 no scipy dependency)\n_r_mean = np.mean(_real_freq); _s_mean = np.mean(_synth_freq)\n_num = sum((r - _r_mean)*(s - _s_mean) for r, s in zip(_real_freq, _synth_freq))\n_denom = (sum((r-_r_mean)**2 for r in _real_freq) * sum((s-_s_mean)**2 for s in _synth_freq)) ** 0.5\npearson_r = _num / _denom if _denom > 0 else 0.0\nprint(f\"Pearson r (top-20 code frequencies, real vs synthetic): {pearson_r:.3f}\")\nif pearson_r > 0.8: print(\"\u2713 Strong correlation \u2014 good distributional fidelity.\")\nelif pearson_r > 0.5: print(\"Moderate correlation \u2014 consider more epochs.\")\nelse: print(\"Weak correlation \u2014 model may need more training.\")"
+ },
+ {
+ "cell_type": "code",
+ "id": "s6-empty",
+ "metadata": {},
+ "outputs": [],
+ "execution_count": null,
+ "source": "_empty = [p for p in synthetic if not p[\"visits\"] or all(len(v) == 0 for v in p[\"visits\"])]\nif _empty:\n print(f\"\u26a0 {len(_empty)} / {len(synthetic)} patients have empty visit sequences.\")\n print(\" Possible causes:\")\n print(\" - Model is undertrained (increase EPOCHS)\")\n print(\" - Temperature too low (try TEMPERATURE = 1.0)\")\n print(\" - _demo_pool not populated (train before calling synthesize_dataset)\")\nelse:\n print(f\"\u2713 All {len(synthetic):,} patients have at least one visit with at least one code.\")"
+ },
+ {
+ "cell_type": "code",
+ "id": "s6-report",
+ "metadata": {},
+ "outputs": [],
+ "execution_count": null,
+ "source": "quality = {\n \"total_synthetic_patients\": len(synthetic),\n \"mean_visits_per_patient\": round(float(np.mean(n_visits)), 3),\n \"std_visits_per_patient\": round(float(np.std(n_visits)), 3),\n \"mean_codes_per_patient\": round(float(np.mean(n_codes)), 3),\n \"std_codes_per_patient\": round(float(np.std(n_codes)), 3),\n \"unique_codes_generated\": len(all_synth_codes),\n \"vocabulary_size\": n_real_codes,\n \"vocabulary_coverage_pct\": round(coverage, 2),\n \"empty_patients_count\": len(_empty),\n \"code_freq_pearson_r\": round(pearson_r, 4),\n \"training_patients\": len(train_dataset),\n \"vocab_total_size\": model._vocab.total_size,\n \"demo_mode\": DEMO_MODE,\n \"preset\": PRESET,\n \"epochs\": EPOCHS,\n \"seed\": SEED,\n \"timestamp\": datetime.now().isoformat(),\n}\nreport_path = f'{OUTPUT_DIR}/quality_report.json'\nwith open(report_path, 'w') as f:\n json.dump(quality, f, indent=2)\nprint(\"Quality Report:\")\nprint(json.dumps(quality, indent=2))\nprint(f\"\\n\u2713 Saved to {report_path}\")"
+ },
+ {
+ "cell_type": "markdown",
+ "id": "s7-header",
+ "metadata": {},
+ "source": "---\n# 7. Download & Next Steps"
+ },
+ {
+ "cell_type": "code",
+ "id": "s7-download",
+ "metadata": {},
+ "outputs": [],
+ "execution_count": null,
+ "source": "# Download output files (Colab only \u2014 silently skipped in local/SLURM environments)\n_outputs = [\n csv_path,\n json_path,\n report_path,\n f'{OUTPUT_DIR}/training_loss.png',\n f'{OUTPUT_DIR}/demographics_distribution.png',\n f'{OUTPUT_DIR}/code_frequency_comparison.png',\n f'{CHECKPOINT_DIR}/checkpoint.pt',\n f'{CHECKPOINT_DIR}/config.json',\n]\n\nif IN_COLAB:\n from google.colab import files as _colab_files\n print(\"Downloading output files...\")\n for _p in _outputs:\n if os.path.exists(_p):\n _colab_files.download(_p)\n print(f\" \u2713 {os.path.basename(_p)}\")\n else:\n print(f\" \u2014 {os.path.basename(_p)} (not found)\")\nelse:\n print(f\"Output files saved to: {OUTPUT_DIR}\")\n print(f\"Checkpoint: {CHECKPOINT_DIR}/checkpoint.pt\")\n for _p in _outputs:\n if os.path.exists(_p):\n _kb = os.path.getsize(_p) / 1024\n print(f\" {os.path.basename(_p):45s} {_kb:8.1f} KB\")"
+ },
+ {
+ "cell_type": "code",
+ "id": "s7-resume",
+ "metadata": {},
+ "outputs": [],
+ "execution_count": null,
+ "source": "# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n# CHECKPOINT RESUME \u2014 Run this cell instead of Section 4 if you already trained\n# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n# Uncomment everything below to load an existing checkpoint, then skip to Section 5.\n\n# from pyhealth.datasets import MIMIC3Dataset, split_by_patient\n# from pyhealth.tasks import promptehr_generation_mimic3_fn\n# from pyhealth.models import PromptEHR\n#\n# dataset = MIMIC3Dataset(\n# root=DATA_DIR,\n# tables=[\"patients\", \"admissions\", \"diagnoses_icd\"],\n# code_mapping={},\n# )\n# sample_dataset = dataset.set_task(promptehr_generation_mimic3_fn)\n# train_dataset, val_dataset, _ = split_by_patient(sample_dataset, [0.8, 0.1, 0.1])\n#\n# model = PromptEHR(\n# dataset=train_dataset,\n# n_num_features=1, cat_cardinalities=[2],\n# d_hidden=D_HIDDEN, prompt_length=PROMPT_LENGTH,\n# bart_config_name=BART_CONFIG_NAME,\n# epochs=EPOCHS, batch_size=BATCH_SIZE,\n# lr=LR, warmup_steps=WARMUP_STEPS,\n# max_seq_length=MAX_SEQ_LENGTH,\n# save_dir=CHECKPOINT_DIR,\n# )\n# ckpt = f'{CHECKPOINT_DIR}/checkpoint.pt'\n# model.load_model(ckpt)\n# print(f\"\u2713 Loaded checkpoint from {ckpt}. Proceed to Section 5.\")\n\nprint(\"(Resume template \u2014 uncomment the lines above to use)\")"
+ },
+ {
+ "cell_type": "markdown",
+ "id": "s7-congrats",
+ "metadata": {},
+ "source": "---\n## \ud83c\udf89 Congratulations!\n\nYou've successfully:\n1. \u2705 Trained a PromptEHR model conditioned on patient demographics\n2. \u2705 Generated synthetic patients whose age/gender distribution mirrors MIMIC-III\n3. \u2705 Validated ICD-9 code frequency fidelity against real training data\n4. \u2705 Saved output files for downstream use\n\n## Next Steps\n\n**Use your synthetic data:**\n- Train readmission/mortality/LoS prediction models on synthetic data\n- Evaluate fairness across demographic subgroups\n- Share synthetic patients without privacy concerns\n\n**Reload and generate more:**\n```python\nfrom pyhealth.models import PromptEHR\nmodel = PromptEHR(dataset=train_dataset, ...)\nmodel.load_model('./promptehr_training/checkpoints/checkpoint.pt')\nextra = model.synthesize_dataset(num_samples=50_000)\n```\n\n## Troubleshooting\n\n| Symptom | Cause | Fix |\n|---------|-------|-----|\n| `AssertionError: transformers>=4.48.3 required` | Old transformers installed | `pip install transformers --upgrade` |\n| Empty patients in output | Undertrained model | Increase `EPOCHS` or raise `TEMPERATURE` to `1.0` |\n| Training loss not decreasing after 2+ epochs | LR too high | Try `LR = 5e-6` and `WARMUP_STEPS = 500` |\n| Out of memory (OOM) | Batch too large | Reduce `BATCH_SIZE = 8` |\n| Very slow training | No GPU | Runtime \u2192 Change runtime type \u2192 T4 GPU |\n| `KeyError: 'visits'` in demo mode | Wrong schema | Ensure `input_schema={\"visits\": \"nested_sequence\"}` |\n| Synthetic codes all the same | Temperature too low | Try `TEMPERATURE = 1.0`, `RANDOM_SAMPLING = True` |\n\n---\n\n## Reference\n\nWang, Y., et al. \"PromptEHR: Conditional Electronic Healthcare Records Generation with Prompt Learning.\" *EMNLP 2023*. https://arxiv.org/abs/2211.01761\n\n---\n_Notebook for PyHealth 2.0 \u00b7 Branch: `promptehr-pr-integration` \u00b7 jalengg/PyHealth_"
+ }
+ ]
+}
\ No newline at end of file
diff --git a/examples/promptehr_mimic3_training.py b/examples/promptehr_mimic3_training.py
new file mode 100644
index 000000000..8387208db
--- /dev/null
+++ b/examples/promptehr_mimic3_training.py
@@ -0,0 +1,47 @@
+"""PromptEHR: Training on MIMIC-III.
+
+Train PromptEHR for synthetic EHR generation using PyHealth 2.0 API.
+
+Reference:
+ Wang et al. "PromptEHR: Conditional Electronic Health Records Generation
+ with Prompt Learning." CHIL 2023.
+"""
+
+from pyhealth.datasets import MIMIC3Dataset, split_by_patient
+from pyhealth.models import PromptEHR
+from pyhealth.tasks import promptehr_generation_mimic3_fn
+
+MIMIC3_ROOT = "/srv/local/data/physionet.org/files/mimiciii/1.4"
+
+# 1. Load MIMIC-III
+dataset = MIMIC3Dataset(
+ root=MIMIC3_ROOT,
+ tables=["patients", "admissions", "diagnoses_icd"],
+ code_mapping={},
+)
+
+# 2. Apply generation task
+sample_dataset = dataset.set_task(promptehr_generation_mimic3_fn)
+print(f"Patients: {len(sample_dataset)}")
+sample_dataset.stat()
+
+# 3. Split
+train, val, test = split_by_patient(sample_dataset, [0.8, 0.1, 0.1])
+
+# 4. Initialize model
+model = PromptEHR(
+ dataset=sample_dataset,
+ n_num_features=1,
+ cat_cardinalities=[2],
+ d_hidden=128,
+ prompt_length=1,
+ epochs=20,
+ batch_size=16,
+ lr=1e-5,
+ warmup_steps=1000,
+ save_dir="./save/promptehr/",
+)
+
+# 5. Train
+model.train_model(train, val)
+print("Training complete. Checkpoint saved to ./save/promptehr/")
diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py
index 14f0bf209..6cdf2ea45 100644
--- a/pyhealth/models/__init__.py
+++ b/pyhealth/models/__init__.py
@@ -16,6 +16,7 @@
from .grasp import GRASP, GRASPLayer
from .medlink import MedLink
from .micron import MICRON, MICRONLayer
+from .promptehr import PromptEHR
from .mlp import MLP
from .molerec import MoleRec, MoleRecLayer
from .retain import RETAIN, RETAINLayer
diff --git a/pyhealth/models/promptehr/__init__.py b/pyhealth/models/promptehr/__init__.py
new file mode 100644
index 000000000..fdf1327a3
--- /dev/null
+++ b/pyhealth/models/promptehr/__init__.py
@@ -0,0 +1,41 @@
+"""PromptEHR: Prompt-based BART model for synthetic EHR generation.
+
+This module provides a demographic-conditioned sequence-to-sequence model
+for generating realistic synthetic electronic health records.
+
+Main components:
+ - PromptEHR: Main model class (inherits from BaseModel)
+ - ConditionalPromptEncoder: Demographic conditioning with reparameterization
+ - PromptBartEncoder: Modified BART encoder with prompt injection
+ - PromptBartDecoder: Modified BART decoder with prompt injection
+ - VisitStructureSampler: Utility for structure-constrained generation
+ - Generation functions: sample_demographics, parse_sequence_to_visits, etc.
+"""
+
+from .model import PromptEHR
+from .conditional_prompt import ConditionalPromptEncoder
+from .bart_encoder import PromptBartEncoder
+from .bart_decoder import PromptBartDecoder
+from .visit_sampler import VisitStructureSampler
+from .generation import (
+ DemographicSampler,
+ sample_demographics,
+ decode_patient_demographics,
+ parse_sequence_to_visits,
+ generate_patient_sequence_conditional,
+ generate_patient_with_structure_constraints
+)
+
+__all__ = [
+ "PromptEHR",
+ "ConditionalPromptEncoder",
+ "PromptBartEncoder",
+ "PromptBartDecoder",
+ "VisitStructureSampler",
+ "DemographicSampler",
+ "sample_demographics",
+ "decode_patient_demographics",
+ "parse_sequence_to_visits",
+ "generate_patient_sequence_conditional",
+ "generate_patient_with_structure_constraints",
+]
diff --git a/pyhealth/models/promptehr/bart_decoder.py b/pyhealth/models/promptehr/bart_decoder.py
new file mode 100644
index 000000000..e6d01a70b
--- /dev/null
+++ b/pyhealth/models/promptehr/bart_decoder.py
@@ -0,0 +1,325 @@
+"""BART decoder with prompt injection for demographic conditioning.
+
+This module provides a modified BART decoder that accepts demographic prompt
+embeddings and prepends them to decoder input sequences for conditioning.
+
+Ported from pehr_scratch/prompt_bart_decoder.py (lines 1-207).
+"""
+
+import torch
+import torch.nn as nn
+from typing import Optional, Tuple
+from transformers.models.bart.modeling_bart import BartDecoder
+from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
+
+
+class PromptBartDecoder(BartDecoder):
+ """BART decoder modified to accept and prepend demographic prompt embeddings.
+
+ Extends the standard BART decoder to support prompt-based conditioning by:
+ 1. Accepting optional prompt embeddings as input
+ 2. Prepending prompts to decoder input token embeddings
+ 3. Extending attention masks to cover prepended prompts
+ 4. Creating causal masks for autoregressive generation
+ 5. Processing through standard BART decoder layers with cross-attention
+
+ This enables demographic conditioning (age + gender) by injecting learned
+ prompt vectors at the decoder input, maintaining demographic alignment
+ during generation (dual prompt injection with encoder).
+
+ Args:
+ config: BartConfig from transformers
+ embed_tokens: Token embedding layer (optional)
+
+ Example:
+ >>> from transformers import BartConfig
+ >>> config = BartConfig.from_pretrained("facebook/bart-base")
+ >>> decoder = PromptBartDecoder(config)
+ >>> # Decode with prompts
+ >>> prompt_embeds = torch.randn(16, 2, 768) # [batch, n_prompts, hidden]
+ >>> input_ids = torch.randint(0, 1000, (16, 50)) # [batch, tgt_len]
+ >>> encoder_outputs = torch.randn(16, 100, 768) # [batch, src_len, hidden]
+ >>> outputs = decoder(
+ ... input_ids,
+ ... encoder_hidden_states=encoder_outputs,
+ ... inputs_prompt_embeds=prompt_embeds
+ ... )
+ """
+
+ def __init__(self, config, embed_tokens=None):
+ """Initialize prompt-aware BART decoder.
+
+ Args:
+ config: BartConfig from transformers
+ embed_tokens: Optional token embedding layer
+ """
+ super().__init__(config, embed_tokens)
+
+ # Initialize embedding scale factor (BART uses sqrt(d_model) scaling)
+ self.embed_scale = None
+ if config.scale_embedding:
+ self.embed_scale = (config.d_model ** 0.5)
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ inputs_prompt_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> BaseModelOutputWithPastAndCrossAttentions:
+ """Forward pass with optional demographic prompt embeddings.
+
+ Args:
+ input_ids: [batch, tgt_seq_len] decoder token IDs
+ attention_mask: [batch, tgt_seq_len] decoder attention mask (1=attend, 0=ignore)
+ encoder_hidden_states: [batch, src_seq_len, hidden_dim] encoder outputs
+ encoder_attention_mask: [batch, src_seq_len] encoder attention mask
+ head_mask: [num_layers, num_heads] mask for self-attention heads
+ cross_attn_head_mask: [num_layers, num_heads] mask for cross-attention heads
+ past_key_values: Cached key-value states for efficient generation
+ inputs_embeds: [batch, tgt_seq_len, hidden_dim] pre-computed embeddings (optional)
+ inputs_prompt_embeds: [batch, n_prompts, hidden_dim] demographic prompts (optional)
+ use_cache: Whether to return key-value cache for generation
+ output_attentions: Whether to return attention weights
+ output_hidden_states: Whether to return all hidden states
+ return_dict: Whether to return BaseModelOutputWithPastAndCrossAttentions or tuple
+
+ Returns:
+ BaseModelOutputWithPastAndCrossAttentions with:
+ - last_hidden_state: [batch, n_prompts + tgt_len, hidden_dim]
+ - past_key_values: Cached key-value states (if use_cache=True)
+ - hidden_states: Tuple of all layer outputs (if output_hidden_states=True)
+ - attentions: Tuple of self-attention weights (if output_attentions=True)
+ - cross_attentions: Tuple of cross-attention weights (if output_attentions=True)
+ """
+ # Set output flags from config defaults
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # Get decoder input embeddings from token IDs
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+ # Apply embedding scaling if configured
+ if self.embed_scale is not None:
+ inputs_embeds = inputs_embeds * self.embed_scale
+
+ # Store original sequence length before prepending prompts
+ original_seq_len = inputs_embeds.shape[1]
+
+ # Prepend prompt embeddings if provided
+ if inputs_prompt_embeds is not None:
+ # Concatenate prompts before decoder input embeddings
+ # inputs_prompt_embeds: [batch, n_prompts, hidden_dim]
+ # inputs_embeds: [batch, tgt_len, hidden_dim]
+ # Result: [batch, n_prompts + tgt_len, hidden_dim]
+ inputs_embeds = torch.cat([inputs_prompt_embeds, inputs_embeds], dim=1)
+
+ # Extend attention mask for prepended prompts
+ batch_size, n_prompts = inputs_prompt_embeds.shape[:2]
+
+ # Create attention mask for prompts (all 1s - always attend to prompts)
+ prompt_attention_mask = torch.ones(
+ batch_size, n_prompts,
+ dtype=attention_mask.dtype if attention_mask is not None else torch.long,
+ device=inputs_embeds.device
+ )
+
+ if attention_mask is not None:
+ # Concatenate prompt mask with decoder attention mask
+ attention_mask = torch.cat([prompt_attention_mask, attention_mask], dim=1)
+ else:
+ # Create attention mask for all tokens (prompts + decoder input)
+ total_seq_len = inputs_embeds.shape[1]
+ attention_mask = torch.ones(
+ batch_size, total_seq_len,
+ dtype=torch.long,
+ device=inputs_embeds.device
+ )
+
+ # Get positional embeddings for full sequence (prompts + decoder tokens)
+ past_key_values_length = 0
+ if past_key_values is not None:
+ # Handle Cache object (new transformers API) or tuple (old API)
+ if hasattr(past_key_values, 'get_seq_length'):
+ past_key_values_length = past_key_values.get_seq_length()
+ elif isinstance(past_key_values, (tuple, list)) and len(past_key_values) > 0:
+ # Defensive: handle unexpected cache structures gracefully
+ # pehr-scratch-expert confirmed: defaulting to 0 is safe (slightly degrades
+ # quality but prevents crash). BART handles positional errors gracefully.
+ try:
+ if past_key_values[0] is not None and isinstance(past_key_values[0], (tuple, list)):
+ if len(past_key_values[0]) > 0 and past_key_values[0][0] is not None:
+ past_key_values_length = past_key_values[0][0].shape[2]
+ except (IndexError, TypeError, AttributeError):
+ # Safe fallback: slightly degrades quality but prevents crash
+ # Positional embeddings will be calculated from position 0
+ past_key_values_length = 0
+
+ # Get positional embeddings (BART uses learned positional embeddings)
+ positions = self.embed_positions(inputs_embeds, past_key_values_length)
+
+ # Combine input embeddings + positional embeddings
+ hidden_states = inputs_embeds + positions
+ hidden_states = self.layernorm_embedding(hidden_states)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+ # Create combined attention mask (causal + padding)
+ if attention_mask is not None:
+ # Create causal mask for decoder self-attention
+ combined_attention_mask = _make_causal_mask(
+ inputs_embeds.shape[:2],
+ inputs_embeds.dtype,
+ device=inputs_embeds.device,
+ past_key_values_length=past_key_values_length,
+ )
+ # Expand padding mask and combine with causal mask
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=inputs_embeds.shape[1])
+ combined_attention_mask = combined_attention_mask + expanded_attn_mask
+ else:
+ # Create causal mask only (no padding)
+ combined_attention_mask = _make_causal_mask(
+ inputs_embeds.shape[:2],
+ inputs_embeds.dtype,
+ device=inputs_embeds.device,
+ past_key_values_length=past_key_values_length,
+ )
+
+ # Expand encoder attention mask for cross-attention
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
+ # [batch, src_len] → [batch, 1, tgt_len, src_len]
+ encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=inputs_embeds.shape[1])
+
+ # Initialize output containers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
+
+ # Pass through decoder layers
+ for idx, decoder_layer in enumerate(self.layers):
+ # Save hidden state before layer if requested
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ # Forward through decoder layer
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=combined_attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+ cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ # Update hidden states
+ hidden_states = layer_outputs[0]
+
+ # Save attention weights if requested
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+ if encoder_hidden_states is not None:
+ all_cross_attentions += (layer_outputs[2],)
+
+ # Save final hidden state if requested
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ # Cache is handled by past_key_values object, not returned in tuple
+ next_cache = past_key_values if use_cache else None
+
+ # Return tuple format if not using return_dict
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
+ if v is not None
+ )
+
+ # Return BaseModelOutputWithPastAndCrossAttentions
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+def _make_causal_mask(
+ input_shape: Tuple[int, int],
+ dtype: torch.dtype,
+ device: torch.device,
+ past_key_values_length: int = 0
+) -> torch.Tensor:
+ """Create causal mask for decoder self-attention.
+
+ Creates a lower-triangular mask that prevents attending to future positions.
+ This is essential for autoregressive generation where each position can only
+ attend to earlier positions.
+
+ Args:
+ input_shape: (batch_size, tgt_len) shape of decoder input
+ dtype: Data type for mask tensor
+ device: Device to create mask on
+ past_key_values_length: Length of cached key-values from previous steps
+
+ Returns:
+ [batch, 1, tgt_len, tgt_len + past_len] causal mask with -inf for future positions
+ """
+ batch_size, tgt_len = input_shape
+
+ # Initialize mask with -inf (prevents attention)
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
+
+ # Create lower triangular mask (0 for allowed positions, -inf for future)
+ mask_cond = torch.arange(mask.size(-1), device=device)
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
+ mask = mask.to(dtype)
+
+ # If using cached key-values, allow attending to all past positions
+ if past_key_values_length > 0:
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
+
+ # Expand to [batch, 1, tgt_len, tgt_len + past_len]
+ return mask[None, None, :, :].expand(batch_size, 1, tgt_len, tgt_len + past_key_values_length)
+
+
+def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None) -> torch.Tensor:
+ """Expand attention mask from [batch, src_len] to [batch, 1, tgt_len, src_len].
+
+ Inverts the mask (1→0, 0→1) and fills masked positions with -inf to prevent attention.
+
+ Args:
+ mask: [batch, src_len] attention mask (1=attend, 0=ignore)
+ dtype: Target data type for the expanded mask
+ tgt_len: Target sequence length (defaults to src_len)
+
+ Returns:
+ [batch, 1, tgt_len, src_len] expanded mask with -inf for masked positions
+ """
+ batch_size, src_len = mask.size()
+ tgt_len = tgt_len if tgt_len is not None else src_len
+
+ # Expand dimensions: [batch, src_len] → [batch, 1, tgt_len, src_len]
+ expanded_mask = mask[:, None, None, :].expand(batch_size, 1, tgt_len, src_len).to(dtype)
+
+ # Invert mask: 1 (attend) → 0, 0 (ignore) → 1
+ inverted_mask = 1.0 - expanded_mask
+
+ # Fill masked positions with -inf (prevents attention)
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
diff --git a/pyhealth/models/promptehr/bart_encoder.py b/pyhealth/models/promptehr/bart_encoder.py
new file mode 100644
index 000000000..726f34cb9
--- /dev/null
+++ b/pyhealth/models/promptehr/bart_encoder.py
@@ -0,0 +1,214 @@
+"""BART encoder with prompt injection for demographic conditioning.
+
+This module provides a modified BART encoder that accepts demographic prompt
+embeddings and prepends them to input sequences for conditioning.
+
+Ported from pehr_scratch/prompt_bart_encoder.py (lines 1-149).
+"""
+
+import torch
+import torch.nn as nn
+from typing import Optional
+from transformers.models.bart.modeling_bart import BartEncoder
+from transformers.modeling_outputs import BaseModelOutput
+
+
+class PromptBartEncoder(BartEncoder):
+ """BART encoder modified to accept and prepend demographic prompt embeddings.
+
+ Extends the standard BART encoder to support prompt-based conditioning by:
+ 1. Accepting optional prompt embeddings as input
+ 2. Prepending prompts to input token embeddings
+ 3. Extending attention masks to cover prepended prompts
+ 4. Processing through standard BART encoder layers
+
+ This enables demographic conditioning (age + gender) by injecting learned
+ prompt vectors at the encoder input.
+
+ Args:
+ config: BartConfig from transformers
+ embed_tokens: Token embedding layer (optional)
+
+ Example:
+ >>> from transformers import BartConfig
+ >>> config = BartConfig.from_pretrained("facebook/bart-base")
+ >>> encoder = PromptBartEncoder(config)
+ >>> # Encode with prompts
+ >>> prompt_embeds = torch.randn(16, 2, 768) # [batch, n_prompts, hidden]
+ >>> input_ids = torch.randint(0, 1000, (16, 100)) # [batch, seq_len]
+ >>> outputs = encoder(input_ids, inputs_prompt_embeds=prompt_embeds)
+ """
+
+ def __init__(self, config, embed_tokens=None):
+ """Initialize prompt-aware BART encoder.
+
+ Args:
+ config: BartConfig from transformers
+ embed_tokens: Optional token embedding layer
+ """
+ super().__init__(config, embed_tokens)
+
+ # Initialize embedding scale factor (BART uses sqrt(d_model) scaling)
+ self.embed_scale = None
+ if config.scale_embedding:
+ self.embed_scale = (config.d_model ** 0.5)
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ inputs_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> BaseModelOutput:
+ """Forward pass with optional demographic prompt embeddings.
+
+ Args:
+ input_ids: [batch, seq_len] token IDs
+ attention_mask: [batch, seq_len] attention mask (1=attend, 0=ignore)
+ head_mask: [num_layers, num_heads] mask for attention heads
+ inputs_embeds: [batch, seq_len, hidden_dim] pre-computed embeddings (optional)
+ inputs_prompt_embeds: [batch, n_prompts, hidden_dim] demographic prompts (optional)
+ output_attentions: Whether to return attention weights
+ output_hidden_states: Whether to return all hidden states
+ return_dict: Whether to return BaseModelOutput or tuple
+
+ Returns:
+ BaseModelOutput with:
+ - last_hidden_state: [batch, n_prompts + seq_len, hidden_dim]
+ - hidden_states: Tuple of all layer outputs (if output_hidden_states=True)
+ - attentions: Tuple of attention weights (if output_attentions=True)
+ """
+ # Set output flags from config defaults
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # Get input embeddings from token IDs
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+ # Apply embedding scaling if configured
+ if self.embed_scale is not None:
+ inputs_embeds = inputs_embeds * self.embed_scale
+
+ # Prepend prompt embeddings if provided
+ if inputs_prompt_embeds is not None:
+ # Concatenate prompts before input embeddings
+ # inputs_prompt_embeds: [batch, n_prompts, hidden_dim]
+ # inputs_embeds: [batch, seq_len, hidden_dim]
+ # Result: [batch, n_prompts + seq_len, hidden_dim]
+ inputs_embeds = torch.cat([inputs_prompt_embeds, inputs_embeds], dim=1)
+
+ # Extend attention mask to account for prepended prompts
+ batch_size, n_prompts = inputs_prompt_embeds.shape[:2]
+
+ if attention_mask is not None:
+ # Create attention mask for prompts matching existing mask dtype/device
+ prompt_attention_mask = torch.ones(
+ batch_size, n_prompts,
+ dtype=attention_mask.dtype,
+ device=attention_mask.device
+ )
+ # Concatenate prompt mask with original mask
+ attention_mask = torch.cat([prompt_attention_mask, attention_mask], dim=1)
+ else:
+ # Create full attention mask for prompts + sequence
+ seq_len = inputs_embeds.shape[1] # Total length including prompts already prepended
+ attention_mask = torch.ones(
+ batch_size, seq_len,
+ dtype=torch.long,
+ device=inputs_embeds.device
+ )
+
+ # Get positional embeddings (BART uses learned positional embeddings)
+ embed_pos = self.embed_positions(inputs_embeds)
+
+ # Combine input embeddings + positional embeddings
+ hidden_states = inputs_embeds + embed_pos
+ hidden_states = self.layernorm_embedding(hidden_states)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+ # Expand attention mask from [batch, seq_len] to [batch, 1, tgt_len, src_len]
+ if attention_mask is not None:
+ attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype)
+
+ # Initialize output containers
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ # Validate head_mask dimensionality
+ if head_mask is not None:
+ if head_mask.size()[0] != len(self.layers):
+ raise ValueError(
+ f"head_mask should have {len(self.layers)} layers, but has {head_mask.size()[0]}"
+ )
+
+ # Pass through encoder layers
+ for idx, encoder_layer in enumerate(self.layers):
+ # Save hidden state before layer if requested
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ # Get layer-specific head mask
+ layer_head_mask = head_mask[idx] if head_mask is not None else None
+
+ # Forward through encoder layer
+ layer_outputs = encoder_layer(
+ hidden_states,
+ attention_mask,
+ layer_head_mask=layer_head_mask,
+ output_attentions=output_attentions,
+ )
+
+ # Update hidden states
+ hidden_states = layer_outputs[0]
+
+ # Save attention weights if requested
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ # Save final hidden state if requested
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ # Return tuple format if not using return_dict
+ if not return_dict:
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
+
+ # Return BaseModelOutput
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=encoder_states,
+ attentions=all_attentions,
+ )
+
+
+def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None) -> torch.Tensor:
+ """Expand attention mask from [batch, src_len] to [batch, 1, tgt_len, src_len].
+
+ Inverts the mask (1→0, 0→1) and fills masked positions with -inf to prevent attention.
+
+ Args:
+ mask: [batch, src_len] attention mask (1=attend, 0=ignore)
+ dtype: Target data type for the expanded mask
+ tgt_len: Target sequence length (defaults to src_len for encoder self-attention)
+
+ Returns:
+ [batch, 1, tgt_len, src_len] expanded mask with -inf for masked positions
+ """
+ batch_size, src_len = mask.size()
+ tgt_len = tgt_len if tgt_len is not None else src_len
+
+ # Expand dimensions: [batch, src_len] → [batch, 1, tgt_len, src_len]
+ expanded_mask = mask[:, None, None, :].expand(batch_size, 1, tgt_len, src_len).to(dtype)
+
+ # Invert mask: 1 (attend) → 0, 0 (ignore) → 1
+ inverted_mask = 1.0 - expanded_mask
+
+ # Fill masked positions with -inf (prevents attention)
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
diff --git a/pyhealth/models/promptehr/conditional_prompt.py b/pyhealth/models/promptehr/conditional_prompt.py
new file mode 100644
index 000000000..4122a5d31
--- /dev/null
+++ b/pyhealth/models/promptehr/conditional_prompt.py
@@ -0,0 +1,251 @@
+"""Conditional prompt encoder for demographic conditioning.
+
+This module provides demographic conditioning through prompt-based learning
+with reparameterization to prevent overfitting.
+
+Ported from pehr_scratch/conditional_prompt.py (lines 1-219).
+"""
+
+import torch
+import torch.nn as nn
+from typing import Optional
+
+
+class NumericalConditionalPrompt(nn.Module):
+ """Embeds continuous numerical features (e.g., age) with reparameterization.
+
+ Uses intermediate d_hidden=128 dimension for better gradient flow and
+ regularization, following PromptEHR's architecture.
+ """
+
+ def __init__(
+ self,
+ n_num_features: int,
+ hidden_dim: int,
+ d_hidden: int = 128,
+ prompt_length: int = 1
+ ):
+ """Initialize numerical prompt encoder with reparameterization.
+
+ Args:
+ n_num_features: Number of continuous features (1 for age only)
+ hidden_dim: Output dimension size (768 for BART-base)
+ d_hidden: Intermediate reparameterization dimension (default: 128)
+ prompt_length: Number of prompt vectors per feature (default: 1)
+ """
+ super().__init__()
+ self.n_num_features = n_num_features
+ self.hidden_dim = hidden_dim
+ self.d_hidden = d_hidden
+ self.prompt_length = prompt_length
+
+ # Reparameterization: learned weight and bias in d_hidden space
+ self.weight = nn.Parameter(torch.Tensor(n_num_features, d_hidden))
+ self.bias = nn.Parameter(torch.Tensor(n_num_features, d_hidden))
+ nn.init.xavier_uniform_(self.weight)
+ nn.init.xavier_uniform_(self.bias)
+
+ # Project from d_hidden to output dimension
+ self.proj = nn.Linear(d_hidden, hidden_dim, bias=False)
+
+ def forward(self, x_num: torch.Tensor) -> torch.Tensor:
+ """Embed numerical features with reparameterization.
+
+ Args:
+ x_num: [batch, n_num_features] continuous values
+
+ Returns:
+ [batch, prompt_length * n_num_features, hidden_dim] embeddings
+ """
+ # Reparameterization: weight * value + bias
+ # x_num: [batch, n_num_features]
+ # weight: [n_num_features, d_hidden]
+ # Result: [batch, n_num_features, d_hidden]
+ x = self.weight[None] * x_num[..., None]
+ x = x + self.bias[None]
+
+ # Project to output dimension
+ # x: [batch, n_num_features, d_hidden] → [batch, n_num_features, hidden_dim]
+ x = self.proj(x)
+
+ # Output: [batch, n_num_features * prompt_length, hidden_dim]
+ return x
+
+
+class CategoricalConditionalPrompt(nn.Module):
+ """Embeds categorical features with offset-based indexing and reparameterization.
+
+ Uses single embedding table with offset-based indexing to prevent category
+ collision, following PromptEHR's architecture.
+ """
+
+ def __init__(
+ self,
+ cat_cardinalities: list,
+ hidden_dim: int,
+ d_hidden: int = 128,
+ prompt_length: int = 1
+ ):
+ """Initialize categorical prompt encoder with reparameterization.
+
+ Args:
+ cat_cardinalities: List of category counts for each feature
+ [2] for gender (M/F) - ethnicity removed
+ hidden_dim: Output dimension size (768 for BART-base)
+ d_hidden: Intermediate reparameterization dimension (default: 128)
+ prompt_length: Number of prompt vectors per feature (default: 1)
+ """
+ super().__init__()
+ assert cat_cardinalities, 'cat_cardinalities must be non-empty'
+ self.cat_cardinalities = cat_cardinalities
+ self.hidden_dim = hidden_dim
+ self.d_hidden = d_hidden
+ self.prompt_length = prompt_length
+
+ # Compute offset indices to prevent category collision
+ # Example: [2] → offsets = [0]
+ # Gender 0 (M) → index 0, Gender 1 (F) → index 1
+ category_offsets = torch.tensor([0] + cat_cardinalities[:-1]).cumsum(0)
+ self.register_buffer('category_offsets', category_offsets, persistent=False)
+
+ # Single embedding table for all categories
+ total_categories = sum(cat_cardinalities)
+ self.embeddings = nn.Embedding(total_categories, d_hidden)
+
+ # Learned bias per feature (not per category)
+ self.bias = nn.Parameter(torch.Tensor(len(cat_cardinalities), d_hidden))
+ nn.init.xavier_uniform_(self.bias)
+
+ # Project from d_hidden to output dimension
+ self.proj = nn.Linear(d_hidden, hidden_dim, bias=False)
+
+ def forward(self, x_cat: torch.Tensor) -> torch.Tensor:
+ """Embed categorical features with offset-based indexing.
+
+ Args:
+ x_cat: [batch, n_cat_features] categorical IDs
+
+ Returns:
+ [batch, n_cat_features * prompt_length, hidden_dim] embeddings
+ """
+ # Add offsets to prevent category collision
+ # x_cat: [batch, n_cat_features]
+ # category_offsets: [n_cat_features]
+ x = self.embeddings(x_cat + self.category_offsets[None])
+
+ # Add learned bias per feature
+ # x: [batch, n_cat_features, d_hidden]
+ # bias: [n_cat_features, d_hidden]
+ x = x + self.bias[None]
+
+ # Project to output dimension
+ # x: [batch, n_cat_features, d_hidden] → [batch, n_cat_features, hidden_dim]
+ x = self.proj(x)
+
+ # Output: [batch, n_cat_features * prompt_length, hidden_dim]
+ return x
+
+
+class ConditionalPromptEncoder(nn.Module):
+ """Combined prompt encoder for both numerical and categorical features.
+
+ Encodes patient demographics (age + gender) into prompt vectors that
+ condition the BART encoder and decoder.
+
+ Example:
+ >>> # For PromptEHR: age (continuous) + gender (categorical)
+ >>> encoder = ConditionalPromptEncoder(
+ ... n_num_features=1, # age
+ ... cat_cardinalities=[2], # gender (M/F)
+ ... hidden_dim=768, # BART dimension
+ ... d_hidden=128 # reparameterization
+ ... )
+ >>> # Batch of 16 patients
+ >>> age = torch.randn(16, 1) # Normalized ages
+ >>> gender = torch.randint(0, 2, (16, 1)) # 0=M, 1=F
+ >>> prompts = encoder(x_num=age, x_cat=gender)
+ >>> prompts.shape # [16, 2, 768] - 2 prompts (age + gender)
+ """
+
+ def __init__(
+ self,
+ n_num_features: Optional[int] = None,
+ cat_cardinalities: Optional[list] = None,
+ hidden_dim: int = 768,
+ d_hidden: int = 128,
+ prompt_length: int = 1
+ ):
+ """Initialize combined prompt encoder.
+
+ Args:
+ n_num_features: Number of continuous features (None to disable)
+ cat_cardinalities: Category counts for each categorical feature (None to disable)
+ hidden_dim: Hidden dimension size (768 for BART-base)
+ d_hidden: Intermediate reparameterization dimension (default: 128)
+ prompt_length: Number of prompt vectors per feature (default: 1)
+ """
+ super().__init__()
+ self.n_num_features = n_num_features
+ self.cat_cardinalities = cat_cardinalities
+ self.hidden_dim = hidden_dim
+ self.d_hidden = d_hidden
+ self.prompt_length = prompt_length
+
+ # Initialize numerical prompt encoder (age)
+ if n_num_features is not None and n_num_features > 0:
+ self.num_prompt = NumericalConditionalPrompt(
+ n_num_features, hidden_dim, d_hidden, prompt_length
+ )
+ else:
+ self.num_prompt = None
+
+ # Initialize categorical prompt encoder (gender)
+ if cat_cardinalities is not None and len(cat_cardinalities) > 0:
+ self.cat_prompt = CategoricalConditionalPrompt(
+ cat_cardinalities, hidden_dim, d_hidden, prompt_length
+ )
+ else:
+ self.cat_prompt = None
+
+ def forward(
+ self,
+ x_num: Optional[torch.Tensor] = None,
+ x_cat: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ """Encode demographics to prompt embeddings.
+
+ Args:
+ x_num: [batch, n_num_features] continuous values (optional)
+ x_cat: [batch, n_cat_features] categorical IDs (optional)
+
+ Returns:
+ [batch, total_prompts, hidden_dim] combined prompt embeddings
+ """
+ prompts = []
+
+ if x_num is not None and self.num_prompt is not None:
+ num_embeds = self.num_prompt(x_num)
+ prompts.append(num_embeds)
+
+ if x_cat is not None and self.cat_prompt is not None:
+ cat_embeds = self.cat_prompt(x_cat)
+ prompts.append(cat_embeds)
+
+ if len(prompts) == 0:
+ raise ValueError("No prompt embeddings generated. Provide x_num or x_cat.")
+
+ # Concatenate along prompt dimension
+ combined_prompts = torch.cat(prompts, dim=1)
+ return combined_prompts
+
+ def get_num_prompts(self) -> int:
+ """Calculate total number of prompt tokens."""
+ num_prompts = 0
+
+ if self.num_prompt is not None:
+ num_prompts += self.n_num_features * self.prompt_length
+
+ if self.cat_prompt is not None:
+ num_prompts += len(self.cat_cardinalities) * self.prompt_length
+
+ return num_prompts
diff --git a/pyhealth/models/promptehr/generation.py b/pyhealth/models/promptehr/generation.py
new file mode 100644
index 000000000..3d674d1d1
--- /dev/null
+++ b/pyhealth/models/promptehr/generation.py
@@ -0,0 +1,1070 @@
+"""
+Generate synthetic patient sequences using trained PromptEHR model.
+
+This module provides functions for generating realistic synthetic EHR data
+using various conditioning strategies (demographics, visit structures, etc.).
+"""
+import json
+import math
+import numpy as np
+import torch
+from pathlib import Path
+from typing import Optional, List, Union, Dict
+
+
+class DemographicSampler:
+ """Sample patient demographics from empirical training distribution.
+
+ Samples age and gender by directly drawing from the observed distribution
+ in training data, ensuring synthetic patients match real population.
+ """
+
+ def __init__(self, patient_records: List, seed: int = 42):
+ """Initialize sampler with empirical demographics from training data.
+
+ Args:
+ patient_records: List of patient records from training set.
+ Each record should have 'age' and 'gender' attributes.
+ seed: Random seed for reproducibility.
+ """
+ self.rng = np.random.RandomState(seed)
+
+ # Extract empirical demographics
+ self.ages = []
+ self.genders = []
+
+ for patient in patient_records:
+ # Handle both dict-like and object-like patient records
+ if hasattr(patient, 'age') and hasattr(patient, 'gender'):
+ age = patient.age
+ gender = patient.gender
+ elif isinstance(patient, dict) and 'age' in patient and 'gender' in patient:
+ age = patient['age']
+ gender = patient['gender']
+ else:
+ continue
+
+ self.ages.append(float(age))
+ # Convert gender to int: M=0, F=1
+ if isinstance(gender, str):
+ gender_int = 0 if gender == 'M' else 1
+ else:
+ gender_int = int(gender)
+ self.genders.append(gender_int)
+
+ # Convert to numpy arrays
+ self.ages = np.array(self.ages)
+ self.genders = np.array(self.genders)
+
+ # Compute statistics
+ self.stats = {
+ 'age_mean': np.mean(self.ages),
+ 'age_std': np.std(self.ages),
+ 'age_median': np.median(self.ages),
+ 'age_min': np.min(self.ages),
+ 'age_max': np.max(self.ages),
+ 'male_pct': (self.genders == 0).mean(),
+ 'female_pct': (self.genders == 1).mean(),
+ }
+
+ def sample(self) -> dict:
+ """Sample demographics from empirical distribution.
+
+ Returns:
+ Dictionary with:
+ - 'age': float (sampled from training ages)
+ - 'sex': int (0=Male, 1=Female, sampled from training)
+ - 'sex_str': str ('M' or 'F')
+ """
+ # Sample random index from training data
+ idx = self.rng.randint(0, len(self.ages))
+
+ age = self.ages[idx]
+ sex = self.genders[idx]
+ sex_str = 'M' if sex == 0 else 'F'
+
+ return {
+ 'age': float(age),
+ 'sex': int(sex),
+ 'sex_str': sex_str
+ }
+
+ def __repr__(self):
+ return (
+ f"DemographicSampler(\n"
+ f" Age: mean={self.stats['age_mean']:.1f}, "
+ f"std={self.stats['age_std']:.1f}, "
+ f"range=[{self.stats['age_min']:.0f}, {self.stats['age_max']:.0f}]\n"
+ f" Gender: {self.stats['male_pct']:.1%} Male, "
+ f"{self.stats['female_pct']:.1%} Female\n"
+ f")"
+ )
+
+
+def build_first_code_prior(
+ training_data_path: str,
+ age_bins: int = 9
+) -> Dict:
+ """Build empirical P(first_code | age, gender) from training data.
+
+ Args:
+ training_data_path: Path to training data directory with MIMIC-III files
+ age_bins: Number of age bins (default: 9 for [0-10), [10-20), ..., [80-90])
+
+ Returns:
+ Dictionary mapping (age_bin, gender) -> {code: probability}
+
+ Example:
+ >>> prior = build_first_code_prior('/path/to/train_data')
+ >>> first_code = sample_first_code(65, 0, prior)
+ """
+ import pandas as pd
+
+ # Load training data
+ admissions = pd.read_csv(f'{training_data_path}/ADMISSIONS.csv')
+ patients = pd.read_csv(f'{training_data_path}/PATIENTS.csv')
+ diagnoses = pd.read_csv(f'{training_data_path}/DIAGNOSES_ICD.csv')
+
+ # Calculate age at first admission
+ admissions['ADMITTIME'] = pd.to_datetime(admissions['ADMITTIME'])
+ patients['DOB'] = pd.to_datetime(patients['DOB'])
+
+ first_admissions = admissions.loc[
+ admissions.groupby('SUBJECT_ID')['ADMITTIME'].idxmin()
+ ][['SUBJECT_ID', 'HADM_ID', 'ADMITTIME']]
+
+ demo = pd.merge(
+ patients[['SUBJECT_ID', 'GENDER', 'DOB']],
+ first_admissions,
+ on='SUBJECT_ID',
+ how='inner'
+ )
+ demo['AGE'] = (demo['ADMITTIME'].dt.year - demo['DOB'].dt.year)
+ demo['AGE'] = demo['AGE'].apply(lambda x: 90 if x > 89 else max(0, x))
+
+ # Get first diagnosis codes
+ first_diag = pd.merge(
+ demo[['SUBJECT_ID', 'HADM_ID', 'AGE', 'GENDER']],
+ diagnoses[['SUBJECT_ID', 'HADM_ID', 'ICD9_CODE']],
+ on=['SUBJECT_ID', 'HADM_ID'],
+ how='inner'
+ )
+
+ # Keep only first code per patient (seq_num=1 or first alphabetically)
+ first_diag = first_diag.sort_values(['SUBJECT_ID', 'ICD9_CODE'])
+ first_diag = first_diag.groupby('SUBJECT_ID').first().reset_index()
+
+ # Bin ages
+ first_diag['age_bin'] = pd.cut(
+ first_diag['AGE'],
+ bins=list(range(0, 91, 10)),
+ labels=list(range(age_bins)),
+ include_lowest=True
+ )
+
+ # Convert gender to int (0=M, 1=F)
+ first_diag['gender_int'] = (first_diag['GENDER'] == 'F').astype(int)
+
+ # Calculate empirical distribution
+ dist = {}
+ for (age_bin, gender), group in first_diag.groupby(['age_bin', 'gender_int']):
+ code_counts = group['ICD9_CODE'].value_counts()
+ total = code_counts.sum()
+ dist[(int(age_bin), int(gender))] = {
+ str(code): count / total
+ for code, count in code_counts.items()
+ }
+
+ return dist
+
+
+def sample_first_code(
+ age: float,
+ gender: int,
+ first_code_prior: Dict
+) -> str:
+ """Sample first diagnosis code from empirical distribution.
+
+ Args:
+ age: Patient age (0-90)
+ gender: Patient gender (0=Male, 1=Female)
+ first_code_prior: Prior from build_first_code_prior()
+
+ Returns:
+ Diagnosis code string (e.g., 'V3000', '41401')
+
+ Example:
+ >>> prior = build_first_code_prior('/path/to/train_data')
+ >>> code = sample_first_code(65, 0, prior)
+ >>> print(code) # e.g., 'V3000'
+ """
+ # Bin age
+ age_bin = min(int(age // 10), 8) # [0-9] -> 0, [10-19] -> 1, ..., [80+] -> 8
+
+ # Get distribution for this demographic
+ key = (age_bin, gender)
+ if key not in first_code_prior:
+ # Fallback to gender-only or overall distribution
+ fallback_key = None
+ for k in first_code_prior.keys():
+ if k[1] == gender:
+ fallback_key = k
+ break
+ if fallback_key:
+ key = fallback_key
+ else:
+ key = list(first_code_prior.keys())[0]
+
+ code_probs = first_code_prior[key]
+ codes = list(code_probs.keys())
+ probs = list(code_probs.values())
+
+ return np.random.choice(codes, p=probs)
+
+
+def build_frequency_prior(
+ tokenizer,
+ frequency_path: Optional[Union[str, Path]] = None,
+ epsilon: float = 1e-10,
+ vocab_size: Optional[int] = None
+) -> torch.Tensor:
+ """Build log-frequency prior over vocabulary for frequency-guided generation.
+
+ Args:
+ tokenizer: DiagnosisCodeTokenizer with vocab and code_offset attributes.
+ frequency_path: Path to training_frequencies.json. If None, uses uniform prior.
+ epsilon: Small constant to avoid log(0) (default: 1e-10).
+ vocab_size: Model vocabulary size. If None, inferred from tokenizer (not recommended).
+ Should match model's lm_head output dimension.
+
+ Returns:
+ torch.Tensor of shape [vocab_size] with log-frequencies.
+ Special tokens get 0 (neutral prior), diagnosis codes get log(freq + epsilon).
+
+ Example:
+ >>> prior = build_frequency_prior(tokenizer, './promptehr_outputs/training_frequencies.json', vocab_size=6963)
+ >>> logits_guided = logits + alpha * prior # Blend with model logits
+ """
+ # Use provided vocab size or infer from tokenizer
+ # WARNING: Inferred size may not match model if there's a mismatch!
+ if vocab_size is None:
+ vocab_size = len(tokenizer.vocab.idx2code)
+
+ log_freqs = torch.zeros(vocab_size)
+
+ if frequency_path is None:
+ # Uniform fallback: all codes equally likely
+ uniform_log_freq = math.log(1.0 / len(tokenizer.vocab.idx2code))
+ log_freqs[tokenizer.code_offset:] = uniform_log_freq
+ return log_freqs
+
+ # Load training frequencies
+ with open(frequency_path, 'r') as f:
+ freq_data = json.load(f)
+
+ frequencies = freq_data['frequencies']
+
+ # Fill in log-frequencies for each code
+ # NOTE: We map code_idx directly to token_id without adding code_offset
+ # because the model vocabulary doesn't include code_offset
+ for code, freq in frequencies.items():
+ if code in tokenizer.vocab.code2idx:
+ code_idx = tokenizer.vocab.code2idx[code]
+ if code_idx < vocab_size:
+ log_freqs[code_idx] = math.log(freq + epsilon)
+
+ # Codes not in training data get very low prior
+ min_log_freq = math.log(epsilon)
+ log_freqs = torch.where(
+ log_freqs == 0,
+ torch.tensor(min_log_freq),
+ log_freqs
+ )
+
+ return log_freqs
+
+
+def sample_demographics(
+ age_mean: float = 60.0,
+ age_std: float = 20.0,
+ male_prob: float = 0.56
+) -> dict:
+ """Sample realistic patient demographics.
+
+ Samples demographics from distributions matching MIMIC-III ICU population.
+
+ Args:
+ age_mean: Mean age for normal distribution (default: 60).
+ age_std: Standard deviation for age (default: 20).
+ male_prob: Probability of male gender (default: 0.56).
+
+ Returns:
+ Dictionary with:
+ - 'age': float in range [0, 90]
+ - 'sex': int (0=Male, 1=Female)
+ - 'sex_str': str ('M' or 'F')
+ """
+ # Sample age from normal distribution, clipped to [0, 90]
+ age = np.random.normal(age_mean, age_std)
+ age = np.clip(age, 0, 90)
+
+ # Sample sex from binomial distribution
+ sex = 0 if np.random.rand() < male_prob else 1
+ sex_str = 'M' if sex == 0 else 'F'
+
+ return {
+ 'age': float(age),
+ 'sex': sex,
+ 'sex_str': sex_str
+ }
+
+
+def decode_patient_demographics(age: float, gender: int) -> dict:
+ """Decode demographics back to readable format.
+
+ Args:
+ age: Normalized age value.
+ gender: Gender category index.
+
+ Returns:
+ Dictionary with decoded demographics.
+ """
+ # Gender mapping (from data_loader.py)
+ gender_map = {0: "M", 1: "F"} # Fixed: M=0, F=1
+
+ return {
+ "age": f"{age:.1f}",
+ "gender": gender_map.get(gender, "UNKNOWN")
+ }
+
+
+def parse_sequence_to_visits(
+ token_ids: List[int],
+ tokenizer
+) -> List[List[str]]:
+ """Parse generated token sequence into visit structure.
+
+ Extracts visits by splitting at and markers, and decodes
+ diagnosis codes within each visit.
+
+ Args:
+ token_ids: List of token IDs from model generation.
+ tokenizer: PyHealth Tokenizer instance (must have bos_token_id,
+ pad_token_id, code_offset, and vocab attributes).
+
+ Returns:
+ List of visits, where each visit is a list of ICD-9 code strings.
+
+ Example:
+ Input: [BOS, , 401.9, 250.00, , , 428.0, , ]
+ Output: [['401.9', '250.00'], ['428.0']]
+ """
+ visits = []
+ current_visit_codes = []
+
+ # Special token IDs
+ v_token_id = tokenizer.convert_tokens_to_indices([""])[0]
+ v_end_token_id = tokenizer.convert_tokens_to_indices(["<\\v>"])[0]
+ bos_token_id = tokenizer.bos_token_id
+ end_token_id = tokenizer.convert_tokens_to_indices([""])[0]
+
+ in_visit = False
+
+ for token_id in token_ids:
+ if token_id == v_token_id:
+ # Start of visit
+ in_visit = True
+ current_visit_codes = []
+ elif token_id == v_end_token_id:
+ # End of visit
+ if in_visit:
+ visits.append(current_visit_codes)
+ in_visit = False
+ elif token_id in [bos_token_id, end_token_id, tokenizer.pad_token_id]:
+ # Skip special tokens
+ continue
+ elif in_visit and token_id >= tokenizer.code_offset:
+ # Diagnosis code token - token_id is already the correct vocab index
+ # FIX: code2idx already includes special tokens, so don't subtract offset
+ if token_id < len(tokenizer.vocab.idx2code):
+ code = tokenizer.vocab.idx2code[token_id]
+ current_visit_codes.append(code)
+
+ # Handle case where sequence ends without closing visit marker
+ if in_visit and len(current_visit_codes) > 0:
+ visits.append(current_visit_codes)
+
+ return visits
+
+
+def generate_patient_sequence_conditional(
+ model,
+ tokenizer,
+ target_patient,
+ device: torch.device,
+ temperature: float = 0.3,
+ top_k: int = 0, # Disabled (test with top_p only)
+ top_p: float = 0.95, # Increased for more diversity
+ prompt_prob: float = 0.0,
+ max_codes_per_visit: int = 20
+) -> dict:
+ """Generate synthetic patient via conditional reconstruction (PromptEHR approach).
+
+ Given a real patient from test set, randomly masks codes and reconstructs
+ the full visit structure. Default prompt_prob=0.0 means zero-code-prompt
+ generation (only demographics provided).
+
+ Args:
+ model: Trained PromptBartModel or PromptEHR model.
+ tokenizer: DiagnosisCodeTokenizer instance.
+ target_patient: Patient record from test set to reconstruct.
+ Must have attributes: age, gender (or sex), visits.
+ device: Device to run on.
+ temperature: Sampling temperature (default: 0.3).
+ top_k: Top-k sampling parameter (default: 40).
+ top_p: Nucleus sampling parameter (default: 0.9).
+ prompt_prob: Probability of keeping each code as prompt (default: 0.0 = zero prompts).
+ max_codes_per_visit: Cap visit codes at this number (default: 20).
+
+ Returns:
+ Dictionary with:
+ - 'generated_visits': List[List[str]] of generated code sequences
+ - 'target_visits': List[List[str]] of original codes
+ - 'prompt_codes': List[List[str]] of codes provided as prompts
+ - 'demographics': dict of patient demographics
+ """
+ model.eval()
+
+ # Extract demographics (handle both 'gender' and 'sex' attributes)
+ if hasattr(target_patient, 'age'):
+ age = target_patient.age
+ else:
+ age = target_patient.get('age', 60.0)
+
+ if hasattr(target_patient, 'gender'):
+ gender_str = target_patient.gender
+ elif hasattr(target_patient, 'sex'):
+ gender_str = target_patient.sex
+ else:
+ gender_str = target_patient.get('gender', 'M')
+
+ gender = 1 if gender_str == 'F' else 0
+
+ x_num = torch.tensor([[age]], dtype=torch.float32).to(device)
+ x_cat = torch.tensor([[gender]], dtype=torch.long).to(device)
+
+ # Get visits
+ if hasattr(target_patient, 'visits'):
+ patient_visits = target_patient.visits
+ else:
+ patient_visits = target_patient.get('visits', [])
+
+ # Initialize accumulators
+ generated_visits = []
+ prompt_codes_per_visit = []
+
+ # Create dummy encoder input (prompts are in decoder)
+ encoder_input_ids = torch.tensor([[tokenizer.pad_token_id]], dtype=torch.long).to(device)
+ encoder_attention_mask = torch.ones_like(encoder_input_ids)
+
+ # Special token IDs
+ v_token_id = tokenizer.convert_tokens_to_indices([""])[0]
+ v_end_token_id = tokenizer.convert_tokens_to_indices(["<\\v>"])[0]
+
+ with torch.no_grad():
+ # Process each visit from target patient
+ for visit_idx, target_codes in enumerate(patient_visits):
+ # Step 1: Cap codes at max_codes_per_visit
+ num_codes = len(target_codes)
+ if num_codes > max_codes_per_visit:
+ target_codes = list(np.random.choice(target_codes, max_codes_per_visit, replace=False))
+ num_codes = max_codes_per_visit
+
+ if num_codes == 0:
+ # Empty visit - skip
+ generated_visits.append([])
+ prompt_codes_per_visit.append([])
+ continue
+
+ # Step 2: Randomly mask codes (binomial sampling)
+ keep_mask = np.random.binomial(1, prompt_prob, num_codes).astype(bool)
+ prompt_codes = [code for i, code in enumerate(target_codes) if keep_mask[i]]
+
+ # Step 3: Encode prompt codes as decoder input
+ prompt_token_ids = [tokenizer.bos_token_id, v_token_id]
+ for code in prompt_codes:
+ # FIX: code2idx already returns token ID with offset included
+ code_token_id = tokenizer.vocab.code2idx[code]
+ prompt_token_ids.append(code_token_id)
+
+ decoder_input_ids = torch.tensor([prompt_token_ids], dtype=torch.long).to(device)
+
+ # Step 4: Generate to reconstruct full visit
+ max_new_tokens = num_codes + 2 # Target length
+
+ # Use model.generate() for automatic handling
+ generated_ids = model.generate(
+ input_ids=encoder_input_ids,
+ attention_mask=encoder_attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ x_num=x_num,
+ x_cat=x_cat,
+ max_new_tokens=max_new_tokens,
+ do_sample=True,
+ num_beams=1, # Disable beam search, use sampling only
+ temperature=temperature,
+ top_k=top_k,
+ top_p=top_p,
+ no_repeat_ngram_size=1, # Prevents duplicate codes
+ eos_token_id=v_end_token_id, # Stop at
+ pad_token_id=tokenizer.pad_token_id,
+ bad_words_ids=[[tokenizer.bos_token_id]] # Suppress BOS in generation
+ )
+
+ # Step 5: Extract generated codes
+ visit_token_ids = generated_ids[0].cpu().tolist()
+
+ # Extract code tokens (skip BOS, , )
+ generated_code_ids = [
+ tid for tid in visit_token_ids
+ if tid >= tokenizer.code_offset
+ ]
+
+ # Decode codes (convert token IDs back to diagnosis codes)
+ # FIX: code2idx already includes special tokens, so don't subtract offset
+ generated_codes = []
+ for tid in generated_code_ids:
+ if tid < len(tokenizer.vocab.idx2code):
+ code = tokenizer.vocab.idx2code[tid]
+ generated_codes.append(code)
+
+ # Step 6: Combine with prompt codes and deduplicate
+ all_codes = list(set(generated_codes + prompt_codes))
+
+ # Ensure exactly num_codes by sampling if needed
+ if len(all_codes) < num_codes:
+ # Not enough unique codes generated - resample with replacement
+ needed = num_codes - len(all_codes)
+ additional = list(np.random.choice(generated_codes, needed, replace=True)) if len(generated_codes) > 0 else []
+ all_codes.extend(additional)
+ elif len(all_codes) > num_codes:
+ # Too many codes - sample exactly num_codes
+ all_codes = list(np.random.choice(all_codes, num_codes, replace=False))
+
+ generated_visits.append(all_codes)
+ prompt_codes_per_visit.append(prompt_codes)
+
+ return {
+ 'generated_visits': generated_visits,
+ 'target_visits': patient_visits,
+ 'prompt_codes': prompt_codes_per_visit,
+ 'demographics': {
+ 'age': age,
+ 'gender': gender_str
+ }
+ }
+
+
+def generate_patient_with_structure_constraints(
+ model,
+ tokenizer,
+ device: torch.device,
+ target_structure: dict,
+ age: Optional[float] = None,
+ sex: Optional[int] = None,
+ first_code: Optional[str] = None,
+ temperature: float = 0.7,
+ top_k: int = 0, # Disabled (test with top_p only)
+ top_p: float = 0.95, # Increased for more diversity
+ max_codes_per_visit: int = 25
+) -> dict:
+ """Generate patient with realistic visit structure constraints.
+
+ This function generates patients visit-by-visit with controlled code counts
+ sampled from real data distributions, producing more realistic EHR records.
+
+ Args:
+ model: Trained PromptBartModel or PromptEHR model.
+ tokenizer: DiagnosisCodeTokenizer instance.
+ device: Device to run on.
+ target_structure: Dict with 'num_visits' and 'codes_per_visit' list.
+ age: Patient age (if None, sampled from distribution).
+ sex: Patient sex ID (0=M, 1=F; if None, sampled).
+ first_code: First diagnosis code to condition on (if None, generated by model).
+ temperature: Sampling temperature (default: 0.7).
+ top_k: Top-k sampling parameter (default: 40).
+ top_p: Nucleus sampling parameter (default: 0.9).
+ max_codes_per_visit: Maximum codes per visit safety cap (default: 25).
+
+ Returns:
+ Dictionary with:
+ - 'generated_visits': List[List[str]] of diagnosis codes
+ - 'demographics': dict with 'age' and 'sex'
+ - 'num_visits': int
+ - 'num_codes': int
+ - 'target_structure': dict (the structure we aimed for)
+ """
+ model.eval()
+
+ # Sample demographics if not provided
+ if age is None or sex is None:
+ sampled_demo = sample_demographics()
+ age = sampled_demo['age'] if age is None else age
+ sex = sampled_demo['sex'] if sex is None else sex
+
+ # Prepare demographic tensors
+ x_num = torch.tensor([[age]], dtype=torch.float32).to(device)
+ x_cat = torch.tensor([[sex]], dtype=torch.long).to(device)
+
+ # Special token IDs
+ bos_token_id = tokenizer.bos_token_id
+ v_token_id = tokenizer.convert_tokens_to_indices([""])[0]
+ v_end_token_id = tokenizer.convert_tokens_to_indices(["<\\v>"])[0]
+ end_token_id = tokenizer.convert_tokens_to_indices([""])[0]
+
+ # Extract target structure
+ num_visits = target_structure['num_visits']
+ codes_per_visit = target_structure['codes_per_visit']
+
+ # Handle case with no visits
+ if num_visits == 0 or len(codes_per_visit) == 0:
+ return {
+ 'generated_visits': [],
+ 'demographics': {'age': age, 'sex': sex},
+ 'num_visits': 0,
+ 'num_codes': 0,
+ 'target_structure': target_structure
+ }
+
+ # Initialize generation with empty sequence
+ # HuggingFace will prepend decoder_start_token_id () automatically
+ # This matches training pattern: [, , codes...] after first is appended
+ decoder_input_ids = torch.tensor([[]], dtype=torch.long).to(device)
+
+ # If first_code provided, prepopulate decoder with + first_code (no )
+ # This starts visit 0 with the sampled first code, then continues generating
+ first_visit_prepopulated = False
+ if first_code is not None and first_code in tokenizer.vocab.code2idx:
+ v_token_id_temp = tokenizer.convert_tokens_to_indices([""])[0]
+ first_code_id = tokenizer.vocab.code2idx[first_code]
+
+ # Add , first_code to decoder_input_ids (NO yet - let generation continue)
+ prepop_ids = torch.tensor([[v_token_id_temp, first_code_id]],
+ dtype=torch.long).to(device)
+ decoder_input_ids = torch.cat([decoder_input_ids, prepop_ids], dim=1)
+ first_visit_prepopulated = True
+
+ # Create dummy encoder input
+ encoder_input_ids = torch.tensor([[tokenizer.pad_token_id]], dtype=torch.long).to(device)
+ encoder_attention_mask = torch.ones_like(encoder_input_ids)
+
+ all_visits = []
+
+ with torch.no_grad():
+ for visit_idx in range(num_visits):
+ target_codes = min(codes_per_visit[visit_idx], max_codes_per_visit)
+
+ # For visit 0 with prepopulated first_code, reduce target by 1 since we already have 1 code
+ if visit_idx == 0 and first_visit_prepopulated:
+ target_codes = max(1, target_codes - 1) # At least 1 more code
+
+ # Skip if target is too small
+ if target_codes < 1:
+ continue
+
+ # Append token to start visit
+ v_token_tensor = torch.tensor([[v_token_id]], dtype=torch.long).to(device)
+ decoder_input_ids = torch.cat([decoder_input_ids, v_token_tensor], dim=1)
+
+ # Calculate max tokens to generate for this visit
+ # Each code is ~1 token, plus 1 for
+ # Add 50% buffer for flexibility
+ max_new_tokens_this_visit = int(target_codes * 1.5) + 1
+
+ try:
+ # Generate codes for this visit
+ generated_visit_ids = model.generate(
+ input_ids=encoder_input_ids,
+ attention_mask=encoder_attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ x_num=x_num,
+ x_cat=x_cat,
+ max_new_tokens=max_new_tokens_this_visit,
+ do_sample=True,
+ num_beams=1,
+ temperature=temperature,
+ top_k=top_k,
+ top_p=top_p,
+ no_repeat_ngram_size=1,
+ eos_token_id=v_end_token_id, # Stop at visit end
+ pad_token_id=tokenizer.pad_token_id
+ # Note: NOT passing bos_token_id - let BART use decoder_start_token_id () automatically
+ )
+
+ # Extract only the newly generated tokens (after decoder_input_ids)
+ new_tokens = generated_visit_ids[0, decoder_input_ids.shape[1]:]
+
+ # Parse the generated visit codes
+ visit_codes = []
+ for token_id in new_tokens:
+ token_id_val = token_id.item()
+ if token_id_val == v_end_token_id:
+ break # End of visit
+ elif token_id_val >= tokenizer.code_offset:
+ # Diagnosis code - token_id_val is already the correct vocab index
+ # FIX: code2idx already includes special tokens, so don't subtract offset
+ if token_id_val < len(tokenizer.vocab.idx2code):
+ code = tokenizer.vocab.idx2code[token_id_val]
+ visit_codes.append(code)
+
+ # If we generated codes, add visit
+ if len(visit_codes) > 0:
+ # Truncate to target if we over-generated
+ if len(visit_codes) > target_codes:
+ visit_codes = visit_codes[:target_codes]
+
+ all_visits.append(visit_codes)
+
+ # Update decoder_input_ids with the full visit (including )
+ # Reconstruct the visit tokens
+ visit_token_ids = [v_token_id] #
+ for code in visit_codes:
+ if code in tokenizer.vocab.code2idx:
+ # FIX: code2idx already returns token ID with offset included
+ code_token_id = tokenizer.vocab.code2idx[code]
+ visit_token_ids.append(code_token_id)
+ visit_token_ids.append(v_end_token_id) #
+
+ # Convert to tensor and concatenate (skip first since already added)
+ visit_tensor = torch.tensor([visit_token_ids[1:]], dtype=torch.long).to(device)
+ decoder_input_ids = torch.cat([decoder_input_ids, visit_tensor], dim=1)
+
+ except Exception as e:
+ # If generation fails for this visit, skip it
+ print(f"Warning: Generation failed for visit {visit_idx + 1}: {e}")
+ continue
+
+ # Check if we're approaching context limit (512 for BART)
+ if decoder_input_ids.shape[1] > 400:
+ break # Stop generating more visits
+
+ # Compute statistics
+ total_codes = sum(len(visit) for visit in all_visits)
+
+ return {
+ 'generated_visits': all_visits,
+ 'demographics': {'age': age, 'sex': sex},
+ 'num_visits': len(all_visits),
+ 'num_codes': total_codes,
+ 'target_structure': target_structure
+ }
+
+
+def generate_with_frequency_prior(
+ model,
+ tokenizer,
+ device: torch.device,
+ target_structure: dict,
+ frequency_prior: torch.Tensor,
+ alpha: float = 1.0,
+ age: Optional[float] = None,
+ sex: Optional[int] = None,
+ temperature: float = 0.7,
+ top_k: int = 0,
+ top_p: float = 0.95,
+ max_codes_per_visit: int = 25,
+ diagnostic_mode: bool = False,
+ diagnostic_path: Optional[str] = None
+) -> dict:
+ """Generate patient with frequency-guided sampling.
+
+ This function is identical to generate_patient_with_structure_constraints,
+ but blends model logits with training frequency prior for realistic code distributions.
+
+ Args:
+ model: Trained PromptBartModel or PromptEHR model.
+ tokenizer: DiagnosisCodeTokenizer instance.
+ device: Device to run on.
+ target_structure: Dict with 'num_visits' and 'codes_per_visit' list.
+ frequency_prior: [vocab_size] log-frequency tensor from build_frequency_prior().
+ alpha: Blending weight (0=pure model, higher=more frequency guidance).
+ Recommended: 0.5-2.0. Start with 1.0.
+ age: Patient age (if None, sampled from distribution).
+ sex: Patient sex ID (0=M, 1=F; if None, sampled).
+ temperature: Sampling temperature (default: 0.7).
+ top_k: Top-k sampling parameter (default: 0 = disabled).
+ top_p: Nucleus sampling parameter (default: 0.95).
+ max_codes_per_visit: Maximum codes per visit safety cap (default: 25).
+ diagnostic_mode: Enable detailed logging of generation process (default: False).
+ diagnostic_path: Path to save diagnostic JSON file (required if diagnostic_mode=True).
+
+ Returns:
+ Dictionary with:
+ - 'generated_visits': List[List[str]] of diagnosis codes
+ - 'demographics': dict with 'age' and 'sex'
+ - 'num_visits': int
+ - 'num_codes': int
+ - 'target_structure': dict (the structure we aimed for)
+ - 'alpha': float (frequency prior weight used)
+ - 'diagnostics': dict (if diagnostic_mode=True) with detailed generation logs
+
+ Example:
+ >>> prior = build_frequency_prior(tokenizer, './promptehr_outputs/training_frequencies.json')
+ >>> result = generate_with_frequency_prior(
+ ... model, tokenizer, device,
+ ... target_structure={'num_visits': 3, 'codes_per_visit': [5, 8, 6]},
+ ... frequency_prior=prior,
+ ... alpha=1.0
+ ... )
+ """
+ model.eval()
+
+ # Sample demographics if not provided
+ if age is None or sex is None:
+ sampled_demo = sample_demographics()
+ age = sampled_demo['age'] if age is None else age
+ sex = sampled_demo['sex'] if sex is None else sex
+
+ # Prepare demographic tensors
+ x_num = torch.tensor([[age]], dtype=torch.float32).to(device)
+ x_cat = torch.tensor([[sex]], dtype=torch.long).to(device)
+
+ # Move frequency prior to device
+ frequency_prior = frequency_prior.to(device)
+
+ # Special token IDs
+ bos_token_id = tokenizer.bos_token_id
+ v_token_id = tokenizer.convert_tokens_to_indices([""])[0]
+ v_end_token_id = tokenizer.convert_tokens_to_indices(["<\\v>"])[0]
+
+ # Extract target structure
+ num_visits = target_structure['num_visits']
+ codes_per_visit = target_structure['codes_per_visit']
+
+ # Handle case with no visits
+ if num_visits == 0 or len(codes_per_visit) == 0:
+ return {
+ 'generated_visits': [],
+ 'demographics': {'age': age, 'sex': sex},
+ 'num_visits': 0,
+ 'num_codes': 0,
+ 'target_structure': target_structure,
+ 'alpha': alpha
+ }
+
+ # Initialize generation with empty sequence
+ # HuggingFace will prepend decoder_start_token_id () automatically
+ # This matches training pattern: [, , codes...] after first is appended
+ decoder_input_ids = torch.tensor([[]], dtype=torch.long).to(device)
+
+ # Create dummy encoder input
+ encoder_input_ids = torch.tensor([[tokenizer.pad_token_id]], dtype=torch.long).to(device)
+ encoder_attention_mask = torch.ones_like(encoder_input_ids)
+
+ all_visits = []
+
+ # Initialize diagnostic tracking
+ all_diagnostics = {'visits': []} if diagnostic_mode else None
+
+ with torch.no_grad():
+ for visit_idx in range(num_visits):
+ target_codes = min(codes_per_visit[visit_idx], max_codes_per_visit)
+
+ # Skip if target is too small
+ if target_codes < 1:
+ continue
+
+ # Append token to start visit
+ v_token_tensor = torch.tensor([[v_token_id]], dtype=torch.long).to(device)
+ decoder_input_ids = torch.cat([decoder_input_ids, v_token_tensor], dim=1)
+
+ # Generate codes for this visit with frequency guidance
+ max_new_tokens_this_visit = int(target_codes * 1.5) + 1
+ visit_codes = []
+
+ # Initialize visit diagnostic tracking
+ visit_diagnostics = {'visit_idx': visit_idx, 'steps': []} if diagnostic_mode else None
+
+ for step in range(max_new_tokens_this_visit):
+ # Forward pass
+ outputs = model(
+ input_ids=encoder_input_ids,
+ attention_mask=encoder_attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ x_num=x_num,
+ x_cat=x_cat,
+ return_dict=True
+ )
+
+ # Get logits for next token (handle both dict and object outputs)
+ if hasattr(outputs, 'logits'):
+ logits = outputs.logits[0, -1, :] # [vocab_size]
+ elif isinstance(outputs, dict) and 'logits' in outputs:
+ logits = outputs['logits'][0, -1, :] # [vocab_size]
+ else:
+ raise TypeError(f"Unexpected output type: {type(outputs)}")
+
+ # Diagnostic logging: raw model logits
+ if diagnostic_mode:
+ step_diagnostics = {
+ 'step': step,
+ 'raw_logits': {
+ 'max': float(logits.max()),
+ 'min': float(logits.min()),
+ 'mean': float(logits.mean()),
+ 'std': float(logits.std()),
+ 'top_5_indices': [int(i) for i in logits.topk(5).indices],
+ 'top_5_codes': [tokenizer.vocab.idx2code.get(int(i), f"<{i}>")
+ for i in logits.topk(5).indices],
+ 'top_5_values': [float(v) for v in logits.topk(5).values]
+ }
+ }
+
+ # BLEND with frequency prior
+ logits_guided = logits + alpha * frequency_prior
+
+ # Diagnostic logging: frequency blending
+ if diagnostic_mode:
+ step_diagnostics['blending'] = {
+ 'alpha': alpha,
+ 'prior_contribution': float((alpha * frequency_prior).abs().mean()),
+ 'logits_shift': float((logits_guided - logits).abs().mean()),
+ 'top_5_after_blend_indices': [int(i) for i in logits_guided.topk(5).indices],
+ 'top_5_after_blend_codes': [tokenizer.vocab.idx2code.get(int(i), f"<{i}>")
+ for i in logits_guided.topk(5).indices],
+ 'top_5_after_blend_values': [float(v) for v in logits_guided.topk(5).values]
+ }
+
+ # Apply temperature
+ scaled_logits = logits_guided / temperature
+
+ # Convert to probabilities
+ probs = torch.softmax(scaled_logits, dim=0)
+
+ # Diagnostic logging: probabilities after temperature
+ if diagnostic_mode:
+ top_probs, top_indices = torch.topk(probs, 20)
+ step_diagnostics['probabilities'] = {
+ 'temperature': temperature,
+ 'entropy': float(-(probs * torch.log(probs + 1e-10)).sum()),
+ 'top_20': [
+ {'code': tokenizer.vocab.idx2code.get(int(idx), f"<{idx}>"),
+ 'prob': float(prob),
+ 'idx': int(idx)}
+ for idx, prob in zip(top_indices, top_probs)
+ ]
+ }
+
+ # Apply top-k filtering if enabled
+ if top_k > 0:
+ top_k_vals, top_k_indices = torch.topk(probs, min(top_k, probs.size(-1)))
+ probs_filtered = torch.zeros_like(probs)
+ probs_filtered.scatter_(0, top_k_indices, top_k_vals)
+ probs = probs_filtered / probs_filtered.sum()
+
+ # Apply nucleus (top-p) sampling
+ if top_p < 1.0:
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True)
+ cumsum_probs = torch.cumsum(sorted_probs, dim=0)
+ nucleus_mask = cumsum_probs <= top_p
+ nucleus_mask[0] = True # Always include top token
+
+ nucleus_indices = sorted_indices[nucleus_mask]
+ nucleus_probs = sorted_probs[nucleus_mask]
+ nucleus_probs = nucleus_probs / nucleus_probs.sum()
+
+ # Sample from nucleus
+ sampled_idx = torch.multinomial(nucleus_probs, 1)[0]
+ next_token = int(nucleus_indices[sampled_idx])
+ else:
+ # Sample directly from filtered probs
+ next_token = int(torch.multinomial(probs, 1)[0])
+
+ # Diagnostic logging: sampling decision
+ if diagnostic_mode:
+ selected_code = tokenizer.vocab.idx2code.get(next_token, f"<{next_token}>")
+ step_diagnostics['selected'] = {
+ 'token': next_token,
+ 'code': selected_code,
+ 'probability': float(probs[next_token]) if next_token < len(probs) else 0.0,
+ 'was_top_1': (next_token == int(probs.argmax())),
+ 'is_special_token': next_token < tokenizer.code_offset
+ }
+ visit_diagnostics['steps'].append(step_diagnostics)
+
+ # Check if we hit end-of-visit
+ if next_token == v_end_token_id:
+ break
+
+ # Extract code if it's a diagnosis code
+ # FIX: code2idx already includes special tokens, so don't subtract offset
+ if next_token >= tokenizer.code_offset:
+ if next_token < len(tokenizer.vocab.idx2code):
+ code = tokenizer.vocab.idx2code[next_token]
+ if code not in visit_codes: # Prevent duplicates
+ visit_codes.append(code)
+
+ # Append token to decoder input
+ next_token_tensor = torch.tensor([[next_token]], dtype=torch.long).to(device)
+ decoder_input_ids = torch.cat([decoder_input_ids, next_token_tensor], dim=1)
+
+ # Stop if we have enough codes
+ if len(visit_codes) >= target_codes:
+ break
+
+ # Add visit if we generated codes
+ if len(visit_codes) > 0:
+ # Truncate to target if over-generated
+ if len(visit_codes) > target_codes:
+ visit_codes = visit_codes[:target_codes]
+
+ all_visits.append(visit_codes)
+
+ # Add visit diagnostics
+ if diagnostic_mode:
+ visit_diagnostics['generated_codes'] = visit_codes
+ visit_diagnostics['target_codes'] = target_codes
+ all_diagnostics['visits'].append(visit_diagnostics)
+
+ # Append to close visit
+ v_end_tensor = torch.tensor([[v_end_token_id]], dtype=torch.long).to(device)
+ decoder_input_ids = torch.cat([decoder_input_ids, v_end_tensor], dim=1)
+
+ # Check if we're approaching context limit
+ if decoder_input_ids.shape[1] > 400:
+ break
+
+ # Compute statistics
+ total_codes = sum(len(visit) for visit in all_visits)
+
+ # Build result dictionary
+ result = {
+ 'generated_visits': all_visits,
+ 'demographics': {'age': age, 'sex': sex},
+ 'num_visits': len(all_visits),
+ 'num_codes': total_codes,
+ 'target_structure': target_structure,
+ 'alpha': alpha
+ }
+
+ # Add diagnostics if enabled
+ if diagnostic_mode:
+ all_diagnostics['demographics'] = {'age': age, 'sex': sex}
+ all_diagnostics['params'] = {
+ 'alpha': alpha,
+ 'temperature': temperature,
+ 'top_k': top_k,
+ 'top_p': top_p
+ }
+ all_diagnostics['generated_codes'] = all_visits
+ result['diagnostics'] = all_diagnostics
+
+ # Save diagnostics to file if path provided
+ if diagnostic_path:
+ import json
+ import os
+ os.makedirs(os.path.dirname(diagnostic_path), exist_ok=True)
+ with open(diagnostic_path, 'w') as f:
+ json.dump(all_diagnostics, f, indent=2)
+
+ return result
diff --git a/pyhealth/models/promptehr/model.py b/pyhealth/models/promptehr/model.py
new file mode 100644
index 000000000..2657c0c7d
--- /dev/null
+++ b/pyhealth/models/promptehr/model.py
@@ -0,0 +1,788 @@
+"""PromptEHR: BART-based generative model for synthetic EHR generation.
+
+This module provides the main PromptEHR model that combines demographic-conditioned
+prompts with BART encoder-decoder architecture for realistic patient record generation.
+
+Ported from pehr_scratch/prompt_bart_model.py (lines 16-276, excluding auxiliary losses).
+"""
+
+import os
+import random
+from typing import Dict, List, Optional, Tuple
+import torch
+import torch.nn as nn
+from torch.nn.utils.rnn import pad_sequence
+from transformers import BartConfig, BartForConditionalGeneration
+from transformers.modeling_outputs import Seq2SeqLMOutput
+
+from pyhealth.models import BaseModel
+from .conditional_prompt import ConditionalPromptEncoder
+from .bart_encoder import PromptBartEncoder
+from .bart_decoder import PromptBartDecoder
+
+
+class PromptBartModel(BartForConditionalGeneration):
+ """BART model with demographic prompt conditioning for EHR generation.
+
+ Extends HuggingFace's BartForConditionalGeneration with:
+ 1. Dual prompt encoders (separate for encoder/decoder)
+ 2. Demographic conditioning via learned prompt vectors
+ 3. Label smoothing for diverse generation
+
+ This is the core generative model WITHOUT auxiliary losses (those caused
+ mode collapse and are excluded per implementation decision D003).
+
+ Args:
+ config: BART configuration from transformers
+ n_num_features: Number of continuous features (1 for age)
+ cat_cardinalities: Category counts for categorical features ([2] for gender M/F)
+ d_hidden: Intermediate reparameterization dimension (default: 128)
+ prompt_length: Number of prompt vectors per feature (default: 1)
+
+ Example:
+ >>> from transformers import BartConfig
+ >>> config = BartConfig.from_pretrained("facebook/bart-base")
+ >>> model = PromptBartModel(
+ ... config,
+ ... n_num_features=1, # age
+ ... cat_cardinalities=[2], # gender (M/F)
+ ... d_hidden=128,
+ ... prompt_length=1
+ ... )
+ >>> # Forward pass with demographics
+ >>> age = torch.randn(16, 1) # [batch, 1]
+ >>> gender = torch.randint(0, 2, (16, 1)) # [batch, 1]
+ >>> input_ids = torch.randint(0, 1000, (16, 100))
+ >>> labels = torch.randint(0, 1000, (16, 50))
+ >>> output = model(
+ ... input_ids=input_ids,
+ ... labels=labels,
+ ... x_num=age,
+ ... x_cat=gender
+ ... )
+ >>> loss = output.loss
+ """
+
+ def __init__(
+ self,
+ config: BartConfig,
+ n_num_features: Optional[int] = None,
+ cat_cardinalities: Optional[list] = None,
+ d_hidden: int = 128,
+ prompt_length: int = 1
+ ):
+ """Initialize PromptBART model with dual prompt conditioning.
+
+ Args:
+ config: BART configuration
+ n_num_features: Number of continuous features (e.g., 1 for age)
+ cat_cardinalities: Category counts for categorical features [n_genders]
+ d_hidden: Intermediate reparameterization dimension (default: 128)
+ prompt_length: Number of prompt vectors per feature (default: 1)
+ """
+ super().__init__(config)
+
+ # Replace encoder and decoder with prompt-aware versions
+ self.model.encoder = PromptBartEncoder(config, self.model.shared)
+ self.model.decoder = PromptBartDecoder(config, self.model.shared)
+
+ # Add SEPARATE conditional prompt encoders for encoder and decoder
+ # This provides stronger demographic conditioning than shared prompts (dual injection)
+ if n_num_features is not None or cat_cardinalities is not None:
+ # Encoder prompt encoder
+ self.encoder_prompt_encoder = ConditionalPromptEncoder(
+ n_num_features=n_num_features,
+ cat_cardinalities=cat_cardinalities,
+ hidden_dim=config.d_model,
+ d_hidden=d_hidden,
+ prompt_length=prompt_length
+ )
+ # Decoder prompt encoder (separate parameters for dual injection)
+ self.decoder_prompt_encoder = ConditionalPromptEncoder(
+ n_num_features=n_num_features,
+ cat_cardinalities=cat_cardinalities,
+ hidden_dim=config.d_model,
+ d_hidden=d_hidden,
+ prompt_length=prompt_length
+ )
+ self.num_prompts = self.encoder_prompt_encoder.get_num_prompts()
+ else:
+ self.encoder_prompt_encoder = None
+ self.decoder_prompt_encoder = None
+ self.num_prompts = 0
+
+ # Initialize weights
+ self.post_init()
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ decoder_head_mask: Optional[torch.Tensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ x_num: Optional[torch.FloatTensor] = None,
+ x_cat: Optional[torch.LongTensor] = None,
+ ) -> Seq2SeqLMOutput:
+ """Forward pass with demographic conditioning.
+
+ Args:
+ input_ids: [batch, seq_len] encoder input token IDs
+ attention_mask: [batch, seq_len] encoder attention mask
+ decoder_input_ids: [batch, tgt_len] decoder input token IDs
+ decoder_attention_mask: [batch, tgt_len] decoder attention mask
+ labels: [batch, tgt_len] target labels for loss computation
+ x_num: [batch, n_num_features] continuous demographic features (e.g., age)
+ x_cat: [batch, n_cat_features] categorical demographic features (e.g., gender)
+ Other args: Standard BART arguments
+
+ Returns:
+ Seq2SeqLMOutput with:
+ - loss: Cross-entropy loss with label smoothing=0.1
+ - logits: [batch, tgt_len, vocab_size] prediction logits
+ - past_key_values: Cached key-value states (if use_cache=True)
+ - decoder_hidden_states: Decoder layer outputs (if output_hidden_states=True)
+ - encoder_last_hidden_state: Final encoder output
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # Encode demographic prompts separately for encoder and decoder
+ # Only prepend prompts on first step (when no cache exists)
+ encoder_prompt_embeds = None
+ decoder_prompt_embeds = None
+ if (x_num is not None or x_cat is not None) and past_key_values is None:
+ if self.encoder_prompt_encoder is not None:
+ encoder_prompt_embeds = self.encoder_prompt_encoder(x_num=x_num, x_cat=x_cat)
+ if self.decoder_prompt_encoder is not None:
+ decoder_prompt_embeds = self.decoder_prompt_encoder(x_num=x_num, x_cat=x_cat)
+
+ # Prepare decoder input IDs (shift labels right for teacher forcing)
+ if labels is not None:
+ if decoder_input_ids is None and decoder_inputs_embeds is None:
+ decoder_input_ids = shift_tokens_right(
+ labels, self.config.pad_token_id, self.config.decoder_start_token_id
+ )
+
+ # Encoder forward pass (with encoder prompts)
+ if encoder_outputs is None:
+ encoder_outputs = self.model.encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ inputs_prompt_embeds=encoder_prompt_embeds, # Encoder-specific prompts
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ # Extend encoder attention mask for prompts
+ encoder_attention_mask = attention_mask
+ if encoder_prompt_embeds is not None and attention_mask is not None:
+ batch_size, n_prompts = encoder_prompt_embeds.shape[:2]
+ prompt_mask = torch.ones(batch_size, n_prompts, dtype=attention_mask.dtype, device=attention_mask.device)
+ encoder_attention_mask = torch.cat([prompt_mask, attention_mask], dim=1)
+
+ # Decoder forward pass (with decoder prompts)
+ decoder_outputs = self.model.decoder(
+ input_ids=decoder_input_ids,
+ attention_mask=decoder_attention_mask,
+ encoder_hidden_states=encoder_outputs[0],
+ encoder_attention_mask=encoder_attention_mask,
+ head_mask=decoder_head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=decoder_inputs_embeds,
+ inputs_prompt_embeds=decoder_prompt_embeds, # Decoder-specific prompts
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ # Language modeling head
+ lm_logits = self.lm_head(decoder_outputs[0])
+
+ # If decoder prompts were prepended, slice them off before loss computation
+ if decoder_prompt_embeds is not None and labels is not None:
+ # decoder_outputs[0] shape: [batch, n_prompts + seq_len, hidden_dim]
+ # We only want logits for the actual sequence positions
+ n_prompts = decoder_prompt_embeds.shape[1]
+ lm_logits = lm_logits[:, n_prompts:, :] # Remove prompt positions
+
+ # Compute loss if labels provided
+ loss = None
+ if labels is not None:
+ # Label smoothing = 0.1 to prevent overconfidence and encourage diversity
+ # Softens target distributions: 90% on correct token, 10% distributed to alternatives
+ loss_fct = nn.CrossEntropyLoss(label_smoothing=0.1)
+ loss = loss_fct(lm_logits.reshape(-1, self.config.vocab_size), labels.view(-1))
+
+ if not return_dict:
+ output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
+ return ((loss,) + output) if loss is not None else output
+
+ return Seq2SeqLMOutput(
+ loss=loss,
+ logits=lm_logits,
+ past_key_values=decoder_outputs.past_key_values,
+ decoder_hidden_states=decoder_outputs.hidden_states,
+ decoder_attentions=decoder_outputs.attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+ encoder_hidden_states=encoder_outputs.hidden_states,
+ encoder_attentions=encoder_outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ decoder_input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ head_mask=None,
+ decoder_head_mask=None,
+ cross_attn_head_mask=None,
+ use_cache=None,
+ encoder_outputs=None,
+ x_num=None,
+ x_cat=None,
+ **kwargs
+ ):
+ """Prepare inputs for autoregressive generation.
+
+ Args:
+ decoder_input_ids: [batch, cur_len] current decoder input IDs
+ past_key_values: Cached key-value states from previous steps
+ x_num: [batch, n_num_features] continuous demographics (passed through)
+ x_cat: [batch, n_cat_features] categorical demographics (passed through)
+ Other args: Standard BART generation arguments
+
+ Returns:
+ Dictionary of inputs for next generation step
+ """
+ # Cut decoder_input_ids if past is used (only need last token)
+ if past_key_values is not None:
+ decoder_input_ids = decoder_input_ids[:, -1:]
+
+ return {
+ "input_ids": None,
+ "encoder_outputs": encoder_outputs,
+ "past_key_values": past_key_values,
+ "decoder_input_ids": decoder_input_ids,
+ "attention_mask": attention_mask,
+ "head_mask": head_mask,
+ "decoder_head_mask": decoder_head_mask,
+ "cross_attn_head_mask": cross_attn_head_mask,
+ "use_cache": use_cache,
+ "x_num": x_num, # Pass demographics through generation
+ "x_cat": x_cat,
+ }
+
+ @staticmethod
+ def _expand_inputs_for_generation(
+ input_ids,
+ expand_size=1,
+ is_encoder_decoder=True,
+ attention_mask=None,
+ encoder_outputs=None,
+ x_num=None,
+ x_cat=None,
+ **model_kwargs,
+ ):
+ """Expand inputs for beam search or multiple samples.
+
+ Args:
+ input_ids: [batch, seq_len] input token IDs
+ expand_size: Number of beams/samples per input
+ x_num: [batch, n_num_features] continuous demographics
+ x_cat: [batch, n_cat_features] categorical demographics
+ Other args: Standard expansion arguments
+
+ Returns:
+ Expanded input_ids and model_kwargs
+ """
+ expanded_return_idx = (
+ torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device)
+ )
+
+ if attention_mask is not None:
+ model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx)
+
+ if encoder_outputs is not None:
+ encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select(
+ 0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device)
+ )
+ model_kwargs["encoder_outputs"] = encoder_outputs
+
+ # Expand demographics for beam search
+ if x_num is not None:
+ model_kwargs["x_num"] = x_num.index_select(0, expanded_return_idx)
+
+ if x_cat is not None:
+ model_kwargs["x_cat"] = x_cat.index_select(0, expanded_return_idx)
+
+ return input_ids, model_kwargs
+
+
+def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
+ """Shift input ids one token to the right for teacher forcing.
+
+ Args:
+ input_ids: [batch, seq_len] target token IDs
+ pad_token_id: ID for padding token
+ decoder_start_token_id: ID for decoder start token (BOS)
+
+ Returns:
+ [batch, seq_len] shifted token IDs with BOS prepended
+ """
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
+ shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
+ shifted_input_ids[:, 0] = decoder_start_token_id
+
+ if pad_token_id is None:
+ raise ValueError("config.pad_token_id must be defined for sequence generation")
+
+ # Replace -100 in labels with pad_token_id
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
+
+ return shifted_input_ids
+
+
+class _PromptEHRVocab:
+ """Internal vocabulary bridging NestedSequenceProcessor indices to BART token IDs.
+
+ Token layout (7 special tokens + N diagnosis codes):
+ 0 = (BartConfig.pad_token_id)
+ 1 = (BartConfig.bos_token_id / decoder_start_token_id)
+ 2 = (BartConfig.eos_token_id)
+ 3 =
+ 4 = (visit start)
+ 5 = (visit end)
+ 6 = (sequence terminator)
+ 7+ = diagnosis codes
+
+ NestedSequenceProcessor uses pad=0, unk=1, codes=2+.
+ Mapping: processor_idx i -> BART token i + 5 (for i >= 2).
+ Total BART vocab size = processor.vocab_size() + 5.
+
+ Args:
+ code_vocab (dict): Mapping of code string to processor index, as
+ returned by ``NestedSequenceProcessor.code_vocab``. Must include
+ ``""`` -> 0 and ``""`` -> 1.
+
+ Examples:
+ >>> vocab = _PromptEHRVocab({"": 0, "": 1, "428": 2, "410": 3})
+ >>> isinstance(vocab, _PromptEHRVocab)
+ True
+ >>> vocab.total_size
+ 9
+ """
+
+ PAD = 0
+ BOS = 1
+ EOS = 2
+ UNK = 3
+ VISIT_START = 4
+ VISIT_END = 5
+ SEQ_END = 6
+ CODE_OFFSET = 7
+
+ def __init__(self, code_vocab: dict):
+ """Build vocab from NestedSequenceProcessor.code_vocab dict."""
+ self._bart_to_code: Dict[int, str] = {}
+ for code, pid in code_vocab.items():
+ if pid >= 2: # skip and
+ self._bart_to_code[pid + 5] = code
+ self.total_size = len(code_vocab) + 5 # 7 special - 2 reused + N codes
+
+ def encode_visits(self, visits_tensor: torch.Tensor) -> List[int]:
+ """Encode a processed [n_visits, max_codes] LongTensor to a token ID list.
+
+ Args:
+ visits_tensor (torch.Tensor): LongTensor of shape
+ ``(n_visits, max_codes_per_visit)`` from NestedSequenceProcessor.
+ Values 0 = pad, 1 = unk, 2+ = code index.
+
+ Returns:
+ list of int: Token IDs in format
+ ``[, code, ..., , , ..., , ]``.
+ """
+ tokens = []
+ for visit in visits_tensor:
+ codes_in_visit = [
+ int(c.item()) + 5 # processor idx 2+ → BART idx 7+
+ for c in visit
+ if c.item() >= 2 # skip pad and unk
+ ]
+ if codes_in_visit:
+ tokens.append(self.VISIT_START)
+ tokens.extend(codes_in_visit)
+ tokens.append(self.VISIT_END)
+ tokens.append(self.SEQ_END)
+ return tokens
+
+ def decode_tokens(self, token_ids: List[int]) -> List[List[str]]:
+ """Decode a generated token ID list back to visit structure.
+
+ Args:
+ token_ids (list of int): Raw generated token IDs from BART.
+
+ Returns:
+ list of list of str: Decoded diagnosis code strings per visit.
+ """
+ visits: List[List[str]] = []
+ current_visit: List[str] = []
+ in_visit = False
+ for tid in token_ids:
+ if tid == self.VISIT_START:
+ in_visit = True
+ current_visit = []
+ elif tid == self.VISIT_END:
+ if in_visit:
+ visits.append(current_visit)
+ in_visit = False
+ elif tid in (self.SEQ_END, self.EOS, self.PAD, self.BOS):
+ break
+ elif in_visit and tid >= self.CODE_OFFSET:
+ code = self._bart_to_code.get(tid)
+ if code:
+ current_visit.append(code)
+ if in_visit and current_visit:
+ visits.append(current_visit)
+ return visits
+
+
+def _promptehr_collate_fn(batch):
+ """Collate PromptEHR training samples, padding token sequences in a batch.
+
+ Pads ``input_ids`` and ``labels`` to the longest sequence in the batch using
+ ``pad_sequence``. Builds the attention mask from padded positions.
+
+ Args:
+ batch (list of dict): Each dict has ``"input_ids"``, ``"labels"``,
+ ``"x_num"``, and ``"x_cat"`` tensors.
+
+ Returns:
+ dict: Batched tensors ready for ``PromptBartModel.forward()``.
+ """
+ input_ids = pad_sequence(
+ [item["input_ids"] for item in batch],
+ batch_first=True,
+ padding_value=_PromptEHRVocab.PAD,
+ )
+ labels = pad_sequence(
+ [item["labels"] for item in batch],
+ batch_first=True,
+ padding_value=-100,
+ )
+ attention_mask = (input_ids != _PromptEHRVocab.PAD).long()
+ x_num = torch.cat([item["x_num"] for item in batch], dim=0)
+ x_cat = torch.cat([item["x_cat"] for item in batch], dim=0)
+ return {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "labels": labels,
+ "x_num": x_num,
+ "x_cat": x_cat,
+ }
+
+
+class PromptEHR(BaseModel):
+ """PromptEHR: demographic-conditioned BART model for synthetic EHR generation.
+
+ Wraps ``PromptBartModel`` (HuggingFace BART with dual prompt conditioning)
+ in a PyHealth ``BaseModel`` interface. Training is handled by a HuggingFace
+ ``Trainer`` loop; generation is autoregressive token-by-token decoding.
+
+ Demographics (age as continuous, gender as categorical) are injected via
+ learned prompt vectors prepended to both encoder and decoder hidden states.
+
+ Args:
+ dataset (SampleDataset): PyHealth sample dataset produced by
+ ``set_task(promptehr_generation_mimic3_fn)``. Must have
+ ``input_processors["visits"]`` (NestedSequenceProcessor).
+ n_num_features (int): Continuous demographic features (1 for age).
+ Default: 1.
+ cat_cardinalities (list of int): Category counts per categorical
+ feature ([2] for binary gender M/F). Default: [2].
+ d_hidden (int): Reparameterization dimension for prompt encoder.
+ Default: 128.
+ prompt_length (int): Number of prompt vectors per feature. Default: 1.
+ bart_config_name (str): Pretrained BART config to use.
+ Default: ``"facebook/bart-base"``.
+ epochs (int): Training epochs. Default: 20.
+ batch_size (int): Training batch size. Default: 16.
+ lr (float): AdamW learning rate. Default: 1e-5.
+ warmup_steps (int): Linear warmup steps. Default: 1000.
+ max_seq_length (int): Maximum token sequence length. Default: 512.
+ save_dir (str): Directory for checkpoints. Default: ``"./save/"``.
+
+ Examples:
+ >>> from pyhealth.datasets.sample_dataset import InMemorySampleDataset
+ >>> samples = [
+ ... {"patient_id": "p1", "visits": [["428", "427"], ["410"]], "age": 65.0, "gender": 0},
+ ... {"patient_id": "p2", "visits": [["250"], ["401", "272"]], "age": 52.0, "gender": 1},
+ ... ]
+ >>> dataset = InMemorySampleDataset(
+ ... samples=samples,
+ ... input_schema={"visits": "nested_sequence"},
+ ... output_schema={},
+ ... )
+ >>> model = PromptEHR(dataset, d_hidden=32, prompt_length=1)
+ >>> isinstance(model, PromptEHR)
+ True
+ """
+
+ def __init__(
+ self,
+ dataset,
+ n_num_features: int = 1,
+ cat_cardinalities: Optional[list] = None,
+ d_hidden: int = 128,
+ prompt_length: int = 1,
+ bart_config_name: "Union[str, BartConfig]" = "facebook/bart-base",
+ epochs: int = 20,
+ batch_size: int = 16,
+ lr: float = 1e-5,
+ warmup_steps: int = 1000,
+ max_seq_length: int = 512,
+ save_dir: str = "./save/",
+ ):
+ """Initialize PromptEHR with vocab derived from the dataset processor."""
+ super().__init__(dataset)
+
+ self.mode = None # skip discriminative evaluation
+ self.save_dir = save_dir
+ self.epochs = epochs
+ self.batch_size = batch_size
+ self.lr = lr
+ self.warmup_steps = warmup_steps
+ self.max_seq_length = max_seq_length
+ self._demo_pool: List[tuple] = [] # (age, gender) pairs from training data
+
+ if cat_cardinalities is None:
+ cat_cardinalities = [2]
+
+ # Derive vocab from the dataset's NestedSequenceProcessor
+ visits_processor = dataset.input_processors["visits"]
+ self._vocab = _PromptEHRVocab(visits_processor.code_vocab)
+ bart_vocab_size = self._vocab.total_size
+
+ # Configure BART with our custom vocab and special token IDs
+ if isinstance(bart_config_name, str):
+ bart_config = BartConfig.from_pretrained(bart_config_name)
+ else:
+ # Accept a BartConfig object directly (useful for tiny test models)
+ bart_config = bart_config_name
+ bart_config.vocab_size = bart_vocab_size
+ bart_config.pad_token_id = _PromptEHRVocab.PAD
+ bart_config.bos_token_id = _PromptEHRVocab.BOS
+ bart_config.eos_token_id = _PromptEHRVocab.EOS
+ bart_config.decoder_start_token_id = _PromptEHRVocab.BOS
+ bart_config.forced_eos_token_id = _PromptEHRVocab.SEQ_END
+ bart_config.dropout = 0.3
+ bart_config.attention_dropout = 0.3
+ bart_config.activation_dropout = 0.3
+
+ self.bart_model = PromptBartModel(
+ config=bart_config,
+ n_num_features=n_num_features,
+ cat_cardinalities=cat_cardinalities,
+ d_hidden=d_hidden,
+ prompt_length=prompt_length,
+ )
+
+ def forward(self, **kwargs) -> Dict:
+ """Not implemented — PromptEHR is a generative model without a discriminative forward.
+
+ Raises:
+ NotImplementedError: Always. Use ``train_model`` and
+ ``synthesize_dataset`` instead.
+ """
+ raise NotImplementedError(
+ "PromptEHR is a generative model. Use train_model() and synthesize_dataset()."
+ )
+
+ def train_model(self, train_dataset, val_dataset=None) -> None:
+ """Train PromptEHR using a HuggingFace Trainer loop.
+
+ Converts PyHealth SampleDataset samples to BART token sequences and
+ trains with HuggingFace ``Trainer``. Demographics (age, gender) are
+ passed as ``x_num`` / ``x_cat`` via a custom data collator.
+
+ Named ``train_model`` (not ``train``) to avoid shadowing
+ ``nn.Module.train()``.
+
+ Args:
+ train_dataset (SampleDataset): Training set with ``"visits"``,
+ ``"age"``, and ``"gender"`` fields.
+ val_dataset (SampleDataset, optional): Validation set for loss
+ monitoring. Default: None.
+ """
+ from torch.utils.data import Dataset as TorchDataset
+ from transformers import Trainer, TrainingArguments
+
+ vocab = self._vocab
+ max_len = self.max_seq_length
+
+ class _EHRDataset(TorchDataset):
+ def __init__(self, samples):
+ self._samples = list(samples)
+
+ def __len__(self):
+ return len(self._samples)
+
+ def __getitem__(self, idx):
+ s = self._samples[idx]
+ tokens = vocab.encode_visits(s["visits"])
+ if len(tokens) > max_len:
+ tokens = tokens[:max_len - 1] + [vocab.SEQ_END]
+ age = float(s.get("age", 60.0))
+ gender = int(s.get("gender", 0))
+ return {
+ "input_ids": torch.tensor(tokens, dtype=torch.long),
+ "labels": torch.tensor(tokens, dtype=torch.long),
+ "x_num": torch.tensor([[age]], dtype=torch.float32),
+ "x_cat": torch.tensor([[gender]], dtype=torch.long),
+ }
+
+ train_samples = list(train_dataset)
+ # Store demographics pool for synthesize_dataset sampling
+ self._demo_pool = [
+ (float(s.get("age", 60.0)), int(s.get("gender", 0)))
+ for s in train_samples
+ ]
+
+ os.makedirs(self.save_dir, exist_ok=True)
+ training_args = TrainingArguments(
+ output_dir=self.save_dir,
+ num_train_epochs=self.epochs,
+ per_device_train_batch_size=self.batch_size,
+ learning_rate=self.lr,
+ warmup_steps=self.warmup_steps,
+ save_strategy="epoch",
+ logging_steps=50,
+ remove_unused_columns=False, # essential: keeps x_num/x_cat
+ use_cpu=not torch.cuda.is_available(),
+ )
+
+ trainer = Trainer(
+ model=self.bart_model,
+ args=training_args,
+ train_dataset=_EHRDataset(train_samples),
+ eval_dataset=_EHRDataset(list(val_dataset)) if val_dataset else None,
+ data_collator=_promptehr_collate_fn,
+ )
+ trainer.train()
+
+ self.save_model(os.path.join(self.save_dir, "checkpoint.pt"))
+
+ def synthesize_dataset(
+ self, num_samples: int, random_sampling: bool = True
+ ) -> List[Dict]:
+ """Generate a synthetic patient dataset.
+
+ Samples demographics from the training data distribution (if available)
+ and generates autoregressive token sequences via BART. Each sequence is
+ decoded back to a nested list of diagnosis code strings.
+
+ Args:
+ num_samples (int): Number of synthetic patients to generate.
+ random_sampling (bool): If True, uses multinomial sampling with
+ ``temperature=0.7, top_p=0.95``. If False, uses greedy decoding.
+ Default: True.
+
+ Returns:
+ list of dict: One record per synthetic patient. Each dict has:
+ ``"patient_id"`` (str): unique identifier, e.g. ``"synthetic_0"``.
+ ``"visits"`` (list of list of str): decoded code strings per visit.
+ """
+ self.bart_model.eval()
+ device = self.device
+
+ results = []
+ with torch.no_grad():
+ for i in range(num_samples):
+ # Sample demographics from training distribution (or defaults)
+ if self._demo_pool:
+ age, gender = self._demo_pool[
+ random.randrange(len(self._demo_pool))
+ ]
+ else:
+ age, gender = 60.0, 0
+
+ x_num = torch.tensor([[age]], dtype=torch.float32, device=device)
+ x_cat = torch.tensor([[gender]], dtype=torch.long, device=device)
+
+ # PAD token as minimal encoder input; prompts carry the signal
+ encoder_input = torch.tensor(
+ [[_PromptEHRVocab.PAD]], dtype=torch.long, device=device
+ )
+
+ output_ids = self.bart_model.generate(
+ input_ids=encoder_input,
+ attention_mask=torch.ones_like(encoder_input),
+ x_num=x_num,
+ x_cat=x_cat,
+ max_length=self.max_seq_length,
+ do_sample=random_sampling,
+ temperature=0.7 if random_sampling else 1.0,
+ top_p=0.95 if random_sampling else 1.0,
+ pad_token_id=_PromptEHRVocab.PAD,
+ eos_token_id=_PromptEHRVocab.SEQ_END,
+ bos_token_id=_PromptEHRVocab.BOS,
+ )
+
+ visits = self._vocab.decode_tokens(output_ids[0].tolist())
+ results.append({
+ "patient_id": f"synthetic_{i}",
+ "visits": visits,
+ })
+
+ return results
+
+ def save_model(self, path: str) -> None:
+ """Save model weights and vocab to a checkpoint file.
+
+ Args:
+ path (str): Destination file path (e.g. ``"./save/checkpoint.pt"``).
+
+ Examples:
+ >>> import tempfile, os
+ >>> tmpdir = tempfile.mkdtemp()
+ >>> model.save_model(os.path.join(tmpdir, "ckpt.pt"))
+ """
+ os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
+ torch.save(
+ {
+ "model": self.bart_model.state_dict(),
+ "vocab": self._vocab,
+ "bart_config": self.bart_model.config,
+ },
+ path,
+ )
+
+ def load_model(self, path: str) -> None:
+ """Load model weights from a checkpoint saved by ``save_model``.
+
+ Args:
+ path (str): Path to checkpoint file produced by ``save_model``.
+
+ Examples:
+ >>> model.load_model("./save/checkpoint.pt")
+ """
+ checkpoint = torch.load(path, map_location=self.device, weights_only=False)
+ self.bart_model.load_state_dict(checkpoint["model"])
+ if "vocab" in checkpoint:
+ self._vocab = checkpoint["vocab"]
diff --git a/pyhealth/models/promptehr/utils.py b/pyhealth/models/promptehr/utils.py
new file mode 100644
index 000000000..43e13ca83
--- /dev/null
+++ b/pyhealth/models/promptehr/utils.py
@@ -0,0 +1,29 @@
+"""Utility functions and classes for PromptEHR.
+
+This module contains:
+ - VisitStructureSampler: Samples realistic visit structures for generation
+ - Data collation functions
+ - Helper utilities
+"""
+
+import torch
+import torch.nn as nn
+
+
+class VisitStructureSampler:
+ """Samples realistic visit structures from training data.
+
+ This is a critical component added Nov 21, 2025 that solves the
+ over-generation problem. Reduces codes/patient from 18.1 → 11.97 (34%).
+
+ Args:
+ TODO: Add arguments after porting from pehr_scratch
+ """
+
+ def __init__(self, **kwargs):
+ # TODO: Port from ~/pehr_scratch/visit_structure_sampler.py
+ raise NotImplementedError("VisitStructureSampler porting in progress")
+
+ def sample(self, **kwargs):
+ """Sample a visit structure."""
+ raise NotImplementedError("VisitStructureSampler porting in progress")
diff --git a/pyhealth/models/promptehr/visit_sampler.py b/pyhealth/models/promptehr/visit_sampler.py
new file mode 100644
index 000000000..03efbf78f
--- /dev/null
+++ b/pyhealth/models/promptehr/visit_sampler.py
@@ -0,0 +1,121 @@
+"""
+Sample realistic visit structures from real MIMIC-III data distributions.
+
+This module provides functionality to sample the number of visits per patient
+and the number of diagnosis codes per visit, matching the empirical distributions
+observed in real EHR data.
+"""
+import numpy as np
+from typing import List
+
+
+class VisitStructureSampler:
+ """Sample realistic visit and code count structures from training data."""
+
+ def __init__(self, patient_records: List, seed: int = 42):
+ """Initialize sampler with empirical distributions from training data.
+
+ Args:
+ patient_records: List of patient records from training set.
+ Each record should have a 'visits' attribute (list of visit codes).
+ seed: Random seed for reproducibility.
+ """
+ self.rng = np.random.RandomState(seed)
+
+ # Extract empirical distributions
+ self.num_visits_per_patient = []
+ self.codes_per_visit_all = []
+
+ for patient in patient_records:
+ # Handle both dict-like and object-like patient records
+ if hasattr(patient, 'visits'):
+ visits = patient.visits
+ elif isinstance(patient, dict) and 'visits' in patient:
+ visits = patient['visits']
+ else:
+ continue
+
+ num_visits = len(visits)
+ self.num_visits_per_patient.append(num_visits)
+
+ for visit in visits:
+ num_codes = len(visit)
+ if num_codes > 0: # Only include non-empty visits
+ self.codes_per_visit_all.append(num_codes)
+
+ # Convert to numpy arrays
+ self.num_visits_per_patient = np.array(self.num_visits_per_patient)
+ self.codes_per_visit_all = np.array(self.codes_per_visit_all)
+
+ # Compute statistics for logging
+ self.stats = {
+ 'visits_mean': np.mean(self.num_visits_per_patient),
+ 'visits_median': np.median(self.num_visits_per_patient),
+ 'visits_90th': np.percentile(self.num_visits_per_patient, 90),
+ 'codes_mean': np.mean(self.codes_per_visit_all),
+ 'codes_median': np.median(self.codes_per_visit_all),
+ 'codes_90th': np.percentile(self.codes_per_visit_all, 90),
+ 'codes_95th': np.percentile(self.codes_per_visit_all, 95),
+ }
+
+ def sample_num_visits(self) -> int:
+ """Sample number of visits from empirical distribution.
+
+ Returns:
+ Number of visits (>= 0).
+ """
+ return int(self.rng.choice(self.num_visits_per_patient))
+
+ def sample_codes_per_visit(self, n_visits: int) -> List[int]:
+ """Sample number of codes for each visit from empirical distribution.
+
+ Args:
+ n_visits: Number of visits to sample code counts for.
+
+ Returns:
+ List of integers representing codes per visit.
+ """
+ if n_visits == 0:
+ return []
+
+ # Sample with replacement from empirical distribution
+ codes_counts = self.rng.choice(self.codes_per_visit_all, size=n_visits, replace=True)
+ return codes_counts.tolist()
+
+ def sample_structure(self) -> dict:
+ """Sample complete visit structure (visits + codes per visit).
+
+ Returns:
+ Dictionary with:
+ - 'num_visits': int (number of visits)
+ - 'codes_per_visit': List[int] (codes for each visit)
+ """
+ num_visits = self.sample_num_visits()
+ codes_per_visit = self.sample_codes_per_visit(num_visits)
+
+ return {
+ 'num_visits': num_visits,
+ 'codes_per_visit': codes_per_visit
+ }
+
+ def get_statistics(self) -> dict:
+ """Get statistics about the underlying distributions.
+
+ Returns:
+ Dictionary with mean/median/percentile statistics.
+ """
+ return self.stats.copy()
+
+ def __repr__(self) -> str:
+ """String representation showing distribution statistics."""
+ return (
+ f"VisitStructureSampler(\n"
+ f" Visits/patient: mean={self.stats['visits_mean']:.2f}, "
+ f"median={self.stats['visits_median']:.0f}, "
+ f"90th%={self.stats['visits_90th']:.0f}\n"
+ f" Codes/visit: mean={self.stats['codes_mean']:.2f}, "
+ f"median={self.stats['codes_median']:.0f}, "
+ f"90th%={self.stats['codes_90th']:.0f}, "
+ f"95th%={self.stats['codes_95th']:.0f}\n"
+ f")"
+ )
diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py
index 2f4294a19..c0811c77a 100644
--- a/pyhealth/tasks/__init__.py
+++ b/pyhealth/tasks/__init__.py
@@ -1,5 +1,11 @@
from .base_task import BaseTask
from .benchmark_ehrshot import BenchmarkEHRShot
+from .ehr_generation import (
+ PromptEHRGenerationMIMIC3,
+ PromptEHRGenerationMIMIC4,
+ promptehr_generation_mimic3_fn,
+ promptehr_generation_mimic4_fn,
+)
from .cancer_survival import CancerMutationBurden, CancerSurvivalPrediction
from .bmd_hs_disease_classification import BMDHSDiseaseClassification
from .cardiology_detect import (
diff --git a/pyhealth/tasks/ehr_generation.py b/pyhealth/tasks/ehr_generation.py
new file mode 100644
index 000000000..788dd0351
--- /dev/null
+++ b/pyhealth/tasks/ehr_generation.py
@@ -0,0 +1,142 @@
+"""Task function for PromptEHR synthetic EHR generation.
+
+Provides task classes for training PromptEHR on MIMIC-III and MIMIC-IV datasets.
+Demographics (age, gender) are extracted alongside visit codes because PromptEHR
+conditions generation on patient-level continuous and categorical features.
+"""
+
+from datetime import datetime
+from typing import Dict, List
+
+import polars as pl
+
+from pyhealth.tasks.base_task import BaseTask
+
+
+class PromptEHRGenerationMIMIC3(BaseTask):
+ """Task for PromptEHR synthetic data generation using MIMIC-III.
+
+ PromptEHR is a BART-based seq2seq model that conditions generation on
+ patient demographics (age, gender) via learned prompt vectors. This task
+ extracts per-admission ICD-9 diagnosis codes grouped into a nested visit
+ list, along with patient demographics for conditioning.
+
+ Patients with fewer than 2 admissions containing diagnosis codes are
+ excluded.
+
+ Attributes:
+ task_name (str): Unique task identifier.
+ input_schema (dict): ``"visits"`` uses ``"nested_sequence"`` encoding
+ (list of lists of code strings).
+ output_schema (dict): Empty — generative task, no conditioning label.
+ _icd_col (str): Polars column path for ICD codes in MIMIC-III.
+
+ Examples:
+ >>> fn = PromptEHRGenerationMIMIC3()
+ >>> fn.task_name
+ 'PromptEHRGenerationMIMIC3'
+ """
+
+ task_name = "PromptEHRGenerationMIMIC3"
+ input_schema = {"visits": "nested_sequence"}
+ output_schema = {}
+ _icd_col = "diagnoses_icd/icd9_code"
+
+ def __call__(self, patient) -> List[Dict]:
+ """Extract visit sequences and demographics for a single patient.
+
+ Diagnosis codes are grouped per admission into a nested list. Age is
+ computed as years between date-of-birth and the first admission date.
+ Gender is encoded as 0 (male) or 1 (female). Defaults of
+ ``age=60.0, gender=0`` are used when demographics are unavailable.
+
+ Args:
+ patient: A PyHealth Patient object with admissions and
+ diagnoses_icd event data.
+
+ Returns:
+ list of dict: A single-element list, or empty list if fewer
+ than 2 visits have diagnosis codes. Each dict contains:
+ ``"patient_id"`` (str): patient identifier.
+ ``"visits"`` (list of list of str): ICD codes per visit.
+ ``"age"`` (float): patient age at first admission in years.
+ ``"gender"`` (int): 0 for male, 1 for female.
+ """
+ admissions = list(patient.get_events(event_type="admissions"))
+ if len(admissions) < 2:
+ return []
+
+ # --- Demographics ---
+ age = 60.0
+ gender = 0
+ patients_df = patient.get_events(event_type="patients", return_df=True)
+ if len(patients_df) > 0:
+ if "patients/gender" in patients_df.columns:
+ gender_val = patients_df["patients/gender"][0]
+ if gender_val == "F":
+ gender = 1
+ if "patients/dob" in patients_df.columns and admissions:
+ dob_val = patients_df["patients/dob"][0]
+ first_admit_ts = admissions[0].timestamp
+ if dob_val is not None and first_admit_ts is not None:
+ # dob_val may be a date/datetime or a string
+ if hasattr(dob_val, "year"):
+ dob_dt = datetime(dob_val.year, dob_val.month, dob_val.day)
+ else:
+ dob_dt = datetime.strptime(str(dob_val)[:10], "%Y-%m-%d")
+ raw_age = (first_admit_ts - dob_dt).days / 365.25
+ # Clamp: MIMIC-III shifts >89-year-old DOBs far into the
+ # past; treat those as 90.
+ age = float(min(90.0, max(0.0, raw_age)))
+
+ # --- Visit codes ---
+ visits = []
+ for adm in admissions:
+ codes = (
+ patient.get_events(
+ event_type="diagnoses_icd",
+ filters=[("hadm_id", "==", adm.hadm_id)],
+ return_df=True,
+ )
+ .select(pl.col(self._icd_col))
+ .to_series()
+ .drop_nulls()
+ .to_list()
+ )
+ if codes:
+ visits.append(codes)
+
+ if len(visits) < 2:
+ return []
+
+ return [{
+ "patient_id": patient.patient_id,
+ "visits": visits,
+ "age": age,
+ "gender": gender,
+ }]
+
+
+class PromptEHRGenerationMIMIC4(PromptEHRGenerationMIMIC3):
+ """Task for PromptEHR synthetic data generation using MIMIC-IV.
+
+ Inherits all logic from :class:`PromptEHRGenerationMIMIC3`. Overrides only
+ the task name and ICD code column to match the MIMIC-IV schema, where the
+ column is ``icd_code`` (unversioned) rather than ``icd9_code``.
+
+ Attributes:
+ task_name (str): Unique task identifier.
+ _icd_col (str): Polars column path for ICD codes in MIMIC-IV.
+
+ Examples:
+ >>> fn = PromptEHRGenerationMIMIC4()
+ >>> fn.task_name
+ 'PromptEHRGenerationMIMIC4'
+ """
+
+ task_name = "PromptEHRGenerationMIMIC4"
+ _icd_col = "diagnoses_icd/icd_code"
+
+
+promptehr_generation_mimic3_fn = PromptEHRGenerationMIMIC3()
+promptehr_generation_mimic4_fn = PromptEHRGenerationMIMIC4()
diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/integration/test_promptehr_end_to_end.py b/tests/integration/test_promptehr_end_to_end.py
new file mode 100644
index 000000000..a3c0bdac6
--- /dev/null
+++ b/tests/integration/test_promptehr_end_to_end.py
@@ -0,0 +1,431 @@
+"""End-to-end integration tests for the PromptEHR synthetic EHR generation pipeline.
+
+Category A tests use InMemorySampleDataset with synthetic data — no external
+data required and must always pass.
+
+Category B tests require actual MIMIC-III data and are skipped gracefully when
+the data is unavailable.
+
+The bootstrap pattern mirrors test_corgan_end_to_end.py: load PromptEHR and
+InMemorySampleDataset via importlib while stubbing out heavy optional
+dependencies (litdata, pyarrow) that are not yet in the venv. transformers IS
+available in the venv and is loaded normally.
+"""
+
+import importlib.util
+import os
+import sys
+import tempfile
+import unittest
+from unittest.mock import MagicMock
+
+
+# ---------------------------------------------------------------------------
+# Bootstrap: load PromptEHR, BaseModel, and InMemorySampleDataset without
+# triggering pyhealth.models.__init__ (many models have unavailable deps) or
+# pyhealth.datasets.__init__ (requires litdata, pyarrow, ...).
+# ---------------------------------------------------------------------------
+
+
+def _bootstrap():
+ """Load PromptEHR, BaseModel, and InMemorySampleDataset via importlib.
+
+ Returns:
+ (BaseModel, PromptEHR, InMemorySampleDataset)
+ """
+ import pyhealth # noqa: F401 — top-level __init__ has no heavy deps
+
+ # Stub pyhealth.datasets so that base_model.py's
+ # "from ..datasets import SampleDataset" resolves cleanly.
+ if "pyhealth.datasets" not in sys.modules:
+ ds_stub = MagicMock()
+
+ class _FakeSampleDataset: # noqa: N801
+ pass
+
+ ds_stub.SampleDataset = _FakeSampleDataset
+ sys.modules["pyhealth.datasets"] = ds_stub
+
+ # Stub pyhealth.models so we can control loading without the real __init__.
+ if "pyhealth.models" not in sys.modules or isinstance(
+ sys.modules["pyhealth.models"], MagicMock
+ ):
+ models_stub = MagicMock()
+ sys.modules["pyhealth.models"] = models_stub
+ else:
+ models_stub = sys.modules["pyhealth.models"]
+
+ # Processors are safe to import normally.
+ from pyhealth.processors import PROCESSOR_REGISTRY # noqa: F401
+
+ def _load_file(mod_name, filepath):
+ spec = importlib.util.spec_from_file_location(mod_name, filepath)
+ mod = importlib.util.module_from_spec(spec)
+ sys.modules[mod_name] = mod
+ spec.loader.exec_module(mod)
+ return mod
+
+ root = os.path.dirname(
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+ )
+ models_dir = os.path.join(root, "pyhealth", "models")
+ promptehr_dir = os.path.join(models_dir, "promptehr")
+
+ # Load base_model and expose via stub.
+ bm_mod = _load_file(
+ "pyhealth.models.base_model", os.path.join(models_dir, "base_model.py")
+ )
+ BaseModel = bm_mod.BaseModel
+ models_stub.BaseModel = BaseModel
+
+ # Create a package stub for pyhealth.models.promptehr so that
+ # model.py's relative imports (from .conditional_prompt import ...) work.
+ promptehr_pkg_stub = MagicMock()
+ sys.modules.setdefault("pyhealth.models.promptehr", promptehr_pkg_stub)
+
+ # Load each PromptEHR submodule in dependency order.
+ # Each is standalone (only torch + transformers, no cross-module imports).
+ for mod_name in (
+ "conditional_prompt",
+ "bart_encoder",
+ "bart_decoder",
+ "visit_sampler",
+ "generation",
+ ):
+ _load_file(
+ f"pyhealth.models.promptehr.{mod_name}",
+ os.path.join(promptehr_dir, f"{mod_name}.py"),
+ )
+
+ # Load model.py last (depends on the submodules loaded above + BaseModel).
+ model_mod = _load_file(
+ "pyhealth.models.promptehr.model",
+ os.path.join(promptehr_dir, "model.py"),
+ )
+ PromptEHR = model_mod.PromptEHR
+
+ # Stub litdata so sample_dataset.py can be loaded without the full package.
+ if "litdata" not in sys.modules:
+ litdata_pkg = MagicMock()
+ litdata_pkg.StreamingDataset = type(
+ "StreamingDataset", (), {"__init__": lambda self, *a, **kw: None}
+ )
+ litdata_utilities = MagicMock()
+ litdata_utilities_train_test = MagicMock()
+ litdata_utilities_train_test.deepcopy_dataset = lambda x: x
+ litdata_utilities.train_test_split = litdata_utilities_train_test
+ litdata_pkg.utilities = litdata_utilities
+ sys.modules["litdata"] = litdata_pkg
+ sys.modules["litdata.utilities"] = litdata_utilities
+ sys.modules["litdata.utilities.train_test_split"] = (
+ litdata_utilities_train_test
+ )
+
+ # Load sample_dataset.py directly (bypasses datasets/__init__.py).
+ ds_file_mod = _load_file(
+ "pyhealth.datasets.sample_dataset",
+ os.path.join(root, "pyhealth", "datasets", "sample_dataset.py"),
+ )
+ InMemorySampleDataset = ds_file_mod.InMemorySampleDataset
+
+ return BaseModel, PromptEHR, InMemorySampleDataset
+
+
+BaseModel, PromptEHR, InMemorySampleDataset = _bootstrap()
+
+import torch # noqa: E402
+from transformers import BartConfig # noqa: E402
+
+
+# ---------------------------------------------------------------------------
+# Shared helpers
+# ---------------------------------------------------------------------------
+
+# Nested lists of code strings — PromptEHR uses nested_sequence schema.
+# 8 samples with ≥2 visits each, plus demographics.
+_SMALL_SAMPLES = [
+ {"patient_id": "p1", "visits": [["A", "B"], ["C", "D"]], "age": 65.0, "gender": 0},
+ {"patient_id": "p2", "visits": [["E"], ["F", "G"]], "age": 45.0, "gender": 1},
+ {"patient_id": "p3", "visits": [["A", "C"], ["B", "E"]], "age": 55.0, "gender": 0},
+ {"patient_id": "p4", "visits": [["D"], ["A"]], "age": 70.0, "gender": 1},
+ {"patient_id": "p5", "visits": [["B", "F"], ["C", "G"]], "age": 40.0, "gender": 0},
+ {"patient_id": "p6", "visits": [["E", "A"], ["D"]], "age": 60.0, "gender": 1},
+ {"patient_id": "p7", "visits": [["G", "B"], ["F", "A"]], "age": 50.0, "gender": 0},
+ {"patient_id": "p8", "visits": [["C"], ["D", "E"]], "age": 35.0, "gender": 1},
+]
+
+# Tiny BART config to keep tests fast (avoids downloading/using 768-dim bart-base).
+_TINY_BART_CONFIG = BartConfig(
+ d_model=32,
+ encoder_layers=1,
+ decoder_layers=1,
+ encoder_ffn_dim=64,
+ decoder_ffn_dim=64,
+ encoder_attention_heads=4,
+ decoder_attention_heads=4,
+ max_position_embeddings=128,
+)
+
+# Minimal model kwargs — tiny architecture and 1 epoch to keep tests fast.
+_SMALL_MODEL_KWARGS = dict(
+ n_num_features=1,
+ cat_cardinalities=[2],
+ d_hidden=32,
+ prompt_length=1,
+ bart_config_name=_TINY_BART_CONFIG,
+ epochs=1,
+ batch_size=4,
+ warmup_steps=0,
+ max_seq_length=64,
+)
+
+
+def _make_dataset(samples=None):
+ if samples is None:
+ samples = _SMALL_SAMPLES
+ return InMemorySampleDataset(
+ samples=samples,
+ input_schema={"visits": "nested_sequence"},
+ output_schema={},
+ )
+
+
+def _make_trained_model():
+ """Return a PromptEHR model trained for 1 epoch on _SMALL_SAMPLES."""
+ dataset = _make_dataset()
+ tmpdir = tempfile.mkdtemp()
+ model = PromptEHR(dataset, save_dir=tmpdir, **_SMALL_MODEL_KWARGS)
+ model.train_model(dataset)
+ return model, tmpdir
+
+
+# ---------------------------------------------------------------------------
+# Category A: In-Memory Integration Tests (must always pass)
+# ---------------------------------------------------------------------------
+
+
+class TestPromptEHRIsBaseModelInstance(unittest.TestCase):
+ """PromptEHR model is an instance of BaseModel."""
+
+ def test_model_is_basemodel_instance(self):
+ dataset = _make_dataset()
+ model = PromptEHR(dataset, **_SMALL_MODEL_KWARGS)
+ self.assertIsInstance(model, BaseModel)
+
+
+class TestPromptEHRFeatureKeys(unittest.TestCase):
+ """model.feature_keys equals ['visits']."""
+
+ def test_feature_keys(self):
+ dataset = _make_dataset()
+ model = PromptEHR(dataset, **_SMALL_MODEL_KWARGS)
+ self.assertEqual(model.feature_keys, ["visits"])
+
+
+class TestPromptEHRVocabSize(unittest.TestCase):
+ """_vocab.total_size equals processor.vocab_size() + 5."""
+
+ def test_vocab_size_matches_processor(self):
+ dataset = _make_dataset()
+ processor = dataset.input_processors["visits"]
+ model = PromptEHR(dataset, **_SMALL_MODEL_KWARGS)
+ expected = processor.vocab_size() + 5
+ self.assertEqual(model._vocab.total_size, expected)
+
+
+class TestPromptEHRForwardRaisesNotImplementedError(unittest.TestCase):
+ """Calling forward() raises NotImplementedError.
+
+ PromptEHR is a generative model; the discriminative forward pass is not
+ applicable.
+ """
+
+ def test_forward_not_implemented(self):
+ dataset = _make_dataset()
+ model = PromptEHR(dataset, **_SMALL_MODEL_KWARGS)
+ with self.assertRaises(NotImplementedError):
+ model.forward()
+
+
+class TestPromptEHRTrainModelRuns(unittest.TestCase):
+ """train_model completes one epoch without error."""
+
+ def test_train_model_runs_one_epoch(self):
+ dataset = _make_dataset()
+ with tempfile.TemporaryDirectory() as tmpdir:
+ model = PromptEHR(dataset, save_dir=tmpdir, **_SMALL_MODEL_KWARGS)
+ try:
+ model.train_model(dataset, val_dataset=None)
+ except Exception as exc: # noqa: BLE001
+ self.fail(f"train_model raised an unexpected exception: {exc}")
+ # A checkpoint must be saved after training
+ ckpt = os.path.join(tmpdir, "checkpoint.pt")
+ self.assertTrue(os.path.exists(ckpt), f"Expected checkpoint at {ckpt}")
+
+
+class TestPromptEHRSynthesizeCount(unittest.TestCase):
+ """synthesize_dataset(num_samples=3) returns exactly 3 dicts."""
+
+ @classmethod
+ def setUpClass(cls):
+ cls.model, cls.tmpdir = _make_trained_model()
+
+ def test_synthesize_returns_correct_count(self):
+ result = self.model.synthesize_dataset(num_samples=3)
+ self.assertIsInstance(result, list)
+ self.assertEqual(len(result), 3)
+
+
+class TestPromptEHRSynthesizeOutputStructure(unittest.TestCase):
+ """Each synthesized dict has patient_id (str) and visits (nested list of str).
+
+ PromptEHR outputs nested visit lists — each patient is a list of visits,
+ each visit is a list of diagnosis code strings.
+ """
+
+ @classmethod
+ def setUpClass(cls):
+ cls.model, cls.tmpdir = _make_trained_model()
+
+ def test_synthesize_output_structure(self):
+ result = self.model.synthesize_dataset(num_samples=3)
+ for i, item in enumerate(result):
+ self.assertIsInstance(item, dict, f"Item {i} is not a dict")
+ self.assertIn("patient_id", item, f"Item {i} missing 'patient_id'")
+ self.assertIn("visits", item, f"Item {i} missing 'visits'")
+ self.assertIsInstance(
+ item["patient_id"], str, f"patient_id in item {i} is not a str"
+ )
+ self.assertIsInstance(
+ item["visits"], list, f"visits in item {i} is not a list"
+ )
+ # visits is a nested list: list of visits, each visit a list of strings
+ for visit_idx, visit in enumerate(item["visits"]):
+ self.assertIsInstance(
+ visit, list,
+ f"visit {visit_idx} in item {i} is not a list"
+ )
+ for code in visit:
+ self.assertIsInstance(
+ code, str,
+ f"code '{code}' in visit {visit_idx}, item {i} is not str"
+ )
+
+
+class TestPromptEHRSaveLoadRoundtrip(unittest.TestCase):
+ """save_model then load_model; synthesize_dataset returns correct count."""
+
+ def test_save_load_roundtrip(self):
+ dataset = _make_dataset()
+ with tempfile.TemporaryDirectory() as tmpdir:
+ model = PromptEHR(dataset, save_dir=tmpdir, **_SMALL_MODEL_KWARGS)
+ model.train_model(dataset)
+ ckpt_path = os.path.join(tmpdir, "test_ckpt.pt")
+ model.save_model(ckpt_path)
+ self.assertTrue(
+ os.path.exists(ckpt_path),
+ f"Expected checkpoint at {ckpt_path}",
+ )
+ model.load_model(ckpt_path)
+ result = model.synthesize_dataset(num_samples=3)
+ self.assertEqual(len(result), 3)
+
+
+# ---------------------------------------------------------------------------
+# Category B: MIMIC-III Integration Tests (skipped if data unavailable)
+# ---------------------------------------------------------------------------
+
+_MIMIC3_PATH = os.environ.get(
+ "PYHEALTH_MIMIC3_PATH",
+ "/srv/local/data/physionet.org/files/mimiciii/1.4",
+)
+
+
+class TestPromptEHRMIMIC3Integration(unittest.TestCase):
+ """End-to-end pipeline test with actual MIMIC-III data.
+
+ Skipped automatically when MIMIC-III is not present on this machine.
+ """
+
+ @classmethod
+ def setUpClass(cls):
+ cls.skip_integration = False
+ cls.skip_reason = ""
+ try:
+ # Remove bootstrap stubs so we can attempt a real import.
+ _saved_ds_stub = sys.modules.pop("pyhealth.datasets", None)
+ try:
+ import importlib as _il
+ _il.invalidate_caches()
+ from pyhealth.datasets import MIMIC3Dataset as _MIMIC3Dataset
+ from pyhealth.tasks.ehr_generation import PromptEHRGenerationMIMIC3
+ except (ImportError, ModuleNotFoundError) as exc:
+ if _saved_ds_stub is not None:
+ sys.modules["pyhealth.datasets"] = _saved_ds_stub
+ raise ImportError(str(exc)) from exc
+
+ cls.dataset = _MIMIC3Dataset(
+ root=_MIMIC3_PATH,
+ tables=["patients", "admissions", "diagnoses_icd"],
+ )
+ task = PromptEHRGenerationMIMIC3()
+ cls.sample_dataset = cls.dataset.set_task(task)
+ except (FileNotFoundError, OSError, ImportError, ValueError) as exc:
+ cls.skip_integration = True
+ cls.skip_reason = str(exc)
+
+ def setUp(self):
+ if self.skip_integration:
+ self.skipTest(f"MIMIC-III integration test skipped: {self.skip_reason}")
+
+ def test_mimic3_set_task_returns_nonempty_dataset(self):
+ """set_task produces at least one sample from MIMIC-III."""
+ self.assertGreater(len(self.sample_dataset), 0)
+
+ def test_mimic3_sample_keys(self):
+ """Every sample must contain patient_id, visits, age, and gender keys."""
+ for sample in self.sample_dataset:
+ self.assertIn("patient_id", sample)
+ self.assertIn("visits", sample)
+ self.assertIn("age", sample)
+ self.assertIn("gender", sample)
+
+ def test_mimic3_visits_are_nested_tensors(self):
+ """visits must be a list of 1-D int64 tensors (NestedSequenceProcessor output).
+
+ NestedSequenceProcessor encodes each visit as a 1-D LongTensor of
+ code indices. This verifies the nested_sequence schema round-trips
+ correctly through set_task.
+ """
+ for sample in self.sample_dataset:
+ visits = sample["visits"]
+ self.assertIsInstance(visits, list)
+ self.assertGreater(len(visits), 0)
+ for visit in visits:
+ self.assertIsInstance(visit, torch.Tensor)
+ self.assertEqual(visit.dtype, torch.long)
+
+ def test_mimic3_full_pipeline_train_and_synthesize(self):
+ """Train one epoch on MIMIC-III data and synthesize a small batch."""
+ with tempfile.TemporaryDirectory() as tmpdir:
+ model = PromptEHR(
+ self.sample_dataset,
+ d_hidden=64,
+ prompt_length=1,
+ bart_config_name=_TINY_BART_CONFIG,
+ epochs=1,
+ batch_size=16,
+ warmup_steps=0,
+ save_dir=tmpdir,
+ )
+ model.train_model(self.sample_dataset, val_dataset=None)
+ synthetic = model.synthesize_dataset(num_samples=5)
+ self.assertEqual(len(synthetic), 5)
+ for item in synthetic:
+ self.assertIn("patient_id", item)
+ self.assertIn("visits", item)
+ self.assertIsInstance(item["visits"], list)
+
+
+if __name__ == "__main__":
+ unittest.main()