From 078ecf9d32b50e9eda37368ba339e800612944d4 Mon Sep 17 00:00:00 2001 From: jalengg Date: Sun, 1 Mar 2026 01:41:40 -0600 Subject: [PATCH 01/37] T1: add PromptEHR source files, task stub, and examples --- examples/promptehr_generate_local.py | 157 +++ examples/promptehr_mimic3.py | 565 +++++++++ pyhealth/models/__init__.py | 1 + pyhealth/models/promptehr/__init__.py | 41 + pyhealth/models/promptehr/bart_decoder.py | 325 +++++ pyhealth/models/promptehr/bart_encoder.py | 214 ++++ .../models/promptehr/conditional_prompt.py | 251 ++++ pyhealth/models/promptehr/generation.py | 1070 +++++++++++++++++ pyhealth/models/promptehr/model.py | 548 +++++++++ pyhealth/models/promptehr/utils.py | 29 + pyhealth/models/promptehr/visit_sampler.py | 121 ++ pyhealth/tasks/ehr_generation.py | 30 + 12 files changed, 3352 insertions(+) create mode 100644 examples/promptehr_generate_local.py create mode 100644 examples/promptehr_mimic3.py create mode 100644 pyhealth/models/promptehr/__init__.py create mode 100644 pyhealth/models/promptehr/bart_decoder.py create mode 100644 pyhealth/models/promptehr/bart_encoder.py create mode 100644 pyhealth/models/promptehr/conditional_prompt.py create mode 100644 pyhealth/models/promptehr/generation.py create mode 100644 pyhealth/models/promptehr/model.py create mode 100644 pyhealth/models/promptehr/utils.py create mode 100644 pyhealth/models/promptehr/visit_sampler.py create mode 100644 pyhealth/tasks/ehr_generation.py diff --git a/examples/promptehr_generate_local.py b/examples/promptehr_generate_local.py new file mode 100644 index 000000000..33b9ad41c --- /dev/null +++ b/examples/promptehr_generate_local.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python3 +"""Quick local generation test for PromptEHR (CPU-only). + +This script demonstrates how to: +1. Load a trained PromptEHR checkpoint +2. Generate synthetic patients on CPU (no GPU required) +3. Display results in human-readable format + +Usage: + python3 examples/promptehr_generate_local.py +""" + +import sys +sys.path.insert(0, '/u/jalenj4/final/PyHealth') + +import torch +import logging +from pathlib import Path + +# PyHealth imports +from pyhealth.models import PromptEHR +from pyhealth.datasets.promptehr_dataset import load_mimic_data +from pyhealth.models.promptehr import ( + VisitStructureSampler, + generate_patient_with_structure_constraints +) + + +def main(): + """Generate 10 synthetic patients locally on CPU.""" + + # Setup + device = torch.device("cpu") # Force CPU (no GPU required) + logging.basicConfig( + level=logging.WARNING, # Reduce noise, only show warnings/errors + format='%(message)s' + ) + logger = logging.getLogger(__name__) + + print("\n" + "="*80) + print("PromptEHR Local Generation Test (CPU mode)") + print("="*80) + + # Load checkpoint + print("\n[1/4] Loading trained checkpoint...") + checkpoint_path = "./promptehr_outputs/checkpoints/final_model.pt" + + if not Path(checkpoint_path).exists(): + print(f"ERROR: Checkpoint not found at {checkpoint_path}") + print("Please ensure training has completed and checkpoint exists.") + return + + checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) + tokenizer = checkpoint['tokenizer'] + + # Add convenience properties and methods if not present + # (for compatibility with old checkpoints saved before these were added) + if not hasattr(tokenizer, 'bos_token_id'): + tokenizer.pad_token_id = tokenizer.vocabulary("") # ID 0 + tokenizer.bos_token_id = tokenizer.vocabulary("") # ID 1 + tokenizer.eos_token_id = tokenizer.vocabulary("") # ID 2 + tokenizer.code_offset = 7 # First diagnosis code ID (after 7 special tokens) + if not hasattr(tokenizer, 'convert_tokens_to_ids'): + # Add method alias: pehr_scratch API uses convert_tokens_to_ids(token) → int + def convert_tokens_to_ids(token: str) -> int: + return tokenizer.convert_tokens_to_indices([token])[0] + tokenizer.convert_tokens_to_ids = convert_tokens_to_ids + if not hasattr(tokenizer, 'vocab'): + # Add vocab object for idx2code and code2idx mappings + class VocabCompat: + def __init__(self, tok): + self.idx2code = tok.vocabulary.idx2token + self.code2idx = tok.vocabulary.token2idx + def __len__(self): + return len(self.idx2code) + tokenizer.vocab = VocabCompat(tokenizer) + + # Rebuild model + print("[2/4] Rebuilding model from checkpoint...") + config = checkpoint['config'] + model = PromptEHR(**config) + model.bart_model.load_state_dict(checkpoint['model_state_dict']) + model.to(device) + model.eval() + + print(f" Model vocabulary size: {config['_custom_vocab_size']}") + print(f" Hidden dimension: {config['d_hidden']}") + print(f" Prompt length: {config['prompt_length']}") + + # Load MIMIC data for structure sampling + print("[3/4] Loading MIMIC-III data for structure sampling...") + print(" (Loading 1000 patients for realistic visit distributions)") + + patient_records, _ = load_mimic_data( + patients_path="/u/jalenj4/pehr_scratch/data_files/PATIENTS.csv", + admissions_path="/u/jalenj4/pehr_scratch/data_files/ADMISSIONS.csv", + diagnoses_path="/u/jalenj4/pehr_scratch/data_files/DIAGNOSES_ICD.csv", + num_patients=1000, + logger=logger + ) + + # Initialize structure sampler + structure_sampler = VisitStructureSampler(patient_records, seed=42) + print(f" {structure_sampler}") + + # Generate synthetic patients + n_patients = 10 + print(f"\n[4/4] Generating {n_patients} synthetic patients...") + print(" (This will take ~10-15 seconds)") + print() + + print("="*80) + print("SYNTHETIC PATIENTS") + print("="*80) + print() + + for i in range(n_patients): + # Sample realistic visit structure + target_structure = structure_sampler.sample_structure() + + # Generate patient + result = generate_patient_with_structure_constraints( + model=model, + tokenizer=tokenizer, + device=device, + target_structure=target_structure, + temperature=0.7, + top_k=40, + top_p=0.9, + max_codes_per_visit=25 + ) + + # Display patient + demo = result['demographics'] + print(f"Patient {i+1}:") + print(f" Age: {demo['age']} years") + print(f" Sex: {'Male' if demo['sex'] == 0 else 'Female'}") + print(f" Number of visits: {result['num_visits']}") + print(f" Diagnosis codes:") + + for visit_idx, codes in enumerate(result['generated_visits'], 1): + if codes: + print(f" Visit {visit_idx}: {', '.join(codes)}") + else: + print(f" Visit {visit_idx}: (no diagnoses)") + print() + + print("="*80) + print("Generation complete!") + print("="*80) + print() + print(f"Successfully generated {n_patients} synthetic patients on CPU.") + print() + + +if __name__ == "__main__": + main() diff --git a/examples/promptehr_mimic3.py b/examples/promptehr_mimic3.py new file mode 100644 index 000000000..1f42868d2 --- /dev/null +++ b/examples/promptehr_mimic3.py @@ -0,0 +1,565 @@ +"""PromptEHR: Training and Generation Example on MIMIC-III + +This example demonstrates the complete PromptEHR pipeline: +1. Load MIMIC-III patient records +2. Train PromptEHR model for synthetic EHR generation +3. Generate synthetic patients with realistic visit structures +4. Evaluate generation quality + +References: + - Paper: "PromptEHR: Conditional Electronic Health Records Generation with Prompt Learning" + - pehr_scratch implementation: /u/jalenj4/pehr_scratch/ +""" + +import os +import sys +import logging +from pathlib import Path +from typing import List, Dict + +import torch +import torch.nn as nn +from torch.utils.data import DataLoader, random_split +from torch.optim import AdamW +from transformers import BartConfig, get_linear_schedule_with_warmup + +# PyHealth imports +from pyhealth.datasets import MIMIC3Dataset +from pyhealth.models import PromptEHR +from pyhealth.trainer import Trainer +from pyhealth.datasets.promptehr_dataset import ( + create_promptehr_tokenizer, + PromptEHRDataset, + load_mimic_data +) +from pyhealth.datasets.promptehr_collator import EHRDataCollator + + +class DeviceAwareCollatorWrapper: + """Wrapper around EHRDataCollator that moves tensors to specified device. + + This wrapper addresses PyHealth Trainer limitation where data is not automatically + moved to device before forward pass. The Trainer directly calls model(**data) at + line 206 without device transfer, requiring collator to handle device placement. + + Args: + collator: Base EHRDataCollator instance + device: Target device ('cuda' or 'cpu') + """ + + def __init__(self, collator: EHRDataCollator, device: str): + """Initialize wrapper with base collator and target device.""" + self.collator = collator + self.device = torch.device(device) + + def __call__(self, batch: List[Dict]) -> Dict[str, torch.Tensor]: + """Collate batch and move all tensors to target device. + + Args: + batch: List of sample dictionaries + + Returns: + Dictionary with batched tensors on target device + """ + # Get batched tensors from base collator (CPU tensors) + batched_data = self.collator(batch) + + # Move all tensors to target device + device_data = { + key: value.to(self.device) if isinstance(value, torch.Tensor) else value + for key, value in batched_data.items() + } + + return device_data + + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + + +def train_promptehr( + mimic3_root: str, + output_dir: str = "./promptehr_outputs", + num_patients: int = 46520, # Full MIMIC-III dataset + batch_size: int = 16, + num_epochs: int = 20, + learning_rate: float = 1e-5, + warmup_steps: int = 1000, + val_split: float = 0.2, + device: str = "cuda", + checkpoint_path: str = None +): + """Train PromptEHR model on MIMIC-III dataset. + + Args: + mimic3_root: Path to MIMIC-III data directory containing: + - PATIENTS.csv + - ADMISSIONS.csv + - DIAGNOSES_ICD.csv + output_dir: Directory to save outputs (checkpoints, logs) + num_patients: Number of patients to load (default: full dataset) + batch_size: Training batch size + num_epochs: Number of training epochs + learning_rate: AdamW learning rate + warmup_steps: Linear warmup steps for scheduler + val_split: Validation split ratio + device: Device to use ('cuda' or 'cpu') + checkpoint_path: Path to resume from checkpoint (optional) + + Returns: + Trained PromptEHR model + """ + # Create output directory + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + checkpoint_dir = output_dir / "checkpoints" + checkpoint_dir.mkdir(exist_ok=True) + + logger.info("=" * 80) + logger.info("PromptEHR Training Pipeline") + logger.info("=" * 80) + logger.info(f"MIMIC-III root: {mimic3_root}") + logger.info(f"Output directory: {output_dir}") + logger.info(f"Device: {device}") + + # Step 1: Load MIMIC-III patient records + logger.info("\n" + "=" * 80) + logger.info("Loading MIMIC-III Patient Records") + logger.info("=" * 80) + + patients_path = os.path.join(mimic3_root, "PATIENTS.csv") + admissions_path = os.path.join(mimic3_root, "ADMISSIONS.csv") + diagnoses_path = os.path.join(mimic3_root, "DIAGNOSES_ICD.csv") + + patient_records, diagnosis_codes = load_mimic_data( + patients_path=patients_path, + admissions_path=admissions_path, + diagnoses_path=diagnoses_path, + num_patients=num_patients, + logger=logger + ) + + logger.info(f"Loaded {len(patient_records)} patients") + logger.info(f"Vocabulary size: {len(diagnosis_codes)} diagnosis codes") + + # Step 2: Create tokenizer + logger.info("\n" + "=" * 80) + logger.info("Creating Tokenizer") + logger.info("=" * 80) + + tokenizer = create_promptehr_tokenizer(diagnosis_codes) + vocab_size = tokenizer.get_vocabulary_size() + logger.info(f"Tokenizer vocabulary size: {vocab_size}") + logger.info(f" Special tokens: 7") + logger.info(f" Diagnosis codes: {len(diagnosis_codes)}") + logger.info(f" Code offset: 7") + + # Step 3: Create dataset + logger.info("\n" + "=" * 80) + logger.info("Creating Dataset") + logger.info("=" * 80) + + dataset = PromptEHRDataset(patient_records, tokenizer, logger) + logger.info(f"Dataset size: {len(dataset)} patients") + + # Train/validation split + train_size = int((1 - val_split) * len(dataset)) + val_size = len(dataset) - train_size + train_dataset, val_dataset = random_split(dataset, [train_size, val_size]) + logger.info(f"Train size: {train_size}, Validation size: {val_size}") + + # Create data collator + # CRITICAL FIX: Disable token replacement to prevent distribution inversion + # Token replacement causes rare codes to be enriched 3.24x and common codes depleted to 0.85x + base_collator = EHRDataCollator( + tokenizer=tokenizer, + max_seq_length=512, + logger=logger, + corruption_prob=0.5, + use_mask_infilling=True, + use_token_deletion=True, + use_token_replacement=False # DISABLED: Causes 4700x frequency inversion + ) + + # Wrap collator to handle device placement + # PyHealth Trainer does not move data to device (line 206: model(**data)) + # so we must handle device transfer in the collator + collator = DeviceAwareCollatorWrapper(base_collator, device) + logger.info(f"Using device-aware collator wrapper (target device: {device})") + + # Create data loaders + train_loader = DataLoader( + train_dataset, + batch_size=batch_size, + shuffle=True, + collate_fn=collator + ) + val_loader = DataLoader( + val_dataset, + batch_size=batch_size, + shuffle=False, + collate_fn=collator + ) + + logger.info(f"Train batches: {len(train_loader)}, Validation batches: {len(val_loader)}") + + # Step 4: Initialize model + logger.info("\n" + "=" * 80) + logger.info("Initializing PromptEHR Model") + logger.info("=" * 80) + + model = PromptEHR( + dataset=None, # Generative model, no discriminative task + n_num_features=1, # Age + cat_cardinalities=[2], # Gender (M/F) + d_hidden=128, + prompt_length=1, + bart_config_name="facebook/bart-base", + _custom_vocab_size=vocab_size # Custom vocab size for MIMIC-III + ) + + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + logger.info(f"Total parameters: {total_params:,}") + logger.info(f"Trainable parameters: {trainable_params:,}") + + # Step 5: Configure trainer + logger.info("\n" + "=" * 80) + logger.info("Configuring Trainer") + logger.info("=" * 80) + + trainer = Trainer( + model=model, + checkpoint_path=checkpoint_path, + metrics=["loss"], + device=device, + enable_logging=True, + output_path=str(output_dir) + ) + + # Step 6: Train + logger.info("\n" + "=" * 80) + logger.info("Starting Training") + logger.info("=" * 80) + + trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + epochs=num_epochs, + optimizer_params={"lr": learning_rate, "weight_decay": 0.01}, + monitor="loss" + ) + + # Step 7: Save final model + final_checkpoint = checkpoint_dir / "final_model.pt" + torch.save({ + 'model_state_dict': model.bart_model.state_dict(), # Save BART model state + 'tokenizer': tokenizer, + 'diagnosis_codes': diagnosis_codes, + 'config': { + 'dataset': None, + 'n_num_features': 1, + 'cat_cardinalities': [2], + 'd_hidden': 128, + 'prompt_length': 1, + 'bart_config_name': "facebook/bart-base", + '_custom_vocab_size': vocab_size + } + }, final_checkpoint) + logger.info(f"\nFinal model saved to: {final_checkpoint}") + + logger.info("\n" + "=" * 80) + logger.info("Training Complete!") + logger.info("=" * 80) + + return model, tokenizer + + +def generate_synthetic_patients( + model: PromptEHR, + tokenizer, + patient_records: List, + num_patients: int = 100, + temperature: float = 0.7, + alpha: float = 2.0, + device: str = "cuda", + mimic3_root: str = None +): + """Generate synthetic patients using trained PromptEHR model. + + Args: + model: Trained PromptEHR model + tokenizer: PromptEHR tokenizer + patient_records: Real patient records (for structure sampling) + num_patients: Number of synthetic patients to generate + temperature: Sampling temperature + device: Device to use + mimic3_root: Path to MIMIC-III training data (for first code prior) + + Returns: + List of generated patient dictionaries + """ + from pyhealth.models.promptehr import VisitStructureSampler + from pyhealth.models.promptehr.generation import ( + DemographicSampler, + build_frequency_prior, + generate_with_frequency_prior + ) + + logger.info("\n" + "=" * 80) + logger.info(f"Generating {num_patients} Synthetic Patients") + logger.info("=" * 80) + + # Initialize visit structure sampler + structure_sampler = VisitStructureSampler(patient_records, seed=42) + logger.info(f"Structure sampler: {structure_sampler}") + + # Initialize demographic sampler + demographic_sampler = DemographicSampler(patient_records, seed=42) + logger.info(f"Demographic sampler: {demographic_sampler}") + + # Build frequency prior for ALL code generation + frequency_prior = None + freq_path = Path(mimic3_root).parent / "promptehr_outputs" / "training_frequencies.json" + if not freq_path.exists(): + freq_path = Path("promptehr_outputs") / "training_frequencies.json" + + if freq_path.exists(): + logger.info(f"Building frequency prior from {freq_path}...") + try: + frequency_prior = build_frequency_prior( + tokenizer, + frequency_path=str(freq_path), + vocab_size=len(tokenizer.vocab.idx2code) + ) + logger.info(f"Frequency prior built: shape {frequency_prior.shape}") + except Exception as e: + logger.warning(f"Failed to build frequency prior: {e}") + logger.warning("Continuing without frequency guidance...") + else: + logger.warning(f"training_frequencies.json not found at {freq_path}") + logger.warning("Continuing without frequency guidance...") + + # Set model to eval mode + model.eval() + model.to(device) + + # Generate patients + generated_patients = [] + for i in range(num_patients): + if (i + 1) % 20 == 0: + logger.info(f"Generated {i + 1}/{num_patients} patients...") + + # Sample realistic visit structure + target_structure = structure_sampler.sample_structure() + + # Sample demographics from empirical distribution + demographics = demographic_sampler.sample() + age = demographics['age'] + sex = demographics['sex'] + + # Generate patient with frequency-guided sampling + if frequency_prior is not None: + result = generate_with_frequency_prior( + model=model, + tokenizer=tokenizer, + device=device, + target_structure=target_structure, + frequency_prior=frequency_prior, + alpha=alpha, # Frequency prior weight (optimal: 2.0 from diagnostic) + age=age, + sex=sex, + temperature=temperature, # Sampling temperature (optimal: 1.0 from diagnostic) + top_k=0, # Disabled - use full vocabulary + top_p=0.95, # Nucleus sampling for quality + max_codes_per_visit=25 + ) + else: + # Fallback to regular generation if no frequency prior + from pyhealth.models.promptehr import generate_patient_with_structure_constraints + result = generate_patient_with_structure_constraints( + model=model, + tokenizer=tokenizer, + device=device, + target_structure=target_structure, + age=age, + sex=sex, + temperature=0.5, + top_k=0, + top_p=0.95, + max_codes_per_visit=25 + ) + + # Store result + demo = result['demographics'] + generated_patients.append({ + 'patient_id': f"SYNTH_{i+1:04d}", + 'age': demo['age'], + 'sex': 'M' if demo['sex'] == 0 else 'F', + 'num_visits': result['num_visits'], + 'visits': result['generated_visits'] + }) + + logger.info(f"\nGeneration complete: {num_patients} patients created") + + # Display statistics + total_visits = sum(p['num_visits'] for p in generated_patients) + total_codes = sum(len(code) for p in generated_patients for visit in p['visits'] for code in visit) + unique_codes = len(set(code for p in generated_patients for visit in p['visits'] for code in visit)) + + logger.info(f"\nDataset Statistics:") + logger.info(f" Total patients: {num_patients}") + logger.info(f" Total visits: {total_visits}") + logger.info(f" Total diagnosis codes: {total_codes}") + logger.info(f" Unique codes: {unique_codes}") + logger.info(f" Average visits/patient: {total_visits/num_patients:.2f}") + logger.info(f" Average codes/patient: {total_codes/num_patients:.1f}") + + return generated_patients + + +def save_synthetic_dataset( + patients: List[Dict], + output_path: str, + format: str = "csv" +): + """Save generated patients to file. + + Args: + patients: List of patient dictionaries + output_path: Path to save file + format: Output format ('csv' or 'json') + """ + import csv + import json + + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + if format == "csv": + with open(output_path, 'w', newline='') as f: + writer = csv.writer(f) + writer.writerow(['patient_id', 'age', 'sex', 'num_visits', 'visit_num', 'diagnosis_codes']) + + for patient in patients: + for visit_idx, visit_codes in enumerate(patient['visits']): + codes_str = ';'.join(visit_codes) + writer.writerow([ + patient['patient_id'], + f"{patient['age']:.1f}", + patient['sex'], + patient['num_visits'], + visit_idx + 1, + codes_str + ]) + + logger.info(f"Saved {len(patients)} patients to {output_path} (CSV format)") + + elif format == "json": + with open(output_path, 'w') as f: + json.dump(patients, f, indent=2) + + logger.info(f"Saved {len(patients)} patients to {output_path} (JSON format)") + + +def main(): + """Main entry point for PromptEHR training and generation.""" + import argparse + + parser = argparse.ArgumentParser(description="PromptEHR Training and Generation") + parser.add_argument("--mimic3_root", type=str, required=True, + help="Path to MIMIC-III data directory") + parser.add_argument("--output_dir", type=str, default="./promptehr_outputs", + help="Output directory for checkpoints and results") + parser.add_argument("--num_patients", type=int, default=46520, + help="Number of patients to load for training") + parser.add_argument("--batch_size", type=int, default=16, + help="Training batch size") + parser.add_argument("--num_epochs", type=int, default=20, + help="Number of training epochs") + parser.add_argument("--learning_rate", type=float, default=1e-5, + help="Learning rate") + parser.add_argument("--device", type=str, default="cuda", + help="Device to use (cuda or cpu)") + parser.add_argument("--checkpoint", type=str, default=None, + help="Path to checkpoint to resume from") + parser.add_argument("--generate_only", action="store_true", + help="Skip training, only generate (requires --checkpoint)") + parser.add_argument("--num_synthetic", type=int, default=100, + help="Number of synthetic patients to generate") + parser.add_argument("--temperature", type=float, default=0.7, + help="Sampling temperature for generation") + parser.add_argument("--alpha", type=float, default=2.0, + help="Frequency prior weight (alpha) for generation") + + args = parser.parse_args() + + # Training + if not args.generate_only: + model, tokenizer = train_promptehr( + mimic3_root=args.mimic3_root, + output_dir=args.output_dir, + num_patients=args.num_patients, + batch_size=args.batch_size, + num_epochs=args.num_epochs, + learning_rate=args.learning_rate, + device=args.device, + checkpoint_path=args.checkpoint + ) + else: + # Load from checkpoint + if args.checkpoint is None: + raise ValueError("--checkpoint required when using --generate_only") + + logger.info(f"Loading model from checkpoint: {args.checkpoint}") + # PyTorch 2.6+ requires weights_only=False to load checkpoints with custom objects (tokenizer) + checkpoint = torch.load(args.checkpoint, weights_only=False) + tokenizer = checkpoint['tokenizer'] + + model = PromptEHR(**checkpoint['config']) + model.bart_model.load_state_dict(checkpoint['model_state_dict']) + model.to(args.device) + model.eval() + + # Load patient records for structure sampling + patients_path = os.path.join(args.mimic3_root, "PATIENTS.csv") + admissions_path = os.path.join(args.mimic3_root, "ADMISSIONS.csv") + diagnoses_path = os.path.join(args.mimic3_root, "DIAGNOSES_ICD.csv") + + patient_records, _ = load_mimic_data( + patients_path=patients_path, + admissions_path=admissions_path, + diagnoses_path=diagnoses_path, + num_patients=args.num_patients, + logger=logger + ) + + # Generation + generated_patients = generate_synthetic_patients( + model=model, + tokenizer=tokenizer, + patient_records=patient_records, + num_patients=args.num_synthetic, + temperature=args.temperature, + alpha=args.alpha, + device=args.device, + mimic3_root=args.mimic3_root + ) + + # Save results + output_csv = Path(args.output_dir) / f"synthetic_patients_{args.num_synthetic}.csv" + save_synthetic_dataset(generated_patients, output_csv, format="csv") + + logger.info("\n" + "=" * 80) + logger.info("PromptEHR Pipeline Complete!") + logger.info("=" * 80) + logger.info(f"Output directory: {args.output_dir}") + logger.info(f"Synthetic dataset: {output_csv}") + + +if __name__ == "__main__": + main() diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index 14f0bf209..6cdf2ea45 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -16,6 +16,7 @@ from .grasp import GRASP, GRASPLayer from .medlink import MedLink from .micron import MICRON, MICRONLayer +from .promptehr import PromptEHR from .mlp import MLP from .molerec import MoleRec, MoleRecLayer from .retain import RETAIN, RETAINLayer diff --git a/pyhealth/models/promptehr/__init__.py b/pyhealth/models/promptehr/__init__.py new file mode 100644 index 000000000..fdf1327a3 --- /dev/null +++ b/pyhealth/models/promptehr/__init__.py @@ -0,0 +1,41 @@ +"""PromptEHR: Prompt-based BART model for synthetic EHR generation. + +This module provides a demographic-conditioned sequence-to-sequence model +for generating realistic synthetic electronic health records. + +Main components: + - PromptEHR: Main model class (inherits from BaseModel) + - ConditionalPromptEncoder: Demographic conditioning with reparameterization + - PromptBartEncoder: Modified BART encoder with prompt injection + - PromptBartDecoder: Modified BART decoder with prompt injection + - VisitStructureSampler: Utility for structure-constrained generation + - Generation functions: sample_demographics, parse_sequence_to_visits, etc. +""" + +from .model import PromptEHR +from .conditional_prompt import ConditionalPromptEncoder +from .bart_encoder import PromptBartEncoder +from .bart_decoder import PromptBartDecoder +from .visit_sampler import VisitStructureSampler +from .generation import ( + DemographicSampler, + sample_demographics, + decode_patient_demographics, + parse_sequence_to_visits, + generate_patient_sequence_conditional, + generate_patient_with_structure_constraints +) + +__all__ = [ + "PromptEHR", + "ConditionalPromptEncoder", + "PromptBartEncoder", + "PromptBartDecoder", + "VisitStructureSampler", + "DemographicSampler", + "sample_demographics", + "decode_patient_demographics", + "parse_sequence_to_visits", + "generate_patient_sequence_conditional", + "generate_patient_with_structure_constraints", +] diff --git a/pyhealth/models/promptehr/bart_decoder.py b/pyhealth/models/promptehr/bart_decoder.py new file mode 100644 index 000000000..e6d01a70b --- /dev/null +++ b/pyhealth/models/promptehr/bart_decoder.py @@ -0,0 +1,325 @@ +"""BART decoder with prompt injection for demographic conditioning. + +This module provides a modified BART decoder that accepts demographic prompt +embeddings and prepends them to decoder input sequences for conditioning. + +Ported from pehr_scratch/prompt_bart_decoder.py (lines 1-207). +""" + +import torch +import torch.nn as nn +from typing import Optional, Tuple +from transformers.models.bart.modeling_bart import BartDecoder +from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions + + +class PromptBartDecoder(BartDecoder): + """BART decoder modified to accept and prepend demographic prompt embeddings. + + Extends the standard BART decoder to support prompt-based conditioning by: + 1. Accepting optional prompt embeddings as input + 2. Prepending prompts to decoder input token embeddings + 3. Extending attention masks to cover prepended prompts + 4. Creating causal masks for autoregressive generation + 5. Processing through standard BART decoder layers with cross-attention + + This enables demographic conditioning (age + gender) by injecting learned + prompt vectors at the decoder input, maintaining demographic alignment + during generation (dual prompt injection with encoder). + + Args: + config: BartConfig from transformers + embed_tokens: Token embedding layer (optional) + + Example: + >>> from transformers import BartConfig + >>> config = BartConfig.from_pretrained("facebook/bart-base") + >>> decoder = PromptBartDecoder(config) + >>> # Decode with prompts + >>> prompt_embeds = torch.randn(16, 2, 768) # [batch, n_prompts, hidden] + >>> input_ids = torch.randint(0, 1000, (16, 50)) # [batch, tgt_len] + >>> encoder_outputs = torch.randn(16, 100, 768) # [batch, src_len, hidden] + >>> outputs = decoder( + ... input_ids, + ... encoder_hidden_states=encoder_outputs, + ... inputs_prompt_embeds=prompt_embeds + ... ) + """ + + def __init__(self, config, embed_tokens=None): + """Initialize prompt-aware BART decoder. + + Args: + config: BartConfig from transformers + embed_tokens: Optional token embedding layer + """ + super().__init__(config, embed_tokens) + + # Initialize embedding scale factor (BART uses sqrt(d_model) scaling) + self.embed_scale = None + if config.scale_embedding: + self.embed_scale = (config.d_model ** 0.5) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + inputs_prompt_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> BaseModelOutputWithPastAndCrossAttentions: + """Forward pass with optional demographic prompt embeddings. + + Args: + input_ids: [batch, tgt_seq_len] decoder token IDs + attention_mask: [batch, tgt_seq_len] decoder attention mask (1=attend, 0=ignore) + encoder_hidden_states: [batch, src_seq_len, hidden_dim] encoder outputs + encoder_attention_mask: [batch, src_seq_len] encoder attention mask + head_mask: [num_layers, num_heads] mask for self-attention heads + cross_attn_head_mask: [num_layers, num_heads] mask for cross-attention heads + past_key_values: Cached key-value states for efficient generation + inputs_embeds: [batch, tgt_seq_len, hidden_dim] pre-computed embeddings (optional) + inputs_prompt_embeds: [batch, n_prompts, hidden_dim] demographic prompts (optional) + use_cache: Whether to return key-value cache for generation + output_attentions: Whether to return attention weights + output_hidden_states: Whether to return all hidden states + return_dict: Whether to return BaseModelOutputWithPastAndCrossAttentions or tuple + + Returns: + BaseModelOutputWithPastAndCrossAttentions with: + - last_hidden_state: [batch, n_prompts + tgt_len, hidden_dim] + - past_key_values: Cached key-value states (if use_cache=True) + - hidden_states: Tuple of all layer outputs (if output_hidden_states=True) + - attentions: Tuple of self-attention weights (if output_attentions=True) + - cross_attentions: Tuple of cross-attention weights (if output_attentions=True) + """ + # Set output flags from config defaults + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Get decoder input embeddings from token IDs + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + # Apply embedding scaling if configured + if self.embed_scale is not None: + inputs_embeds = inputs_embeds * self.embed_scale + + # Store original sequence length before prepending prompts + original_seq_len = inputs_embeds.shape[1] + + # Prepend prompt embeddings if provided + if inputs_prompt_embeds is not None: + # Concatenate prompts before decoder input embeddings + # inputs_prompt_embeds: [batch, n_prompts, hidden_dim] + # inputs_embeds: [batch, tgt_len, hidden_dim] + # Result: [batch, n_prompts + tgt_len, hidden_dim] + inputs_embeds = torch.cat([inputs_prompt_embeds, inputs_embeds], dim=1) + + # Extend attention mask for prepended prompts + batch_size, n_prompts = inputs_prompt_embeds.shape[:2] + + # Create attention mask for prompts (all 1s - always attend to prompts) + prompt_attention_mask = torch.ones( + batch_size, n_prompts, + dtype=attention_mask.dtype if attention_mask is not None else torch.long, + device=inputs_embeds.device + ) + + if attention_mask is not None: + # Concatenate prompt mask with decoder attention mask + attention_mask = torch.cat([prompt_attention_mask, attention_mask], dim=1) + else: + # Create attention mask for all tokens (prompts + decoder input) + total_seq_len = inputs_embeds.shape[1] + attention_mask = torch.ones( + batch_size, total_seq_len, + dtype=torch.long, + device=inputs_embeds.device + ) + + # Get positional embeddings for full sequence (prompts + decoder tokens) + past_key_values_length = 0 + if past_key_values is not None: + # Handle Cache object (new transformers API) or tuple (old API) + if hasattr(past_key_values, 'get_seq_length'): + past_key_values_length = past_key_values.get_seq_length() + elif isinstance(past_key_values, (tuple, list)) and len(past_key_values) > 0: + # Defensive: handle unexpected cache structures gracefully + # pehr-scratch-expert confirmed: defaulting to 0 is safe (slightly degrades + # quality but prevents crash). BART handles positional errors gracefully. + try: + if past_key_values[0] is not None and isinstance(past_key_values[0], (tuple, list)): + if len(past_key_values[0]) > 0 and past_key_values[0][0] is not None: + past_key_values_length = past_key_values[0][0].shape[2] + except (IndexError, TypeError, AttributeError): + # Safe fallback: slightly degrades quality but prevents crash + # Positional embeddings will be calculated from position 0 + past_key_values_length = 0 + + # Get positional embeddings (BART uses learned positional embeddings) + positions = self.embed_positions(inputs_embeds, past_key_values_length) + + # Combine input embeddings + positional embeddings + hidden_states = inputs_embeds + positions + hidden_states = self.layernorm_embedding(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + # Create combined attention mask (causal + padding) + if attention_mask is not None: + # Create causal mask for decoder self-attention + combined_attention_mask = _make_causal_mask( + inputs_embeds.shape[:2], + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + # Expand padding mask and combine with causal mask + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=inputs_embeds.shape[1]) + combined_attention_mask = combined_attention_mask + expanded_attn_mask + else: + # Create causal mask only (no padding) + combined_attention_mask = _make_causal_mask( + inputs_embeds.shape[:2], + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + + # Expand encoder attention mask for cross-attention + if encoder_hidden_states is not None and encoder_attention_mask is not None: + # [batch, src_len] → [batch, 1, tgt_len, src_len] + encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=inputs_embeds.shape[1]) + + # Initialize output containers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + + # Pass through decoder layers + for idx, decoder_layer in enumerate(self.layers): + # Save hidden state before layer if requested + if output_hidden_states: + all_hidden_states += (hidden_states,) + + # Forward through decoder layer + layer_outputs = decoder_layer( + hidden_states, + attention_mask=combined_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + # Update hidden states + hidden_states = layer_outputs[0] + + # Save attention weights if requested + if output_attentions: + all_self_attns += (layer_outputs[1],) + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + # Save final hidden state if requested + if output_hidden_states: + all_hidden_states += (hidden_states,) + + # Cache is handled by past_key_values object, not returned in tuple + next_cache = past_key_values if use_cache else None + + # Return tuple format if not using return_dict + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + if v is not None + ) + + # Return BaseModelOutputWithPastAndCrossAttentions + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +def _make_causal_mask( + input_shape: Tuple[int, int], + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0 +) -> torch.Tensor: + """Create causal mask for decoder self-attention. + + Creates a lower-triangular mask that prevents attending to future positions. + This is essential for autoregressive generation where each position can only + attend to earlier positions. + + Args: + input_shape: (batch_size, tgt_len) shape of decoder input + dtype: Data type for mask tensor + device: Device to create mask on + past_key_values_length: Length of cached key-values from previous steps + + Returns: + [batch, 1, tgt_len, tgt_len + past_len] causal mask with -inf for future positions + """ + batch_size, tgt_len = input_shape + + # Initialize mask with -inf (prevents attention) + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + + # Create lower triangular mask (0 for allowed positions, -inf for future) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + # If using cached key-values, allow attending to all past positions + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + + # Expand to [batch, 1, tgt_len, tgt_len + past_len] + return mask[None, None, :, :].expand(batch_size, 1, tgt_len, tgt_len + past_key_values_length) + + +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None) -> torch.Tensor: + """Expand attention mask from [batch, src_len] to [batch, 1, tgt_len, src_len]. + + Inverts the mask (1→0, 0→1) and fills masked positions with -inf to prevent attention. + + Args: + mask: [batch, src_len] attention mask (1=attend, 0=ignore) + dtype: Target data type for the expanded mask + tgt_len: Target sequence length (defaults to src_len) + + Returns: + [batch, 1, tgt_len, src_len] expanded mask with -inf for masked positions + """ + batch_size, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + # Expand dimensions: [batch, src_len] → [batch, 1, tgt_len, src_len] + expanded_mask = mask[:, None, None, :].expand(batch_size, 1, tgt_len, src_len).to(dtype) + + # Invert mask: 1 (attend) → 0, 0 (ignore) → 1 + inverted_mask = 1.0 - expanded_mask + + # Fill masked positions with -inf (prevents attention) + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) diff --git a/pyhealth/models/promptehr/bart_encoder.py b/pyhealth/models/promptehr/bart_encoder.py new file mode 100644 index 000000000..726f34cb9 --- /dev/null +++ b/pyhealth/models/promptehr/bart_encoder.py @@ -0,0 +1,214 @@ +"""BART encoder with prompt injection for demographic conditioning. + +This module provides a modified BART encoder that accepts demographic prompt +embeddings and prepends them to input sequences for conditioning. + +Ported from pehr_scratch/prompt_bart_encoder.py (lines 1-149). +""" + +import torch +import torch.nn as nn +from typing import Optional +from transformers.models.bart.modeling_bart import BartEncoder +from transformers.modeling_outputs import BaseModelOutput + + +class PromptBartEncoder(BartEncoder): + """BART encoder modified to accept and prepend demographic prompt embeddings. + + Extends the standard BART encoder to support prompt-based conditioning by: + 1. Accepting optional prompt embeddings as input + 2. Prepending prompts to input token embeddings + 3. Extending attention masks to cover prepended prompts + 4. Processing through standard BART encoder layers + + This enables demographic conditioning (age + gender) by injecting learned + prompt vectors at the encoder input. + + Args: + config: BartConfig from transformers + embed_tokens: Token embedding layer (optional) + + Example: + >>> from transformers import BartConfig + >>> config = BartConfig.from_pretrained("facebook/bart-base") + >>> encoder = PromptBartEncoder(config) + >>> # Encode with prompts + >>> prompt_embeds = torch.randn(16, 2, 768) # [batch, n_prompts, hidden] + >>> input_ids = torch.randint(0, 1000, (16, 100)) # [batch, seq_len] + >>> outputs = encoder(input_ids, inputs_prompt_embeds=prompt_embeds) + """ + + def __init__(self, config, embed_tokens=None): + """Initialize prompt-aware BART encoder. + + Args: + config: BartConfig from transformers + embed_tokens: Optional token embedding layer + """ + super().__init__(config, embed_tokens) + + # Initialize embedding scale factor (BART uses sqrt(d_model) scaling) + self.embed_scale = None + if config.scale_embedding: + self.embed_scale = (config.d_model ** 0.5) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + inputs_prompt_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> BaseModelOutput: + """Forward pass with optional demographic prompt embeddings. + + Args: + input_ids: [batch, seq_len] token IDs + attention_mask: [batch, seq_len] attention mask (1=attend, 0=ignore) + head_mask: [num_layers, num_heads] mask for attention heads + inputs_embeds: [batch, seq_len, hidden_dim] pre-computed embeddings (optional) + inputs_prompt_embeds: [batch, n_prompts, hidden_dim] demographic prompts (optional) + output_attentions: Whether to return attention weights + output_hidden_states: Whether to return all hidden states + return_dict: Whether to return BaseModelOutput or tuple + + Returns: + BaseModelOutput with: + - last_hidden_state: [batch, n_prompts + seq_len, hidden_dim] + - hidden_states: Tuple of all layer outputs (if output_hidden_states=True) + - attentions: Tuple of attention weights (if output_attentions=True) + """ + # Set output flags from config defaults + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Get input embeddings from token IDs + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + # Apply embedding scaling if configured + if self.embed_scale is not None: + inputs_embeds = inputs_embeds * self.embed_scale + + # Prepend prompt embeddings if provided + if inputs_prompt_embeds is not None: + # Concatenate prompts before input embeddings + # inputs_prompt_embeds: [batch, n_prompts, hidden_dim] + # inputs_embeds: [batch, seq_len, hidden_dim] + # Result: [batch, n_prompts + seq_len, hidden_dim] + inputs_embeds = torch.cat([inputs_prompt_embeds, inputs_embeds], dim=1) + + # Extend attention mask to account for prepended prompts + batch_size, n_prompts = inputs_prompt_embeds.shape[:2] + + if attention_mask is not None: + # Create attention mask for prompts matching existing mask dtype/device + prompt_attention_mask = torch.ones( + batch_size, n_prompts, + dtype=attention_mask.dtype, + device=attention_mask.device + ) + # Concatenate prompt mask with original mask + attention_mask = torch.cat([prompt_attention_mask, attention_mask], dim=1) + else: + # Create full attention mask for prompts + sequence + seq_len = inputs_embeds.shape[1] # Total length including prompts already prepended + attention_mask = torch.ones( + batch_size, seq_len, + dtype=torch.long, + device=inputs_embeds.device + ) + + # Get positional embeddings (BART uses learned positional embeddings) + embed_pos = self.embed_positions(inputs_embeds) + + # Combine input embeddings + positional embeddings + hidden_states = inputs_embeds + embed_pos + hidden_states = self.layernorm_embedding(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + # Expand attention mask from [batch, seq_len] to [batch, 1, tgt_len, src_len] + if attention_mask is not None: + attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) + + # Initialize output containers + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # Validate head_mask dimensionality + if head_mask is not None: + if head_mask.size()[0] != len(self.layers): + raise ValueError( + f"head_mask should have {len(self.layers)} layers, but has {head_mask.size()[0]}" + ) + + # Pass through encoder layers + for idx, encoder_layer in enumerate(self.layers): + # Save hidden state before layer if requested + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + # Get layer-specific head mask + layer_head_mask = head_mask[idx] if head_mask is not None else None + + # Forward through encoder layer + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + + # Update hidden states + hidden_states = layer_outputs[0] + + # Save attention weights if requested + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + # Save final hidden state if requested + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + # Return tuple format if not using return_dict + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + + # Return BaseModelOutput + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=encoder_states, + attentions=all_attentions, + ) + + +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None) -> torch.Tensor: + """Expand attention mask from [batch, src_len] to [batch, 1, tgt_len, src_len]. + + Inverts the mask (1→0, 0→1) and fills masked positions with -inf to prevent attention. + + Args: + mask: [batch, src_len] attention mask (1=attend, 0=ignore) + dtype: Target data type for the expanded mask + tgt_len: Target sequence length (defaults to src_len for encoder self-attention) + + Returns: + [batch, 1, tgt_len, src_len] expanded mask with -inf for masked positions + """ + batch_size, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + # Expand dimensions: [batch, src_len] → [batch, 1, tgt_len, src_len] + expanded_mask = mask[:, None, None, :].expand(batch_size, 1, tgt_len, src_len).to(dtype) + + # Invert mask: 1 (attend) → 0, 0 (ignore) → 1 + inverted_mask = 1.0 - expanded_mask + + # Fill masked positions with -inf (prevents attention) + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) diff --git a/pyhealth/models/promptehr/conditional_prompt.py b/pyhealth/models/promptehr/conditional_prompt.py new file mode 100644 index 000000000..4122a5d31 --- /dev/null +++ b/pyhealth/models/promptehr/conditional_prompt.py @@ -0,0 +1,251 @@ +"""Conditional prompt encoder for demographic conditioning. + +This module provides demographic conditioning through prompt-based learning +with reparameterization to prevent overfitting. + +Ported from pehr_scratch/conditional_prompt.py (lines 1-219). +""" + +import torch +import torch.nn as nn +from typing import Optional + + +class NumericalConditionalPrompt(nn.Module): + """Embeds continuous numerical features (e.g., age) with reparameterization. + + Uses intermediate d_hidden=128 dimension for better gradient flow and + regularization, following PromptEHR's architecture. + """ + + def __init__( + self, + n_num_features: int, + hidden_dim: int, + d_hidden: int = 128, + prompt_length: int = 1 + ): + """Initialize numerical prompt encoder with reparameterization. + + Args: + n_num_features: Number of continuous features (1 for age only) + hidden_dim: Output dimension size (768 for BART-base) + d_hidden: Intermediate reparameterization dimension (default: 128) + prompt_length: Number of prompt vectors per feature (default: 1) + """ + super().__init__() + self.n_num_features = n_num_features + self.hidden_dim = hidden_dim + self.d_hidden = d_hidden + self.prompt_length = prompt_length + + # Reparameterization: learned weight and bias in d_hidden space + self.weight = nn.Parameter(torch.Tensor(n_num_features, d_hidden)) + self.bias = nn.Parameter(torch.Tensor(n_num_features, d_hidden)) + nn.init.xavier_uniform_(self.weight) + nn.init.xavier_uniform_(self.bias) + + # Project from d_hidden to output dimension + self.proj = nn.Linear(d_hidden, hidden_dim, bias=False) + + def forward(self, x_num: torch.Tensor) -> torch.Tensor: + """Embed numerical features with reparameterization. + + Args: + x_num: [batch, n_num_features] continuous values + + Returns: + [batch, prompt_length * n_num_features, hidden_dim] embeddings + """ + # Reparameterization: weight * value + bias + # x_num: [batch, n_num_features] + # weight: [n_num_features, d_hidden] + # Result: [batch, n_num_features, d_hidden] + x = self.weight[None] * x_num[..., None] + x = x + self.bias[None] + + # Project to output dimension + # x: [batch, n_num_features, d_hidden] → [batch, n_num_features, hidden_dim] + x = self.proj(x) + + # Output: [batch, n_num_features * prompt_length, hidden_dim] + return x + + +class CategoricalConditionalPrompt(nn.Module): + """Embeds categorical features with offset-based indexing and reparameterization. + + Uses single embedding table with offset-based indexing to prevent category + collision, following PromptEHR's architecture. + """ + + def __init__( + self, + cat_cardinalities: list, + hidden_dim: int, + d_hidden: int = 128, + prompt_length: int = 1 + ): + """Initialize categorical prompt encoder with reparameterization. + + Args: + cat_cardinalities: List of category counts for each feature + [2] for gender (M/F) - ethnicity removed + hidden_dim: Output dimension size (768 for BART-base) + d_hidden: Intermediate reparameterization dimension (default: 128) + prompt_length: Number of prompt vectors per feature (default: 1) + """ + super().__init__() + assert cat_cardinalities, 'cat_cardinalities must be non-empty' + self.cat_cardinalities = cat_cardinalities + self.hidden_dim = hidden_dim + self.d_hidden = d_hidden + self.prompt_length = prompt_length + + # Compute offset indices to prevent category collision + # Example: [2] → offsets = [0] + # Gender 0 (M) → index 0, Gender 1 (F) → index 1 + category_offsets = torch.tensor([0] + cat_cardinalities[:-1]).cumsum(0) + self.register_buffer('category_offsets', category_offsets, persistent=False) + + # Single embedding table for all categories + total_categories = sum(cat_cardinalities) + self.embeddings = nn.Embedding(total_categories, d_hidden) + + # Learned bias per feature (not per category) + self.bias = nn.Parameter(torch.Tensor(len(cat_cardinalities), d_hidden)) + nn.init.xavier_uniform_(self.bias) + + # Project from d_hidden to output dimension + self.proj = nn.Linear(d_hidden, hidden_dim, bias=False) + + def forward(self, x_cat: torch.Tensor) -> torch.Tensor: + """Embed categorical features with offset-based indexing. + + Args: + x_cat: [batch, n_cat_features] categorical IDs + + Returns: + [batch, n_cat_features * prompt_length, hidden_dim] embeddings + """ + # Add offsets to prevent category collision + # x_cat: [batch, n_cat_features] + # category_offsets: [n_cat_features] + x = self.embeddings(x_cat + self.category_offsets[None]) + + # Add learned bias per feature + # x: [batch, n_cat_features, d_hidden] + # bias: [n_cat_features, d_hidden] + x = x + self.bias[None] + + # Project to output dimension + # x: [batch, n_cat_features, d_hidden] → [batch, n_cat_features, hidden_dim] + x = self.proj(x) + + # Output: [batch, n_cat_features * prompt_length, hidden_dim] + return x + + +class ConditionalPromptEncoder(nn.Module): + """Combined prompt encoder for both numerical and categorical features. + + Encodes patient demographics (age + gender) into prompt vectors that + condition the BART encoder and decoder. + + Example: + >>> # For PromptEHR: age (continuous) + gender (categorical) + >>> encoder = ConditionalPromptEncoder( + ... n_num_features=1, # age + ... cat_cardinalities=[2], # gender (M/F) + ... hidden_dim=768, # BART dimension + ... d_hidden=128 # reparameterization + ... ) + >>> # Batch of 16 patients + >>> age = torch.randn(16, 1) # Normalized ages + >>> gender = torch.randint(0, 2, (16, 1)) # 0=M, 1=F + >>> prompts = encoder(x_num=age, x_cat=gender) + >>> prompts.shape # [16, 2, 768] - 2 prompts (age + gender) + """ + + def __init__( + self, + n_num_features: Optional[int] = None, + cat_cardinalities: Optional[list] = None, + hidden_dim: int = 768, + d_hidden: int = 128, + prompt_length: int = 1 + ): + """Initialize combined prompt encoder. + + Args: + n_num_features: Number of continuous features (None to disable) + cat_cardinalities: Category counts for each categorical feature (None to disable) + hidden_dim: Hidden dimension size (768 for BART-base) + d_hidden: Intermediate reparameterization dimension (default: 128) + prompt_length: Number of prompt vectors per feature (default: 1) + """ + super().__init__() + self.n_num_features = n_num_features + self.cat_cardinalities = cat_cardinalities + self.hidden_dim = hidden_dim + self.d_hidden = d_hidden + self.prompt_length = prompt_length + + # Initialize numerical prompt encoder (age) + if n_num_features is not None and n_num_features > 0: + self.num_prompt = NumericalConditionalPrompt( + n_num_features, hidden_dim, d_hidden, prompt_length + ) + else: + self.num_prompt = None + + # Initialize categorical prompt encoder (gender) + if cat_cardinalities is not None and len(cat_cardinalities) > 0: + self.cat_prompt = CategoricalConditionalPrompt( + cat_cardinalities, hidden_dim, d_hidden, prompt_length + ) + else: + self.cat_prompt = None + + def forward( + self, + x_num: Optional[torch.Tensor] = None, + x_cat: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Encode demographics to prompt embeddings. + + Args: + x_num: [batch, n_num_features] continuous values (optional) + x_cat: [batch, n_cat_features] categorical IDs (optional) + + Returns: + [batch, total_prompts, hidden_dim] combined prompt embeddings + """ + prompts = [] + + if x_num is not None and self.num_prompt is not None: + num_embeds = self.num_prompt(x_num) + prompts.append(num_embeds) + + if x_cat is not None and self.cat_prompt is not None: + cat_embeds = self.cat_prompt(x_cat) + prompts.append(cat_embeds) + + if len(prompts) == 0: + raise ValueError("No prompt embeddings generated. Provide x_num or x_cat.") + + # Concatenate along prompt dimension + combined_prompts = torch.cat(prompts, dim=1) + return combined_prompts + + def get_num_prompts(self) -> int: + """Calculate total number of prompt tokens.""" + num_prompts = 0 + + if self.num_prompt is not None: + num_prompts += self.n_num_features * self.prompt_length + + if self.cat_prompt is not None: + num_prompts += len(self.cat_cardinalities) * self.prompt_length + + return num_prompts diff --git a/pyhealth/models/promptehr/generation.py b/pyhealth/models/promptehr/generation.py new file mode 100644 index 000000000..3d674d1d1 --- /dev/null +++ b/pyhealth/models/promptehr/generation.py @@ -0,0 +1,1070 @@ +""" +Generate synthetic patient sequences using trained PromptEHR model. + +This module provides functions for generating realistic synthetic EHR data +using various conditioning strategies (demographics, visit structures, etc.). +""" +import json +import math +import numpy as np +import torch +from pathlib import Path +from typing import Optional, List, Union, Dict + + +class DemographicSampler: + """Sample patient demographics from empirical training distribution. + + Samples age and gender by directly drawing from the observed distribution + in training data, ensuring synthetic patients match real population. + """ + + def __init__(self, patient_records: List, seed: int = 42): + """Initialize sampler with empirical demographics from training data. + + Args: + patient_records: List of patient records from training set. + Each record should have 'age' and 'gender' attributes. + seed: Random seed for reproducibility. + """ + self.rng = np.random.RandomState(seed) + + # Extract empirical demographics + self.ages = [] + self.genders = [] + + for patient in patient_records: + # Handle both dict-like and object-like patient records + if hasattr(patient, 'age') and hasattr(patient, 'gender'): + age = patient.age + gender = patient.gender + elif isinstance(patient, dict) and 'age' in patient and 'gender' in patient: + age = patient['age'] + gender = patient['gender'] + else: + continue + + self.ages.append(float(age)) + # Convert gender to int: M=0, F=1 + if isinstance(gender, str): + gender_int = 0 if gender == 'M' else 1 + else: + gender_int = int(gender) + self.genders.append(gender_int) + + # Convert to numpy arrays + self.ages = np.array(self.ages) + self.genders = np.array(self.genders) + + # Compute statistics + self.stats = { + 'age_mean': np.mean(self.ages), + 'age_std': np.std(self.ages), + 'age_median': np.median(self.ages), + 'age_min': np.min(self.ages), + 'age_max': np.max(self.ages), + 'male_pct': (self.genders == 0).mean(), + 'female_pct': (self.genders == 1).mean(), + } + + def sample(self) -> dict: + """Sample demographics from empirical distribution. + + Returns: + Dictionary with: + - 'age': float (sampled from training ages) + - 'sex': int (0=Male, 1=Female, sampled from training) + - 'sex_str': str ('M' or 'F') + """ + # Sample random index from training data + idx = self.rng.randint(0, len(self.ages)) + + age = self.ages[idx] + sex = self.genders[idx] + sex_str = 'M' if sex == 0 else 'F' + + return { + 'age': float(age), + 'sex': int(sex), + 'sex_str': sex_str + } + + def __repr__(self): + return ( + f"DemographicSampler(\n" + f" Age: mean={self.stats['age_mean']:.1f}, " + f"std={self.stats['age_std']:.1f}, " + f"range=[{self.stats['age_min']:.0f}, {self.stats['age_max']:.0f}]\n" + f" Gender: {self.stats['male_pct']:.1%} Male, " + f"{self.stats['female_pct']:.1%} Female\n" + f")" + ) + + +def build_first_code_prior( + training_data_path: str, + age_bins: int = 9 +) -> Dict: + """Build empirical P(first_code | age, gender) from training data. + + Args: + training_data_path: Path to training data directory with MIMIC-III files + age_bins: Number of age bins (default: 9 for [0-10), [10-20), ..., [80-90]) + + Returns: + Dictionary mapping (age_bin, gender) -> {code: probability} + + Example: + >>> prior = build_first_code_prior('/path/to/train_data') + >>> first_code = sample_first_code(65, 0, prior) + """ + import pandas as pd + + # Load training data + admissions = pd.read_csv(f'{training_data_path}/ADMISSIONS.csv') + patients = pd.read_csv(f'{training_data_path}/PATIENTS.csv') + diagnoses = pd.read_csv(f'{training_data_path}/DIAGNOSES_ICD.csv') + + # Calculate age at first admission + admissions['ADMITTIME'] = pd.to_datetime(admissions['ADMITTIME']) + patients['DOB'] = pd.to_datetime(patients['DOB']) + + first_admissions = admissions.loc[ + admissions.groupby('SUBJECT_ID')['ADMITTIME'].idxmin() + ][['SUBJECT_ID', 'HADM_ID', 'ADMITTIME']] + + demo = pd.merge( + patients[['SUBJECT_ID', 'GENDER', 'DOB']], + first_admissions, + on='SUBJECT_ID', + how='inner' + ) + demo['AGE'] = (demo['ADMITTIME'].dt.year - demo['DOB'].dt.year) + demo['AGE'] = demo['AGE'].apply(lambda x: 90 if x > 89 else max(0, x)) + + # Get first diagnosis codes + first_diag = pd.merge( + demo[['SUBJECT_ID', 'HADM_ID', 'AGE', 'GENDER']], + diagnoses[['SUBJECT_ID', 'HADM_ID', 'ICD9_CODE']], + on=['SUBJECT_ID', 'HADM_ID'], + how='inner' + ) + + # Keep only first code per patient (seq_num=1 or first alphabetically) + first_diag = first_diag.sort_values(['SUBJECT_ID', 'ICD9_CODE']) + first_diag = first_diag.groupby('SUBJECT_ID').first().reset_index() + + # Bin ages + first_diag['age_bin'] = pd.cut( + first_diag['AGE'], + bins=list(range(0, 91, 10)), + labels=list(range(age_bins)), + include_lowest=True + ) + + # Convert gender to int (0=M, 1=F) + first_diag['gender_int'] = (first_diag['GENDER'] == 'F').astype(int) + + # Calculate empirical distribution + dist = {} + for (age_bin, gender), group in first_diag.groupby(['age_bin', 'gender_int']): + code_counts = group['ICD9_CODE'].value_counts() + total = code_counts.sum() + dist[(int(age_bin), int(gender))] = { + str(code): count / total + for code, count in code_counts.items() + } + + return dist + + +def sample_first_code( + age: float, + gender: int, + first_code_prior: Dict +) -> str: + """Sample first diagnosis code from empirical distribution. + + Args: + age: Patient age (0-90) + gender: Patient gender (0=Male, 1=Female) + first_code_prior: Prior from build_first_code_prior() + + Returns: + Diagnosis code string (e.g., 'V3000', '41401') + + Example: + >>> prior = build_first_code_prior('/path/to/train_data') + >>> code = sample_first_code(65, 0, prior) + >>> print(code) # e.g., 'V3000' + """ + # Bin age + age_bin = min(int(age // 10), 8) # [0-9] -> 0, [10-19] -> 1, ..., [80+] -> 8 + + # Get distribution for this demographic + key = (age_bin, gender) + if key not in first_code_prior: + # Fallback to gender-only or overall distribution + fallback_key = None + for k in first_code_prior.keys(): + if k[1] == gender: + fallback_key = k + break + if fallback_key: + key = fallback_key + else: + key = list(first_code_prior.keys())[0] + + code_probs = first_code_prior[key] + codes = list(code_probs.keys()) + probs = list(code_probs.values()) + + return np.random.choice(codes, p=probs) + + +def build_frequency_prior( + tokenizer, + frequency_path: Optional[Union[str, Path]] = None, + epsilon: float = 1e-10, + vocab_size: Optional[int] = None +) -> torch.Tensor: + """Build log-frequency prior over vocabulary for frequency-guided generation. + + Args: + tokenizer: DiagnosisCodeTokenizer with vocab and code_offset attributes. + frequency_path: Path to training_frequencies.json. If None, uses uniform prior. + epsilon: Small constant to avoid log(0) (default: 1e-10). + vocab_size: Model vocabulary size. If None, inferred from tokenizer (not recommended). + Should match model's lm_head output dimension. + + Returns: + torch.Tensor of shape [vocab_size] with log-frequencies. + Special tokens get 0 (neutral prior), diagnosis codes get log(freq + epsilon). + + Example: + >>> prior = build_frequency_prior(tokenizer, './promptehr_outputs/training_frequencies.json', vocab_size=6963) + >>> logits_guided = logits + alpha * prior # Blend with model logits + """ + # Use provided vocab size or infer from tokenizer + # WARNING: Inferred size may not match model if there's a mismatch! + if vocab_size is None: + vocab_size = len(tokenizer.vocab.idx2code) + + log_freqs = torch.zeros(vocab_size) + + if frequency_path is None: + # Uniform fallback: all codes equally likely + uniform_log_freq = math.log(1.0 / len(tokenizer.vocab.idx2code)) + log_freqs[tokenizer.code_offset:] = uniform_log_freq + return log_freqs + + # Load training frequencies + with open(frequency_path, 'r') as f: + freq_data = json.load(f) + + frequencies = freq_data['frequencies'] + + # Fill in log-frequencies for each code + # NOTE: We map code_idx directly to token_id without adding code_offset + # because the model vocabulary doesn't include code_offset + for code, freq in frequencies.items(): + if code in tokenizer.vocab.code2idx: + code_idx = tokenizer.vocab.code2idx[code] + if code_idx < vocab_size: + log_freqs[code_idx] = math.log(freq + epsilon) + + # Codes not in training data get very low prior + min_log_freq = math.log(epsilon) + log_freqs = torch.where( + log_freqs == 0, + torch.tensor(min_log_freq), + log_freqs + ) + + return log_freqs + + +def sample_demographics( + age_mean: float = 60.0, + age_std: float = 20.0, + male_prob: float = 0.56 +) -> dict: + """Sample realistic patient demographics. + + Samples demographics from distributions matching MIMIC-III ICU population. + + Args: + age_mean: Mean age for normal distribution (default: 60). + age_std: Standard deviation for age (default: 20). + male_prob: Probability of male gender (default: 0.56). + + Returns: + Dictionary with: + - 'age': float in range [0, 90] + - 'sex': int (0=Male, 1=Female) + - 'sex_str': str ('M' or 'F') + """ + # Sample age from normal distribution, clipped to [0, 90] + age = np.random.normal(age_mean, age_std) + age = np.clip(age, 0, 90) + + # Sample sex from binomial distribution + sex = 0 if np.random.rand() < male_prob else 1 + sex_str = 'M' if sex == 0 else 'F' + + return { + 'age': float(age), + 'sex': sex, + 'sex_str': sex_str + } + + +def decode_patient_demographics(age: float, gender: int) -> dict: + """Decode demographics back to readable format. + + Args: + age: Normalized age value. + gender: Gender category index. + + Returns: + Dictionary with decoded demographics. + """ + # Gender mapping (from data_loader.py) + gender_map = {0: "M", 1: "F"} # Fixed: M=0, F=1 + + return { + "age": f"{age:.1f}", + "gender": gender_map.get(gender, "UNKNOWN") + } + + +def parse_sequence_to_visits( + token_ids: List[int], + tokenizer +) -> List[List[str]]: + """Parse generated token sequence into visit structure. + + Extracts visits by splitting at and markers, and decodes + diagnosis codes within each visit. + + Args: + token_ids: List of token IDs from model generation. + tokenizer: PyHealth Tokenizer instance (must have bos_token_id, + pad_token_id, code_offset, and vocab attributes). + + Returns: + List of visits, where each visit is a list of ICD-9 code strings. + + Example: + Input: [BOS, , 401.9, 250.00, , , 428.0, , ] + Output: [['401.9', '250.00'], ['428.0']] + """ + visits = [] + current_visit_codes = [] + + # Special token IDs + v_token_id = tokenizer.convert_tokens_to_indices([""])[0] + v_end_token_id = tokenizer.convert_tokens_to_indices(["<\\v>"])[0] + bos_token_id = tokenizer.bos_token_id + end_token_id = tokenizer.convert_tokens_to_indices([""])[0] + + in_visit = False + + for token_id in token_ids: + if token_id == v_token_id: + # Start of visit + in_visit = True + current_visit_codes = [] + elif token_id == v_end_token_id: + # End of visit + if in_visit: + visits.append(current_visit_codes) + in_visit = False + elif token_id in [bos_token_id, end_token_id, tokenizer.pad_token_id]: + # Skip special tokens + continue + elif in_visit and token_id >= tokenizer.code_offset: + # Diagnosis code token - token_id is already the correct vocab index + # FIX: code2idx already includes special tokens, so don't subtract offset + if token_id < len(tokenizer.vocab.idx2code): + code = tokenizer.vocab.idx2code[token_id] + current_visit_codes.append(code) + + # Handle case where sequence ends without closing visit marker + if in_visit and len(current_visit_codes) > 0: + visits.append(current_visit_codes) + + return visits + + +def generate_patient_sequence_conditional( + model, + tokenizer, + target_patient, + device: torch.device, + temperature: float = 0.3, + top_k: int = 0, # Disabled (test with top_p only) + top_p: float = 0.95, # Increased for more diversity + prompt_prob: float = 0.0, + max_codes_per_visit: int = 20 +) -> dict: + """Generate synthetic patient via conditional reconstruction (PromptEHR approach). + + Given a real patient from test set, randomly masks codes and reconstructs + the full visit structure. Default prompt_prob=0.0 means zero-code-prompt + generation (only demographics provided). + + Args: + model: Trained PromptBartModel or PromptEHR model. + tokenizer: DiagnosisCodeTokenizer instance. + target_patient: Patient record from test set to reconstruct. + Must have attributes: age, gender (or sex), visits. + device: Device to run on. + temperature: Sampling temperature (default: 0.3). + top_k: Top-k sampling parameter (default: 40). + top_p: Nucleus sampling parameter (default: 0.9). + prompt_prob: Probability of keeping each code as prompt (default: 0.0 = zero prompts). + max_codes_per_visit: Cap visit codes at this number (default: 20). + + Returns: + Dictionary with: + - 'generated_visits': List[List[str]] of generated code sequences + - 'target_visits': List[List[str]] of original codes + - 'prompt_codes': List[List[str]] of codes provided as prompts + - 'demographics': dict of patient demographics + """ + model.eval() + + # Extract demographics (handle both 'gender' and 'sex' attributes) + if hasattr(target_patient, 'age'): + age = target_patient.age + else: + age = target_patient.get('age', 60.0) + + if hasattr(target_patient, 'gender'): + gender_str = target_patient.gender + elif hasattr(target_patient, 'sex'): + gender_str = target_patient.sex + else: + gender_str = target_patient.get('gender', 'M') + + gender = 1 if gender_str == 'F' else 0 + + x_num = torch.tensor([[age]], dtype=torch.float32).to(device) + x_cat = torch.tensor([[gender]], dtype=torch.long).to(device) + + # Get visits + if hasattr(target_patient, 'visits'): + patient_visits = target_patient.visits + else: + patient_visits = target_patient.get('visits', []) + + # Initialize accumulators + generated_visits = [] + prompt_codes_per_visit = [] + + # Create dummy encoder input (prompts are in decoder) + encoder_input_ids = torch.tensor([[tokenizer.pad_token_id]], dtype=torch.long).to(device) + encoder_attention_mask = torch.ones_like(encoder_input_ids) + + # Special token IDs + v_token_id = tokenizer.convert_tokens_to_indices([""])[0] + v_end_token_id = tokenizer.convert_tokens_to_indices(["<\\v>"])[0] + + with torch.no_grad(): + # Process each visit from target patient + for visit_idx, target_codes in enumerate(patient_visits): + # Step 1: Cap codes at max_codes_per_visit + num_codes = len(target_codes) + if num_codes > max_codes_per_visit: + target_codes = list(np.random.choice(target_codes, max_codes_per_visit, replace=False)) + num_codes = max_codes_per_visit + + if num_codes == 0: + # Empty visit - skip + generated_visits.append([]) + prompt_codes_per_visit.append([]) + continue + + # Step 2: Randomly mask codes (binomial sampling) + keep_mask = np.random.binomial(1, prompt_prob, num_codes).astype(bool) + prompt_codes = [code for i, code in enumerate(target_codes) if keep_mask[i]] + + # Step 3: Encode prompt codes as decoder input + prompt_token_ids = [tokenizer.bos_token_id, v_token_id] + for code in prompt_codes: + # FIX: code2idx already returns token ID with offset included + code_token_id = tokenizer.vocab.code2idx[code] + prompt_token_ids.append(code_token_id) + + decoder_input_ids = torch.tensor([prompt_token_ids], dtype=torch.long).to(device) + + # Step 4: Generate to reconstruct full visit + max_new_tokens = num_codes + 2 # Target length + + # Use model.generate() for automatic handling + generated_ids = model.generate( + input_ids=encoder_input_ids, + attention_mask=encoder_attention_mask, + decoder_input_ids=decoder_input_ids, + x_num=x_num, + x_cat=x_cat, + max_new_tokens=max_new_tokens, + do_sample=True, + num_beams=1, # Disable beam search, use sampling only + temperature=temperature, + top_k=top_k, + top_p=top_p, + no_repeat_ngram_size=1, # Prevents duplicate codes + eos_token_id=v_end_token_id, # Stop at + pad_token_id=tokenizer.pad_token_id, + bad_words_ids=[[tokenizer.bos_token_id]] # Suppress BOS in generation + ) + + # Step 5: Extract generated codes + visit_token_ids = generated_ids[0].cpu().tolist() + + # Extract code tokens (skip BOS, , ) + generated_code_ids = [ + tid for tid in visit_token_ids + if tid >= tokenizer.code_offset + ] + + # Decode codes (convert token IDs back to diagnosis codes) + # FIX: code2idx already includes special tokens, so don't subtract offset + generated_codes = [] + for tid in generated_code_ids: + if tid < len(tokenizer.vocab.idx2code): + code = tokenizer.vocab.idx2code[tid] + generated_codes.append(code) + + # Step 6: Combine with prompt codes and deduplicate + all_codes = list(set(generated_codes + prompt_codes)) + + # Ensure exactly num_codes by sampling if needed + if len(all_codes) < num_codes: + # Not enough unique codes generated - resample with replacement + needed = num_codes - len(all_codes) + additional = list(np.random.choice(generated_codes, needed, replace=True)) if len(generated_codes) > 0 else [] + all_codes.extend(additional) + elif len(all_codes) > num_codes: + # Too many codes - sample exactly num_codes + all_codes = list(np.random.choice(all_codes, num_codes, replace=False)) + + generated_visits.append(all_codes) + prompt_codes_per_visit.append(prompt_codes) + + return { + 'generated_visits': generated_visits, + 'target_visits': patient_visits, + 'prompt_codes': prompt_codes_per_visit, + 'demographics': { + 'age': age, + 'gender': gender_str + } + } + + +def generate_patient_with_structure_constraints( + model, + tokenizer, + device: torch.device, + target_structure: dict, + age: Optional[float] = None, + sex: Optional[int] = None, + first_code: Optional[str] = None, + temperature: float = 0.7, + top_k: int = 0, # Disabled (test with top_p only) + top_p: float = 0.95, # Increased for more diversity + max_codes_per_visit: int = 25 +) -> dict: + """Generate patient with realistic visit structure constraints. + + This function generates patients visit-by-visit with controlled code counts + sampled from real data distributions, producing more realistic EHR records. + + Args: + model: Trained PromptBartModel or PromptEHR model. + tokenizer: DiagnosisCodeTokenizer instance. + device: Device to run on. + target_structure: Dict with 'num_visits' and 'codes_per_visit' list. + age: Patient age (if None, sampled from distribution). + sex: Patient sex ID (0=M, 1=F; if None, sampled). + first_code: First diagnosis code to condition on (if None, generated by model). + temperature: Sampling temperature (default: 0.7). + top_k: Top-k sampling parameter (default: 40). + top_p: Nucleus sampling parameter (default: 0.9). + max_codes_per_visit: Maximum codes per visit safety cap (default: 25). + + Returns: + Dictionary with: + - 'generated_visits': List[List[str]] of diagnosis codes + - 'demographics': dict with 'age' and 'sex' + - 'num_visits': int + - 'num_codes': int + - 'target_structure': dict (the structure we aimed for) + """ + model.eval() + + # Sample demographics if not provided + if age is None or sex is None: + sampled_demo = sample_demographics() + age = sampled_demo['age'] if age is None else age + sex = sampled_demo['sex'] if sex is None else sex + + # Prepare demographic tensors + x_num = torch.tensor([[age]], dtype=torch.float32).to(device) + x_cat = torch.tensor([[sex]], dtype=torch.long).to(device) + + # Special token IDs + bos_token_id = tokenizer.bos_token_id + v_token_id = tokenizer.convert_tokens_to_indices([""])[0] + v_end_token_id = tokenizer.convert_tokens_to_indices(["<\\v>"])[0] + end_token_id = tokenizer.convert_tokens_to_indices([""])[0] + + # Extract target structure + num_visits = target_structure['num_visits'] + codes_per_visit = target_structure['codes_per_visit'] + + # Handle case with no visits + if num_visits == 0 or len(codes_per_visit) == 0: + return { + 'generated_visits': [], + 'demographics': {'age': age, 'sex': sex}, + 'num_visits': 0, + 'num_codes': 0, + 'target_structure': target_structure + } + + # Initialize generation with empty sequence + # HuggingFace will prepend decoder_start_token_id () automatically + # This matches training pattern: [, , codes...] after first is appended + decoder_input_ids = torch.tensor([[]], dtype=torch.long).to(device) + + # If first_code provided, prepopulate decoder with + first_code (no ) + # This starts visit 0 with the sampled first code, then continues generating + first_visit_prepopulated = False + if first_code is not None and first_code in tokenizer.vocab.code2idx: + v_token_id_temp = tokenizer.convert_tokens_to_indices([""])[0] + first_code_id = tokenizer.vocab.code2idx[first_code] + + # Add , first_code to decoder_input_ids (NO yet - let generation continue) + prepop_ids = torch.tensor([[v_token_id_temp, first_code_id]], + dtype=torch.long).to(device) + decoder_input_ids = torch.cat([decoder_input_ids, prepop_ids], dim=1) + first_visit_prepopulated = True + + # Create dummy encoder input + encoder_input_ids = torch.tensor([[tokenizer.pad_token_id]], dtype=torch.long).to(device) + encoder_attention_mask = torch.ones_like(encoder_input_ids) + + all_visits = [] + + with torch.no_grad(): + for visit_idx in range(num_visits): + target_codes = min(codes_per_visit[visit_idx], max_codes_per_visit) + + # For visit 0 with prepopulated first_code, reduce target by 1 since we already have 1 code + if visit_idx == 0 and first_visit_prepopulated: + target_codes = max(1, target_codes - 1) # At least 1 more code + + # Skip if target is too small + if target_codes < 1: + continue + + # Append token to start visit + v_token_tensor = torch.tensor([[v_token_id]], dtype=torch.long).to(device) + decoder_input_ids = torch.cat([decoder_input_ids, v_token_tensor], dim=1) + + # Calculate max tokens to generate for this visit + # Each code is ~1 token, plus 1 for + # Add 50% buffer for flexibility + max_new_tokens_this_visit = int(target_codes * 1.5) + 1 + + try: + # Generate codes for this visit + generated_visit_ids = model.generate( + input_ids=encoder_input_ids, + attention_mask=encoder_attention_mask, + decoder_input_ids=decoder_input_ids, + x_num=x_num, + x_cat=x_cat, + max_new_tokens=max_new_tokens_this_visit, + do_sample=True, + num_beams=1, + temperature=temperature, + top_k=top_k, + top_p=top_p, + no_repeat_ngram_size=1, + eos_token_id=v_end_token_id, # Stop at visit end + pad_token_id=tokenizer.pad_token_id + # Note: NOT passing bos_token_id - let BART use decoder_start_token_id () automatically + ) + + # Extract only the newly generated tokens (after decoder_input_ids) + new_tokens = generated_visit_ids[0, decoder_input_ids.shape[1]:] + + # Parse the generated visit codes + visit_codes = [] + for token_id in new_tokens: + token_id_val = token_id.item() + if token_id_val == v_end_token_id: + break # End of visit + elif token_id_val >= tokenizer.code_offset: + # Diagnosis code - token_id_val is already the correct vocab index + # FIX: code2idx already includes special tokens, so don't subtract offset + if token_id_val < len(tokenizer.vocab.idx2code): + code = tokenizer.vocab.idx2code[token_id_val] + visit_codes.append(code) + + # If we generated codes, add visit + if len(visit_codes) > 0: + # Truncate to target if we over-generated + if len(visit_codes) > target_codes: + visit_codes = visit_codes[:target_codes] + + all_visits.append(visit_codes) + + # Update decoder_input_ids with the full visit (including ) + # Reconstruct the visit tokens + visit_token_ids = [v_token_id] # + for code in visit_codes: + if code in tokenizer.vocab.code2idx: + # FIX: code2idx already returns token ID with offset included + code_token_id = tokenizer.vocab.code2idx[code] + visit_token_ids.append(code_token_id) + visit_token_ids.append(v_end_token_id) # + + # Convert to tensor and concatenate (skip first since already added) + visit_tensor = torch.tensor([visit_token_ids[1:]], dtype=torch.long).to(device) + decoder_input_ids = torch.cat([decoder_input_ids, visit_tensor], dim=1) + + except Exception as e: + # If generation fails for this visit, skip it + print(f"Warning: Generation failed for visit {visit_idx + 1}: {e}") + continue + + # Check if we're approaching context limit (512 for BART) + if decoder_input_ids.shape[1] > 400: + break # Stop generating more visits + + # Compute statistics + total_codes = sum(len(visit) for visit in all_visits) + + return { + 'generated_visits': all_visits, + 'demographics': {'age': age, 'sex': sex}, + 'num_visits': len(all_visits), + 'num_codes': total_codes, + 'target_structure': target_structure + } + + +def generate_with_frequency_prior( + model, + tokenizer, + device: torch.device, + target_structure: dict, + frequency_prior: torch.Tensor, + alpha: float = 1.0, + age: Optional[float] = None, + sex: Optional[int] = None, + temperature: float = 0.7, + top_k: int = 0, + top_p: float = 0.95, + max_codes_per_visit: int = 25, + diagnostic_mode: bool = False, + diagnostic_path: Optional[str] = None +) -> dict: + """Generate patient with frequency-guided sampling. + + This function is identical to generate_patient_with_structure_constraints, + but blends model logits with training frequency prior for realistic code distributions. + + Args: + model: Trained PromptBartModel or PromptEHR model. + tokenizer: DiagnosisCodeTokenizer instance. + device: Device to run on. + target_structure: Dict with 'num_visits' and 'codes_per_visit' list. + frequency_prior: [vocab_size] log-frequency tensor from build_frequency_prior(). + alpha: Blending weight (0=pure model, higher=more frequency guidance). + Recommended: 0.5-2.0. Start with 1.0. + age: Patient age (if None, sampled from distribution). + sex: Patient sex ID (0=M, 1=F; if None, sampled). + temperature: Sampling temperature (default: 0.7). + top_k: Top-k sampling parameter (default: 0 = disabled). + top_p: Nucleus sampling parameter (default: 0.95). + max_codes_per_visit: Maximum codes per visit safety cap (default: 25). + diagnostic_mode: Enable detailed logging of generation process (default: False). + diagnostic_path: Path to save diagnostic JSON file (required if diagnostic_mode=True). + + Returns: + Dictionary with: + - 'generated_visits': List[List[str]] of diagnosis codes + - 'demographics': dict with 'age' and 'sex' + - 'num_visits': int + - 'num_codes': int + - 'target_structure': dict (the structure we aimed for) + - 'alpha': float (frequency prior weight used) + - 'diagnostics': dict (if diagnostic_mode=True) with detailed generation logs + + Example: + >>> prior = build_frequency_prior(tokenizer, './promptehr_outputs/training_frequencies.json') + >>> result = generate_with_frequency_prior( + ... model, tokenizer, device, + ... target_structure={'num_visits': 3, 'codes_per_visit': [5, 8, 6]}, + ... frequency_prior=prior, + ... alpha=1.0 + ... ) + """ + model.eval() + + # Sample demographics if not provided + if age is None or sex is None: + sampled_demo = sample_demographics() + age = sampled_demo['age'] if age is None else age + sex = sampled_demo['sex'] if sex is None else sex + + # Prepare demographic tensors + x_num = torch.tensor([[age]], dtype=torch.float32).to(device) + x_cat = torch.tensor([[sex]], dtype=torch.long).to(device) + + # Move frequency prior to device + frequency_prior = frequency_prior.to(device) + + # Special token IDs + bos_token_id = tokenizer.bos_token_id + v_token_id = tokenizer.convert_tokens_to_indices([""])[0] + v_end_token_id = tokenizer.convert_tokens_to_indices(["<\\v>"])[0] + + # Extract target structure + num_visits = target_structure['num_visits'] + codes_per_visit = target_structure['codes_per_visit'] + + # Handle case with no visits + if num_visits == 0 or len(codes_per_visit) == 0: + return { + 'generated_visits': [], + 'demographics': {'age': age, 'sex': sex}, + 'num_visits': 0, + 'num_codes': 0, + 'target_structure': target_structure, + 'alpha': alpha + } + + # Initialize generation with empty sequence + # HuggingFace will prepend decoder_start_token_id () automatically + # This matches training pattern: [, , codes...] after first is appended + decoder_input_ids = torch.tensor([[]], dtype=torch.long).to(device) + + # Create dummy encoder input + encoder_input_ids = torch.tensor([[tokenizer.pad_token_id]], dtype=torch.long).to(device) + encoder_attention_mask = torch.ones_like(encoder_input_ids) + + all_visits = [] + + # Initialize diagnostic tracking + all_diagnostics = {'visits': []} if diagnostic_mode else None + + with torch.no_grad(): + for visit_idx in range(num_visits): + target_codes = min(codes_per_visit[visit_idx], max_codes_per_visit) + + # Skip if target is too small + if target_codes < 1: + continue + + # Append token to start visit + v_token_tensor = torch.tensor([[v_token_id]], dtype=torch.long).to(device) + decoder_input_ids = torch.cat([decoder_input_ids, v_token_tensor], dim=1) + + # Generate codes for this visit with frequency guidance + max_new_tokens_this_visit = int(target_codes * 1.5) + 1 + visit_codes = [] + + # Initialize visit diagnostic tracking + visit_diagnostics = {'visit_idx': visit_idx, 'steps': []} if diagnostic_mode else None + + for step in range(max_new_tokens_this_visit): + # Forward pass + outputs = model( + input_ids=encoder_input_ids, + attention_mask=encoder_attention_mask, + decoder_input_ids=decoder_input_ids, + x_num=x_num, + x_cat=x_cat, + return_dict=True + ) + + # Get logits for next token (handle both dict and object outputs) + if hasattr(outputs, 'logits'): + logits = outputs.logits[0, -1, :] # [vocab_size] + elif isinstance(outputs, dict) and 'logits' in outputs: + logits = outputs['logits'][0, -1, :] # [vocab_size] + else: + raise TypeError(f"Unexpected output type: {type(outputs)}") + + # Diagnostic logging: raw model logits + if diagnostic_mode: + step_diagnostics = { + 'step': step, + 'raw_logits': { + 'max': float(logits.max()), + 'min': float(logits.min()), + 'mean': float(logits.mean()), + 'std': float(logits.std()), + 'top_5_indices': [int(i) for i in logits.topk(5).indices], + 'top_5_codes': [tokenizer.vocab.idx2code.get(int(i), f"<{i}>") + for i in logits.topk(5).indices], + 'top_5_values': [float(v) for v in logits.topk(5).values] + } + } + + # BLEND with frequency prior + logits_guided = logits + alpha * frequency_prior + + # Diagnostic logging: frequency blending + if diagnostic_mode: + step_diagnostics['blending'] = { + 'alpha': alpha, + 'prior_contribution': float((alpha * frequency_prior).abs().mean()), + 'logits_shift': float((logits_guided - logits).abs().mean()), + 'top_5_after_blend_indices': [int(i) for i in logits_guided.topk(5).indices], + 'top_5_after_blend_codes': [tokenizer.vocab.idx2code.get(int(i), f"<{i}>") + for i in logits_guided.topk(5).indices], + 'top_5_after_blend_values': [float(v) for v in logits_guided.topk(5).values] + } + + # Apply temperature + scaled_logits = logits_guided / temperature + + # Convert to probabilities + probs = torch.softmax(scaled_logits, dim=0) + + # Diagnostic logging: probabilities after temperature + if diagnostic_mode: + top_probs, top_indices = torch.topk(probs, 20) + step_diagnostics['probabilities'] = { + 'temperature': temperature, + 'entropy': float(-(probs * torch.log(probs + 1e-10)).sum()), + 'top_20': [ + {'code': tokenizer.vocab.idx2code.get(int(idx), f"<{idx}>"), + 'prob': float(prob), + 'idx': int(idx)} + for idx, prob in zip(top_indices, top_probs) + ] + } + + # Apply top-k filtering if enabled + if top_k > 0: + top_k_vals, top_k_indices = torch.topk(probs, min(top_k, probs.size(-1))) + probs_filtered = torch.zeros_like(probs) + probs_filtered.scatter_(0, top_k_indices, top_k_vals) + probs = probs_filtered / probs_filtered.sum() + + # Apply nucleus (top-p) sampling + if top_p < 1.0: + sorted_probs, sorted_indices = torch.sort(probs, descending=True) + cumsum_probs = torch.cumsum(sorted_probs, dim=0) + nucleus_mask = cumsum_probs <= top_p + nucleus_mask[0] = True # Always include top token + + nucleus_indices = sorted_indices[nucleus_mask] + nucleus_probs = sorted_probs[nucleus_mask] + nucleus_probs = nucleus_probs / nucleus_probs.sum() + + # Sample from nucleus + sampled_idx = torch.multinomial(nucleus_probs, 1)[0] + next_token = int(nucleus_indices[sampled_idx]) + else: + # Sample directly from filtered probs + next_token = int(torch.multinomial(probs, 1)[0]) + + # Diagnostic logging: sampling decision + if diagnostic_mode: + selected_code = tokenizer.vocab.idx2code.get(next_token, f"<{next_token}>") + step_diagnostics['selected'] = { + 'token': next_token, + 'code': selected_code, + 'probability': float(probs[next_token]) if next_token < len(probs) else 0.0, + 'was_top_1': (next_token == int(probs.argmax())), + 'is_special_token': next_token < tokenizer.code_offset + } + visit_diagnostics['steps'].append(step_diagnostics) + + # Check if we hit end-of-visit + if next_token == v_end_token_id: + break + + # Extract code if it's a diagnosis code + # FIX: code2idx already includes special tokens, so don't subtract offset + if next_token >= tokenizer.code_offset: + if next_token < len(tokenizer.vocab.idx2code): + code = tokenizer.vocab.idx2code[next_token] + if code not in visit_codes: # Prevent duplicates + visit_codes.append(code) + + # Append token to decoder input + next_token_tensor = torch.tensor([[next_token]], dtype=torch.long).to(device) + decoder_input_ids = torch.cat([decoder_input_ids, next_token_tensor], dim=1) + + # Stop if we have enough codes + if len(visit_codes) >= target_codes: + break + + # Add visit if we generated codes + if len(visit_codes) > 0: + # Truncate to target if over-generated + if len(visit_codes) > target_codes: + visit_codes = visit_codes[:target_codes] + + all_visits.append(visit_codes) + + # Add visit diagnostics + if diagnostic_mode: + visit_diagnostics['generated_codes'] = visit_codes + visit_diagnostics['target_codes'] = target_codes + all_diagnostics['visits'].append(visit_diagnostics) + + # Append to close visit + v_end_tensor = torch.tensor([[v_end_token_id]], dtype=torch.long).to(device) + decoder_input_ids = torch.cat([decoder_input_ids, v_end_tensor], dim=1) + + # Check if we're approaching context limit + if decoder_input_ids.shape[1] > 400: + break + + # Compute statistics + total_codes = sum(len(visit) for visit in all_visits) + + # Build result dictionary + result = { + 'generated_visits': all_visits, + 'demographics': {'age': age, 'sex': sex}, + 'num_visits': len(all_visits), + 'num_codes': total_codes, + 'target_structure': target_structure, + 'alpha': alpha + } + + # Add diagnostics if enabled + if diagnostic_mode: + all_diagnostics['demographics'] = {'age': age, 'sex': sex} + all_diagnostics['params'] = { + 'alpha': alpha, + 'temperature': temperature, + 'top_k': top_k, + 'top_p': top_p + } + all_diagnostics['generated_codes'] = all_visits + result['diagnostics'] = all_diagnostics + + # Save diagnostics to file if path provided + if diagnostic_path: + import json + import os + os.makedirs(os.path.dirname(diagnostic_path), exist_ok=True) + with open(diagnostic_path, 'w') as f: + json.dump(all_diagnostics, f, indent=2) + + return result diff --git a/pyhealth/models/promptehr/model.py b/pyhealth/models/promptehr/model.py new file mode 100644 index 000000000..0ffb7f68e --- /dev/null +++ b/pyhealth/models/promptehr/model.py @@ -0,0 +1,548 @@ +"""PromptEHR: BART-based generative model for synthetic EHR generation. + +This module provides the main PromptEHR model that combines demographic-conditioned +prompts with BART encoder-decoder architecture for realistic patient record generation. + +Ported from pehr_scratch/prompt_bart_model.py (lines 16-276, excluding auxiliary losses). +""" + +from typing import Dict, List, Optional, Tuple +import torch +import torch.nn as nn +from transformers import BartConfig, BartForConditionalGeneration +from transformers.modeling_outputs import Seq2SeqLMOutput + +from pyhealth.models import BaseModel +from .conditional_prompt import ConditionalPromptEncoder +from .bart_encoder import PromptBartEncoder +from .bart_decoder import PromptBartDecoder + + +class PromptBartModel(BartForConditionalGeneration): + """BART model with demographic prompt conditioning for EHR generation. + + Extends HuggingFace's BartForConditionalGeneration with: + 1. Dual prompt encoders (separate for encoder/decoder) + 2. Demographic conditioning via learned prompt vectors + 3. Label smoothing for diverse generation + + This is the core generative model WITHOUT auxiliary losses (those caused + mode collapse and are excluded per implementation decision D003). + + Args: + config: BART configuration from transformers + n_num_features: Number of continuous features (1 for age) + cat_cardinalities: Category counts for categorical features ([2] for gender M/F) + d_hidden: Intermediate reparameterization dimension (default: 128) + prompt_length: Number of prompt vectors per feature (default: 1) + + Example: + >>> from transformers import BartConfig + >>> config = BartConfig.from_pretrained("facebook/bart-base") + >>> model = PromptBartModel( + ... config, + ... n_num_features=1, # age + ... cat_cardinalities=[2], # gender (M/F) + ... d_hidden=128, + ... prompt_length=1 + ... ) + >>> # Forward pass with demographics + >>> age = torch.randn(16, 1) # [batch, 1] + >>> gender = torch.randint(0, 2, (16, 1)) # [batch, 1] + >>> input_ids = torch.randint(0, 1000, (16, 100)) + >>> labels = torch.randint(0, 1000, (16, 50)) + >>> output = model( + ... input_ids=input_ids, + ... labels=labels, + ... x_num=age, + ... x_cat=gender + ... ) + >>> loss = output.loss + """ + + def __init__( + self, + config: BartConfig, + n_num_features: Optional[int] = None, + cat_cardinalities: Optional[list] = None, + d_hidden: int = 128, + prompt_length: int = 1 + ): + """Initialize PromptBART model with dual prompt conditioning. + + Args: + config: BART configuration + n_num_features: Number of continuous features (e.g., 1 for age) + cat_cardinalities: Category counts for categorical features [n_genders] + d_hidden: Intermediate reparameterization dimension (default: 128) + prompt_length: Number of prompt vectors per feature (default: 1) + """ + super().__init__(config) + + # Replace encoder and decoder with prompt-aware versions + self.model.encoder = PromptBartEncoder(config, self.model.shared) + self.model.decoder = PromptBartDecoder(config, self.model.shared) + + # Add SEPARATE conditional prompt encoders for encoder and decoder + # This provides stronger demographic conditioning than shared prompts (dual injection) + if n_num_features is not None or cat_cardinalities is not None: + # Encoder prompt encoder + self.encoder_prompt_encoder = ConditionalPromptEncoder( + n_num_features=n_num_features, + cat_cardinalities=cat_cardinalities, + hidden_dim=config.d_model, + d_hidden=d_hidden, + prompt_length=prompt_length + ) + # Decoder prompt encoder (separate parameters for dual injection) + self.decoder_prompt_encoder = ConditionalPromptEncoder( + n_num_features=n_num_features, + cat_cardinalities=cat_cardinalities, + hidden_dim=config.d_model, + d_hidden=d_hidden, + prompt_length=prompt_length + ) + self.num_prompts = self.encoder_prompt_encoder.get_num_prompts() + else: + self.encoder_prompt_encoder = None + self.decoder_prompt_encoder = None + self.num_prompts = 0 + + # Initialize weights + self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + x_num: Optional[torch.FloatTensor] = None, + x_cat: Optional[torch.LongTensor] = None, + ) -> Seq2SeqLMOutput: + """Forward pass with demographic conditioning. + + Args: + input_ids: [batch, seq_len] encoder input token IDs + attention_mask: [batch, seq_len] encoder attention mask + decoder_input_ids: [batch, tgt_len] decoder input token IDs + decoder_attention_mask: [batch, tgt_len] decoder attention mask + labels: [batch, tgt_len] target labels for loss computation + x_num: [batch, n_num_features] continuous demographic features (e.g., age) + x_cat: [batch, n_cat_features] categorical demographic features (e.g., gender) + Other args: Standard BART arguments + + Returns: + Seq2SeqLMOutput with: + - loss: Cross-entropy loss with label smoothing=0.1 + - logits: [batch, tgt_len, vocab_size] prediction logits + - past_key_values: Cached key-value states (if use_cache=True) + - decoder_hidden_states: Decoder layer outputs (if output_hidden_states=True) + - encoder_last_hidden_state: Final encoder output + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Encode demographic prompts separately for encoder and decoder + # Only prepend prompts on first step (when no cache exists) + encoder_prompt_embeds = None + decoder_prompt_embeds = None + if (x_num is not None or x_cat is not None) and past_key_values is None: + if self.encoder_prompt_encoder is not None: + encoder_prompt_embeds = self.encoder_prompt_encoder(x_num=x_num, x_cat=x_cat) + if self.decoder_prompt_encoder is not None: + decoder_prompt_embeds = self.decoder_prompt_encoder(x_num=x_num, x_cat=x_cat) + + # Prepare decoder input IDs (shift labels right for teacher forcing) + if labels is not None: + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + # Encoder forward pass (with encoder prompts) + if encoder_outputs is None: + encoder_outputs = self.model.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + inputs_prompt_embeds=encoder_prompt_embeds, # Encoder-specific prompts + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + # Extend encoder attention mask for prompts + encoder_attention_mask = attention_mask + if encoder_prompt_embeds is not None and attention_mask is not None: + batch_size, n_prompts = encoder_prompt_embeds.shape[:2] + prompt_mask = torch.ones(batch_size, n_prompts, dtype=attention_mask.dtype, device=attention_mask.device) + encoder_attention_mask = torch.cat([prompt_mask, attention_mask], dim=1) + + # Decoder forward pass (with decoder prompts) + decoder_outputs = self.model.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=encoder_attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + inputs_prompt_embeds=decoder_prompt_embeds, # Decoder-specific prompts + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + # Language modeling head + lm_logits = self.lm_head(decoder_outputs[0]) + + # If decoder prompts were prepended, slice them off before loss computation + if decoder_prompt_embeds is not None and labels is not None: + # decoder_outputs[0] shape: [batch, n_prompts + seq_len, hidden_dim] + # We only want logits for the actual sequence positions + n_prompts = decoder_prompt_embeds.shape[1] + lm_logits = lm_logits[:, n_prompts:, :] # Remove prompt positions + + # Compute loss if labels provided + loss = None + if labels is not None: + # Label smoothing = 0.1 to prevent overconfidence and encourage diversity + # Softens target distributions: 90% on correct token, 10% distributed to alternatives + loss_fct = nn.CrossEntropyLoss(label_smoothing=0.1) + loss = loss_fct(lm_logits.reshape(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs + return ((loss,) + output) if loss is not None else output + + return Seq2SeqLMOutput( + loss=loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + decoder_input_ids, + past_key_values=None, + attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + x_num=None, + x_cat=None, + **kwargs + ): + """Prepare inputs for autoregressive generation. + + Args: + decoder_input_ids: [batch, cur_len] current decoder input IDs + past_key_values: Cached key-value states from previous steps + x_num: [batch, n_num_features] continuous demographics (passed through) + x_cat: [batch, n_cat_features] categorical demographics (passed through) + Other args: Standard BART generation arguments + + Returns: + Dictionary of inputs for next generation step + """ + # Cut decoder_input_ids if past is used (only need last token) + if past_key_values is not None: + decoder_input_ids = decoder_input_ids[:, -1:] + + return { + "input_ids": None, + "encoder_outputs": encoder_outputs, + "past_key_values": past_key_values, + "decoder_input_ids": decoder_input_ids, + "attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, + "x_num": x_num, # Pass demographics through generation + "x_cat": x_cat, + } + + @staticmethod + def _expand_inputs_for_generation( + input_ids, + expand_size=1, + is_encoder_decoder=True, + attention_mask=None, + encoder_outputs=None, + x_num=None, + x_cat=None, + **model_kwargs, + ): + """Expand inputs for beam search or multiple samples. + + Args: + input_ids: [batch, seq_len] input token IDs + expand_size: Number of beams/samples per input + x_num: [batch, n_num_features] continuous demographics + x_cat: [batch, n_cat_features] categorical demographics + Other args: Standard expansion arguments + + Returns: + Expanded input_ids and model_kwargs + """ + expanded_return_idx = ( + torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device) + ) + + if attention_mask is not None: + model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx) + + if encoder_outputs is not None: + encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select( + 0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device) + ) + model_kwargs["encoder_outputs"] = encoder_outputs + + # Expand demographics for beam search + if x_num is not None: + model_kwargs["x_num"] = x_num.index_select(0, expanded_return_idx) + + if x_cat is not None: + model_kwargs["x_cat"] = x_cat.index_select(0, expanded_return_idx) + + return input_ids, model_kwargs + + +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): + """Shift input ids one token to the right for teacher forcing. + + Args: + input_ids: [batch, seq_len] target token IDs + pad_token_id: ID for padding token + decoder_start_token_id: ID for decoder start token (BOS) + + Returns: + [batch, seq_len] shifted token IDs with BOS prepended + """ + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + shifted_input_ids[:, 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("config.pad_token_id must be defined for sequence generation") + + # Replace -100 in labels with pad_token_id + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +class PromptEHR(BaseModel): + """PromptEHR: PyHealth wrapper for prompt-based BART EHR generation. + + This class extends PyHealth's BaseModel to integrate PromptBartModel into + the PyHealth ecosystem while maintaining compatibility with PyHealth's + Trainer and evaluation infrastructure. + + Args: + dataset: PyHealth dataset (required by BaseModel, can be None for generative) + n_num_features: Number of continuous features (1 for age) + cat_cardinalities: Category counts for categorical features ([2] for gender) + d_hidden: Intermediate reparameterization dimension (default: 128) + prompt_length: Number of prompt vectors per feature (default: 1) + bart_config_name: Pretrained BART model name (default: "facebook/bart-base") + **kwargs: Additional BaseModel arguments + + Example: + >>> from pyhealth.datasets import PromptEHRDataset + >>> dataset = PromptEHRDataset(...) + >>> model = PromptEHR( + ... dataset=dataset, + ... n_num_features=1, + ... cat_cardinalities=[2], + ... d_hidden=128 + ... ) + >>> # Training + >>> output = model(input_ids=..., labels=..., x_num=..., x_cat=...) + >>> loss = output["loss"] + >>> # Generation + >>> generated = model.generate(input_ids=..., x_num=..., x_cat=...) + """ + + def __init__( + self, + dataset=None, + n_num_features: int = 1, + cat_cardinalities: Optional[list] = None, + d_hidden: int = 128, + prompt_length: int = 1, + bart_config_name: str = "facebook/bart-base", + **kwargs + ): + """Initialize PromptEHR model with PyHealth BaseModel integration. + + Args: + dataset: PyHealth dataset (can be None for generative models) + n_num_features: Number of continuous features (default: 1 for age) + cat_cardinalities: Category counts (default: [2] for gender M/F) + d_hidden: Reparameterization dimension (default: 128) + prompt_length: Prompt vectors per feature (default: 1) + bart_config_name: Pretrained BART model (default: "facebook/bart-base") + **kwargs: Additional BaseModel arguments (including _custom_vocab_size for checkpoint loading) + """ + # Extract custom vocab size if provided (used by load_from_checkpoint) + custom_vocab_size = kwargs.pop('_custom_vocab_size', None) + + super().__init__(dataset=dataset, **kwargs) + + # Set mode to None to skip discriminative evaluation (generative model) + self.mode = None + + # Default categorical cardinalities if not provided + if cat_cardinalities is None: + cat_cardinalities = [2] # Gender (M/F) + + # Initialize BART config from pretrained + bart_config = BartConfig.from_pretrained(bart_config_name) + + # Override vocab_size if loading from custom checkpoint + if custom_vocab_size is not None: + bart_config.vocab_size = custom_vocab_size + + # Apply dropout configuration (increased from BART default 0.1 to 0.3) + bart_config.dropout = 0.3 + bart_config.attention_dropout = 0.3 + bart_config.activation_dropout = 0.3 + + # Initialize PromptBartModel + self.bart_model = PromptBartModel( + config=bart_config, + n_num_features=n_num_features, + cat_cardinalities=cat_cardinalities, + d_hidden=d_hidden, + prompt_length=prompt_length + ) + + def forward(self, **kwargs) -> Dict[str, torch.Tensor]: + """Forward pass for training. + + Args: + **kwargs: Arguments passed to PromptBartModel.forward() + Required: input_ids, labels, x_num, x_cat + Optional: attention_mask, decoder_attention_mask, etc. + + Returns: + Dictionary with: + - loss: Cross-entropy loss with label smoothing + - logits: Prediction logits (optional) + """ + output = self.bart_model(**kwargs) + + # Return PyHealth-compatible dict (minimum: {"loss": ...}) + result = { + "loss": output.loss, + } + + # Add optional fields if available + if hasattr(output, "logits"): + result["logits"] = output.logits + + return result + + def generate(self, **kwargs): + """Generate synthetic patient sequences. + + Args: + **kwargs: Arguments passed to PromptBartModel.generate() + Required: input_ids (demographics encoded), x_num, x_cat + Optional: max_length, num_beams, temperature, etc. + + Returns: + Generated token IDs [batch, seq_len] + """ + return self.bart_model.generate(**kwargs) + + @classmethod + def load_from_checkpoint(cls, checkpoint_path, dataset=None, **model_kwargs): + """Load PromptEHR model from pehr_scratch checkpoint. + + Args: + checkpoint_path: Path to checkpoint file (e.g., best_model.pt) + dataset: PyHealth dataset (optional, can be None for generative models) + **model_kwargs: Model initialization arguments (n_num_features, cat_cardinalities, etc.) + + Returns: + Loaded PromptEHR model with checkpoint weights + + Example: + >>> model = PromptEHR.load_from_checkpoint( + ... "/scratch/jalenj4/promptehr_checkpoints/best_model.pt", + ... n_num_features=1, + ... cat_cardinalities=[2] + ... ) + """ + import torch + + # Load checkpoint (weights_only=False needed for custom tokenizer class) + checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False) + + # Extract model state dict (pehr_scratch format has extra keys) + if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: + state_dict = checkpoint['model_state_dict'] + epoch = checkpoint.get('epoch', None) + val_loss = checkpoint.get('val_loss', None) + else: + # Direct state dict + state_dict = checkpoint + epoch = None + val_loss = None + + # Auto-detect vocab_size from checkpoint + # pehr_scratch uses custom vocabulary (6992 tokens) vs BART default (50265) + if 'model.shared.weight' in state_dict: + checkpoint_vocab_size = state_dict['model.shared.weight'].shape[0] + + # Override bart_config_name if vocab size differs from default + if 'bart_config_name' not in model_kwargs: + # Load default config to check vocab size + from transformers import BartConfig + default_config = BartConfig.from_pretrained("facebook/bart-base") + + if checkpoint_vocab_size != default_config.vocab_size: + # Create custom config with detected vocab size + print(f"Detected custom vocab_size={checkpoint_vocab_size} in checkpoint " + f"(BART default: {default_config.vocab_size})") + + # Store custom config by temporarily modifying the config + model_kwargs['_custom_vocab_size'] = checkpoint_vocab_size + + # Create model instance + model = cls(dataset=dataset, **model_kwargs) + + # Load weights + model.bart_model.load_state_dict(state_dict, strict=True) + + # Print checkpoint info + if epoch is not None: + print(f"Loaded checkpoint from epoch {epoch}, val_loss={val_loss:.4f}") + + return model diff --git a/pyhealth/models/promptehr/utils.py b/pyhealth/models/promptehr/utils.py new file mode 100644 index 000000000..43e13ca83 --- /dev/null +++ b/pyhealth/models/promptehr/utils.py @@ -0,0 +1,29 @@ +"""Utility functions and classes for PromptEHR. + +This module contains: + - VisitStructureSampler: Samples realistic visit structures for generation + - Data collation functions + - Helper utilities +""" + +import torch +import torch.nn as nn + + +class VisitStructureSampler: + """Samples realistic visit structures from training data. + + This is a critical component added Nov 21, 2025 that solves the + over-generation problem. Reduces codes/patient from 18.1 → 11.97 (34%). + + Args: + TODO: Add arguments after porting from pehr_scratch + """ + + def __init__(self, **kwargs): + # TODO: Port from ~/pehr_scratch/visit_structure_sampler.py + raise NotImplementedError("VisitStructureSampler porting in progress") + + def sample(self, **kwargs): + """Sample a visit structure.""" + raise NotImplementedError("VisitStructureSampler porting in progress") diff --git a/pyhealth/models/promptehr/visit_sampler.py b/pyhealth/models/promptehr/visit_sampler.py new file mode 100644 index 000000000..03efbf78f --- /dev/null +++ b/pyhealth/models/promptehr/visit_sampler.py @@ -0,0 +1,121 @@ +""" +Sample realistic visit structures from real MIMIC-III data distributions. + +This module provides functionality to sample the number of visits per patient +and the number of diagnosis codes per visit, matching the empirical distributions +observed in real EHR data. +""" +import numpy as np +from typing import List + + +class VisitStructureSampler: + """Sample realistic visit and code count structures from training data.""" + + def __init__(self, patient_records: List, seed: int = 42): + """Initialize sampler with empirical distributions from training data. + + Args: + patient_records: List of patient records from training set. + Each record should have a 'visits' attribute (list of visit codes). + seed: Random seed for reproducibility. + """ + self.rng = np.random.RandomState(seed) + + # Extract empirical distributions + self.num_visits_per_patient = [] + self.codes_per_visit_all = [] + + for patient in patient_records: + # Handle both dict-like and object-like patient records + if hasattr(patient, 'visits'): + visits = patient.visits + elif isinstance(patient, dict) and 'visits' in patient: + visits = patient['visits'] + else: + continue + + num_visits = len(visits) + self.num_visits_per_patient.append(num_visits) + + for visit in visits: + num_codes = len(visit) + if num_codes > 0: # Only include non-empty visits + self.codes_per_visit_all.append(num_codes) + + # Convert to numpy arrays + self.num_visits_per_patient = np.array(self.num_visits_per_patient) + self.codes_per_visit_all = np.array(self.codes_per_visit_all) + + # Compute statistics for logging + self.stats = { + 'visits_mean': np.mean(self.num_visits_per_patient), + 'visits_median': np.median(self.num_visits_per_patient), + 'visits_90th': np.percentile(self.num_visits_per_patient, 90), + 'codes_mean': np.mean(self.codes_per_visit_all), + 'codes_median': np.median(self.codes_per_visit_all), + 'codes_90th': np.percentile(self.codes_per_visit_all, 90), + 'codes_95th': np.percentile(self.codes_per_visit_all, 95), + } + + def sample_num_visits(self) -> int: + """Sample number of visits from empirical distribution. + + Returns: + Number of visits (>= 0). + """ + return int(self.rng.choice(self.num_visits_per_patient)) + + def sample_codes_per_visit(self, n_visits: int) -> List[int]: + """Sample number of codes for each visit from empirical distribution. + + Args: + n_visits: Number of visits to sample code counts for. + + Returns: + List of integers representing codes per visit. + """ + if n_visits == 0: + return [] + + # Sample with replacement from empirical distribution + codes_counts = self.rng.choice(self.codes_per_visit_all, size=n_visits, replace=True) + return codes_counts.tolist() + + def sample_structure(self) -> dict: + """Sample complete visit structure (visits + codes per visit). + + Returns: + Dictionary with: + - 'num_visits': int (number of visits) + - 'codes_per_visit': List[int] (codes for each visit) + """ + num_visits = self.sample_num_visits() + codes_per_visit = self.sample_codes_per_visit(num_visits) + + return { + 'num_visits': num_visits, + 'codes_per_visit': codes_per_visit + } + + def get_statistics(self) -> dict: + """Get statistics about the underlying distributions. + + Returns: + Dictionary with mean/median/percentile statistics. + """ + return self.stats.copy() + + def __repr__(self) -> str: + """String representation showing distribution statistics.""" + return ( + f"VisitStructureSampler(\n" + f" Visits/patient: mean={self.stats['visits_mean']:.2f}, " + f"median={self.stats['visits_median']:.0f}, " + f"90th%={self.stats['visits_90th']:.0f}\n" + f" Codes/visit: mean={self.stats['codes_mean']:.2f}, " + f"median={self.stats['codes_median']:.0f}, " + f"90th%={self.stats['codes_90th']:.0f}, " + f"95th%={self.stats['codes_95th']:.0f}\n" + f")" + ) diff --git a/pyhealth/tasks/ehr_generation.py b/pyhealth/tasks/ehr_generation.py new file mode 100644 index 000000000..dc523ff5a --- /dev/null +++ b/pyhealth/tasks/ehr_generation.py @@ -0,0 +1,30 @@ +"""EHR generation task function for PromptEHR. + +This module defines the task function for synthetic EHR generation. +""" + +from typing import Dict, List, Optional + + +def ehr_generation_fn(patient_data: Dict) -> Dict: + """Task function for EHR generation. + + This task function prepares patient data for conditional EHR generation, + including demographics and optional visit history for continuation. + + Args: + patient_data: Dictionary containing patient information + + Returns: + Dictionary with input_schema and output_schema attributes + + Examples: + TODO: Add usage examples + """ + # TODO: Port task function logic from pehr_scratch + raise NotImplementedError("ehr_generation_fn porting in progress") + + +# Set task function attributes (PyHealth pattern) +ehr_generation_fn.input_schema = None # TODO: Define schema +ehr_generation_fn.output_schema = None # TODO: Define schema From 6a651427f98aafa7174e17e250bce16263aafd56 Mon Sep 17 00:00:00 2001 From: jalengg Date: Sun, 1 Mar 2026 01:48:45 -0600 Subject: [PATCH 02/37] T2: add PromptEHRGenerationMIMIC3 BaseTask with demographics --- pyhealth/tasks/__init__.py | 6 ++ pyhealth/tasks/ehr_generation.py | 144 +++++++++++++++++++++++++++---- 2 files changed, 133 insertions(+), 17 deletions(-) diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 2f4294a19..c0811c77a 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -1,5 +1,11 @@ from .base_task import BaseTask from .benchmark_ehrshot import BenchmarkEHRShot +from .ehr_generation import ( + PromptEHRGenerationMIMIC3, + PromptEHRGenerationMIMIC4, + promptehr_generation_mimic3_fn, + promptehr_generation_mimic4_fn, +) from .cancer_survival import CancerMutationBurden, CancerSurvivalPrediction from .bmd_hs_disease_classification import BMDHSDiseaseClassification from .cardiology_detect import ( diff --git a/pyhealth/tasks/ehr_generation.py b/pyhealth/tasks/ehr_generation.py index dc523ff5a..e6aaca8f9 100644 --- a/pyhealth/tasks/ehr_generation.py +++ b/pyhealth/tasks/ehr_generation.py @@ -1,30 +1,140 @@ -"""EHR generation task function for PromptEHR. +"""Task function for PromptEHR synthetic EHR generation. -This module defines the task function for synthetic EHR generation. +Provides task classes for training PromptEHR on MIMIC-III and MIMIC-IV datasets. +Demographics (age, gender) are extracted alongside visit codes because PromptEHR +conditions generation on patient-level continuous and categorical features. """ -from typing import Dict, List, Optional +from datetime import datetime +from typing import Dict, List +import polars as pl -def ehr_generation_fn(patient_data: Dict) -> Dict: - """Task function for EHR generation. +from pyhealth.tasks.base_task import BaseTask - This task function prepares patient data for conditional EHR generation, - including demographics and optional visit history for continuation. - Args: - patient_data: Dictionary containing patient information +class PromptEHRGenerationMIMIC3(BaseTask): + """Task for PromptEHR synthetic data generation using MIMIC-III. - Returns: - Dictionary with input_schema and output_schema attributes + PromptEHR is a BART-based seq2seq model that conditions generation on + patient demographics (age, gender) via learned prompt vectors. This task + extracts per-admission ICD-9 diagnosis codes grouped into a nested visit + list, along with patient demographics for conditioning. + + Patients with fewer than 2 admissions containing diagnosis codes are + excluded. + + Attributes: + task_name (str): Unique task identifier. + input_schema (dict): ``"visits"`` uses ``"nested_sequence"`` encoding + (list of lists of code strings). + output_schema (dict): Empty — generative task, no conditioning label. + _icd_col (str): Polars column path for ICD codes in MIMIC-III. Examples: - TODO: Add usage examples + >>> fn = PromptEHRGenerationMIMIC3() + >>> fn.task_name + 'PromptEHRGenerationMIMIC3' """ - # TODO: Port task function logic from pehr_scratch - raise NotImplementedError("ehr_generation_fn porting in progress") + + task_name = "PromptEHRGenerationMIMIC3" + input_schema = {"visits": "nested_sequence"} + output_schema = {} + _icd_col = "diagnoses_icd/icd9_code" + + def __call__(self, patient) -> List[Dict]: + """Extract visit sequences and demographics for a single patient. + + Diagnosis codes are grouped per admission into a nested list. Age is + computed as years between date-of-birth and the first admission date. + Gender is encoded as 0 (male) or 1 (female). Defaults of + ``age=60.0, gender=0`` are used when demographics are unavailable. + + Args: + patient: A PyHealth Patient object with admissions and + diagnoses_icd event data. + + Returns: + list of dict: A single-element list, or empty list if fewer + than 2 visits have diagnosis codes. Each dict contains: + ``"patient_id"`` (str): patient identifier. + ``"visits"`` (list of list of str): ICD codes per visit. + ``"age"`` (float): patient age at first admission in years. + ``"gender"`` (int): 0 for male, 1 for female. + """ + admissions = list(patient.get_events(event_type="admissions")) + + # --- Demographics --- + age = 60.0 + gender = 0 + patients_df = patient.get_events(event_type="patients", return_df=True) + if len(patients_df) > 0: + if "patients/gender" in patients_df.columns: + gender_val = patients_df["patients/gender"][0] + if gender_val == "F": + gender = 1 + if "patients/dob" in patients_df.columns and admissions: + dob_val = patients_df["patients/dob"][0] + first_admit_ts = admissions[0].timestamp + if dob_val is not None and first_admit_ts is not None: + # dob_val may be a date/datetime or a string + if hasattr(dob_val, "year"): + dob_dt = datetime(dob_val.year, dob_val.month, dob_val.day) + else: + dob_dt = datetime.strptime(str(dob_val)[:10], "%Y-%m-%d") + raw_age = (first_admit_ts - dob_dt).days / 365.25 + # Clamp: MIMIC-III shifts >89-year-old DOBs far into the + # past; treat those as 90. + age = float(min(90.0, max(0.0, raw_age))) + + # --- Visit codes --- + visits = [] + for adm in admissions: + codes = ( + patient.get_events( + event_type="diagnoses_icd", + filters=[("hadm_id", "==", adm.hadm_id)], + return_df=True, + ) + .select(pl.col(self._icd_col)) + .to_series() + .drop_nulls() + .to_list() + ) + if codes: + visits.append(codes) + + if len(visits) < 2: + return [] + + return [{ + "patient_id": patient.patient_id, + "visits": visits, + "age": age, + "gender": gender, + }] + + +class PromptEHRGenerationMIMIC4(PromptEHRGenerationMIMIC3): + """Task for PromptEHR synthetic data generation using MIMIC-IV. + + Inherits all logic from :class:`PromptEHRGenerationMIMIC3`. Overrides only + the task name and ICD code column to match the MIMIC-IV schema, where the + column is ``icd_code`` (unversioned) rather than ``icd9_code``. + + Attributes: + task_name (str): Unique task identifier. + _icd_col (str): Polars column path for ICD codes in MIMIC-IV. + + Examples: + >>> fn = PromptEHRGenerationMIMIC4() + >>> fn.task_name + 'PromptEHRGenerationMIMIC4' + """ + + task_name = "PromptEHRGenerationMIMIC4" + _icd_col = "diagnoses_icd/icd_code" -# Set task function attributes (PyHealth pattern) -ehr_generation_fn.input_schema = None # TODO: Define schema -ehr_generation_fn.output_schema = None # TODO: Define schema +promptehr_generation_mimic3_fn = PromptEHRGenerationMIMIC3() +promptehr_generation_mimic4_fn = PromptEHRGenerationMIMIC4() From 39ec171f2583fee586c73fb0a5e7401ebacf50cd Mon Sep 17 00:00:00 2001 From: jalengg Date: Sun, 1 Mar 2026 02:18:33 -0600 Subject: [PATCH 03/37] T3+fix: Refactor PromptEHR to BaseModel; fix T2 early exit + T3 review issues --- pyhealth/models/promptehr/model.py | 512 +++++++++++++++++++++-------- pyhealth/tasks/ehr_generation.py | 2 + 2 files changed, 381 insertions(+), 133 deletions(-) diff --git a/pyhealth/models/promptehr/model.py b/pyhealth/models/promptehr/model.py index 0ffb7f68e..08b710ab1 100644 --- a/pyhealth/models/promptehr/model.py +++ b/pyhealth/models/promptehr/model.py @@ -6,9 +6,12 @@ Ported from pehr_scratch/prompt_bart_model.py (lines 16-276, excluding auxiliary losses). """ +import os +import random from typing import Dict, List, Optional, Tuple import torch import torch.nn as nn +from torch.nn.utils.rnn import pad_sequence from transformers import BartConfig, BartForConditionalGeneration from transformers.modeling_outputs import Seq2SeqLMOutput @@ -356,193 +359,436 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start return shifted_input_ids -class PromptEHR(BaseModel): - """PromptEHR: PyHealth wrapper for prompt-based BART EHR generation. +class _PromptEHRVocab: + """Internal vocabulary bridging NestedSequenceProcessor indices to BART token IDs. + + Token layout (7 special tokens + N diagnosis codes): + 0 = (BartConfig.pad_token_id) + 1 = (BartConfig.bos_token_id / decoder_start_token_id) + 2 = (BartConfig.eos_token_id) + 3 = + 4 = (visit start) + 5 = (visit end) + 6 = (sequence terminator) + 7+ = diagnosis codes + + NestedSequenceProcessor uses pad=0, unk=1, codes=2+. + Mapping: processor_idx i → BART token i + 5 (for i >= 2). + Total BART vocab size = processor.vocab_size() + 5. + """ + + PAD = 0 + BOS = 1 + EOS = 2 + UNK = 3 + VISIT_START = 4 + VISIT_END = 5 + SEQ_END = 6 + CODE_OFFSET = 7 + + def __init__(self, code_vocab: dict): + """Build vocab from NestedSequenceProcessor.code_vocab dict. + + Args: + code_vocab (dict): Mapping of code string → processor index. + Must have ``""`` → 0 and ``""`` → 1. + """ + self._bart_to_code: Dict[int, str] = {} + for code, pid in code_vocab.items(): + if pid >= 2: # skip and + self._bart_to_code[pid + 5] = code + self.total_size = len(code_vocab) + 5 # 7 special - 2 reused + N codes + + def encode_visits(self, visits_tensor: torch.Tensor) -> List[int]: + """Encode a processed [n_visits, max_codes] LongTensor to a token ID list. + + Args: + visits_tensor (torch.Tensor): LongTensor of shape + ``(n_visits, max_codes_per_visit)`` from NestedSequenceProcessor. + Values 0 = pad, 1 = unk, 2+ = code index. + + Returns: + list of int: Token IDs in format + ``[, code, ..., , , ..., , ]``. + """ + tokens = [] + for visit in visits_tensor: + codes_in_visit = [ + int(c.item()) + 5 # processor idx 2+ → BART idx 7+ + for c in visit + if c.item() >= 2 # skip pad and unk + ] + if codes_in_visit: + tokens.append(self.VISIT_START) + tokens.extend(codes_in_visit) + tokens.append(self.VISIT_END) + tokens.append(self.SEQ_END) + return tokens + + def decode_tokens(self, token_ids: List[int]) -> List[List[str]]: + """Decode a generated token ID list back to visit structure. - This class extends PyHealth's BaseModel to integrate PromptBartModel into - the PyHealth ecosystem while maintaining compatibility with PyHealth's - Trainer and evaluation infrastructure. + Args: + token_ids (list of int): Raw generated token IDs from BART. + + Returns: + list of list of str: Decoded diagnosis code strings per visit. + """ + visits: List[List[str]] = [] + current_visit: List[str] = [] + in_visit = False + for tid in token_ids: + if tid == self.VISIT_START: + in_visit = True + current_visit = [] + elif tid == self.VISIT_END: + if in_visit: + visits.append(current_visit) + in_visit = False + elif tid in (self.SEQ_END, self.EOS, self.PAD, self.BOS): + break + elif in_visit and tid >= self.CODE_OFFSET: + code = self._bart_to_code.get(tid) + if code: + current_visit.append(code) + if in_visit and current_visit: + visits.append(current_visit) + return visits + + +def _promptehr_collate_fn(batch): + """Collate PromptEHR training samples, padding token sequences in a batch. + + Pads ``input_ids`` and ``labels`` to the longest sequence in the batch using + ``pad_sequence``. Builds the attention mask from padded positions. Args: - dataset: PyHealth dataset (required by BaseModel, can be None for generative) - n_num_features: Number of continuous features (1 for age) - cat_cardinalities: Category counts for categorical features ([2] for gender) - d_hidden: Intermediate reparameterization dimension (default: 128) - prompt_length: Number of prompt vectors per feature (default: 1) - bart_config_name: Pretrained BART model name (default: "facebook/bart-base") - **kwargs: Additional BaseModel arguments + batch (list of dict): Each dict has ``"input_ids"``, ``"labels"``, + ``"x_num"``, and ``"x_cat"`` tensors. - Example: - >>> from pyhealth.datasets import PromptEHRDataset - >>> dataset = PromptEHRDataset(...) - >>> model = PromptEHR( - ... dataset=dataset, - ... n_num_features=1, - ... cat_cardinalities=[2], - ... d_hidden=128 + Returns: + dict: Batched tensors ready for ``PromptBartModel.forward()``. + """ + input_ids = pad_sequence( + [item["input_ids"] for item in batch], + batch_first=True, + padding_value=_PromptEHRVocab.PAD, + ) + labels = pad_sequence( + [item["labels"] for item in batch], + batch_first=True, + padding_value=-100, + ) + attention_mask = (input_ids != _PromptEHRVocab.PAD).long() + x_num = torch.cat([item["x_num"] for item in batch], dim=0) + x_cat = torch.cat([item["x_cat"] for item in batch], dim=0) + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "labels": labels, + "x_num": x_num, + "x_cat": x_cat, + } + + +class PromptEHR(BaseModel): + """PromptEHR: demographic-conditioned BART model for synthetic EHR generation. + + Wraps ``PromptBartModel`` (HuggingFace BART with dual prompt conditioning) + in a PyHealth ``BaseModel`` interface. Training is handled by a HuggingFace + ``Trainer`` loop; generation is autoregressive token-by-token decoding. + + Demographics (age as continuous, gender as categorical) are injected via + learned prompt vectors prepended to both encoder and decoder hidden states. + + Args: + dataset (SampleDataset): PyHealth sample dataset produced by + ``set_task(promptehr_generation_mimic3_fn)``. Must have + ``input_processors["visits"]`` (NestedSequenceProcessor). + n_num_features (int): Continuous demographic features (1 for age). + Default: 1. + cat_cardinalities (list of int): Category counts per categorical + feature ([2] for binary gender M/F). Default: [2]. + d_hidden (int): Reparameterization dimension for prompt encoder. + Default: 128. + prompt_length (int): Number of prompt vectors per feature. Default: 1. + bart_config_name (str): Pretrained BART config to use. + Default: ``"facebook/bart-base"``. + epochs (int): Training epochs. Default: 20. + batch_size (int): Training batch size. Default: 16. + lr (float): AdamW learning rate. Default: 1e-5. + warmup_steps (int): Linear warmup steps. Default: 1000. + max_seq_length (int): Maximum token sequence length. Default: 512. + save_dir (str): Directory for checkpoints. Default: ``"./save/"``. + + Examples: + >>> from pyhealth.datasets.sample_dataset import InMemorySampleDataset + >>> samples = [ + ... {"patient_id": "p1", "visits": [["428", "427"], ["410"]], "age": 65.0, "gender": 0}, + ... {"patient_id": "p2", "visits": [["250"], ["401", "272"]], "age": 52.0, "gender": 1}, + ... ] + >>> dataset = InMemorySampleDataset( + ... samples=samples, + ... input_schema={"visits": "nested_sequence"}, + ... output_schema={}, ... ) - >>> # Training - >>> output = model(input_ids=..., labels=..., x_num=..., x_cat=...) - >>> loss = output["loss"] - >>> # Generation - >>> generated = model.generate(input_ids=..., x_num=..., x_cat=...) + >>> model = PromptEHR(dataset, d_hidden=32, prompt_length=1) + >>> isinstance(model, PromptEHR) + True """ def __init__( self, - dataset=None, + dataset, n_num_features: int = 1, cat_cardinalities: Optional[list] = None, d_hidden: int = 128, prompt_length: int = 1, bart_config_name: str = "facebook/bart-base", - **kwargs + epochs: int = 20, + batch_size: int = 16, + lr: float = 1e-5, + warmup_steps: int = 1000, + max_seq_length: int = 512, + save_dir: str = "./save/", ): - """Initialize PromptEHR model with PyHealth BaseModel integration. + """Initialize PromptEHR with vocab derived from the dataset processor. Args: - dataset: PyHealth dataset (can be None for generative models) - n_num_features: Number of continuous features (default: 1 for age) - cat_cardinalities: Category counts (default: [2] for gender M/F) - d_hidden: Reparameterization dimension (default: 128) - prompt_length: Prompt vectors per feature (default: 1) - bart_config_name: Pretrained BART model (default: "facebook/bart-base") - **kwargs: Additional BaseModel arguments (including _custom_vocab_size for checkpoint loading) + dataset (SampleDataset): PyHealth dataset with + ``input_processors["visits"]`` (NestedSequenceProcessor). + n_num_features (int): Continuous demographic features. Default: 1. + cat_cardinalities (list of int): Category cardinalities. Default: [2]. + d_hidden (int): Prompt encoder hidden dim. Default: 128. + prompt_length (int): Prompt vectors per feature. Default: 1. + bart_config_name (str): Pretrained BART config. Default: + ``"facebook/bart-base"``. + epochs (int): Training epochs. Default: 20. + batch_size (int): Training batch size. Default: 16. + lr (float): AdamW learning rate. Default: 1e-5. + warmup_steps (int): Linear warmup steps. Default: 1000. + max_seq_length (int): Token sequence length cap. Default: 512. + save_dir (str): Checkpoint output directory. Default: ``"./save/"``. """ - # Extract custom vocab size if provided (used by load_from_checkpoint) - custom_vocab_size = kwargs.pop('_custom_vocab_size', None) - - super().__init__(dataset=dataset, **kwargs) + super().__init__(dataset) - # Set mode to None to skip discriminative evaluation (generative model) - self.mode = None + self.mode = None # skip discriminative evaluation + self.save_dir = save_dir + self.epochs = epochs + self.batch_size = batch_size + self.lr = lr + self.warmup_steps = warmup_steps + self.max_seq_length = max_seq_length + self._demo_pool: List[tuple] = [] # (age, gender) pairs from training data - # Default categorical cardinalities if not provided if cat_cardinalities is None: - cat_cardinalities = [2] # Gender (M/F) + cat_cardinalities = [2] - # Initialize BART config from pretrained - bart_config = BartConfig.from_pretrained(bart_config_name) - - # Override vocab_size if loading from custom checkpoint - if custom_vocab_size is not None: - bart_config.vocab_size = custom_vocab_size + # Derive vocab from the dataset's NestedSequenceProcessor + visits_processor = dataset.input_processors["visits"] + self._vocab = _PromptEHRVocab(visits_processor.code_vocab) + bart_vocab_size = self._vocab.total_size - # Apply dropout configuration (increased from BART default 0.1 to 0.3) + # Configure BART with our custom vocab and special token IDs + bart_config = BartConfig.from_pretrained(bart_config_name) + bart_config.vocab_size = bart_vocab_size + bart_config.pad_token_id = _PromptEHRVocab.PAD + bart_config.bos_token_id = _PromptEHRVocab.BOS + bart_config.eos_token_id = _PromptEHRVocab.EOS + bart_config.decoder_start_token_id = _PromptEHRVocab.BOS + bart_config.forced_eos_token_id = _PromptEHRVocab.SEQ_END bart_config.dropout = 0.3 bart_config.attention_dropout = 0.3 bart_config.activation_dropout = 0.3 - # Initialize PromptBartModel self.bart_model = PromptBartModel( config=bart_config, n_num_features=n_num_features, cat_cardinalities=cat_cardinalities, d_hidden=d_hidden, - prompt_length=prompt_length + prompt_length=prompt_length, ) - def forward(self, **kwargs) -> Dict[str, torch.Tensor]: - """Forward pass for training. + def forward(self, **kwargs) -> Dict: + """Not implemented — PromptEHR is a generative model without a discriminative forward. - Args: - **kwargs: Arguments passed to PromptBartModel.forward() - Required: input_ids, labels, x_num, x_cat - Optional: attention_mask, decoder_attention_mask, etc. - - Returns: - Dictionary with: - - loss: Cross-entropy loss with label smoothing - - logits: Prediction logits (optional) + Raises: + NotImplementedError: Always. Use ``train_model`` and + ``synthesize_dataset`` instead. """ - output = self.bart_model(**kwargs) - - # Return PyHealth-compatible dict (minimum: {"loss": ...}) - result = { - "loss": output.loss, - } + raise NotImplementedError( + "PromptEHR is a generative model. Use train_model() and synthesize_dataset()." + ) - # Add optional fields if available - if hasattr(output, "logits"): - result["logits"] = output.logits + def train_model(self, train_dataset, val_dataset=None) -> None: + """Train PromptEHR using a HuggingFace Trainer loop. - return result + Converts PyHealth SampleDataset samples to BART token sequences and + trains with HuggingFace ``Trainer``. Demographics (age, gender) are + passed as ``x_num`` / ``x_cat`` via a custom data collator. - def generate(self, **kwargs): - """Generate synthetic patient sequences. + Named ``train_model`` (not ``train``) to avoid shadowing + ``nn.Module.train()``. Args: - **kwargs: Arguments passed to PromptBartModel.generate() - Required: input_ids (demographics encoded), x_num, x_cat - Optional: max_length, num_beams, temperature, etc. - - Returns: - Generated token IDs [batch, seq_len] + train_dataset (SampleDataset): Training set with ``"visits"``, + ``"age"``, and ``"gender"`` fields. + val_dataset (SampleDataset, optional): Validation set for loss + monitoring. Default: None. """ - return self.bart_model.generate(**kwargs) + from torch.utils.data import Dataset as TorchDataset + from transformers import Trainer, TrainingArguments + + vocab = self._vocab + max_len = self.max_seq_length + + class _EHRDataset(TorchDataset): + def __init__(self, samples): + self._samples = list(samples) + + def __len__(self): + return len(self._samples) + + def __getitem__(self, idx): + s = self._samples[idx] + tokens = vocab.encode_visits(s["visits"]) + if len(tokens) > max_len: + tokens = tokens[:max_len - 1] + [vocab.SEQ_END] + age = float(s.get("age", 60.0)) + gender = int(s.get("gender", 0)) + return { + "input_ids": torch.tensor(tokens, dtype=torch.long), + "labels": torch.tensor(tokens, dtype=torch.long), + "x_num": torch.tensor([[age]], dtype=torch.float32), + "x_cat": torch.tensor([[gender]], dtype=torch.long), + } + + train_samples = list(train_dataset) + # Store demographics pool for synthesize_dataset sampling + self._demo_pool = [ + (float(s.get("age", 60.0)), int(s.get("gender", 0))) + for s in train_samples + ] + + os.makedirs(self.save_dir, exist_ok=True) + training_args = TrainingArguments( + output_dir=self.save_dir, + num_train_epochs=self.epochs, + per_device_train_batch_size=self.batch_size, + learning_rate=self.lr, + warmup_steps=self.warmup_steps, + save_strategy="epoch", + logging_steps=50, + remove_unused_columns=False, # essential: keeps x_num/x_cat + use_cpu=not torch.cuda.is_available(), + ) + + trainer = Trainer( + model=self.bart_model, + args=training_args, + train_dataset=_EHRDataset(train_samples), + eval_dataset=_EHRDataset(list(val_dataset)) if val_dataset else None, + data_collator=_promptehr_collate_fn, + ) + trainer.train() + + self.save_model(os.path.join(self.save_dir, "checkpoint.pt")) - @classmethod - def load_from_checkpoint(cls, checkpoint_path, dataset=None, **model_kwargs): - """Load PromptEHR model from pehr_scratch checkpoint. + def synthesize_dataset( + self, num_samples: int, random_sampling: bool = True + ) -> List[Dict]: + """Generate a synthetic patient dataset. + + Samples demographics from the training data distribution (if available) + and generates autoregressive token sequences via BART. Each sequence is + decoded back to a nested list of diagnosis code strings. Args: - checkpoint_path: Path to checkpoint file (e.g., best_model.pt) - dataset: PyHealth dataset (optional, can be None for generative models) - **model_kwargs: Model initialization arguments (n_num_features, cat_cardinalities, etc.) + num_samples (int): Number of synthetic patients to generate. + random_sampling (bool): If True, uses multinomial sampling with + ``temperature=0.7, top_p=0.95``. If False, uses greedy decoding. + Default: True. Returns: - Loaded PromptEHR model with checkpoint weights - - Example: - >>> model = PromptEHR.load_from_checkpoint( - ... "/scratch/jalenj4/promptehr_checkpoints/best_model.pt", - ... n_num_features=1, - ... cat_cardinalities=[2] - ... ) + list of dict: One record per synthetic patient. Each dict has: + ``"patient_id"`` (str): unique identifier, e.g. ``"synthetic_0"``. + ``"visits"`` (list of list of str): decoded code strings per visit. """ - import torch - - # Load checkpoint (weights_only=False needed for custom tokenizer class) - checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False) + self.bart_model.eval() + device = self.device + + results = [] + with torch.no_grad(): + for i in range(num_samples): + # Sample demographics from training distribution (or defaults) + if self._demo_pool: + age, gender = self._demo_pool[ + random.randrange(len(self._demo_pool)) + ] + else: + age, gender = 60.0, 0 + + x_num = torch.tensor([[age]], dtype=torch.float32, device=device) + x_cat = torch.tensor([[gender]], dtype=torch.long, device=device) + + # PAD token as minimal encoder input; prompts carry the signal + encoder_input = torch.tensor( + [[_PromptEHRVocab.PAD]], dtype=torch.long, device=device + ) - # Extract model state dict (pehr_scratch format has extra keys) - if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: - state_dict = checkpoint['model_state_dict'] - epoch = checkpoint.get('epoch', None) - val_loss = checkpoint.get('val_loss', None) - else: - # Direct state dict - state_dict = checkpoint - epoch = None - val_loss = None + output_ids = self.bart_model.generate( + input_ids=encoder_input, + attention_mask=torch.ones_like(encoder_input), + x_num=x_num, + x_cat=x_cat, + max_length=self.max_seq_length, + do_sample=random_sampling, + temperature=0.7 if random_sampling else 1.0, + top_p=0.95 if random_sampling else 1.0, + pad_token_id=_PromptEHRVocab.PAD, + eos_token_id=_PromptEHRVocab.SEQ_END, + bos_token_id=_PromptEHRVocab.BOS, + ) - # Auto-detect vocab_size from checkpoint - # pehr_scratch uses custom vocabulary (6992 tokens) vs BART default (50265) - if 'model.shared.weight' in state_dict: - checkpoint_vocab_size = state_dict['model.shared.weight'].shape[0] + visits = self._vocab.decode_tokens(output_ids[0].tolist()) + results.append({ + "patient_id": f"synthetic_{i}", + "visits": visits, + }) - # Override bart_config_name if vocab size differs from default - if 'bart_config_name' not in model_kwargs: - # Load default config to check vocab size - from transformers import BartConfig - default_config = BartConfig.from_pretrained("facebook/bart-base") + return results - if checkpoint_vocab_size != default_config.vocab_size: - # Create custom config with detected vocab size - print(f"Detected custom vocab_size={checkpoint_vocab_size} in checkpoint " - f"(BART default: {default_config.vocab_size})") + def save_model(self, path: str) -> None: + """Save model weights and vocab to a checkpoint file. - # Store custom config by temporarily modifying the config - model_kwargs['_custom_vocab_size'] = checkpoint_vocab_size + Args: + path (str): Destination file path (e.g. ``"./save/checkpoint.pt"``). - # Create model instance - model = cls(dataset=dataset, **model_kwargs) + Examples: + >>> import tempfile, os + >>> tmpdir = tempfile.mkdtemp() + >>> model.save_model(os.path.join(tmpdir, "ckpt.pt")) + """ + os.makedirs(os.path.dirname(path) or ".", exist_ok=True) + torch.save( + { + "model": self.bart_model.state_dict(), + "vocab": self._vocab, + "bart_config": self.bart_model.config, + }, + path, + ) - # Load weights - model.bart_model.load_state_dict(state_dict, strict=True) + def load_model(self, path: str) -> None: + """Load model weights from a checkpoint saved by ``save_model``. - # Print checkpoint info - if epoch is not None: - print(f"Loaded checkpoint from epoch {epoch}, val_loss={val_loss:.4f}") + Args: + path (str): Path to checkpoint file produced by ``save_model``. - return model + Examples: + >>> model.load_model("./save/checkpoint.pt") + """ + checkpoint = torch.load(path, map_location=self.device, weights_only=False) + self.bart_model.load_state_dict(checkpoint["model"]) + if "vocab" in checkpoint: + self._vocab = checkpoint["vocab"] diff --git a/pyhealth/tasks/ehr_generation.py b/pyhealth/tasks/ehr_generation.py index e6aaca8f9..788dd0351 100644 --- a/pyhealth/tasks/ehr_generation.py +++ b/pyhealth/tasks/ehr_generation.py @@ -63,6 +63,8 @@ def __call__(self, patient) -> List[Dict]: ``"gender"`` (int): 0 for male, 1 for female. """ admissions = list(patient.get_events(event_type="admissions")) + if len(admissions) < 2: + return [] # --- Demographics --- age = 60.0 From 68f0ca3376774eaf995f25ed434c57e371d08939 Mon Sep 17 00:00:00 2001 From: jalengg Date: Sun, 1 Mar 2026 02:24:15 -0600 Subject: [PATCH 04/37] T5: Add PromptEHR PyHealth 2.0 generation example --- .../generate_synthetic_mimic3_promptehr.py | 47 ++ examples/promptehr_generate_local.py | 157 ----- examples/promptehr_mimic3.py | 565 ------------------ 3 files changed, 47 insertions(+), 722 deletions(-) create mode 100644 examples/generate_synthetic_mimic3_promptehr.py delete mode 100644 examples/promptehr_generate_local.py delete mode 100644 examples/promptehr_mimic3.py diff --git a/examples/generate_synthetic_mimic3_promptehr.py b/examples/generate_synthetic_mimic3_promptehr.py new file mode 100644 index 000000000..5eefb7ff7 --- /dev/null +++ b/examples/generate_synthetic_mimic3_promptehr.py @@ -0,0 +1,47 @@ +"""PromptEHR: Synthetic MIMIC-III Patient Generation. + +Load a trained PromptEHR checkpoint and generate synthetic patients. + +Reference: + Wang et al. "PromptEHR: Conditional Electronic Healthcare Records + Generation with Prompt Learning." EMNLP 2023. + https://arxiv.org/abs/2211.01761 +""" + +import json + +from pyhealth.datasets import MIMIC3Dataset +from pyhealth.models import PromptEHR +from pyhealth.tasks import promptehr_generation_mimic3_fn + +MIMIC3_ROOT = "/srv/local/data/physionet.org/files/mimiciii/1.4" +CHECKPOINT_PATH = "./save/promptehr/checkpoint.pt" +OUTPUT_PATH = "./save/promptehr/synthetic_patients.json" +NUM_SAMPLES = 10_000 + +# 1. Load dataset + apply task (needed for processor/vocab reconstruction) +dataset = MIMIC3Dataset( + root=MIMIC3_ROOT, + tables=["patients", "admissions", "diagnoses_icd"], + code_mapping={}, +) +sample_dataset = dataset.set_task(promptehr_generation_mimic3_fn) + +# 2. Load checkpoint +model = PromptEHR(dataset=sample_dataset) +model.load_model(CHECKPOINT_PATH) +print(f"Loaded checkpoint from {CHECKPOINT_PATH}") + +# 3. Generate +print(f"Generating {NUM_SAMPLES} synthetic patients...") +synthetic = model.synthesize_dataset(num_samples=NUM_SAMPLES) +print(f"Generated {len(synthetic)} patients") + +# 4. Save +with open(OUTPUT_PATH, "w") as f: + json.dump(synthetic, f, indent=2) +print(f"Saved to {OUTPUT_PATH}") + +# Summary stats +avg_visits = sum(len(p["visits"]) for p in synthetic) / len(synthetic) +print(f"Average visits per patient: {avg_visits:.2f}") diff --git a/examples/promptehr_generate_local.py b/examples/promptehr_generate_local.py deleted file mode 100644 index 33b9ad41c..000000000 --- a/examples/promptehr_generate_local.py +++ /dev/null @@ -1,157 +0,0 @@ -#!/usr/bin/env python3 -"""Quick local generation test for PromptEHR (CPU-only). - -This script demonstrates how to: -1. Load a trained PromptEHR checkpoint -2. Generate synthetic patients on CPU (no GPU required) -3. Display results in human-readable format - -Usage: - python3 examples/promptehr_generate_local.py -""" - -import sys -sys.path.insert(0, '/u/jalenj4/final/PyHealth') - -import torch -import logging -from pathlib import Path - -# PyHealth imports -from pyhealth.models import PromptEHR -from pyhealth.datasets.promptehr_dataset import load_mimic_data -from pyhealth.models.promptehr import ( - VisitStructureSampler, - generate_patient_with_structure_constraints -) - - -def main(): - """Generate 10 synthetic patients locally on CPU.""" - - # Setup - device = torch.device("cpu") # Force CPU (no GPU required) - logging.basicConfig( - level=logging.WARNING, # Reduce noise, only show warnings/errors - format='%(message)s' - ) - logger = logging.getLogger(__name__) - - print("\n" + "="*80) - print("PromptEHR Local Generation Test (CPU mode)") - print("="*80) - - # Load checkpoint - print("\n[1/4] Loading trained checkpoint...") - checkpoint_path = "./promptehr_outputs/checkpoints/final_model.pt" - - if not Path(checkpoint_path).exists(): - print(f"ERROR: Checkpoint not found at {checkpoint_path}") - print("Please ensure training has completed and checkpoint exists.") - return - - checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) - tokenizer = checkpoint['tokenizer'] - - # Add convenience properties and methods if not present - # (for compatibility with old checkpoints saved before these were added) - if not hasattr(tokenizer, 'bos_token_id'): - tokenizer.pad_token_id = tokenizer.vocabulary("") # ID 0 - tokenizer.bos_token_id = tokenizer.vocabulary("") # ID 1 - tokenizer.eos_token_id = tokenizer.vocabulary("") # ID 2 - tokenizer.code_offset = 7 # First diagnosis code ID (after 7 special tokens) - if not hasattr(tokenizer, 'convert_tokens_to_ids'): - # Add method alias: pehr_scratch API uses convert_tokens_to_ids(token) → int - def convert_tokens_to_ids(token: str) -> int: - return tokenizer.convert_tokens_to_indices([token])[0] - tokenizer.convert_tokens_to_ids = convert_tokens_to_ids - if not hasattr(tokenizer, 'vocab'): - # Add vocab object for idx2code and code2idx mappings - class VocabCompat: - def __init__(self, tok): - self.idx2code = tok.vocabulary.idx2token - self.code2idx = tok.vocabulary.token2idx - def __len__(self): - return len(self.idx2code) - tokenizer.vocab = VocabCompat(tokenizer) - - # Rebuild model - print("[2/4] Rebuilding model from checkpoint...") - config = checkpoint['config'] - model = PromptEHR(**config) - model.bart_model.load_state_dict(checkpoint['model_state_dict']) - model.to(device) - model.eval() - - print(f" Model vocabulary size: {config['_custom_vocab_size']}") - print(f" Hidden dimension: {config['d_hidden']}") - print(f" Prompt length: {config['prompt_length']}") - - # Load MIMIC data for structure sampling - print("[3/4] Loading MIMIC-III data for structure sampling...") - print(" (Loading 1000 patients for realistic visit distributions)") - - patient_records, _ = load_mimic_data( - patients_path="/u/jalenj4/pehr_scratch/data_files/PATIENTS.csv", - admissions_path="/u/jalenj4/pehr_scratch/data_files/ADMISSIONS.csv", - diagnoses_path="/u/jalenj4/pehr_scratch/data_files/DIAGNOSES_ICD.csv", - num_patients=1000, - logger=logger - ) - - # Initialize structure sampler - structure_sampler = VisitStructureSampler(patient_records, seed=42) - print(f" {structure_sampler}") - - # Generate synthetic patients - n_patients = 10 - print(f"\n[4/4] Generating {n_patients} synthetic patients...") - print(" (This will take ~10-15 seconds)") - print() - - print("="*80) - print("SYNTHETIC PATIENTS") - print("="*80) - print() - - for i in range(n_patients): - # Sample realistic visit structure - target_structure = structure_sampler.sample_structure() - - # Generate patient - result = generate_patient_with_structure_constraints( - model=model, - tokenizer=tokenizer, - device=device, - target_structure=target_structure, - temperature=0.7, - top_k=40, - top_p=0.9, - max_codes_per_visit=25 - ) - - # Display patient - demo = result['demographics'] - print(f"Patient {i+1}:") - print(f" Age: {demo['age']} years") - print(f" Sex: {'Male' if demo['sex'] == 0 else 'Female'}") - print(f" Number of visits: {result['num_visits']}") - print(f" Diagnosis codes:") - - for visit_idx, codes in enumerate(result['generated_visits'], 1): - if codes: - print(f" Visit {visit_idx}: {', '.join(codes)}") - else: - print(f" Visit {visit_idx}: (no diagnoses)") - print() - - print("="*80) - print("Generation complete!") - print("="*80) - print() - print(f"Successfully generated {n_patients} synthetic patients on CPU.") - print() - - -if __name__ == "__main__": - main() diff --git a/examples/promptehr_mimic3.py b/examples/promptehr_mimic3.py deleted file mode 100644 index 1f42868d2..000000000 --- a/examples/promptehr_mimic3.py +++ /dev/null @@ -1,565 +0,0 @@ -"""PromptEHR: Training and Generation Example on MIMIC-III - -This example demonstrates the complete PromptEHR pipeline: -1. Load MIMIC-III patient records -2. Train PromptEHR model for synthetic EHR generation -3. Generate synthetic patients with realistic visit structures -4. Evaluate generation quality - -References: - - Paper: "PromptEHR: Conditional Electronic Health Records Generation with Prompt Learning" - - pehr_scratch implementation: /u/jalenj4/pehr_scratch/ -""" - -import os -import sys -import logging -from pathlib import Path -from typing import List, Dict - -import torch -import torch.nn as nn -from torch.utils.data import DataLoader, random_split -from torch.optim import AdamW -from transformers import BartConfig, get_linear_schedule_with_warmup - -# PyHealth imports -from pyhealth.datasets import MIMIC3Dataset -from pyhealth.models import PromptEHR -from pyhealth.trainer import Trainer -from pyhealth.datasets.promptehr_dataset import ( - create_promptehr_tokenizer, - PromptEHRDataset, - load_mimic_data -) -from pyhealth.datasets.promptehr_collator import EHRDataCollator - - -class DeviceAwareCollatorWrapper: - """Wrapper around EHRDataCollator that moves tensors to specified device. - - This wrapper addresses PyHealth Trainer limitation where data is not automatically - moved to device before forward pass. The Trainer directly calls model(**data) at - line 206 without device transfer, requiring collator to handle device placement. - - Args: - collator: Base EHRDataCollator instance - device: Target device ('cuda' or 'cpu') - """ - - def __init__(self, collator: EHRDataCollator, device: str): - """Initialize wrapper with base collator and target device.""" - self.collator = collator - self.device = torch.device(device) - - def __call__(self, batch: List[Dict]) -> Dict[str, torch.Tensor]: - """Collate batch and move all tensors to target device. - - Args: - batch: List of sample dictionaries - - Returns: - Dictionary with batched tensors on target device - """ - # Get batched tensors from base collator (CPU tensors) - batched_data = self.collator(batch) - - # Move all tensors to target device - device_data = { - key: value.to(self.device) if isinstance(value, torch.Tensor) else value - for key, value in batched_data.items() - } - - return device_data - - -# Configure logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) -logger = logging.getLogger(__name__) - - -def train_promptehr( - mimic3_root: str, - output_dir: str = "./promptehr_outputs", - num_patients: int = 46520, # Full MIMIC-III dataset - batch_size: int = 16, - num_epochs: int = 20, - learning_rate: float = 1e-5, - warmup_steps: int = 1000, - val_split: float = 0.2, - device: str = "cuda", - checkpoint_path: str = None -): - """Train PromptEHR model on MIMIC-III dataset. - - Args: - mimic3_root: Path to MIMIC-III data directory containing: - - PATIENTS.csv - - ADMISSIONS.csv - - DIAGNOSES_ICD.csv - output_dir: Directory to save outputs (checkpoints, logs) - num_patients: Number of patients to load (default: full dataset) - batch_size: Training batch size - num_epochs: Number of training epochs - learning_rate: AdamW learning rate - warmup_steps: Linear warmup steps for scheduler - val_split: Validation split ratio - device: Device to use ('cuda' or 'cpu') - checkpoint_path: Path to resume from checkpoint (optional) - - Returns: - Trained PromptEHR model - """ - # Create output directory - output_dir = Path(output_dir) - output_dir.mkdir(parents=True, exist_ok=True) - checkpoint_dir = output_dir / "checkpoints" - checkpoint_dir.mkdir(exist_ok=True) - - logger.info("=" * 80) - logger.info("PromptEHR Training Pipeline") - logger.info("=" * 80) - logger.info(f"MIMIC-III root: {mimic3_root}") - logger.info(f"Output directory: {output_dir}") - logger.info(f"Device: {device}") - - # Step 1: Load MIMIC-III patient records - logger.info("\n" + "=" * 80) - logger.info("Loading MIMIC-III Patient Records") - logger.info("=" * 80) - - patients_path = os.path.join(mimic3_root, "PATIENTS.csv") - admissions_path = os.path.join(mimic3_root, "ADMISSIONS.csv") - diagnoses_path = os.path.join(mimic3_root, "DIAGNOSES_ICD.csv") - - patient_records, diagnosis_codes = load_mimic_data( - patients_path=patients_path, - admissions_path=admissions_path, - diagnoses_path=diagnoses_path, - num_patients=num_patients, - logger=logger - ) - - logger.info(f"Loaded {len(patient_records)} patients") - logger.info(f"Vocabulary size: {len(diagnosis_codes)} diagnosis codes") - - # Step 2: Create tokenizer - logger.info("\n" + "=" * 80) - logger.info("Creating Tokenizer") - logger.info("=" * 80) - - tokenizer = create_promptehr_tokenizer(diagnosis_codes) - vocab_size = tokenizer.get_vocabulary_size() - logger.info(f"Tokenizer vocabulary size: {vocab_size}") - logger.info(f" Special tokens: 7") - logger.info(f" Diagnosis codes: {len(diagnosis_codes)}") - logger.info(f" Code offset: 7") - - # Step 3: Create dataset - logger.info("\n" + "=" * 80) - logger.info("Creating Dataset") - logger.info("=" * 80) - - dataset = PromptEHRDataset(patient_records, tokenizer, logger) - logger.info(f"Dataset size: {len(dataset)} patients") - - # Train/validation split - train_size = int((1 - val_split) * len(dataset)) - val_size = len(dataset) - train_size - train_dataset, val_dataset = random_split(dataset, [train_size, val_size]) - logger.info(f"Train size: {train_size}, Validation size: {val_size}") - - # Create data collator - # CRITICAL FIX: Disable token replacement to prevent distribution inversion - # Token replacement causes rare codes to be enriched 3.24x and common codes depleted to 0.85x - base_collator = EHRDataCollator( - tokenizer=tokenizer, - max_seq_length=512, - logger=logger, - corruption_prob=0.5, - use_mask_infilling=True, - use_token_deletion=True, - use_token_replacement=False # DISABLED: Causes 4700x frequency inversion - ) - - # Wrap collator to handle device placement - # PyHealth Trainer does not move data to device (line 206: model(**data)) - # so we must handle device transfer in the collator - collator = DeviceAwareCollatorWrapper(base_collator, device) - logger.info(f"Using device-aware collator wrapper (target device: {device})") - - # Create data loaders - train_loader = DataLoader( - train_dataset, - batch_size=batch_size, - shuffle=True, - collate_fn=collator - ) - val_loader = DataLoader( - val_dataset, - batch_size=batch_size, - shuffle=False, - collate_fn=collator - ) - - logger.info(f"Train batches: {len(train_loader)}, Validation batches: {len(val_loader)}") - - # Step 4: Initialize model - logger.info("\n" + "=" * 80) - logger.info("Initializing PromptEHR Model") - logger.info("=" * 80) - - model = PromptEHR( - dataset=None, # Generative model, no discriminative task - n_num_features=1, # Age - cat_cardinalities=[2], # Gender (M/F) - d_hidden=128, - prompt_length=1, - bart_config_name="facebook/bart-base", - _custom_vocab_size=vocab_size # Custom vocab size for MIMIC-III - ) - - total_params = sum(p.numel() for p in model.parameters()) - trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) - logger.info(f"Total parameters: {total_params:,}") - logger.info(f"Trainable parameters: {trainable_params:,}") - - # Step 5: Configure trainer - logger.info("\n" + "=" * 80) - logger.info("Configuring Trainer") - logger.info("=" * 80) - - trainer = Trainer( - model=model, - checkpoint_path=checkpoint_path, - metrics=["loss"], - device=device, - enable_logging=True, - output_path=str(output_dir) - ) - - # Step 6: Train - logger.info("\n" + "=" * 80) - logger.info("Starting Training") - logger.info("=" * 80) - - trainer.train( - train_dataloader=train_loader, - val_dataloader=val_loader, - epochs=num_epochs, - optimizer_params={"lr": learning_rate, "weight_decay": 0.01}, - monitor="loss" - ) - - # Step 7: Save final model - final_checkpoint = checkpoint_dir / "final_model.pt" - torch.save({ - 'model_state_dict': model.bart_model.state_dict(), # Save BART model state - 'tokenizer': tokenizer, - 'diagnosis_codes': diagnosis_codes, - 'config': { - 'dataset': None, - 'n_num_features': 1, - 'cat_cardinalities': [2], - 'd_hidden': 128, - 'prompt_length': 1, - 'bart_config_name': "facebook/bart-base", - '_custom_vocab_size': vocab_size - } - }, final_checkpoint) - logger.info(f"\nFinal model saved to: {final_checkpoint}") - - logger.info("\n" + "=" * 80) - logger.info("Training Complete!") - logger.info("=" * 80) - - return model, tokenizer - - -def generate_synthetic_patients( - model: PromptEHR, - tokenizer, - patient_records: List, - num_patients: int = 100, - temperature: float = 0.7, - alpha: float = 2.0, - device: str = "cuda", - mimic3_root: str = None -): - """Generate synthetic patients using trained PromptEHR model. - - Args: - model: Trained PromptEHR model - tokenizer: PromptEHR tokenizer - patient_records: Real patient records (for structure sampling) - num_patients: Number of synthetic patients to generate - temperature: Sampling temperature - device: Device to use - mimic3_root: Path to MIMIC-III training data (for first code prior) - - Returns: - List of generated patient dictionaries - """ - from pyhealth.models.promptehr import VisitStructureSampler - from pyhealth.models.promptehr.generation import ( - DemographicSampler, - build_frequency_prior, - generate_with_frequency_prior - ) - - logger.info("\n" + "=" * 80) - logger.info(f"Generating {num_patients} Synthetic Patients") - logger.info("=" * 80) - - # Initialize visit structure sampler - structure_sampler = VisitStructureSampler(patient_records, seed=42) - logger.info(f"Structure sampler: {structure_sampler}") - - # Initialize demographic sampler - demographic_sampler = DemographicSampler(patient_records, seed=42) - logger.info(f"Demographic sampler: {demographic_sampler}") - - # Build frequency prior for ALL code generation - frequency_prior = None - freq_path = Path(mimic3_root).parent / "promptehr_outputs" / "training_frequencies.json" - if not freq_path.exists(): - freq_path = Path("promptehr_outputs") / "training_frequencies.json" - - if freq_path.exists(): - logger.info(f"Building frequency prior from {freq_path}...") - try: - frequency_prior = build_frequency_prior( - tokenizer, - frequency_path=str(freq_path), - vocab_size=len(tokenizer.vocab.idx2code) - ) - logger.info(f"Frequency prior built: shape {frequency_prior.shape}") - except Exception as e: - logger.warning(f"Failed to build frequency prior: {e}") - logger.warning("Continuing without frequency guidance...") - else: - logger.warning(f"training_frequencies.json not found at {freq_path}") - logger.warning("Continuing without frequency guidance...") - - # Set model to eval mode - model.eval() - model.to(device) - - # Generate patients - generated_patients = [] - for i in range(num_patients): - if (i + 1) % 20 == 0: - logger.info(f"Generated {i + 1}/{num_patients} patients...") - - # Sample realistic visit structure - target_structure = structure_sampler.sample_structure() - - # Sample demographics from empirical distribution - demographics = demographic_sampler.sample() - age = demographics['age'] - sex = demographics['sex'] - - # Generate patient with frequency-guided sampling - if frequency_prior is not None: - result = generate_with_frequency_prior( - model=model, - tokenizer=tokenizer, - device=device, - target_structure=target_structure, - frequency_prior=frequency_prior, - alpha=alpha, # Frequency prior weight (optimal: 2.0 from diagnostic) - age=age, - sex=sex, - temperature=temperature, # Sampling temperature (optimal: 1.0 from diagnostic) - top_k=0, # Disabled - use full vocabulary - top_p=0.95, # Nucleus sampling for quality - max_codes_per_visit=25 - ) - else: - # Fallback to regular generation if no frequency prior - from pyhealth.models.promptehr import generate_patient_with_structure_constraints - result = generate_patient_with_structure_constraints( - model=model, - tokenizer=tokenizer, - device=device, - target_structure=target_structure, - age=age, - sex=sex, - temperature=0.5, - top_k=0, - top_p=0.95, - max_codes_per_visit=25 - ) - - # Store result - demo = result['demographics'] - generated_patients.append({ - 'patient_id': f"SYNTH_{i+1:04d}", - 'age': demo['age'], - 'sex': 'M' if demo['sex'] == 0 else 'F', - 'num_visits': result['num_visits'], - 'visits': result['generated_visits'] - }) - - logger.info(f"\nGeneration complete: {num_patients} patients created") - - # Display statistics - total_visits = sum(p['num_visits'] for p in generated_patients) - total_codes = sum(len(code) for p in generated_patients for visit in p['visits'] for code in visit) - unique_codes = len(set(code for p in generated_patients for visit in p['visits'] for code in visit)) - - logger.info(f"\nDataset Statistics:") - logger.info(f" Total patients: {num_patients}") - logger.info(f" Total visits: {total_visits}") - logger.info(f" Total diagnosis codes: {total_codes}") - logger.info(f" Unique codes: {unique_codes}") - logger.info(f" Average visits/patient: {total_visits/num_patients:.2f}") - logger.info(f" Average codes/patient: {total_codes/num_patients:.1f}") - - return generated_patients - - -def save_synthetic_dataset( - patients: List[Dict], - output_path: str, - format: str = "csv" -): - """Save generated patients to file. - - Args: - patients: List of patient dictionaries - output_path: Path to save file - format: Output format ('csv' or 'json') - """ - import csv - import json - - output_path = Path(output_path) - output_path.parent.mkdir(parents=True, exist_ok=True) - - if format == "csv": - with open(output_path, 'w', newline='') as f: - writer = csv.writer(f) - writer.writerow(['patient_id', 'age', 'sex', 'num_visits', 'visit_num', 'diagnosis_codes']) - - for patient in patients: - for visit_idx, visit_codes in enumerate(patient['visits']): - codes_str = ';'.join(visit_codes) - writer.writerow([ - patient['patient_id'], - f"{patient['age']:.1f}", - patient['sex'], - patient['num_visits'], - visit_idx + 1, - codes_str - ]) - - logger.info(f"Saved {len(patients)} patients to {output_path} (CSV format)") - - elif format == "json": - with open(output_path, 'w') as f: - json.dump(patients, f, indent=2) - - logger.info(f"Saved {len(patients)} patients to {output_path} (JSON format)") - - -def main(): - """Main entry point for PromptEHR training and generation.""" - import argparse - - parser = argparse.ArgumentParser(description="PromptEHR Training and Generation") - parser.add_argument("--mimic3_root", type=str, required=True, - help="Path to MIMIC-III data directory") - parser.add_argument("--output_dir", type=str, default="./promptehr_outputs", - help="Output directory for checkpoints and results") - parser.add_argument("--num_patients", type=int, default=46520, - help="Number of patients to load for training") - parser.add_argument("--batch_size", type=int, default=16, - help="Training batch size") - parser.add_argument("--num_epochs", type=int, default=20, - help="Number of training epochs") - parser.add_argument("--learning_rate", type=float, default=1e-5, - help="Learning rate") - parser.add_argument("--device", type=str, default="cuda", - help="Device to use (cuda or cpu)") - parser.add_argument("--checkpoint", type=str, default=None, - help="Path to checkpoint to resume from") - parser.add_argument("--generate_only", action="store_true", - help="Skip training, only generate (requires --checkpoint)") - parser.add_argument("--num_synthetic", type=int, default=100, - help="Number of synthetic patients to generate") - parser.add_argument("--temperature", type=float, default=0.7, - help="Sampling temperature for generation") - parser.add_argument("--alpha", type=float, default=2.0, - help="Frequency prior weight (alpha) for generation") - - args = parser.parse_args() - - # Training - if not args.generate_only: - model, tokenizer = train_promptehr( - mimic3_root=args.mimic3_root, - output_dir=args.output_dir, - num_patients=args.num_patients, - batch_size=args.batch_size, - num_epochs=args.num_epochs, - learning_rate=args.learning_rate, - device=args.device, - checkpoint_path=args.checkpoint - ) - else: - # Load from checkpoint - if args.checkpoint is None: - raise ValueError("--checkpoint required when using --generate_only") - - logger.info(f"Loading model from checkpoint: {args.checkpoint}") - # PyTorch 2.6+ requires weights_only=False to load checkpoints with custom objects (tokenizer) - checkpoint = torch.load(args.checkpoint, weights_only=False) - tokenizer = checkpoint['tokenizer'] - - model = PromptEHR(**checkpoint['config']) - model.bart_model.load_state_dict(checkpoint['model_state_dict']) - model.to(args.device) - model.eval() - - # Load patient records for structure sampling - patients_path = os.path.join(args.mimic3_root, "PATIENTS.csv") - admissions_path = os.path.join(args.mimic3_root, "ADMISSIONS.csv") - diagnoses_path = os.path.join(args.mimic3_root, "DIAGNOSES_ICD.csv") - - patient_records, _ = load_mimic_data( - patients_path=patients_path, - admissions_path=admissions_path, - diagnoses_path=diagnoses_path, - num_patients=args.num_patients, - logger=logger - ) - - # Generation - generated_patients = generate_synthetic_patients( - model=model, - tokenizer=tokenizer, - patient_records=patient_records, - num_patients=args.num_synthetic, - temperature=args.temperature, - alpha=args.alpha, - device=args.device, - mimic3_root=args.mimic3_root - ) - - # Save results - output_csv = Path(args.output_dir) / f"synthetic_patients_{args.num_synthetic}.csv" - save_synthetic_dataset(generated_patients, output_csv, format="csv") - - logger.info("\n" + "=" * 80) - logger.info("PromptEHR Pipeline Complete!") - logger.info("=" * 80) - logger.info(f"Output directory: {args.output_dir}") - logger.info(f"Synthetic dataset: {output_csv}") - - -if __name__ == "__main__": - main() From 54a083607d39c7867b55c211a2833aa2058f3162 Mon Sep 17 00:00:00 2001 From: jalengg Date: Sun, 1 Mar 2026 02:24:18 -0600 Subject: [PATCH 05/37] T4: Add PromptEHR PyHealth 2.0 training example --- examples/promptehr_mimic3_training.py | 47 +++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 examples/promptehr_mimic3_training.py diff --git a/examples/promptehr_mimic3_training.py b/examples/promptehr_mimic3_training.py new file mode 100644 index 000000000..8387208db --- /dev/null +++ b/examples/promptehr_mimic3_training.py @@ -0,0 +1,47 @@ +"""PromptEHR: Training on MIMIC-III. + +Train PromptEHR for synthetic EHR generation using PyHealth 2.0 API. + +Reference: + Wang et al. "PromptEHR: Conditional Electronic Health Records Generation + with Prompt Learning." CHIL 2023. +""" + +from pyhealth.datasets import MIMIC3Dataset, split_by_patient +from pyhealth.models import PromptEHR +from pyhealth.tasks import promptehr_generation_mimic3_fn + +MIMIC3_ROOT = "/srv/local/data/physionet.org/files/mimiciii/1.4" + +# 1. Load MIMIC-III +dataset = MIMIC3Dataset( + root=MIMIC3_ROOT, + tables=["patients", "admissions", "diagnoses_icd"], + code_mapping={}, +) + +# 2. Apply generation task +sample_dataset = dataset.set_task(promptehr_generation_mimic3_fn) +print(f"Patients: {len(sample_dataset)}") +sample_dataset.stat() + +# 3. Split +train, val, test = split_by_patient(sample_dataset, [0.8, 0.1, 0.1]) + +# 4. Initialize model +model = PromptEHR( + dataset=sample_dataset, + n_num_features=1, + cat_cardinalities=[2], + d_hidden=128, + prompt_length=1, + epochs=20, + batch_size=16, + lr=1e-5, + warmup_steps=1000, + save_dir="./save/promptehr/", +) + +# 5. Train +model.train_model(train, val) +print("Training complete. Checkpoint saved to ./save/promptehr/") From 9e9589a2a95e885a58426193f3b3cbeda03be868 Mon Sep 17 00:00:00 2001 From: jalengg Date: Sun, 1 Mar 2026 02:24:58 -0600 Subject: [PATCH 06/37] T7: Update PromptEHR docstrings to Google/PyHealth style --- pyhealth/models/promptehr/model.py | 40 +++++++++++------------------- 1 file changed, 15 insertions(+), 25 deletions(-) diff --git a/pyhealth/models/promptehr/model.py b/pyhealth/models/promptehr/model.py index 08b710ab1..55c9df2ee 100644 --- a/pyhealth/models/promptehr/model.py +++ b/pyhealth/models/promptehr/model.py @@ -373,8 +373,20 @@ class _PromptEHRVocab: 7+ = diagnosis codes NestedSequenceProcessor uses pad=0, unk=1, codes=2+. - Mapping: processor_idx i → BART token i + 5 (for i >= 2). + Mapping: processor_idx i -> BART token i + 5 (for i >= 2). Total BART vocab size = processor.vocab_size() + 5. + + Args: + code_vocab (dict): Mapping of code string to processor index, as + returned by ``NestedSequenceProcessor.code_vocab``. Must include + ``""`` -> 0 and ``""`` -> 1. + + Examples: + >>> vocab = _PromptEHRVocab({"": 0, "": 1, "428": 2, "410": 3}) + >>> isinstance(vocab, _PromptEHRVocab) + True + >>> vocab.total_size + 9 """ PAD = 0 @@ -387,12 +399,7 @@ class _PromptEHRVocab: CODE_OFFSET = 7 def __init__(self, code_vocab: dict): - """Build vocab from NestedSequenceProcessor.code_vocab dict. - - Args: - code_vocab (dict): Mapping of code string → processor index. - Must have ``""`` → 0 and ``""`` → 1. - """ + """Build vocab from NestedSequenceProcessor.code_vocab dict.""" self._bart_to_code: Dict[int, str] = {} for code, pid in code_vocab.items(): if pid >= 2: # skip and @@ -552,24 +559,7 @@ def __init__( max_seq_length: int = 512, save_dir: str = "./save/", ): - """Initialize PromptEHR with vocab derived from the dataset processor. - - Args: - dataset (SampleDataset): PyHealth dataset with - ``input_processors["visits"]`` (NestedSequenceProcessor). - n_num_features (int): Continuous demographic features. Default: 1. - cat_cardinalities (list of int): Category cardinalities. Default: [2]. - d_hidden (int): Prompt encoder hidden dim. Default: 128. - prompt_length (int): Prompt vectors per feature. Default: 1. - bart_config_name (str): Pretrained BART config. Default: - ``"facebook/bart-base"``. - epochs (int): Training epochs. Default: 20. - batch_size (int): Training batch size. Default: 16. - lr (float): AdamW learning rate. Default: 1e-5. - warmup_steps (int): Linear warmup steps. Default: 1000. - max_seq_length (int): Token sequence length cap. Default: 512. - save_dir (str): Checkpoint output directory. Default: ``"./save/"``. - """ + """Initialize PromptEHR with vocab derived from the dataset processor.""" super().__init__(dataset) self.mode = None # skip discriminative evaluation From ad0b27b72a75e41dbfaec7f5d28f2bf141747c0a Mon Sep 17 00:00:00 2001 From: jalengg Date: Sun, 1 Mar 2026 15:46:06 -0600 Subject: [PATCH 07/37] T8: Add PromptEHR integration tests (8 pass, 4 skip MIMIC-III) Also: accept BartConfig object as bart_config_name for tiny test models. --- pyhealth/models/promptehr/model.py | 8 +- tests/integration/__init__.py | 0 .../integration/test_promptehr_end_to_end.py | 431 ++++++++++++++++++ 3 files changed, 437 insertions(+), 2 deletions(-) create mode 100644 tests/integration/__init__.py create mode 100644 tests/integration/test_promptehr_end_to_end.py diff --git a/pyhealth/models/promptehr/model.py b/pyhealth/models/promptehr/model.py index 55c9df2ee..2657c0c7d 100644 --- a/pyhealth/models/promptehr/model.py +++ b/pyhealth/models/promptehr/model.py @@ -551,7 +551,7 @@ def __init__( cat_cardinalities: Optional[list] = None, d_hidden: int = 128, prompt_length: int = 1, - bart_config_name: str = "facebook/bart-base", + bart_config_name: "Union[str, BartConfig]" = "facebook/bart-base", epochs: int = 20, batch_size: int = 16, lr: float = 1e-5, @@ -580,7 +580,11 @@ def __init__( bart_vocab_size = self._vocab.total_size # Configure BART with our custom vocab and special token IDs - bart_config = BartConfig.from_pretrained(bart_config_name) + if isinstance(bart_config_name, str): + bart_config = BartConfig.from_pretrained(bart_config_name) + else: + # Accept a BartConfig object directly (useful for tiny test models) + bart_config = bart_config_name bart_config.vocab_size = bart_vocab_size bart_config.pad_token_id = _PromptEHRVocab.PAD bart_config.bos_token_id = _PromptEHRVocab.BOS diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/integration/test_promptehr_end_to_end.py b/tests/integration/test_promptehr_end_to_end.py new file mode 100644 index 000000000..a3c0bdac6 --- /dev/null +++ b/tests/integration/test_promptehr_end_to_end.py @@ -0,0 +1,431 @@ +"""End-to-end integration tests for the PromptEHR synthetic EHR generation pipeline. + +Category A tests use InMemorySampleDataset with synthetic data — no external +data required and must always pass. + +Category B tests require actual MIMIC-III data and are skipped gracefully when +the data is unavailable. + +The bootstrap pattern mirrors test_corgan_end_to_end.py: load PromptEHR and +InMemorySampleDataset via importlib while stubbing out heavy optional +dependencies (litdata, pyarrow) that are not yet in the venv. transformers IS +available in the venv and is loaded normally. +""" + +import importlib.util +import os +import sys +import tempfile +import unittest +from unittest.mock import MagicMock + + +# --------------------------------------------------------------------------- +# Bootstrap: load PromptEHR, BaseModel, and InMemorySampleDataset without +# triggering pyhealth.models.__init__ (many models have unavailable deps) or +# pyhealth.datasets.__init__ (requires litdata, pyarrow, ...). +# --------------------------------------------------------------------------- + + +def _bootstrap(): + """Load PromptEHR, BaseModel, and InMemorySampleDataset via importlib. + + Returns: + (BaseModel, PromptEHR, InMemorySampleDataset) + """ + import pyhealth # noqa: F401 — top-level __init__ has no heavy deps + + # Stub pyhealth.datasets so that base_model.py's + # "from ..datasets import SampleDataset" resolves cleanly. + if "pyhealth.datasets" not in sys.modules: + ds_stub = MagicMock() + + class _FakeSampleDataset: # noqa: N801 + pass + + ds_stub.SampleDataset = _FakeSampleDataset + sys.modules["pyhealth.datasets"] = ds_stub + + # Stub pyhealth.models so we can control loading without the real __init__. + if "pyhealth.models" not in sys.modules or isinstance( + sys.modules["pyhealth.models"], MagicMock + ): + models_stub = MagicMock() + sys.modules["pyhealth.models"] = models_stub + else: + models_stub = sys.modules["pyhealth.models"] + + # Processors are safe to import normally. + from pyhealth.processors import PROCESSOR_REGISTRY # noqa: F401 + + def _load_file(mod_name, filepath): + spec = importlib.util.spec_from_file_location(mod_name, filepath) + mod = importlib.util.module_from_spec(spec) + sys.modules[mod_name] = mod + spec.loader.exec_module(mod) + return mod + + root = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + ) + models_dir = os.path.join(root, "pyhealth", "models") + promptehr_dir = os.path.join(models_dir, "promptehr") + + # Load base_model and expose via stub. + bm_mod = _load_file( + "pyhealth.models.base_model", os.path.join(models_dir, "base_model.py") + ) + BaseModel = bm_mod.BaseModel + models_stub.BaseModel = BaseModel + + # Create a package stub for pyhealth.models.promptehr so that + # model.py's relative imports (from .conditional_prompt import ...) work. + promptehr_pkg_stub = MagicMock() + sys.modules.setdefault("pyhealth.models.promptehr", promptehr_pkg_stub) + + # Load each PromptEHR submodule in dependency order. + # Each is standalone (only torch + transformers, no cross-module imports). + for mod_name in ( + "conditional_prompt", + "bart_encoder", + "bart_decoder", + "visit_sampler", + "generation", + ): + _load_file( + f"pyhealth.models.promptehr.{mod_name}", + os.path.join(promptehr_dir, f"{mod_name}.py"), + ) + + # Load model.py last (depends on the submodules loaded above + BaseModel). + model_mod = _load_file( + "pyhealth.models.promptehr.model", + os.path.join(promptehr_dir, "model.py"), + ) + PromptEHR = model_mod.PromptEHR + + # Stub litdata so sample_dataset.py can be loaded without the full package. + if "litdata" not in sys.modules: + litdata_pkg = MagicMock() + litdata_pkg.StreamingDataset = type( + "StreamingDataset", (), {"__init__": lambda self, *a, **kw: None} + ) + litdata_utilities = MagicMock() + litdata_utilities_train_test = MagicMock() + litdata_utilities_train_test.deepcopy_dataset = lambda x: x + litdata_utilities.train_test_split = litdata_utilities_train_test + litdata_pkg.utilities = litdata_utilities + sys.modules["litdata"] = litdata_pkg + sys.modules["litdata.utilities"] = litdata_utilities + sys.modules["litdata.utilities.train_test_split"] = ( + litdata_utilities_train_test + ) + + # Load sample_dataset.py directly (bypasses datasets/__init__.py). + ds_file_mod = _load_file( + "pyhealth.datasets.sample_dataset", + os.path.join(root, "pyhealth", "datasets", "sample_dataset.py"), + ) + InMemorySampleDataset = ds_file_mod.InMemorySampleDataset + + return BaseModel, PromptEHR, InMemorySampleDataset + + +BaseModel, PromptEHR, InMemorySampleDataset = _bootstrap() + +import torch # noqa: E402 +from transformers import BartConfig # noqa: E402 + + +# --------------------------------------------------------------------------- +# Shared helpers +# --------------------------------------------------------------------------- + +# Nested lists of code strings — PromptEHR uses nested_sequence schema. +# 8 samples with ≥2 visits each, plus demographics. +_SMALL_SAMPLES = [ + {"patient_id": "p1", "visits": [["A", "B"], ["C", "D"]], "age": 65.0, "gender": 0}, + {"patient_id": "p2", "visits": [["E"], ["F", "G"]], "age": 45.0, "gender": 1}, + {"patient_id": "p3", "visits": [["A", "C"], ["B", "E"]], "age": 55.0, "gender": 0}, + {"patient_id": "p4", "visits": [["D"], ["A"]], "age": 70.0, "gender": 1}, + {"patient_id": "p5", "visits": [["B", "F"], ["C", "G"]], "age": 40.0, "gender": 0}, + {"patient_id": "p6", "visits": [["E", "A"], ["D"]], "age": 60.0, "gender": 1}, + {"patient_id": "p7", "visits": [["G", "B"], ["F", "A"]], "age": 50.0, "gender": 0}, + {"patient_id": "p8", "visits": [["C"], ["D", "E"]], "age": 35.0, "gender": 1}, +] + +# Tiny BART config to keep tests fast (avoids downloading/using 768-dim bart-base). +_TINY_BART_CONFIG = BartConfig( + d_model=32, + encoder_layers=1, + decoder_layers=1, + encoder_ffn_dim=64, + decoder_ffn_dim=64, + encoder_attention_heads=4, + decoder_attention_heads=4, + max_position_embeddings=128, +) + +# Minimal model kwargs — tiny architecture and 1 epoch to keep tests fast. +_SMALL_MODEL_KWARGS = dict( + n_num_features=1, + cat_cardinalities=[2], + d_hidden=32, + prompt_length=1, + bart_config_name=_TINY_BART_CONFIG, + epochs=1, + batch_size=4, + warmup_steps=0, + max_seq_length=64, +) + + +def _make_dataset(samples=None): + if samples is None: + samples = _SMALL_SAMPLES + return InMemorySampleDataset( + samples=samples, + input_schema={"visits": "nested_sequence"}, + output_schema={}, + ) + + +def _make_trained_model(): + """Return a PromptEHR model trained for 1 epoch on _SMALL_SAMPLES.""" + dataset = _make_dataset() + tmpdir = tempfile.mkdtemp() + model = PromptEHR(dataset, save_dir=tmpdir, **_SMALL_MODEL_KWARGS) + model.train_model(dataset) + return model, tmpdir + + +# --------------------------------------------------------------------------- +# Category A: In-Memory Integration Tests (must always pass) +# --------------------------------------------------------------------------- + + +class TestPromptEHRIsBaseModelInstance(unittest.TestCase): + """PromptEHR model is an instance of BaseModel.""" + + def test_model_is_basemodel_instance(self): + dataset = _make_dataset() + model = PromptEHR(dataset, **_SMALL_MODEL_KWARGS) + self.assertIsInstance(model, BaseModel) + + +class TestPromptEHRFeatureKeys(unittest.TestCase): + """model.feature_keys equals ['visits'].""" + + def test_feature_keys(self): + dataset = _make_dataset() + model = PromptEHR(dataset, **_SMALL_MODEL_KWARGS) + self.assertEqual(model.feature_keys, ["visits"]) + + +class TestPromptEHRVocabSize(unittest.TestCase): + """_vocab.total_size equals processor.vocab_size() + 5.""" + + def test_vocab_size_matches_processor(self): + dataset = _make_dataset() + processor = dataset.input_processors["visits"] + model = PromptEHR(dataset, **_SMALL_MODEL_KWARGS) + expected = processor.vocab_size() + 5 + self.assertEqual(model._vocab.total_size, expected) + + +class TestPromptEHRForwardRaisesNotImplementedError(unittest.TestCase): + """Calling forward() raises NotImplementedError. + + PromptEHR is a generative model; the discriminative forward pass is not + applicable. + """ + + def test_forward_not_implemented(self): + dataset = _make_dataset() + model = PromptEHR(dataset, **_SMALL_MODEL_KWARGS) + with self.assertRaises(NotImplementedError): + model.forward() + + +class TestPromptEHRTrainModelRuns(unittest.TestCase): + """train_model completes one epoch without error.""" + + def test_train_model_runs_one_epoch(self): + dataset = _make_dataset() + with tempfile.TemporaryDirectory() as tmpdir: + model = PromptEHR(dataset, save_dir=tmpdir, **_SMALL_MODEL_KWARGS) + try: + model.train_model(dataset, val_dataset=None) + except Exception as exc: # noqa: BLE001 + self.fail(f"train_model raised an unexpected exception: {exc}") + # A checkpoint must be saved after training + ckpt = os.path.join(tmpdir, "checkpoint.pt") + self.assertTrue(os.path.exists(ckpt), f"Expected checkpoint at {ckpt}") + + +class TestPromptEHRSynthesizeCount(unittest.TestCase): + """synthesize_dataset(num_samples=3) returns exactly 3 dicts.""" + + @classmethod + def setUpClass(cls): + cls.model, cls.tmpdir = _make_trained_model() + + def test_synthesize_returns_correct_count(self): + result = self.model.synthesize_dataset(num_samples=3) + self.assertIsInstance(result, list) + self.assertEqual(len(result), 3) + + +class TestPromptEHRSynthesizeOutputStructure(unittest.TestCase): + """Each synthesized dict has patient_id (str) and visits (nested list of str). + + PromptEHR outputs nested visit lists — each patient is a list of visits, + each visit is a list of diagnosis code strings. + """ + + @classmethod + def setUpClass(cls): + cls.model, cls.tmpdir = _make_trained_model() + + def test_synthesize_output_structure(self): + result = self.model.synthesize_dataset(num_samples=3) + for i, item in enumerate(result): + self.assertIsInstance(item, dict, f"Item {i} is not a dict") + self.assertIn("patient_id", item, f"Item {i} missing 'patient_id'") + self.assertIn("visits", item, f"Item {i} missing 'visits'") + self.assertIsInstance( + item["patient_id"], str, f"patient_id in item {i} is not a str" + ) + self.assertIsInstance( + item["visits"], list, f"visits in item {i} is not a list" + ) + # visits is a nested list: list of visits, each visit a list of strings + for visit_idx, visit in enumerate(item["visits"]): + self.assertIsInstance( + visit, list, + f"visit {visit_idx} in item {i} is not a list" + ) + for code in visit: + self.assertIsInstance( + code, str, + f"code '{code}' in visit {visit_idx}, item {i} is not str" + ) + + +class TestPromptEHRSaveLoadRoundtrip(unittest.TestCase): + """save_model then load_model; synthesize_dataset returns correct count.""" + + def test_save_load_roundtrip(self): + dataset = _make_dataset() + with tempfile.TemporaryDirectory() as tmpdir: + model = PromptEHR(dataset, save_dir=tmpdir, **_SMALL_MODEL_KWARGS) + model.train_model(dataset) + ckpt_path = os.path.join(tmpdir, "test_ckpt.pt") + model.save_model(ckpt_path) + self.assertTrue( + os.path.exists(ckpt_path), + f"Expected checkpoint at {ckpt_path}", + ) + model.load_model(ckpt_path) + result = model.synthesize_dataset(num_samples=3) + self.assertEqual(len(result), 3) + + +# --------------------------------------------------------------------------- +# Category B: MIMIC-III Integration Tests (skipped if data unavailable) +# --------------------------------------------------------------------------- + +_MIMIC3_PATH = os.environ.get( + "PYHEALTH_MIMIC3_PATH", + "/srv/local/data/physionet.org/files/mimiciii/1.4", +) + + +class TestPromptEHRMIMIC3Integration(unittest.TestCase): + """End-to-end pipeline test with actual MIMIC-III data. + + Skipped automatically when MIMIC-III is not present on this machine. + """ + + @classmethod + def setUpClass(cls): + cls.skip_integration = False + cls.skip_reason = "" + try: + # Remove bootstrap stubs so we can attempt a real import. + _saved_ds_stub = sys.modules.pop("pyhealth.datasets", None) + try: + import importlib as _il + _il.invalidate_caches() + from pyhealth.datasets import MIMIC3Dataset as _MIMIC3Dataset + from pyhealth.tasks.ehr_generation import PromptEHRGenerationMIMIC3 + except (ImportError, ModuleNotFoundError) as exc: + if _saved_ds_stub is not None: + sys.modules["pyhealth.datasets"] = _saved_ds_stub + raise ImportError(str(exc)) from exc + + cls.dataset = _MIMIC3Dataset( + root=_MIMIC3_PATH, + tables=["patients", "admissions", "diagnoses_icd"], + ) + task = PromptEHRGenerationMIMIC3() + cls.sample_dataset = cls.dataset.set_task(task) + except (FileNotFoundError, OSError, ImportError, ValueError) as exc: + cls.skip_integration = True + cls.skip_reason = str(exc) + + def setUp(self): + if self.skip_integration: + self.skipTest(f"MIMIC-III integration test skipped: {self.skip_reason}") + + def test_mimic3_set_task_returns_nonempty_dataset(self): + """set_task produces at least one sample from MIMIC-III.""" + self.assertGreater(len(self.sample_dataset), 0) + + def test_mimic3_sample_keys(self): + """Every sample must contain patient_id, visits, age, and gender keys.""" + for sample in self.sample_dataset: + self.assertIn("patient_id", sample) + self.assertIn("visits", sample) + self.assertIn("age", sample) + self.assertIn("gender", sample) + + def test_mimic3_visits_are_nested_tensors(self): + """visits must be a list of 1-D int64 tensors (NestedSequenceProcessor output). + + NestedSequenceProcessor encodes each visit as a 1-D LongTensor of + code indices. This verifies the nested_sequence schema round-trips + correctly through set_task. + """ + for sample in self.sample_dataset: + visits = sample["visits"] + self.assertIsInstance(visits, list) + self.assertGreater(len(visits), 0) + for visit in visits: + self.assertIsInstance(visit, torch.Tensor) + self.assertEqual(visit.dtype, torch.long) + + def test_mimic3_full_pipeline_train_and_synthesize(self): + """Train one epoch on MIMIC-III data and synthesize a small batch.""" + with tempfile.TemporaryDirectory() as tmpdir: + model = PromptEHR( + self.sample_dataset, + d_hidden=64, + prompt_length=1, + bart_config_name=_TINY_BART_CONFIG, + epochs=1, + batch_size=16, + warmup_steps=0, + save_dir=tmpdir, + ) + model.train_model(self.sample_dataset, val_dataset=None) + synthetic = model.synthesize_dataset(num_samples=5) + self.assertEqual(len(synthetic), 5) + for item in synthetic: + self.assertIn("patient_id", item) + self.assertIn("visits", item) + self.assertIsInstance(item["visits"], list) + + +if __name__ == "__main__": + unittest.main() From a7be2973c2386a160b740dfc3d95bf0cb6430e02 Mon Sep 17 00:00:00 2001 From: jalengg Date: Sun, 1 Mar 2026 18:42:46 -0600 Subject: [PATCH 08/37] Add PromptEHR Colab notebook: demographic-conditioned synthetic EHR generation --- examples/promptehr_mimic3_colab.ipynb | 252 ++++++++++++++++++++++++++ 1 file changed, 252 insertions(+) create mode 100644 examples/promptehr_mimic3_colab.ipynb diff --git a/examples/promptehr_mimic3_colab.ipynb b/examples/promptehr_mimic3_colab.ipynb new file mode 100644 index 000000000..925efc57a --- /dev/null +++ b/examples/promptehr_mimic3_colab.ipynb @@ -0,0 +1,252 @@ +{ + "nbformat": 4, + "nbformat_minor": 5, + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.0" + }, + "colab": { + "provenance": [], + "gpuType": "T4" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "id": "preamble", + "metadata": {}, + "source": "# PromptEHR: Demographic-Conditioned Synthetic EHR Generation\n\n_Last updated: 2026-03-01_\n\nTrain **PromptEHR** on your MIMIC-III data and generate synthetic patients whose demographic distributions mirror the real population.\n\n## What You'll Need\n\n1. **MIMIC-III Access** (or run in Demo Mode without it). Download 3 files from PhysioNet:\n - `PATIENTS.csv` \u2014 patient demographics (date of birth, gender)\n - `ADMISSIONS.csv` \u2014 hospital admission records (timestamps)\n - `DIAGNOSES_ICD.csv` \u2014 ICD-9 diagnosis codes\n\n2. **Google Colab** (or local environment): Free tier works; GPU recommended.\n\n> **Demo Mode**: No MIMIC-III? Set `PRESET = \"demo\"` and skip the file upload step. The notebook runs the full pipeline with synthetic stand-in data.\n\n## What You'll Get\n\n- A trained PromptEHR model conditioned on patient age and gender\n- Synthetic patients whose age/gender distributions mirror the MIMIC-III population\n- `synthetic_patients.csv` \u2014 flat `SUBJECT_ID, VISIT_NUM, ICD9_CODE` records\n- `synthetic_patients.json` \u2014 nested visit records for PyHealth downstream tasks\n- `quality_report.json` \u2014 statistics for automated evaluation and CI\n\n## How Long It Takes\n\n| Preset | Epochs | Time (T4 GPU) | Use case |\n|--------|--------|----------------|----------|\n| `\"demo\"` | 5 | ~30\u201345 min | First run, CI smoke test |\n| `\"production\"` | 20 | ~3\u20135 hrs | Publication-quality results |\n\n## What Makes PromptEHR Different from HALO\n\nUnlike HALO (which generates patients from a shared unconditional distribution), **PromptEHR conditions generation on patient demographics**. It uses a BART Seq2Seq Transformer with learned \"prompt\" vectors \u2014 one per demographic feature \u2014 prepended to the encoder input. During training, the model learns that older male patients tend to have different diagnosis patterns than young female patients. During generation, demographics are sampled from the real training distribution, so the synthetic cohort's age/gender profile automatically mirrors MIMIC-III.\n\nThis matters clinically: synthetic datasets used for fairness research or subgroup analysis must preserve demographic distributions. PromptEHR provides this guarantee by design.\n\n**Reference**: Wang et al., \"PromptEHR: Conditional Electronic Healthcare Records Generation with Prompt Learning.\" EMNLP 2023. https://arxiv.org/abs/2211.01761" + }, + { + "cell_type": "markdown", + "id": "s1-header", + "metadata": {}, + "source": "---\n# 1. Setup & Installation" + }, + { + "cell_type": "code", + "id": "s1-setup", + "metadata": {}, + "outputs": [], + "execution_count": null, + "source": "import subprocess\nimport sys\n\n# Install PyHealth from GitHub\nFORK = 'jalengg'\nBRANCH = 'promptehr-pr-integration'\ninstall_url = f\"git+https://github.com/{FORK}/PyHealth.git@{BRANCH}\"\nsubprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", install_url,\n \"--quiet\", \"--no-cache-dir\"])\nprint(f\"\u2713 PyHealth installed from {FORK}/{BRANCH}\")\n\n# Environment detection \u2014 MUST come before any google.colab import\ntry:\n import google.colab # noqa: F401\n IN_COLAB = True\nexcept ImportError:\n IN_COLAB = False\n\nimport os\nimport json\nimport glob\nimport random\nimport numpy as np\nimport torch\nimport pandas as pd\nimport matplotlib.pyplot as plt\nfrom datetime import datetime\nfrom collections import Counter\nfrom IPython.display import display\n\nprint(f\"Running in Colab: {IN_COLAB}\")\nprint(f\"PyTorch: {torch.__version__}\")\nif torch.cuda.is_available():\n print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n gb = torch.cuda.get_device_properties(0).total_memory / 1e9\n print(f\"GPU memory: {gb:.1f} GB\")\nelse:\n print(\"WARNING: No GPU detected. Training will be slow.\")\n print(\" \u2192 Runtime \u2192 Change runtime type \u2192 T4 GPU\")\n\n# Verify HuggingFace transformers version (>=4.48.3 required for use_cpu parameter)\nimport transformers\n_ver = tuple(int(x) for x in transformers.__version__.split(\".\")[:2])\nassert _ver >= (4, 48), (\n f\"transformers>=4.48.3 required (got {transformers.__version__}). \"\n \"Fix: pip install transformers --upgrade\"\n)\nprint(f\"transformers: {transformers.__version__} \u2713\")\nprint(\"\u2713 All setup complete\")" + }, + { + "cell_type": "markdown", + "id": "s2-header", + "metadata": {}, + "source": "---\n# 2. Configuration" + }, + { + "cell_type": "markdown", + "id": "s2-desc", + "metadata": {}, + "source": "Configure all parameters here. **This is the only cell you need to modify.**\n\n- **`PRESET = \"demo\"`** \u2014 5 epochs, 1 K synthetic patients, ~30\u201345 min on T4\n- **`PRESET = \"production\"`** \u2014 20 epochs, 10 K synthetic patients, ~3\u20135 hrs on T4" + }, + { + "cell_type": "code", + "id": "s2-config", + "metadata": {}, + "outputs": [], + "execution_count": null, + "source": "# ============================================================\n# CONFIGURATION \u2014 All modifiable parameters in one place\n# ============================================================\n\n# --- Preset ---\nPRESET = \"demo\" # \"demo\" or \"production\"\n\n# --- Training parameters ---\nif PRESET == \"demo\":\n EPOCHS = 5\n BATCH_SIZE = 16\n N_SYNTHETIC_SAMPLES = 1_000\n WARMUP_STEPS = 100\nelif PRESET == \"production\":\n EPOCHS = 20\n BATCH_SIZE = 16\n N_SYNTHETIC_SAMPLES = 10_000\n WARMUP_STEPS = 1_000\n\nLR = 1e-5 # Paper LR; low to avoid catastrophic forgetting of BART weights\nMAX_SEQ_LENGTH = 512 # Max tokens per patient (visits + special tokens)\n\n# --- Model architecture ---\nD_HIDDEN = 128 # Hidden dim for demographic prompt encoder\nPROMPT_LENGTH = 1 # Prompt vectors per demographic feature (1 is sufficient per paper)\n\n# --- BART backbone ---\n# \"facebook/bart-base\": pretrained BART (139 M params, 768 hidden dim).\n# PromptEHR fine-tunes these weights rather than training from scratch \u2014\n# the pretrained sequence modeling prior means even 20 epochs can produce good results.\nBART_CONFIG_NAME = \"facebook/bart-base\"\n\n# --- Generation parameters ---\nRANDOM_SAMPLING = True # True: nucleus sampling (diverse), False: greedy (deterministic)\nTEMPERATURE = 0.7 # Lower = more common codes. Higher = more rare/diverse codes.\nTOP_P = 0.95 # Nucleus sampling: sample from top 95% probability mass.\n\n# --- Reproducibility ---\nSEED = 42\n\n# --- Paths (all derived from BASE_DIR) ---\nBASE_DIR = '/content/drive/MyDrive/PromptEHR_Training' if IN_COLAB else './promptehr_training'\nDATA_DIR = f'{BASE_DIR}/data'\nCHECKPOINT_DIR = f'{BASE_DIR}/checkpoints'\nOUTPUT_DIR = f'{BASE_DIR}/output'\n\nfor d in [DATA_DIR, CHECKPOINT_DIR, OUTPUT_DIR]:\n os.makedirs(d, exist_ok=True)\n\nprint(f\"Preset: {PRESET}\")\nprint(f\"Epochs: {EPOCHS} | Batch size: {BATCH_SIZE} | LR: {LR}\")\nprint(f\"Synthetic: {N_SYNTHETIC_SAMPLES:,} patients\")\nprint(f\"Base directory: {BASE_DIR}\")\nprint(\"\u2713 Configuration complete\")" + }, + { + "cell_type": "markdown", + "id": "s3-header", + "metadata": {}, + "source": "---\n# 3. Data Upload" + }, + { + "cell_type": "markdown", + "id": "s3-desc", + "metadata": {}, + "source": "Upload your MIMIC-III CSV files. PromptEHR needs **3 files** (one more than HALO \u2014 `PATIENTS.csv` is required for demographic conditioning):\n\n1. `PATIENTS.csv` \u2014 date of birth and gender\n2. `ADMISSIONS.csv` \u2014 admission timestamps (used to compute age at first admission)\n3. `DIAGNOSES_ICD.csv` \u2014 ICD-9 diagnosis codes\n\nFiles persist across Colab sessions when saved to Google Drive.\n\n**No MIMIC-III?** The next cell automatically activates Demo Mode." + }, + { + "cell_type": "code", + "id": "s3-upload", + "metadata": {}, + "outputs": [], + "execution_count": null, + "source": "DEMO_MODE = False\n\n# Mount Drive (Colab only)\nif IN_COLAB:\n from google.colab import drive\n drive.mount('/content/drive')\n print(\"\u2713 Google Drive mounted\")\n\n# Check which files exist\nrequired_files = {\n 'PATIENTS.csv': 'Patient demographics (DOB, gender)',\n 'ADMISSIONS.csv': 'Admission records (timestamps)',\n 'DIAGNOSES_ICD.csv': 'ICD-9 diagnosis codes',\n}\nexisting = {f: os.path.exists(f'{DATA_DIR}/{f}') for f in required_files}\nmissing = [f for f, ok in existing.items() if not ok]\n\nprint(\"\\nMIMIC-III file status:\")\nfor fname, desc in required_files.items():\n mark = \"\u2713\" if existing[fname] else \"\u2717 MISSING\"\n print(f\" {mark} {fname} \u2014 {desc}\")\n\nif missing and IN_COLAB:\n print(f\"\\nUploading {len(missing)} missing file(s)...\")\n from google.colab import files as _colab_files\n uploaded = _colab_files.upload()\n for fname, data in uploaded.items():\n dest = f'{DATA_DIR}/{fname}'\n with open(dest, 'wb') as f:\n f.write(data)\n print(f\" Saved {fname} \u2192 {dest}\")\n missing = [f for f in required_files if not os.path.exists(f'{DATA_DIR}/{f}')]\n\nif missing:\n print(f\"\\nMIMIC-III files not available ({missing}).\")\n print(\"\u2192 Activating Demo Mode \u2014 full pipeline with synthetic stand-in data.\")\n DEMO_MODE = True\nelse:\n print(\"\\n\u2713 All 3 MIMIC-III files present. Running in MIMIC-III mode.\")" + }, + { + "cell_type": "code", + "id": "s3-demo", + "metadata": {}, + "outputs": [], + "execution_count": null, + "source": "if DEMO_MODE:\n print(\"Setting up Demo Mode data...\")\n from pyhealth.datasets.sample_dataset import InMemorySampleDataset\n\n # Synthetic stand-in: 200 patients, 2-6 visits, realistic ICD-9 codes.\n # Exercises the full pipeline without any real patient data.\n random.seed(SEED)\n icd9_pool = [\n \"428.0\", \"401.9\", \"250.00\", \"272.4\", \"410.71\", \"486\",\n \"585.3\", \"V58.61\", \"412\", \"414.01\", \"276.1\", \"285.9\",\n \"584.9\", \"305.1\", \"290.0\", \"427.31\", \"518.81\", \"496\",\n \"038.9\", \"599.0\",\n ]\n demo_samples = []\n for i in range(200):\n n_visits = random.randint(2, 6)\n visits = [random.sample(icd9_pool, random.randint(1, 5)) for _ in range(n_visits)]\n demo_samples.append({\n \"patient_id\": f\"DEMO_{i:04d}\",\n \"visits\": visits,\n \"age\": float(random.randint(18, 89)),\n \"gender\": random.randint(0, 1),\n })\n print(f\"\u2713 Demo dataset: {len(demo_samples)} patients, up to 6 visits each\")\n print(\" (Replace with real MIMIC-III data for publication-quality results)\")" + }, + { + "cell_type": "code", + "id": "s3-validate", + "metadata": {}, + "outputs": [], + "execution_count": null, + "source": "if not DEMO_MODE:\n print(\"Validating MIMIC-III files...\")\n _patients = pd.read_csv(f'{DATA_DIR}/PATIENTS.csv')\n assert 'SUBJECT_ID' in _patients.columns, \"PATIENTS.csv missing SUBJECT_ID\"\n assert 'GENDER' in _patients.columns, \"PATIENTS.csv missing GENDER\"\n assert 'DOB' in _patients.columns, \"PATIENTS.csv missing DOB\"\n print(f\"\u2713 PATIENTS.csv: {len(_patients):>8,} rows\")\n\n _admissions = pd.read_csv(f'{DATA_DIR}/ADMISSIONS.csv')\n assert 'SUBJECT_ID' in _admissions.columns, \"ADMISSIONS.csv missing SUBJECT_ID\"\n assert 'HADM_ID' in _admissions.columns, \"ADMISSIONS.csv missing HADM_ID\"\n print(f\"\u2713 ADMISSIONS.csv: {len(_admissions):>8,} rows\")\n\n _diagnoses = pd.read_csv(f'{DATA_DIR}/DIAGNOSES_ICD.csv')\n assert 'ICD9_CODE' in _diagnoses.columns, \"DIAGNOSES_ICD.csv missing ICD9_CODE\"\n print(f\"\u2713 DIAGNOSES_ICD.csv: {len(_diagnoses):>8,} rows\")\n\n del _patients, _admissions, _diagnoses # free memory\n print(\"\\n\u2713 All files validated successfully\")" + }, + { + "cell_type": "markdown", + "id": "s4-header", + "metadata": {}, + "source": "---\n# 4. Training" + }, + { + "cell_type": "markdown", + "id": "s4-desc", + "metadata": {}, + "source": "**What happens during training:**\n\n1. **Dataset loading**: PyHealth reads MIMIC-III and creates one sample per patient (nested visit sequences + demographics: age at first admission, gender).\n2. **Tokenization**: Each ICD-9 code is mapped to a unique BART token ID. Special tokens mark visit boundaries: `[VISIT_START]`, `[VISIT_END]`, `[SEQ_END]`.\n3. **Demographic prompts**: Age and gender are encoded into learned prompt vectors prepended to the BART encoder input \u2014 steering the model toward age/gender-appropriate diagnosis patterns.\n4. **Fine-tuning**: HuggingFace Trainer fine-tunes the BART Seq2Seq model to predict the next token conditioned on the demographic prompts.\n5. **Checkpoint**: Saved to `{CHECKPOINT_DIR}/checkpoint.pt` after training.\n\nThe `WARMUP_STEPS` ramp up the learning rate gradually during early training, preventing catastrophic forgetting of BART's pretrained sequence modeling capabilities." + }, + { + "cell_type": "code", + "id": "s4-dataset", + "metadata": {}, + "outputs": [], + "execution_count": null, + "source": "# Set all random seeds before any stochastic operation\ntorch.manual_seed(SEED)\nnp.random.seed(SEED)\nrandom.seed(SEED)\nif torch.cuda.is_available():\n torch.cuda.manual_seed_all(SEED)\n torch.backends.cudnn.deterministic = True\nprint(f\"\u2713 Random seed set to {SEED}\")\n\nfrom pyhealth.datasets import split_by_patient\nfrom pyhealth.models import PromptEHR\n\nif not DEMO_MODE:\n from pyhealth.datasets import MIMIC3Dataset\n from pyhealth.tasks import promptehr_generation_mimic3_fn\n\n print(\"\\nLoading MIMIC-III dataset (this may take a few minutes)...\")\n dataset = MIMIC3Dataset(\n root=DATA_DIR,\n tables=[\"patients\", \"admissions\", \"diagnoses_icd\"],\n code_mapping={},\n )\n print(f\"Loaded {len(dataset.patients):,} patients\")\n\n print(\"Applying PromptEHR generation task...\")\n sample_dataset = dataset.set_task(promptehr_generation_mimic3_fn)\n print(f\"Eligible patients (\u22652 visits with ICD-9 codes): {len(sample_dataset):,}\")\nelse:\n from pyhealth.datasets.sample_dataset import InMemorySampleDataset\n sample_dataset = InMemorySampleDataset(\n samples=demo_samples,\n input_schema={\"visits\": \"nested_sequence\"},\n output_schema={},\n )\n print(f\"Demo dataset ready: {len(sample_dataset)} patients\")\n\ntrain_dataset, val_dataset, _ = split_by_patient(sample_dataset, [0.8, 0.1, 0.1])\nprint(f\"\\nSplit: {len(train_dataset):,} train / {len(val_dataset):,} val patients\")" + }, + { + "cell_type": "code", + "id": "s4-init", + "metadata": {}, + "outputs": [], + "execution_count": null, + "source": "# Save config alongside checkpoint for reproducibility\n_config = {k: str(v) for k, v in globals().items()\n if k.isupper() and not k.startswith('_')\n and isinstance(v, (str, int, float, bool))}\n_config['timestamp'] = datetime.now().isoformat()\n_config_path = f'{CHECKPOINT_DIR}/config.json'\nwith open(_config_path, 'w') as f:\n json.dump(_config, f, indent=2)\nprint(f\"\u2713 Config saved to {_config_path}\")\n\n# Initialize model\nprint(\"\\nInitializing PromptEHR model...\")\nmodel = PromptEHR(\n dataset=train_dataset,\n n_num_features=1, # 1 continuous demographic feature: age\n cat_cardinalities=[2], # 1 categorical feature: gender (binary: 0=male, 1=female)\n d_hidden=D_HIDDEN,\n prompt_length=PROMPT_LENGTH,\n bart_config_name=BART_CONFIG_NAME,\n epochs=EPOCHS,\n batch_size=BATCH_SIZE,\n lr=LR,\n warmup_steps=WARMUP_STEPS,\n max_seq_length=MAX_SEQ_LENGTH,\n save_dir=CHECKPOINT_DIR,\n)\n\nn_special = 7 # PAD, BOS, EOS, UNK, VISIT_START, VISIT_END, SEQ_END\nn_codes = model._vocab.total_size - n_special\ntotal_params = sum(p.numel() for p in model.parameters())\nprint(f\"\u2713 PromptEHR initialized\")\nprint(f\" Vocabulary: {model._vocab.total_size} tokens \"\n f\"({n_codes} ICD-9 codes + {n_special} special tokens)\")\nprint(f\" Parameters: {total_params:,}\")" + }, + { + "cell_type": "code", + "id": "s4-train", + "metadata": {}, + "outputs": [], + "execution_count": null, + "source": "print(\"Starting training...\")\nprint(\"HuggingFace Trainer will print step-by-step progress below.\")\nprint(\"=\" * 60)\n\nmodel.train_model(train_dataset, val_dataset=val_dataset)\n\nprint(\"=\" * 60)\nprint(\"\u2713 Training complete!\")\nprint(f\" Checkpoint: {CHECKPOINT_DIR}/checkpoint.pt\")" + }, + { + "cell_type": "code", + "id": "s4-loss", + "metadata": {}, + "outputs": [], + "execution_count": null, + "source": "# Plot training loss from HuggingFace Trainer logs\n_state_files = glob.glob(f'{CHECKPOINT_DIR}/**/trainer_state.json', recursive=True)\n\nif _state_files:\n with open(_state_files[0]) as f:\n _log = json.load(f)['log_history']\n _steps = [e['step'] for e in _log if 'loss' in e]\n _losses = [e['loss'] for e in _log if 'loss' in e]\n\n if _steps:\n fig, ax = plt.subplots(figsize=(9, 4))\n ax.plot(_steps, _losses, 'b-o', linewidth=1.5, markersize=4, label='Training loss')\n ax.set_xlabel('Training step', fontsize=12)\n ax.set_ylabel('Cross-entropy loss', fontsize=12)\n ax.set_title('PromptEHR Training Loss', fontsize=14)\n ax.legend(); ax.grid(alpha=0.3)\n plt.tight_layout()\n _loss_plot = f'{OUTPUT_DIR}/training_loss.png'\n plt.savefig(_loss_plot, dpi=150); plt.show()\n print(f\"Initial loss: {_losses[0]:.4f} \u2192 Final loss: {_losses[-1]:.4f}\")\n print(f\"Plot saved to: {_loss_plot}\")\n else:\n print(\"No loss values recorded (too few steps for demo preset).\")\nelse:\n print(\"trainer_state.json not found \u2014 skipping loss curve.\")\n print(\"(Expected for very short demo runs.)\")" + }, + { + "cell_type": "markdown", + "id": "s5-header", + "metadata": {}, + "source": "---\n# 5. Generation" + }, + { + "cell_type": "markdown", + "id": "s5-desc", + "metadata": {}, + "source": "**How generation works:**\n\n1. **Demographic sampling**: For each synthetic patient, `synthesize_dataset` draws an `(age, gender)` pair from `model._demo_pool` \u2014 the real training population. This ensures the synthetic cohort's demographic profile mirrors MIMIC-III.\n2. **Prompt conditioning**: The sampled demographics are encoded into prompt vectors and prepended to the BART encoder input.\n3. **Autoregressive decoding**: BART generates tokens one at a time. Special tokens `[VISIT_START]` and `[VISIT_END]` structure the output into visits; `[SEQ_END]` ends the patient sequence.\n4. **Decoding**: Token IDs are mapped back to ICD-9 code strings.\n\n`RANDOM_SAMPLING = True` (default): nucleus sampling \u2014 diverse, realistic output. \n`RANDOM_SAMPLING = False`: greedy decoding \u2014 deterministic, may repeat common patterns." + }, + { + "cell_type": "code", + "id": "s5-generate", + "metadata": {}, + "outputs": [], + "execution_count": null, + "source": "print(f\"Generating {N_SYNTHETIC_SAMPLES:,} synthetic patients...\")\nprint(f\" Sampling: {'nucleus (random)' if RANDOM_SAMPLING else 'greedy'}\"\n + (f\", temperature={TEMPERATURE}, top_p={TOP_P}\" if RANDOM_SAMPLING else \"\"))\nprint(\"(This may take several minutes...)\")\n\nsynthetic = model.synthesize_dataset(\n num_samples=N_SYNTHETIC_SAMPLES,\n random_sampling=RANDOM_SAMPLING,\n)\n\nprint(f\"\\n\u2713 Generated {len(synthetic):,} synthetic patients\")\n\n# Preview\n_preview = []\nfor p in synthetic[:10]:\n _v0 = p[\"visits\"][0] if p[\"visits\"] else []\n _sample = \", \".join(_v0[:4]) + (\"...\" if len(_v0) > 4 else \"\")\n _preview.append({\n \"patient_id\": p[\"patient_id\"],\n \"n_visits\": len(p[\"visits\"]),\n \"total_codes\": sum(len(v) for v in p[\"visits\"]),\n \"first_visit_codes\": _sample or \"(empty)\",\n })\ndisplay(pd.DataFrame(_preview))" + }, + { + "cell_type": "code", + "id": "s5-save", + "metadata": {}, + "outputs": [], + "execution_count": null, + "source": "# Save as JSON (full nested records \u2014 directly loadable back into PyHealth)\njson_path = f'{OUTPUT_DIR}/synthetic_patients.json'\nwith open(json_path, 'w') as f:\n json.dump(synthetic, f, indent=2)\nprint(f\"\u2713 {len(synthetic):,} patients \u2192 {json_path}\")\n\n# Save as CSV (flat SUBJECT_ID, VISIT_NUM, ICD9_CODE \u2014 matches MIMIC-III output schema)\n_rows = []\nfor p in synthetic:\n for _vnum, _visit in enumerate(p[\"visits\"], 1):\n for _code in _visit:\n _rows.append({\"SUBJECT_ID\": p[\"patient_id\"],\n \"VISIT_NUM\": _vnum,\n \"ICD9_CODE\": _code})\ndf_synthetic = pd.DataFrame(_rows)\ncsv_path = f'{OUTPUT_DIR}/synthetic_patients.csv'\ndf_synthetic.to_csv(csv_path, index=False)\nprint(f\"\u2713 {len(df_synthetic):,} records \u2192 {csv_path}\")\nprint(f\" Columns: SUBJECT_ID, VISIT_NUM, ICD9_CODE\")\nprint(\"\\nSample rows:\")\ndisplay(df_synthetic.head(8))" + }, + { + "cell_type": "markdown", + "id": "s6-header", + "metadata": {}, + "source": "---\n# 6. Results & Evaluation" + }, + { + "cell_type": "code", + "id": "s6-stats", + "metadata": {}, + "outputs": [], + "execution_count": null, + "source": "print(\"=\" * 60)\nprint(\"SYNTHETIC DATASET STATISTICS\")\nprint(\"=\" * 60)\n\nn_visits = [len(p[\"visits\"]) for p in synthetic]\nn_codes = [sum(len(v) for v in p[\"visits\"]) for p in synthetic]\n\nprint(f\"\\nPatients: {len(synthetic):,}\")\nprint(f\"\\nVisits per patient:\")\nprint(f\" Mean \u00b1 SD : {np.mean(n_visits):.2f} \u00b1 {np.std(n_visits):.2f}\")\nprint(f\" Median : {np.median(n_visits):.0f}\")\nprint(f\" Range : [{min(n_visits)}, {max(n_visits)}]\")\nprint(f\"\\nDiagnosis codes per patient:\")\nprint(f\" Mean \u00b1 SD : {np.mean(n_codes):.2f} \u00b1 {np.std(n_codes):.2f}\")\nprint(f\" Median : {np.median(n_codes):.0f}\")\nprint(f\" Range : [{min(n_codes)}, {max(n_codes)}]\")\n\nfig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))\nax1.hist(n_visits, bins=20, color='steelblue', edgecolor='white', alpha=0.85)\nax1.set_xlabel('Visits per patient'); ax1.set_ylabel('Count')\nax1.set_title('Visit Count Distribution')\nax2.hist(n_codes, bins=30, color='coral', edgecolor='white', alpha=0.85)\nax2.set_xlabel('Codes per patient'); ax2.set_ylabel('Count')\nax2.set_title('Code Count Distribution')\nplt.tight_layout()\nplt.savefig(f'{OUTPUT_DIR}/count_distributions.png', dpi=150)\nplt.show()" + }, + { + "cell_type": "code", + "id": "s6-coverage", + "metadata": {}, + "outputs": [], + "execution_count": null, + "source": "all_synth_codes = set(c for p in synthetic for v in p[\"visits\"] for c in v)\nn_real_codes = len(model._vocab._bart_to_code) # ICD-9 codes in vocabulary\ncoverage = len(all_synth_codes) / n_real_codes * 100 if n_real_codes > 0 else 0.0\n\nprint(f\"Vocabulary size (ICD-9 codes): {n_real_codes:,}\")\nprint(f\"Unique codes in synthetic: {len(all_synth_codes):,}\")\nprint(f\"Vocabulary coverage: {coverage:.1f}%\")\n\nif coverage < 30:\n print(\"\\n\u26a0 Low coverage may indicate mode collapse.\")\n print(\" Consider: more EPOCHS, lower LR, or check _demo_pool is populated.\")\nelif coverage < 60:\n print(\"\\nModerate coverage \u2014 expected for demo preset.\")\n print(\"Production training typically achieves 60\u201380%.\")\nelse:\n print(f\"\\n\u2713 Good vocabulary coverage.\")" + }, + { + "cell_type": "code", + "id": "s6-demographics", + "metadata": {}, + "outputs": [], + "execution_count": null, + "source": "# model._demo_pool stores (age, gender) pairs from training data.\n# synthesize_dataset samples from this pool for each synthetic patient,\n# so the synthetic cohort's demographics automatically mirror the training population.\nif model._demo_pool:\n _ages = [a for a, g in model._demo_pool]\n _genders = [g for a, g in model._demo_pool]\n _n_male = sum(1 for g in _genders if g == 0)\n _n_female = sum(1 for g in _genders if g == 1)\n\n fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(13, 5))\n\n ax1.hist(_ages, bins=25, density=True, color='steelblue', edgecolor='white',\n alpha=0.8, label='Training population')\n ax1.axvline(np.mean(_ages), color='navy', linestyle='--', linewidth=1.5,\n label=f'Mean age: {np.mean(_ages):.1f}')\n ax1.set_xlabel('Age at first admission', fontsize=12)\n ax1.set_ylabel('Density', fontsize=12)\n ax1.set_title('Age Distribution\\n(Conditioning Source)', fontsize=13)\n ax1.legend(fontsize=10)\n\n _bars = ax2.bar(['Male', 'Female'], [_n_male, _n_female],\n color=['steelblue', 'coral'], edgecolor='white', alpha=0.85)\n for _bar, _val in zip(_bars, [_n_male, _n_female]):\n ax2.text(_bar.get_x() + _bar.get_width()/2, _bar.get_height() + 5,\n f'{_val:,}\\n({_val/len(_genders)*100:.1f}%)',\n ha='center', va='bottom', fontsize=11)\n ax2.set_ylabel('Patient count', fontsize=12)\n ax2.set_title('Gender Distribution\\n(Conditioning Source)', fontsize=13)\n\n plt.tight_layout()\n plt.savefig(f'{OUTPUT_DIR}/demographics_distribution.png', dpi=150)\n plt.show()\n\n print(f\"Demographics pool: {len(model._demo_pool):,} training patients\")\n print(f\" Age: mean={np.mean(_ages):.1f}, std={np.std(_ages):.1f}, \"\n f\"range=[{min(_ages):.0f}, {max(_ages):.0f}]\")\n print(f\" Male: {_n_male:,} ({_n_male/len(_genders)*100:.1f}%)\")\n print(f\" Female: {_n_female:,} ({_n_female/len(_genders)*100:.1f}%)\")\n print(\"\\n\u2713 Synthetic patients are generated with demographics sampled from this distribution.\")\nelse:\n print(\"_demo_pool is empty \u2014 model was not trained before calling synthesize_dataset.\")\n print(\"Run Section 4 first, or load a checkpoint that was saved after training.\")" + }, + { + "cell_type": "code", + "id": "s6-freq", + "metadata": {}, + "outputs": [], + "execution_count": null, + "source": "# Build real training code frequencies by decoding processor-encoded visit tensors.\n# NestedSequenceProcessor: index 0=pad, 1=unk, 2+=codes.\n# _PromptEHRVocab mapping: bart_id = processor_idx + 5 for codes (idx>=2).\n_vocab_map = model._vocab._bart_to_code # bart_token_id -> ICD-9 code string\n_real_counts = Counter()\n\nfor _sample in train_dataset:\n for _visit in _sample.get(\"visits\", []):\n for _tok in _visit:\n _idx = int(_tok.item()) if hasattr(_tok, 'item') else int(_tok)\n if _idx >= 2: # skip pad(0) and unk(1)\n _bart_id = _idx + 5\n _code = _vocab_map.get(_bart_id)\n if _code:\n _real_counts[_code] += 1\n\n_synth_counts = Counter(c for p in synthetic for v in p[\"visits\"] for c in v)\n\n_top_codes = [c for c, _ in _real_counts.most_common(20)]\n_real_freq = [_real_counts[c] for c in _top_codes]\n_synth_freq = [_synth_counts.get(c, 0) for c in _top_codes]\n\nfig, ax = plt.subplots(figsize=(15, 5))\n_x = range(len(_top_codes))\nax.bar([i - 0.2 for i in _x], _real_freq, 0.38, label='Real (training)', color='steelblue', alpha=0.85)\nax.bar([i + 0.2 for i in _x], _synth_freq, 0.38, label='Synthetic', color='coral', alpha=0.85)\nax.set_xticks(_x)\nax.set_xticklabels(_top_codes, rotation=45, ha='right', fontsize=9)\nax.set_ylabel('Frequency', fontsize=12)\nax.set_title('Top-20 ICD-9 Code Frequency: Real vs Synthetic', fontsize=14)\nax.legend(fontsize=11); ax.grid(axis='y', alpha=0.3)\nplt.tight_layout()\nplt.savefig(f'{OUTPUT_DIR}/code_frequency_comparison.png', dpi=150)\nplt.show()\n\n# Pearson r (manual computation \u2014 no scipy dependency)\n_r_mean = np.mean(_real_freq); _s_mean = np.mean(_synth_freq)\n_num = sum((r - _r_mean)*(s - _s_mean) for r, s in zip(_real_freq, _synth_freq))\n_denom = (sum((r-_r_mean)**2 for r in _real_freq) * sum((s-_s_mean)**2 for s in _synth_freq)) ** 0.5\npearson_r = _num / _denom if _denom > 0 else 0.0\nprint(f\"Pearson r (top-20 code frequencies, real vs synthetic): {pearson_r:.3f}\")\nif pearson_r > 0.8: print(\"\u2713 Strong correlation \u2014 good distributional fidelity.\")\nelif pearson_r > 0.5: print(\"Moderate correlation \u2014 consider more epochs.\")\nelse: print(\"Weak correlation \u2014 model may need more training.\")" + }, + { + "cell_type": "code", + "id": "s6-empty", + "metadata": {}, + "outputs": [], + "execution_count": null, + "source": "_empty = [p for p in synthetic if not p[\"visits\"] or all(len(v) == 0 for v in p[\"visits\"])]\nif _empty:\n print(f\"\u26a0 {len(_empty)} / {len(synthetic)} patients have empty visit sequences.\")\n print(\" Possible causes:\")\n print(\" - Model is undertrained (increase EPOCHS)\")\n print(\" - Temperature too low (try TEMPERATURE = 1.0)\")\n print(\" - _demo_pool not populated (train before calling synthesize_dataset)\")\nelse:\n print(f\"\u2713 All {len(synthetic):,} patients have at least one visit with at least one code.\")" + }, + { + "cell_type": "code", + "id": "s6-report", + "metadata": {}, + "outputs": [], + "execution_count": null, + "source": "quality = {\n \"total_synthetic_patients\": len(synthetic),\n \"mean_visits_per_patient\": round(float(np.mean(n_visits)), 3),\n \"std_visits_per_patient\": round(float(np.std(n_visits)), 3),\n \"mean_codes_per_patient\": round(float(np.mean(n_codes)), 3),\n \"std_codes_per_patient\": round(float(np.std(n_codes)), 3),\n \"unique_codes_generated\": len(all_synth_codes),\n \"vocabulary_size\": n_real_codes,\n \"vocabulary_coverage_pct\": round(coverage, 2),\n \"empty_patients_count\": len(_empty),\n \"code_freq_pearson_r\": round(pearson_r, 4),\n \"training_patients\": len(train_dataset),\n \"vocab_total_size\": model._vocab.total_size,\n \"demo_mode\": DEMO_MODE,\n \"preset\": PRESET,\n \"epochs\": EPOCHS,\n \"seed\": SEED,\n \"timestamp\": datetime.now().isoformat(),\n}\nreport_path = f'{OUTPUT_DIR}/quality_report.json'\nwith open(report_path, 'w') as f:\n json.dump(quality, f, indent=2)\nprint(\"Quality Report:\")\nprint(json.dumps(quality, indent=2))\nprint(f\"\\n\u2713 Saved to {report_path}\")" + }, + { + "cell_type": "markdown", + "id": "s7-header", + "metadata": {}, + "source": "---\n# 7. Download & Next Steps" + }, + { + "cell_type": "code", + "id": "s7-download", + "metadata": {}, + "outputs": [], + "execution_count": null, + "source": "# Download output files (Colab only \u2014 silently skipped in local/SLURM environments)\n_outputs = [\n csv_path,\n json_path,\n report_path,\n f'{OUTPUT_DIR}/training_loss.png',\n f'{OUTPUT_DIR}/demographics_distribution.png',\n f'{OUTPUT_DIR}/code_frequency_comparison.png',\n f'{CHECKPOINT_DIR}/checkpoint.pt',\n f'{CHECKPOINT_DIR}/config.json',\n]\n\nif IN_COLAB:\n from google.colab import files as _colab_files\n print(\"Downloading output files...\")\n for _p in _outputs:\n if os.path.exists(_p):\n _colab_files.download(_p)\n print(f\" \u2713 {os.path.basename(_p)}\")\n else:\n print(f\" \u2014 {os.path.basename(_p)} (not found)\")\nelse:\n print(f\"Output files saved to: {OUTPUT_DIR}\")\n print(f\"Checkpoint: {CHECKPOINT_DIR}/checkpoint.pt\")\n for _p in _outputs:\n if os.path.exists(_p):\n _kb = os.path.getsize(_p) / 1024\n print(f\" {os.path.basename(_p):45s} {_kb:8.1f} KB\")" + }, + { + "cell_type": "code", + "id": "s7-resume", + "metadata": {}, + "outputs": [], + "execution_count": null, + "source": "# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n# CHECKPOINT RESUME \u2014 Run this cell instead of Section 4 if you already trained\n# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n# Uncomment everything below to load an existing checkpoint, then skip to Section 5.\n\n# from pyhealth.datasets import MIMIC3Dataset, split_by_patient\n# from pyhealth.tasks import promptehr_generation_mimic3_fn\n# from pyhealth.models import PromptEHR\n#\n# dataset = MIMIC3Dataset(\n# root=DATA_DIR,\n# tables=[\"patients\", \"admissions\", \"diagnoses_icd\"],\n# code_mapping={},\n# )\n# sample_dataset = dataset.set_task(promptehr_generation_mimic3_fn)\n# train_dataset, val_dataset, _ = split_by_patient(sample_dataset, [0.8, 0.1, 0.1])\n#\n# model = PromptEHR(\n# dataset=train_dataset,\n# n_num_features=1, cat_cardinalities=[2],\n# d_hidden=D_HIDDEN, prompt_length=PROMPT_LENGTH,\n# bart_config_name=BART_CONFIG_NAME,\n# epochs=EPOCHS, batch_size=BATCH_SIZE,\n# lr=LR, warmup_steps=WARMUP_STEPS,\n# max_seq_length=MAX_SEQ_LENGTH,\n# save_dir=CHECKPOINT_DIR,\n# )\n# ckpt = f'{CHECKPOINT_DIR}/checkpoint.pt'\n# model.load_model(ckpt)\n# print(f\"\u2713 Loaded checkpoint from {ckpt}. Proceed to Section 5.\")\n\nprint(\"(Resume template \u2014 uncomment the lines above to use)\")" + }, + { + "cell_type": "markdown", + "id": "s7-congrats", + "metadata": {}, + "source": "---\n## \ud83c\udf89 Congratulations!\n\nYou've successfully:\n1. \u2705 Trained a PromptEHR model conditioned on patient demographics\n2. \u2705 Generated synthetic patients whose age/gender distribution mirrors MIMIC-III\n3. \u2705 Validated ICD-9 code frequency fidelity against real training data\n4. \u2705 Saved output files for downstream use\n\n## Next Steps\n\n**Use your synthetic data:**\n- Train readmission/mortality/LoS prediction models on synthetic data\n- Evaluate fairness across demographic subgroups\n- Share synthetic patients without privacy concerns\n\n**Reload and generate more:**\n```python\nfrom pyhealth.models import PromptEHR\nmodel = PromptEHR(dataset=train_dataset, ...)\nmodel.load_model('./promptehr_training/checkpoints/checkpoint.pt')\nextra = model.synthesize_dataset(num_samples=50_000)\n```\n\n## Troubleshooting\n\n| Symptom | Cause | Fix |\n|---------|-------|-----|\n| `AssertionError: transformers>=4.48.3 required` | Old transformers installed | `pip install transformers --upgrade` |\n| Empty patients in output | Undertrained model | Increase `EPOCHS` or raise `TEMPERATURE` to `1.0` |\n| Training loss not decreasing after 2+ epochs | LR too high | Try `LR = 5e-6` and `WARMUP_STEPS = 500` |\n| Out of memory (OOM) | Batch too large | Reduce `BATCH_SIZE = 8` |\n| Very slow training | No GPU | Runtime \u2192 Change runtime type \u2192 T4 GPU |\n| `KeyError: 'visits'` in demo mode | Wrong schema | Ensure `input_schema={\"visits\": \"nested_sequence\"}` |\n| Synthetic codes all the same | Temperature too low | Try `TEMPERATURE = 1.0`, `RANDOM_SAMPLING = True` |\n\n---\n\n## Reference\n\nWang, Y., et al. \"PromptEHR: Conditional Electronic Healthcare Records Generation with Prompt Learning.\" *EMNLP 2023*. https://arxiv.org/abs/2211.01761\n\n---\n_Notebook for PyHealth 2.0 \u00b7 Branch: `promptehr-pr-integration` \u00b7 jalengg/PyHealth_" + } + ] +} \ No newline at end of file From b1cc36d05ce9f5567bf10e8920f6461c6d0a606f Mon Sep 17 00:00:00 2001 From: jalengg Date: Wed, 4 Mar 2026 00:13:06 -0600 Subject: [PATCH 09/37] Fix: idempotent Drive mount in Colab notebook Guard drive.mount() with os.path.isdir('/content/drive/MyDrive') check so re-running the cell does not raise ValueError: Mountpoint must not already contain files. --- examples/promptehr_mimic3_colab.ipynb | 50 +++++++++++++-------------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/examples/promptehr_mimic3_colab.ipynb b/examples/promptehr_mimic3_colab.ipynb index 925efc57a..fec713ccd 100644 --- a/examples/promptehr_mimic3_colab.ipynb +++ b/examples/promptehr_mimic3_colab.ipynb @@ -22,7 +22,7 @@ "cell_type": "markdown", "id": "preamble", "metadata": {}, - "source": "# PromptEHR: Demographic-Conditioned Synthetic EHR Generation\n\n_Last updated: 2026-03-01_\n\nTrain **PromptEHR** on your MIMIC-III data and generate synthetic patients whose demographic distributions mirror the real population.\n\n## What You'll Need\n\n1. **MIMIC-III Access** (or run in Demo Mode without it). Download 3 files from PhysioNet:\n - `PATIENTS.csv` \u2014 patient demographics (date of birth, gender)\n - `ADMISSIONS.csv` \u2014 hospital admission records (timestamps)\n - `DIAGNOSES_ICD.csv` \u2014 ICD-9 diagnosis codes\n\n2. **Google Colab** (or local environment): Free tier works; GPU recommended.\n\n> **Demo Mode**: No MIMIC-III? Set `PRESET = \"demo\"` and skip the file upload step. The notebook runs the full pipeline with synthetic stand-in data.\n\n## What You'll Get\n\n- A trained PromptEHR model conditioned on patient age and gender\n- Synthetic patients whose age/gender distributions mirror the MIMIC-III population\n- `synthetic_patients.csv` \u2014 flat `SUBJECT_ID, VISIT_NUM, ICD9_CODE` records\n- `synthetic_patients.json` \u2014 nested visit records for PyHealth downstream tasks\n- `quality_report.json` \u2014 statistics for automated evaluation and CI\n\n## How Long It Takes\n\n| Preset | Epochs | Time (T4 GPU) | Use case |\n|--------|--------|----------------|----------|\n| `\"demo\"` | 5 | ~30\u201345 min | First run, CI smoke test |\n| `\"production\"` | 20 | ~3\u20135 hrs | Publication-quality results |\n\n## What Makes PromptEHR Different from HALO\n\nUnlike HALO (which generates patients from a shared unconditional distribution), **PromptEHR conditions generation on patient demographics**. It uses a BART Seq2Seq Transformer with learned \"prompt\" vectors \u2014 one per demographic feature \u2014 prepended to the encoder input. During training, the model learns that older male patients tend to have different diagnosis patterns than young female patients. During generation, demographics are sampled from the real training distribution, so the synthetic cohort's age/gender profile automatically mirrors MIMIC-III.\n\nThis matters clinically: synthetic datasets used for fairness research or subgroup analysis must preserve demographic distributions. PromptEHR provides this guarantee by design.\n\n**Reference**: Wang et al., \"PromptEHR: Conditional Electronic Healthcare Records Generation with Prompt Learning.\" EMNLP 2023. https://arxiv.org/abs/2211.01761" + "source": "# PromptEHR: Demographic-Conditioned Synthetic EHR Generation\n\n_Last updated: 2026-03-01_\n\nTrain **PromptEHR** on your MIMIC-III data and generate synthetic patients whose demographic distributions mirror the real population.\n\n## What You'll Need\n\n1. **MIMIC-III Access** (or run in Demo Mode without it). Download 3 files from PhysioNet:\n - `PATIENTS.csv` — patient demographics (date of birth, gender)\n - `ADMISSIONS.csv` — hospital admission records (timestamps)\n - `DIAGNOSES_ICD.csv` — ICD-9 diagnosis codes\n\n2. **Google Colab** (or local environment): Free tier works; GPU recommended.\n\n> **Demo Mode**: No MIMIC-III? Set `PRESET = \"demo\"` and skip the file upload step. The notebook runs the full pipeline with synthetic stand-in data.\n\n## What You'll Get\n\n- A trained PromptEHR model conditioned on patient age and gender\n- Synthetic patients whose age/gender distributions mirror the MIMIC-III population\n- `synthetic_patients.csv` — flat `SUBJECT_ID, VISIT_NUM, ICD9_CODE` records\n- `synthetic_patients.json` — nested visit records for PyHealth downstream tasks\n- `quality_report.json` — statistics for automated evaluation and CI\n\n## How Long It Takes\n\n| Preset | Epochs | Time (T4 GPU) | Use case |\n|--------|--------|----------------|----------|\n| `\"demo\"` | 5 | ~30–45 min | First run, CI smoke test |\n| `\"production\"` | 20 | ~3–5 hrs | Publication-quality results |\n\n## What Makes PromptEHR Different from HALO\n\nUnlike HALO (which generates patients from a shared unconditional distribution), **PromptEHR conditions generation on patient demographics**. It uses a BART Seq2Seq Transformer with learned \"prompt\" vectors — one per demographic feature — prepended to the encoder input. During training, the model learns that older male patients tend to have different diagnosis patterns than young female patients. During generation, demographics are sampled from the real training distribution, so the synthetic cohort's age/gender profile automatically mirrors MIMIC-III.\n\nThis matters clinically: synthetic datasets used for fairness research or subgroup analysis must preserve demographic distributions. PromptEHR provides this guarantee by design.\n\n**Reference**: Wang et al., \"PromptEHR: Conditional Electronic Healthcare Records Generation with Prompt Learning.\" EMNLP 2023. https://arxiv.org/abs/2211.01761" }, { "cell_type": "markdown", @@ -36,7 +36,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "import subprocess\nimport sys\n\n# Install PyHealth from GitHub\nFORK = 'jalengg'\nBRANCH = 'promptehr-pr-integration'\ninstall_url = f\"git+https://github.com/{FORK}/PyHealth.git@{BRANCH}\"\nsubprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", install_url,\n \"--quiet\", \"--no-cache-dir\"])\nprint(f\"\u2713 PyHealth installed from {FORK}/{BRANCH}\")\n\n# Environment detection \u2014 MUST come before any google.colab import\ntry:\n import google.colab # noqa: F401\n IN_COLAB = True\nexcept ImportError:\n IN_COLAB = False\n\nimport os\nimport json\nimport glob\nimport random\nimport numpy as np\nimport torch\nimport pandas as pd\nimport matplotlib.pyplot as plt\nfrom datetime import datetime\nfrom collections import Counter\nfrom IPython.display import display\n\nprint(f\"Running in Colab: {IN_COLAB}\")\nprint(f\"PyTorch: {torch.__version__}\")\nif torch.cuda.is_available():\n print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n gb = torch.cuda.get_device_properties(0).total_memory / 1e9\n print(f\"GPU memory: {gb:.1f} GB\")\nelse:\n print(\"WARNING: No GPU detected. Training will be slow.\")\n print(\" \u2192 Runtime \u2192 Change runtime type \u2192 T4 GPU\")\n\n# Verify HuggingFace transformers version (>=4.48.3 required for use_cpu parameter)\nimport transformers\n_ver = tuple(int(x) for x in transformers.__version__.split(\".\")[:2])\nassert _ver >= (4, 48), (\n f\"transformers>=4.48.3 required (got {transformers.__version__}). \"\n \"Fix: pip install transformers --upgrade\"\n)\nprint(f\"transformers: {transformers.__version__} \u2713\")\nprint(\"\u2713 All setup complete\")" + "source": "import subprocess\nimport sys\n\n# Install PyHealth from GitHub\nFORK = 'jalengg'\nBRANCH = 'promptehr-pr-integration'\ninstall_url = f\"git+https://github.com/{FORK}/PyHealth.git@{BRANCH}\"\nsubprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", install_url,\n \"--quiet\", \"--no-cache-dir\"])\nprint(f\"✓ PyHealth installed from {FORK}/{BRANCH}\")\n\n# Environment detection — MUST come before any google.colab import\ntry:\n import google.colab # noqa: F401\n IN_COLAB = True\nexcept ImportError:\n IN_COLAB = False\n\nimport os\nimport json\nimport glob\nimport random\nimport numpy as np\nimport torch\nimport pandas as pd\nimport matplotlib.pyplot as plt\nfrom datetime import datetime\nfrom collections import Counter\nfrom IPython.display import display\n\nprint(f\"Running in Colab: {IN_COLAB}\")\nprint(f\"PyTorch: {torch.__version__}\")\nif torch.cuda.is_available():\n print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n gb = torch.cuda.get_device_properties(0).total_memory / 1e9\n print(f\"GPU memory: {gb:.1f} GB\")\nelse:\n print(\"WARNING: No GPU detected. Training will be slow.\")\n print(\" → Runtime → Change runtime type → T4 GPU\")\n\n# Verify HuggingFace transformers version (>=4.48.3 required for use_cpu parameter)\nimport transformers\n_ver = tuple(int(x) for x in transformers.__version__.split(\".\")[:2])\nassert _ver >= (4, 48), (\n f\"transformers>=4.48.3 required (got {transformers.__version__}). \"\n \"Fix: pip install transformers --upgrade\"\n)\nprint(f\"transformers: {transformers.__version__} ✓\")\nprint(\"✓ All setup complete\")" }, { "cell_type": "markdown", @@ -48,7 +48,7 @@ "cell_type": "markdown", "id": "s2-desc", "metadata": {}, - "source": "Configure all parameters here. **This is the only cell you need to modify.**\n\n- **`PRESET = \"demo\"`** \u2014 5 epochs, 1 K synthetic patients, ~30\u201345 min on T4\n- **`PRESET = \"production\"`** \u2014 20 epochs, 10 K synthetic patients, ~3\u20135 hrs on T4" + "source": "Configure all parameters here. **This is the only cell you need to modify.**\n\n- **`PRESET = \"demo\"`** — 5 epochs, 1 K synthetic patients, ~30–45 min on T4\n- **`PRESET = \"production\"`** — 20 epochs, 10 K synthetic patients, ~3–5 hrs on T4" }, { "cell_type": "code", @@ -56,7 +56,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "# ============================================================\n# CONFIGURATION \u2014 All modifiable parameters in one place\n# ============================================================\n\n# --- Preset ---\nPRESET = \"demo\" # \"demo\" or \"production\"\n\n# --- Training parameters ---\nif PRESET == \"demo\":\n EPOCHS = 5\n BATCH_SIZE = 16\n N_SYNTHETIC_SAMPLES = 1_000\n WARMUP_STEPS = 100\nelif PRESET == \"production\":\n EPOCHS = 20\n BATCH_SIZE = 16\n N_SYNTHETIC_SAMPLES = 10_000\n WARMUP_STEPS = 1_000\n\nLR = 1e-5 # Paper LR; low to avoid catastrophic forgetting of BART weights\nMAX_SEQ_LENGTH = 512 # Max tokens per patient (visits + special tokens)\n\n# --- Model architecture ---\nD_HIDDEN = 128 # Hidden dim for demographic prompt encoder\nPROMPT_LENGTH = 1 # Prompt vectors per demographic feature (1 is sufficient per paper)\n\n# --- BART backbone ---\n# \"facebook/bart-base\": pretrained BART (139 M params, 768 hidden dim).\n# PromptEHR fine-tunes these weights rather than training from scratch \u2014\n# the pretrained sequence modeling prior means even 20 epochs can produce good results.\nBART_CONFIG_NAME = \"facebook/bart-base\"\n\n# --- Generation parameters ---\nRANDOM_SAMPLING = True # True: nucleus sampling (diverse), False: greedy (deterministic)\nTEMPERATURE = 0.7 # Lower = more common codes. Higher = more rare/diverse codes.\nTOP_P = 0.95 # Nucleus sampling: sample from top 95% probability mass.\n\n# --- Reproducibility ---\nSEED = 42\n\n# --- Paths (all derived from BASE_DIR) ---\nBASE_DIR = '/content/drive/MyDrive/PromptEHR_Training' if IN_COLAB else './promptehr_training'\nDATA_DIR = f'{BASE_DIR}/data'\nCHECKPOINT_DIR = f'{BASE_DIR}/checkpoints'\nOUTPUT_DIR = f'{BASE_DIR}/output'\n\nfor d in [DATA_DIR, CHECKPOINT_DIR, OUTPUT_DIR]:\n os.makedirs(d, exist_ok=True)\n\nprint(f\"Preset: {PRESET}\")\nprint(f\"Epochs: {EPOCHS} | Batch size: {BATCH_SIZE} | LR: {LR}\")\nprint(f\"Synthetic: {N_SYNTHETIC_SAMPLES:,} patients\")\nprint(f\"Base directory: {BASE_DIR}\")\nprint(\"\u2713 Configuration complete\")" + "source": "# ============================================================\n# CONFIGURATION — All modifiable parameters in one place\n# ============================================================\n\n# --- Preset ---\nPRESET = \"demo\" # \"demo\" or \"production\"\n\n# --- Training parameters ---\nif PRESET == \"demo\":\n EPOCHS = 5\n BATCH_SIZE = 16\n N_SYNTHETIC_SAMPLES = 1_000\n WARMUP_STEPS = 100\nelif PRESET == \"production\":\n EPOCHS = 20\n BATCH_SIZE = 16\n N_SYNTHETIC_SAMPLES = 10_000\n WARMUP_STEPS = 1_000\n\nLR = 1e-5 # Paper LR; low to avoid catastrophic forgetting of BART weights\nMAX_SEQ_LENGTH = 512 # Max tokens per patient (visits + special tokens)\n\n# --- Model architecture ---\nD_HIDDEN = 128 # Hidden dim for demographic prompt encoder\nPROMPT_LENGTH = 1 # Prompt vectors per demographic feature (1 is sufficient per paper)\n\n# --- BART backbone ---\n# \"facebook/bart-base\": pretrained BART (139 M params, 768 hidden dim).\n# PromptEHR fine-tunes these weights rather than training from scratch —\n# the pretrained sequence modeling prior means even 20 epochs can produce good results.\nBART_CONFIG_NAME = \"facebook/bart-base\"\n\n# --- Generation parameters ---\nRANDOM_SAMPLING = True # True: nucleus sampling (diverse), False: greedy (deterministic)\nTEMPERATURE = 0.7 # Lower = more common codes. Higher = more rare/diverse codes.\nTOP_P = 0.95 # Nucleus sampling: sample from top 95% probability mass.\n\n# --- Reproducibility ---\nSEED = 42\n\n# --- Paths (all derived from BASE_DIR) ---\nBASE_DIR = '/content/drive/MyDrive/PromptEHR_Training' if IN_COLAB else './promptehr_training'\nDATA_DIR = f'{BASE_DIR}/data'\nCHECKPOINT_DIR = f'{BASE_DIR}/checkpoints'\nOUTPUT_DIR = f'{BASE_DIR}/output'\n\nfor d in [DATA_DIR, CHECKPOINT_DIR, OUTPUT_DIR]:\n os.makedirs(d, exist_ok=True)\n\nprint(f\"Preset: {PRESET}\")\nprint(f\"Epochs: {EPOCHS} | Batch size: {BATCH_SIZE} | LR: {LR}\")\nprint(f\"Synthetic: {N_SYNTHETIC_SAMPLES:,} patients\")\nprint(f\"Base directory: {BASE_DIR}\")\nprint(\"✓ Configuration complete\")" }, { "cell_type": "markdown", @@ -68,7 +68,7 @@ "cell_type": "markdown", "id": "s3-desc", "metadata": {}, - "source": "Upload your MIMIC-III CSV files. PromptEHR needs **3 files** (one more than HALO \u2014 `PATIENTS.csv` is required for demographic conditioning):\n\n1. `PATIENTS.csv` \u2014 date of birth and gender\n2. `ADMISSIONS.csv` \u2014 admission timestamps (used to compute age at first admission)\n3. `DIAGNOSES_ICD.csv` \u2014 ICD-9 diagnosis codes\n\nFiles persist across Colab sessions when saved to Google Drive.\n\n**No MIMIC-III?** The next cell automatically activates Demo Mode." + "source": "Upload your MIMIC-III CSV files. PromptEHR needs **3 files** (one more than HALO — `PATIENTS.csv` is required for demographic conditioning):\n\n1. `PATIENTS.csv` — date of birth and gender\n2. `ADMISSIONS.csv` — admission timestamps (used to compute age at first admission)\n3. `DIAGNOSES_ICD.csv` — ICD-9 diagnosis codes\n\nFiles persist across Colab sessions when saved to Google Drive.\n\n**No MIMIC-III?** The next cell automatically activates Demo Mode." }, { "cell_type": "code", @@ -76,7 +76,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "DEMO_MODE = False\n\n# Mount Drive (Colab only)\nif IN_COLAB:\n from google.colab import drive\n drive.mount('/content/drive')\n print(\"\u2713 Google Drive mounted\")\n\n# Check which files exist\nrequired_files = {\n 'PATIENTS.csv': 'Patient demographics (DOB, gender)',\n 'ADMISSIONS.csv': 'Admission records (timestamps)',\n 'DIAGNOSES_ICD.csv': 'ICD-9 diagnosis codes',\n}\nexisting = {f: os.path.exists(f'{DATA_DIR}/{f}') for f in required_files}\nmissing = [f for f, ok in existing.items() if not ok]\n\nprint(\"\\nMIMIC-III file status:\")\nfor fname, desc in required_files.items():\n mark = \"\u2713\" if existing[fname] else \"\u2717 MISSING\"\n print(f\" {mark} {fname} \u2014 {desc}\")\n\nif missing and IN_COLAB:\n print(f\"\\nUploading {len(missing)} missing file(s)...\")\n from google.colab import files as _colab_files\n uploaded = _colab_files.upload()\n for fname, data in uploaded.items():\n dest = f'{DATA_DIR}/{fname}'\n with open(dest, 'wb') as f:\n f.write(data)\n print(f\" Saved {fname} \u2192 {dest}\")\n missing = [f for f in required_files if not os.path.exists(f'{DATA_DIR}/{f}')]\n\nif missing:\n print(f\"\\nMIMIC-III files not available ({missing}).\")\n print(\"\u2192 Activating Demo Mode \u2014 full pipeline with synthetic stand-in data.\")\n DEMO_MODE = True\nelse:\n print(\"\\n\u2713 All 3 MIMIC-III files present. Running in MIMIC-III mode.\")" + "source": "DEMO_MODE = False\n\n# Mount Drive (Colab only) — guard makes this cell idempotent (safe to re-run)\nif IN_COLAB:\n from google.colab import drive\n if not os.path.isdir('/content/drive/MyDrive'):\n drive.mount('/content/drive')\n else:\n print(\"Drive already mounted\")\n print(\"✓ Google Drive mounted\")\n\n# Check which files exist\nrequired_files = {\n 'PATIENTS.csv': 'Patient demographics (DOB, gender)',\n 'ADMISSIONS.csv': 'Admission records (timestamps)',\n 'DIAGNOSES_ICD.csv': 'ICD-9 diagnosis codes',\n}\nexisting = {f: os.path.exists(f'{DATA_DIR}/{f}') for f in required_files}\nmissing = [f for f, ok in existing.items() if not ok]\n\nprint(\"\\nMIMIC-III file status:\")\nfor fname, desc in required_files.items():\n mark = \"✓\" if existing[fname] else \"✗ MISSING\"\n print(f\" {mark} {fname} — {desc}\")\n\nif missing and IN_COLAB:\n print(f\"\\nUploading {len(missing)} missing file(s)...\")\n from google.colab import files as _colab_files\n uploaded = _colab_files.upload()\n for fname, data in uploaded.items():\n dest = f'{DATA_DIR}/{fname}'\n with open(dest, 'wb') as f:\n f.write(data)\n print(f\" Saved {fname} → {dest}\")\n missing = [f for f in required_files if not os.path.exists(f'{DATA_DIR}/{f}')]\n\nif missing:\n print(f\"\\nMIMIC-III files not available ({missing}).\")\n print(\"→ Activating Demo Mode — full pipeline with synthetic stand-in data.\")\n DEMO_MODE = True\nelse:\n print(\"\\n✓ All 3 MIMIC-III files present. Running in MIMIC-III mode.\")" }, { "cell_type": "code", @@ -84,7 +84,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "if DEMO_MODE:\n print(\"Setting up Demo Mode data...\")\n from pyhealth.datasets.sample_dataset import InMemorySampleDataset\n\n # Synthetic stand-in: 200 patients, 2-6 visits, realistic ICD-9 codes.\n # Exercises the full pipeline without any real patient data.\n random.seed(SEED)\n icd9_pool = [\n \"428.0\", \"401.9\", \"250.00\", \"272.4\", \"410.71\", \"486\",\n \"585.3\", \"V58.61\", \"412\", \"414.01\", \"276.1\", \"285.9\",\n \"584.9\", \"305.1\", \"290.0\", \"427.31\", \"518.81\", \"496\",\n \"038.9\", \"599.0\",\n ]\n demo_samples = []\n for i in range(200):\n n_visits = random.randint(2, 6)\n visits = [random.sample(icd9_pool, random.randint(1, 5)) for _ in range(n_visits)]\n demo_samples.append({\n \"patient_id\": f\"DEMO_{i:04d}\",\n \"visits\": visits,\n \"age\": float(random.randint(18, 89)),\n \"gender\": random.randint(0, 1),\n })\n print(f\"\u2713 Demo dataset: {len(demo_samples)} patients, up to 6 visits each\")\n print(\" (Replace with real MIMIC-III data for publication-quality results)\")" + "source": "if DEMO_MODE:\n print(\"Setting up Demo Mode data...\")\n from pyhealth.datasets.sample_dataset import InMemorySampleDataset\n\n # Synthetic stand-in: 200 patients, 2-6 visits, realistic ICD-9 codes.\n # Exercises the full pipeline without any real patient data.\n random.seed(SEED)\n icd9_pool = [\n \"428.0\", \"401.9\", \"250.00\", \"272.4\", \"410.71\", \"486\",\n \"585.3\", \"V58.61\", \"412\", \"414.01\", \"276.1\", \"285.9\",\n \"584.9\", \"305.1\", \"290.0\", \"427.31\", \"518.81\", \"496\",\n \"038.9\", \"599.0\",\n ]\n demo_samples = []\n for i in range(200):\n n_visits = random.randint(2, 6)\n visits = [random.sample(icd9_pool, random.randint(1, 5)) for _ in range(n_visits)]\n demo_samples.append({\n \"patient_id\": f\"DEMO_{i:04d}\",\n \"visits\": visits,\n \"age\": float(random.randint(18, 89)),\n \"gender\": random.randint(0, 1),\n })\n print(f\"✓ Demo dataset: {len(demo_samples)} patients, up to 6 visits each\")\n print(\" (Replace with real MIMIC-III data for publication-quality results)\")" }, { "cell_type": "code", @@ -92,7 +92,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "if not DEMO_MODE:\n print(\"Validating MIMIC-III files...\")\n _patients = pd.read_csv(f'{DATA_DIR}/PATIENTS.csv')\n assert 'SUBJECT_ID' in _patients.columns, \"PATIENTS.csv missing SUBJECT_ID\"\n assert 'GENDER' in _patients.columns, \"PATIENTS.csv missing GENDER\"\n assert 'DOB' in _patients.columns, \"PATIENTS.csv missing DOB\"\n print(f\"\u2713 PATIENTS.csv: {len(_patients):>8,} rows\")\n\n _admissions = pd.read_csv(f'{DATA_DIR}/ADMISSIONS.csv')\n assert 'SUBJECT_ID' in _admissions.columns, \"ADMISSIONS.csv missing SUBJECT_ID\"\n assert 'HADM_ID' in _admissions.columns, \"ADMISSIONS.csv missing HADM_ID\"\n print(f\"\u2713 ADMISSIONS.csv: {len(_admissions):>8,} rows\")\n\n _diagnoses = pd.read_csv(f'{DATA_DIR}/DIAGNOSES_ICD.csv')\n assert 'ICD9_CODE' in _diagnoses.columns, \"DIAGNOSES_ICD.csv missing ICD9_CODE\"\n print(f\"\u2713 DIAGNOSES_ICD.csv: {len(_diagnoses):>8,} rows\")\n\n del _patients, _admissions, _diagnoses # free memory\n print(\"\\n\u2713 All files validated successfully\")" + "source": "if not DEMO_MODE:\n print(\"Validating MIMIC-III files...\")\n _patients = pd.read_csv(f'{DATA_DIR}/PATIENTS.csv')\n assert 'SUBJECT_ID' in _patients.columns, \"PATIENTS.csv missing SUBJECT_ID\"\n assert 'GENDER' in _patients.columns, \"PATIENTS.csv missing GENDER\"\n assert 'DOB' in _patients.columns, \"PATIENTS.csv missing DOB\"\n print(f\"✓ PATIENTS.csv: {len(_patients):>8,} rows\")\n\n _admissions = pd.read_csv(f'{DATA_DIR}/ADMISSIONS.csv')\n assert 'SUBJECT_ID' in _admissions.columns, \"ADMISSIONS.csv missing SUBJECT_ID\"\n assert 'HADM_ID' in _admissions.columns, \"ADMISSIONS.csv missing HADM_ID\"\n print(f\"✓ ADMISSIONS.csv: {len(_admissions):>8,} rows\")\n\n _diagnoses = pd.read_csv(f'{DATA_DIR}/DIAGNOSES_ICD.csv')\n assert 'ICD9_CODE' in _diagnoses.columns, \"DIAGNOSES_ICD.csv missing ICD9_CODE\"\n print(f\"✓ DIAGNOSES_ICD.csv: {len(_diagnoses):>8,} rows\")\n\n del _patients, _admissions, _diagnoses # free memory\n print(\"\\n✓ All files validated successfully\")" }, { "cell_type": "markdown", @@ -104,7 +104,7 @@ "cell_type": "markdown", "id": "s4-desc", "metadata": {}, - "source": "**What happens during training:**\n\n1. **Dataset loading**: PyHealth reads MIMIC-III and creates one sample per patient (nested visit sequences + demographics: age at first admission, gender).\n2. **Tokenization**: Each ICD-9 code is mapped to a unique BART token ID. Special tokens mark visit boundaries: `[VISIT_START]`, `[VISIT_END]`, `[SEQ_END]`.\n3. **Demographic prompts**: Age and gender are encoded into learned prompt vectors prepended to the BART encoder input \u2014 steering the model toward age/gender-appropriate diagnosis patterns.\n4. **Fine-tuning**: HuggingFace Trainer fine-tunes the BART Seq2Seq model to predict the next token conditioned on the demographic prompts.\n5. **Checkpoint**: Saved to `{CHECKPOINT_DIR}/checkpoint.pt` after training.\n\nThe `WARMUP_STEPS` ramp up the learning rate gradually during early training, preventing catastrophic forgetting of BART's pretrained sequence modeling capabilities." + "source": "**What happens during training:**\n\n1. **Dataset loading**: PyHealth reads MIMIC-III and creates one sample per patient (nested visit sequences + demographics: age at first admission, gender).\n2. **Tokenization**: Each ICD-9 code is mapped to a unique BART token ID. Special tokens mark visit boundaries: `[VISIT_START]`, `[VISIT_END]`, `[SEQ_END]`.\n3. **Demographic prompts**: Age and gender are encoded into learned prompt vectors prepended to the BART encoder input — steering the model toward age/gender-appropriate diagnosis patterns.\n4. **Fine-tuning**: HuggingFace Trainer fine-tunes the BART Seq2Seq model to predict the next token conditioned on the demographic prompts.\n5. **Checkpoint**: Saved to `{CHECKPOINT_DIR}/checkpoint.pt` after training.\n\nThe `WARMUP_STEPS` ramp up the learning rate gradually during early training, preventing catastrophic forgetting of BART's pretrained sequence modeling capabilities." }, { "cell_type": "code", @@ -112,7 +112,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "# Set all random seeds before any stochastic operation\ntorch.manual_seed(SEED)\nnp.random.seed(SEED)\nrandom.seed(SEED)\nif torch.cuda.is_available():\n torch.cuda.manual_seed_all(SEED)\n torch.backends.cudnn.deterministic = True\nprint(f\"\u2713 Random seed set to {SEED}\")\n\nfrom pyhealth.datasets import split_by_patient\nfrom pyhealth.models import PromptEHR\n\nif not DEMO_MODE:\n from pyhealth.datasets import MIMIC3Dataset\n from pyhealth.tasks import promptehr_generation_mimic3_fn\n\n print(\"\\nLoading MIMIC-III dataset (this may take a few minutes)...\")\n dataset = MIMIC3Dataset(\n root=DATA_DIR,\n tables=[\"patients\", \"admissions\", \"diagnoses_icd\"],\n code_mapping={},\n )\n print(f\"Loaded {len(dataset.patients):,} patients\")\n\n print(\"Applying PromptEHR generation task...\")\n sample_dataset = dataset.set_task(promptehr_generation_mimic3_fn)\n print(f\"Eligible patients (\u22652 visits with ICD-9 codes): {len(sample_dataset):,}\")\nelse:\n from pyhealth.datasets.sample_dataset import InMemorySampleDataset\n sample_dataset = InMemorySampleDataset(\n samples=demo_samples,\n input_schema={\"visits\": \"nested_sequence\"},\n output_schema={},\n )\n print(f\"Demo dataset ready: {len(sample_dataset)} patients\")\n\ntrain_dataset, val_dataset, _ = split_by_patient(sample_dataset, [0.8, 0.1, 0.1])\nprint(f\"\\nSplit: {len(train_dataset):,} train / {len(val_dataset):,} val patients\")" + "source": "# Set all random seeds before any stochastic operation\ntorch.manual_seed(SEED)\nnp.random.seed(SEED)\nrandom.seed(SEED)\nif torch.cuda.is_available():\n torch.cuda.manual_seed_all(SEED)\n torch.backends.cudnn.deterministic = True\nprint(f\"✓ Random seed set to {SEED}\")\n\nfrom pyhealth.datasets import split_by_patient\nfrom pyhealth.models import PromptEHR\n\nif not DEMO_MODE:\n from pyhealth.datasets import MIMIC3Dataset\n from pyhealth.tasks import promptehr_generation_mimic3_fn\n\n print(\"\\nLoading MIMIC-III dataset (this may take a few minutes)...\")\n dataset = MIMIC3Dataset(\n root=DATA_DIR,\n tables=[\"patients\", \"admissions\", \"diagnoses_icd\"],\n code_mapping={},\n )\n print(f\"Loaded {len(dataset.patients):,} patients\")\n\n print(\"Applying PromptEHR generation task...\")\n sample_dataset = dataset.set_task(promptehr_generation_mimic3_fn)\n print(f\"Eligible patients (≥2 visits with ICD-9 codes): {len(sample_dataset):,}\")\nelse:\n from pyhealth.datasets.sample_dataset import InMemorySampleDataset\n sample_dataset = InMemorySampleDataset(\n samples=demo_samples,\n input_schema={\"visits\": \"nested_sequence\"},\n output_schema={},\n )\n print(f\"Demo dataset ready: {len(sample_dataset)} patients\")\n\ntrain_dataset, val_dataset, _ = split_by_patient(sample_dataset, [0.8, 0.1, 0.1])\nprint(f\"\\nSplit: {len(train_dataset):,} train / {len(val_dataset):,} val patients\")" }, { "cell_type": "code", @@ -120,7 +120,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "# Save config alongside checkpoint for reproducibility\n_config = {k: str(v) for k, v in globals().items()\n if k.isupper() and not k.startswith('_')\n and isinstance(v, (str, int, float, bool))}\n_config['timestamp'] = datetime.now().isoformat()\n_config_path = f'{CHECKPOINT_DIR}/config.json'\nwith open(_config_path, 'w') as f:\n json.dump(_config, f, indent=2)\nprint(f\"\u2713 Config saved to {_config_path}\")\n\n# Initialize model\nprint(\"\\nInitializing PromptEHR model...\")\nmodel = PromptEHR(\n dataset=train_dataset,\n n_num_features=1, # 1 continuous demographic feature: age\n cat_cardinalities=[2], # 1 categorical feature: gender (binary: 0=male, 1=female)\n d_hidden=D_HIDDEN,\n prompt_length=PROMPT_LENGTH,\n bart_config_name=BART_CONFIG_NAME,\n epochs=EPOCHS,\n batch_size=BATCH_SIZE,\n lr=LR,\n warmup_steps=WARMUP_STEPS,\n max_seq_length=MAX_SEQ_LENGTH,\n save_dir=CHECKPOINT_DIR,\n)\n\nn_special = 7 # PAD, BOS, EOS, UNK, VISIT_START, VISIT_END, SEQ_END\nn_codes = model._vocab.total_size - n_special\ntotal_params = sum(p.numel() for p in model.parameters())\nprint(f\"\u2713 PromptEHR initialized\")\nprint(f\" Vocabulary: {model._vocab.total_size} tokens \"\n f\"({n_codes} ICD-9 codes + {n_special} special tokens)\")\nprint(f\" Parameters: {total_params:,}\")" + "source": "# Save config alongside checkpoint for reproducibility\n_config = {k: str(v) for k, v in globals().items()\n if k.isupper() and not k.startswith('_')\n and isinstance(v, (str, int, float, bool))}\n_config['timestamp'] = datetime.now().isoformat()\n_config_path = f'{CHECKPOINT_DIR}/config.json'\nwith open(_config_path, 'w') as f:\n json.dump(_config, f, indent=2)\nprint(f\"✓ Config saved to {_config_path}\")\n\n# Initialize model\nprint(\"\\nInitializing PromptEHR model...\")\nmodel = PromptEHR(\n dataset=train_dataset,\n n_num_features=1, # 1 continuous demographic feature: age\n cat_cardinalities=[2], # 1 categorical feature: gender (binary: 0=male, 1=female)\n d_hidden=D_HIDDEN,\n prompt_length=PROMPT_LENGTH,\n bart_config_name=BART_CONFIG_NAME,\n epochs=EPOCHS,\n batch_size=BATCH_SIZE,\n lr=LR,\n warmup_steps=WARMUP_STEPS,\n max_seq_length=MAX_SEQ_LENGTH,\n save_dir=CHECKPOINT_DIR,\n)\n\nn_special = 7 # PAD, BOS, EOS, UNK, VISIT_START, VISIT_END, SEQ_END\nn_codes = model._vocab.total_size - n_special\ntotal_params = sum(p.numel() for p in model.parameters())\nprint(f\"✓ PromptEHR initialized\")\nprint(f\" Vocabulary: {model._vocab.total_size} tokens \"\n f\"({n_codes} ICD-9 codes + {n_special} special tokens)\")\nprint(f\" Parameters: {total_params:,}\")" }, { "cell_type": "code", @@ -128,7 +128,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "print(\"Starting training...\")\nprint(\"HuggingFace Trainer will print step-by-step progress below.\")\nprint(\"=\" * 60)\n\nmodel.train_model(train_dataset, val_dataset=val_dataset)\n\nprint(\"=\" * 60)\nprint(\"\u2713 Training complete!\")\nprint(f\" Checkpoint: {CHECKPOINT_DIR}/checkpoint.pt\")" + "source": "print(\"Starting training...\")\nprint(\"HuggingFace Trainer will print step-by-step progress below.\")\nprint(\"=\" * 60)\n\nmodel.train_model(train_dataset, val_dataset=val_dataset)\n\nprint(\"=\" * 60)\nprint(\"✓ Training complete!\")\nprint(f\" Checkpoint: {CHECKPOINT_DIR}/checkpoint.pt\")" }, { "cell_type": "code", @@ -136,7 +136,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "# Plot training loss from HuggingFace Trainer logs\n_state_files = glob.glob(f'{CHECKPOINT_DIR}/**/trainer_state.json', recursive=True)\n\nif _state_files:\n with open(_state_files[0]) as f:\n _log = json.load(f)['log_history']\n _steps = [e['step'] for e in _log if 'loss' in e]\n _losses = [e['loss'] for e in _log if 'loss' in e]\n\n if _steps:\n fig, ax = plt.subplots(figsize=(9, 4))\n ax.plot(_steps, _losses, 'b-o', linewidth=1.5, markersize=4, label='Training loss')\n ax.set_xlabel('Training step', fontsize=12)\n ax.set_ylabel('Cross-entropy loss', fontsize=12)\n ax.set_title('PromptEHR Training Loss', fontsize=14)\n ax.legend(); ax.grid(alpha=0.3)\n plt.tight_layout()\n _loss_plot = f'{OUTPUT_DIR}/training_loss.png'\n plt.savefig(_loss_plot, dpi=150); plt.show()\n print(f\"Initial loss: {_losses[0]:.4f} \u2192 Final loss: {_losses[-1]:.4f}\")\n print(f\"Plot saved to: {_loss_plot}\")\n else:\n print(\"No loss values recorded (too few steps for demo preset).\")\nelse:\n print(\"trainer_state.json not found \u2014 skipping loss curve.\")\n print(\"(Expected for very short demo runs.)\")" + "source": "# Plot training loss from HuggingFace Trainer logs\n_state_files = glob.glob(f'{CHECKPOINT_DIR}/**/trainer_state.json', recursive=True)\n\nif _state_files:\n with open(_state_files[0]) as f:\n _log = json.load(f)['log_history']\n _steps = [e['step'] for e in _log if 'loss' in e]\n _losses = [e['loss'] for e in _log if 'loss' in e]\n\n if _steps:\n fig, ax = plt.subplots(figsize=(9, 4))\n ax.plot(_steps, _losses, 'b-o', linewidth=1.5, markersize=4, label='Training loss')\n ax.set_xlabel('Training step', fontsize=12)\n ax.set_ylabel('Cross-entropy loss', fontsize=12)\n ax.set_title('PromptEHR Training Loss', fontsize=14)\n ax.legend(); ax.grid(alpha=0.3)\n plt.tight_layout()\n _loss_plot = f'{OUTPUT_DIR}/training_loss.png'\n plt.savefig(_loss_plot, dpi=150); plt.show()\n print(f\"Initial loss: {_losses[0]:.4f} → Final loss: {_losses[-1]:.4f}\")\n print(f\"Plot saved to: {_loss_plot}\")\n else:\n print(\"No loss values recorded (too few steps for demo preset).\")\nelse:\n print(\"trainer_state.json not found — skipping loss curve.\")\n print(\"(Expected for very short demo runs.)\")" }, { "cell_type": "markdown", @@ -148,7 +148,7 @@ "cell_type": "markdown", "id": "s5-desc", "metadata": {}, - "source": "**How generation works:**\n\n1. **Demographic sampling**: For each synthetic patient, `synthesize_dataset` draws an `(age, gender)` pair from `model._demo_pool` \u2014 the real training population. This ensures the synthetic cohort's demographic profile mirrors MIMIC-III.\n2. **Prompt conditioning**: The sampled demographics are encoded into prompt vectors and prepended to the BART encoder input.\n3. **Autoregressive decoding**: BART generates tokens one at a time. Special tokens `[VISIT_START]` and `[VISIT_END]` structure the output into visits; `[SEQ_END]` ends the patient sequence.\n4. **Decoding**: Token IDs are mapped back to ICD-9 code strings.\n\n`RANDOM_SAMPLING = True` (default): nucleus sampling \u2014 diverse, realistic output. \n`RANDOM_SAMPLING = False`: greedy decoding \u2014 deterministic, may repeat common patterns." + "source": "**How generation works:**\n\n1. **Demographic sampling**: For each synthetic patient, `synthesize_dataset` draws an `(age, gender)` pair from `model._demo_pool` — the real training population. This ensures the synthetic cohort's demographic profile mirrors MIMIC-III.\n2. **Prompt conditioning**: The sampled demographics are encoded into prompt vectors and prepended to the BART encoder input.\n3. **Autoregressive decoding**: BART generates tokens one at a time. Special tokens `[VISIT_START]` and `[VISIT_END]` structure the output into visits; `[SEQ_END]` ends the patient sequence.\n4. **Decoding**: Token IDs are mapped back to ICD-9 code strings.\n\n`RANDOM_SAMPLING = True` (default): nucleus sampling — diverse, realistic output. \n`RANDOM_SAMPLING = False`: greedy decoding — deterministic, may repeat common patterns." }, { "cell_type": "code", @@ -156,7 +156,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "print(f\"Generating {N_SYNTHETIC_SAMPLES:,} synthetic patients...\")\nprint(f\" Sampling: {'nucleus (random)' if RANDOM_SAMPLING else 'greedy'}\"\n + (f\", temperature={TEMPERATURE}, top_p={TOP_P}\" if RANDOM_SAMPLING else \"\"))\nprint(\"(This may take several minutes...)\")\n\nsynthetic = model.synthesize_dataset(\n num_samples=N_SYNTHETIC_SAMPLES,\n random_sampling=RANDOM_SAMPLING,\n)\n\nprint(f\"\\n\u2713 Generated {len(synthetic):,} synthetic patients\")\n\n# Preview\n_preview = []\nfor p in synthetic[:10]:\n _v0 = p[\"visits\"][0] if p[\"visits\"] else []\n _sample = \", \".join(_v0[:4]) + (\"...\" if len(_v0) > 4 else \"\")\n _preview.append({\n \"patient_id\": p[\"patient_id\"],\n \"n_visits\": len(p[\"visits\"]),\n \"total_codes\": sum(len(v) for v in p[\"visits\"]),\n \"first_visit_codes\": _sample or \"(empty)\",\n })\ndisplay(pd.DataFrame(_preview))" + "source": "print(f\"Generating {N_SYNTHETIC_SAMPLES:,} synthetic patients...\")\nprint(f\" Sampling: {'nucleus (random)' if RANDOM_SAMPLING else 'greedy'}\"\n + (f\", temperature={TEMPERATURE}, top_p={TOP_P}\" if RANDOM_SAMPLING else \"\"))\nprint(\"(This may take several minutes...)\")\n\nsynthetic = model.synthesize_dataset(\n num_samples=N_SYNTHETIC_SAMPLES,\n random_sampling=RANDOM_SAMPLING,\n)\n\nprint(f\"\\n✓ Generated {len(synthetic):,} synthetic patients\")\n\n# Preview\n_preview = []\nfor p in synthetic[:10]:\n _v0 = p[\"visits\"][0] if p[\"visits\"] else []\n _sample = \", \".join(_v0[:4]) + (\"...\" if len(_v0) > 4 else \"\")\n _preview.append({\n \"patient_id\": p[\"patient_id\"],\n \"n_visits\": len(p[\"visits\"]),\n \"total_codes\": sum(len(v) for v in p[\"visits\"]),\n \"first_visit_codes\": _sample or \"(empty)\",\n })\ndisplay(pd.DataFrame(_preview))" }, { "cell_type": "code", @@ -164,7 +164,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "# Save as JSON (full nested records \u2014 directly loadable back into PyHealth)\njson_path = f'{OUTPUT_DIR}/synthetic_patients.json'\nwith open(json_path, 'w') as f:\n json.dump(synthetic, f, indent=2)\nprint(f\"\u2713 {len(synthetic):,} patients \u2192 {json_path}\")\n\n# Save as CSV (flat SUBJECT_ID, VISIT_NUM, ICD9_CODE \u2014 matches MIMIC-III output schema)\n_rows = []\nfor p in synthetic:\n for _vnum, _visit in enumerate(p[\"visits\"], 1):\n for _code in _visit:\n _rows.append({\"SUBJECT_ID\": p[\"patient_id\"],\n \"VISIT_NUM\": _vnum,\n \"ICD9_CODE\": _code})\ndf_synthetic = pd.DataFrame(_rows)\ncsv_path = f'{OUTPUT_DIR}/synthetic_patients.csv'\ndf_synthetic.to_csv(csv_path, index=False)\nprint(f\"\u2713 {len(df_synthetic):,} records \u2192 {csv_path}\")\nprint(f\" Columns: SUBJECT_ID, VISIT_NUM, ICD9_CODE\")\nprint(\"\\nSample rows:\")\ndisplay(df_synthetic.head(8))" + "source": "# Save as JSON (full nested records — directly loadable back into PyHealth)\njson_path = f'{OUTPUT_DIR}/synthetic_patients.json'\nwith open(json_path, 'w') as f:\n json.dump(synthetic, f, indent=2)\nprint(f\"✓ {len(synthetic):,} patients → {json_path}\")\n\n# Save as CSV (flat SUBJECT_ID, VISIT_NUM, ICD9_CODE — matches MIMIC-III output schema)\n_rows = []\nfor p in synthetic:\n for _vnum, _visit in enumerate(p[\"visits\"], 1):\n for _code in _visit:\n _rows.append({\"SUBJECT_ID\": p[\"patient_id\"],\n \"VISIT_NUM\": _vnum,\n \"ICD9_CODE\": _code})\ndf_synthetic = pd.DataFrame(_rows)\ncsv_path = f'{OUTPUT_DIR}/synthetic_patients.csv'\ndf_synthetic.to_csv(csv_path, index=False)\nprint(f\"✓ {len(df_synthetic):,} records → {csv_path}\")\nprint(f\" Columns: SUBJECT_ID, VISIT_NUM, ICD9_CODE\")\nprint(\"\\nSample rows:\")\ndisplay(df_synthetic.head(8))" }, { "cell_type": "markdown", @@ -178,7 +178,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "print(\"=\" * 60)\nprint(\"SYNTHETIC DATASET STATISTICS\")\nprint(\"=\" * 60)\n\nn_visits = [len(p[\"visits\"]) for p in synthetic]\nn_codes = [sum(len(v) for v in p[\"visits\"]) for p in synthetic]\n\nprint(f\"\\nPatients: {len(synthetic):,}\")\nprint(f\"\\nVisits per patient:\")\nprint(f\" Mean \u00b1 SD : {np.mean(n_visits):.2f} \u00b1 {np.std(n_visits):.2f}\")\nprint(f\" Median : {np.median(n_visits):.0f}\")\nprint(f\" Range : [{min(n_visits)}, {max(n_visits)}]\")\nprint(f\"\\nDiagnosis codes per patient:\")\nprint(f\" Mean \u00b1 SD : {np.mean(n_codes):.2f} \u00b1 {np.std(n_codes):.2f}\")\nprint(f\" Median : {np.median(n_codes):.0f}\")\nprint(f\" Range : [{min(n_codes)}, {max(n_codes)}]\")\n\nfig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))\nax1.hist(n_visits, bins=20, color='steelblue', edgecolor='white', alpha=0.85)\nax1.set_xlabel('Visits per patient'); ax1.set_ylabel('Count')\nax1.set_title('Visit Count Distribution')\nax2.hist(n_codes, bins=30, color='coral', edgecolor='white', alpha=0.85)\nax2.set_xlabel('Codes per patient'); ax2.set_ylabel('Count')\nax2.set_title('Code Count Distribution')\nplt.tight_layout()\nplt.savefig(f'{OUTPUT_DIR}/count_distributions.png', dpi=150)\nplt.show()" + "source": "print(\"=\" * 60)\nprint(\"SYNTHETIC DATASET STATISTICS\")\nprint(\"=\" * 60)\n\nn_visits = [len(p[\"visits\"]) for p in synthetic]\nn_codes = [sum(len(v) for v in p[\"visits\"]) for p in synthetic]\n\nprint(f\"\\nPatients: {len(synthetic):,}\")\nprint(f\"\\nVisits per patient:\")\nprint(f\" Mean ± SD : {np.mean(n_visits):.2f} ± {np.std(n_visits):.2f}\")\nprint(f\" Median : {np.median(n_visits):.0f}\")\nprint(f\" Range : [{min(n_visits)}, {max(n_visits)}]\")\nprint(f\"\\nDiagnosis codes per patient:\")\nprint(f\" Mean ± SD : {np.mean(n_codes):.2f} ± {np.std(n_codes):.2f}\")\nprint(f\" Median : {np.median(n_codes):.0f}\")\nprint(f\" Range : [{min(n_codes)}, {max(n_codes)}]\")\n\nfig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))\nax1.hist(n_visits, bins=20, color='steelblue', edgecolor='white', alpha=0.85)\nax1.set_xlabel('Visits per patient'); ax1.set_ylabel('Count')\nax1.set_title('Visit Count Distribution')\nax2.hist(n_codes, bins=30, color='coral', edgecolor='white', alpha=0.85)\nax2.set_xlabel('Codes per patient'); ax2.set_ylabel('Count')\nax2.set_title('Code Count Distribution')\nplt.tight_layout()\nplt.savefig(f'{OUTPUT_DIR}/count_distributions.png', dpi=150)\nplt.show()" }, { "cell_type": "code", @@ -186,7 +186,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "all_synth_codes = set(c for p in synthetic for v in p[\"visits\"] for c in v)\nn_real_codes = len(model._vocab._bart_to_code) # ICD-9 codes in vocabulary\ncoverage = len(all_synth_codes) / n_real_codes * 100 if n_real_codes > 0 else 0.0\n\nprint(f\"Vocabulary size (ICD-9 codes): {n_real_codes:,}\")\nprint(f\"Unique codes in synthetic: {len(all_synth_codes):,}\")\nprint(f\"Vocabulary coverage: {coverage:.1f}%\")\n\nif coverage < 30:\n print(\"\\n\u26a0 Low coverage may indicate mode collapse.\")\n print(\" Consider: more EPOCHS, lower LR, or check _demo_pool is populated.\")\nelif coverage < 60:\n print(\"\\nModerate coverage \u2014 expected for demo preset.\")\n print(\"Production training typically achieves 60\u201380%.\")\nelse:\n print(f\"\\n\u2713 Good vocabulary coverage.\")" + "source": "all_synth_codes = set(c for p in synthetic for v in p[\"visits\"] for c in v)\nn_real_codes = len(model._vocab._bart_to_code) # ICD-9 codes in vocabulary\ncoverage = len(all_synth_codes) / n_real_codes * 100 if n_real_codes > 0 else 0.0\n\nprint(f\"Vocabulary size (ICD-9 codes): {n_real_codes:,}\")\nprint(f\"Unique codes in synthetic: {len(all_synth_codes):,}\")\nprint(f\"Vocabulary coverage: {coverage:.1f}%\")\n\nif coverage < 30:\n print(\"\\n⚠ Low coverage may indicate mode collapse.\")\n print(\" Consider: more EPOCHS, lower LR, or check _demo_pool is populated.\")\nelif coverage < 60:\n print(\"\\nModerate coverage — expected for demo preset.\")\n print(\"Production training typically achieves 60–80%.\")\nelse:\n print(f\"\\n✓ Good vocabulary coverage.\")" }, { "cell_type": "code", @@ -194,7 +194,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "# model._demo_pool stores (age, gender) pairs from training data.\n# synthesize_dataset samples from this pool for each synthetic patient,\n# so the synthetic cohort's demographics automatically mirror the training population.\nif model._demo_pool:\n _ages = [a for a, g in model._demo_pool]\n _genders = [g for a, g in model._demo_pool]\n _n_male = sum(1 for g in _genders if g == 0)\n _n_female = sum(1 for g in _genders if g == 1)\n\n fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(13, 5))\n\n ax1.hist(_ages, bins=25, density=True, color='steelblue', edgecolor='white',\n alpha=0.8, label='Training population')\n ax1.axvline(np.mean(_ages), color='navy', linestyle='--', linewidth=1.5,\n label=f'Mean age: {np.mean(_ages):.1f}')\n ax1.set_xlabel('Age at first admission', fontsize=12)\n ax1.set_ylabel('Density', fontsize=12)\n ax1.set_title('Age Distribution\\n(Conditioning Source)', fontsize=13)\n ax1.legend(fontsize=10)\n\n _bars = ax2.bar(['Male', 'Female'], [_n_male, _n_female],\n color=['steelblue', 'coral'], edgecolor='white', alpha=0.85)\n for _bar, _val in zip(_bars, [_n_male, _n_female]):\n ax2.text(_bar.get_x() + _bar.get_width()/2, _bar.get_height() + 5,\n f'{_val:,}\\n({_val/len(_genders)*100:.1f}%)',\n ha='center', va='bottom', fontsize=11)\n ax2.set_ylabel('Patient count', fontsize=12)\n ax2.set_title('Gender Distribution\\n(Conditioning Source)', fontsize=13)\n\n plt.tight_layout()\n plt.savefig(f'{OUTPUT_DIR}/demographics_distribution.png', dpi=150)\n plt.show()\n\n print(f\"Demographics pool: {len(model._demo_pool):,} training patients\")\n print(f\" Age: mean={np.mean(_ages):.1f}, std={np.std(_ages):.1f}, \"\n f\"range=[{min(_ages):.0f}, {max(_ages):.0f}]\")\n print(f\" Male: {_n_male:,} ({_n_male/len(_genders)*100:.1f}%)\")\n print(f\" Female: {_n_female:,} ({_n_female/len(_genders)*100:.1f}%)\")\n print(\"\\n\u2713 Synthetic patients are generated with demographics sampled from this distribution.\")\nelse:\n print(\"_demo_pool is empty \u2014 model was not trained before calling synthesize_dataset.\")\n print(\"Run Section 4 first, or load a checkpoint that was saved after training.\")" + "source": "# model._demo_pool stores (age, gender) pairs from training data.\n# synthesize_dataset samples from this pool for each synthetic patient,\n# so the synthetic cohort's demographics automatically mirror the training population.\nif model._demo_pool:\n _ages = [a for a, g in model._demo_pool]\n _genders = [g for a, g in model._demo_pool]\n _n_male = sum(1 for g in _genders if g == 0)\n _n_female = sum(1 for g in _genders if g == 1)\n\n fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(13, 5))\n\n ax1.hist(_ages, bins=25, density=True, color='steelblue', edgecolor='white',\n alpha=0.8, label='Training population')\n ax1.axvline(np.mean(_ages), color='navy', linestyle='--', linewidth=1.5,\n label=f'Mean age: {np.mean(_ages):.1f}')\n ax1.set_xlabel('Age at first admission', fontsize=12)\n ax1.set_ylabel('Density', fontsize=12)\n ax1.set_title('Age Distribution\\n(Conditioning Source)', fontsize=13)\n ax1.legend(fontsize=10)\n\n _bars = ax2.bar(['Male', 'Female'], [_n_male, _n_female],\n color=['steelblue', 'coral'], edgecolor='white', alpha=0.85)\n for _bar, _val in zip(_bars, [_n_male, _n_female]):\n ax2.text(_bar.get_x() + _bar.get_width()/2, _bar.get_height() + 5,\n f'{_val:,}\\n({_val/len(_genders)*100:.1f}%)',\n ha='center', va='bottom', fontsize=11)\n ax2.set_ylabel('Patient count', fontsize=12)\n ax2.set_title('Gender Distribution\\n(Conditioning Source)', fontsize=13)\n\n plt.tight_layout()\n plt.savefig(f'{OUTPUT_DIR}/demographics_distribution.png', dpi=150)\n plt.show()\n\n print(f\"Demographics pool: {len(model._demo_pool):,} training patients\")\n print(f\" Age: mean={np.mean(_ages):.1f}, std={np.std(_ages):.1f}, \"\n f\"range=[{min(_ages):.0f}, {max(_ages):.0f}]\")\n print(f\" Male: {_n_male:,} ({_n_male/len(_genders)*100:.1f}%)\")\n print(f\" Female: {_n_female:,} ({_n_female/len(_genders)*100:.1f}%)\")\n print(\"\\n✓ Synthetic patients are generated with demographics sampled from this distribution.\")\nelse:\n print(\"_demo_pool is empty — model was not trained before calling synthesize_dataset.\")\n print(\"Run Section 4 first, or load a checkpoint that was saved after training.\")" }, { "cell_type": "code", @@ -202,7 +202,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "# Build real training code frequencies by decoding processor-encoded visit tensors.\n# NestedSequenceProcessor: index 0=pad, 1=unk, 2+=codes.\n# _PromptEHRVocab mapping: bart_id = processor_idx + 5 for codes (idx>=2).\n_vocab_map = model._vocab._bart_to_code # bart_token_id -> ICD-9 code string\n_real_counts = Counter()\n\nfor _sample in train_dataset:\n for _visit in _sample.get(\"visits\", []):\n for _tok in _visit:\n _idx = int(_tok.item()) if hasattr(_tok, 'item') else int(_tok)\n if _idx >= 2: # skip pad(0) and unk(1)\n _bart_id = _idx + 5\n _code = _vocab_map.get(_bart_id)\n if _code:\n _real_counts[_code] += 1\n\n_synth_counts = Counter(c for p in synthetic for v in p[\"visits\"] for c in v)\n\n_top_codes = [c for c, _ in _real_counts.most_common(20)]\n_real_freq = [_real_counts[c] for c in _top_codes]\n_synth_freq = [_synth_counts.get(c, 0) for c in _top_codes]\n\nfig, ax = plt.subplots(figsize=(15, 5))\n_x = range(len(_top_codes))\nax.bar([i - 0.2 for i in _x], _real_freq, 0.38, label='Real (training)', color='steelblue', alpha=0.85)\nax.bar([i + 0.2 for i in _x], _synth_freq, 0.38, label='Synthetic', color='coral', alpha=0.85)\nax.set_xticks(_x)\nax.set_xticklabels(_top_codes, rotation=45, ha='right', fontsize=9)\nax.set_ylabel('Frequency', fontsize=12)\nax.set_title('Top-20 ICD-9 Code Frequency: Real vs Synthetic', fontsize=14)\nax.legend(fontsize=11); ax.grid(axis='y', alpha=0.3)\nplt.tight_layout()\nplt.savefig(f'{OUTPUT_DIR}/code_frequency_comparison.png', dpi=150)\nplt.show()\n\n# Pearson r (manual computation \u2014 no scipy dependency)\n_r_mean = np.mean(_real_freq); _s_mean = np.mean(_synth_freq)\n_num = sum((r - _r_mean)*(s - _s_mean) for r, s in zip(_real_freq, _synth_freq))\n_denom = (sum((r-_r_mean)**2 for r in _real_freq) * sum((s-_s_mean)**2 for s in _synth_freq)) ** 0.5\npearson_r = _num / _denom if _denom > 0 else 0.0\nprint(f\"Pearson r (top-20 code frequencies, real vs synthetic): {pearson_r:.3f}\")\nif pearson_r > 0.8: print(\"\u2713 Strong correlation \u2014 good distributional fidelity.\")\nelif pearson_r > 0.5: print(\"Moderate correlation \u2014 consider more epochs.\")\nelse: print(\"Weak correlation \u2014 model may need more training.\")" + "source": "# Build real training code frequencies by decoding processor-encoded visit tensors.\n# NestedSequenceProcessor: index 0=pad, 1=unk, 2+=codes.\n# _PromptEHRVocab mapping: bart_id = processor_idx + 5 for codes (idx>=2).\n_vocab_map = model._vocab._bart_to_code # bart_token_id -> ICD-9 code string\n_real_counts = Counter()\n\nfor _sample in train_dataset:\n for _visit in _sample.get(\"visits\", []):\n for _tok in _visit:\n _idx = int(_tok.item()) if hasattr(_tok, 'item') else int(_tok)\n if _idx >= 2: # skip pad(0) and unk(1)\n _bart_id = _idx + 5\n _code = _vocab_map.get(_bart_id)\n if _code:\n _real_counts[_code] += 1\n\n_synth_counts = Counter(c for p in synthetic for v in p[\"visits\"] for c in v)\n\n_top_codes = [c for c, _ in _real_counts.most_common(20)]\n_real_freq = [_real_counts[c] for c in _top_codes]\n_synth_freq = [_synth_counts.get(c, 0) for c in _top_codes]\n\nfig, ax = plt.subplots(figsize=(15, 5))\n_x = range(len(_top_codes))\nax.bar([i - 0.2 for i in _x], _real_freq, 0.38, label='Real (training)', color='steelblue', alpha=0.85)\nax.bar([i + 0.2 for i in _x], _synth_freq, 0.38, label='Synthetic', color='coral', alpha=0.85)\nax.set_xticks(_x)\nax.set_xticklabels(_top_codes, rotation=45, ha='right', fontsize=9)\nax.set_ylabel('Frequency', fontsize=12)\nax.set_title('Top-20 ICD-9 Code Frequency: Real vs Synthetic', fontsize=14)\nax.legend(fontsize=11); ax.grid(axis='y', alpha=0.3)\nplt.tight_layout()\nplt.savefig(f'{OUTPUT_DIR}/code_frequency_comparison.png', dpi=150)\nplt.show()\n\n# Pearson r (manual computation — no scipy dependency)\n_r_mean = np.mean(_real_freq); _s_mean = np.mean(_synth_freq)\n_num = sum((r - _r_mean)*(s - _s_mean) for r, s in zip(_real_freq, _synth_freq))\n_denom = (sum((r-_r_mean)**2 for r in _real_freq) * sum((s-_s_mean)**2 for s in _synth_freq)) ** 0.5\npearson_r = _num / _denom if _denom > 0 else 0.0\nprint(f\"Pearson r (top-20 code frequencies, real vs synthetic): {pearson_r:.3f}\")\nif pearson_r > 0.8: print(\"✓ Strong correlation — good distributional fidelity.\")\nelif pearson_r > 0.5: print(\"Moderate correlation — consider more epochs.\")\nelse: print(\"Weak correlation — model may need more training.\")" }, { "cell_type": "code", @@ -210,7 +210,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "_empty = [p for p in synthetic if not p[\"visits\"] or all(len(v) == 0 for v in p[\"visits\"])]\nif _empty:\n print(f\"\u26a0 {len(_empty)} / {len(synthetic)} patients have empty visit sequences.\")\n print(\" Possible causes:\")\n print(\" - Model is undertrained (increase EPOCHS)\")\n print(\" - Temperature too low (try TEMPERATURE = 1.0)\")\n print(\" - _demo_pool not populated (train before calling synthesize_dataset)\")\nelse:\n print(f\"\u2713 All {len(synthetic):,} patients have at least one visit with at least one code.\")" + "source": "_empty = [p for p in synthetic if not p[\"visits\"] or all(len(v) == 0 for v in p[\"visits\"])]\nif _empty:\n print(f\"⚠ {len(_empty)} / {len(synthetic)} patients have empty visit sequences.\")\n print(\" Possible causes:\")\n print(\" - Model is undertrained (increase EPOCHS)\")\n print(\" - Temperature too low (try TEMPERATURE = 1.0)\")\n print(\" - _demo_pool not populated (train before calling synthesize_dataset)\")\nelse:\n print(f\"✓ All {len(synthetic):,} patients have at least one visit with at least one code.\")" }, { "cell_type": "code", @@ -218,7 +218,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "quality = {\n \"total_synthetic_patients\": len(synthetic),\n \"mean_visits_per_patient\": round(float(np.mean(n_visits)), 3),\n \"std_visits_per_patient\": round(float(np.std(n_visits)), 3),\n \"mean_codes_per_patient\": round(float(np.mean(n_codes)), 3),\n \"std_codes_per_patient\": round(float(np.std(n_codes)), 3),\n \"unique_codes_generated\": len(all_synth_codes),\n \"vocabulary_size\": n_real_codes,\n \"vocabulary_coverage_pct\": round(coverage, 2),\n \"empty_patients_count\": len(_empty),\n \"code_freq_pearson_r\": round(pearson_r, 4),\n \"training_patients\": len(train_dataset),\n \"vocab_total_size\": model._vocab.total_size,\n \"demo_mode\": DEMO_MODE,\n \"preset\": PRESET,\n \"epochs\": EPOCHS,\n \"seed\": SEED,\n \"timestamp\": datetime.now().isoformat(),\n}\nreport_path = f'{OUTPUT_DIR}/quality_report.json'\nwith open(report_path, 'w') as f:\n json.dump(quality, f, indent=2)\nprint(\"Quality Report:\")\nprint(json.dumps(quality, indent=2))\nprint(f\"\\n\u2713 Saved to {report_path}\")" + "source": "quality = {\n \"total_synthetic_patients\": len(synthetic),\n \"mean_visits_per_patient\": round(float(np.mean(n_visits)), 3),\n \"std_visits_per_patient\": round(float(np.std(n_visits)), 3),\n \"mean_codes_per_patient\": round(float(np.mean(n_codes)), 3),\n \"std_codes_per_patient\": round(float(np.std(n_codes)), 3),\n \"unique_codes_generated\": len(all_synth_codes),\n \"vocabulary_size\": n_real_codes,\n \"vocabulary_coverage_pct\": round(coverage, 2),\n \"empty_patients_count\": len(_empty),\n \"code_freq_pearson_r\": round(pearson_r, 4),\n \"training_patients\": len(train_dataset),\n \"vocab_total_size\": model._vocab.total_size,\n \"demo_mode\": DEMO_MODE,\n \"preset\": PRESET,\n \"epochs\": EPOCHS,\n \"seed\": SEED,\n \"timestamp\": datetime.now().isoformat(),\n}\nreport_path = f'{OUTPUT_DIR}/quality_report.json'\nwith open(report_path, 'w') as f:\n json.dump(quality, f, indent=2)\nprint(\"Quality Report:\")\nprint(json.dumps(quality, indent=2))\nprint(f\"\\n✓ Saved to {report_path}\")" }, { "cell_type": "markdown", @@ -232,7 +232,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "# Download output files (Colab only \u2014 silently skipped in local/SLURM environments)\n_outputs = [\n csv_path,\n json_path,\n report_path,\n f'{OUTPUT_DIR}/training_loss.png',\n f'{OUTPUT_DIR}/demographics_distribution.png',\n f'{OUTPUT_DIR}/code_frequency_comparison.png',\n f'{CHECKPOINT_DIR}/checkpoint.pt',\n f'{CHECKPOINT_DIR}/config.json',\n]\n\nif IN_COLAB:\n from google.colab import files as _colab_files\n print(\"Downloading output files...\")\n for _p in _outputs:\n if os.path.exists(_p):\n _colab_files.download(_p)\n print(f\" \u2713 {os.path.basename(_p)}\")\n else:\n print(f\" \u2014 {os.path.basename(_p)} (not found)\")\nelse:\n print(f\"Output files saved to: {OUTPUT_DIR}\")\n print(f\"Checkpoint: {CHECKPOINT_DIR}/checkpoint.pt\")\n for _p in _outputs:\n if os.path.exists(_p):\n _kb = os.path.getsize(_p) / 1024\n print(f\" {os.path.basename(_p):45s} {_kb:8.1f} KB\")" + "source": "# Download output files (Colab only — silently skipped in local/SLURM environments)\n_outputs = [\n csv_path,\n json_path,\n report_path,\n f'{OUTPUT_DIR}/training_loss.png',\n f'{OUTPUT_DIR}/demographics_distribution.png',\n f'{OUTPUT_DIR}/code_frequency_comparison.png',\n f'{CHECKPOINT_DIR}/checkpoint.pt',\n f'{CHECKPOINT_DIR}/config.json',\n]\n\nif IN_COLAB:\n from google.colab import files as _colab_files\n print(\"Downloading output files...\")\n for _p in _outputs:\n if os.path.exists(_p):\n _colab_files.download(_p)\n print(f\" ✓ {os.path.basename(_p)}\")\n else:\n print(f\" — {os.path.basename(_p)} (not found)\")\nelse:\n print(f\"Output files saved to: {OUTPUT_DIR}\")\n print(f\"Checkpoint: {CHECKPOINT_DIR}/checkpoint.pt\")\n for _p in _outputs:\n if os.path.exists(_p):\n _kb = os.path.getsize(_p) / 1024\n print(f\" {os.path.basename(_p):45s} {_kb:8.1f} KB\")" }, { "cell_type": "code", @@ -240,13 +240,13 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n# CHECKPOINT RESUME \u2014 Run this cell instead of Section 4 if you already trained\n# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n# Uncomment everything below to load an existing checkpoint, then skip to Section 5.\n\n# from pyhealth.datasets import MIMIC3Dataset, split_by_patient\n# from pyhealth.tasks import promptehr_generation_mimic3_fn\n# from pyhealth.models import PromptEHR\n#\n# dataset = MIMIC3Dataset(\n# root=DATA_DIR,\n# tables=[\"patients\", \"admissions\", \"diagnoses_icd\"],\n# code_mapping={},\n# )\n# sample_dataset = dataset.set_task(promptehr_generation_mimic3_fn)\n# train_dataset, val_dataset, _ = split_by_patient(sample_dataset, [0.8, 0.1, 0.1])\n#\n# model = PromptEHR(\n# dataset=train_dataset,\n# n_num_features=1, cat_cardinalities=[2],\n# d_hidden=D_HIDDEN, prompt_length=PROMPT_LENGTH,\n# bart_config_name=BART_CONFIG_NAME,\n# epochs=EPOCHS, batch_size=BATCH_SIZE,\n# lr=LR, warmup_steps=WARMUP_STEPS,\n# max_seq_length=MAX_SEQ_LENGTH,\n# save_dir=CHECKPOINT_DIR,\n# )\n# ckpt = f'{CHECKPOINT_DIR}/checkpoint.pt'\n# model.load_model(ckpt)\n# print(f\"\u2713 Loaded checkpoint from {ckpt}. Proceed to Section 5.\")\n\nprint(\"(Resume template \u2014 uncomment the lines above to use)\")" + "source": "# ─────────────────────────────────────────────────────────────────────────────\n# CHECKPOINT RESUME — Run this cell instead of Section 4 if you already trained\n# ─────────────────────────────────────────────────────────────────────────────\n# Uncomment everything below to load an existing checkpoint, then skip to Section 5.\n\n# from pyhealth.datasets import MIMIC3Dataset, split_by_patient\n# from pyhealth.tasks import promptehr_generation_mimic3_fn\n# from pyhealth.models import PromptEHR\n#\n# dataset = MIMIC3Dataset(\n# root=DATA_DIR,\n# tables=[\"patients\", \"admissions\", \"diagnoses_icd\"],\n# code_mapping={},\n# )\n# sample_dataset = dataset.set_task(promptehr_generation_mimic3_fn)\n# train_dataset, val_dataset, _ = split_by_patient(sample_dataset, [0.8, 0.1, 0.1])\n#\n# model = PromptEHR(\n# dataset=train_dataset,\n# n_num_features=1, cat_cardinalities=[2],\n# d_hidden=D_HIDDEN, prompt_length=PROMPT_LENGTH,\n# bart_config_name=BART_CONFIG_NAME,\n# epochs=EPOCHS, batch_size=BATCH_SIZE,\n# lr=LR, warmup_steps=WARMUP_STEPS,\n# max_seq_length=MAX_SEQ_LENGTH,\n# save_dir=CHECKPOINT_DIR,\n# )\n# ckpt = f'{CHECKPOINT_DIR}/checkpoint.pt'\n# model.load_model(ckpt)\n# print(f\"✓ Loaded checkpoint from {ckpt}. Proceed to Section 5.\")\n\nprint(\"(Resume template — uncomment the lines above to use)\")" }, { "cell_type": "markdown", "id": "s7-congrats", "metadata": {}, - "source": "---\n## \ud83c\udf89 Congratulations!\n\nYou've successfully:\n1. \u2705 Trained a PromptEHR model conditioned on patient demographics\n2. \u2705 Generated synthetic patients whose age/gender distribution mirrors MIMIC-III\n3. \u2705 Validated ICD-9 code frequency fidelity against real training data\n4. \u2705 Saved output files for downstream use\n\n## Next Steps\n\n**Use your synthetic data:**\n- Train readmission/mortality/LoS prediction models on synthetic data\n- Evaluate fairness across demographic subgroups\n- Share synthetic patients without privacy concerns\n\n**Reload and generate more:**\n```python\nfrom pyhealth.models import PromptEHR\nmodel = PromptEHR(dataset=train_dataset, ...)\nmodel.load_model('./promptehr_training/checkpoints/checkpoint.pt')\nextra = model.synthesize_dataset(num_samples=50_000)\n```\n\n## Troubleshooting\n\n| Symptom | Cause | Fix |\n|---------|-------|-----|\n| `AssertionError: transformers>=4.48.3 required` | Old transformers installed | `pip install transformers --upgrade` |\n| Empty patients in output | Undertrained model | Increase `EPOCHS` or raise `TEMPERATURE` to `1.0` |\n| Training loss not decreasing after 2+ epochs | LR too high | Try `LR = 5e-6` and `WARMUP_STEPS = 500` |\n| Out of memory (OOM) | Batch too large | Reduce `BATCH_SIZE = 8` |\n| Very slow training | No GPU | Runtime \u2192 Change runtime type \u2192 T4 GPU |\n| `KeyError: 'visits'` in demo mode | Wrong schema | Ensure `input_schema={\"visits\": \"nested_sequence\"}` |\n| Synthetic codes all the same | Temperature too low | Try `TEMPERATURE = 1.0`, `RANDOM_SAMPLING = True` |\n\n---\n\n## Reference\n\nWang, Y., et al. \"PromptEHR: Conditional Electronic Healthcare Records Generation with Prompt Learning.\" *EMNLP 2023*. https://arxiv.org/abs/2211.01761\n\n---\n_Notebook for PyHealth 2.0 \u00b7 Branch: `promptehr-pr-integration` \u00b7 jalengg/PyHealth_" + "source": "---\n## 🎉 Congratulations!\n\nYou've successfully:\n1. ✅ Trained a PromptEHR model conditioned on patient demographics\n2. ✅ Generated synthetic patients whose age/gender distribution mirrors MIMIC-III\n3. ✅ Validated ICD-9 code frequency fidelity against real training data\n4. ✅ Saved output files for downstream use\n\n## Next Steps\n\n**Use your synthetic data:**\n- Train readmission/mortality/LoS prediction models on synthetic data\n- Evaluate fairness across demographic subgroups\n- Share synthetic patients without privacy concerns\n\n**Reload and generate more:**\n```python\nfrom pyhealth.models import PromptEHR\nmodel = PromptEHR(dataset=train_dataset, ...)\nmodel.load_model('./promptehr_training/checkpoints/checkpoint.pt')\nextra = model.synthesize_dataset(num_samples=50_000)\n```\n\n## Troubleshooting\n\n| Symptom | Cause | Fix |\n|---------|-------|-----|\n| `AssertionError: transformers>=4.48.3 required` | Old transformers installed | `pip install transformers --upgrade` |\n| Empty patients in output | Undertrained model | Increase `EPOCHS` or raise `TEMPERATURE` to `1.0` |\n| Training loss not decreasing after 2+ epochs | LR too high | Try `LR = 5e-6` and `WARMUP_STEPS = 500` |\n| Out of memory (OOM) | Batch too large | Reduce `BATCH_SIZE = 8` |\n| Very slow training | No GPU | Runtime → Change runtime type → T4 GPU |\n| `KeyError: 'visits'` in demo mode | Wrong schema | Ensure `input_schema={\"visits\": \"nested_sequence\"}` |\n| Synthetic codes all the same | Temperature too low | Try `TEMPERATURE = 1.0`, `RANDOM_SAMPLING = True` |\n\n---\n\n## Reference\n\nWang, Y., et al. \"PromptEHR: Conditional Electronic Healthcare Records Generation with Prompt Learning.\" *EMNLP 2023*. https://arxiv.org/abs/2211.01761\n\n---\n_Notebook for PyHealth 2.0 · Branch: `promptehr-pr-integration` · jalengg/PyHealth_" } ] } \ No newline at end of file From 5ab75969e675959d26f48a149aaea2df49bd01a3 Mon Sep 17 00:00:00 2001 From: jalengg Date: Wed, 4 Mar 2026 01:00:36 -0600 Subject: [PATCH 10/37] Fix: guard scipy/mne-dependent task imports to fix Colab numpy 2.x cascade Wrap cardiology_detect (scipy), EEG_abnormal/events (mne), sleep_staging variants (mne), and temple_university_EEG_tasks (mne) in try/except so that pyhealth.tasks import does not fail in Colab where numpy 2.x breaks scipy._lib._util. Mirrors the identical fix in halo-pr-528. --- pyhealth/tasks/__init__.py | 56 +++++++++++++++++++++++++------------- 1 file changed, 37 insertions(+), 19 deletions(-) diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index c0811c77a..f9834777d 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -8,13 +8,16 @@ ) from .cancer_survival import CancerMutationBurden, CancerSurvivalPrediction from .bmd_hs_disease_classification import BMDHSDiseaseClassification -from .cardiology_detect import ( - cardiology_isAD_fn, - cardiology_isAR_fn, - cardiology_isBBBFB_fn, - cardiology_isCD_fn, - cardiology_isWA_fn, -) +try: + from .cardiology_detect import ( + cardiology_isAD_fn, + cardiology_isAR_fn, + cardiology_isBBBFB_fn, + cardiology_isCD_fn, + cardiology_isWA_fn, + ) +except ImportError: + pass # scipy unavailable; cardiology tasks not registered from .chestxray14_binary_classification import ChestXray14BinaryClassification from .chestxray14_multilabel_classification import ChestXray14MultilabelClassification from .covid19_cxr_classification import COVID19CXRClassification @@ -27,8 +30,14 @@ drug_recommendation_mimic4_fn, drug_recommendation_omop_fn, ) -from .EEG_abnormal import EEG_isAbnormal_fn -from .EEG_events import EEG_events_fn +try: + from .EEG_abnormal import EEG_isAbnormal_fn +except ImportError: + pass # mne unavailable +try: + from .EEG_events import EEG_events_fn +except ImportError: + pass # mne unavailable from .in_hospital_mortality_mimic4 import InHospitalMortalityMIMIC4 from .length_of_stay_prediction import ( LengthOfStayPredictioneICU, @@ -59,16 +68,25 @@ ReadmissionPredictionMIMIC4, ReadmissionPredictionOMOP, ) -from .sleep_staging import ( - sleep_staging_isruc_fn, - sleep_staging_shhs_fn, - sleep_staging_sleepedf_fn, -) -from .sleep_staging_v2 import SleepStagingSleepEDF -from .temple_university_EEG_tasks import ( - EEGEventsTUEV, - EEGAbnormalTUAB -) +try: + from .sleep_staging import ( + sleep_staging_isruc_fn, + sleep_staging_shhs_fn, + sleep_staging_sleepedf_fn, + ) +except ImportError: + pass # mne unavailable +try: + from .sleep_staging_v2 import SleepStagingSleepEDF +except ImportError: + pass # mne unavailable +try: + from .temple_university_EEG_tasks import ( + EEGEventsTUEV, + EEGAbnormalTUAB + ) +except ImportError: + pass # mne unavailable from .variant_classification import ( MutationPathogenicityPrediction, VariantClassificationClinVar, From 4f7edb58238804b0fa9a1982a8d60e0d998be0c5 Mon Sep 17 00:00:00 2001 From: jalengg Date: Wed, 4 Mar 2026 01:25:34 -0600 Subject: [PATCH 11/37] Feat: persist MIMIC-III files to Drive, skip re-upload on reconnect MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - When all 3 files exist in DATA_DIR (Drive-backed), print sizes and skip upload entirely — mirrors HALO notebook UX - Normalize uploaded filenames via shutil.copy so Colab's duplicate rename (e.g. ADMISSIONS (1).csv) maps to canonical name in Drive - Keep idempotent drive.mount() guard from previous fix --- examples/promptehr_mimic3_colab.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/promptehr_mimic3_colab.ipynb b/examples/promptehr_mimic3_colab.ipynb index fec713ccd..23401a08c 100644 --- a/examples/promptehr_mimic3_colab.ipynb +++ b/examples/promptehr_mimic3_colab.ipynb @@ -76,7 +76,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "DEMO_MODE = False\n\n# Mount Drive (Colab only) — guard makes this cell idempotent (safe to re-run)\nif IN_COLAB:\n from google.colab import drive\n if not os.path.isdir('/content/drive/MyDrive'):\n drive.mount('/content/drive')\n else:\n print(\"Drive already mounted\")\n print(\"✓ Google Drive mounted\")\n\n# Check which files exist\nrequired_files = {\n 'PATIENTS.csv': 'Patient demographics (DOB, gender)',\n 'ADMISSIONS.csv': 'Admission records (timestamps)',\n 'DIAGNOSES_ICD.csv': 'ICD-9 diagnosis codes',\n}\nexisting = {f: os.path.exists(f'{DATA_DIR}/{f}') for f in required_files}\nmissing = [f for f, ok in existing.items() if not ok]\n\nprint(\"\\nMIMIC-III file status:\")\nfor fname, desc in required_files.items():\n mark = \"✓\" if existing[fname] else \"✗ MISSING\"\n print(f\" {mark} {fname} — {desc}\")\n\nif missing and IN_COLAB:\n print(f\"\\nUploading {len(missing)} missing file(s)...\")\n from google.colab import files as _colab_files\n uploaded = _colab_files.upload()\n for fname, data in uploaded.items():\n dest = f'{DATA_DIR}/{fname}'\n with open(dest, 'wb') as f:\n f.write(data)\n print(f\" Saved {fname} → {dest}\")\n missing = [f for f in required_files if not os.path.exists(f'{DATA_DIR}/{f}')]\n\nif missing:\n print(f\"\\nMIMIC-III files not available ({missing}).\")\n print(\"→ Activating Demo Mode — full pipeline with synthetic stand-in data.\")\n DEMO_MODE = True\nelse:\n print(\"\\n✓ All 3 MIMIC-III files present. Running in MIMIC-III mode.\")" + "source": "import shutil\nDEMO_MODE = False\n\n# Mount Drive (Colab only) — guard makes this cell idempotent (safe to re-run)\nif IN_COLAB:\n from google.colab import drive\n if not os.path.isdir('/content/drive/MyDrive'):\n drive.mount('/content/drive')\n else:\n print(\"Drive already mounted\")\n print(\"✓ Google Drive mounted\")\n\n# Check which files exist in the Drive-backed DATA_DIR\nrequired_files = {\n 'PATIENTS.csv': 'Patient demographics (DOB, gender)',\n 'ADMISSIONS.csv': 'Admission records (timestamps)',\n 'DIAGNOSES_ICD.csv': 'ICD-9 diagnosis codes',\n}\nexisting = {f: os.path.exists(f'{DATA_DIR}/{f}') for f in required_files}\nmissing = [f for f, ok in existing.items() if not ok]\n\nif not missing:\n # All files already in Drive — no upload needed\n print(\"✓ All MIMIC-III files found in Drive (no upload needed):\")\n for fname in required_files:\n size_mb = os.path.getsize(f'{DATA_DIR}/{fname}') / 1024 / 1024\n print(f\" {fname} ({size_mb:.1f} MB)\")\n print(f\"\\nFiles are reused from: {DATA_DIR}\")\n print(\"To force re-upload, delete files from that folder and re-run this cell.\")\nelse:\n print(\"MIMIC-III file status:\")\n for fname, desc in required_files.items():\n mark = \"✓\" if existing[fname] else \"✗ MISSING\"\n print(f\" {mark} {fname} — {desc}\")\n\n if IN_COLAB:\n print(f\"\\nUploading {len(missing)} missing file(s)...\")\n from google.colab import files as _colab_files\n uploaded = _colab_files.upload()\n\n # Normalize filenames — Colab renames duplicates as \"ADMISSIONS (1).csv\".\n # Match each upload to the required file it belongs to, then copy with\n # the canonical name so subsequent runs find the file in Drive.\n for uploaded_name, data in uploaded.items():\n matched = None\n for req in required_files:\n base = req.replace('.csv', '')\n if base in uploaded_name and uploaded_name.endswith('.csv'):\n matched = req\n break\n if matched:\n # Write upload bytes to /content/ then copy to Drive-backed dest\n tmp = f'/content/{uploaded_name}'\n with open(tmp, 'wb') as f:\n f.write(data)\n dest = f'{DATA_DIR}/{matched}'\n shutil.copy(tmp, dest)\n size_mb = os.path.getsize(dest) / 1024 / 1024\n print(f\" ✓ Saved {matched} ({size_mb:.1f} MB) → {dest}\")\n else:\n print(f\" ⚠ Unrecognised file: {uploaded_name} (skipped)\")\n\n missing = [f for f in required_files if not os.path.exists(f'{DATA_DIR}/{f}')]\n\n if missing:\n print(f\"\\nMIMIC-III files not available ({missing}).\")\n print(\"→ Activating Demo Mode — full pipeline with synthetic stand-in data.\")\n DEMO_MODE = True\n else:\n print(\"\\n✓ All 3 MIMIC-III files present. Running in MIMIC-III mode.\")" }, { "cell_type": "code", From 394e128e225cbba12090b61a856648d4f9e0fea9 Mon Sep 17 00:00:00 2001 From: jalengg Date: Wed, 4 Mar 2026 01:30:46 -0600 Subject: [PATCH 12/37] Fix: guard PIL/mne-dependent dataset imports to fix Colab import cascade Wrap ChestXray14Dataset, COVID19CXRDataset (PIL/torchvision), SleepEDFDataset, TUABDataset, TUEVDataset (mne) in try/except so datasets/__init__ does not fail when optional deps are absent. TUABDataset was the immediate cause: tuab.py imports EEGAbnormalTUAB from pyhealth.tasks, which is now silently absent when mne is unavailable. Mirrors identical guards in halo-pr-528. --- pyhealth/datasets/__init__.py | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index b38c575c2..00a5e1884 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -48,10 +48,16 @@ def __init__(self, *args, **kwargs): from .base_dataset import BaseDataset from .cardiology import CardiologyDataset -from .chestxray14 import ChestXray14Dataset +try: + from .chestxray14 import ChestXray14Dataset +except ImportError: + pass # PIL/torchvision unavailable from .clinvar import ClinVarDataset from .cosmic import COSMICDataset -from .covid19_cxr import COVID19CXRDataset +try: + from .covid19_cxr import COVID19CXRDataset +except ImportError: + pass # PIL/torchvision unavailable from .dreamt import DREAMTDataset from .ehrshot import EHRShotDataset from .eicu import eICUDataset @@ -63,7 +69,10 @@ def __init__(self, *args, **kwargs): from .omop import OMOPDataset from .sample_dataset import SampleBuilder, SampleDataset, create_sample_dataset from .shhs import SHHSDataset -from .sleepedf import SleepEDFDataset +try: + from .sleepedf import SleepEDFDataset +except ImportError: + pass # mne unavailable from .bmd_hs import BMDHSDataset from .support2 import Support2Dataset from .tcga_prad import TCGAPRADDataset @@ -76,8 +85,14 @@ def __init__(self, *args, **kwargs): split_by_visit, split_by_visit_conformal, ) -from .tuab import TUABDataset -from .tuev import TUEVDataset +try: + from .tuab import TUABDataset +except ImportError: + pass # mne unavailable; TUABDataset not registered +try: + from .tuev import TUEVDataset +except ImportError: + pass # mne unavailable; TUEVDataset not registered from .utils import ( collate_fn_dict, collate_fn_dict_with_padding, From 8c176e9c159a3df234797a97c61f2fb668e76b0b Mon Sep 17 00:00:00 2001 From: jalengg Date: Wed, 4 Mar 2026 01:50:35 -0600 Subject: [PATCH 13/37] Fix: force-reinstall PyHealth in setup cell; update preamble SHA - Add --force-reinstall to pip install so Colab never loads a stale cached build that lacks the try/except import guards - Switch to subprocess.run with returncode check (mirrors HALO pattern) - Update preamble: last_modified 2026-03-03, commit 394e128e --- examples/promptehr_mimic3_colab.ipynb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/promptehr_mimic3_colab.ipynb b/examples/promptehr_mimic3_colab.ipynb index 23401a08c..d4ec750c0 100644 --- a/examples/promptehr_mimic3_colab.ipynb +++ b/examples/promptehr_mimic3_colab.ipynb @@ -22,7 +22,7 @@ "cell_type": "markdown", "id": "preamble", "metadata": {}, - "source": "# PromptEHR: Demographic-Conditioned Synthetic EHR Generation\n\n_Last updated: 2026-03-01_\n\nTrain **PromptEHR** on your MIMIC-III data and generate synthetic patients whose demographic distributions mirror the real population.\n\n## What You'll Need\n\n1. **MIMIC-III Access** (or run in Demo Mode without it). Download 3 files from PhysioNet:\n - `PATIENTS.csv` — patient demographics (date of birth, gender)\n - `ADMISSIONS.csv` — hospital admission records (timestamps)\n - `DIAGNOSES_ICD.csv` — ICD-9 diagnosis codes\n\n2. **Google Colab** (or local environment): Free tier works; GPU recommended.\n\n> **Demo Mode**: No MIMIC-III? Set `PRESET = \"demo\"` and skip the file upload step. The notebook runs the full pipeline with synthetic stand-in data.\n\n## What You'll Get\n\n- A trained PromptEHR model conditioned on patient age and gender\n- Synthetic patients whose age/gender distributions mirror the MIMIC-III population\n- `synthetic_patients.csv` — flat `SUBJECT_ID, VISIT_NUM, ICD9_CODE` records\n- `synthetic_patients.json` — nested visit records for PyHealth downstream tasks\n- `quality_report.json` — statistics for automated evaluation and CI\n\n## How Long It Takes\n\n| Preset | Epochs | Time (T4 GPU) | Use case |\n|--------|--------|----------------|----------|\n| `\"demo\"` | 5 | ~30–45 min | First run, CI smoke test |\n| `\"production\"` | 20 | ~3–5 hrs | Publication-quality results |\n\n## What Makes PromptEHR Different from HALO\n\nUnlike HALO (which generates patients from a shared unconditional distribution), **PromptEHR conditions generation on patient demographics**. It uses a BART Seq2Seq Transformer with learned \"prompt\" vectors — one per demographic feature — prepended to the encoder input. During training, the model learns that older male patients tend to have different diagnosis patterns than young female patients. During generation, demographics are sampled from the real training distribution, so the synthetic cohort's age/gender profile automatically mirrors MIMIC-III.\n\nThis matters clinically: synthetic datasets used for fairness research or subgroup analysis must preserve demographic distributions. PromptEHR provides this guarantee by design.\n\n**Reference**: Wang et al., \"PromptEHR: Conditional Electronic Healthcare Records Generation with Prompt Learning.\" EMNLP 2023. https://arxiv.org/abs/2211.01761" + "source": "# PromptEHR: Demographic-Conditioned Synthetic EHR Generation\n\n_Last updated: 2026-03-03 · commit `394e128e`_\n\nTrain **PromptEHR** on your MIMIC-III data and generate synthetic patients whose demographic distributions mirror the real population.\n\n## What You'll Need\n\n1. **MIMIC-III Access** (or run in Demo Mode without it). Download 3 files from PhysioNet:\n - `PATIENTS.csv` — patient demographics (date of birth, gender)\n - `ADMISSIONS.csv` — hospital admission records (timestamps)\n - `DIAGNOSES_ICD.csv` — ICD-9 diagnosis codes\n\n2. **Google Colab** (or local environment): Free tier works; GPU recommended.\n\n> **Demo Mode**: No MIMIC-III? Set `PRESET = \"demo\"` and skip the file upload step. The notebook runs the full pipeline with synthetic stand-in data.\n\n## What You'll Get\n\n- A trained PromptEHR model conditioned on patient age and gender\n- Synthetic patients whose age/gender distributions mirror the MIMIC-III population\n- `synthetic_patients.csv` — flat `SUBJECT_ID, VISIT_NUM, ICD9_CODE` records\n- `synthetic_patients.json` — nested visit records for PyHealth downstream tasks\n- `quality_report.json` — statistics for automated evaluation and CI\n\n## How Long It Takes\n\n| Preset | Epochs | Time (T4 GPU) | Use case |\n|--------|--------|----------------|----------|\n| `\"demo\"` | 5 | ~30–45 min | First run, CI smoke test |\n| `\"production\"` | 20 | ~3–5 hrs | Publication-quality results |\n\n## What Makes PromptEHR Different from HALO\n\nUnlike HALO (which generates patients from a shared unconditional distribution), **PromptEHR conditions generation on patient demographics**. It uses a BART Seq2Seq Transformer with learned \"prompt\" vectors — one per demographic feature — prepended to the encoder input. During training, the model learns that older male patients tend to have different diagnosis patterns than young female patients. During generation, demographics are sampled from the real training distribution, so the synthetic cohort's age/gender profile automatically mirrors MIMIC-III.\n\nThis matters clinically: synthetic datasets used for fairness research or subgroup analysis must preserve demographic distributions. PromptEHR provides this guarantee by design.\n\n**Reference**: Wang et al., \"PromptEHR: Conditional Electronic Healthcare Records Generation with Prompt Learning.\" EMNLP 2023. https://arxiv.org/abs/2211.01761" }, { "cell_type": "markdown", @@ -36,7 +36,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "import subprocess\nimport sys\n\n# Install PyHealth from GitHub\nFORK = 'jalengg'\nBRANCH = 'promptehr-pr-integration'\ninstall_url = f\"git+https://github.com/{FORK}/PyHealth.git@{BRANCH}\"\nsubprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", install_url,\n \"--quiet\", \"--no-cache-dir\"])\nprint(f\"✓ PyHealth installed from {FORK}/{BRANCH}\")\n\n# Environment detection — MUST come before any google.colab import\ntry:\n import google.colab # noqa: F401\n IN_COLAB = True\nexcept ImportError:\n IN_COLAB = False\n\nimport os\nimport json\nimport glob\nimport random\nimport numpy as np\nimport torch\nimport pandas as pd\nimport matplotlib.pyplot as plt\nfrom datetime import datetime\nfrom collections import Counter\nfrom IPython.display import display\n\nprint(f\"Running in Colab: {IN_COLAB}\")\nprint(f\"PyTorch: {torch.__version__}\")\nif torch.cuda.is_available():\n print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n gb = torch.cuda.get_device_properties(0).total_memory / 1e9\n print(f\"GPU memory: {gb:.1f} GB\")\nelse:\n print(\"WARNING: No GPU detected. Training will be slow.\")\n print(\" → Runtime → Change runtime type → T4 GPU\")\n\n# Verify HuggingFace transformers version (>=4.48.3 required for use_cpu parameter)\nimport transformers\n_ver = tuple(int(x) for x in transformers.__version__.split(\".\")[:2])\nassert _ver >= (4, 48), (\n f\"transformers>=4.48.3 required (got {transformers.__version__}). \"\n \"Fix: pip install transformers --upgrade\"\n)\nprint(f\"transformers: {transformers.__version__} ✓\")\nprint(\"✓ All setup complete\")" + "source": "import subprocess\nimport sys\n\n# Install PyHealth from GitHub — force-reinstall ensures Colab never uses a stale cached build\nFORK = 'jalengg'\nBRANCH = 'promptehr-pr-integration'\ninstall_url = f\"git+https://github.com/{FORK}/PyHealth.git@{BRANCH}\"\nresult = subprocess.run(\n [sys.executable, \"-m\", \"pip\", \"install\", install_url,\n \"--quiet\", \"--no-cache-dir\", \"--force-reinstall\"],\n capture_output=True, text=True,\n)\nif result.returncode != 0:\n print(result.stderr)\n raise RuntimeError(\"PyHealth installation failed — see error above.\")\nprint(f\"✓ PyHealth installed from {FORK}/{BRANCH}\")\n\n# Environment detection — MUST come before any google.colab import\ntry:\n import google.colab # noqa: F401\n IN_COLAB = True\nexcept ImportError:\n IN_COLAB = False\n\nimport os\nimport json\nimport glob\nimport random\nimport numpy as np\nimport torch\nimport pandas as pd\nimport matplotlib.pyplot as plt\nfrom datetime import datetime\nfrom collections import Counter\nfrom IPython.display import display\n\nprint(f\"Running in Colab: {IN_COLAB}\")\nprint(f\"PyTorch: {torch.__version__}\")\nif torch.cuda.is_available():\n print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n gb = torch.cuda.get_device_properties(0).total_memory / 1e9\n print(f\"GPU memory: {gb:.1f} GB\")\nelse:\n print(\"WARNING: No GPU detected. Training will be slow.\")\n print(\" → Runtime → Change runtime type → T4 GPU\")\n\n# Verify HuggingFace transformers version (>=4.48.3 required for use_cpu parameter)\nimport transformers\n_ver = tuple(int(x) for x in transformers.__version__.split(\".\")[:2])\nassert _ver >= (4, 48), (\n f\"transformers>=4.48.3 required (got {transformers.__version__}). \"\n \"Fix: pip install transformers --upgrade\"\n)\nprint(f\"transformers: {transformers.__version__} ✓\")\nprint(\"✓ All setup complete\")" }, { "cell_type": "markdown", From e77178a72b89f47e473c0fb184f1c7609fc5524a Mon Sep 17 00:00:00 2001 From: jalengg Date: Wed, 4 Mar 2026 01:56:50 -0600 Subject: [PATCH 14/37] Chore: update preamble SHA to 8c176e9c --- examples/promptehr_mimic3_colab.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/promptehr_mimic3_colab.ipynb b/examples/promptehr_mimic3_colab.ipynb index d4ec750c0..a64b7019e 100644 --- a/examples/promptehr_mimic3_colab.ipynb +++ b/examples/promptehr_mimic3_colab.ipynb @@ -22,7 +22,7 @@ "cell_type": "markdown", "id": "preamble", "metadata": {}, - "source": "# PromptEHR: Demographic-Conditioned Synthetic EHR Generation\n\n_Last updated: 2026-03-03 · commit `394e128e`_\n\nTrain **PromptEHR** on your MIMIC-III data and generate synthetic patients whose demographic distributions mirror the real population.\n\n## What You'll Need\n\n1. **MIMIC-III Access** (or run in Demo Mode without it). Download 3 files from PhysioNet:\n - `PATIENTS.csv` — patient demographics (date of birth, gender)\n - `ADMISSIONS.csv` — hospital admission records (timestamps)\n - `DIAGNOSES_ICD.csv` — ICD-9 diagnosis codes\n\n2. **Google Colab** (or local environment): Free tier works; GPU recommended.\n\n> **Demo Mode**: No MIMIC-III? Set `PRESET = \"demo\"` and skip the file upload step. The notebook runs the full pipeline with synthetic stand-in data.\n\n## What You'll Get\n\n- A trained PromptEHR model conditioned on patient age and gender\n- Synthetic patients whose age/gender distributions mirror the MIMIC-III population\n- `synthetic_patients.csv` — flat `SUBJECT_ID, VISIT_NUM, ICD9_CODE` records\n- `synthetic_patients.json` — nested visit records for PyHealth downstream tasks\n- `quality_report.json` — statistics for automated evaluation and CI\n\n## How Long It Takes\n\n| Preset | Epochs | Time (T4 GPU) | Use case |\n|--------|--------|----------------|----------|\n| `\"demo\"` | 5 | ~30–45 min | First run, CI smoke test |\n| `\"production\"` | 20 | ~3–5 hrs | Publication-quality results |\n\n## What Makes PromptEHR Different from HALO\n\nUnlike HALO (which generates patients from a shared unconditional distribution), **PromptEHR conditions generation on patient demographics**. It uses a BART Seq2Seq Transformer with learned \"prompt\" vectors — one per demographic feature — prepended to the encoder input. During training, the model learns that older male patients tend to have different diagnosis patterns than young female patients. During generation, demographics are sampled from the real training distribution, so the synthetic cohort's age/gender profile automatically mirrors MIMIC-III.\n\nThis matters clinically: synthetic datasets used for fairness research or subgroup analysis must preserve demographic distributions. PromptEHR provides this guarantee by design.\n\n**Reference**: Wang et al., \"PromptEHR: Conditional Electronic Healthcare Records Generation with Prompt Learning.\" EMNLP 2023. https://arxiv.org/abs/2211.01761" + "source": "# PromptEHR: Demographic-Conditioned Synthetic EHR Generation\n\n_Last updated: 2026-03-03 · commit `8c176e9c`_\n\nTrain **PromptEHR** on your MIMIC-III data and generate synthetic patients whose demographic distributions mirror the real population.\n\n## What You'll Need\n\n1. **MIMIC-III Access** (or run in Demo Mode without it). Download 3 files from PhysioNet:\n - `PATIENTS.csv` — patient demographics (date of birth, gender)\n - `ADMISSIONS.csv` — hospital admission records (timestamps)\n - `DIAGNOSES_ICD.csv` — ICD-9 diagnosis codes\n\n2. **Google Colab** (or local environment): Free tier works; GPU recommended.\n\n> **Demo Mode**: No MIMIC-III? Set `PRESET = \"demo\"` and skip the file upload step. The notebook runs the full pipeline with synthetic stand-in data.\n\n## What You'll Get\n\n- A trained PromptEHR model conditioned on patient age and gender\n- Synthetic patients whose age/gender distributions mirror the MIMIC-III population\n- `synthetic_patients.csv` — flat `SUBJECT_ID, VISIT_NUM, ICD9_CODE` records\n- `synthetic_patients.json` — nested visit records for PyHealth downstream tasks\n- `quality_report.json` — statistics for automated evaluation and CI\n\n## How Long It Takes\n\n| Preset | Epochs | Time (T4 GPU) | Use case |\n|--------|--------|----------------|----------|\n| `\"demo\"` | 5 | ~30–45 min | First run, CI smoke test |\n| `\"production\"` | 20 | ~3–5 hrs | Publication-quality results |\n\n## What Makes PromptEHR Different from HALO\n\nUnlike HALO (which generates patients from a shared unconditional distribution), **PromptEHR conditions generation on patient demographics**. It uses a BART Seq2Seq Transformer with learned \"prompt\" vectors — one per demographic feature — prepended to the encoder input. During training, the model learns that older male patients tend to have different diagnosis patterns than young female patients. During generation, demographics are sampled from the real training distribution, so the synthetic cohort's age/gender profile automatically mirrors MIMIC-III.\n\nThis matters clinically: synthetic datasets used for fairness research or subgroup analysis must preserve demographic distributions. PromptEHR provides this guarantee by design.\n\n**Reference**: Wang et al., \"PromptEHR: Conditional Electronic Healthcare Records Generation with Prompt Learning.\" EMNLP 2023. https://arxiv.org/abs/2211.01761" }, { "cell_type": "markdown", From 6a4e1c8244dbb7116b9e2b583949bda0686404f1 Mon Sep 17 00:00:00 2001 From: jalengg Date: Wed, 4 Mar 2026 02:03:07 -0600 Subject: [PATCH 15/37] Chore: switch preamble timestamp to UTC (no SHA lag) --- examples/promptehr_mimic3_colab.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/promptehr_mimic3_colab.ipynb b/examples/promptehr_mimic3_colab.ipynb index a64b7019e..895ebd23a 100644 --- a/examples/promptehr_mimic3_colab.ipynb +++ b/examples/promptehr_mimic3_colab.ipynb @@ -22,7 +22,7 @@ "cell_type": "markdown", "id": "preamble", "metadata": {}, - "source": "# PromptEHR: Demographic-Conditioned Synthetic EHR Generation\n\n_Last updated: 2026-03-03 · commit `8c176e9c`_\n\nTrain **PromptEHR** on your MIMIC-III data and generate synthetic patients whose demographic distributions mirror the real population.\n\n## What You'll Need\n\n1. **MIMIC-III Access** (or run in Demo Mode without it). Download 3 files from PhysioNet:\n - `PATIENTS.csv` — patient demographics (date of birth, gender)\n - `ADMISSIONS.csv` — hospital admission records (timestamps)\n - `DIAGNOSES_ICD.csv` — ICD-9 diagnosis codes\n\n2. **Google Colab** (or local environment): Free tier works; GPU recommended.\n\n> **Demo Mode**: No MIMIC-III? Set `PRESET = \"demo\"` and skip the file upload step. The notebook runs the full pipeline with synthetic stand-in data.\n\n## What You'll Get\n\n- A trained PromptEHR model conditioned on patient age and gender\n- Synthetic patients whose age/gender distributions mirror the MIMIC-III population\n- `synthetic_patients.csv` — flat `SUBJECT_ID, VISIT_NUM, ICD9_CODE` records\n- `synthetic_patients.json` — nested visit records for PyHealth downstream tasks\n- `quality_report.json` — statistics for automated evaluation and CI\n\n## How Long It Takes\n\n| Preset | Epochs | Time (T4 GPU) | Use case |\n|--------|--------|----------------|----------|\n| `\"demo\"` | 5 | ~30–45 min | First run, CI smoke test |\n| `\"production\"` | 20 | ~3–5 hrs | Publication-quality results |\n\n## What Makes PromptEHR Different from HALO\n\nUnlike HALO (which generates patients from a shared unconditional distribution), **PromptEHR conditions generation on patient demographics**. It uses a BART Seq2Seq Transformer with learned \"prompt\" vectors — one per demographic feature — prepended to the encoder input. During training, the model learns that older male patients tend to have different diagnosis patterns than young female patients. During generation, demographics are sampled from the real training distribution, so the synthetic cohort's age/gender profile automatically mirrors MIMIC-III.\n\nThis matters clinically: synthetic datasets used for fairness research or subgroup analysis must preserve demographic distributions. PromptEHR provides this guarantee by design.\n\n**Reference**: Wang et al., \"PromptEHR: Conditional Electronic Healthcare Records Generation with Prompt Learning.\" EMNLP 2023. https://arxiv.org/abs/2211.01761" + "source": "# PromptEHR: Demographic-Conditioned Synthetic EHR Generation\n\n_Last updated: 2026-03-04 07:59:46 (UTC)_\n\nTrain **PromptEHR** on your MIMIC-III data and generate synthetic patients whose demographic distributions mirror the real population.\n\n## What You'll Need\n\n1. **MIMIC-III Access** (or run in Demo Mode without it). Download 3 files from PhysioNet:\n - `PATIENTS.csv` — patient demographics (date of birth, gender)\n - `ADMISSIONS.csv` — hospital admission records (timestamps)\n - `DIAGNOSES_ICD.csv` — ICD-9 diagnosis codes\n\n2. **Google Colab** (or local environment): Free tier works; GPU recommended.\n\n> **Demo Mode**: No MIMIC-III? Set `PRESET = \"demo\"` and skip the file upload step. The notebook runs the full pipeline with synthetic stand-in data.\n\n## What You'll Get\n\n- A trained PromptEHR model conditioned on patient age and gender\n- Synthetic patients whose age/gender distributions mirror the MIMIC-III population\n- `synthetic_patients.csv` — flat `SUBJECT_ID, VISIT_NUM, ICD9_CODE` records\n- `synthetic_patients.json` — nested visit records for PyHealth downstream tasks\n- `quality_report.json` — statistics for automated evaluation and CI\n\n## How Long It Takes\n\n| Preset | Epochs | Time (T4 GPU) | Use case |\n|--------|--------|----------------|----------|\n| `\"demo\"` | 5 | ~30–45 min | First run, CI smoke test |\n| `\"production\"` | 20 | ~3–5 hrs | Publication-quality results |\n\n## What Makes PromptEHR Different from HALO\n\nUnlike HALO (which generates patients from a shared unconditional distribution), **PromptEHR conditions generation on patient demographics**. It uses a BART Seq2Seq Transformer with learned \"prompt\" vectors — one per demographic feature — prepended to the encoder input. During training, the model learns that older male patients tend to have different diagnosis patterns than young female patients. During generation, demographics are sampled from the real training distribution, so the synthetic cohort's age/gender profile automatically mirrors MIMIC-III.\n\nThis matters clinically: synthetic datasets used for fairness research or subgroup analysis must preserve demographic distributions. PromptEHR provides this guarantee by design.\n\n**Reference**: Wang et al., \"PromptEHR: Conditional Electronic Healthcare Records Generation with Prompt Learning.\" EMNLP 2023. https://arxiv.org/abs/2211.01761" }, { "cell_type": "markdown", From 6bb347af13638ac15850f9fbc996295940c9e10d Mon Sep 17 00:00:00 2001 From: jalengg Date: Wed, 4 Mar 2026 02:21:13 -0600 Subject: [PATCH 16/37] Fix: guard optional-dep model imports to fix Colab sklearn/scipy cascade MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Wrap biot (einops), cnn/graph_torchvision/torchvision/vision_embedding (PIL/torchvision), grasp (sklearn→scipy cascade), molerec/safedrug (rdkit), tfm_tokenizer (einops), transformers_model/text_embedding/sdoh (transformers) in try/except — mirrors halo-pr-528. Also removes duplicate medlink import. --- pyhealth/models/__init__.py | 77 +++++++++++++++++++++++++++---------- 1 file changed, 56 insertions(+), 21 deletions(-) diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index 6cdf2ea45..c8dcbe282 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -1,8 +1,14 @@ from .adacare import AdaCare, AdaCareLayer from .agent import Agent, AgentLayer from .base_model import BaseModel -from .biot import BIOT -from .cnn import CNN, CNNLayer +try: + from .biot import BIOT +except ImportError: + pass # einops unavailable +try: + from .cnn import CNN, CNNLayer +except ImportError: + pass # PIL/torchvision unavailable from .concare import ConCare, ConCareLayer from .contrawr import ContraWR, ResBlock2D from .deepr import Deepr, DeeprLayer @@ -12,34 +18,63 @@ from .logistic_regression import LogisticRegression from .gan import GAN from .gnn import GAT, GCN -from .graph_torchvision_model import Graph_TorchvisionModel -from .grasp import GRASP, GRASPLayer +try: + from .graph_torchvision_model import Graph_TorchvisionModel +except ImportError: + pass # torchvision unavailable +try: + from .grasp import GRASP, GRASPLayer +except ImportError: + pass # sklearn unavailable from .medlink import MedLink from .micron import MICRON, MICRONLayer -from .promptehr import PromptEHR from .mlp import MLP -from .molerec import MoleRec, MoleRecLayer +try: + from .molerec import MoleRec, MoleRecLayer +except ImportError: + pass # rdkit unavailable +from .promptehr import PromptEHR from .retain import RETAIN, RETAINLayer from .rnn import MultimodalRNN, RNN, RNNLayer -from .safedrug import SafeDrug, SafeDrugLayer +try: + from .safedrug import SafeDrug, SafeDrugLayer +except ImportError: + pass # rdkit unavailable from .sparcnet import DenseBlock, DenseLayer, SparcNet, TransitionLayer from .stagenet import StageNet, StageNetLayer from .stagenet_mha import StageAttentionNet, StageNetAttentionLayer from .tcn import TCN, TCNLayer -from .tfm_tokenizer import ( - TFMTokenizer, - TFM_VQVAE2_deep, - TFM_TOKEN_Classifier, - get_tfm_tokenizer_2x2x8, - get_tfm_token_classifier_64x4, - load_embedding_weights, -) -from .torchvision_model import TorchvisionModel +try: + from .tfm_tokenizer import ( + TFMTokenizer, + TFM_VQVAE2_deep, + TFM_TOKEN_Classifier, + get_tfm_tokenizer_2x2x8, + get_tfm_token_classifier_64x4, + load_embedding_weights, + ) +except ImportError: + pass # einops unavailable +try: + from .torchvision_model import TorchvisionModel +except ImportError: + pass # torchvision unavailable from .transformer import Transformer, TransformerLayer -from .transformers_model import TransformersModel +try: + from .transformers_model import TransformersModel +except ImportError: + pass # transformers unavailable from .ehrmamba import EHRMamba, MambaBlock from .vae import VAE -from .vision_embedding import VisionEmbeddingModel -from .text_embedding import TextEmbedding -from .sdoh import SdohClassifier -from .medlink import MedLink +try: + from .vision_embedding import VisionEmbeddingModel +except ImportError: + pass # PIL/torchvision unavailable +try: + from .text_embedding import TextEmbedding +except ImportError: + pass # transformers unavailable +try: + from .sdoh import SdohClassifier +except ImportError: + pass # transformers/peft unavailable From 2dad59c91432d5d874a4b285a831e337a91ac6e1 Mon Sep 17 00:00:00 2001 From: jalengg Date: Wed, 4 Mar 2026 02:29:44 -0600 Subject: [PATCH 17/37] Chore: update notebook timestamp 2026-03-04 08:21:17 UTC --- examples/promptehr_mimic3_colab.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/promptehr_mimic3_colab.ipynb b/examples/promptehr_mimic3_colab.ipynb index 895ebd23a..fc19ea8ed 100644 --- a/examples/promptehr_mimic3_colab.ipynb +++ b/examples/promptehr_mimic3_colab.ipynb @@ -22,7 +22,7 @@ "cell_type": "markdown", "id": "preamble", "metadata": {}, - "source": "# PromptEHR: Demographic-Conditioned Synthetic EHR Generation\n\n_Last updated: 2026-03-04 07:59:46 (UTC)_\n\nTrain **PromptEHR** on your MIMIC-III data and generate synthetic patients whose demographic distributions mirror the real population.\n\n## What You'll Need\n\n1. **MIMIC-III Access** (or run in Demo Mode without it). Download 3 files from PhysioNet:\n - `PATIENTS.csv` — patient demographics (date of birth, gender)\n - `ADMISSIONS.csv` — hospital admission records (timestamps)\n - `DIAGNOSES_ICD.csv` — ICD-9 diagnosis codes\n\n2. **Google Colab** (or local environment): Free tier works; GPU recommended.\n\n> **Demo Mode**: No MIMIC-III? Set `PRESET = \"demo\"` and skip the file upload step. The notebook runs the full pipeline with synthetic stand-in data.\n\n## What You'll Get\n\n- A trained PromptEHR model conditioned on patient age and gender\n- Synthetic patients whose age/gender distributions mirror the MIMIC-III population\n- `synthetic_patients.csv` — flat `SUBJECT_ID, VISIT_NUM, ICD9_CODE` records\n- `synthetic_patients.json` — nested visit records for PyHealth downstream tasks\n- `quality_report.json` — statistics for automated evaluation and CI\n\n## How Long It Takes\n\n| Preset | Epochs | Time (T4 GPU) | Use case |\n|--------|--------|----------------|----------|\n| `\"demo\"` | 5 | ~30–45 min | First run, CI smoke test |\n| `\"production\"` | 20 | ~3–5 hrs | Publication-quality results |\n\n## What Makes PromptEHR Different from HALO\n\nUnlike HALO (which generates patients from a shared unconditional distribution), **PromptEHR conditions generation on patient demographics**. It uses a BART Seq2Seq Transformer with learned \"prompt\" vectors — one per demographic feature — prepended to the encoder input. During training, the model learns that older male patients tend to have different diagnosis patterns than young female patients. During generation, demographics are sampled from the real training distribution, so the synthetic cohort's age/gender profile automatically mirrors MIMIC-III.\n\nThis matters clinically: synthetic datasets used for fairness research or subgroup analysis must preserve demographic distributions. PromptEHR provides this guarantee by design.\n\n**Reference**: Wang et al., \"PromptEHR: Conditional Electronic Healthcare Records Generation with Prompt Learning.\" EMNLP 2023. https://arxiv.org/abs/2211.01761" + "source": "# PromptEHR: Demographic-Conditioned Synthetic EHR Generation\n\n_Last updated: 2026-03-04 08:21:17 (UTC)_\n\nTrain **PromptEHR** on your MIMIC-III data and generate synthetic patients whose demographic distributions mirror the real population.\n\n## What You'll Need\n\n1. **MIMIC-III Access** (or run in Demo Mode without it). Download 3 files from PhysioNet:\n - `PATIENTS.csv` — patient demographics (date of birth, gender)\n - `ADMISSIONS.csv` — hospital admission records (timestamps)\n - `DIAGNOSES_ICD.csv` — ICD-9 diagnosis codes\n\n2. **Google Colab** (or local environment): Free tier works; GPU recommended.\n\n> **Demo Mode**: No MIMIC-III? Set `PRESET = \"demo\"` and skip the file upload step. The notebook runs the full pipeline with synthetic stand-in data.\n\n## What You'll Get\n\n- A trained PromptEHR model conditioned on patient age and gender\n- Synthetic patients whose age/gender distributions mirror the MIMIC-III population\n- `synthetic_patients.csv` — flat `SUBJECT_ID, VISIT_NUM, ICD9_CODE` records\n- `synthetic_patients.json` — nested visit records for PyHealth downstream tasks\n- `quality_report.json` — statistics for automated evaluation and CI\n\n## How Long It Takes\n\n| Preset | Epochs | Time (T4 GPU) | Use case |\n|--------|--------|----------------|----------|\n| `\"demo\"` | 5 | ~30–45 min | First run, CI smoke test |\n| `\"production\"` | 20 | ~3–5 hrs | Publication-quality results |\n\n## What Makes PromptEHR Different from HALO\n\nUnlike HALO (which generates patients from a shared unconditional distribution), **PromptEHR conditions generation on patient demographics**. It uses a BART Seq2Seq Transformer with learned \"prompt\" vectors — one per demographic feature — prepended to the encoder input. During training, the model learns that older male patients tend to have different diagnosis patterns than young female patients. During generation, demographics are sampled from the real training distribution, so the synthetic cohort's age/gender profile automatically mirrors MIMIC-III.\n\nThis matters clinically: synthetic datasets used for fairness research or subgroup analysis must preserve demographic distributions. PromptEHR provides this guarantee by design.\n\n**Reference**: Wang et al., \"PromptEHR: Conditional Electronic Healthcare Records Generation with Prompt Learning.\" EMNLP 2023. https://arxiv.org/abs/2211.01761" }, { "cell_type": "markdown", From 517260fd682e8c2ed74878f38fa0247bd0f61126 Mon Sep 17 00:00:00 2001 From: jalengg Date: Wed, 4 Mar 2026 02:42:55 -0600 Subject: [PATCH 18/37] Fix: guard EEGAbnormalTUAB/EEGEventsTUEV at source; rewrite preamble - tuab.py, tuev.py: wrap task imports in try/except (= None fallback) so TUABDataset/TUEVDataset load cleanly when mne is unavailable. Mirrors halo-pr-528 commit b1470ad4. - Notebook preamble: restructured to match HALO layout (What You'll Need / How It Works / Important Notes / References); removed 'Why PromptEHR is different from HALO' section per user request. - Timestamp: 2026-03-04 08:37:50 UTC --- examples/promptehr_mimic3_colab.ipynb | 4 ++-- pyhealth/datasets/tuab.py | 5 ++++- pyhealth/datasets/tuev.py | 5 ++++- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/examples/promptehr_mimic3_colab.ipynb b/examples/promptehr_mimic3_colab.ipynb index fc19ea8ed..77789d000 100644 --- a/examples/promptehr_mimic3_colab.ipynb +++ b/examples/promptehr_mimic3_colab.ipynb @@ -22,7 +22,7 @@ "cell_type": "markdown", "id": "preamble", "metadata": {}, - "source": "# PromptEHR: Demographic-Conditioned Synthetic EHR Generation\n\n_Last updated: 2026-03-04 08:21:17 (UTC)_\n\nTrain **PromptEHR** on your MIMIC-III data and generate synthetic patients whose demographic distributions mirror the real population.\n\n## What You'll Need\n\n1. **MIMIC-III Access** (or run in Demo Mode without it). Download 3 files from PhysioNet:\n - `PATIENTS.csv` — patient demographics (date of birth, gender)\n - `ADMISSIONS.csv` — hospital admission records (timestamps)\n - `DIAGNOSES_ICD.csv` — ICD-9 diagnosis codes\n\n2. **Google Colab** (or local environment): Free tier works; GPU recommended.\n\n> **Demo Mode**: No MIMIC-III? Set `PRESET = \"demo\"` and skip the file upload step. The notebook runs the full pipeline with synthetic stand-in data.\n\n## What You'll Get\n\n- A trained PromptEHR model conditioned on patient age and gender\n- Synthetic patients whose age/gender distributions mirror the MIMIC-III population\n- `synthetic_patients.csv` — flat `SUBJECT_ID, VISIT_NUM, ICD9_CODE` records\n- `synthetic_patients.json` — nested visit records for PyHealth downstream tasks\n- `quality_report.json` — statistics for automated evaluation and CI\n\n## How Long It Takes\n\n| Preset | Epochs | Time (T4 GPU) | Use case |\n|--------|--------|----------------|----------|\n| `\"demo\"` | 5 | ~30–45 min | First run, CI smoke test |\n| `\"production\"` | 20 | ~3–5 hrs | Publication-quality results |\n\n## What Makes PromptEHR Different from HALO\n\nUnlike HALO (which generates patients from a shared unconditional distribution), **PromptEHR conditions generation on patient demographics**. It uses a BART Seq2Seq Transformer with learned \"prompt\" vectors — one per demographic feature — prepended to the encoder input. During training, the model learns that older male patients tend to have different diagnosis patterns than young female patients. During generation, demographics are sampled from the real training distribution, so the synthetic cohort's age/gender profile automatically mirrors MIMIC-III.\n\nThis matters clinically: synthetic datasets used for fairness research or subgroup analysis must preserve demographic distributions. PromptEHR provides this guarantee by design.\n\n**Reference**: Wang et al., \"PromptEHR: Conditional Electronic Healthcare Records Generation with Prompt Learning.\" EMNLP 2023. https://arxiv.org/abs/2211.01761" + "source": "# PromptEHR Synthetic Data Generation for MIMIC-III\n\n_Last updated: 2026-03-04 08:37:50 (UTC)_\n\nThis notebook trains PromptEHR on your MIMIC-III data and generates synthetic patients conditioned on patient demographics (age and gender).\n\n## What You'll Need\n\n1. **MIMIC-III Access**: Download these files from PhysioNet:\n - `PATIENTS.csv` — patient demographics (date of birth, gender)\n - `ADMISSIONS.csv` — hospital admission records (timestamps)\n - `DIAGNOSES_ICD.csv` — ICD-9 diagnosis codes\n\n2. **Google Colab**: Free tier works, but GPU recommended (Runtime → Change runtime type → GPU)\n\n3. **Time**:\n - Demo (5 epochs, 1K samples): ~30–45 min on GPU\n - Production (20 epochs, 10K samples): ~3–5 hrs on GPU\n\n## How It Works\n\n1. **Setup**: Install PyHealth and mount Google Drive\n2. **Upload Data**: Upload your MIMIC-III CSV files (persisted to Drive across sessions)\n3. **Configure**: Set hyperparameters (epochs, batch size, etc.)\n4. **Train**: Fine-tune PromptEHR on MIMIC-III (checkpoints saved to Drive)\n5. **Generate**: Create synthetic patients conditioned on sampled demographics\n6. **Download**: Get CSV and JSON files with synthetic data\n\n## Important Notes\n\n**Colab Timeout**: Free Colab sessions timeout after ~12 hours. For production training, consider:\n- Colab Pro for longer sessions\n- Running on a GPU cluster using `examples/slurm/`\n\n**Demo vs Production**:\n- Demo defaults (5 epochs, 1K samples) let you test the full pipeline quickly\n- Production settings (20 epochs, 10K samples) produce publication-quality results\n\n## References\n\n- [PromptEHR Paper](https://arxiv.org/abs/2211.01761) — Wang et al., EMNLP 2023\n- [PyHealth Documentation](https://pyhealth.readthedocs.io/)\n- [MIMIC-III Access](https://physionet.org/content/mimiciii/)" }, { "cell_type": "markdown", @@ -36,7 +36,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "import subprocess\nimport sys\n\n# Install PyHealth from GitHub — force-reinstall ensures Colab never uses a stale cached build\nFORK = 'jalengg'\nBRANCH = 'promptehr-pr-integration'\ninstall_url = f\"git+https://github.com/{FORK}/PyHealth.git@{BRANCH}\"\nresult = subprocess.run(\n [sys.executable, \"-m\", \"pip\", \"install\", install_url,\n \"--quiet\", \"--no-cache-dir\", \"--force-reinstall\"],\n capture_output=True, text=True,\n)\nif result.returncode != 0:\n print(result.stderr)\n raise RuntimeError(\"PyHealth installation failed — see error above.\")\nprint(f\"✓ PyHealth installed from {FORK}/{BRANCH}\")\n\n# Environment detection — MUST come before any google.colab import\ntry:\n import google.colab # noqa: F401\n IN_COLAB = True\nexcept ImportError:\n IN_COLAB = False\n\nimport os\nimport json\nimport glob\nimport random\nimport numpy as np\nimport torch\nimport pandas as pd\nimport matplotlib.pyplot as plt\nfrom datetime import datetime\nfrom collections import Counter\nfrom IPython.display import display\n\nprint(f\"Running in Colab: {IN_COLAB}\")\nprint(f\"PyTorch: {torch.__version__}\")\nif torch.cuda.is_available():\n print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n gb = torch.cuda.get_device_properties(0).total_memory / 1e9\n print(f\"GPU memory: {gb:.1f} GB\")\nelse:\n print(\"WARNING: No GPU detected. Training will be slow.\")\n print(\" → Runtime → Change runtime type → T4 GPU\")\n\n# Verify HuggingFace transformers version (>=4.48.3 required for use_cpu parameter)\nimport transformers\n_ver = tuple(int(x) for x in transformers.__version__.split(\".\")[:2])\nassert _ver >= (4, 48), (\n f\"transformers>=4.48.3 required (got {transformers.__version__}). \"\n \"Fix: pip install transformers --upgrade\"\n)\nprint(f\"transformers: {transformers.__version__} ✓\")\nprint(\"✓ All setup complete\")" + "source": "import subprocess\nimport sys\n\n# Fix numpy 2.x incompatibility: upgrade scipy and scikit-learn before anything else.\n# transformers.generation pulls in sklearn at import time; old scipy breaks on numpy 2.x.\nsubprocess.run(\n [sys.executable, \"-m\", \"pip\", \"install\",\n \"scipy>=1.14\", \"scikit-learn>=1.5\",\n \"--quiet\", \"--no-cache-dir\"],\n check=True,\n)\nprint(\"✓ scipy and scikit-learn upgraded for numpy 2.x compatibility\")\n\n# Install PyHealth from GitHub — force-reinstall ensures Colab never uses a stale cached build\nFORK = 'jalengg'\nBRANCH = 'promptehr-pr-integration'\ninstall_url = f\"git+https://github.com/{FORK}/PyHealth.git@{BRANCH}\"\nresult = subprocess.run(\n [sys.executable, \"-m\", \"pip\", \"install\", install_url,\n \"--quiet\", \"--no-cache-dir\", \"--force-reinstall\"],\n capture_output=True, text=True,\n)\nif result.returncode != 0:\n print(result.stderr)\n raise RuntimeError(\"PyHealth installation failed — see error above.\")\nprint(f\"✓ PyHealth installed from {FORK}/{BRANCH}\")\n\n# Environment detection — MUST come before any google.colab import\ntry:\n import google.colab # noqa: F401\n IN_COLAB = True\nexcept ImportError:\n IN_COLAB = False\n\nimport os\nimport json\nimport glob\nimport random\nimport numpy as np\nimport torch\nimport pandas as pd\nimport matplotlib.pyplot as plt\nfrom datetime import datetime\nfrom collections import Counter\nfrom IPython.display import display\n\nprint(f\"Running in Colab: {IN_COLAB}\")\nprint(f\"PyTorch: {torch.__version__}\")\nif torch.cuda.is_available():\n print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n gb = torch.cuda.get_device_properties(0).total_memory / 1e9\n print(f\"GPU memory: {gb:.1f} GB\")\nelse:\n print(\"WARNING: No GPU detected. Training will be slow.\")\n print(\" → Runtime → Change runtime type → T4 GPU\")\n\n# Verify HuggingFace transformers version (>=4.48.3 required for use_cpu parameter)\nimport transformers\n_ver = tuple(int(x) for x in transformers.__version__.split(\".\")[:2])\nassert _ver >= (4, 48), (\n f\"transformers>=4.48.3 required (got {transformers.__version__}). \"\n \"Fix: pip install transformers --upgrade\"\n)\nprint(f\"transformers: {transformers.__version__} ✓\")\nprint(\"✓ All setup complete\")" }, { "cell_type": "markdown", diff --git a/pyhealth/datasets/tuab.py b/pyhealth/datasets/tuab.py index e2a3fc69c..1ba6cc3c8 100644 --- a/pyhealth/datasets/tuab.py +++ b/pyhealth/datasets/tuab.py @@ -5,7 +5,10 @@ from typing import Optional from .base_dataset import BaseDataset -from pyhealth.tasks import EEGAbnormalTUAB +try: + from pyhealth.tasks import EEGAbnormalTUAB +except ImportError: + EEGAbnormalTUAB = None # mne unavailable; TUABDataset.default_task will raise if called logger = logging.getLogger(__name__) diff --git a/pyhealth/datasets/tuev.py b/pyhealth/datasets/tuev.py index 7e8dacf98..7dd30fd58 100644 --- a/pyhealth/datasets/tuev.py +++ b/pyhealth/datasets/tuev.py @@ -5,7 +5,10 @@ from typing import Optional from .base_dataset import BaseDataset -from pyhealth.tasks import EEGEventsTUEV +try: + from pyhealth.tasks import EEGEventsTUEV +except ImportError: + EEGEventsTUEV = None # mne unavailable; TUEVDataset.default_task will raise if called logger = logging.getLogger(__name__) From c3618c72b6f69d4579b9ba1bcdfd0a3cc3c87713 Mon Sep 17 00:00:00 2001 From: jalengg Date: Wed, 4 Mar 2026 03:04:11 -0600 Subject: [PATCH 19/37] Fix: install scipy>=1.14 AFTER PyHealth to prevent --force-reinstall clobber MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --force-reinstall reinstalls all transitive deps, which could downgrade scipy back to the old Colab binary. Installing scipy>=1.14 in a second pip call after PyHealth ensures it is the final version on disk when s4-dataset later triggers the transformers→sklearn→scipy import chain. --- examples/promptehr_mimic3_colab.ipynb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/promptehr_mimic3_colab.ipynb b/examples/promptehr_mimic3_colab.ipynb index 77789d000..4e8e6d47e 100644 --- a/examples/promptehr_mimic3_colab.ipynb +++ b/examples/promptehr_mimic3_colab.ipynb @@ -22,7 +22,7 @@ "cell_type": "markdown", "id": "preamble", "metadata": {}, - "source": "# PromptEHR Synthetic Data Generation for MIMIC-III\n\n_Last updated: 2026-03-04 08:37:50 (UTC)_\n\nThis notebook trains PromptEHR on your MIMIC-III data and generates synthetic patients conditioned on patient demographics (age and gender).\n\n## What You'll Need\n\n1. **MIMIC-III Access**: Download these files from PhysioNet:\n - `PATIENTS.csv` — patient demographics (date of birth, gender)\n - `ADMISSIONS.csv` — hospital admission records (timestamps)\n - `DIAGNOSES_ICD.csv` — ICD-9 diagnosis codes\n\n2. **Google Colab**: Free tier works, but GPU recommended (Runtime → Change runtime type → GPU)\n\n3. **Time**:\n - Demo (5 epochs, 1K samples): ~30–45 min on GPU\n - Production (20 epochs, 10K samples): ~3–5 hrs on GPU\n\n## How It Works\n\n1. **Setup**: Install PyHealth and mount Google Drive\n2. **Upload Data**: Upload your MIMIC-III CSV files (persisted to Drive across sessions)\n3. **Configure**: Set hyperparameters (epochs, batch size, etc.)\n4. **Train**: Fine-tune PromptEHR on MIMIC-III (checkpoints saved to Drive)\n5. **Generate**: Create synthetic patients conditioned on sampled demographics\n6. **Download**: Get CSV and JSON files with synthetic data\n\n## Important Notes\n\n**Colab Timeout**: Free Colab sessions timeout after ~12 hours. For production training, consider:\n- Colab Pro for longer sessions\n- Running on a GPU cluster using `examples/slurm/`\n\n**Demo vs Production**:\n- Demo defaults (5 epochs, 1K samples) let you test the full pipeline quickly\n- Production settings (20 epochs, 10K samples) produce publication-quality results\n\n## References\n\n- [PromptEHR Paper](https://arxiv.org/abs/2211.01761) — Wang et al., EMNLP 2023\n- [PyHealth Documentation](https://pyhealth.readthedocs.io/)\n- [MIMIC-III Access](https://physionet.org/content/mimiciii/)" + "source": "# PromptEHR Synthetic Data Generation for MIMIC-III\n\n_Last updated: 2026-03-04 09:03:48 (UTC)_\n\nThis notebook trains PromptEHR on your MIMIC-III data and generates synthetic patients conditioned on patient demographics (age and gender).\n\n## What You'll Need\n\n1. **MIMIC-III Access**: Download these files from PhysioNet:\n - `PATIENTS.csv` — patient demographics (date of birth, gender)\n - `ADMISSIONS.csv` — hospital admission records (timestamps)\n - `DIAGNOSES_ICD.csv` — ICD-9 diagnosis codes\n\n2. **Google Colab**: Free tier works, but GPU recommended (Runtime → Change runtime type → GPU)\n\n3. **Time**:\n - Demo (5 epochs, 1K samples): ~30–45 min on GPU\n - Production (20 epochs, 10K samples): ~3–5 hrs on GPU\n\n## How It Works\n\n1. **Setup**: Install PyHealth and mount Google Drive\n2. **Upload Data**: Upload your MIMIC-III CSV files (persisted to Drive across sessions)\n3. **Configure**: Set hyperparameters (epochs, batch size, etc.)\n4. **Train**: Fine-tune PromptEHR on MIMIC-III (checkpoints saved to Drive)\n5. **Generate**: Create synthetic patients conditioned on sampled demographics\n6. **Download**: Get CSV and JSON files with synthetic data\n\n## Important Notes\n\n**Colab Timeout**: Free Colab sessions timeout after ~12 hours. For production training, consider:\n- Colab Pro for longer sessions\n- Running on a GPU cluster using `examples/slurm/`\n\n**Demo vs Production**:\n- Demo defaults (5 epochs, 1K samples) let you test the full pipeline quickly\n- Production settings (20 epochs, 10K samples) produce publication-quality results\n\n## References\n\n- [PromptEHR Paper](https://arxiv.org/abs/2211.01761) — Wang et al., EMNLP 2023\n- [PyHealth Documentation](https://pyhealth.readthedocs.io/)\n- [MIMIC-III Access](https://physionet.org/content/mimiciii/)" }, { "cell_type": "markdown", @@ -36,7 +36,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "import subprocess\nimport sys\n\n# Fix numpy 2.x incompatibility: upgrade scipy and scikit-learn before anything else.\n# transformers.generation pulls in sklearn at import time; old scipy breaks on numpy 2.x.\nsubprocess.run(\n [sys.executable, \"-m\", \"pip\", \"install\",\n \"scipy>=1.14\", \"scikit-learn>=1.5\",\n \"--quiet\", \"--no-cache-dir\"],\n check=True,\n)\nprint(\"✓ scipy and scikit-learn upgraded for numpy 2.x compatibility\")\n\n# Install PyHealth from GitHub — force-reinstall ensures Colab never uses a stale cached build\nFORK = 'jalengg'\nBRANCH = 'promptehr-pr-integration'\ninstall_url = f\"git+https://github.com/{FORK}/PyHealth.git@{BRANCH}\"\nresult = subprocess.run(\n [sys.executable, \"-m\", \"pip\", \"install\", install_url,\n \"--quiet\", \"--no-cache-dir\", \"--force-reinstall\"],\n capture_output=True, text=True,\n)\nif result.returncode != 0:\n print(result.stderr)\n raise RuntimeError(\"PyHealth installation failed — see error above.\")\nprint(f\"✓ PyHealth installed from {FORK}/{BRANCH}\")\n\n# Environment detection — MUST come before any google.colab import\ntry:\n import google.colab # noqa: F401\n IN_COLAB = True\nexcept ImportError:\n IN_COLAB = False\n\nimport os\nimport json\nimport glob\nimport random\nimport numpy as np\nimport torch\nimport pandas as pd\nimport matplotlib.pyplot as plt\nfrom datetime import datetime\nfrom collections import Counter\nfrom IPython.display import display\n\nprint(f\"Running in Colab: {IN_COLAB}\")\nprint(f\"PyTorch: {torch.__version__}\")\nif torch.cuda.is_available():\n print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n gb = torch.cuda.get_device_properties(0).total_memory / 1e9\n print(f\"GPU memory: {gb:.1f} GB\")\nelse:\n print(\"WARNING: No GPU detected. Training will be slow.\")\n print(\" → Runtime → Change runtime type → T4 GPU\")\n\n# Verify HuggingFace transformers version (>=4.48.3 required for use_cpu parameter)\nimport transformers\n_ver = tuple(int(x) for x in transformers.__version__.split(\".\")[:2])\nassert _ver >= (4, 48), (\n f\"transformers>=4.48.3 required (got {transformers.__version__}). \"\n \"Fix: pip install transformers --upgrade\"\n)\nprint(f\"transformers: {transformers.__version__} ✓\")\nprint(\"✓ All setup complete\")" + "source": "import subprocess\nimport sys\n\n# 1. Install PyHealth from GitHub — force-reinstall ensures Colab never uses a stale cached build.\n# (This may pull in old scipy as a transitive dep — we upgrade it in step 2.)\nFORK = 'jalengg'\nBRANCH = 'promptehr-pr-integration'\ninstall_url = f\"git+https://github.com/{FORK}/PyHealth.git@{BRANCH}\"\nresult = subprocess.run(\n [sys.executable, \"-m\", \"pip\", \"install\", install_url,\n \"--quiet\", \"--no-cache-dir\", \"--force-reinstall\"],\n capture_output=True, text=True,\n)\nif result.returncode != 0:\n print(result.stderr)\n raise RuntimeError(\"PyHealth installation failed — see error above.\")\nprint(f\"✓ PyHealth installed from {FORK}/{BRANCH}\")\n\n# 2. Upgrade scipy and scikit-learn AFTER PyHealth.\n# --force-reinstall above may have pulled in old scipy; upgrading here ensures\n# scipy>=1.14 (which supports numpy 2.x) is on disk before any import happens.\n# transformers.generation pulls in sklearn at import time; old scipy breaks on numpy 2.x.\nsubprocess.run(\n [sys.executable, \"-m\", \"pip\", \"install\",\n \"scipy>=1.14\", \"scikit-learn>=1.5\",\n \"--quiet\", \"--no-cache-dir\"],\n check=True,\n)\nprint(\"✓ scipy>=1.14 and scikit-learn>=1.5 installed\")\n\n# Environment detection — MUST come before any google.colab import\ntry:\n import google.colab # noqa: F401\n IN_COLAB = True\nexcept ImportError:\n IN_COLAB = False\n\nimport os\nimport json\nimport glob\nimport random\nimport numpy as np\nimport torch\nimport pandas as pd\nimport matplotlib.pyplot as plt\nfrom datetime import datetime\nfrom collections import Counter\nfrom IPython.display import display\n\nprint(f\"Running in Colab: {IN_COLAB}\")\nprint(f\"PyTorch: {torch.__version__}\")\nif torch.cuda.is_available():\n print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n gb = torch.cuda.get_device_properties(0).total_memory / 1e9\n print(f\"GPU memory: {gb:.1f} GB\")\nelse:\n print(\"WARNING: No GPU detected. Training will be slow.\")\n print(\" → Runtime → Change runtime type → T4 GPU\")\n\n# Verify HuggingFace transformers version (>=4.48.3 required for use_cpu parameter)\nimport transformers\n_ver = tuple(int(x) for x in transformers.__version__.split(\".\")[:2])\nassert _ver >= (4, 48), (\n f\"transformers>=4.48.3 required (got {transformers.__version__}). \"\n \"Fix: pip install transformers --upgrade\"\n)\nprint(f\"transformers: {transformers.__version__} ✓\")\nprint(\"✓ All setup complete\")" }, { "cell_type": "markdown", From 2d4b5bea8fedab9cd2a581f875238cca5d2fa787 Mon Sep 17 00:00:00 2001 From: jalengg Date: Wed, 4 Mar 2026 03:30:47 -0600 Subject: [PATCH 20/37] Fix: add Pillow>=10.4.0 to post-install step to fix mixed PIL state PIL._typing._Ink moved between Pillow versions; --force-reinstall can leave the package in an inconsistent state. Pinning Pillow>=10.4.0 in the post-PyHealth upgrade step ensures consistent PIL internals. --- examples/promptehr_mimic3_colab.ipynb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/promptehr_mimic3_colab.ipynb b/examples/promptehr_mimic3_colab.ipynb index 4e8e6d47e..3c20f0966 100644 --- a/examples/promptehr_mimic3_colab.ipynb +++ b/examples/promptehr_mimic3_colab.ipynb @@ -22,7 +22,7 @@ "cell_type": "markdown", "id": "preamble", "metadata": {}, - "source": "# PromptEHR Synthetic Data Generation for MIMIC-III\n\n_Last updated: 2026-03-04 09:03:48 (UTC)_\n\nThis notebook trains PromptEHR on your MIMIC-III data and generates synthetic patients conditioned on patient demographics (age and gender).\n\n## What You'll Need\n\n1. **MIMIC-III Access**: Download these files from PhysioNet:\n - `PATIENTS.csv` — patient demographics (date of birth, gender)\n - `ADMISSIONS.csv` — hospital admission records (timestamps)\n - `DIAGNOSES_ICD.csv` — ICD-9 diagnosis codes\n\n2. **Google Colab**: Free tier works, but GPU recommended (Runtime → Change runtime type → GPU)\n\n3. **Time**:\n - Demo (5 epochs, 1K samples): ~30–45 min on GPU\n - Production (20 epochs, 10K samples): ~3–5 hrs on GPU\n\n## How It Works\n\n1. **Setup**: Install PyHealth and mount Google Drive\n2. **Upload Data**: Upload your MIMIC-III CSV files (persisted to Drive across sessions)\n3. **Configure**: Set hyperparameters (epochs, batch size, etc.)\n4. **Train**: Fine-tune PromptEHR on MIMIC-III (checkpoints saved to Drive)\n5. **Generate**: Create synthetic patients conditioned on sampled demographics\n6. **Download**: Get CSV and JSON files with synthetic data\n\n## Important Notes\n\n**Colab Timeout**: Free Colab sessions timeout after ~12 hours. For production training, consider:\n- Colab Pro for longer sessions\n- Running on a GPU cluster using `examples/slurm/`\n\n**Demo vs Production**:\n- Demo defaults (5 epochs, 1K samples) let you test the full pipeline quickly\n- Production settings (20 epochs, 10K samples) produce publication-quality results\n\n## References\n\n- [PromptEHR Paper](https://arxiv.org/abs/2211.01761) — Wang et al., EMNLP 2023\n- [PyHealth Documentation](https://pyhealth.readthedocs.io/)\n- [MIMIC-III Access](https://physionet.org/content/mimiciii/)" + "source": "# PromptEHR Synthetic Data Generation for MIMIC-III\n\n_Last updated: 2026-03-04 09:29:09 (UTC)_\n\nThis notebook trains PromptEHR on your MIMIC-III data and generates synthetic patients conditioned on patient demographics (age and gender).\n\n## What You'll Need\n\n1. **MIMIC-III Access**: Download these files from PhysioNet:\n - `PATIENTS.csv` — patient demographics (date of birth, gender)\n - `ADMISSIONS.csv` — hospital admission records (timestamps)\n - `DIAGNOSES_ICD.csv` — ICD-9 diagnosis codes\n\n2. **Google Colab**: Free tier works, but GPU recommended (Runtime → Change runtime type → GPU)\n\n3. **Time**:\n - Demo (5 epochs, 1K samples): ~30–45 min on GPU\n - Production (20 epochs, 10K samples): ~3–5 hrs on GPU\n\n## How It Works\n\n1. **Setup**: Install PyHealth and mount Google Drive\n2. **Upload Data**: Upload your MIMIC-III CSV files (persisted to Drive across sessions)\n3. **Configure**: Set hyperparameters (epochs, batch size, etc.)\n4. **Train**: Fine-tune PromptEHR on MIMIC-III (checkpoints saved to Drive)\n5. **Generate**: Create synthetic patients conditioned on sampled demographics\n6. **Download**: Get CSV and JSON files with synthetic data\n\n## Important Notes\n\n**Colab Timeout**: Free Colab sessions timeout after ~12 hours. For production training, consider:\n- Colab Pro for longer sessions\n- Running on a GPU cluster using `examples/slurm/`\n\n**Demo vs Production**:\n- Demo defaults (5 epochs, 1K samples) let you test the full pipeline quickly\n- Production settings (20 epochs, 10K samples) produce publication-quality results\n\n## References\n\n- [PromptEHR Paper](https://arxiv.org/abs/2211.01761) — Wang et al., EMNLP 2023\n- [PyHealth Documentation](https://pyhealth.readthedocs.io/)\n- [MIMIC-III Access](https://physionet.org/content/mimiciii/)" }, { "cell_type": "markdown", @@ -36,7 +36,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "import subprocess\nimport sys\n\n# 1. Install PyHealth from GitHub — force-reinstall ensures Colab never uses a stale cached build.\n# (This may pull in old scipy as a transitive dep — we upgrade it in step 2.)\nFORK = 'jalengg'\nBRANCH = 'promptehr-pr-integration'\ninstall_url = f\"git+https://github.com/{FORK}/PyHealth.git@{BRANCH}\"\nresult = subprocess.run(\n [sys.executable, \"-m\", \"pip\", \"install\", install_url,\n \"--quiet\", \"--no-cache-dir\", \"--force-reinstall\"],\n capture_output=True, text=True,\n)\nif result.returncode != 0:\n print(result.stderr)\n raise RuntimeError(\"PyHealth installation failed — see error above.\")\nprint(f\"✓ PyHealth installed from {FORK}/{BRANCH}\")\n\n# 2. Upgrade scipy and scikit-learn AFTER PyHealth.\n# --force-reinstall above may have pulled in old scipy; upgrading here ensures\n# scipy>=1.14 (which supports numpy 2.x) is on disk before any import happens.\n# transformers.generation pulls in sklearn at import time; old scipy breaks on numpy 2.x.\nsubprocess.run(\n [sys.executable, \"-m\", \"pip\", \"install\",\n \"scipy>=1.14\", \"scikit-learn>=1.5\",\n \"--quiet\", \"--no-cache-dir\"],\n check=True,\n)\nprint(\"✓ scipy>=1.14 and scikit-learn>=1.5 installed\")\n\n# Environment detection — MUST come before any google.colab import\ntry:\n import google.colab # noqa: F401\n IN_COLAB = True\nexcept ImportError:\n IN_COLAB = False\n\nimport os\nimport json\nimport glob\nimport random\nimport numpy as np\nimport torch\nimport pandas as pd\nimport matplotlib.pyplot as plt\nfrom datetime import datetime\nfrom collections import Counter\nfrom IPython.display import display\n\nprint(f\"Running in Colab: {IN_COLAB}\")\nprint(f\"PyTorch: {torch.__version__}\")\nif torch.cuda.is_available():\n print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n gb = torch.cuda.get_device_properties(0).total_memory / 1e9\n print(f\"GPU memory: {gb:.1f} GB\")\nelse:\n print(\"WARNING: No GPU detected. Training will be slow.\")\n print(\" → Runtime → Change runtime type → T4 GPU\")\n\n# Verify HuggingFace transformers version (>=4.48.3 required for use_cpu parameter)\nimport transformers\n_ver = tuple(int(x) for x in transformers.__version__.split(\".\")[:2])\nassert _ver >= (4, 48), (\n f\"transformers>=4.48.3 required (got {transformers.__version__}). \"\n \"Fix: pip install transformers --upgrade\"\n)\nprint(f\"transformers: {transformers.__version__} ✓\")\nprint(\"✓ All setup complete\")" + "source": "import subprocess\nimport sys\n\n# 1. Install PyHealth from GitHub — force-reinstall ensures Colab never uses a stale cached build.\n# (This may pull in old/mismatched transitive deps — we normalize them in step 2.)\nFORK = 'jalengg'\nBRANCH = 'promptehr-pr-integration'\ninstall_url = f\"git+https://github.com/{FORK}/PyHealth.git@{BRANCH}\"\nresult = subprocess.run(\n [sys.executable, \"-m\", \"pip\", \"install\", install_url,\n \"--quiet\", \"--no-cache-dir\", \"--force-reinstall\"],\n capture_output=True, text=True,\n)\nif result.returncode != 0:\n print(result.stderr)\n raise RuntimeError(\"PyHealth installation failed — see error above.\")\nprint(f\"✓ PyHealth installed from {FORK}/{BRANCH}\")\n\n# 2. Normalize transitive deps AFTER PyHealth install.\n# --force-reinstall above can leave packages in mixed-version states:\n# - scipy>=1.14 supports numpy 2.x (fixes transformers→sklearn→scipy cascade)\n# - scikit-learn>=1.5 matches scipy>=1.14\n# - Pillow>=10.4.0 ensures consistent PIL internals (_Ink moved between versions)\nsubprocess.run(\n [sys.executable, \"-m\", \"pip\", \"install\",\n \"scipy>=1.14\", \"scikit-learn>=1.5\", \"Pillow>=10.4.0\",\n \"--quiet\", \"--no-cache-dir\"],\n check=True,\n)\nprint(\"✓ scipy>=1.14, scikit-learn>=1.5, Pillow>=10.4.0 installed\")\n\n# Environment detection — MUST come before any google.colab import\ntry:\n import google.colab # noqa: F401\n IN_COLAB = True\nexcept ImportError:\n IN_COLAB = False\n\nimport os\nimport json\nimport glob\nimport random\nimport numpy as np\nimport torch\nimport pandas as pd\nimport matplotlib.pyplot as plt\nfrom datetime import datetime\nfrom collections import Counter\nfrom IPython.display import display\n\nprint(f\"Running in Colab: {IN_COLAB}\")\nprint(f\"PyTorch: {torch.__version__}\")\nif torch.cuda.is_available():\n print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n gb = torch.cuda.get_device_properties(0).total_memory / 1e9\n print(f\"GPU memory: {gb:.1f} GB\")\nelse:\n print(\"WARNING: No GPU detected. Training will be slow.\")\n print(\" → Runtime → Change runtime type → T4 GPU\")\n\n# Verify HuggingFace transformers version (>=4.48.3 required for use_cpu parameter)\nimport transformers\n_ver = tuple(int(x) for x in transformers.__version__.split(\".\")[:2])\nassert _ver >= (4, 48), (\n f\"transformers>=4.48.3 required (got {transformers.__version__}). \"\n \"Fix: pip install transformers --upgrade\"\n)\nprint(f\"transformers: {transformers.__version__} ✓\")\nprint(\"✓ All setup complete\")" }, { "cell_type": "markdown", From b30c27b0c0ba2e8deaafdaffe5bcd6b8255a12e6 Mon Sep 17 00:00:00 2001 From: jalengg Date: Wed, 4 Mar 2026 03:54:35 -0600 Subject: [PATCH 21/37] Fix: guard ImageProcessor/TimeImageProcessor in processors/__init__.py (PIL/torchvision unavailable) --- pyhealth/processors/__init__.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/pyhealth/processors/__init__.py b/pyhealth/processors/__init__.py index 283354f80..70762dd9c 100644 --- a/pyhealth/processors/__init__.py +++ b/pyhealth/processors/__init__.py @@ -18,7 +18,10 @@ def get_processor(name: str): # Import all processors so they register themselves -from .image_processor import ImageProcessor +try: + from .image_processor import ImageProcessor +except ImportError: + pass # PIL/torchvision unavailable from .label_processor import ( BinaryLabelProcessor, MultiClassLabelProcessor, @@ -44,7 +47,10 @@ def get_processor(name: str): from .tensor_processor import TensorProcessor from .text_processor import TextProcessor from .timeseries_processor import TimeseriesProcessor -from .time_image_processor import TimeImageProcessor +try: + from .time_image_processor import TimeImageProcessor +except ImportError: + pass # PIL/torchvision unavailable from .audio_processor import AudioProcessor from .ignore_processor import IgnoreProcessor from .tuple_time_text_processor import TupleTimeTextProcessor From 2a50760a046ac5ffe65b826641952a99ac875145 Mon Sep 17 00:00:00 2001 From: jalengg Date: Wed, 4 Mar 2026 03:55:27 -0600 Subject: [PATCH 22/37] Chore: update notebook timestamp to 2026-03-04 09:55:13 (UTC) --- examples/promptehr_mimic3_colab.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/promptehr_mimic3_colab.ipynb b/examples/promptehr_mimic3_colab.ipynb index 3c20f0966..9f9980859 100644 --- a/examples/promptehr_mimic3_colab.ipynb +++ b/examples/promptehr_mimic3_colab.ipynb @@ -22,7 +22,7 @@ "cell_type": "markdown", "id": "preamble", "metadata": {}, - "source": "# PromptEHR Synthetic Data Generation for MIMIC-III\n\n_Last updated: 2026-03-04 09:29:09 (UTC)_\n\nThis notebook trains PromptEHR on your MIMIC-III data and generates synthetic patients conditioned on patient demographics (age and gender).\n\n## What You'll Need\n\n1. **MIMIC-III Access**: Download these files from PhysioNet:\n - `PATIENTS.csv` — patient demographics (date of birth, gender)\n - `ADMISSIONS.csv` — hospital admission records (timestamps)\n - `DIAGNOSES_ICD.csv` — ICD-9 diagnosis codes\n\n2. **Google Colab**: Free tier works, but GPU recommended (Runtime → Change runtime type → GPU)\n\n3. **Time**:\n - Demo (5 epochs, 1K samples): ~30–45 min on GPU\n - Production (20 epochs, 10K samples): ~3–5 hrs on GPU\n\n## How It Works\n\n1. **Setup**: Install PyHealth and mount Google Drive\n2. **Upload Data**: Upload your MIMIC-III CSV files (persisted to Drive across sessions)\n3. **Configure**: Set hyperparameters (epochs, batch size, etc.)\n4. **Train**: Fine-tune PromptEHR on MIMIC-III (checkpoints saved to Drive)\n5. **Generate**: Create synthetic patients conditioned on sampled demographics\n6. **Download**: Get CSV and JSON files with synthetic data\n\n## Important Notes\n\n**Colab Timeout**: Free Colab sessions timeout after ~12 hours. For production training, consider:\n- Colab Pro for longer sessions\n- Running on a GPU cluster using `examples/slurm/`\n\n**Demo vs Production**:\n- Demo defaults (5 epochs, 1K samples) let you test the full pipeline quickly\n- Production settings (20 epochs, 10K samples) produce publication-quality results\n\n## References\n\n- [PromptEHR Paper](https://arxiv.org/abs/2211.01761) — Wang et al., EMNLP 2023\n- [PyHealth Documentation](https://pyhealth.readthedocs.io/)\n- [MIMIC-III Access](https://physionet.org/content/mimiciii/)" + "source": "# PromptEHR Synthetic Data Generation for MIMIC-III\n\n_Last updated: 2026-03-04 09:55:13 (UTC)_\n\nThis notebook trains PromptEHR on your MIMIC-III data and generates synthetic patients conditioned on patient demographics (age and gender).\n\n## What You'll Need\n\n1. **MIMIC-III Access**: Download these files from PhysioNet:\n - `PATIENTS.csv` — patient demographics (date of birth, gender)\n - `ADMISSIONS.csv` — hospital admission records (timestamps)\n - `DIAGNOSES_ICD.csv` — ICD-9 diagnosis codes\n\n2. **Google Colab**: Free tier works, but GPU recommended (Runtime → Change runtime type → GPU)\n\n3. **Time**:\n - Demo (5 epochs, 1K samples): ~30–45 min on GPU\n - Production (20 epochs, 10K samples): ~3–5 hrs on GPU\n\n## How It Works\n\n1. **Setup**: Install PyHealth and mount Google Drive\n2. **Upload Data**: Upload your MIMIC-III CSV files (persisted to Drive across sessions)\n3. **Configure**: Set hyperparameters (epochs, batch size, etc.)\n4. **Train**: Fine-tune PromptEHR on MIMIC-III (checkpoints saved to Drive)\n5. **Generate**: Create synthetic patients conditioned on sampled demographics\n6. **Download**: Get CSV and JSON files with synthetic data\n\n## Important Notes\n\n**Colab Timeout**: Free Colab sessions timeout after ~12 hours. For production training, consider:\n- Colab Pro for longer sessions\n- Running on a GPU cluster using `examples/slurm/`\n\n**Demo vs Production**:\n- Demo defaults (5 epochs, 1K samples) let you test the full pipeline quickly\n- Production settings (20 epochs, 10K samples) produce publication-quality results\n\n## References\n\n- [PromptEHR Paper](https://arxiv.org/abs/2211.01761) — Wang et al., EMNLP 2023\n- [PyHealth Documentation](https://pyhealth.readthedocs.io/)\n- [MIMIC-III Access](https://physionet.org/content/mimiciii/)" }, { "cell_type": "markdown", From 5aa78009a4cadabf983943aa4095904779abd208 Mon Sep 17 00:00:00 2001 From: jalengg Date: Wed, 4 Mar 2026 04:13:21 -0600 Subject: [PATCH 23/37] =?UTF-8?q?Fix:=20Drive=20mount=20guard=20+=20makedi?= =?UTF-8?q?rs=20ordering=20=E2=80=94=20files=20no=20longer=20re-uploaded?= =?UTF-8?q?=20each=20session?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Root cause: s2-config called os.makedirs(DATA_DIR) before Drive was mounted, creating a local /content/drive/MyDrive directory. The s3-upload guard then saw isdir('/content/drive/MyDrive') == True and skipped drive.mount(), so all file checks ran against an empty local path. Fix: - s2-config: skip makedirs in Colab (Drive not yet mounted) - s3-upload: use os.path.ismount('/content/drive') guard (checks actual filesystem mount, not directory existence); makedirs after mount --- examples/promptehr_mimic3_colab.ipynb | 54 ++++++++++++++------------- 1 file changed, 29 insertions(+), 25 deletions(-) diff --git a/examples/promptehr_mimic3_colab.ipynb b/examples/promptehr_mimic3_colab.ipynb index 9f9980859..b9b2cf8d0 100644 --- a/examples/promptehr_mimic3_colab.ipynb +++ b/examples/promptehr_mimic3_colab.ipynb @@ -22,7 +22,7 @@ "cell_type": "markdown", "id": "preamble", "metadata": {}, - "source": "# PromptEHR Synthetic Data Generation for MIMIC-III\n\n_Last updated: 2026-03-04 09:55:13 (UTC)_\n\nThis notebook trains PromptEHR on your MIMIC-III data and generates synthetic patients conditioned on patient demographics (age and gender).\n\n## What You'll Need\n\n1. **MIMIC-III Access**: Download these files from PhysioNet:\n - `PATIENTS.csv` — patient demographics (date of birth, gender)\n - `ADMISSIONS.csv` — hospital admission records (timestamps)\n - `DIAGNOSES_ICD.csv` — ICD-9 diagnosis codes\n\n2. **Google Colab**: Free tier works, but GPU recommended (Runtime → Change runtime type → GPU)\n\n3. **Time**:\n - Demo (5 epochs, 1K samples): ~30–45 min on GPU\n - Production (20 epochs, 10K samples): ~3–5 hrs on GPU\n\n## How It Works\n\n1. **Setup**: Install PyHealth and mount Google Drive\n2. **Upload Data**: Upload your MIMIC-III CSV files (persisted to Drive across sessions)\n3. **Configure**: Set hyperparameters (epochs, batch size, etc.)\n4. **Train**: Fine-tune PromptEHR on MIMIC-III (checkpoints saved to Drive)\n5. **Generate**: Create synthetic patients conditioned on sampled demographics\n6. **Download**: Get CSV and JSON files with synthetic data\n\n## Important Notes\n\n**Colab Timeout**: Free Colab sessions timeout after ~12 hours. For production training, consider:\n- Colab Pro for longer sessions\n- Running on a GPU cluster using `examples/slurm/`\n\n**Demo vs Production**:\n- Demo defaults (5 epochs, 1K samples) let you test the full pipeline quickly\n- Production settings (20 epochs, 10K samples) produce publication-quality results\n\n## References\n\n- [PromptEHR Paper](https://arxiv.org/abs/2211.01761) — Wang et al., EMNLP 2023\n- [PyHealth Documentation](https://pyhealth.readthedocs.io/)\n- [MIMIC-III Access](https://physionet.org/content/mimiciii/)" + "source": "# PromptEHR Synthetic Data Generation for MIMIC-III\n\n_Last updated: 2026-03-04 10:07:18 (UTC)_\n\nThis notebook trains PromptEHR on your MIMIC-III data and generates synthetic patients conditioned on patient demographics (age and gender).\n\n## What You'll Need\n\n1. **MIMIC-III Access**: Download these files from PhysioNet:\n - `PATIENTS.csv` \u2014 patient demographics (date of birth, gender)\n - `ADMISSIONS.csv` \u2014 hospital admission records (timestamps)\n - `DIAGNOSES_ICD.csv` \u2014 ICD-9 diagnosis codes\n\n2. **Google Colab**: Free tier works, but GPU recommended (Runtime \u2192 Change runtime type \u2192 GPU)\n\n3. **Time**:\n - Demo (5 epochs, 1K samples): ~30\u201345 min on GPU\n - Production (20 epochs, 10K samples): ~3\u20135 hrs on GPU\n\n## How It Works\n\n1. **Setup**: Install PyHealth and mount Google Drive\n2. **Upload Data**: Upload your MIMIC-III CSV files (persisted to Drive across sessions)\n3. **Configure**: Set hyperparameters (epochs, batch size, etc.)\n4. **Train**: Fine-tune PromptEHR on MIMIC-III (checkpoints saved to Drive)\n5. **Generate**: Create synthetic patients conditioned on sampled demographics\n6. **Download**: Get CSV and JSON files with synthetic data\n\n## Important Notes\n\n**Colab Timeout**: Free Colab sessions timeout after ~12 hours. For production training, consider:\n- Colab Pro for longer sessions\n- Running on a GPU cluster using `examples/slurm/`\n\n**Demo vs Production**:\n- Demo defaults (5 epochs, 1K samples) let you test the full pipeline quickly\n- Production settings (20 epochs, 10K samples) produce publication-quality results\n\n## References\n\n- [PromptEHR Paper](https://arxiv.org/abs/2211.01761) \u2014 Wang et al., EMNLP 2023\n- [PyHealth Documentation](https://pyhealth.readthedocs.io/)\n- [MIMIC-III Access](https://physionet.org/content/mimiciii/)" }, { "cell_type": "markdown", @@ -36,7 +36,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "import subprocess\nimport sys\n\n# 1. Install PyHealth from GitHub — force-reinstall ensures Colab never uses a stale cached build.\n# (This may pull in old/mismatched transitive deps — we normalize them in step 2.)\nFORK = 'jalengg'\nBRANCH = 'promptehr-pr-integration'\ninstall_url = f\"git+https://github.com/{FORK}/PyHealth.git@{BRANCH}\"\nresult = subprocess.run(\n [sys.executable, \"-m\", \"pip\", \"install\", install_url,\n \"--quiet\", \"--no-cache-dir\", \"--force-reinstall\"],\n capture_output=True, text=True,\n)\nif result.returncode != 0:\n print(result.stderr)\n raise RuntimeError(\"PyHealth installation failed — see error above.\")\nprint(f\"✓ PyHealth installed from {FORK}/{BRANCH}\")\n\n# 2. Normalize transitive deps AFTER PyHealth install.\n# --force-reinstall above can leave packages in mixed-version states:\n# - scipy>=1.14 supports numpy 2.x (fixes transformers→sklearn→scipy cascade)\n# - scikit-learn>=1.5 matches scipy>=1.14\n# - Pillow>=10.4.0 ensures consistent PIL internals (_Ink moved between versions)\nsubprocess.run(\n [sys.executable, \"-m\", \"pip\", \"install\",\n \"scipy>=1.14\", \"scikit-learn>=1.5\", \"Pillow>=10.4.0\",\n \"--quiet\", \"--no-cache-dir\"],\n check=True,\n)\nprint(\"✓ scipy>=1.14, scikit-learn>=1.5, Pillow>=10.4.0 installed\")\n\n# Environment detection — MUST come before any google.colab import\ntry:\n import google.colab # noqa: F401\n IN_COLAB = True\nexcept ImportError:\n IN_COLAB = False\n\nimport os\nimport json\nimport glob\nimport random\nimport numpy as np\nimport torch\nimport pandas as pd\nimport matplotlib.pyplot as plt\nfrom datetime import datetime\nfrom collections import Counter\nfrom IPython.display import display\n\nprint(f\"Running in Colab: {IN_COLAB}\")\nprint(f\"PyTorch: {torch.__version__}\")\nif torch.cuda.is_available():\n print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n gb = torch.cuda.get_device_properties(0).total_memory / 1e9\n print(f\"GPU memory: {gb:.1f} GB\")\nelse:\n print(\"WARNING: No GPU detected. Training will be slow.\")\n print(\" → Runtime → Change runtime type → T4 GPU\")\n\n# Verify HuggingFace transformers version (>=4.48.3 required for use_cpu parameter)\nimport transformers\n_ver = tuple(int(x) for x in transformers.__version__.split(\".\")[:2])\nassert _ver >= (4, 48), (\n f\"transformers>=4.48.3 required (got {transformers.__version__}). \"\n \"Fix: pip install transformers --upgrade\"\n)\nprint(f\"transformers: {transformers.__version__} ✓\")\nprint(\"✓ All setup complete\")" + "source": "import subprocess\nimport sys\n\n# 1. Install PyHealth from GitHub \u2014 force-reinstall ensures Colab never uses a stale cached build.\n# (This may pull in old/mismatched transitive deps \u2014 we normalize them in step 2.)\nFORK = 'jalengg'\nBRANCH = 'promptehr-pr-integration'\ninstall_url = f\"git+https://github.com/{FORK}/PyHealth.git@{BRANCH}\"\nresult = subprocess.run(\n [sys.executable, \"-m\", \"pip\", \"install\", install_url,\n \"--quiet\", \"--no-cache-dir\", \"--force-reinstall\"],\n capture_output=True, text=True,\n)\nif result.returncode != 0:\n print(result.stderr)\n raise RuntimeError(\"PyHealth installation failed \u2014 see error above.\")\nprint(f\"\u2713 PyHealth installed from {FORK}/{BRANCH}\")\n\n# 2. Normalize transitive deps AFTER PyHealth install.\n# --force-reinstall above can leave packages in mixed-version states:\n# - scipy>=1.14 supports numpy 2.x (fixes transformers\u2192sklearn\u2192scipy cascade)\n# - scikit-learn>=1.5 matches scipy>=1.14\n# - Pillow>=10.4.0 ensures consistent PIL internals (_Ink moved between versions)\nsubprocess.run(\n [sys.executable, \"-m\", \"pip\", \"install\",\n \"scipy>=1.14\", \"scikit-learn>=1.5\", \"Pillow>=10.4.0\",\n \"--quiet\", \"--no-cache-dir\"],\n check=True,\n)\nprint(\"\u2713 scipy>=1.14, scikit-learn>=1.5, Pillow>=10.4.0 installed\")\n\n# Environment detection \u2014 MUST come before any google.colab import\ntry:\n import google.colab # noqa: F401\n IN_COLAB = True\nexcept ImportError:\n IN_COLAB = False\n\nimport os\nimport json\nimport glob\nimport random\nimport numpy as np\nimport torch\nimport pandas as pd\nimport matplotlib.pyplot as plt\nfrom datetime import datetime\nfrom collections import Counter\nfrom IPython.display import display\n\nprint(f\"Running in Colab: {IN_COLAB}\")\nprint(f\"PyTorch: {torch.__version__}\")\nif torch.cuda.is_available():\n print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n gb = torch.cuda.get_device_properties(0).total_memory / 1e9\n print(f\"GPU memory: {gb:.1f} GB\")\nelse:\n print(\"WARNING: No GPU detected. Training will be slow.\")\n print(\" \u2192 Runtime \u2192 Change runtime type \u2192 T4 GPU\")\n\n# Verify HuggingFace transformers version (>=4.48.3 required for use_cpu parameter)\nimport transformers\n_ver = tuple(int(x) for x in transformers.__version__.split(\".\")[:2])\nassert _ver >= (4, 48), (\n f\"transformers>=4.48.3 required (got {transformers.__version__}). \"\n \"Fix: pip install transformers --upgrade\"\n)\nprint(f\"transformers: {transformers.__version__} \u2713\")\nprint(\"\u2713 All setup complete\")" }, { "cell_type": "markdown", @@ -48,7 +48,7 @@ "cell_type": "markdown", "id": "s2-desc", "metadata": {}, - "source": "Configure all parameters here. **This is the only cell you need to modify.**\n\n- **`PRESET = \"demo\"`** — 5 epochs, 1 K synthetic patients, ~30–45 min on T4\n- **`PRESET = \"production\"`** — 20 epochs, 10 K synthetic patients, ~3–5 hrs on T4" + "source": "Configure all parameters here. **This is the only cell you need to modify.**\n\n- **`PRESET = \"demo\"`** \u2014 5 epochs, 1 K synthetic patients, ~30\u201345 min on T4\n- **`PRESET = \"production\"`** \u2014 20 epochs, 10 K synthetic patients, ~3\u20135 hrs on T4" }, { "cell_type": "code", @@ -56,7 +56,9 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "# ============================================================\n# CONFIGURATION — All modifiable parameters in one place\n# ============================================================\n\n# --- Preset ---\nPRESET = \"demo\" # \"demo\" or \"production\"\n\n# --- Training parameters ---\nif PRESET == \"demo\":\n EPOCHS = 5\n BATCH_SIZE = 16\n N_SYNTHETIC_SAMPLES = 1_000\n WARMUP_STEPS = 100\nelif PRESET == \"production\":\n EPOCHS = 20\n BATCH_SIZE = 16\n N_SYNTHETIC_SAMPLES = 10_000\n WARMUP_STEPS = 1_000\n\nLR = 1e-5 # Paper LR; low to avoid catastrophic forgetting of BART weights\nMAX_SEQ_LENGTH = 512 # Max tokens per patient (visits + special tokens)\n\n# --- Model architecture ---\nD_HIDDEN = 128 # Hidden dim for demographic prompt encoder\nPROMPT_LENGTH = 1 # Prompt vectors per demographic feature (1 is sufficient per paper)\n\n# --- BART backbone ---\n# \"facebook/bart-base\": pretrained BART (139 M params, 768 hidden dim).\n# PromptEHR fine-tunes these weights rather than training from scratch —\n# the pretrained sequence modeling prior means even 20 epochs can produce good results.\nBART_CONFIG_NAME = \"facebook/bart-base\"\n\n# --- Generation parameters ---\nRANDOM_SAMPLING = True # True: nucleus sampling (diverse), False: greedy (deterministic)\nTEMPERATURE = 0.7 # Lower = more common codes. Higher = more rare/diverse codes.\nTOP_P = 0.95 # Nucleus sampling: sample from top 95% probability mass.\n\n# --- Reproducibility ---\nSEED = 42\n\n# --- Paths (all derived from BASE_DIR) ---\nBASE_DIR = '/content/drive/MyDrive/PromptEHR_Training' if IN_COLAB else './promptehr_training'\nDATA_DIR = f'{BASE_DIR}/data'\nCHECKPOINT_DIR = f'{BASE_DIR}/checkpoints'\nOUTPUT_DIR = f'{BASE_DIR}/output'\n\nfor d in [DATA_DIR, CHECKPOINT_DIR, OUTPUT_DIR]:\n os.makedirs(d, exist_ok=True)\n\nprint(f\"Preset: {PRESET}\")\nprint(f\"Epochs: {EPOCHS} | Batch size: {BATCH_SIZE} | LR: {LR}\")\nprint(f\"Synthetic: {N_SYNTHETIC_SAMPLES:,} patients\")\nprint(f\"Base directory: {BASE_DIR}\")\nprint(\"✓ Configuration complete\")" + "source": [ + "# ============================================================\n# CONFIGURATION \u2014 All modifiable parameters in one place\n# ============================================================\n\n# --- Preset ---\nPRESET = \"demo\" # \"demo\" or \"production\"\n\n# --- Training parameters ---\nif PRESET == \"demo\":\n EPOCHS = 5\n BATCH_SIZE = 16\n N_SYNTHETIC_SAMPLES = 1_000\n WARMUP_STEPS = 100\nelif PRESET == \"production\":\n EPOCHS = 20\n BATCH_SIZE = 16\n N_SYNTHETIC_SAMPLES = 10_000\n WARMUP_STEPS = 1_000\n\nLR = 1e-5 # Paper LR; low to avoid catastrophic forgetting of BART weights\nMAX_SEQ_LENGTH = 512 # Max tokens per patient (visits + special tokens)\n\n# --- Model architecture ---\nD_HIDDEN = 128 # Hidden dim for demographic prompt encoder\nPROMPT_LENGTH = 1 # Prompt vectors per demographic feature (1 is sufficient per paper)\n\n# --- BART backbone ---\n# \"facebook/bart-base\": pretrained BART (139 M params, 768 hidden dim).\n# PromptEHR fine-tunes these weights rather than training from scratch \u2014\n# the pretrained sequence modeling prior means even 20 epochs can produce good results.\nBART_CONFIG_NAME = \"facebook/bart-base\"\n\n# --- Generation parameters ---\nRANDOM_SAMPLING = True # True: nucleus sampling (diverse), False: greedy (deterministic)\nTEMPERATURE = 0.7 # Lower = more common codes. Higher = more rare/diverse codes.\nTOP_P = 0.95 # Nucleus sampling: sample from top 95% probability mass.\n\n# --- Reproducibility ---\nSEED = 42\n\n# --- Paths (all derived from BASE_DIR) ---\nBASE_DIR = '/content/drive/MyDrive/PromptEHR_Training' if IN_COLAB else './promptehr_training'\nDATA_DIR = f'{BASE_DIR}/data'\nCHECKPOINT_DIR = f'{BASE_DIR}/checkpoints'\nOUTPUT_DIR = f'{BASE_DIR}/output'\n\n# In Colab, Drive-backed dirs are created after mount (in s3-upload).\n# In local/SLURM environments, create them immediately.\nif not IN_COLAB:\n for d in [DATA_DIR, CHECKPOINT_DIR, OUTPUT_DIR]:\n os.makedirs(d, exist_ok=True)\n\nprint(f\"Preset: {PRESET}\")\nprint(f\"Epochs: {EPOCHS} | Batch size: {BATCH_SIZE} | LR: {LR}\")\nprint(f\"Synthetic: {N_SYNTHETIC_SAMPLES:,} patients\")\nprint(f\"Base directory: {BASE_DIR}\")\nprint(\"\u2713 Configuration complete\")" + ] }, { "cell_type": "markdown", @@ -68,7 +70,7 @@ "cell_type": "markdown", "id": "s3-desc", "metadata": {}, - "source": "Upload your MIMIC-III CSV files. PromptEHR needs **3 files** (one more than HALO — `PATIENTS.csv` is required for demographic conditioning):\n\n1. `PATIENTS.csv` — date of birth and gender\n2. `ADMISSIONS.csv` — admission timestamps (used to compute age at first admission)\n3. `DIAGNOSES_ICD.csv` — ICD-9 diagnosis codes\n\nFiles persist across Colab sessions when saved to Google Drive.\n\n**No MIMIC-III?** The next cell automatically activates Demo Mode." + "source": "Upload your MIMIC-III CSV files. PromptEHR needs **3 files** (one more than HALO \u2014 `PATIENTS.csv` is required for demographic conditioning):\n\n1. `PATIENTS.csv` \u2014 date of birth and gender\n2. `ADMISSIONS.csv` \u2014 admission timestamps (used to compute age at first admission)\n3. `DIAGNOSES_ICD.csv` \u2014 ICD-9 diagnosis codes\n\nFiles persist across Colab sessions when saved to Google Drive.\n\n**No MIMIC-III?** The next cell automatically activates Demo Mode." }, { "cell_type": "code", @@ -76,7 +78,9 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "import shutil\nDEMO_MODE = False\n\n# Mount Drive (Colab only) — guard makes this cell idempotent (safe to re-run)\nif IN_COLAB:\n from google.colab import drive\n if not os.path.isdir('/content/drive/MyDrive'):\n drive.mount('/content/drive')\n else:\n print(\"Drive already mounted\")\n print(\"✓ Google Drive mounted\")\n\n# Check which files exist in the Drive-backed DATA_DIR\nrequired_files = {\n 'PATIENTS.csv': 'Patient demographics (DOB, gender)',\n 'ADMISSIONS.csv': 'Admission records (timestamps)',\n 'DIAGNOSES_ICD.csv': 'ICD-9 diagnosis codes',\n}\nexisting = {f: os.path.exists(f'{DATA_DIR}/{f}') for f in required_files}\nmissing = [f for f, ok in existing.items() if not ok]\n\nif not missing:\n # All files already in Drive — no upload needed\n print(\"✓ All MIMIC-III files found in Drive (no upload needed):\")\n for fname in required_files:\n size_mb = os.path.getsize(f'{DATA_DIR}/{fname}') / 1024 / 1024\n print(f\" {fname} ({size_mb:.1f} MB)\")\n print(f\"\\nFiles are reused from: {DATA_DIR}\")\n print(\"To force re-upload, delete files from that folder and re-run this cell.\")\nelse:\n print(\"MIMIC-III file status:\")\n for fname, desc in required_files.items():\n mark = \"✓\" if existing[fname] else \"✗ MISSING\"\n print(f\" {mark} {fname} — {desc}\")\n\n if IN_COLAB:\n print(f\"\\nUploading {len(missing)} missing file(s)...\")\n from google.colab import files as _colab_files\n uploaded = _colab_files.upload()\n\n # Normalize filenames — Colab renames duplicates as \"ADMISSIONS (1).csv\".\n # Match each upload to the required file it belongs to, then copy with\n # the canonical name so subsequent runs find the file in Drive.\n for uploaded_name, data in uploaded.items():\n matched = None\n for req in required_files:\n base = req.replace('.csv', '')\n if base in uploaded_name and uploaded_name.endswith('.csv'):\n matched = req\n break\n if matched:\n # Write upload bytes to /content/ then copy to Drive-backed dest\n tmp = f'/content/{uploaded_name}'\n with open(tmp, 'wb') as f:\n f.write(data)\n dest = f'{DATA_DIR}/{matched}'\n shutil.copy(tmp, dest)\n size_mb = os.path.getsize(dest) / 1024 / 1024\n print(f\" ✓ Saved {matched} ({size_mb:.1f} MB) → {dest}\")\n else:\n print(f\" ⚠ Unrecognised file: {uploaded_name} (skipped)\")\n\n missing = [f for f in required_files if not os.path.exists(f'{DATA_DIR}/{f}')]\n\n if missing:\n print(f\"\\nMIMIC-III files not available ({missing}).\")\n print(\"→ Activating Demo Mode — full pipeline with synthetic stand-in data.\")\n DEMO_MODE = True\n else:\n print(\"\\n✓ All 3 MIMIC-III files present. Running in MIMIC-III mode.\")" + "source": [ + "import shutil\nDEMO_MODE = False\n\n# Mount Drive (Colab only) \u2014 guard makes this cell idempotent (safe to re-run)\nif IN_COLAB:\n from google.colab import drive\n # os.path.ismount checks the actual filesystem mount, not just if\n # the directory path exists (which makedirs in s2-config may have created).\n if not os.path.ismount('/content/drive'):\n drive.mount('/content/drive')\n else:\n print(\"Drive already mounted\")\n print(\"\u2713 Google Drive mounted\")\n # Create Drive-backed directories now that Drive is mounted.\n for _d in [DATA_DIR, CHECKPOINT_DIR, OUTPUT_DIR]:\n os.makedirs(_d, exist_ok=True)\n\n# Check which files exist in the Drive-backed DATA_DIR\nrequired_files = {\n 'PATIENTS.csv': 'Patient demographics (DOB, gender)',\n 'ADMISSIONS.csv': 'Admission records (timestamps)',\n 'DIAGNOSES_ICD.csv': 'ICD-9 diagnosis codes',\n}\nexisting = {f: os.path.exists(f'{DATA_DIR}/{f}') for f in required_files}\nmissing = [f for f, ok in existing.items() if not ok]\n\nif not missing:\n # All files already in Drive \u2014 no upload needed\n print(\"\u2713 All MIMIC-III files found in Drive (no upload needed):\")\n for fname in required_files:\n size_mb = os.path.getsize(f'{DATA_DIR}/{fname}') / 1024 / 1024\n print(f\" {fname} ({size_mb:.1f} MB)\")\n print(f\"\\nFiles are reused from: {DATA_DIR}\")\n print(\"To force re-upload, delete files from that folder and re-run this cell.\")\nelse:\n print(\"MIMIC-III file status:\")\n for fname, desc in required_files.items():\n mark = \"\u2713\" if existing[fname] else \"\u2717 MISSING\"\n print(f\" {mark} {fname} \u2014 {desc}\")\n\n if IN_COLAB:\n print(f\"\\nUploading {len(missing)} missing file(s)...\")\n from google.colab import files as _colab_files\n uploaded = _colab_files.upload()\n\n # Normalize filenames \u2014 Colab renames duplicates as \"ADMISSIONS (1).csv\".\n # Match each upload to the required file it belongs to, then copy with\n # the canonical name so subsequent runs find the file in Drive.\n for uploaded_name, data in uploaded.items():\n matched = None\n for req in required_files:\n base = req.replace('.csv', '')\n if base in uploaded_name and uploaded_name.endswith('.csv'):\n matched = req\n break\n if matched:\n # Write upload bytes to /content/ then copy to Drive-backed dest\n tmp = f'/content/{uploaded_name}'\n with open(tmp, 'wb') as f:\n f.write(data)\n dest = f'{DATA_DIR}/{matched}'\n shutil.copy(tmp, dest)\n size_mb = os.path.getsize(dest) / 1024 / 1024\n print(f\" \u2713 Saved {matched} ({size_mb:.1f} MB) \u2192 {dest}\")\n else:\n print(f\" \u26a0 Unrecognised file: {uploaded_name} (skipped)\")\n\n missing = [f for f in required_files if not os.path.exists(f'{DATA_DIR}/{f}')]\n\n if missing:\n print(f\"\\nMIMIC-III files not available ({missing}).\")\n print(\"\u2192 Activating Demo Mode \u2014 full pipeline with synthetic stand-in data.\")\n DEMO_MODE = True\n else:\n print(\"\\n\u2713 All 3 MIMIC-III files present. Running in MIMIC-III mode.\")" + ] }, { "cell_type": "code", @@ -84,7 +88,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "if DEMO_MODE:\n print(\"Setting up Demo Mode data...\")\n from pyhealth.datasets.sample_dataset import InMemorySampleDataset\n\n # Synthetic stand-in: 200 patients, 2-6 visits, realistic ICD-9 codes.\n # Exercises the full pipeline without any real patient data.\n random.seed(SEED)\n icd9_pool = [\n \"428.0\", \"401.9\", \"250.00\", \"272.4\", \"410.71\", \"486\",\n \"585.3\", \"V58.61\", \"412\", \"414.01\", \"276.1\", \"285.9\",\n \"584.9\", \"305.1\", \"290.0\", \"427.31\", \"518.81\", \"496\",\n \"038.9\", \"599.0\",\n ]\n demo_samples = []\n for i in range(200):\n n_visits = random.randint(2, 6)\n visits = [random.sample(icd9_pool, random.randint(1, 5)) for _ in range(n_visits)]\n demo_samples.append({\n \"patient_id\": f\"DEMO_{i:04d}\",\n \"visits\": visits,\n \"age\": float(random.randint(18, 89)),\n \"gender\": random.randint(0, 1),\n })\n print(f\"✓ Demo dataset: {len(demo_samples)} patients, up to 6 visits each\")\n print(\" (Replace with real MIMIC-III data for publication-quality results)\")" + "source": "if DEMO_MODE:\n print(\"Setting up Demo Mode data...\")\n from pyhealth.datasets.sample_dataset import InMemorySampleDataset\n\n # Synthetic stand-in: 200 patients, 2-6 visits, realistic ICD-9 codes.\n # Exercises the full pipeline without any real patient data.\n random.seed(SEED)\n icd9_pool = [\n \"428.0\", \"401.9\", \"250.00\", \"272.4\", \"410.71\", \"486\",\n \"585.3\", \"V58.61\", \"412\", \"414.01\", \"276.1\", \"285.9\",\n \"584.9\", \"305.1\", \"290.0\", \"427.31\", \"518.81\", \"496\",\n \"038.9\", \"599.0\",\n ]\n demo_samples = []\n for i in range(200):\n n_visits = random.randint(2, 6)\n visits = [random.sample(icd9_pool, random.randint(1, 5)) for _ in range(n_visits)]\n demo_samples.append({\n \"patient_id\": f\"DEMO_{i:04d}\",\n \"visits\": visits,\n \"age\": float(random.randint(18, 89)),\n \"gender\": random.randint(0, 1),\n })\n print(f\"\u2713 Demo dataset: {len(demo_samples)} patients, up to 6 visits each\")\n print(\" (Replace with real MIMIC-III data for publication-quality results)\")" }, { "cell_type": "code", @@ -92,7 +96,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "if not DEMO_MODE:\n print(\"Validating MIMIC-III files...\")\n _patients = pd.read_csv(f'{DATA_DIR}/PATIENTS.csv')\n assert 'SUBJECT_ID' in _patients.columns, \"PATIENTS.csv missing SUBJECT_ID\"\n assert 'GENDER' in _patients.columns, \"PATIENTS.csv missing GENDER\"\n assert 'DOB' in _patients.columns, \"PATIENTS.csv missing DOB\"\n print(f\"✓ PATIENTS.csv: {len(_patients):>8,} rows\")\n\n _admissions = pd.read_csv(f'{DATA_DIR}/ADMISSIONS.csv')\n assert 'SUBJECT_ID' in _admissions.columns, \"ADMISSIONS.csv missing SUBJECT_ID\"\n assert 'HADM_ID' in _admissions.columns, \"ADMISSIONS.csv missing HADM_ID\"\n print(f\"✓ ADMISSIONS.csv: {len(_admissions):>8,} rows\")\n\n _diagnoses = pd.read_csv(f'{DATA_DIR}/DIAGNOSES_ICD.csv')\n assert 'ICD9_CODE' in _diagnoses.columns, \"DIAGNOSES_ICD.csv missing ICD9_CODE\"\n print(f\"✓ DIAGNOSES_ICD.csv: {len(_diagnoses):>8,} rows\")\n\n del _patients, _admissions, _diagnoses # free memory\n print(\"\\n✓ All files validated successfully\")" + "source": "if not DEMO_MODE:\n print(\"Validating MIMIC-III files...\")\n _patients = pd.read_csv(f'{DATA_DIR}/PATIENTS.csv')\n assert 'SUBJECT_ID' in _patients.columns, \"PATIENTS.csv missing SUBJECT_ID\"\n assert 'GENDER' in _patients.columns, \"PATIENTS.csv missing GENDER\"\n assert 'DOB' in _patients.columns, \"PATIENTS.csv missing DOB\"\n print(f\"\u2713 PATIENTS.csv: {len(_patients):>8,} rows\")\n\n _admissions = pd.read_csv(f'{DATA_DIR}/ADMISSIONS.csv')\n assert 'SUBJECT_ID' in _admissions.columns, \"ADMISSIONS.csv missing SUBJECT_ID\"\n assert 'HADM_ID' in _admissions.columns, \"ADMISSIONS.csv missing HADM_ID\"\n print(f\"\u2713 ADMISSIONS.csv: {len(_admissions):>8,} rows\")\n\n _diagnoses = pd.read_csv(f'{DATA_DIR}/DIAGNOSES_ICD.csv')\n assert 'ICD9_CODE' in _diagnoses.columns, \"DIAGNOSES_ICD.csv missing ICD9_CODE\"\n print(f\"\u2713 DIAGNOSES_ICD.csv: {len(_diagnoses):>8,} rows\")\n\n del _patients, _admissions, _diagnoses # free memory\n print(\"\\n\u2713 All files validated successfully\")" }, { "cell_type": "markdown", @@ -104,7 +108,7 @@ "cell_type": "markdown", "id": "s4-desc", "metadata": {}, - "source": "**What happens during training:**\n\n1. **Dataset loading**: PyHealth reads MIMIC-III and creates one sample per patient (nested visit sequences + demographics: age at first admission, gender).\n2. **Tokenization**: Each ICD-9 code is mapped to a unique BART token ID. Special tokens mark visit boundaries: `[VISIT_START]`, `[VISIT_END]`, `[SEQ_END]`.\n3. **Demographic prompts**: Age and gender are encoded into learned prompt vectors prepended to the BART encoder input — steering the model toward age/gender-appropriate diagnosis patterns.\n4. **Fine-tuning**: HuggingFace Trainer fine-tunes the BART Seq2Seq model to predict the next token conditioned on the demographic prompts.\n5. **Checkpoint**: Saved to `{CHECKPOINT_DIR}/checkpoint.pt` after training.\n\nThe `WARMUP_STEPS` ramp up the learning rate gradually during early training, preventing catastrophic forgetting of BART's pretrained sequence modeling capabilities." + "source": "**What happens during training:**\n\n1. **Dataset loading**: PyHealth reads MIMIC-III and creates one sample per patient (nested visit sequences + demographics: age at first admission, gender).\n2. **Tokenization**: Each ICD-9 code is mapped to a unique BART token ID. Special tokens mark visit boundaries: `[VISIT_START]`, `[VISIT_END]`, `[SEQ_END]`.\n3. **Demographic prompts**: Age and gender are encoded into learned prompt vectors prepended to the BART encoder input \u2014 steering the model toward age/gender-appropriate diagnosis patterns.\n4. **Fine-tuning**: HuggingFace Trainer fine-tunes the BART Seq2Seq model to predict the next token conditioned on the demographic prompts.\n5. **Checkpoint**: Saved to `{CHECKPOINT_DIR}/checkpoint.pt` after training.\n\nThe `WARMUP_STEPS` ramp up the learning rate gradually during early training, preventing catastrophic forgetting of BART's pretrained sequence modeling capabilities." }, { "cell_type": "code", @@ -112,7 +116,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "# Set all random seeds before any stochastic operation\ntorch.manual_seed(SEED)\nnp.random.seed(SEED)\nrandom.seed(SEED)\nif torch.cuda.is_available():\n torch.cuda.manual_seed_all(SEED)\n torch.backends.cudnn.deterministic = True\nprint(f\"✓ Random seed set to {SEED}\")\n\nfrom pyhealth.datasets import split_by_patient\nfrom pyhealth.models import PromptEHR\n\nif not DEMO_MODE:\n from pyhealth.datasets import MIMIC3Dataset\n from pyhealth.tasks import promptehr_generation_mimic3_fn\n\n print(\"\\nLoading MIMIC-III dataset (this may take a few minutes)...\")\n dataset = MIMIC3Dataset(\n root=DATA_DIR,\n tables=[\"patients\", \"admissions\", \"diagnoses_icd\"],\n code_mapping={},\n )\n print(f\"Loaded {len(dataset.patients):,} patients\")\n\n print(\"Applying PromptEHR generation task...\")\n sample_dataset = dataset.set_task(promptehr_generation_mimic3_fn)\n print(f\"Eligible patients (≥2 visits with ICD-9 codes): {len(sample_dataset):,}\")\nelse:\n from pyhealth.datasets.sample_dataset import InMemorySampleDataset\n sample_dataset = InMemorySampleDataset(\n samples=demo_samples,\n input_schema={\"visits\": \"nested_sequence\"},\n output_schema={},\n )\n print(f\"Demo dataset ready: {len(sample_dataset)} patients\")\n\ntrain_dataset, val_dataset, _ = split_by_patient(sample_dataset, [0.8, 0.1, 0.1])\nprint(f\"\\nSplit: {len(train_dataset):,} train / {len(val_dataset):,} val patients\")" + "source": "# Set all random seeds before any stochastic operation\ntorch.manual_seed(SEED)\nnp.random.seed(SEED)\nrandom.seed(SEED)\nif torch.cuda.is_available():\n torch.cuda.manual_seed_all(SEED)\n torch.backends.cudnn.deterministic = True\nprint(f\"\u2713 Random seed set to {SEED}\")\n\nfrom pyhealth.datasets import split_by_patient\nfrom pyhealth.models import PromptEHR\n\nif not DEMO_MODE:\n from pyhealth.datasets import MIMIC3Dataset\n from pyhealth.tasks import promptehr_generation_mimic3_fn\n\n print(\"\\nLoading MIMIC-III dataset (this may take a few minutes)...\")\n dataset = MIMIC3Dataset(\n root=DATA_DIR,\n tables=[\"patients\", \"admissions\", \"diagnoses_icd\"],\n code_mapping={},\n )\n print(f\"Loaded {len(dataset.patients):,} patients\")\n\n print(\"Applying PromptEHR generation task...\")\n sample_dataset = dataset.set_task(promptehr_generation_mimic3_fn)\n print(f\"Eligible patients (\u22652 visits with ICD-9 codes): {len(sample_dataset):,}\")\nelse:\n from pyhealth.datasets.sample_dataset import InMemorySampleDataset\n sample_dataset = InMemorySampleDataset(\n samples=demo_samples,\n input_schema={\"visits\": \"nested_sequence\"},\n output_schema={},\n )\n print(f\"Demo dataset ready: {len(sample_dataset)} patients\")\n\ntrain_dataset, val_dataset, _ = split_by_patient(sample_dataset, [0.8, 0.1, 0.1])\nprint(f\"\\nSplit: {len(train_dataset):,} train / {len(val_dataset):,} val patients\")" }, { "cell_type": "code", @@ -120,7 +124,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "# Save config alongside checkpoint for reproducibility\n_config = {k: str(v) for k, v in globals().items()\n if k.isupper() and not k.startswith('_')\n and isinstance(v, (str, int, float, bool))}\n_config['timestamp'] = datetime.now().isoformat()\n_config_path = f'{CHECKPOINT_DIR}/config.json'\nwith open(_config_path, 'w') as f:\n json.dump(_config, f, indent=2)\nprint(f\"✓ Config saved to {_config_path}\")\n\n# Initialize model\nprint(\"\\nInitializing PromptEHR model...\")\nmodel = PromptEHR(\n dataset=train_dataset,\n n_num_features=1, # 1 continuous demographic feature: age\n cat_cardinalities=[2], # 1 categorical feature: gender (binary: 0=male, 1=female)\n d_hidden=D_HIDDEN,\n prompt_length=PROMPT_LENGTH,\n bart_config_name=BART_CONFIG_NAME,\n epochs=EPOCHS,\n batch_size=BATCH_SIZE,\n lr=LR,\n warmup_steps=WARMUP_STEPS,\n max_seq_length=MAX_SEQ_LENGTH,\n save_dir=CHECKPOINT_DIR,\n)\n\nn_special = 7 # PAD, BOS, EOS, UNK, VISIT_START, VISIT_END, SEQ_END\nn_codes = model._vocab.total_size - n_special\ntotal_params = sum(p.numel() for p in model.parameters())\nprint(f\"✓ PromptEHR initialized\")\nprint(f\" Vocabulary: {model._vocab.total_size} tokens \"\n f\"({n_codes} ICD-9 codes + {n_special} special tokens)\")\nprint(f\" Parameters: {total_params:,}\")" + "source": "# Save config alongside checkpoint for reproducibility\n_config = {k: str(v) for k, v in globals().items()\n if k.isupper() and not k.startswith('_')\n and isinstance(v, (str, int, float, bool))}\n_config['timestamp'] = datetime.now().isoformat()\n_config_path = f'{CHECKPOINT_DIR}/config.json'\nwith open(_config_path, 'w') as f:\n json.dump(_config, f, indent=2)\nprint(f\"\u2713 Config saved to {_config_path}\")\n\n# Initialize model\nprint(\"\\nInitializing PromptEHR model...\")\nmodel = PromptEHR(\n dataset=train_dataset,\n n_num_features=1, # 1 continuous demographic feature: age\n cat_cardinalities=[2], # 1 categorical feature: gender (binary: 0=male, 1=female)\n d_hidden=D_HIDDEN,\n prompt_length=PROMPT_LENGTH,\n bart_config_name=BART_CONFIG_NAME,\n epochs=EPOCHS,\n batch_size=BATCH_SIZE,\n lr=LR,\n warmup_steps=WARMUP_STEPS,\n max_seq_length=MAX_SEQ_LENGTH,\n save_dir=CHECKPOINT_DIR,\n)\n\nn_special = 7 # PAD, BOS, EOS, UNK, VISIT_START, VISIT_END, SEQ_END\nn_codes = model._vocab.total_size - n_special\ntotal_params = sum(p.numel() for p in model.parameters())\nprint(f\"\u2713 PromptEHR initialized\")\nprint(f\" Vocabulary: {model._vocab.total_size} tokens \"\n f\"({n_codes} ICD-9 codes + {n_special} special tokens)\")\nprint(f\" Parameters: {total_params:,}\")" }, { "cell_type": "code", @@ -128,7 +132,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "print(\"Starting training...\")\nprint(\"HuggingFace Trainer will print step-by-step progress below.\")\nprint(\"=\" * 60)\n\nmodel.train_model(train_dataset, val_dataset=val_dataset)\n\nprint(\"=\" * 60)\nprint(\"✓ Training complete!\")\nprint(f\" Checkpoint: {CHECKPOINT_DIR}/checkpoint.pt\")" + "source": "print(\"Starting training...\")\nprint(\"HuggingFace Trainer will print step-by-step progress below.\")\nprint(\"=\" * 60)\n\nmodel.train_model(train_dataset, val_dataset=val_dataset)\n\nprint(\"=\" * 60)\nprint(\"\u2713 Training complete!\")\nprint(f\" Checkpoint: {CHECKPOINT_DIR}/checkpoint.pt\")" }, { "cell_type": "code", @@ -136,7 +140,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "# Plot training loss from HuggingFace Trainer logs\n_state_files = glob.glob(f'{CHECKPOINT_DIR}/**/trainer_state.json', recursive=True)\n\nif _state_files:\n with open(_state_files[0]) as f:\n _log = json.load(f)['log_history']\n _steps = [e['step'] for e in _log if 'loss' in e]\n _losses = [e['loss'] for e in _log if 'loss' in e]\n\n if _steps:\n fig, ax = plt.subplots(figsize=(9, 4))\n ax.plot(_steps, _losses, 'b-o', linewidth=1.5, markersize=4, label='Training loss')\n ax.set_xlabel('Training step', fontsize=12)\n ax.set_ylabel('Cross-entropy loss', fontsize=12)\n ax.set_title('PromptEHR Training Loss', fontsize=14)\n ax.legend(); ax.grid(alpha=0.3)\n plt.tight_layout()\n _loss_plot = f'{OUTPUT_DIR}/training_loss.png'\n plt.savefig(_loss_plot, dpi=150); plt.show()\n print(f\"Initial loss: {_losses[0]:.4f} → Final loss: {_losses[-1]:.4f}\")\n print(f\"Plot saved to: {_loss_plot}\")\n else:\n print(\"No loss values recorded (too few steps for demo preset).\")\nelse:\n print(\"trainer_state.json not found — skipping loss curve.\")\n print(\"(Expected for very short demo runs.)\")" + "source": "# Plot training loss from HuggingFace Trainer logs\n_state_files = glob.glob(f'{CHECKPOINT_DIR}/**/trainer_state.json', recursive=True)\n\nif _state_files:\n with open(_state_files[0]) as f:\n _log = json.load(f)['log_history']\n _steps = [e['step'] for e in _log if 'loss' in e]\n _losses = [e['loss'] for e in _log if 'loss' in e]\n\n if _steps:\n fig, ax = plt.subplots(figsize=(9, 4))\n ax.plot(_steps, _losses, 'b-o', linewidth=1.5, markersize=4, label='Training loss')\n ax.set_xlabel('Training step', fontsize=12)\n ax.set_ylabel('Cross-entropy loss', fontsize=12)\n ax.set_title('PromptEHR Training Loss', fontsize=14)\n ax.legend(); ax.grid(alpha=0.3)\n plt.tight_layout()\n _loss_plot = f'{OUTPUT_DIR}/training_loss.png'\n plt.savefig(_loss_plot, dpi=150); plt.show()\n print(f\"Initial loss: {_losses[0]:.4f} \u2192 Final loss: {_losses[-1]:.4f}\")\n print(f\"Plot saved to: {_loss_plot}\")\n else:\n print(\"No loss values recorded (too few steps for demo preset).\")\nelse:\n print(\"trainer_state.json not found \u2014 skipping loss curve.\")\n print(\"(Expected for very short demo runs.)\")" }, { "cell_type": "markdown", @@ -148,7 +152,7 @@ "cell_type": "markdown", "id": "s5-desc", "metadata": {}, - "source": "**How generation works:**\n\n1. **Demographic sampling**: For each synthetic patient, `synthesize_dataset` draws an `(age, gender)` pair from `model._demo_pool` — the real training population. This ensures the synthetic cohort's demographic profile mirrors MIMIC-III.\n2. **Prompt conditioning**: The sampled demographics are encoded into prompt vectors and prepended to the BART encoder input.\n3. **Autoregressive decoding**: BART generates tokens one at a time. Special tokens `[VISIT_START]` and `[VISIT_END]` structure the output into visits; `[SEQ_END]` ends the patient sequence.\n4. **Decoding**: Token IDs are mapped back to ICD-9 code strings.\n\n`RANDOM_SAMPLING = True` (default): nucleus sampling — diverse, realistic output. \n`RANDOM_SAMPLING = False`: greedy decoding — deterministic, may repeat common patterns." + "source": "**How generation works:**\n\n1. **Demographic sampling**: For each synthetic patient, `synthesize_dataset` draws an `(age, gender)` pair from `model._demo_pool` \u2014 the real training population. This ensures the synthetic cohort's demographic profile mirrors MIMIC-III.\n2. **Prompt conditioning**: The sampled demographics are encoded into prompt vectors and prepended to the BART encoder input.\n3. **Autoregressive decoding**: BART generates tokens one at a time. Special tokens `[VISIT_START]` and `[VISIT_END]` structure the output into visits; `[SEQ_END]` ends the patient sequence.\n4. **Decoding**: Token IDs are mapped back to ICD-9 code strings.\n\n`RANDOM_SAMPLING = True` (default): nucleus sampling \u2014 diverse, realistic output. \n`RANDOM_SAMPLING = False`: greedy decoding \u2014 deterministic, may repeat common patterns." }, { "cell_type": "code", @@ -156,7 +160,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "print(f\"Generating {N_SYNTHETIC_SAMPLES:,} synthetic patients...\")\nprint(f\" Sampling: {'nucleus (random)' if RANDOM_SAMPLING else 'greedy'}\"\n + (f\", temperature={TEMPERATURE}, top_p={TOP_P}\" if RANDOM_SAMPLING else \"\"))\nprint(\"(This may take several minutes...)\")\n\nsynthetic = model.synthesize_dataset(\n num_samples=N_SYNTHETIC_SAMPLES,\n random_sampling=RANDOM_SAMPLING,\n)\n\nprint(f\"\\n✓ Generated {len(synthetic):,} synthetic patients\")\n\n# Preview\n_preview = []\nfor p in synthetic[:10]:\n _v0 = p[\"visits\"][0] if p[\"visits\"] else []\n _sample = \", \".join(_v0[:4]) + (\"...\" if len(_v0) > 4 else \"\")\n _preview.append({\n \"patient_id\": p[\"patient_id\"],\n \"n_visits\": len(p[\"visits\"]),\n \"total_codes\": sum(len(v) for v in p[\"visits\"]),\n \"first_visit_codes\": _sample or \"(empty)\",\n })\ndisplay(pd.DataFrame(_preview))" + "source": "print(f\"Generating {N_SYNTHETIC_SAMPLES:,} synthetic patients...\")\nprint(f\" Sampling: {'nucleus (random)' if RANDOM_SAMPLING else 'greedy'}\"\n + (f\", temperature={TEMPERATURE}, top_p={TOP_P}\" if RANDOM_SAMPLING else \"\"))\nprint(\"(This may take several minutes...)\")\n\nsynthetic = model.synthesize_dataset(\n num_samples=N_SYNTHETIC_SAMPLES,\n random_sampling=RANDOM_SAMPLING,\n)\n\nprint(f\"\\n\u2713 Generated {len(synthetic):,} synthetic patients\")\n\n# Preview\n_preview = []\nfor p in synthetic[:10]:\n _v0 = p[\"visits\"][0] if p[\"visits\"] else []\n _sample = \", \".join(_v0[:4]) + (\"...\" if len(_v0) > 4 else \"\")\n _preview.append({\n \"patient_id\": p[\"patient_id\"],\n \"n_visits\": len(p[\"visits\"]),\n \"total_codes\": sum(len(v) for v in p[\"visits\"]),\n \"first_visit_codes\": _sample or \"(empty)\",\n })\ndisplay(pd.DataFrame(_preview))" }, { "cell_type": "code", @@ -164,7 +168,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "# Save as JSON (full nested records — directly loadable back into PyHealth)\njson_path = f'{OUTPUT_DIR}/synthetic_patients.json'\nwith open(json_path, 'w') as f:\n json.dump(synthetic, f, indent=2)\nprint(f\"✓ {len(synthetic):,} patients → {json_path}\")\n\n# Save as CSV (flat SUBJECT_ID, VISIT_NUM, ICD9_CODE — matches MIMIC-III output schema)\n_rows = []\nfor p in synthetic:\n for _vnum, _visit in enumerate(p[\"visits\"], 1):\n for _code in _visit:\n _rows.append({\"SUBJECT_ID\": p[\"patient_id\"],\n \"VISIT_NUM\": _vnum,\n \"ICD9_CODE\": _code})\ndf_synthetic = pd.DataFrame(_rows)\ncsv_path = f'{OUTPUT_DIR}/synthetic_patients.csv'\ndf_synthetic.to_csv(csv_path, index=False)\nprint(f\"✓ {len(df_synthetic):,} records → {csv_path}\")\nprint(f\" Columns: SUBJECT_ID, VISIT_NUM, ICD9_CODE\")\nprint(\"\\nSample rows:\")\ndisplay(df_synthetic.head(8))" + "source": "# Save as JSON (full nested records \u2014 directly loadable back into PyHealth)\njson_path = f'{OUTPUT_DIR}/synthetic_patients.json'\nwith open(json_path, 'w') as f:\n json.dump(synthetic, f, indent=2)\nprint(f\"\u2713 {len(synthetic):,} patients \u2192 {json_path}\")\n\n# Save as CSV (flat SUBJECT_ID, VISIT_NUM, ICD9_CODE \u2014 matches MIMIC-III output schema)\n_rows = []\nfor p in synthetic:\n for _vnum, _visit in enumerate(p[\"visits\"], 1):\n for _code in _visit:\n _rows.append({\"SUBJECT_ID\": p[\"patient_id\"],\n \"VISIT_NUM\": _vnum,\n \"ICD9_CODE\": _code})\ndf_synthetic = pd.DataFrame(_rows)\ncsv_path = f'{OUTPUT_DIR}/synthetic_patients.csv'\ndf_synthetic.to_csv(csv_path, index=False)\nprint(f\"\u2713 {len(df_synthetic):,} records \u2192 {csv_path}\")\nprint(f\" Columns: SUBJECT_ID, VISIT_NUM, ICD9_CODE\")\nprint(\"\\nSample rows:\")\ndisplay(df_synthetic.head(8))" }, { "cell_type": "markdown", @@ -178,7 +182,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "print(\"=\" * 60)\nprint(\"SYNTHETIC DATASET STATISTICS\")\nprint(\"=\" * 60)\n\nn_visits = [len(p[\"visits\"]) for p in synthetic]\nn_codes = [sum(len(v) for v in p[\"visits\"]) for p in synthetic]\n\nprint(f\"\\nPatients: {len(synthetic):,}\")\nprint(f\"\\nVisits per patient:\")\nprint(f\" Mean ± SD : {np.mean(n_visits):.2f} ± {np.std(n_visits):.2f}\")\nprint(f\" Median : {np.median(n_visits):.0f}\")\nprint(f\" Range : [{min(n_visits)}, {max(n_visits)}]\")\nprint(f\"\\nDiagnosis codes per patient:\")\nprint(f\" Mean ± SD : {np.mean(n_codes):.2f} ± {np.std(n_codes):.2f}\")\nprint(f\" Median : {np.median(n_codes):.0f}\")\nprint(f\" Range : [{min(n_codes)}, {max(n_codes)}]\")\n\nfig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))\nax1.hist(n_visits, bins=20, color='steelblue', edgecolor='white', alpha=0.85)\nax1.set_xlabel('Visits per patient'); ax1.set_ylabel('Count')\nax1.set_title('Visit Count Distribution')\nax2.hist(n_codes, bins=30, color='coral', edgecolor='white', alpha=0.85)\nax2.set_xlabel('Codes per patient'); ax2.set_ylabel('Count')\nax2.set_title('Code Count Distribution')\nplt.tight_layout()\nplt.savefig(f'{OUTPUT_DIR}/count_distributions.png', dpi=150)\nplt.show()" + "source": "print(\"=\" * 60)\nprint(\"SYNTHETIC DATASET STATISTICS\")\nprint(\"=\" * 60)\n\nn_visits = [len(p[\"visits\"]) for p in synthetic]\nn_codes = [sum(len(v) for v in p[\"visits\"]) for p in synthetic]\n\nprint(f\"\\nPatients: {len(synthetic):,}\")\nprint(f\"\\nVisits per patient:\")\nprint(f\" Mean \u00b1 SD : {np.mean(n_visits):.2f} \u00b1 {np.std(n_visits):.2f}\")\nprint(f\" Median : {np.median(n_visits):.0f}\")\nprint(f\" Range : [{min(n_visits)}, {max(n_visits)}]\")\nprint(f\"\\nDiagnosis codes per patient:\")\nprint(f\" Mean \u00b1 SD : {np.mean(n_codes):.2f} \u00b1 {np.std(n_codes):.2f}\")\nprint(f\" Median : {np.median(n_codes):.0f}\")\nprint(f\" Range : [{min(n_codes)}, {max(n_codes)}]\")\n\nfig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))\nax1.hist(n_visits, bins=20, color='steelblue', edgecolor='white', alpha=0.85)\nax1.set_xlabel('Visits per patient'); ax1.set_ylabel('Count')\nax1.set_title('Visit Count Distribution')\nax2.hist(n_codes, bins=30, color='coral', edgecolor='white', alpha=0.85)\nax2.set_xlabel('Codes per patient'); ax2.set_ylabel('Count')\nax2.set_title('Code Count Distribution')\nplt.tight_layout()\nplt.savefig(f'{OUTPUT_DIR}/count_distributions.png', dpi=150)\nplt.show()" }, { "cell_type": "code", @@ -186,7 +190,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "all_synth_codes = set(c for p in synthetic for v in p[\"visits\"] for c in v)\nn_real_codes = len(model._vocab._bart_to_code) # ICD-9 codes in vocabulary\ncoverage = len(all_synth_codes) / n_real_codes * 100 if n_real_codes > 0 else 0.0\n\nprint(f\"Vocabulary size (ICD-9 codes): {n_real_codes:,}\")\nprint(f\"Unique codes in synthetic: {len(all_synth_codes):,}\")\nprint(f\"Vocabulary coverage: {coverage:.1f}%\")\n\nif coverage < 30:\n print(\"\\n⚠ Low coverage may indicate mode collapse.\")\n print(\" Consider: more EPOCHS, lower LR, or check _demo_pool is populated.\")\nelif coverage < 60:\n print(\"\\nModerate coverage — expected for demo preset.\")\n print(\"Production training typically achieves 60–80%.\")\nelse:\n print(f\"\\n✓ Good vocabulary coverage.\")" + "source": "all_synth_codes = set(c for p in synthetic for v in p[\"visits\"] for c in v)\nn_real_codes = len(model._vocab._bart_to_code) # ICD-9 codes in vocabulary\ncoverage = len(all_synth_codes) / n_real_codes * 100 if n_real_codes > 0 else 0.0\n\nprint(f\"Vocabulary size (ICD-9 codes): {n_real_codes:,}\")\nprint(f\"Unique codes in synthetic: {len(all_synth_codes):,}\")\nprint(f\"Vocabulary coverage: {coverage:.1f}%\")\n\nif coverage < 30:\n print(\"\\n\u26a0 Low coverage may indicate mode collapse.\")\n print(\" Consider: more EPOCHS, lower LR, or check _demo_pool is populated.\")\nelif coverage < 60:\n print(\"\\nModerate coverage \u2014 expected for demo preset.\")\n print(\"Production training typically achieves 60\u201380%.\")\nelse:\n print(f\"\\n\u2713 Good vocabulary coverage.\")" }, { "cell_type": "code", @@ -194,7 +198,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "# model._demo_pool stores (age, gender) pairs from training data.\n# synthesize_dataset samples from this pool for each synthetic patient,\n# so the synthetic cohort's demographics automatically mirror the training population.\nif model._demo_pool:\n _ages = [a for a, g in model._demo_pool]\n _genders = [g for a, g in model._demo_pool]\n _n_male = sum(1 for g in _genders if g == 0)\n _n_female = sum(1 for g in _genders if g == 1)\n\n fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(13, 5))\n\n ax1.hist(_ages, bins=25, density=True, color='steelblue', edgecolor='white',\n alpha=0.8, label='Training population')\n ax1.axvline(np.mean(_ages), color='navy', linestyle='--', linewidth=1.5,\n label=f'Mean age: {np.mean(_ages):.1f}')\n ax1.set_xlabel('Age at first admission', fontsize=12)\n ax1.set_ylabel('Density', fontsize=12)\n ax1.set_title('Age Distribution\\n(Conditioning Source)', fontsize=13)\n ax1.legend(fontsize=10)\n\n _bars = ax2.bar(['Male', 'Female'], [_n_male, _n_female],\n color=['steelblue', 'coral'], edgecolor='white', alpha=0.85)\n for _bar, _val in zip(_bars, [_n_male, _n_female]):\n ax2.text(_bar.get_x() + _bar.get_width()/2, _bar.get_height() + 5,\n f'{_val:,}\\n({_val/len(_genders)*100:.1f}%)',\n ha='center', va='bottom', fontsize=11)\n ax2.set_ylabel('Patient count', fontsize=12)\n ax2.set_title('Gender Distribution\\n(Conditioning Source)', fontsize=13)\n\n plt.tight_layout()\n plt.savefig(f'{OUTPUT_DIR}/demographics_distribution.png', dpi=150)\n plt.show()\n\n print(f\"Demographics pool: {len(model._demo_pool):,} training patients\")\n print(f\" Age: mean={np.mean(_ages):.1f}, std={np.std(_ages):.1f}, \"\n f\"range=[{min(_ages):.0f}, {max(_ages):.0f}]\")\n print(f\" Male: {_n_male:,} ({_n_male/len(_genders)*100:.1f}%)\")\n print(f\" Female: {_n_female:,} ({_n_female/len(_genders)*100:.1f}%)\")\n print(\"\\n✓ Synthetic patients are generated with demographics sampled from this distribution.\")\nelse:\n print(\"_demo_pool is empty — model was not trained before calling synthesize_dataset.\")\n print(\"Run Section 4 first, or load a checkpoint that was saved after training.\")" + "source": "# model._demo_pool stores (age, gender) pairs from training data.\n# synthesize_dataset samples from this pool for each synthetic patient,\n# so the synthetic cohort's demographics automatically mirror the training population.\nif model._demo_pool:\n _ages = [a for a, g in model._demo_pool]\n _genders = [g for a, g in model._demo_pool]\n _n_male = sum(1 for g in _genders if g == 0)\n _n_female = sum(1 for g in _genders if g == 1)\n\n fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(13, 5))\n\n ax1.hist(_ages, bins=25, density=True, color='steelblue', edgecolor='white',\n alpha=0.8, label='Training population')\n ax1.axvline(np.mean(_ages), color='navy', linestyle='--', linewidth=1.5,\n label=f'Mean age: {np.mean(_ages):.1f}')\n ax1.set_xlabel('Age at first admission', fontsize=12)\n ax1.set_ylabel('Density', fontsize=12)\n ax1.set_title('Age Distribution\\n(Conditioning Source)', fontsize=13)\n ax1.legend(fontsize=10)\n\n _bars = ax2.bar(['Male', 'Female'], [_n_male, _n_female],\n color=['steelblue', 'coral'], edgecolor='white', alpha=0.85)\n for _bar, _val in zip(_bars, [_n_male, _n_female]):\n ax2.text(_bar.get_x() + _bar.get_width()/2, _bar.get_height() + 5,\n f'{_val:,}\\n({_val/len(_genders)*100:.1f}%)',\n ha='center', va='bottom', fontsize=11)\n ax2.set_ylabel('Patient count', fontsize=12)\n ax2.set_title('Gender Distribution\\n(Conditioning Source)', fontsize=13)\n\n plt.tight_layout()\n plt.savefig(f'{OUTPUT_DIR}/demographics_distribution.png', dpi=150)\n plt.show()\n\n print(f\"Demographics pool: {len(model._demo_pool):,} training patients\")\n print(f\" Age: mean={np.mean(_ages):.1f}, std={np.std(_ages):.1f}, \"\n f\"range=[{min(_ages):.0f}, {max(_ages):.0f}]\")\n print(f\" Male: {_n_male:,} ({_n_male/len(_genders)*100:.1f}%)\")\n print(f\" Female: {_n_female:,} ({_n_female/len(_genders)*100:.1f}%)\")\n print(\"\\n\u2713 Synthetic patients are generated with demographics sampled from this distribution.\")\nelse:\n print(\"_demo_pool is empty \u2014 model was not trained before calling synthesize_dataset.\")\n print(\"Run Section 4 first, or load a checkpoint that was saved after training.\")" }, { "cell_type": "code", @@ -202,7 +206,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "# Build real training code frequencies by decoding processor-encoded visit tensors.\n# NestedSequenceProcessor: index 0=pad, 1=unk, 2+=codes.\n# _PromptEHRVocab mapping: bart_id = processor_idx + 5 for codes (idx>=2).\n_vocab_map = model._vocab._bart_to_code # bart_token_id -> ICD-9 code string\n_real_counts = Counter()\n\nfor _sample in train_dataset:\n for _visit in _sample.get(\"visits\", []):\n for _tok in _visit:\n _idx = int(_tok.item()) if hasattr(_tok, 'item') else int(_tok)\n if _idx >= 2: # skip pad(0) and unk(1)\n _bart_id = _idx + 5\n _code = _vocab_map.get(_bart_id)\n if _code:\n _real_counts[_code] += 1\n\n_synth_counts = Counter(c for p in synthetic for v in p[\"visits\"] for c in v)\n\n_top_codes = [c for c, _ in _real_counts.most_common(20)]\n_real_freq = [_real_counts[c] for c in _top_codes]\n_synth_freq = [_synth_counts.get(c, 0) for c in _top_codes]\n\nfig, ax = plt.subplots(figsize=(15, 5))\n_x = range(len(_top_codes))\nax.bar([i - 0.2 for i in _x], _real_freq, 0.38, label='Real (training)', color='steelblue', alpha=0.85)\nax.bar([i + 0.2 for i in _x], _synth_freq, 0.38, label='Synthetic', color='coral', alpha=0.85)\nax.set_xticks(_x)\nax.set_xticklabels(_top_codes, rotation=45, ha='right', fontsize=9)\nax.set_ylabel('Frequency', fontsize=12)\nax.set_title('Top-20 ICD-9 Code Frequency: Real vs Synthetic', fontsize=14)\nax.legend(fontsize=11); ax.grid(axis='y', alpha=0.3)\nplt.tight_layout()\nplt.savefig(f'{OUTPUT_DIR}/code_frequency_comparison.png', dpi=150)\nplt.show()\n\n# Pearson r (manual computation — no scipy dependency)\n_r_mean = np.mean(_real_freq); _s_mean = np.mean(_synth_freq)\n_num = sum((r - _r_mean)*(s - _s_mean) for r, s in zip(_real_freq, _synth_freq))\n_denom = (sum((r-_r_mean)**2 for r in _real_freq) * sum((s-_s_mean)**2 for s in _synth_freq)) ** 0.5\npearson_r = _num / _denom if _denom > 0 else 0.0\nprint(f\"Pearson r (top-20 code frequencies, real vs synthetic): {pearson_r:.3f}\")\nif pearson_r > 0.8: print(\"✓ Strong correlation — good distributional fidelity.\")\nelif pearson_r > 0.5: print(\"Moderate correlation — consider more epochs.\")\nelse: print(\"Weak correlation — model may need more training.\")" + "source": "# Build real training code frequencies by decoding processor-encoded visit tensors.\n# NestedSequenceProcessor: index 0=pad, 1=unk, 2+=codes.\n# _PromptEHRVocab mapping: bart_id = processor_idx + 5 for codes (idx>=2).\n_vocab_map = model._vocab._bart_to_code # bart_token_id -> ICD-9 code string\n_real_counts = Counter()\n\nfor _sample in train_dataset:\n for _visit in _sample.get(\"visits\", []):\n for _tok in _visit:\n _idx = int(_tok.item()) if hasattr(_tok, 'item') else int(_tok)\n if _idx >= 2: # skip pad(0) and unk(1)\n _bart_id = _idx + 5\n _code = _vocab_map.get(_bart_id)\n if _code:\n _real_counts[_code] += 1\n\n_synth_counts = Counter(c for p in synthetic for v in p[\"visits\"] for c in v)\n\n_top_codes = [c for c, _ in _real_counts.most_common(20)]\n_real_freq = [_real_counts[c] for c in _top_codes]\n_synth_freq = [_synth_counts.get(c, 0) for c in _top_codes]\n\nfig, ax = plt.subplots(figsize=(15, 5))\n_x = range(len(_top_codes))\nax.bar([i - 0.2 for i in _x], _real_freq, 0.38, label='Real (training)', color='steelblue', alpha=0.85)\nax.bar([i + 0.2 for i in _x], _synth_freq, 0.38, label='Synthetic', color='coral', alpha=0.85)\nax.set_xticks(_x)\nax.set_xticklabels(_top_codes, rotation=45, ha='right', fontsize=9)\nax.set_ylabel('Frequency', fontsize=12)\nax.set_title('Top-20 ICD-9 Code Frequency: Real vs Synthetic', fontsize=14)\nax.legend(fontsize=11); ax.grid(axis='y', alpha=0.3)\nplt.tight_layout()\nplt.savefig(f'{OUTPUT_DIR}/code_frequency_comparison.png', dpi=150)\nplt.show()\n\n# Pearson r (manual computation \u2014 no scipy dependency)\n_r_mean = np.mean(_real_freq); _s_mean = np.mean(_synth_freq)\n_num = sum((r - _r_mean)*(s - _s_mean) for r, s in zip(_real_freq, _synth_freq))\n_denom = (sum((r-_r_mean)**2 for r in _real_freq) * sum((s-_s_mean)**2 for s in _synth_freq)) ** 0.5\npearson_r = _num / _denom if _denom > 0 else 0.0\nprint(f\"Pearson r (top-20 code frequencies, real vs synthetic): {pearson_r:.3f}\")\nif pearson_r > 0.8: print(\"\u2713 Strong correlation \u2014 good distributional fidelity.\")\nelif pearson_r > 0.5: print(\"Moderate correlation \u2014 consider more epochs.\")\nelse: print(\"Weak correlation \u2014 model may need more training.\")" }, { "cell_type": "code", @@ -210,7 +214,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "_empty = [p for p in synthetic if not p[\"visits\"] or all(len(v) == 0 for v in p[\"visits\"])]\nif _empty:\n print(f\"⚠ {len(_empty)} / {len(synthetic)} patients have empty visit sequences.\")\n print(\" Possible causes:\")\n print(\" - Model is undertrained (increase EPOCHS)\")\n print(\" - Temperature too low (try TEMPERATURE = 1.0)\")\n print(\" - _demo_pool not populated (train before calling synthesize_dataset)\")\nelse:\n print(f\"✓ All {len(synthetic):,} patients have at least one visit with at least one code.\")" + "source": "_empty = [p for p in synthetic if not p[\"visits\"] or all(len(v) == 0 for v in p[\"visits\"])]\nif _empty:\n print(f\"\u26a0 {len(_empty)} / {len(synthetic)} patients have empty visit sequences.\")\n print(\" Possible causes:\")\n print(\" - Model is undertrained (increase EPOCHS)\")\n print(\" - Temperature too low (try TEMPERATURE = 1.0)\")\n print(\" - _demo_pool not populated (train before calling synthesize_dataset)\")\nelse:\n print(f\"\u2713 All {len(synthetic):,} patients have at least one visit with at least one code.\")" }, { "cell_type": "code", @@ -218,7 +222,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "quality = {\n \"total_synthetic_patients\": len(synthetic),\n \"mean_visits_per_patient\": round(float(np.mean(n_visits)), 3),\n \"std_visits_per_patient\": round(float(np.std(n_visits)), 3),\n \"mean_codes_per_patient\": round(float(np.mean(n_codes)), 3),\n \"std_codes_per_patient\": round(float(np.std(n_codes)), 3),\n \"unique_codes_generated\": len(all_synth_codes),\n \"vocabulary_size\": n_real_codes,\n \"vocabulary_coverage_pct\": round(coverage, 2),\n \"empty_patients_count\": len(_empty),\n \"code_freq_pearson_r\": round(pearson_r, 4),\n \"training_patients\": len(train_dataset),\n \"vocab_total_size\": model._vocab.total_size,\n \"demo_mode\": DEMO_MODE,\n \"preset\": PRESET,\n \"epochs\": EPOCHS,\n \"seed\": SEED,\n \"timestamp\": datetime.now().isoformat(),\n}\nreport_path = f'{OUTPUT_DIR}/quality_report.json'\nwith open(report_path, 'w') as f:\n json.dump(quality, f, indent=2)\nprint(\"Quality Report:\")\nprint(json.dumps(quality, indent=2))\nprint(f\"\\n✓ Saved to {report_path}\")" + "source": "quality = {\n \"total_synthetic_patients\": len(synthetic),\n \"mean_visits_per_patient\": round(float(np.mean(n_visits)), 3),\n \"std_visits_per_patient\": round(float(np.std(n_visits)), 3),\n \"mean_codes_per_patient\": round(float(np.mean(n_codes)), 3),\n \"std_codes_per_patient\": round(float(np.std(n_codes)), 3),\n \"unique_codes_generated\": len(all_synth_codes),\n \"vocabulary_size\": n_real_codes,\n \"vocabulary_coverage_pct\": round(coverage, 2),\n \"empty_patients_count\": len(_empty),\n \"code_freq_pearson_r\": round(pearson_r, 4),\n \"training_patients\": len(train_dataset),\n \"vocab_total_size\": model._vocab.total_size,\n \"demo_mode\": DEMO_MODE,\n \"preset\": PRESET,\n \"epochs\": EPOCHS,\n \"seed\": SEED,\n \"timestamp\": datetime.now().isoformat(),\n}\nreport_path = f'{OUTPUT_DIR}/quality_report.json'\nwith open(report_path, 'w') as f:\n json.dump(quality, f, indent=2)\nprint(\"Quality Report:\")\nprint(json.dumps(quality, indent=2))\nprint(f\"\\n\u2713 Saved to {report_path}\")" }, { "cell_type": "markdown", @@ -232,7 +236,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "# Download output files (Colab only — silently skipped in local/SLURM environments)\n_outputs = [\n csv_path,\n json_path,\n report_path,\n f'{OUTPUT_DIR}/training_loss.png',\n f'{OUTPUT_DIR}/demographics_distribution.png',\n f'{OUTPUT_DIR}/code_frequency_comparison.png',\n f'{CHECKPOINT_DIR}/checkpoint.pt',\n f'{CHECKPOINT_DIR}/config.json',\n]\n\nif IN_COLAB:\n from google.colab import files as _colab_files\n print(\"Downloading output files...\")\n for _p in _outputs:\n if os.path.exists(_p):\n _colab_files.download(_p)\n print(f\" ✓ {os.path.basename(_p)}\")\n else:\n print(f\" — {os.path.basename(_p)} (not found)\")\nelse:\n print(f\"Output files saved to: {OUTPUT_DIR}\")\n print(f\"Checkpoint: {CHECKPOINT_DIR}/checkpoint.pt\")\n for _p in _outputs:\n if os.path.exists(_p):\n _kb = os.path.getsize(_p) / 1024\n print(f\" {os.path.basename(_p):45s} {_kb:8.1f} KB\")" + "source": "# Download output files (Colab only \u2014 silently skipped in local/SLURM environments)\n_outputs = [\n csv_path,\n json_path,\n report_path,\n f'{OUTPUT_DIR}/training_loss.png',\n f'{OUTPUT_DIR}/demographics_distribution.png',\n f'{OUTPUT_DIR}/code_frequency_comparison.png',\n f'{CHECKPOINT_DIR}/checkpoint.pt',\n f'{CHECKPOINT_DIR}/config.json',\n]\n\nif IN_COLAB:\n from google.colab import files as _colab_files\n print(\"Downloading output files...\")\n for _p in _outputs:\n if os.path.exists(_p):\n _colab_files.download(_p)\n print(f\" \u2713 {os.path.basename(_p)}\")\n else:\n print(f\" \u2014 {os.path.basename(_p)} (not found)\")\nelse:\n print(f\"Output files saved to: {OUTPUT_DIR}\")\n print(f\"Checkpoint: {CHECKPOINT_DIR}/checkpoint.pt\")\n for _p in _outputs:\n if os.path.exists(_p):\n _kb = os.path.getsize(_p) / 1024\n print(f\" {os.path.basename(_p):45s} {_kb:8.1f} KB\")" }, { "cell_type": "code", @@ -240,13 +244,13 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "# ─────────────────────────────────────────────────────────────────────────────\n# CHECKPOINT RESUME — Run this cell instead of Section 4 if you already trained\n# ─────────────────────────────────────────────────────────────────────────────\n# Uncomment everything below to load an existing checkpoint, then skip to Section 5.\n\n# from pyhealth.datasets import MIMIC3Dataset, split_by_patient\n# from pyhealth.tasks import promptehr_generation_mimic3_fn\n# from pyhealth.models import PromptEHR\n#\n# dataset = MIMIC3Dataset(\n# root=DATA_DIR,\n# tables=[\"patients\", \"admissions\", \"diagnoses_icd\"],\n# code_mapping={},\n# )\n# sample_dataset = dataset.set_task(promptehr_generation_mimic3_fn)\n# train_dataset, val_dataset, _ = split_by_patient(sample_dataset, [0.8, 0.1, 0.1])\n#\n# model = PromptEHR(\n# dataset=train_dataset,\n# n_num_features=1, cat_cardinalities=[2],\n# d_hidden=D_HIDDEN, prompt_length=PROMPT_LENGTH,\n# bart_config_name=BART_CONFIG_NAME,\n# epochs=EPOCHS, batch_size=BATCH_SIZE,\n# lr=LR, warmup_steps=WARMUP_STEPS,\n# max_seq_length=MAX_SEQ_LENGTH,\n# save_dir=CHECKPOINT_DIR,\n# )\n# ckpt = f'{CHECKPOINT_DIR}/checkpoint.pt'\n# model.load_model(ckpt)\n# print(f\"✓ Loaded checkpoint from {ckpt}. Proceed to Section 5.\")\n\nprint(\"(Resume template — uncomment the lines above to use)\")" + "source": "# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n# CHECKPOINT RESUME \u2014 Run this cell instead of Section 4 if you already trained\n# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n# Uncomment everything below to load an existing checkpoint, then skip to Section 5.\n\n# from pyhealth.datasets import MIMIC3Dataset, split_by_patient\n# from pyhealth.tasks import promptehr_generation_mimic3_fn\n# from pyhealth.models import PromptEHR\n#\n# dataset = MIMIC3Dataset(\n# root=DATA_DIR,\n# tables=[\"patients\", \"admissions\", \"diagnoses_icd\"],\n# code_mapping={},\n# )\n# sample_dataset = dataset.set_task(promptehr_generation_mimic3_fn)\n# train_dataset, val_dataset, _ = split_by_patient(sample_dataset, [0.8, 0.1, 0.1])\n#\n# model = PromptEHR(\n# dataset=train_dataset,\n# n_num_features=1, cat_cardinalities=[2],\n# d_hidden=D_HIDDEN, prompt_length=PROMPT_LENGTH,\n# bart_config_name=BART_CONFIG_NAME,\n# epochs=EPOCHS, batch_size=BATCH_SIZE,\n# lr=LR, warmup_steps=WARMUP_STEPS,\n# max_seq_length=MAX_SEQ_LENGTH,\n# save_dir=CHECKPOINT_DIR,\n# )\n# ckpt = f'{CHECKPOINT_DIR}/checkpoint.pt'\n# model.load_model(ckpt)\n# print(f\"\u2713 Loaded checkpoint from {ckpt}. Proceed to Section 5.\")\n\nprint(\"(Resume template \u2014 uncomment the lines above to use)\")" }, { "cell_type": "markdown", "id": "s7-congrats", "metadata": {}, - "source": "---\n## 🎉 Congratulations!\n\nYou've successfully:\n1. ✅ Trained a PromptEHR model conditioned on patient demographics\n2. ✅ Generated synthetic patients whose age/gender distribution mirrors MIMIC-III\n3. ✅ Validated ICD-9 code frequency fidelity against real training data\n4. ✅ Saved output files for downstream use\n\n## Next Steps\n\n**Use your synthetic data:**\n- Train readmission/mortality/LoS prediction models on synthetic data\n- Evaluate fairness across demographic subgroups\n- Share synthetic patients without privacy concerns\n\n**Reload and generate more:**\n```python\nfrom pyhealth.models import PromptEHR\nmodel = PromptEHR(dataset=train_dataset, ...)\nmodel.load_model('./promptehr_training/checkpoints/checkpoint.pt')\nextra = model.synthesize_dataset(num_samples=50_000)\n```\n\n## Troubleshooting\n\n| Symptom | Cause | Fix |\n|---------|-------|-----|\n| `AssertionError: transformers>=4.48.3 required` | Old transformers installed | `pip install transformers --upgrade` |\n| Empty patients in output | Undertrained model | Increase `EPOCHS` or raise `TEMPERATURE` to `1.0` |\n| Training loss not decreasing after 2+ epochs | LR too high | Try `LR = 5e-6` and `WARMUP_STEPS = 500` |\n| Out of memory (OOM) | Batch too large | Reduce `BATCH_SIZE = 8` |\n| Very slow training | No GPU | Runtime → Change runtime type → T4 GPU |\n| `KeyError: 'visits'` in demo mode | Wrong schema | Ensure `input_schema={\"visits\": \"nested_sequence\"}` |\n| Synthetic codes all the same | Temperature too low | Try `TEMPERATURE = 1.0`, `RANDOM_SAMPLING = True` |\n\n---\n\n## Reference\n\nWang, Y., et al. \"PromptEHR: Conditional Electronic Healthcare Records Generation with Prompt Learning.\" *EMNLP 2023*. https://arxiv.org/abs/2211.01761\n\n---\n_Notebook for PyHealth 2.0 · Branch: `promptehr-pr-integration` · jalengg/PyHealth_" + "source": "---\n## \ud83c\udf89 Congratulations!\n\nYou've successfully:\n1. \u2705 Trained a PromptEHR model conditioned on patient demographics\n2. \u2705 Generated synthetic patients whose age/gender distribution mirrors MIMIC-III\n3. \u2705 Validated ICD-9 code frequency fidelity against real training data\n4. \u2705 Saved output files for downstream use\n\n## Next Steps\n\n**Use your synthetic data:**\n- Train readmission/mortality/LoS prediction models on synthetic data\n- Evaluate fairness across demographic subgroups\n- Share synthetic patients without privacy concerns\n\n**Reload and generate more:**\n```python\nfrom pyhealth.models import PromptEHR\nmodel = PromptEHR(dataset=train_dataset, ...)\nmodel.load_model('./promptehr_training/checkpoints/checkpoint.pt')\nextra = model.synthesize_dataset(num_samples=50_000)\n```\n\n## Troubleshooting\n\n| Symptom | Cause | Fix |\n|---------|-------|-----|\n| `AssertionError: transformers>=4.48.3 required` | Old transformers installed | `pip install transformers --upgrade` |\n| Empty patients in output | Undertrained model | Increase `EPOCHS` or raise `TEMPERATURE` to `1.0` |\n| Training loss not decreasing after 2+ epochs | LR too high | Try `LR = 5e-6` and `WARMUP_STEPS = 500` |\n| Out of memory (OOM) | Batch too large | Reduce `BATCH_SIZE = 8` |\n| Very slow training | No GPU | Runtime \u2192 Change runtime type \u2192 T4 GPU |\n| `KeyError: 'visits'` in demo mode | Wrong schema | Ensure `input_schema={\"visits\": \"nested_sequence\"}` |\n| Synthetic codes all the same | Temperature too low | Try `TEMPERATURE = 1.0`, `RANDOM_SAMPLING = True` |\n\n---\n\n## Reference\n\nWang, Y., et al. \"PromptEHR: Conditional Electronic Healthcare Records Generation with Prompt Learning.\" *EMNLP 2023*. https://arxiv.org/abs/2211.01761\n\n---\n_Notebook for PyHealth 2.0 \u00b7 Branch: `promptehr-pr-integration` \u00b7 jalengg/PyHealth_" } ] } \ No newline at end of file From bf203a68b7cab3f33ef8ac3b1a2979c6932f25f7 Mon Sep 17 00:00:00 2001 From: jalengg Date: Wed, 4 Mar 2026 04:20:35 -0600 Subject: [PATCH 24/37] Fix: force-reinstall numpy+scipy post-PyHealth to clear mixed-version state PyHealth --force-reinstall can leave numpy/scipy in a mixed state where Python files and compiled .so extensions are from different versions, causing 'cannot import name _center from numpy._core.umath'. Fix: add --force-reinstall and explicit numpy~=2.2.0 to the post-PyHealth pip upgrade step, guaranteeing all numpy/scipy files are from consistent versions that support each other and numpy 2.x. --- examples/promptehr_mimic3_colab.ipynb | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/promptehr_mimic3_colab.ipynb b/examples/promptehr_mimic3_colab.ipynb index b9b2cf8d0..3b195ec7d 100644 --- a/examples/promptehr_mimic3_colab.ipynb +++ b/examples/promptehr_mimic3_colab.ipynb @@ -22,7 +22,7 @@ "cell_type": "markdown", "id": "preamble", "metadata": {}, - "source": "# PromptEHR Synthetic Data Generation for MIMIC-III\n\n_Last updated: 2026-03-04 10:07:18 (UTC)_\n\nThis notebook trains PromptEHR on your MIMIC-III data and generates synthetic patients conditioned on patient demographics (age and gender).\n\n## What You'll Need\n\n1. **MIMIC-III Access**: Download these files from PhysioNet:\n - `PATIENTS.csv` \u2014 patient demographics (date of birth, gender)\n - `ADMISSIONS.csv` \u2014 hospital admission records (timestamps)\n - `DIAGNOSES_ICD.csv` \u2014 ICD-9 diagnosis codes\n\n2. **Google Colab**: Free tier works, but GPU recommended (Runtime \u2192 Change runtime type \u2192 GPU)\n\n3. **Time**:\n - Demo (5 epochs, 1K samples): ~30\u201345 min on GPU\n - Production (20 epochs, 10K samples): ~3\u20135 hrs on GPU\n\n## How It Works\n\n1. **Setup**: Install PyHealth and mount Google Drive\n2. **Upload Data**: Upload your MIMIC-III CSV files (persisted to Drive across sessions)\n3. **Configure**: Set hyperparameters (epochs, batch size, etc.)\n4. **Train**: Fine-tune PromptEHR on MIMIC-III (checkpoints saved to Drive)\n5. **Generate**: Create synthetic patients conditioned on sampled demographics\n6. **Download**: Get CSV and JSON files with synthetic data\n\n## Important Notes\n\n**Colab Timeout**: Free Colab sessions timeout after ~12 hours. For production training, consider:\n- Colab Pro for longer sessions\n- Running on a GPU cluster using `examples/slurm/`\n\n**Demo vs Production**:\n- Demo defaults (5 epochs, 1K samples) let you test the full pipeline quickly\n- Production settings (20 epochs, 10K samples) produce publication-quality results\n\n## References\n\n- [PromptEHR Paper](https://arxiv.org/abs/2211.01761) \u2014 Wang et al., EMNLP 2023\n- [PyHealth Documentation](https://pyhealth.readthedocs.io/)\n- [MIMIC-III Access](https://physionet.org/content/mimiciii/)" + "source": "# PromptEHR Synthetic Data Generation for MIMIC-III\n\n_Last updated: 2026-03-04 10:29:15 (UTC)_\n\nThis notebook trains PromptEHR on your MIMIC-III data and generates synthetic patients conditioned on patient demographics (age and gender).\n\n## What You'll Need\n\n1. **MIMIC-III Access**: Download these files from PhysioNet:\n - `PATIENTS.csv` \u2014 patient demographics (date of birth, gender)\n - `ADMISSIONS.csv` \u2014 hospital admission records (timestamps)\n - `DIAGNOSES_ICD.csv` \u2014 ICD-9 diagnosis codes\n\n2. **Google Colab**: Free tier works, but GPU recommended (Runtime \u2192 Change runtime type \u2192 GPU)\n\n3. **Time**:\n - Demo (5 epochs, 1K samples): ~30\u201345 min on GPU\n - Production (20 epochs, 10K samples): ~3\u20135 hrs on GPU\n\n## How It Works\n\n1. **Setup**: Install PyHealth and mount Google Drive\n2. **Upload Data**: Upload your MIMIC-III CSV files (persisted to Drive across sessions)\n3. **Configure**: Set hyperparameters (epochs, batch size, etc.)\n4. **Train**: Fine-tune PromptEHR on MIMIC-III (checkpoints saved to Drive)\n5. **Generate**: Create synthetic patients conditioned on sampled demographics\n6. **Download**: Get CSV and JSON files with synthetic data\n\n## Important Notes\n\n**Colab Timeout**: Free Colab sessions timeout after ~12 hours. For production training, consider:\n- Colab Pro for longer sessions\n- Running on a GPU cluster using `examples/slurm/`\n\n**Demo vs Production**:\n- Demo defaults (5 epochs, 1K samples) let you test the full pipeline quickly\n- Production settings (20 epochs, 10K samples) produce publication-quality results\n\n## References\n\n- [PromptEHR Paper](https://arxiv.org/abs/2211.01761) \u2014 Wang et al., EMNLP 2023\n- [PyHealth Documentation](https://pyhealth.readthedocs.io/)\n- [MIMIC-III Access](https://physionet.org/content/mimiciii/)" }, { "cell_type": "markdown", @@ -36,7 +36,9 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "import subprocess\nimport sys\n\n# 1. Install PyHealth from GitHub \u2014 force-reinstall ensures Colab never uses a stale cached build.\n# (This may pull in old/mismatched transitive deps \u2014 we normalize them in step 2.)\nFORK = 'jalengg'\nBRANCH = 'promptehr-pr-integration'\ninstall_url = f\"git+https://github.com/{FORK}/PyHealth.git@{BRANCH}\"\nresult = subprocess.run(\n [sys.executable, \"-m\", \"pip\", \"install\", install_url,\n \"--quiet\", \"--no-cache-dir\", \"--force-reinstall\"],\n capture_output=True, text=True,\n)\nif result.returncode != 0:\n print(result.stderr)\n raise RuntimeError(\"PyHealth installation failed \u2014 see error above.\")\nprint(f\"\u2713 PyHealth installed from {FORK}/{BRANCH}\")\n\n# 2. Normalize transitive deps AFTER PyHealth install.\n# --force-reinstall above can leave packages in mixed-version states:\n# - scipy>=1.14 supports numpy 2.x (fixes transformers\u2192sklearn\u2192scipy cascade)\n# - scikit-learn>=1.5 matches scipy>=1.14\n# - Pillow>=10.4.0 ensures consistent PIL internals (_Ink moved between versions)\nsubprocess.run(\n [sys.executable, \"-m\", \"pip\", \"install\",\n \"scipy>=1.14\", \"scikit-learn>=1.5\", \"Pillow>=10.4.0\",\n \"--quiet\", \"--no-cache-dir\"],\n check=True,\n)\nprint(\"\u2713 scipy>=1.14, scikit-learn>=1.5, Pillow>=10.4.0 installed\")\n\n# Environment detection \u2014 MUST come before any google.colab import\ntry:\n import google.colab # noqa: F401\n IN_COLAB = True\nexcept ImportError:\n IN_COLAB = False\n\nimport os\nimport json\nimport glob\nimport random\nimport numpy as np\nimport torch\nimport pandas as pd\nimport matplotlib.pyplot as plt\nfrom datetime import datetime\nfrom collections import Counter\nfrom IPython.display import display\n\nprint(f\"Running in Colab: {IN_COLAB}\")\nprint(f\"PyTorch: {torch.__version__}\")\nif torch.cuda.is_available():\n print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n gb = torch.cuda.get_device_properties(0).total_memory / 1e9\n print(f\"GPU memory: {gb:.1f} GB\")\nelse:\n print(\"WARNING: No GPU detected. Training will be slow.\")\n print(\" \u2192 Runtime \u2192 Change runtime type \u2192 T4 GPU\")\n\n# Verify HuggingFace transformers version (>=4.48.3 required for use_cpu parameter)\nimport transformers\n_ver = tuple(int(x) for x in transformers.__version__.split(\".\")[:2])\nassert _ver >= (4, 48), (\n f\"transformers>=4.48.3 required (got {transformers.__version__}). \"\n \"Fix: pip install transformers --upgrade\"\n)\nprint(f\"transformers: {transformers.__version__} \u2713\")\nprint(\"\u2713 All setup complete\")" + "source": [ + "import subprocess\nimport sys\n\n# 1. Install PyHealth from GitHub \u2014 force-reinstall ensures Colab never uses a stale cached build.\n# (This may pull in old/mismatched transitive deps \u2014 we normalize them in step 2.)\nFORK = 'jalengg'\nBRANCH = 'promptehr-pr-integration'\ninstall_url = f\"git+https://github.com/{FORK}/PyHealth.git@{BRANCH}\"\nresult = subprocess.run(\n [sys.executable, \"-m\", \"pip\", \"install\", install_url,\n \"--quiet\", \"--no-cache-dir\", \"--force-reinstall\"],\n capture_output=True, text=True,\n)\nif result.returncode != 0:\n print(result.stderr)\n raise RuntimeError(\"PyHealth installation failed \u2014 see error above.\")\nprint(f\"\u2713 PyHealth installed from {FORK}/{BRANCH}\")\n\n# 2. Force-reinstall key transitive deps AFTER PyHealth install.\n# --force-reinstall on PyHealth above can leave numpy/scipy in a\n# mixed-version state (Python files from one version, compiled\n# extensions from another). We force-reinstall them here to guarantee\n# a consistent, fully-installed set:\n# - numpy~=2.2.0: Colab numpy 2.x has _center in _core.umath (required by scipy)\n# - scipy>=1.14: first scipy release with full numpy 2.x support\n# - scikit-learn>=1.5: compatible with scipy>=1.14\n# - Pillow>=10.4.0: consistent PIL internals (_Ink moved between versions)\nsubprocess.run(\n [sys.executable, \"-m\", \"pip\", \"install\",\n \"numpy~=2.2.0\", \"scipy>=1.14\", \"scikit-learn>=1.5\", \"Pillow>=10.4.0\",\n \"--quiet\", \"--no-cache-dir\", \"--force-reinstall\"],\n check=True,\n)\nprint(\"\u2713 numpy~=2.2.0, scipy>=1.14, scikit-learn>=1.5, Pillow>=10.4.0 reinstalled\")\n\n# Environment detection \u2014 MUST come before any google.colab import\ntry:\n import google.colab # noqa: F401\n IN_COLAB = True\nexcept ImportError:\n IN_COLAB = False\n\nimport os\nimport json\nimport glob\nimport random\nimport numpy as np\nimport torch\nimport pandas as pd\nimport matplotlib.pyplot as plt\nfrom datetime import datetime\nfrom collections import Counter\nfrom IPython.display import display\n\nprint(f\"Running in Colab: {IN_COLAB}\")\nprint(f\"PyTorch: {torch.__version__}\")\nif torch.cuda.is_available():\n print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n gb = torch.cuda.get_device_properties(0).total_memory / 1e9\n print(f\"GPU memory: {gb:.1f} GB\")\nelse:\n print(\"WARNING: No GPU detected. Training will be slow.\")\n print(\" \u2192 Runtime \u2192 Change runtime type \u2192 T4 GPU\")\n\n# Verify HuggingFace transformers version (>=4.48.3 required for use_cpu parameter)\nimport transformers\n_ver = tuple(int(x) for x in transformers.__version__.split(\".\")[:2])\nassert _ver >= (4, 48), (\n f\"transformers>=4.48.3 required (got {transformers.__version__}). \"\n \"Fix: pip install transformers --upgrade\"\n)\nprint(f\"transformers: {transformers.__version__} \u2713\")\nprint(\"\u2713 All setup complete\")" + ] }, { "cell_type": "markdown", From 9ac2960b26f9991612a8a2c1af631bcece6a9598 Mon Sep 17 00:00:00 2001 From: jalengg Date: Wed, 4 Mar 2026 04:20:45 -0600 Subject: [PATCH 25/37] Chore: fix notebook timestamp to 2026-03-04 10:20:35 (UTC) --- examples/promptehr_mimic3_colab.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/promptehr_mimic3_colab.ipynb b/examples/promptehr_mimic3_colab.ipynb index 3b195ec7d..ffd3a568b 100644 --- a/examples/promptehr_mimic3_colab.ipynb +++ b/examples/promptehr_mimic3_colab.ipynb @@ -22,7 +22,7 @@ "cell_type": "markdown", "id": "preamble", "metadata": {}, - "source": "# PromptEHR Synthetic Data Generation for MIMIC-III\n\n_Last updated: 2026-03-04 10:29:15 (UTC)_\n\nThis notebook trains PromptEHR on your MIMIC-III data and generates synthetic patients conditioned on patient demographics (age and gender).\n\n## What You'll Need\n\n1. **MIMIC-III Access**: Download these files from PhysioNet:\n - `PATIENTS.csv` \u2014 patient demographics (date of birth, gender)\n - `ADMISSIONS.csv` \u2014 hospital admission records (timestamps)\n - `DIAGNOSES_ICD.csv` \u2014 ICD-9 diagnosis codes\n\n2. **Google Colab**: Free tier works, but GPU recommended (Runtime \u2192 Change runtime type \u2192 GPU)\n\n3. **Time**:\n - Demo (5 epochs, 1K samples): ~30\u201345 min on GPU\n - Production (20 epochs, 10K samples): ~3\u20135 hrs on GPU\n\n## How It Works\n\n1. **Setup**: Install PyHealth and mount Google Drive\n2. **Upload Data**: Upload your MIMIC-III CSV files (persisted to Drive across sessions)\n3. **Configure**: Set hyperparameters (epochs, batch size, etc.)\n4. **Train**: Fine-tune PromptEHR on MIMIC-III (checkpoints saved to Drive)\n5. **Generate**: Create synthetic patients conditioned on sampled demographics\n6. **Download**: Get CSV and JSON files with synthetic data\n\n## Important Notes\n\n**Colab Timeout**: Free Colab sessions timeout after ~12 hours. For production training, consider:\n- Colab Pro for longer sessions\n- Running on a GPU cluster using `examples/slurm/`\n\n**Demo vs Production**:\n- Demo defaults (5 epochs, 1K samples) let you test the full pipeline quickly\n- Production settings (20 epochs, 10K samples) produce publication-quality results\n\n## References\n\n- [PromptEHR Paper](https://arxiv.org/abs/2211.01761) \u2014 Wang et al., EMNLP 2023\n- [PyHealth Documentation](https://pyhealth.readthedocs.io/)\n- [MIMIC-III Access](https://physionet.org/content/mimiciii/)" + "source": "# PromptEHR Synthetic Data Generation for MIMIC-III\n\n_Last updated: 2026-03-04 10:20:35 (UTC)_\n\nThis notebook trains PromptEHR on your MIMIC-III data and generates synthetic patients conditioned on patient demographics (age and gender).\n\n## What You'll Need\n\n1. **MIMIC-III Access**: Download these files from PhysioNet:\n - `PATIENTS.csv` \u2014 patient demographics (date of birth, gender)\n - `ADMISSIONS.csv` \u2014 hospital admission records (timestamps)\n - `DIAGNOSES_ICD.csv` \u2014 ICD-9 diagnosis codes\n\n2. **Google Colab**: Free tier works, but GPU recommended (Runtime \u2192 Change runtime type \u2192 GPU)\n\n3. **Time**:\n - Demo (5 epochs, 1K samples): ~30\u201345 min on GPU\n - Production (20 epochs, 10K samples): ~3\u20135 hrs on GPU\n\n## How It Works\n\n1. **Setup**: Install PyHealth and mount Google Drive\n2. **Upload Data**: Upload your MIMIC-III CSV files (persisted to Drive across sessions)\n3. **Configure**: Set hyperparameters (epochs, batch size, etc.)\n4. **Train**: Fine-tune PromptEHR on MIMIC-III (checkpoints saved to Drive)\n5. **Generate**: Create synthetic patients conditioned on sampled demographics\n6. **Download**: Get CSV and JSON files with synthetic data\n\n## Important Notes\n\n**Colab Timeout**: Free Colab sessions timeout after ~12 hours. For production training, consider:\n- Colab Pro for longer sessions\n- Running on a GPU cluster using `examples/slurm/`\n\n**Demo vs Production**:\n- Demo defaults (5 epochs, 1K samples) let you test the full pipeline quickly\n- Production settings (20 epochs, 10K samples) produce publication-quality results\n\n## References\n\n- [PromptEHR Paper](https://arxiv.org/abs/2211.01761) \u2014 Wang et al., EMNLP 2023\n- [PyHealth Documentation](https://pyhealth.readthedocs.io/)\n- [MIMIC-III Access](https://physionet.org/content/mimiciii/)" }, { "cell_type": "markdown", From 3a9c4cd23ae05f3a52fcc3ff305a7dc9bb0986d9 Mon Sep 17 00:00:00 2001 From: jalengg Date: Wed, 4 Mar 2026 04:31:21 -0600 Subject: [PATCH 26/37] Fix: processors/__init__ __all__ + RuntimeError guard + numpy>=2.0.0 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add FeatureProcessor import (was in __all__ but never imported) - Remove LabelProcessor from __all__ (class does not exist) - Guard ImageProcessor/TimeImageProcessor with (ImportError, RuntimeError) to catch broken Pillow installs that raise RuntimeError, not ImportError - Build __all__ dynamically so guarded processors are only listed when their imports succeed - Change numpy~=2.2.0 → numpy>=2.0.0 in notebook post-install to avoid hard ceiling at <2.3 that would downgrade as Colab numpy advances --- examples/promptehr_mimic3_colab.ipynb | 4 ++-- pyhealth/processors/__init__.py | 24 ++++++++++++++++-------- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/examples/promptehr_mimic3_colab.ipynb b/examples/promptehr_mimic3_colab.ipynb index ffd3a568b..9c4d4a02c 100644 --- a/examples/promptehr_mimic3_colab.ipynb +++ b/examples/promptehr_mimic3_colab.ipynb @@ -22,7 +22,7 @@ "cell_type": "markdown", "id": "preamble", "metadata": {}, - "source": "# PromptEHR Synthetic Data Generation for MIMIC-III\n\n_Last updated: 2026-03-04 10:20:35 (UTC)_\n\nThis notebook trains PromptEHR on your MIMIC-III data and generates synthetic patients conditioned on patient demographics (age and gender).\n\n## What You'll Need\n\n1. **MIMIC-III Access**: Download these files from PhysioNet:\n - `PATIENTS.csv` \u2014 patient demographics (date of birth, gender)\n - `ADMISSIONS.csv` \u2014 hospital admission records (timestamps)\n - `DIAGNOSES_ICD.csv` \u2014 ICD-9 diagnosis codes\n\n2. **Google Colab**: Free tier works, but GPU recommended (Runtime \u2192 Change runtime type \u2192 GPU)\n\n3. **Time**:\n - Demo (5 epochs, 1K samples): ~30\u201345 min on GPU\n - Production (20 epochs, 10K samples): ~3\u20135 hrs on GPU\n\n## How It Works\n\n1. **Setup**: Install PyHealth and mount Google Drive\n2. **Upload Data**: Upload your MIMIC-III CSV files (persisted to Drive across sessions)\n3. **Configure**: Set hyperparameters (epochs, batch size, etc.)\n4. **Train**: Fine-tune PromptEHR on MIMIC-III (checkpoints saved to Drive)\n5. **Generate**: Create synthetic patients conditioned on sampled demographics\n6. **Download**: Get CSV and JSON files with synthetic data\n\n## Important Notes\n\n**Colab Timeout**: Free Colab sessions timeout after ~12 hours. For production training, consider:\n- Colab Pro for longer sessions\n- Running on a GPU cluster using `examples/slurm/`\n\n**Demo vs Production**:\n- Demo defaults (5 epochs, 1K samples) let you test the full pipeline quickly\n- Production settings (20 epochs, 10K samples) produce publication-quality results\n\n## References\n\n- [PromptEHR Paper](https://arxiv.org/abs/2211.01761) \u2014 Wang et al., EMNLP 2023\n- [PyHealth Documentation](https://pyhealth.readthedocs.io/)\n- [MIMIC-III Access](https://physionet.org/content/mimiciii/)" + "source": "# PromptEHR Synthetic Data Generation for MIMIC-III\n\n_Last updated: 2026-03-04 10:45:22 (UTC)_\n\nThis notebook trains PromptEHR on your MIMIC-III data and generates synthetic patients conditioned on patient demographics (age and gender).\n\n## What You'll Need\n\n1. **MIMIC-III Access**: Download these files from PhysioNet:\n - `PATIENTS.csv` \u2014 patient demographics (date of birth, gender)\n - `ADMISSIONS.csv` \u2014 hospital admission records (timestamps)\n - `DIAGNOSES_ICD.csv` \u2014 ICD-9 diagnosis codes\n\n2. **Google Colab**: Free tier works, but GPU recommended (Runtime \u2192 Change runtime type \u2192 GPU)\n\n3. **Time**:\n - Demo (5 epochs, 1K samples): ~30\u201345 min on GPU\n - Production (20 epochs, 10K samples): ~3\u20135 hrs on GPU\n\n## How It Works\n\n1. **Setup**: Install PyHealth and mount Google Drive\n2. **Upload Data**: Upload your MIMIC-III CSV files (persisted to Drive across sessions)\n3. **Configure**: Set hyperparameters (epochs, batch size, etc.)\n4. **Train**: Fine-tune PromptEHR on MIMIC-III (checkpoints saved to Drive)\n5. **Generate**: Create synthetic patients conditioned on sampled demographics\n6. **Download**: Get CSV and JSON files with synthetic data\n\n## Important Notes\n\n**Colab Timeout**: Free Colab sessions timeout after ~12 hours. For production training, consider:\n- Colab Pro for longer sessions\n- Running on a GPU cluster using `examples/slurm/`\n\n**Demo vs Production**:\n- Demo defaults (5 epochs, 1K samples) let you test the full pipeline quickly\n- Production settings (20 epochs, 10K samples) produce publication-quality results\n\n## References\n\n- [PromptEHR Paper](https://arxiv.org/abs/2211.01761) \u2014 Wang et al., EMNLP 2023\n- [PyHealth Documentation](https://pyhealth.readthedocs.io/)\n- [MIMIC-III Access](https://physionet.org/content/mimiciii/)" }, { "cell_type": "markdown", @@ -37,7 +37,7 @@ "outputs": [], "execution_count": null, "source": [ - "import subprocess\nimport sys\n\n# 1. Install PyHealth from GitHub \u2014 force-reinstall ensures Colab never uses a stale cached build.\n# (This may pull in old/mismatched transitive deps \u2014 we normalize them in step 2.)\nFORK = 'jalengg'\nBRANCH = 'promptehr-pr-integration'\ninstall_url = f\"git+https://github.com/{FORK}/PyHealth.git@{BRANCH}\"\nresult = subprocess.run(\n [sys.executable, \"-m\", \"pip\", \"install\", install_url,\n \"--quiet\", \"--no-cache-dir\", \"--force-reinstall\"],\n capture_output=True, text=True,\n)\nif result.returncode != 0:\n print(result.stderr)\n raise RuntimeError(\"PyHealth installation failed \u2014 see error above.\")\nprint(f\"\u2713 PyHealth installed from {FORK}/{BRANCH}\")\n\n# 2. Force-reinstall key transitive deps AFTER PyHealth install.\n# --force-reinstall on PyHealth above can leave numpy/scipy in a\n# mixed-version state (Python files from one version, compiled\n# extensions from another). We force-reinstall them here to guarantee\n# a consistent, fully-installed set:\n# - numpy~=2.2.0: Colab numpy 2.x has _center in _core.umath (required by scipy)\n# - scipy>=1.14: first scipy release with full numpy 2.x support\n# - scikit-learn>=1.5: compatible with scipy>=1.14\n# - Pillow>=10.4.0: consistent PIL internals (_Ink moved between versions)\nsubprocess.run(\n [sys.executable, \"-m\", \"pip\", \"install\",\n \"numpy~=2.2.0\", \"scipy>=1.14\", \"scikit-learn>=1.5\", \"Pillow>=10.4.0\",\n \"--quiet\", \"--no-cache-dir\", \"--force-reinstall\"],\n check=True,\n)\nprint(\"\u2713 numpy~=2.2.0, scipy>=1.14, scikit-learn>=1.5, Pillow>=10.4.0 reinstalled\")\n\n# Environment detection \u2014 MUST come before any google.colab import\ntry:\n import google.colab # noqa: F401\n IN_COLAB = True\nexcept ImportError:\n IN_COLAB = False\n\nimport os\nimport json\nimport glob\nimport random\nimport numpy as np\nimport torch\nimport pandas as pd\nimport matplotlib.pyplot as plt\nfrom datetime import datetime\nfrom collections import Counter\nfrom IPython.display import display\n\nprint(f\"Running in Colab: {IN_COLAB}\")\nprint(f\"PyTorch: {torch.__version__}\")\nif torch.cuda.is_available():\n print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n gb = torch.cuda.get_device_properties(0).total_memory / 1e9\n print(f\"GPU memory: {gb:.1f} GB\")\nelse:\n print(\"WARNING: No GPU detected. Training will be slow.\")\n print(\" \u2192 Runtime \u2192 Change runtime type \u2192 T4 GPU\")\n\n# Verify HuggingFace transformers version (>=4.48.3 required for use_cpu parameter)\nimport transformers\n_ver = tuple(int(x) for x in transformers.__version__.split(\".\")[:2])\nassert _ver >= (4, 48), (\n f\"transformers>=4.48.3 required (got {transformers.__version__}). \"\n \"Fix: pip install transformers --upgrade\"\n)\nprint(f\"transformers: {transformers.__version__} \u2713\")\nprint(\"\u2713 All setup complete\")" + "import subprocess\nimport sys\n\n# 1. Install PyHealth from GitHub \u2014 force-reinstall ensures Colab never uses a stale cached build.\n# (This may pull in old/mismatched transitive deps \u2014 we normalize them in step 2.)\nFORK = 'jalengg'\nBRANCH = 'promptehr-pr-integration'\ninstall_url = f\"git+https://github.com/{FORK}/PyHealth.git@{BRANCH}\"\nresult = subprocess.run(\n [sys.executable, \"-m\", \"pip\", \"install\", install_url,\n \"--quiet\", \"--no-cache-dir\", \"--force-reinstall\"],\n capture_output=True, text=True,\n)\nif result.returncode != 0:\n print(result.stderr)\n raise RuntimeError(\"PyHealth installation failed \u2014 see error above.\")\nprint(f\"\u2713 PyHealth installed from {FORK}/{BRANCH}\")\n\n# 2. Force-reinstall key transitive deps AFTER PyHealth install.\n# --force-reinstall on PyHealth above can leave numpy/scipy in a\n# mixed-version state (Python files from one version, compiled\n# extensions from another). We force-reinstall them here to guarantee\n# a consistent, fully-installed set:\n# - numpy>=2.0.0: ensures numpy 2.x (_center in _core.umath, required by scipy)\n# - scipy>=1.14: first scipy release with full numpy 2.x support\n# - scikit-learn>=1.5: compatible with scipy>=1.14\n# - Pillow>=10.4.0: consistent PIL internals (_Ink moved between versions)\nsubprocess.run(\n [sys.executable, \"-m\", \"pip\", \"install\",\n \"numpy>=2.0.0\", \"scipy>=1.14\", \"scikit-learn>=1.5\", \"Pillow>=10.4.0\",\n \"--quiet\", \"--no-cache-dir\", \"--force-reinstall\"],\n check=True,\n)\nprint(\"\u2713 numpy>=2.0.0, scipy>=1.14, scikit-learn>=1.5, Pillow>=10.4.0 reinstalled\")\n\n# Environment detection \u2014 MUST come before any google.colab import\ntry:\n import google.colab # noqa: F401\n IN_COLAB = True\nexcept ImportError:\n IN_COLAB = False\n\nimport os\nimport json\nimport glob\nimport random\nimport numpy as np\nimport torch\nimport pandas as pd\nimport matplotlib.pyplot as plt\nfrom datetime import datetime\nfrom collections import Counter\nfrom IPython.display import display\n\nprint(f\"Running in Colab: {IN_COLAB}\")\nprint(f\"PyTorch: {torch.__version__}\")\nif torch.cuda.is_available():\n print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n gb = torch.cuda.get_device_properties(0).total_memory / 1e9\n print(f\"GPU memory: {gb:.1f} GB\")\nelse:\n print(\"WARNING: No GPU detected. Training will be slow.\")\n print(\" \u2192 Runtime \u2192 Change runtime type \u2192 T4 GPU\")\n\n# Verify HuggingFace transformers version (>=4.48.3 required for use_cpu parameter)\nimport transformers\n_ver = tuple(int(x) for x in transformers.__version__.split(\".\")[:2])\nassert _ver >= (4, 48), (\n f\"transformers>=4.48.3 required (got {transformers.__version__}). \"\n \"Fix: pip install transformers --upgrade\"\n)\nprint(f\"transformers: {transformers.__version__} \u2713\")\nprint(\"\u2713 All setup complete\")" ] }, { diff --git a/pyhealth/processors/__init__.py b/pyhealth/processors/__init__.py index 70762dd9c..f24f1fa0c 100644 --- a/pyhealth/processors/__init__.py +++ b/pyhealth/processors/__init__.py @@ -18,10 +18,12 @@ def get_processor(name: str): # Import all processors so they register themselves +from .base_processor import FeatureProcessor try: from .image_processor import ImageProcessor -except ImportError: - pass # PIL/torchvision unavailable + _has_image_processor = True +except (ImportError, RuntimeError): + _has_image_processor = False # PIL/torchvision unavailable or broken from .label_processor import ( BinaryLabelProcessor, MultiClassLabelProcessor, @@ -49,17 +51,20 @@ def get_processor(name: str): from .timeseries_processor import TimeseriesProcessor try: from .time_image_processor import TimeImageProcessor -except ImportError: - pass # PIL/torchvision unavailable + _has_time_image_processor = True +except (ImportError, RuntimeError): + _has_time_image_processor = False # PIL/torchvision unavailable or broken from .audio_processor import AudioProcessor from .ignore_processor import IgnoreProcessor from .tuple_time_text_processor import TupleTimeTextProcessor -# Expose public API +# Expose public API — optional processors only listed if successfully imported __all__ = [ "FeatureProcessor", - "ImageProcessor", - "LabelProcessor", + "BinaryLabelProcessor", + "MultiClassLabelProcessor", + "MultiLabelProcessor", + "RegressionLabelProcessor", "MultiHotProcessor", "NestedFloatsProcessor", "NestedSequenceProcessor", @@ -71,7 +76,10 @@ def get_processor(name: str): "TensorProcessor", "TextProcessor", "TimeseriesProcessor", - "TimeImageProcessor", "AudioProcessor", "TupleTimeTextProcessor", ] +if _has_image_processor: + __all__.append("ImageProcessor") +if _has_time_image_processor: + __all__.append("TimeImageProcessor") From cbdd115b29bf7a83564fa8f2bac1cde3b1462e9b Mon Sep 17 00:00:00 2001 From: jalengg Date: Wed, 4 Mar 2026 04:31:35 -0600 Subject: [PATCH 27/37] Chore: fix notebook timestamp to 2026-03-04 10:31:20 (UTC) --- examples/promptehr_mimic3_colab.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/promptehr_mimic3_colab.ipynb b/examples/promptehr_mimic3_colab.ipynb index 9c4d4a02c..be4756f97 100644 --- a/examples/promptehr_mimic3_colab.ipynb +++ b/examples/promptehr_mimic3_colab.ipynb @@ -22,7 +22,7 @@ "cell_type": "markdown", "id": "preamble", "metadata": {}, - "source": "# PromptEHR Synthetic Data Generation for MIMIC-III\n\n_Last updated: 2026-03-04 10:45:22 (UTC)_\n\nThis notebook trains PromptEHR on your MIMIC-III data and generates synthetic patients conditioned on patient demographics (age and gender).\n\n## What You'll Need\n\n1. **MIMIC-III Access**: Download these files from PhysioNet:\n - `PATIENTS.csv` \u2014 patient demographics (date of birth, gender)\n - `ADMISSIONS.csv` \u2014 hospital admission records (timestamps)\n - `DIAGNOSES_ICD.csv` \u2014 ICD-9 diagnosis codes\n\n2. **Google Colab**: Free tier works, but GPU recommended (Runtime \u2192 Change runtime type \u2192 GPU)\n\n3. **Time**:\n - Demo (5 epochs, 1K samples): ~30\u201345 min on GPU\n - Production (20 epochs, 10K samples): ~3\u20135 hrs on GPU\n\n## How It Works\n\n1. **Setup**: Install PyHealth and mount Google Drive\n2. **Upload Data**: Upload your MIMIC-III CSV files (persisted to Drive across sessions)\n3. **Configure**: Set hyperparameters (epochs, batch size, etc.)\n4. **Train**: Fine-tune PromptEHR on MIMIC-III (checkpoints saved to Drive)\n5. **Generate**: Create synthetic patients conditioned on sampled demographics\n6. **Download**: Get CSV and JSON files with synthetic data\n\n## Important Notes\n\n**Colab Timeout**: Free Colab sessions timeout after ~12 hours. For production training, consider:\n- Colab Pro for longer sessions\n- Running on a GPU cluster using `examples/slurm/`\n\n**Demo vs Production**:\n- Demo defaults (5 epochs, 1K samples) let you test the full pipeline quickly\n- Production settings (20 epochs, 10K samples) produce publication-quality results\n\n## References\n\n- [PromptEHR Paper](https://arxiv.org/abs/2211.01761) \u2014 Wang et al., EMNLP 2023\n- [PyHealth Documentation](https://pyhealth.readthedocs.io/)\n- [MIMIC-III Access](https://physionet.org/content/mimiciii/)" + "source": "# PromptEHR Synthetic Data Generation for MIMIC-III\n\n_Last updated: 2026-03-04 10:31:20 (UTC)_\n\nThis notebook trains PromptEHR on your MIMIC-III data and generates synthetic patients conditioned on patient demographics (age and gender).\n\n## What You'll Need\n\n1. **MIMIC-III Access**: Download these files from PhysioNet:\n - `PATIENTS.csv` \u2014 patient demographics (date of birth, gender)\n - `ADMISSIONS.csv` \u2014 hospital admission records (timestamps)\n - `DIAGNOSES_ICD.csv` \u2014 ICD-9 diagnosis codes\n\n2. **Google Colab**: Free tier works, but GPU recommended (Runtime \u2192 Change runtime type \u2192 GPU)\n\n3. **Time**:\n - Demo (5 epochs, 1K samples): ~30\u201345 min on GPU\n - Production (20 epochs, 10K samples): ~3\u20135 hrs on GPU\n\n## How It Works\n\n1. **Setup**: Install PyHealth and mount Google Drive\n2. **Upload Data**: Upload your MIMIC-III CSV files (persisted to Drive across sessions)\n3. **Configure**: Set hyperparameters (epochs, batch size, etc.)\n4. **Train**: Fine-tune PromptEHR on MIMIC-III (checkpoints saved to Drive)\n5. **Generate**: Create synthetic patients conditioned on sampled demographics\n6. **Download**: Get CSV and JSON files with synthetic data\n\n## Important Notes\n\n**Colab Timeout**: Free Colab sessions timeout after ~12 hours. For production training, consider:\n- Colab Pro for longer sessions\n- Running on a GPU cluster using `examples/slurm/`\n\n**Demo vs Production**:\n- Demo defaults (5 epochs, 1K samples) let you test the full pipeline quickly\n- Production settings (20 epochs, 10K samples) produce publication-quality results\n\n## References\n\n- [PromptEHR Paper](https://arxiv.org/abs/2211.01761) \u2014 Wang et al., EMNLP 2023\n- [PyHealth Documentation](https://pyhealth.readthedocs.io/)\n- [MIMIC-III Access](https://physionet.org/content/mimiciii/)" }, { "cell_type": "markdown", From 8872b7d91d99801164fb6e6bdb93a91c521262bb Mon Sep 17 00:00:00 2001 From: jalengg Date: Wed, 4 Mar 2026 11:15:55 -0600 Subject: [PATCH 28/37] Fix: numpy version ceiling, force-reinstall cascade, Drive stale mount MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - pyproject.toml: numpy~=2.2.0 → numpy>=2.0.0 (removes <2.3 ceiling; prevents downgrade when Colab has numpy 2.3.x, which was the root cause of the recurring _center ImportError) - s1-setup: remove --force-reinstall and numpy from post-install step; use --upgrade instead (force-reinstall of scipy force-reinstalls numpy transitively, creating mixed-version compiled/Python state) - s3-upload: drive.mount(..., force_remount=True) to handle stale FUSE mount state that raised "Mountpoint must not already contain files" --- examples/promptehr_mimic3_colab.ipynb | 54 +++++++++++++-------------- pyproject.toml | 2 +- 2 files changed, 26 insertions(+), 30 deletions(-) diff --git a/examples/promptehr_mimic3_colab.ipynb b/examples/promptehr_mimic3_colab.ipynb index be4756f97..14d6bc1fb 100644 --- a/examples/promptehr_mimic3_colab.ipynb +++ b/examples/promptehr_mimic3_colab.ipynb @@ -22,7 +22,7 @@ "cell_type": "markdown", "id": "preamble", "metadata": {}, - "source": "# PromptEHR Synthetic Data Generation for MIMIC-III\n\n_Last updated: 2026-03-04 10:31:20 (UTC)_\n\nThis notebook trains PromptEHR on your MIMIC-III data and generates synthetic patients conditioned on patient demographics (age and gender).\n\n## What You'll Need\n\n1. **MIMIC-III Access**: Download these files from PhysioNet:\n - `PATIENTS.csv` \u2014 patient demographics (date of birth, gender)\n - `ADMISSIONS.csv` \u2014 hospital admission records (timestamps)\n - `DIAGNOSES_ICD.csv` \u2014 ICD-9 diagnosis codes\n\n2. **Google Colab**: Free tier works, but GPU recommended (Runtime \u2192 Change runtime type \u2192 GPU)\n\n3. **Time**:\n - Demo (5 epochs, 1K samples): ~30\u201345 min on GPU\n - Production (20 epochs, 10K samples): ~3\u20135 hrs on GPU\n\n## How It Works\n\n1. **Setup**: Install PyHealth and mount Google Drive\n2. **Upload Data**: Upload your MIMIC-III CSV files (persisted to Drive across sessions)\n3. **Configure**: Set hyperparameters (epochs, batch size, etc.)\n4. **Train**: Fine-tune PromptEHR on MIMIC-III (checkpoints saved to Drive)\n5. **Generate**: Create synthetic patients conditioned on sampled demographics\n6. **Download**: Get CSV and JSON files with synthetic data\n\n## Important Notes\n\n**Colab Timeout**: Free Colab sessions timeout after ~12 hours. For production training, consider:\n- Colab Pro for longer sessions\n- Running on a GPU cluster using `examples/slurm/`\n\n**Demo vs Production**:\n- Demo defaults (5 epochs, 1K samples) let you test the full pipeline quickly\n- Production settings (20 epochs, 10K samples) produce publication-quality results\n\n## References\n\n- [PromptEHR Paper](https://arxiv.org/abs/2211.01761) \u2014 Wang et al., EMNLP 2023\n- [PyHealth Documentation](https://pyhealth.readthedocs.io/)\n- [MIMIC-III Access](https://physionet.org/content/mimiciii/)" + "source": "# PromptEHR Synthetic Data Generation for MIMIC-III\n\n_Last updated: 2026-03-04 17:15:17 (UTC)_\n\nThis notebook trains PromptEHR on your MIMIC-III data and generates synthetic patients conditioned on patient demographics (age and gender).\n\n## What You'll Need\n\n1. **MIMIC-III Access**: Download these files from PhysioNet:\n - `PATIENTS.csv` — patient demographics (date of birth, gender)\n - `ADMISSIONS.csv` — hospital admission records (timestamps)\n - `DIAGNOSES_ICD.csv` — ICD-9 diagnosis codes\n\n2. **Google Colab**: Free tier works, but GPU recommended (Runtime → Change runtime type → GPU)\n\n3. **Time**:\n - Demo (5 epochs, 1K samples): ~30–45 min on GPU\n - Production (20 epochs, 10K samples): ~3–5 hrs on GPU\n\n## How It Works\n\n1. **Setup**: Install PyHealth and mount Google Drive\n2. **Upload Data**: Upload your MIMIC-III CSV files (persisted to Drive across sessions)\n3. **Configure**: Set hyperparameters (epochs, batch size, etc.)\n4. **Train**: Fine-tune PromptEHR on MIMIC-III (checkpoints saved to Drive)\n5. **Generate**: Create synthetic patients conditioned on sampled demographics\n6. **Download**: Get CSV and JSON files with synthetic data\n\n## Important Notes\n\n**Colab Timeout**: Free Colab sessions timeout after ~12 hours. For production training, consider:\n- Colab Pro for longer sessions\n- Running on a GPU cluster using `examples/slurm/`\n\n**Demo vs Production**:\n- Demo defaults (5 epochs, 1K samples) let you test the full pipeline quickly\n- Production settings (20 epochs, 10K samples) produce publication-quality results\n\n## References\n\n- [PromptEHR Paper](https://arxiv.org/abs/2211.01761) — Wang et al., EMNLP 2023\n- [PyHealth Documentation](https://pyhealth.readthedocs.io/)\n- [MIMIC-III Access](https://physionet.org/content/mimiciii/)" }, { "cell_type": "markdown", @@ -36,9 +36,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": [ - "import subprocess\nimport sys\n\n# 1. Install PyHealth from GitHub \u2014 force-reinstall ensures Colab never uses a stale cached build.\n# (This may pull in old/mismatched transitive deps \u2014 we normalize them in step 2.)\nFORK = 'jalengg'\nBRANCH = 'promptehr-pr-integration'\ninstall_url = f\"git+https://github.com/{FORK}/PyHealth.git@{BRANCH}\"\nresult = subprocess.run(\n [sys.executable, \"-m\", \"pip\", \"install\", install_url,\n \"--quiet\", \"--no-cache-dir\", \"--force-reinstall\"],\n capture_output=True, text=True,\n)\nif result.returncode != 0:\n print(result.stderr)\n raise RuntimeError(\"PyHealth installation failed \u2014 see error above.\")\nprint(f\"\u2713 PyHealth installed from {FORK}/{BRANCH}\")\n\n# 2. Force-reinstall key transitive deps AFTER PyHealth install.\n# --force-reinstall on PyHealth above can leave numpy/scipy in a\n# mixed-version state (Python files from one version, compiled\n# extensions from another). We force-reinstall them here to guarantee\n# a consistent, fully-installed set:\n# - numpy>=2.0.0: ensures numpy 2.x (_center in _core.umath, required by scipy)\n# - scipy>=1.14: first scipy release with full numpy 2.x support\n# - scikit-learn>=1.5: compatible with scipy>=1.14\n# - Pillow>=10.4.0: consistent PIL internals (_Ink moved between versions)\nsubprocess.run(\n [sys.executable, \"-m\", \"pip\", \"install\",\n \"numpy>=2.0.0\", \"scipy>=1.14\", \"scikit-learn>=1.5\", \"Pillow>=10.4.0\",\n \"--quiet\", \"--no-cache-dir\", \"--force-reinstall\"],\n check=True,\n)\nprint(\"\u2713 numpy>=2.0.0, scipy>=1.14, scikit-learn>=1.5, Pillow>=10.4.0 reinstalled\")\n\n# Environment detection \u2014 MUST come before any google.colab import\ntry:\n import google.colab # noqa: F401\n IN_COLAB = True\nexcept ImportError:\n IN_COLAB = False\n\nimport os\nimport json\nimport glob\nimport random\nimport numpy as np\nimport torch\nimport pandas as pd\nimport matplotlib.pyplot as plt\nfrom datetime import datetime\nfrom collections import Counter\nfrom IPython.display import display\n\nprint(f\"Running in Colab: {IN_COLAB}\")\nprint(f\"PyTorch: {torch.__version__}\")\nif torch.cuda.is_available():\n print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n gb = torch.cuda.get_device_properties(0).total_memory / 1e9\n print(f\"GPU memory: {gb:.1f} GB\")\nelse:\n print(\"WARNING: No GPU detected. Training will be slow.\")\n print(\" \u2192 Runtime \u2192 Change runtime type \u2192 T4 GPU\")\n\n# Verify HuggingFace transformers version (>=4.48.3 required for use_cpu parameter)\nimport transformers\n_ver = tuple(int(x) for x in transformers.__version__.split(\".\")[:2])\nassert _ver >= (4, 48), (\n f\"transformers>=4.48.3 required (got {transformers.__version__}). \"\n \"Fix: pip install transformers --upgrade\"\n)\nprint(f\"transformers: {transformers.__version__} \u2713\")\nprint(\"\u2713 All setup complete\")" - ] + "source": "import subprocess\nimport sys\n\n# 1. Install PyHealth from GitHub — force-reinstall ensures Colab never uses a stale cached build.\n# (This may pull in old/mismatched transitive deps — we normalize them in step 2.)\nFORK = 'jalengg'\nBRANCH = 'promptehr-pr-integration'\ninstall_url = f\"git+https://github.com/{FORK}/PyHealth.git@{BRANCH}\"\nresult = subprocess.run(\n [sys.executable, \"-m\", \"pip\", \"install\", install_url,\n \"--quiet\", \"--no-cache-dir\", \"--force-reinstall\"],\n capture_output=True, text=True,\n)\nif result.returncode != 0:\n print(result.stderr)\n raise RuntimeError(\"PyHealth installation failed — see error above.\")\nprint(f\"✓ PyHealth installed from {FORK}/{BRANCH}\")\n\n# 2. Upgrade optional deps that Colab may have at outdated versions.\n# We do NOT use --force-reinstall here — that would also force-reinstall\n# numpy/scipy and their transitive deps, creating a mixed-version state\n# (Python files from one version, compiled extensions from another).\n# PyHealth's pyproject.toml already pins numpy>=2.0.0, so step 1 ensures\n# a compatible numpy is installed. We just upgrade scipy and Pillow:\n# - scipy>=1.14: first release with full numpy 2.x support\n# - Pillow>=10.4.0: consistent PIL internals (_Ink moved between versions)\nsubprocess.run(\n [sys.executable, \"-m\", \"pip\", \"install\",\n \"scipy>=1.14\", \"Pillow>=10.4.0\",\n \"--quiet\", \"--no-cache-dir\", \"--upgrade\"],\n check=True,\n)\nprint(\"✓ scipy>=1.14, Pillow>=10.4.0 upgraded\")\n\n# Environment detection — MUST come before any google.colab import\ntry:\n import google.colab # noqa: F401\n IN_COLAB = True\nexcept ImportError:\n IN_COLAB = False\n\nimport os\nimport json\nimport glob\nimport random\nimport numpy as np\nimport torch\nimport pandas as pd\nimport matplotlib.pyplot as plt\nfrom datetime import datetime\nfrom collections import Counter\nfrom IPython.display import display\n\nprint(f\"Running in Colab: {IN_COLAB}\")\nprint(f\"PyTorch: {torch.__version__}\")\nif torch.cuda.is_available():\n print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n gb = torch.cuda.get_device_properties(0).total_memory / 1e9\n print(f\"GPU memory: {gb:.1f} GB\")\nelse:\n print(\"WARNING: No GPU detected. Training will be slow.\")\n print(\" → Runtime → Change runtime type → T4 GPU\")\n\n# Verify HuggingFace transformers version (>=4.48.3 required for use_cpu parameter)\nimport transformers\n_ver = tuple(int(x) for x in transformers.__version__.split(\".\")[:2])\nassert _ver >= (4, 48), (\n f\"transformers>=4.48.3 required (got {transformers.__version__}). \"\n \"Fix: pip install transformers --upgrade\"\n)\nprint(f\"transformers: {transformers.__version__} ✓\")\nprint(\"✓ All setup complete\")" }, { "cell_type": "markdown", @@ -50,7 +48,7 @@ "cell_type": "markdown", "id": "s2-desc", "metadata": {}, - "source": "Configure all parameters here. **This is the only cell you need to modify.**\n\n- **`PRESET = \"demo\"`** \u2014 5 epochs, 1 K synthetic patients, ~30\u201345 min on T4\n- **`PRESET = \"production\"`** \u2014 20 epochs, 10 K synthetic patients, ~3\u20135 hrs on T4" + "source": "Configure all parameters here. **This is the only cell you need to modify.**\n\n- **`PRESET = \"demo\"`** — 5 epochs, 1 K synthetic patients, ~30–45 min on T4\n- **`PRESET = \"production\"`** — 20 epochs, 10 K synthetic patients, ~3–5 hrs on T4" }, { "cell_type": "code", @@ -59,7 +57,7 @@ "outputs": [], "execution_count": null, "source": [ - "# ============================================================\n# CONFIGURATION \u2014 All modifiable parameters in one place\n# ============================================================\n\n# --- Preset ---\nPRESET = \"demo\" # \"demo\" or \"production\"\n\n# --- Training parameters ---\nif PRESET == \"demo\":\n EPOCHS = 5\n BATCH_SIZE = 16\n N_SYNTHETIC_SAMPLES = 1_000\n WARMUP_STEPS = 100\nelif PRESET == \"production\":\n EPOCHS = 20\n BATCH_SIZE = 16\n N_SYNTHETIC_SAMPLES = 10_000\n WARMUP_STEPS = 1_000\n\nLR = 1e-5 # Paper LR; low to avoid catastrophic forgetting of BART weights\nMAX_SEQ_LENGTH = 512 # Max tokens per patient (visits + special tokens)\n\n# --- Model architecture ---\nD_HIDDEN = 128 # Hidden dim for demographic prompt encoder\nPROMPT_LENGTH = 1 # Prompt vectors per demographic feature (1 is sufficient per paper)\n\n# --- BART backbone ---\n# \"facebook/bart-base\": pretrained BART (139 M params, 768 hidden dim).\n# PromptEHR fine-tunes these weights rather than training from scratch \u2014\n# the pretrained sequence modeling prior means even 20 epochs can produce good results.\nBART_CONFIG_NAME = \"facebook/bart-base\"\n\n# --- Generation parameters ---\nRANDOM_SAMPLING = True # True: nucleus sampling (diverse), False: greedy (deterministic)\nTEMPERATURE = 0.7 # Lower = more common codes. Higher = more rare/diverse codes.\nTOP_P = 0.95 # Nucleus sampling: sample from top 95% probability mass.\n\n# --- Reproducibility ---\nSEED = 42\n\n# --- Paths (all derived from BASE_DIR) ---\nBASE_DIR = '/content/drive/MyDrive/PromptEHR_Training' if IN_COLAB else './promptehr_training'\nDATA_DIR = f'{BASE_DIR}/data'\nCHECKPOINT_DIR = f'{BASE_DIR}/checkpoints'\nOUTPUT_DIR = f'{BASE_DIR}/output'\n\n# In Colab, Drive-backed dirs are created after mount (in s3-upload).\n# In local/SLURM environments, create them immediately.\nif not IN_COLAB:\n for d in [DATA_DIR, CHECKPOINT_DIR, OUTPUT_DIR]:\n os.makedirs(d, exist_ok=True)\n\nprint(f\"Preset: {PRESET}\")\nprint(f\"Epochs: {EPOCHS} | Batch size: {BATCH_SIZE} | LR: {LR}\")\nprint(f\"Synthetic: {N_SYNTHETIC_SAMPLES:,} patients\")\nprint(f\"Base directory: {BASE_DIR}\")\nprint(\"\u2713 Configuration complete\")" + "# ============================================================\n# CONFIGURATION — All modifiable parameters in one place\n# ============================================================\n\n# --- Preset ---\nPRESET = \"demo\" # \"demo\" or \"production\"\n\n# --- Training parameters ---\nif PRESET == \"demo\":\n EPOCHS = 5\n BATCH_SIZE = 16\n N_SYNTHETIC_SAMPLES = 1_000\n WARMUP_STEPS = 100\nelif PRESET == \"production\":\n EPOCHS = 20\n BATCH_SIZE = 16\n N_SYNTHETIC_SAMPLES = 10_000\n WARMUP_STEPS = 1_000\n\nLR = 1e-5 # Paper LR; low to avoid catastrophic forgetting of BART weights\nMAX_SEQ_LENGTH = 512 # Max tokens per patient (visits + special tokens)\n\n# --- Model architecture ---\nD_HIDDEN = 128 # Hidden dim for demographic prompt encoder\nPROMPT_LENGTH = 1 # Prompt vectors per demographic feature (1 is sufficient per paper)\n\n# --- BART backbone ---\n# \"facebook/bart-base\": pretrained BART (139 M params, 768 hidden dim).\n# PromptEHR fine-tunes these weights rather than training from scratch —\n# the pretrained sequence modeling prior means even 20 epochs can produce good results.\nBART_CONFIG_NAME = \"facebook/bart-base\"\n\n# --- Generation parameters ---\nRANDOM_SAMPLING = True # True: nucleus sampling (diverse), False: greedy (deterministic)\nTEMPERATURE = 0.7 # Lower = more common codes. Higher = more rare/diverse codes.\nTOP_P = 0.95 # Nucleus sampling: sample from top 95% probability mass.\n\n# --- Reproducibility ---\nSEED = 42\n\n# --- Paths (all derived from BASE_DIR) ---\nBASE_DIR = '/content/drive/MyDrive/PromptEHR_Training' if IN_COLAB else './promptehr_training'\nDATA_DIR = f'{BASE_DIR}/data'\nCHECKPOINT_DIR = f'{BASE_DIR}/checkpoints'\nOUTPUT_DIR = f'{BASE_DIR}/output'\n\n# In Colab, Drive-backed dirs are created after mount (in s3-upload).\n# In local/SLURM environments, create them immediately.\nif not IN_COLAB:\n for d in [DATA_DIR, CHECKPOINT_DIR, OUTPUT_DIR]:\n os.makedirs(d, exist_ok=True)\n\nprint(f\"Preset: {PRESET}\")\nprint(f\"Epochs: {EPOCHS} | Batch size: {BATCH_SIZE} | LR: {LR}\")\nprint(f\"Synthetic: {N_SYNTHETIC_SAMPLES:,} patients\")\nprint(f\"Base directory: {BASE_DIR}\")\nprint(\"✓ Configuration complete\")" ] }, { @@ -72,7 +70,7 @@ "cell_type": "markdown", "id": "s3-desc", "metadata": {}, - "source": "Upload your MIMIC-III CSV files. PromptEHR needs **3 files** (one more than HALO \u2014 `PATIENTS.csv` is required for demographic conditioning):\n\n1. `PATIENTS.csv` \u2014 date of birth and gender\n2. `ADMISSIONS.csv` \u2014 admission timestamps (used to compute age at first admission)\n3. `DIAGNOSES_ICD.csv` \u2014 ICD-9 diagnosis codes\n\nFiles persist across Colab sessions when saved to Google Drive.\n\n**No MIMIC-III?** The next cell automatically activates Demo Mode." + "source": "Upload your MIMIC-III CSV files. PromptEHR needs **3 files** (one more than HALO — `PATIENTS.csv` is required for demographic conditioning):\n\n1. `PATIENTS.csv` — date of birth and gender\n2. `ADMISSIONS.csv` — admission timestamps (used to compute age at first admission)\n3. `DIAGNOSES_ICD.csv` — ICD-9 diagnosis codes\n\nFiles persist across Colab sessions when saved to Google Drive.\n\n**No MIMIC-III?** The next cell automatically activates Demo Mode." }, { "cell_type": "code", @@ -80,9 +78,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": [ - "import shutil\nDEMO_MODE = False\n\n# Mount Drive (Colab only) \u2014 guard makes this cell idempotent (safe to re-run)\nif IN_COLAB:\n from google.colab import drive\n # os.path.ismount checks the actual filesystem mount, not just if\n # the directory path exists (which makedirs in s2-config may have created).\n if not os.path.ismount('/content/drive'):\n drive.mount('/content/drive')\n else:\n print(\"Drive already mounted\")\n print(\"\u2713 Google Drive mounted\")\n # Create Drive-backed directories now that Drive is mounted.\n for _d in [DATA_DIR, CHECKPOINT_DIR, OUTPUT_DIR]:\n os.makedirs(_d, exist_ok=True)\n\n# Check which files exist in the Drive-backed DATA_DIR\nrequired_files = {\n 'PATIENTS.csv': 'Patient demographics (DOB, gender)',\n 'ADMISSIONS.csv': 'Admission records (timestamps)',\n 'DIAGNOSES_ICD.csv': 'ICD-9 diagnosis codes',\n}\nexisting = {f: os.path.exists(f'{DATA_DIR}/{f}') for f in required_files}\nmissing = [f for f, ok in existing.items() if not ok]\n\nif not missing:\n # All files already in Drive \u2014 no upload needed\n print(\"\u2713 All MIMIC-III files found in Drive (no upload needed):\")\n for fname in required_files:\n size_mb = os.path.getsize(f'{DATA_DIR}/{fname}') / 1024 / 1024\n print(f\" {fname} ({size_mb:.1f} MB)\")\n print(f\"\\nFiles are reused from: {DATA_DIR}\")\n print(\"To force re-upload, delete files from that folder and re-run this cell.\")\nelse:\n print(\"MIMIC-III file status:\")\n for fname, desc in required_files.items():\n mark = \"\u2713\" if existing[fname] else \"\u2717 MISSING\"\n print(f\" {mark} {fname} \u2014 {desc}\")\n\n if IN_COLAB:\n print(f\"\\nUploading {len(missing)} missing file(s)...\")\n from google.colab import files as _colab_files\n uploaded = _colab_files.upload()\n\n # Normalize filenames \u2014 Colab renames duplicates as \"ADMISSIONS (1).csv\".\n # Match each upload to the required file it belongs to, then copy with\n # the canonical name so subsequent runs find the file in Drive.\n for uploaded_name, data in uploaded.items():\n matched = None\n for req in required_files:\n base = req.replace('.csv', '')\n if base in uploaded_name and uploaded_name.endswith('.csv'):\n matched = req\n break\n if matched:\n # Write upload bytes to /content/ then copy to Drive-backed dest\n tmp = f'/content/{uploaded_name}'\n with open(tmp, 'wb') as f:\n f.write(data)\n dest = f'{DATA_DIR}/{matched}'\n shutil.copy(tmp, dest)\n size_mb = os.path.getsize(dest) / 1024 / 1024\n print(f\" \u2713 Saved {matched} ({size_mb:.1f} MB) \u2192 {dest}\")\n else:\n print(f\" \u26a0 Unrecognised file: {uploaded_name} (skipped)\")\n\n missing = [f for f in required_files if not os.path.exists(f'{DATA_DIR}/{f}')]\n\n if missing:\n print(f\"\\nMIMIC-III files not available ({missing}).\")\n print(\"\u2192 Activating Demo Mode \u2014 full pipeline with synthetic stand-in data.\")\n DEMO_MODE = True\n else:\n print(\"\\n\u2713 All 3 MIMIC-III files present. Running in MIMIC-III mode.\")" - ] + "source": "import shutil\nDEMO_MODE = False\n\n# Mount Drive (Colab only) — guard makes this cell idempotent (safe to re-run)\nif IN_COLAB:\n from google.colab import drive\n # os.path.ismount checks the actual filesystem mount, not just if\n # the directory path exists (which makedirs in s2-config may have created).\n # force_remount=True handles stale mount state that would otherwise raise\n # ValueError: \"Mountpoint must not already contain files\".\n if not os.path.ismount('/content/drive'):\n drive.mount('/content/drive', force_remount=True)\n else:\n print(\"Drive already mounted\")\n print(\"✓ Google Drive mounted\")\n # Create Drive-backed directories now that Drive is mounted.\n for _d in [DATA_DIR, CHECKPOINT_DIR, OUTPUT_DIR]:\n os.makedirs(_d, exist_ok=True)\n\n# Check which files exist in the Drive-backed DATA_DIR\nrequired_files = {\n 'PATIENTS.csv': 'Patient demographics (DOB, gender)',\n 'ADMISSIONS.csv': 'Admission records (timestamps)',\n 'DIAGNOSES_ICD.csv': 'ICD-9 diagnosis codes',\n}\nexisting = {f: os.path.exists(f'{DATA_DIR}/{f}') for f in required_files}\nmissing = [f for f, ok in existing.items() if not ok]\n\nif not missing:\n # All files already in Drive — no upload needed\n print(\"✓ All MIMIC-III files found in Drive (no upload needed):\")\n for fname in required_files:\n size_mb = os.path.getsize(f'{DATA_DIR}/{fname}') / 1024 / 1024\n print(f\" {fname} ({size_mb:.1f} MB)\")\n print(f\"\\nFiles are reused from: {DATA_DIR}\")\n print(\"To force re-upload, delete files from that folder and re-run this cell.\")\nelse:\n print(\"MIMIC-III file status:\")\n for fname, desc in required_files.items():\n mark = \"✓\" if existing[fname] else \"✗ MISSING\"\n print(f\" {mark} {fname} — {desc}\")\n\n if IN_COLAB:\n print(f\"\\nUploading {len(missing)} missing file(s)...\")\n from google.colab import files as _colab_files\n uploaded = _colab_files.upload()\n\n # Normalize filenames — Colab renames duplicates as \"ADMISSIONS (1).csv\".\n # Match each upload to the required file it belongs to, then copy with\n # the canonical name so subsequent runs find the file in Drive.\n for uploaded_name, data in uploaded.items():\n matched = None\n for req in required_files:\n base = req.replace('.csv', '')\n if base in uploaded_name and uploaded_name.endswith('.csv'):\n matched = req\n break\n if matched:\n # Write upload bytes to /content/ then copy to Drive-backed dest\n tmp = f'/content/{uploaded_name}'\n with open(tmp, 'wb') as f:\n f.write(data)\n dest = f'{DATA_DIR}/{matched}'\n shutil.copy(tmp, dest)\n size_mb = os.path.getsize(dest) / 1024 / 1024\n print(f\" ✓ Saved {matched} ({size_mb:.1f} MB) → {dest}\")\n else:\n print(f\" ⚠ Unrecognised file: {uploaded_name} (skipped)\")\n\n missing = [f for f in required_files if not os.path.exists(f'{DATA_DIR}/{f}')]\n\n if missing:\n print(f\"\\nMIMIC-III files not available ({missing}).\")\n print(\"→ Activating Demo Mode — full pipeline with synthetic stand-in data.\")\n DEMO_MODE = True\n else:\n print(\"\\n✓ All 3 MIMIC-III files present. Running in MIMIC-III mode.\")" }, { "cell_type": "code", @@ -90,7 +86,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "if DEMO_MODE:\n print(\"Setting up Demo Mode data...\")\n from pyhealth.datasets.sample_dataset import InMemorySampleDataset\n\n # Synthetic stand-in: 200 patients, 2-6 visits, realistic ICD-9 codes.\n # Exercises the full pipeline without any real patient data.\n random.seed(SEED)\n icd9_pool = [\n \"428.0\", \"401.9\", \"250.00\", \"272.4\", \"410.71\", \"486\",\n \"585.3\", \"V58.61\", \"412\", \"414.01\", \"276.1\", \"285.9\",\n \"584.9\", \"305.1\", \"290.0\", \"427.31\", \"518.81\", \"496\",\n \"038.9\", \"599.0\",\n ]\n demo_samples = []\n for i in range(200):\n n_visits = random.randint(2, 6)\n visits = [random.sample(icd9_pool, random.randint(1, 5)) for _ in range(n_visits)]\n demo_samples.append({\n \"patient_id\": f\"DEMO_{i:04d}\",\n \"visits\": visits,\n \"age\": float(random.randint(18, 89)),\n \"gender\": random.randint(0, 1),\n })\n print(f\"\u2713 Demo dataset: {len(demo_samples)} patients, up to 6 visits each\")\n print(\" (Replace with real MIMIC-III data for publication-quality results)\")" + "source": "if DEMO_MODE:\n print(\"Setting up Demo Mode data...\")\n from pyhealth.datasets.sample_dataset import InMemorySampleDataset\n\n # Synthetic stand-in: 200 patients, 2-6 visits, realistic ICD-9 codes.\n # Exercises the full pipeline without any real patient data.\n random.seed(SEED)\n icd9_pool = [\n \"428.0\", \"401.9\", \"250.00\", \"272.4\", \"410.71\", \"486\",\n \"585.3\", \"V58.61\", \"412\", \"414.01\", \"276.1\", \"285.9\",\n \"584.9\", \"305.1\", \"290.0\", \"427.31\", \"518.81\", \"496\",\n \"038.9\", \"599.0\",\n ]\n demo_samples = []\n for i in range(200):\n n_visits = random.randint(2, 6)\n visits = [random.sample(icd9_pool, random.randint(1, 5)) for _ in range(n_visits)]\n demo_samples.append({\n \"patient_id\": f\"DEMO_{i:04d}\",\n \"visits\": visits,\n \"age\": float(random.randint(18, 89)),\n \"gender\": random.randint(0, 1),\n })\n print(f\"✓ Demo dataset: {len(demo_samples)} patients, up to 6 visits each\")\n print(\" (Replace with real MIMIC-III data for publication-quality results)\")" }, { "cell_type": "code", @@ -98,7 +94,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "if not DEMO_MODE:\n print(\"Validating MIMIC-III files...\")\n _patients = pd.read_csv(f'{DATA_DIR}/PATIENTS.csv')\n assert 'SUBJECT_ID' in _patients.columns, \"PATIENTS.csv missing SUBJECT_ID\"\n assert 'GENDER' in _patients.columns, \"PATIENTS.csv missing GENDER\"\n assert 'DOB' in _patients.columns, \"PATIENTS.csv missing DOB\"\n print(f\"\u2713 PATIENTS.csv: {len(_patients):>8,} rows\")\n\n _admissions = pd.read_csv(f'{DATA_DIR}/ADMISSIONS.csv')\n assert 'SUBJECT_ID' in _admissions.columns, \"ADMISSIONS.csv missing SUBJECT_ID\"\n assert 'HADM_ID' in _admissions.columns, \"ADMISSIONS.csv missing HADM_ID\"\n print(f\"\u2713 ADMISSIONS.csv: {len(_admissions):>8,} rows\")\n\n _diagnoses = pd.read_csv(f'{DATA_DIR}/DIAGNOSES_ICD.csv')\n assert 'ICD9_CODE' in _diagnoses.columns, \"DIAGNOSES_ICD.csv missing ICD9_CODE\"\n print(f\"\u2713 DIAGNOSES_ICD.csv: {len(_diagnoses):>8,} rows\")\n\n del _patients, _admissions, _diagnoses # free memory\n print(\"\\n\u2713 All files validated successfully\")" + "source": "if not DEMO_MODE:\n print(\"Validating MIMIC-III files...\")\n _patients = pd.read_csv(f'{DATA_DIR}/PATIENTS.csv')\n assert 'SUBJECT_ID' in _patients.columns, \"PATIENTS.csv missing SUBJECT_ID\"\n assert 'GENDER' in _patients.columns, \"PATIENTS.csv missing GENDER\"\n assert 'DOB' in _patients.columns, \"PATIENTS.csv missing DOB\"\n print(f\"✓ PATIENTS.csv: {len(_patients):>8,} rows\")\n\n _admissions = pd.read_csv(f'{DATA_DIR}/ADMISSIONS.csv')\n assert 'SUBJECT_ID' in _admissions.columns, \"ADMISSIONS.csv missing SUBJECT_ID\"\n assert 'HADM_ID' in _admissions.columns, \"ADMISSIONS.csv missing HADM_ID\"\n print(f\"✓ ADMISSIONS.csv: {len(_admissions):>8,} rows\")\n\n _diagnoses = pd.read_csv(f'{DATA_DIR}/DIAGNOSES_ICD.csv')\n assert 'ICD9_CODE' in _diagnoses.columns, \"DIAGNOSES_ICD.csv missing ICD9_CODE\"\n print(f\"✓ DIAGNOSES_ICD.csv: {len(_diagnoses):>8,} rows\")\n\n del _patients, _admissions, _diagnoses # free memory\n print(\"\\n✓ All files validated successfully\")" }, { "cell_type": "markdown", @@ -110,7 +106,7 @@ "cell_type": "markdown", "id": "s4-desc", "metadata": {}, - "source": "**What happens during training:**\n\n1. **Dataset loading**: PyHealth reads MIMIC-III and creates one sample per patient (nested visit sequences + demographics: age at first admission, gender).\n2. **Tokenization**: Each ICD-9 code is mapped to a unique BART token ID. Special tokens mark visit boundaries: `[VISIT_START]`, `[VISIT_END]`, `[SEQ_END]`.\n3. **Demographic prompts**: Age and gender are encoded into learned prompt vectors prepended to the BART encoder input \u2014 steering the model toward age/gender-appropriate diagnosis patterns.\n4. **Fine-tuning**: HuggingFace Trainer fine-tunes the BART Seq2Seq model to predict the next token conditioned on the demographic prompts.\n5. **Checkpoint**: Saved to `{CHECKPOINT_DIR}/checkpoint.pt` after training.\n\nThe `WARMUP_STEPS` ramp up the learning rate gradually during early training, preventing catastrophic forgetting of BART's pretrained sequence modeling capabilities." + "source": "**What happens during training:**\n\n1. **Dataset loading**: PyHealth reads MIMIC-III and creates one sample per patient (nested visit sequences + demographics: age at first admission, gender).\n2. **Tokenization**: Each ICD-9 code is mapped to a unique BART token ID. Special tokens mark visit boundaries: `[VISIT_START]`, `[VISIT_END]`, `[SEQ_END]`.\n3. **Demographic prompts**: Age and gender are encoded into learned prompt vectors prepended to the BART encoder input — steering the model toward age/gender-appropriate diagnosis patterns.\n4. **Fine-tuning**: HuggingFace Trainer fine-tunes the BART Seq2Seq model to predict the next token conditioned on the demographic prompts.\n5. **Checkpoint**: Saved to `{CHECKPOINT_DIR}/checkpoint.pt` after training.\n\nThe `WARMUP_STEPS` ramp up the learning rate gradually during early training, preventing catastrophic forgetting of BART's pretrained sequence modeling capabilities." }, { "cell_type": "code", @@ -118,7 +114,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "# Set all random seeds before any stochastic operation\ntorch.manual_seed(SEED)\nnp.random.seed(SEED)\nrandom.seed(SEED)\nif torch.cuda.is_available():\n torch.cuda.manual_seed_all(SEED)\n torch.backends.cudnn.deterministic = True\nprint(f\"\u2713 Random seed set to {SEED}\")\n\nfrom pyhealth.datasets import split_by_patient\nfrom pyhealth.models import PromptEHR\n\nif not DEMO_MODE:\n from pyhealth.datasets import MIMIC3Dataset\n from pyhealth.tasks import promptehr_generation_mimic3_fn\n\n print(\"\\nLoading MIMIC-III dataset (this may take a few minutes)...\")\n dataset = MIMIC3Dataset(\n root=DATA_DIR,\n tables=[\"patients\", \"admissions\", \"diagnoses_icd\"],\n code_mapping={},\n )\n print(f\"Loaded {len(dataset.patients):,} patients\")\n\n print(\"Applying PromptEHR generation task...\")\n sample_dataset = dataset.set_task(promptehr_generation_mimic3_fn)\n print(f\"Eligible patients (\u22652 visits with ICD-9 codes): {len(sample_dataset):,}\")\nelse:\n from pyhealth.datasets.sample_dataset import InMemorySampleDataset\n sample_dataset = InMemorySampleDataset(\n samples=demo_samples,\n input_schema={\"visits\": \"nested_sequence\"},\n output_schema={},\n )\n print(f\"Demo dataset ready: {len(sample_dataset)} patients\")\n\ntrain_dataset, val_dataset, _ = split_by_patient(sample_dataset, [0.8, 0.1, 0.1])\nprint(f\"\\nSplit: {len(train_dataset):,} train / {len(val_dataset):,} val patients\")" + "source": "# Set all random seeds before any stochastic operation\ntorch.manual_seed(SEED)\nnp.random.seed(SEED)\nrandom.seed(SEED)\nif torch.cuda.is_available():\n torch.cuda.manual_seed_all(SEED)\n torch.backends.cudnn.deterministic = True\nprint(f\"✓ Random seed set to {SEED}\")\n\nfrom pyhealth.datasets import split_by_patient\nfrom pyhealth.models import PromptEHR\n\nif not DEMO_MODE:\n from pyhealth.datasets import MIMIC3Dataset\n from pyhealth.tasks import promptehr_generation_mimic3_fn\n\n print(\"\\nLoading MIMIC-III dataset (this may take a few minutes)...\")\n dataset = MIMIC3Dataset(\n root=DATA_DIR,\n tables=[\"patients\", \"admissions\", \"diagnoses_icd\"],\n code_mapping={},\n )\n print(f\"Loaded {len(dataset.patients):,} patients\")\n\n print(\"Applying PromptEHR generation task...\")\n sample_dataset = dataset.set_task(promptehr_generation_mimic3_fn)\n print(f\"Eligible patients (≥2 visits with ICD-9 codes): {len(sample_dataset):,}\")\nelse:\n from pyhealth.datasets.sample_dataset import InMemorySampleDataset\n sample_dataset = InMemorySampleDataset(\n samples=demo_samples,\n input_schema={\"visits\": \"nested_sequence\"},\n output_schema={},\n )\n print(f\"Demo dataset ready: {len(sample_dataset)} patients\")\n\ntrain_dataset, val_dataset, _ = split_by_patient(sample_dataset, [0.8, 0.1, 0.1])\nprint(f\"\\nSplit: {len(train_dataset):,} train / {len(val_dataset):,} val patients\")" }, { "cell_type": "code", @@ -126,7 +122,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "# Save config alongside checkpoint for reproducibility\n_config = {k: str(v) for k, v in globals().items()\n if k.isupper() and not k.startswith('_')\n and isinstance(v, (str, int, float, bool))}\n_config['timestamp'] = datetime.now().isoformat()\n_config_path = f'{CHECKPOINT_DIR}/config.json'\nwith open(_config_path, 'w') as f:\n json.dump(_config, f, indent=2)\nprint(f\"\u2713 Config saved to {_config_path}\")\n\n# Initialize model\nprint(\"\\nInitializing PromptEHR model...\")\nmodel = PromptEHR(\n dataset=train_dataset,\n n_num_features=1, # 1 continuous demographic feature: age\n cat_cardinalities=[2], # 1 categorical feature: gender (binary: 0=male, 1=female)\n d_hidden=D_HIDDEN,\n prompt_length=PROMPT_LENGTH,\n bart_config_name=BART_CONFIG_NAME,\n epochs=EPOCHS,\n batch_size=BATCH_SIZE,\n lr=LR,\n warmup_steps=WARMUP_STEPS,\n max_seq_length=MAX_SEQ_LENGTH,\n save_dir=CHECKPOINT_DIR,\n)\n\nn_special = 7 # PAD, BOS, EOS, UNK, VISIT_START, VISIT_END, SEQ_END\nn_codes = model._vocab.total_size - n_special\ntotal_params = sum(p.numel() for p in model.parameters())\nprint(f\"\u2713 PromptEHR initialized\")\nprint(f\" Vocabulary: {model._vocab.total_size} tokens \"\n f\"({n_codes} ICD-9 codes + {n_special} special tokens)\")\nprint(f\" Parameters: {total_params:,}\")" + "source": "# Save config alongside checkpoint for reproducibility\n_config = {k: str(v) for k, v in globals().items()\n if k.isupper() and not k.startswith('_')\n and isinstance(v, (str, int, float, bool))}\n_config['timestamp'] = datetime.now().isoformat()\n_config_path = f'{CHECKPOINT_DIR}/config.json'\nwith open(_config_path, 'w') as f:\n json.dump(_config, f, indent=2)\nprint(f\"✓ Config saved to {_config_path}\")\n\n# Initialize model\nprint(\"\\nInitializing PromptEHR model...\")\nmodel = PromptEHR(\n dataset=train_dataset,\n n_num_features=1, # 1 continuous demographic feature: age\n cat_cardinalities=[2], # 1 categorical feature: gender (binary: 0=male, 1=female)\n d_hidden=D_HIDDEN,\n prompt_length=PROMPT_LENGTH,\n bart_config_name=BART_CONFIG_NAME,\n epochs=EPOCHS,\n batch_size=BATCH_SIZE,\n lr=LR,\n warmup_steps=WARMUP_STEPS,\n max_seq_length=MAX_SEQ_LENGTH,\n save_dir=CHECKPOINT_DIR,\n)\n\nn_special = 7 # PAD, BOS, EOS, UNK, VISIT_START, VISIT_END, SEQ_END\nn_codes = model._vocab.total_size - n_special\ntotal_params = sum(p.numel() for p in model.parameters())\nprint(f\"✓ PromptEHR initialized\")\nprint(f\" Vocabulary: {model._vocab.total_size} tokens \"\n f\"({n_codes} ICD-9 codes + {n_special} special tokens)\")\nprint(f\" Parameters: {total_params:,}\")" }, { "cell_type": "code", @@ -134,7 +130,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "print(\"Starting training...\")\nprint(\"HuggingFace Trainer will print step-by-step progress below.\")\nprint(\"=\" * 60)\n\nmodel.train_model(train_dataset, val_dataset=val_dataset)\n\nprint(\"=\" * 60)\nprint(\"\u2713 Training complete!\")\nprint(f\" Checkpoint: {CHECKPOINT_DIR}/checkpoint.pt\")" + "source": "print(\"Starting training...\")\nprint(\"HuggingFace Trainer will print step-by-step progress below.\")\nprint(\"=\" * 60)\n\nmodel.train_model(train_dataset, val_dataset=val_dataset)\n\nprint(\"=\" * 60)\nprint(\"✓ Training complete!\")\nprint(f\" Checkpoint: {CHECKPOINT_DIR}/checkpoint.pt\")" }, { "cell_type": "code", @@ -142,7 +138,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "# Plot training loss from HuggingFace Trainer logs\n_state_files = glob.glob(f'{CHECKPOINT_DIR}/**/trainer_state.json', recursive=True)\n\nif _state_files:\n with open(_state_files[0]) as f:\n _log = json.load(f)['log_history']\n _steps = [e['step'] for e in _log if 'loss' in e]\n _losses = [e['loss'] for e in _log if 'loss' in e]\n\n if _steps:\n fig, ax = plt.subplots(figsize=(9, 4))\n ax.plot(_steps, _losses, 'b-o', linewidth=1.5, markersize=4, label='Training loss')\n ax.set_xlabel('Training step', fontsize=12)\n ax.set_ylabel('Cross-entropy loss', fontsize=12)\n ax.set_title('PromptEHR Training Loss', fontsize=14)\n ax.legend(); ax.grid(alpha=0.3)\n plt.tight_layout()\n _loss_plot = f'{OUTPUT_DIR}/training_loss.png'\n plt.savefig(_loss_plot, dpi=150); plt.show()\n print(f\"Initial loss: {_losses[0]:.4f} \u2192 Final loss: {_losses[-1]:.4f}\")\n print(f\"Plot saved to: {_loss_plot}\")\n else:\n print(\"No loss values recorded (too few steps for demo preset).\")\nelse:\n print(\"trainer_state.json not found \u2014 skipping loss curve.\")\n print(\"(Expected for very short demo runs.)\")" + "source": "# Plot training loss from HuggingFace Trainer logs\n_state_files = glob.glob(f'{CHECKPOINT_DIR}/**/trainer_state.json', recursive=True)\n\nif _state_files:\n with open(_state_files[0]) as f:\n _log = json.load(f)['log_history']\n _steps = [e['step'] for e in _log if 'loss' in e]\n _losses = [e['loss'] for e in _log if 'loss' in e]\n\n if _steps:\n fig, ax = plt.subplots(figsize=(9, 4))\n ax.plot(_steps, _losses, 'b-o', linewidth=1.5, markersize=4, label='Training loss')\n ax.set_xlabel('Training step', fontsize=12)\n ax.set_ylabel('Cross-entropy loss', fontsize=12)\n ax.set_title('PromptEHR Training Loss', fontsize=14)\n ax.legend(); ax.grid(alpha=0.3)\n plt.tight_layout()\n _loss_plot = f'{OUTPUT_DIR}/training_loss.png'\n plt.savefig(_loss_plot, dpi=150); plt.show()\n print(f\"Initial loss: {_losses[0]:.4f} → Final loss: {_losses[-1]:.4f}\")\n print(f\"Plot saved to: {_loss_plot}\")\n else:\n print(\"No loss values recorded (too few steps for demo preset).\")\nelse:\n print(\"trainer_state.json not found — skipping loss curve.\")\n print(\"(Expected for very short demo runs.)\")" }, { "cell_type": "markdown", @@ -154,7 +150,7 @@ "cell_type": "markdown", "id": "s5-desc", "metadata": {}, - "source": "**How generation works:**\n\n1. **Demographic sampling**: For each synthetic patient, `synthesize_dataset` draws an `(age, gender)` pair from `model._demo_pool` \u2014 the real training population. This ensures the synthetic cohort's demographic profile mirrors MIMIC-III.\n2. **Prompt conditioning**: The sampled demographics are encoded into prompt vectors and prepended to the BART encoder input.\n3. **Autoregressive decoding**: BART generates tokens one at a time. Special tokens `[VISIT_START]` and `[VISIT_END]` structure the output into visits; `[SEQ_END]` ends the patient sequence.\n4. **Decoding**: Token IDs are mapped back to ICD-9 code strings.\n\n`RANDOM_SAMPLING = True` (default): nucleus sampling \u2014 diverse, realistic output. \n`RANDOM_SAMPLING = False`: greedy decoding \u2014 deterministic, may repeat common patterns." + "source": "**How generation works:**\n\n1. **Demographic sampling**: For each synthetic patient, `synthesize_dataset` draws an `(age, gender)` pair from `model._demo_pool` — the real training population. This ensures the synthetic cohort's demographic profile mirrors MIMIC-III.\n2. **Prompt conditioning**: The sampled demographics are encoded into prompt vectors and prepended to the BART encoder input.\n3. **Autoregressive decoding**: BART generates tokens one at a time. Special tokens `[VISIT_START]` and `[VISIT_END]` structure the output into visits; `[SEQ_END]` ends the patient sequence.\n4. **Decoding**: Token IDs are mapped back to ICD-9 code strings.\n\n`RANDOM_SAMPLING = True` (default): nucleus sampling — diverse, realistic output. \n`RANDOM_SAMPLING = False`: greedy decoding — deterministic, may repeat common patterns." }, { "cell_type": "code", @@ -162,7 +158,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "print(f\"Generating {N_SYNTHETIC_SAMPLES:,} synthetic patients...\")\nprint(f\" Sampling: {'nucleus (random)' if RANDOM_SAMPLING else 'greedy'}\"\n + (f\", temperature={TEMPERATURE}, top_p={TOP_P}\" if RANDOM_SAMPLING else \"\"))\nprint(\"(This may take several minutes...)\")\n\nsynthetic = model.synthesize_dataset(\n num_samples=N_SYNTHETIC_SAMPLES,\n random_sampling=RANDOM_SAMPLING,\n)\n\nprint(f\"\\n\u2713 Generated {len(synthetic):,} synthetic patients\")\n\n# Preview\n_preview = []\nfor p in synthetic[:10]:\n _v0 = p[\"visits\"][0] if p[\"visits\"] else []\n _sample = \", \".join(_v0[:4]) + (\"...\" if len(_v0) > 4 else \"\")\n _preview.append({\n \"patient_id\": p[\"patient_id\"],\n \"n_visits\": len(p[\"visits\"]),\n \"total_codes\": sum(len(v) for v in p[\"visits\"]),\n \"first_visit_codes\": _sample or \"(empty)\",\n })\ndisplay(pd.DataFrame(_preview))" + "source": "print(f\"Generating {N_SYNTHETIC_SAMPLES:,} synthetic patients...\")\nprint(f\" Sampling: {'nucleus (random)' if RANDOM_SAMPLING else 'greedy'}\"\n + (f\", temperature={TEMPERATURE}, top_p={TOP_P}\" if RANDOM_SAMPLING else \"\"))\nprint(\"(This may take several minutes...)\")\n\nsynthetic = model.synthesize_dataset(\n num_samples=N_SYNTHETIC_SAMPLES,\n random_sampling=RANDOM_SAMPLING,\n)\n\nprint(f\"\\n✓ Generated {len(synthetic):,} synthetic patients\")\n\n# Preview\n_preview = []\nfor p in synthetic[:10]:\n _v0 = p[\"visits\"][0] if p[\"visits\"] else []\n _sample = \", \".join(_v0[:4]) + (\"...\" if len(_v0) > 4 else \"\")\n _preview.append({\n \"patient_id\": p[\"patient_id\"],\n \"n_visits\": len(p[\"visits\"]),\n \"total_codes\": sum(len(v) for v in p[\"visits\"]),\n \"first_visit_codes\": _sample or \"(empty)\",\n })\ndisplay(pd.DataFrame(_preview))" }, { "cell_type": "code", @@ -170,7 +166,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "# Save as JSON (full nested records \u2014 directly loadable back into PyHealth)\njson_path = f'{OUTPUT_DIR}/synthetic_patients.json'\nwith open(json_path, 'w') as f:\n json.dump(synthetic, f, indent=2)\nprint(f\"\u2713 {len(synthetic):,} patients \u2192 {json_path}\")\n\n# Save as CSV (flat SUBJECT_ID, VISIT_NUM, ICD9_CODE \u2014 matches MIMIC-III output schema)\n_rows = []\nfor p in synthetic:\n for _vnum, _visit in enumerate(p[\"visits\"], 1):\n for _code in _visit:\n _rows.append({\"SUBJECT_ID\": p[\"patient_id\"],\n \"VISIT_NUM\": _vnum,\n \"ICD9_CODE\": _code})\ndf_synthetic = pd.DataFrame(_rows)\ncsv_path = f'{OUTPUT_DIR}/synthetic_patients.csv'\ndf_synthetic.to_csv(csv_path, index=False)\nprint(f\"\u2713 {len(df_synthetic):,} records \u2192 {csv_path}\")\nprint(f\" Columns: SUBJECT_ID, VISIT_NUM, ICD9_CODE\")\nprint(\"\\nSample rows:\")\ndisplay(df_synthetic.head(8))" + "source": "# Save as JSON (full nested records — directly loadable back into PyHealth)\njson_path = f'{OUTPUT_DIR}/synthetic_patients.json'\nwith open(json_path, 'w') as f:\n json.dump(synthetic, f, indent=2)\nprint(f\"✓ {len(synthetic):,} patients → {json_path}\")\n\n# Save as CSV (flat SUBJECT_ID, VISIT_NUM, ICD9_CODE — matches MIMIC-III output schema)\n_rows = []\nfor p in synthetic:\n for _vnum, _visit in enumerate(p[\"visits\"], 1):\n for _code in _visit:\n _rows.append({\"SUBJECT_ID\": p[\"patient_id\"],\n \"VISIT_NUM\": _vnum,\n \"ICD9_CODE\": _code})\ndf_synthetic = pd.DataFrame(_rows)\ncsv_path = f'{OUTPUT_DIR}/synthetic_patients.csv'\ndf_synthetic.to_csv(csv_path, index=False)\nprint(f\"✓ {len(df_synthetic):,} records → {csv_path}\")\nprint(f\" Columns: SUBJECT_ID, VISIT_NUM, ICD9_CODE\")\nprint(\"\\nSample rows:\")\ndisplay(df_synthetic.head(8))" }, { "cell_type": "markdown", @@ -184,7 +180,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "print(\"=\" * 60)\nprint(\"SYNTHETIC DATASET STATISTICS\")\nprint(\"=\" * 60)\n\nn_visits = [len(p[\"visits\"]) for p in synthetic]\nn_codes = [sum(len(v) for v in p[\"visits\"]) for p in synthetic]\n\nprint(f\"\\nPatients: {len(synthetic):,}\")\nprint(f\"\\nVisits per patient:\")\nprint(f\" Mean \u00b1 SD : {np.mean(n_visits):.2f} \u00b1 {np.std(n_visits):.2f}\")\nprint(f\" Median : {np.median(n_visits):.0f}\")\nprint(f\" Range : [{min(n_visits)}, {max(n_visits)}]\")\nprint(f\"\\nDiagnosis codes per patient:\")\nprint(f\" Mean \u00b1 SD : {np.mean(n_codes):.2f} \u00b1 {np.std(n_codes):.2f}\")\nprint(f\" Median : {np.median(n_codes):.0f}\")\nprint(f\" Range : [{min(n_codes)}, {max(n_codes)}]\")\n\nfig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))\nax1.hist(n_visits, bins=20, color='steelblue', edgecolor='white', alpha=0.85)\nax1.set_xlabel('Visits per patient'); ax1.set_ylabel('Count')\nax1.set_title('Visit Count Distribution')\nax2.hist(n_codes, bins=30, color='coral', edgecolor='white', alpha=0.85)\nax2.set_xlabel('Codes per patient'); ax2.set_ylabel('Count')\nax2.set_title('Code Count Distribution')\nplt.tight_layout()\nplt.savefig(f'{OUTPUT_DIR}/count_distributions.png', dpi=150)\nplt.show()" + "source": "print(\"=\" * 60)\nprint(\"SYNTHETIC DATASET STATISTICS\")\nprint(\"=\" * 60)\n\nn_visits = [len(p[\"visits\"]) for p in synthetic]\nn_codes = [sum(len(v) for v in p[\"visits\"]) for p in synthetic]\n\nprint(f\"\\nPatients: {len(synthetic):,}\")\nprint(f\"\\nVisits per patient:\")\nprint(f\" Mean ± SD : {np.mean(n_visits):.2f} ± {np.std(n_visits):.2f}\")\nprint(f\" Median : {np.median(n_visits):.0f}\")\nprint(f\" Range : [{min(n_visits)}, {max(n_visits)}]\")\nprint(f\"\\nDiagnosis codes per patient:\")\nprint(f\" Mean ± SD : {np.mean(n_codes):.2f} ± {np.std(n_codes):.2f}\")\nprint(f\" Median : {np.median(n_codes):.0f}\")\nprint(f\" Range : [{min(n_codes)}, {max(n_codes)}]\")\n\nfig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))\nax1.hist(n_visits, bins=20, color='steelblue', edgecolor='white', alpha=0.85)\nax1.set_xlabel('Visits per patient'); ax1.set_ylabel('Count')\nax1.set_title('Visit Count Distribution')\nax2.hist(n_codes, bins=30, color='coral', edgecolor='white', alpha=0.85)\nax2.set_xlabel('Codes per patient'); ax2.set_ylabel('Count')\nax2.set_title('Code Count Distribution')\nplt.tight_layout()\nplt.savefig(f'{OUTPUT_DIR}/count_distributions.png', dpi=150)\nplt.show()" }, { "cell_type": "code", @@ -192,7 +188,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "all_synth_codes = set(c for p in synthetic for v in p[\"visits\"] for c in v)\nn_real_codes = len(model._vocab._bart_to_code) # ICD-9 codes in vocabulary\ncoverage = len(all_synth_codes) / n_real_codes * 100 if n_real_codes > 0 else 0.0\n\nprint(f\"Vocabulary size (ICD-9 codes): {n_real_codes:,}\")\nprint(f\"Unique codes in synthetic: {len(all_synth_codes):,}\")\nprint(f\"Vocabulary coverage: {coverage:.1f}%\")\n\nif coverage < 30:\n print(\"\\n\u26a0 Low coverage may indicate mode collapse.\")\n print(\" Consider: more EPOCHS, lower LR, or check _demo_pool is populated.\")\nelif coverage < 60:\n print(\"\\nModerate coverage \u2014 expected for demo preset.\")\n print(\"Production training typically achieves 60\u201380%.\")\nelse:\n print(f\"\\n\u2713 Good vocabulary coverage.\")" + "source": "all_synth_codes = set(c for p in synthetic for v in p[\"visits\"] for c in v)\nn_real_codes = len(model._vocab._bart_to_code) # ICD-9 codes in vocabulary\ncoverage = len(all_synth_codes) / n_real_codes * 100 if n_real_codes > 0 else 0.0\n\nprint(f\"Vocabulary size (ICD-9 codes): {n_real_codes:,}\")\nprint(f\"Unique codes in synthetic: {len(all_synth_codes):,}\")\nprint(f\"Vocabulary coverage: {coverage:.1f}%\")\n\nif coverage < 30:\n print(\"\\n⚠ Low coverage may indicate mode collapse.\")\n print(\" Consider: more EPOCHS, lower LR, or check _demo_pool is populated.\")\nelif coverage < 60:\n print(\"\\nModerate coverage — expected for demo preset.\")\n print(\"Production training typically achieves 60–80%.\")\nelse:\n print(f\"\\n✓ Good vocabulary coverage.\")" }, { "cell_type": "code", @@ -200,7 +196,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "# model._demo_pool stores (age, gender) pairs from training data.\n# synthesize_dataset samples from this pool for each synthetic patient,\n# so the synthetic cohort's demographics automatically mirror the training population.\nif model._demo_pool:\n _ages = [a for a, g in model._demo_pool]\n _genders = [g for a, g in model._demo_pool]\n _n_male = sum(1 for g in _genders if g == 0)\n _n_female = sum(1 for g in _genders if g == 1)\n\n fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(13, 5))\n\n ax1.hist(_ages, bins=25, density=True, color='steelblue', edgecolor='white',\n alpha=0.8, label='Training population')\n ax1.axvline(np.mean(_ages), color='navy', linestyle='--', linewidth=1.5,\n label=f'Mean age: {np.mean(_ages):.1f}')\n ax1.set_xlabel('Age at first admission', fontsize=12)\n ax1.set_ylabel('Density', fontsize=12)\n ax1.set_title('Age Distribution\\n(Conditioning Source)', fontsize=13)\n ax1.legend(fontsize=10)\n\n _bars = ax2.bar(['Male', 'Female'], [_n_male, _n_female],\n color=['steelblue', 'coral'], edgecolor='white', alpha=0.85)\n for _bar, _val in zip(_bars, [_n_male, _n_female]):\n ax2.text(_bar.get_x() + _bar.get_width()/2, _bar.get_height() + 5,\n f'{_val:,}\\n({_val/len(_genders)*100:.1f}%)',\n ha='center', va='bottom', fontsize=11)\n ax2.set_ylabel('Patient count', fontsize=12)\n ax2.set_title('Gender Distribution\\n(Conditioning Source)', fontsize=13)\n\n plt.tight_layout()\n plt.savefig(f'{OUTPUT_DIR}/demographics_distribution.png', dpi=150)\n plt.show()\n\n print(f\"Demographics pool: {len(model._demo_pool):,} training patients\")\n print(f\" Age: mean={np.mean(_ages):.1f}, std={np.std(_ages):.1f}, \"\n f\"range=[{min(_ages):.0f}, {max(_ages):.0f}]\")\n print(f\" Male: {_n_male:,} ({_n_male/len(_genders)*100:.1f}%)\")\n print(f\" Female: {_n_female:,} ({_n_female/len(_genders)*100:.1f}%)\")\n print(\"\\n\u2713 Synthetic patients are generated with demographics sampled from this distribution.\")\nelse:\n print(\"_demo_pool is empty \u2014 model was not trained before calling synthesize_dataset.\")\n print(\"Run Section 4 first, or load a checkpoint that was saved after training.\")" + "source": "# model._demo_pool stores (age, gender) pairs from training data.\n# synthesize_dataset samples from this pool for each synthetic patient,\n# so the synthetic cohort's demographics automatically mirror the training population.\nif model._demo_pool:\n _ages = [a for a, g in model._demo_pool]\n _genders = [g for a, g in model._demo_pool]\n _n_male = sum(1 for g in _genders if g == 0)\n _n_female = sum(1 for g in _genders if g == 1)\n\n fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(13, 5))\n\n ax1.hist(_ages, bins=25, density=True, color='steelblue', edgecolor='white',\n alpha=0.8, label='Training population')\n ax1.axvline(np.mean(_ages), color='navy', linestyle='--', linewidth=1.5,\n label=f'Mean age: {np.mean(_ages):.1f}')\n ax1.set_xlabel('Age at first admission', fontsize=12)\n ax1.set_ylabel('Density', fontsize=12)\n ax1.set_title('Age Distribution\\n(Conditioning Source)', fontsize=13)\n ax1.legend(fontsize=10)\n\n _bars = ax2.bar(['Male', 'Female'], [_n_male, _n_female],\n color=['steelblue', 'coral'], edgecolor='white', alpha=0.85)\n for _bar, _val in zip(_bars, [_n_male, _n_female]):\n ax2.text(_bar.get_x() + _bar.get_width()/2, _bar.get_height() + 5,\n f'{_val:,}\\n({_val/len(_genders)*100:.1f}%)',\n ha='center', va='bottom', fontsize=11)\n ax2.set_ylabel('Patient count', fontsize=12)\n ax2.set_title('Gender Distribution\\n(Conditioning Source)', fontsize=13)\n\n plt.tight_layout()\n plt.savefig(f'{OUTPUT_DIR}/demographics_distribution.png', dpi=150)\n plt.show()\n\n print(f\"Demographics pool: {len(model._demo_pool):,} training patients\")\n print(f\" Age: mean={np.mean(_ages):.1f}, std={np.std(_ages):.1f}, \"\n f\"range=[{min(_ages):.0f}, {max(_ages):.0f}]\")\n print(f\" Male: {_n_male:,} ({_n_male/len(_genders)*100:.1f}%)\")\n print(f\" Female: {_n_female:,} ({_n_female/len(_genders)*100:.1f}%)\")\n print(\"\\n✓ Synthetic patients are generated with demographics sampled from this distribution.\")\nelse:\n print(\"_demo_pool is empty — model was not trained before calling synthesize_dataset.\")\n print(\"Run Section 4 first, or load a checkpoint that was saved after training.\")" }, { "cell_type": "code", @@ -208,7 +204,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "# Build real training code frequencies by decoding processor-encoded visit tensors.\n# NestedSequenceProcessor: index 0=pad, 1=unk, 2+=codes.\n# _PromptEHRVocab mapping: bart_id = processor_idx + 5 for codes (idx>=2).\n_vocab_map = model._vocab._bart_to_code # bart_token_id -> ICD-9 code string\n_real_counts = Counter()\n\nfor _sample in train_dataset:\n for _visit in _sample.get(\"visits\", []):\n for _tok in _visit:\n _idx = int(_tok.item()) if hasattr(_tok, 'item') else int(_tok)\n if _idx >= 2: # skip pad(0) and unk(1)\n _bart_id = _idx + 5\n _code = _vocab_map.get(_bart_id)\n if _code:\n _real_counts[_code] += 1\n\n_synth_counts = Counter(c for p in synthetic for v in p[\"visits\"] for c in v)\n\n_top_codes = [c for c, _ in _real_counts.most_common(20)]\n_real_freq = [_real_counts[c] for c in _top_codes]\n_synth_freq = [_synth_counts.get(c, 0) for c in _top_codes]\n\nfig, ax = plt.subplots(figsize=(15, 5))\n_x = range(len(_top_codes))\nax.bar([i - 0.2 for i in _x], _real_freq, 0.38, label='Real (training)', color='steelblue', alpha=0.85)\nax.bar([i + 0.2 for i in _x], _synth_freq, 0.38, label='Synthetic', color='coral', alpha=0.85)\nax.set_xticks(_x)\nax.set_xticklabels(_top_codes, rotation=45, ha='right', fontsize=9)\nax.set_ylabel('Frequency', fontsize=12)\nax.set_title('Top-20 ICD-9 Code Frequency: Real vs Synthetic', fontsize=14)\nax.legend(fontsize=11); ax.grid(axis='y', alpha=0.3)\nplt.tight_layout()\nplt.savefig(f'{OUTPUT_DIR}/code_frequency_comparison.png', dpi=150)\nplt.show()\n\n# Pearson r (manual computation \u2014 no scipy dependency)\n_r_mean = np.mean(_real_freq); _s_mean = np.mean(_synth_freq)\n_num = sum((r - _r_mean)*(s - _s_mean) for r, s in zip(_real_freq, _synth_freq))\n_denom = (sum((r-_r_mean)**2 for r in _real_freq) * sum((s-_s_mean)**2 for s in _synth_freq)) ** 0.5\npearson_r = _num / _denom if _denom > 0 else 0.0\nprint(f\"Pearson r (top-20 code frequencies, real vs synthetic): {pearson_r:.3f}\")\nif pearson_r > 0.8: print(\"\u2713 Strong correlation \u2014 good distributional fidelity.\")\nelif pearson_r > 0.5: print(\"Moderate correlation \u2014 consider more epochs.\")\nelse: print(\"Weak correlation \u2014 model may need more training.\")" + "source": "# Build real training code frequencies by decoding processor-encoded visit tensors.\n# NestedSequenceProcessor: index 0=pad, 1=unk, 2+=codes.\n# _PromptEHRVocab mapping: bart_id = processor_idx + 5 for codes (idx>=2).\n_vocab_map = model._vocab._bart_to_code # bart_token_id -> ICD-9 code string\n_real_counts = Counter()\n\nfor _sample in train_dataset:\n for _visit in _sample.get(\"visits\", []):\n for _tok in _visit:\n _idx = int(_tok.item()) if hasattr(_tok, 'item') else int(_tok)\n if _idx >= 2: # skip pad(0) and unk(1)\n _bart_id = _idx + 5\n _code = _vocab_map.get(_bart_id)\n if _code:\n _real_counts[_code] += 1\n\n_synth_counts = Counter(c for p in synthetic for v in p[\"visits\"] for c in v)\n\n_top_codes = [c for c, _ in _real_counts.most_common(20)]\n_real_freq = [_real_counts[c] for c in _top_codes]\n_synth_freq = [_synth_counts.get(c, 0) for c in _top_codes]\n\nfig, ax = plt.subplots(figsize=(15, 5))\n_x = range(len(_top_codes))\nax.bar([i - 0.2 for i in _x], _real_freq, 0.38, label='Real (training)', color='steelblue', alpha=0.85)\nax.bar([i + 0.2 for i in _x], _synth_freq, 0.38, label='Synthetic', color='coral', alpha=0.85)\nax.set_xticks(_x)\nax.set_xticklabels(_top_codes, rotation=45, ha='right', fontsize=9)\nax.set_ylabel('Frequency', fontsize=12)\nax.set_title('Top-20 ICD-9 Code Frequency: Real vs Synthetic', fontsize=14)\nax.legend(fontsize=11); ax.grid(axis='y', alpha=0.3)\nplt.tight_layout()\nplt.savefig(f'{OUTPUT_DIR}/code_frequency_comparison.png', dpi=150)\nplt.show()\n\n# Pearson r (manual computation — no scipy dependency)\n_r_mean = np.mean(_real_freq); _s_mean = np.mean(_synth_freq)\n_num = sum((r - _r_mean)*(s - _s_mean) for r, s in zip(_real_freq, _synth_freq))\n_denom = (sum((r-_r_mean)**2 for r in _real_freq) * sum((s-_s_mean)**2 for s in _synth_freq)) ** 0.5\npearson_r = _num / _denom if _denom > 0 else 0.0\nprint(f\"Pearson r (top-20 code frequencies, real vs synthetic): {pearson_r:.3f}\")\nif pearson_r > 0.8: print(\"✓ Strong correlation — good distributional fidelity.\")\nelif pearson_r > 0.5: print(\"Moderate correlation — consider more epochs.\")\nelse: print(\"Weak correlation — model may need more training.\")" }, { "cell_type": "code", @@ -216,7 +212,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "_empty = [p for p in synthetic if not p[\"visits\"] or all(len(v) == 0 for v in p[\"visits\"])]\nif _empty:\n print(f\"\u26a0 {len(_empty)} / {len(synthetic)} patients have empty visit sequences.\")\n print(\" Possible causes:\")\n print(\" - Model is undertrained (increase EPOCHS)\")\n print(\" - Temperature too low (try TEMPERATURE = 1.0)\")\n print(\" - _demo_pool not populated (train before calling synthesize_dataset)\")\nelse:\n print(f\"\u2713 All {len(synthetic):,} patients have at least one visit with at least one code.\")" + "source": "_empty = [p for p in synthetic if not p[\"visits\"] or all(len(v) == 0 for v in p[\"visits\"])]\nif _empty:\n print(f\"⚠ {len(_empty)} / {len(synthetic)} patients have empty visit sequences.\")\n print(\" Possible causes:\")\n print(\" - Model is undertrained (increase EPOCHS)\")\n print(\" - Temperature too low (try TEMPERATURE = 1.0)\")\n print(\" - _demo_pool not populated (train before calling synthesize_dataset)\")\nelse:\n print(f\"✓ All {len(synthetic):,} patients have at least one visit with at least one code.\")" }, { "cell_type": "code", @@ -224,7 +220,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "quality = {\n \"total_synthetic_patients\": len(synthetic),\n \"mean_visits_per_patient\": round(float(np.mean(n_visits)), 3),\n \"std_visits_per_patient\": round(float(np.std(n_visits)), 3),\n \"mean_codes_per_patient\": round(float(np.mean(n_codes)), 3),\n \"std_codes_per_patient\": round(float(np.std(n_codes)), 3),\n \"unique_codes_generated\": len(all_synth_codes),\n \"vocabulary_size\": n_real_codes,\n \"vocabulary_coverage_pct\": round(coverage, 2),\n \"empty_patients_count\": len(_empty),\n \"code_freq_pearson_r\": round(pearson_r, 4),\n \"training_patients\": len(train_dataset),\n \"vocab_total_size\": model._vocab.total_size,\n \"demo_mode\": DEMO_MODE,\n \"preset\": PRESET,\n \"epochs\": EPOCHS,\n \"seed\": SEED,\n \"timestamp\": datetime.now().isoformat(),\n}\nreport_path = f'{OUTPUT_DIR}/quality_report.json'\nwith open(report_path, 'w') as f:\n json.dump(quality, f, indent=2)\nprint(\"Quality Report:\")\nprint(json.dumps(quality, indent=2))\nprint(f\"\\n\u2713 Saved to {report_path}\")" + "source": "quality = {\n \"total_synthetic_patients\": len(synthetic),\n \"mean_visits_per_patient\": round(float(np.mean(n_visits)), 3),\n \"std_visits_per_patient\": round(float(np.std(n_visits)), 3),\n \"mean_codes_per_patient\": round(float(np.mean(n_codes)), 3),\n \"std_codes_per_patient\": round(float(np.std(n_codes)), 3),\n \"unique_codes_generated\": len(all_synth_codes),\n \"vocabulary_size\": n_real_codes,\n \"vocabulary_coverage_pct\": round(coverage, 2),\n \"empty_patients_count\": len(_empty),\n \"code_freq_pearson_r\": round(pearson_r, 4),\n \"training_patients\": len(train_dataset),\n \"vocab_total_size\": model._vocab.total_size,\n \"demo_mode\": DEMO_MODE,\n \"preset\": PRESET,\n \"epochs\": EPOCHS,\n \"seed\": SEED,\n \"timestamp\": datetime.now().isoformat(),\n}\nreport_path = f'{OUTPUT_DIR}/quality_report.json'\nwith open(report_path, 'w') as f:\n json.dump(quality, f, indent=2)\nprint(\"Quality Report:\")\nprint(json.dumps(quality, indent=2))\nprint(f\"\\n✓ Saved to {report_path}\")" }, { "cell_type": "markdown", @@ -238,7 +234,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "# Download output files (Colab only \u2014 silently skipped in local/SLURM environments)\n_outputs = [\n csv_path,\n json_path,\n report_path,\n f'{OUTPUT_DIR}/training_loss.png',\n f'{OUTPUT_DIR}/demographics_distribution.png',\n f'{OUTPUT_DIR}/code_frequency_comparison.png',\n f'{CHECKPOINT_DIR}/checkpoint.pt',\n f'{CHECKPOINT_DIR}/config.json',\n]\n\nif IN_COLAB:\n from google.colab import files as _colab_files\n print(\"Downloading output files...\")\n for _p in _outputs:\n if os.path.exists(_p):\n _colab_files.download(_p)\n print(f\" \u2713 {os.path.basename(_p)}\")\n else:\n print(f\" \u2014 {os.path.basename(_p)} (not found)\")\nelse:\n print(f\"Output files saved to: {OUTPUT_DIR}\")\n print(f\"Checkpoint: {CHECKPOINT_DIR}/checkpoint.pt\")\n for _p in _outputs:\n if os.path.exists(_p):\n _kb = os.path.getsize(_p) / 1024\n print(f\" {os.path.basename(_p):45s} {_kb:8.1f} KB\")" + "source": "# Download output files (Colab only — silently skipped in local/SLURM environments)\n_outputs = [\n csv_path,\n json_path,\n report_path,\n f'{OUTPUT_DIR}/training_loss.png',\n f'{OUTPUT_DIR}/demographics_distribution.png',\n f'{OUTPUT_DIR}/code_frequency_comparison.png',\n f'{CHECKPOINT_DIR}/checkpoint.pt',\n f'{CHECKPOINT_DIR}/config.json',\n]\n\nif IN_COLAB:\n from google.colab import files as _colab_files\n print(\"Downloading output files...\")\n for _p in _outputs:\n if os.path.exists(_p):\n _colab_files.download(_p)\n print(f\" ✓ {os.path.basename(_p)}\")\n else:\n print(f\" — {os.path.basename(_p)} (not found)\")\nelse:\n print(f\"Output files saved to: {OUTPUT_DIR}\")\n print(f\"Checkpoint: {CHECKPOINT_DIR}/checkpoint.pt\")\n for _p in _outputs:\n if os.path.exists(_p):\n _kb = os.path.getsize(_p) / 1024\n print(f\" {os.path.basename(_p):45s} {_kb:8.1f} KB\")" }, { "cell_type": "code", @@ -246,13 +242,13 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n# CHECKPOINT RESUME \u2014 Run this cell instead of Section 4 if you already trained\n# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n# Uncomment everything below to load an existing checkpoint, then skip to Section 5.\n\n# from pyhealth.datasets import MIMIC3Dataset, split_by_patient\n# from pyhealth.tasks import promptehr_generation_mimic3_fn\n# from pyhealth.models import PromptEHR\n#\n# dataset = MIMIC3Dataset(\n# root=DATA_DIR,\n# tables=[\"patients\", \"admissions\", \"diagnoses_icd\"],\n# code_mapping={},\n# )\n# sample_dataset = dataset.set_task(promptehr_generation_mimic3_fn)\n# train_dataset, val_dataset, _ = split_by_patient(sample_dataset, [0.8, 0.1, 0.1])\n#\n# model = PromptEHR(\n# dataset=train_dataset,\n# n_num_features=1, cat_cardinalities=[2],\n# d_hidden=D_HIDDEN, prompt_length=PROMPT_LENGTH,\n# bart_config_name=BART_CONFIG_NAME,\n# epochs=EPOCHS, batch_size=BATCH_SIZE,\n# lr=LR, warmup_steps=WARMUP_STEPS,\n# max_seq_length=MAX_SEQ_LENGTH,\n# save_dir=CHECKPOINT_DIR,\n# )\n# ckpt = f'{CHECKPOINT_DIR}/checkpoint.pt'\n# model.load_model(ckpt)\n# print(f\"\u2713 Loaded checkpoint from {ckpt}. Proceed to Section 5.\")\n\nprint(\"(Resume template \u2014 uncomment the lines above to use)\")" + "source": "# ─────────────────────────────────────────────────────────────────────────────\n# CHECKPOINT RESUME — Run this cell instead of Section 4 if you already trained\n# ─────────────────────────────────────────────────────────────────────────────\n# Uncomment everything below to load an existing checkpoint, then skip to Section 5.\n\n# from pyhealth.datasets import MIMIC3Dataset, split_by_patient\n# from pyhealth.tasks import promptehr_generation_mimic3_fn\n# from pyhealth.models import PromptEHR\n#\n# dataset = MIMIC3Dataset(\n# root=DATA_DIR,\n# tables=[\"patients\", \"admissions\", \"diagnoses_icd\"],\n# code_mapping={},\n# )\n# sample_dataset = dataset.set_task(promptehr_generation_mimic3_fn)\n# train_dataset, val_dataset, _ = split_by_patient(sample_dataset, [0.8, 0.1, 0.1])\n#\n# model = PromptEHR(\n# dataset=train_dataset,\n# n_num_features=1, cat_cardinalities=[2],\n# d_hidden=D_HIDDEN, prompt_length=PROMPT_LENGTH,\n# bart_config_name=BART_CONFIG_NAME,\n# epochs=EPOCHS, batch_size=BATCH_SIZE,\n# lr=LR, warmup_steps=WARMUP_STEPS,\n# max_seq_length=MAX_SEQ_LENGTH,\n# save_dir=CHECKPOINT_DIR,\n# )\n# ckpt = f'{CHECKPOINT_DIR}/checkpoint.pt'\n# model.load_model(ckpt)\n# print(f\"✓ Loaded checkpoint from {ckpt}. Proceed to Section 5.\")\n\nprint(\"(Resume template — uncomment the lines above to use)\")" }, { "cell_type": "markdown", "id": "s7-congrats", "metadata": {}, - "source": "---\n## \ud83c\udf89 Congratulations!\n\nYou've successfully:\n1. \u2705 Trained a PromptEHR model conditioned on patient demographics\n2. \u2705 Generated synthetic patients whose age/gender distribution mirrors MIMIC-III\n3. \u2705 Validated ICD-9 code frequency fidelity against real training data\n4. \u2705 Saved output files for downstream use\n\n## Next Steps\n\n**Use your synthetic data:**\n- Train readmission/mortality/LoS prediction models on synthetic data\n- Evaluate fairness across demographic subgroups\n- Share synthetic patients without privacy concerns\n\n**Reload and generate more:**\n```python\nfrom pyhealth.models import PromptEHR\nmodel = PromptEHR(dataset=train_dataset, ...)\nmodel.load_model('./promptehr_training/checkpoints/checkpoint.pt')\nextra = model.synthesize_dataset(num_samples=50_000)\n```\n\n## Troubleshooting\n\n| Symptom | Cause | Fix |\n|---------|-------|-----|\n| `AssertionError: transformers>=4.48.3 required` | Old transformers installed | `pip install transformers --upgrade` |\n| Empty patients in output | Undertrained model | Increase `EPOCHS` or raise `TEMPERATURE` to `1.0` |\n| Training loss not decreasing after 2+ epochs | LR too high | Try `LR = 5e-6` and `WARMUP_STEPS = 500` |\n| Out of memory (OOM) | Batch too large | Reduce `BATCH_SIZE = 8` |\n| Very slow training | No GPU | Runtime \u2192 Change runtime type \u2192 T4 GPU |\n| `KeyError: 'visits'` in demo mode | Wrong schema | Ensure `input_schema={\"visits\": \"nested_sequence\"}` |\n| Synthetic codes all the same | Temperature too low | Try `TEMPERATURE = 1.0`, `RANDOM_SAMPLING = True` |\n\n---\n\n## Reference\n\nWang, Y., et al. \"PromptEHR: Conditional Electronic Healthcare Records Generation with Prompt Learning.\" *EMNLP 2023*. https://arxiv.org/abs/2211.01761\n\n---\n_Notebook for PyHealth 2.0 \u00b7 Branch: `promptehr-pr-integration` \u00b7 jalengg/PyHealth_" + "source": "---\n## 🎉 Congratulations!\n\nYou've successfully:\n1. ✅ Trained a PromptEHR model conditioned on patient demographics\n2. ✅ Generated synthetic patients whose age/gender distribution mirrors MIMIC-III\n3. ✅ Validated ICD-9 code frequency fidelity against real training data\n4. ✅ Saved output files for downstream use\n\n## Next Steps\n\n**Use your synthetic data:**\n- Train readmission/mortality/LoS prediction models on synthetic data\n- Evaluate fairness across demographic subgroups\n- Share synthetic patients without privacy concerns\n\n**Reload and generate more:**\n```python\nfrom pyhealth.models import PromptEHR\nmodel = PromptEHR(dataset=train_dataset, ...)\nmodel.load_model('./promptehr_training/checkpoints/checkpoint.pt')\nextra = model.synthesize_dataset(num_samples=50_000)\n```\n\n## Troubleshooting\n\n| Symptom | Cause | Fix |\n|---------|-------|-----|\n| `AssertionError: transformers>=4.48.3 required` | Old transformers installed | `pip install transformers --upgrade` |\n| Empty patients in output | Undertrained model | Increase `EPOCHS` or raise `TEMPERATURE` to `1.0` |\n| Training loss not decreasing after 2+ epochs | LR too high | Try `LR = 5e-6` and `WARMUP_STEPS = 500` |\n| Out of memory (OOM) | Batch too large | Reduce `BATCH_SIZE = 8` |\n| Very slow training | No GPU | Runtime → Change runtime type → T4 GPU |\n| `KeyError: 'visits'` in demo mode | Wrong schema | Ensure `input_schema={\"visits\": \"nested_sequence\"}` |\n| Synthetic codes all the same | Temperature too low | Try `TEMPERATURE = 1.0`, `RANDOM_SAMPLING = True` |\n\n---\n\n## Reference\n\nWang, Y., et al. \"PromptEHR: Conditional Electronic Healthcare Records Generation with Prompt Learning.\" *EMNLP 2023*. https://arxiv.org/abs/2211.01761\n\n---\n_Notebook for PyHealth 2.0 · Branch: `promptehr-pr-integration` · jalengg/PyHealth_" } ] } \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 308e6b114..c9fd4626d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,7 @@ dependencies = [ "networkx", "mne~=1.10.0", "urllib3~=2.5.0", - "numpy~=2.2.0", + "numpy>=2.0.0", "tqdm", "polars~=1.35.2", "pandas~=2.3.1", From 364d6f64804d8b774a11e23bbccb92ed2e90a6e5 Mon Sep 17 00:00:00 2001 From: jalengg Date: Wed, 4 Mar 2026 12:22:13 -0600 Subject: [PATCH 29/37] Fix numpy mixed-version error: replace --force-reinstall with uninstall+install --- examples/promptehr_mimic3_colab.ipynb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/promptehr_mimic3_colab.ipynb b/examples/promptehr_mimic3_colab.ipynb index 14d6bc1fb..03868f859 100644 --- a/examples/promptehr_mimic3_colab.ipynb +++ b/examples/promptehr_mimic3_colab.ipynb @@ -22,7 +22,7 @@ "cell_type": "markdown", "id": "preamble", "metadata": {}, - "source": "# PromptEHR Synthetic Data Generation for MIMIC-III\n\n_Last updated: 2026-03-04 17:15:17 (UTC)_\n\nThis notebook trains PromptEHR on your MIMIC-III data and generates synthetic patients conditioned on patient demographics (age and gender).\n\n## What You'll Need\n\n1. **MIMIC-III Access**: Download these files from PhysioNet:\n - `PATIENTS.csv` — patient demographics (date of birth, gender)\n - `ADMISSIONS.csv` — hospital admission records (timestamps)\n - `DIAGNOSES_ICD.csv` — ICD-9 diagnosis codes\n\n2. **Google Colab**: Free tier works, but GPU recommended (Runtime → Change runtime type → GPU)\n\n3. **Time**:\n - Demo (5 epochs, 1K samples): ~30–45 min on GPU\n - Production (20 epochs, 10K samples): ~3–5 hrs on GPU\n\n## How It Works\n\n1. **Setup**: Install PyHealth and mount Google Drive\n2. **Upload Data**: Upload your MIMIC-III CSV files (persisted to Drive across sessions)\n3. **Configure**: Set hyperparameters (epochs, batch size, etc.)\n4. **Train**: Fine-tune PromptEHR on MIMIC-III (checkpoints saved to Drive)\n5. **Generate**: Create synthetic patients conditioned on sampled demographics\n6. **Download**: Get CSV and JSON files with synthetic data\n\n## Important Notes\n\n**Colab Timeout**: Free Colab sessions timeout after ~12 hours. For production training, consider:\n- Colab Pro for longer sessions\n- Running on a GPU cluster using `examples/slurm/`\n\n**Demo vs Production**:\n- Demo defaults (5 epochs, 1K samples) let you test the full pipeline quickly\n- Production settings (20 epochs, 10K samples) produce publication-quality results\n\n## References\n\n- [PromptEHR Paper](https://arxiv.org/abs/2211.01761) — Wang et al., EMNLP 2023\n- [PyHealth Documentation](https://pyhealth.readthedocs.io/)\n- [MIMIC-III Access](https://physionet.org/content/mimiciii/)" + "source": "# PromptEHR Synthetic Data Generation for MIMIC-III\n\n_Last updated: 2026-03-04 17:45:00 (UTC)_\n\nThis notebook trains PromptEHR on your MIMIC-III data and generates synthetic patients conditioned on patient demographics (age and gender).\n\n## What You'll Need\n\n1. **MIMIC-III Access**: Download these files from PhysioNet:\n - `PATIENTS.csv` — patient demographics (date of birth, gender)\n - `ADMISSIONS.csv` — hospital admission records (timestamps)\n - `DIAGNOSES_ICD.csv` — ICD-9 diagnosis codes\n\n2. **Google Colab**: Free tier works, but GPU recommended (Runtime → Change runtime type → GPU)\n\n3. **Time**:\n - Demo (5 epochs, 1K samples): ~30–45 min on GPU\n - Production (20 epochs, 10K samples): ~3–5 hrs on GPU\n\n## How It Works\n\n1. **Setup**: Install PyHealth and mount Google Drive\n2. **Upload Data**: Upload your MIMIC-III CSV files (persisted to Drive across sessions)\n3. **Configure**: Set hyperparameters (epochs, batch size, etc.)\n4. **Train**: Fine-tune PromptEHR on MIMIC-III (checkpoints saved to Drive)\n5. **Generate**: Create synthetic patients conditioned on sampled demographics\n6. **Download**: Get CSV and JSON files with synthetic data\n\n## Important Notes\n\n**Colab Timeout**: Free Colab sessions timeout after ~12 hours. For production training, consider:\n- Colab Pro for longer sessions\n- Running on a GPU cluster using `examples/slurm/`\n\n**Demo vs Production**:\n- Demo defaults (5 epochs, 1K samples) let you test the full pipeline quickly\n- Production settings (20 epochs, 10K samples) produce publication-quality results\n\n## References\n\n- [PromptEHR Paper](https://arxiv.org/abs/2211.01761) — Wang et al., EMNLP 2023\n- [PyHealth Documentation](https://pyhealth.readthedocs.io/)\n- [MIMIC-III Access](https://physionet.org/content/mimiciii/)" }, { "cell_type": "markdown", @@ -36,7 +36,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "import subprocess\nimport sys\n\n# 1. Install PyHealth from GitHub — force-reinstall ensures Colab never uses a stale cached build.\n# (This may pull in old/mismatched transitive deps — we normalize them in step 2.)\nFORK = 'jalengg'\nBRANCH = 'promptehr-pr-integration'\ninstall_url = f\"git+https://github.com/{FORK}/PyHealth.git@{BRANCH}\"\nresult = subprocess.run(\n [sys.executable, \"-m\", \"pip\", \"install\", install_url,\n \"--quiet\", \"--no-cache-dir\", \"--force-reinstall\"],\n capture_output=True, text=True,\n)\nif result.returncode != 0:\n print(result.stderr)\n raise RuntimeError(\"PyHealth installation failed — see error above.\")\nprint(f\"✓ PyHealth installed from {FORK}/{BRANCH}\")\n\n# 2. Upgrade optional deps that Colab may have at outdated versions.\n# We do NOT use --force-reinstall here — that would also force-reinstall\n# numpy/scipy and their transitive deps, creating a mixed-version state\n# (Python files from one version, compiled extensions from another).\n# PyHealth's pyproject.toml already pins numpy>=2.0.0, so step 1 ensures\n# a compatible numpy is installed. We just upgrade scipy and Pillow:\n# - scipy>=1.14: first release with full numpy 2.x support\n# - Pillow>=10.4.0: consistent PIL internals (_Ink moved between versions)\nsubprocess.run(\n [sys.executable, \"-m\", \"pip\", \"install\",\n \"scipy>=1.14\", \"Pillow>=10.4.0\",\n \"--quiet\", \"--no-cache-dir\", \"--upgrade\"],\n check=True,\n)\nprint(\"✓ scipy>=1.14, Pillow>=10.4.0 upgraded\")\n\n# Environment detection — MUST come before any google.colab import\ntry:\n import google.colab # noqa: F401\n IN_COLAB = True\nexcept ImportError:\n IN_COLAB = False\n\nimport os\nimport json\nimport glob\nimport random\nimport numpy as np\nimport torch\nimport pandas as pd\nimport matplotlib.pyplot as plt\nfrom datetime import datetime\nfrom collections import Counter\nfrom IPython.display import display\n\nprint(f\"Running in Colab: {IN_COLAB}\")\nprint(f\"PyTorch: {torch.__version__}\")\nif torch.cuda.is_available():\n print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n gb = torch.cuda.get_device_properties(0).total_memory / 1e9\n print(f\"GPU memory: {gb:.1f} GB\")\nelse:\n print(\"WARNING: No GPU detected. Training will be slow.\")\n print(\" → Runtime → Change runtime type → T4 GPU\")\n\n# Verify HuggingFace transformers version (>=4.48.3 required for use_cpu parameter)\nimport transformers\n_ver = tuple(int(x) for x in transformers.__version__.split(\".\")[:2])\nassert _ver >= (4, 48), (\n f\"transformers>=4.48.3 required (got {transformers.__version__}). \"\n \"Fix: pip install transformers --upgrade\"\n)\nprint(f\"transformers: {transformers.__version__} ✓\")\nprint(\"✓ All setup complete\")" + "source": "import subprocess\nimport sys\n\n# 1. Install PyHealth from GitHub.\n# We uninstall first (to clear any stale build from a previous session),\n# then do a normal install. We do NOT use --force-reinstall because it\n# force-reinstalls ALL transitive deps (numpy, scipy, etc.), which in\n# Colab's system environment creates a mixed-version state where numpy's\n# .py files and compiled .so extensions come from different versions.\nFORK = 'jalengg'\nBRANCH = 'promptehr-pr-integration'\ninstall_url = f\"git+https://github.com/{FORK}/PyHealth.git@{BRANCH}\"\n\nsubprocess.run(\n [sys.executable, \"-m\", \"pip\", \"uninstall\", \"pyhealth\", \"-y\"],\n capture_output=True, text=True,\n)\nresult = subprocess.run(\n [sys.executable, \"-m\", \"pip\", \"install\", install_url,\n \"--quiet\", \"--no-cache-dir\"],\n capture_output=True, text=True,\n)\nif result.returncode != 0:\n print(result.stderr)\n raise RuntimeError(\"PyHealth installation failed — see error above.\")\nprint(f\"✓ PyHealth installed from {FORK}/{BRANCH}\")\n\n# 2. Upgrade scipy and Pillow if Colab has outdated versions.\n# Normal --upgrade only touches packages that need updating and does NOT\n# force-reinstall transitive deps (i.e., numpy stays untouched).\nsubprocess.run(\n [sys.executable, \"-m\", \"pip\", \"install\",\n \"scipy>=1.14\", \"Pillow>=10.4.0\",\n \"--quiet\", \"--no-cache-dir\", \"--upgrade\"],\n check=True,\n)\nprint(\"✓ scipy>=1.14, Pillow>=10.4.0 upgraded\")\n\n# Environment detection — MUST come before any google.colab import\ntry:\n import google.colab # noqa: F401\n IN_COLAB = True\nexcept ImportError:\n IN_COLAB = False\n\nimport os\nimport json\nimport glob\nimport random\nimport numpy as np\nimport torch\nimport pandas as pd\nimport matplotlib.pyplot as plt\nfrom datetime import datetime\nfrom collections import Counter\nfrom IPython.display import display\n\nprint(f\"Running in Colab: {IN_COLAB}\")\nprint(f\"numpy: {np.__version__} (should match Colab's pre-installed version)\")\nprint(f\"PyTorch: {torch.__version__}\")\nif torch.cuda.is_available():\n print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n gb = torch.cuda.get_device_properties(0).total_memory / 1e9\n print(f\"GPU memory: {gb:.1f} GB\")\nelse:\n print(\"WARNING: No GPU detected. Training will be slow.\")\n print(\" → Runtime → Change runtime type → T4 GPU\")\n\n# Verify HuggingFace transformers version (>=4.48.3 required for use_cpu parameter)\nimport transformers\n_ver = tuple(int(x) for x in transformers.__version__.split(\".\")[:2])\nassert _ver >= (4, 48), (\n f\"transformers>=4.48.3 required (got {transformers.__version__}). \"\n \"Fix: pip install transformers --upgrade\"\n)\nprint(f\"transformers: {transformers.__version__} ✓\")\nprint(\"✓ All setup complete\")" }, { "cell_type": "markdown", From 121999f4fb9a6c935116fe5e3f66400af20c4781 Mon Sep 17 00:00:00 2001 From: jalengg Date: Wed, 4 Mar 2026 12:51:32 -0600 Subject: [PATCH 30/37] Remove step 2 dep upgrades causing Pillow mixed-version state --- examples/promptehr_mimic3_colab.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/promptehr_mimic3_colab.ipynb b/examples/promptehr_mimic3_colab.ipynb index 03868f859..c2d9008c9 100644 --- a/examples/promptehr_mimic3_colab.ipynb +++ b/examples/promptehr_mimic3_colab.ipynb @@ -36,7 +36,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "import subprocess\nimport sys\n\n# 1. Install PyHealth from GitHub.\n# We uninstall first (to clear any stale build from a previous session),\n# then do a normal install. We do NOT use --force-reinstall because it\n# force-reinstalls ALL transitive deps (numpy, scipy, etc.), which in\n# Colab's system environment creates a mixed-version state where numpy's\n# .py files and compiled .so extensions come from different versions.\nFORK = 'jalengg'\nBRANCH = 'promptehr-pr-integration'\ninstall_url = f\"git+https://github.com/{FORK}/PyHealth.git@{BRANCH}\"\n\nsubprocess.run(\n [sys.executable, \"-m\", \"pip\", \"uninstall\", \"pyhealth\", \"-y\"],\n capture_output=True, text=True,\n)\nresult = subprocess.run(\n [sys.executable, \"-m\", \"pip\", \"install\", install_url,\n \"--quiet\", \"--no-cache-dir\"],\n capture_output=True, text=True,\n)\nif result.returncode != 0:\n print(result.stderr)\n raise RuntimeError(\"PyHealth installation failed — see error above.\")\nprint(f\"✓ PyHealth installed from {FORK}/{BRANCH}\")\n\n# 2. Upgrade scipy and Pillow if Colab has outdated versions.\n# Normal --upgrade only touches packages that need updating and does NOT\n# force-reinstall transitive deps (i.e., numpy stays untouched).\nsubprocess.run(\n [sys.executable, \"-m\", \"pip\", \"install\",\n \"scipy>=1.14\", \"Pillow>=10.4.0\",\n \"--quiet\", \"--no-cache-dir\", \"--upgrade\"],\n check=True,\n)\nprint(\"✓ scipy>=1.14, Pillow>=10.4.0 upgraded\")\n\n# Environment detection — MUST come before any google.colab import\ntry:\n import google.colab # noqa: F401\n IN_COLAB = True\nexcept ImportError:\n IN_COLAB = False\n\nimport os\nimport json\nimport glob\nimport random\nimport numpy as np\nimport torch\nimport pandas as pd\nimport matplotlib.pyplot as plt\nfrom datetime import datetime\nfrom collections import Counter\nfrom IPython.display import display\n\nprint(f\"Running in Colab: {IN_COLAB}\")\nprint(f\"numpy: {np.__version__} (should match Colab's pre-installed version)\")\nprint(f\"PyTorch: {torch.__version__}\")\nif torch.cuda.is_available():\n print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n gb = torch.cuda.get_device_properties(0).total_memory / 1e9\n print(f\"GPU memory: {gb:.1f} GB\")\nelse:\n print(\"WARNING: No GPU detected. Training will be slow.\")\n print(\" → Runtime → Change runtime type → T4 GPU\")\n\n# Verify HuggingFace transformers version (>=4.48.3 required for use_cpu parameter)\nimport transformers\n_ver = tuple(int(x) for x in transformers.__version__.split(\".\")[:2])\nassert _ver >= (4, 48), (\n f\"transformers>=4.48.3 required (got {transformers.__version__}). \"\n \"Fix: pip install transformers --upgrade\"\n)\nprint(f\"transformers: {transformers.__version__} ✓\")\nprint(\"✓ All setup complete\")" + "source": "import subprocess\nimport sys\n\n# Install PyHealth from GitHub.\n# We uninstall first (to clear any stale build from a previous session),\n# then do a normal install. We do NOT use --force-reinstall because it\n# force-reinstalls ALL transitive deps (numpy, scipy, Pillow, etc.),\n# which in Colab's system environment creates mixed-version states.\n# We also do NOT explicitly upgrade transitive deps — Colab's\n# pre-installed numpy/scipy/Pillow are left untouched.\nFORK = 'jalengg'\nBRANCH = 'promptehr-pr-integration'\ninstall_url = f\"git+https://github.com/{FORK}/PyHealth.git@{BRANCH}\"\n\nsubprocess.run(\n [sys.executable, \"-m\", \"pip\", \"uninstall\", \"pyhealth\", \"-y\"],\n capture_output=True, text=True,\n)\nresult = subprocess.run(\n [sys.executable, \"-m\", \"pip\", \"install\", install_url,\n \"--quiet\", \"--no-cache-dir\"],\n capture_output=True, text=True,\n)\nif result.returncode != 0:\n print(result.stderr)\n raise RuntimeError(\"PyHealth installation failed — see error above.\")\nprint(f\"✓ PyHealth installed from {FORK}/{BRANCH}\")\n\n# Environment detection — MUST come before any google.colab import\ntry:\n import google.colab # noqa: F401\n IN_COLAB = True\nexcept ImportError:\n IN_COLAB = False\n\nimport os\nimport json\nimport glob\nimport random\nimport numpy as np\nimport torch\nimport pandas as pd\nimport matplotlib.pyplot as plt\nfrom datetime import datetime\nfrom collections import Counter\nfrom IPython.display import display\n\nprint(f\"Running in Colab: {IN_COLAB}\")\nprint(f\"numpy: {np.__version__}\")\nprint(f\"PyTorch: {torch.__version__}\")\nif torch.cuda.is_available():\n print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n gb = torch.cuda.get_device_properties(0).total_memory / 1e9\n print(f\"GPU memory: {gb:.1f} GB\")\nelse:\n print(\"WARNING: No GPU detected. Training will be slow.\")\n print(\" → Runtime → Change runtime type → T4 GPU\")\n\n# Verify HuggingFace transformers version (>=4.48.3 required for use_cpu parameter)\nimport transformers\n_ver = tuple(int(x) for x in transformers.__version__.split(\".\")[:2])\nassert _ver >= (4, 48), (\n f\"transformers>=4.48.3 required (got {transformers.__version__}). \"\n \"Fix: pip install transformers --upgrade\"\n)\nprint(f\"transformers: {transformers.__version__} ✓\")\nprint(\"✓ All setup complete\")" }, { "cell_type": "markdown", From 732d207b5f70ddc7e15821a33d689942e91e244f Mon Sep 17 00:00:00 2001 From: jalengg Date: Wed, 4 Mar 2026 13:18:28 -0600 Subject: [PATCH 31/37] Fix Pillow mixed-version state: force-reinstall Pillow after PyHealth install MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Same pattern as HALO's scipy fix (b80f837b): PyHealth install may partially upgrade Pillow (via torch→torchvision→Pillow cascade), leaving mixed .py/.so files. Force-reinstall only Pillow (--no-deps) before it gets imported so all files come from one version. --- examples/promptehr_mimic3_colab.ipynb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/promptehr_mimic3_colab.ipynb b/examples/promptehr_mimic3_colab.ipynb index c2d9008c9..87fae512d 100644 --- a/examples/promptehr_mimic3_colab.ipynb +++ b/examples/promptehr_mimic3_colab.ipynb @@ -22,7 +22,7 @@ "cell_type": "markdown", "id": "preamble", "metadata": {}, - "source": "# PromptEHR Synthetic Data Generation for MIMIC-III\n\n_Last updated: 2026-03-04 17:45:00 (UTC)_\n\nThis notebook trains PromptEHR on your MIMIC-III data and generates synthetic patients conditioned on patient demographics (age and gender).\n\n## What You'll Need\n\n1. **MIMIC-III Access**: Download these files from PhysioNet:\n - `PATIENTS.csv` — patient demographics (date of birth, gender)\n - `ADMISSIONS.csv` — hospital admission records (timestamps)\n - `DIAGNOSES_ICD.csv` — ICD-9 diagnosis codes\n\n2. **Google Colab**: Free tier works, but GPU recommended (Runtime → Change runtime type → GPU)\n\n3. **Time**:\n - Demo (5 epochs, 1K samples): ~30–45 min on GPU\n - Production (20 epochs, 10K samples): ~3–5 hrs on GPU\n\n## How It Works\n\n1. **Setup**: Install PyHealth and mount Google Drive\n2. **Upload Data**: Upload your MIMIC-III CSV files (persisted to Drive across sessions)\n3. **Configure**: Set hyperparameters (epochs, batch size, etc.)\n4. **Train**: Fine-tune PromptEHR on MIMIC-III (checkpoints saved to Drive)\n5. **Generate**: Create synthetic patients conditioned on sampled demographics\n6. **Download**: Get CSV and JSON files with synthetic data\n\n## Important Notes\n\n**Colab Timeout**: Free Colab sessions timeout after ~12 hours. For production training, consider:\n- Colab Pro for longer sessions\n- Running on a GPU cluster using `examples/slurm/`\n\n**Demo vs Production**:\n- Demo defaults (5 epochs, 1K samples) let you test the full pipeline quickly\n- Production settings (20 epochs, 10K samples) produce publication-quality results\n\n## References\n\n- [PromptEHR Paper](https://arxiv.org/abs/2211.01761) — Wang et al., EMNLP 2023\n- [PyHealth Documentation](https://pyhealth.readthedocs.io/)\n- [MIMIC-III Access](https://physionet.org/content/mimiciii/)" + "source": "# PromptEHR Synthetic Data Generation for MIMIC-III\n\n_Last updated: 2026-03-04 19:00:00 (UTC)_\n\nThis notebook trains PromptEHR on your MIMIC-III data and generates synthetic patients conditioned on patient demographics (age and gender).\n\n## What You'll Need\n\n1. **MIMIC-III Access**: Download these files from PhysioNet:\n - `PATIENTS.csv` — patient demographics (date of birth, gender)\n - `ADMISSIONS.csv` — hospital admission records (timestamps)\n - `DIAGNOSES_ICD.csv` — ICD-9 diagnosis codes\n\n2. **Google Colab**: Free tier works, but GPU recommended (Runtime → Change runtime type → GPU)\n\n3. **Time**:\n - Demo (5 epochs, 1K samples): ~30–45 min on GPU\n - Production (20 epochs, 10K samples): ~3–5 hrs on GPU\n\n## How It Works\n\n1. **Setup**: Install PyHealth and mount Google Drive\n2. **Upload Data**: Upload your MIMIC-III CSV files (persisted to Drive across sessions)\n3. **Configure**: Set hyperparameters (epochs, batch size, etc.)\n4. **Train**: Fine-tune PromptEHR on MIMIC-III (checkpoints saved to Drive)\n5. **Generate**: Create synthetic patients conditioned on sampled demographics\n6. **Download**: Get CSV and JSON files with synthetic data\n\n## Important Notes\n\n**Colab Timeout**: Free Colab sessions timeout after ~12 hours. For production training, consider:\n- Colab Pro for longer sessions\n- Running on a GPU cluster using `examples/slurm/`\n\n**Demo vs Production**:\n- Demo defaults (5 epochs, 1K samples) let you test the full pipeline quickly\n- Production settings (20 epochs, 10K samples) produce publication-quality results\n\n## References\n\n- [PromptEHR Paper](https://arxiv.org/abs/2211.01761) — Wang et al., EMNLP 2023\n- [PyHealth Documentation](https://pyhealth.readthedocs.io/)\n- [MIMIC-III Access](https://physionet.org/content/mimiciii/)" }, { "cell_type": "markdown", @@ -36,7 +36,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "import subprocess\nimport sys\n\n# Install PyHealth from GitHub.\n# We uninstall first (to clear any stale build from a previous session),\n# then do a normal install. We do NOT use --force-reinstall because it\n# force-reinstalls ALL transitive deps (numpy, scipy, Pillow, etc.),\n# which in Colab's system environment creates mixed-version states.\n# We also do NOT explicitly upgrade transitive deps — Colab's\n# pre-installed numpy/scipy/Pillow are left untouched.\nFORK = 'jalengg'\nBRANCH = 'promptehr-pr-integration'\ninstall_url = f\"git+https://github.com/{FORK}/PyHealth.git@{BRANCH}\"\n\nsubprocess.run(\n [sys.executable, \"-m\", \"pip\", \"uninstall\", \"pyhealth\", \"-y\"],\n capture_output=True, text=True,\n)\nresult = subprocess.run(\n [sys.executable, \"-m\", \"pip\", \"install\", install_url,\n \"--quiet\", \"--no-cache-dir\"],\n capture_output=True, text=True,\n)\nif result.returncode != 0:\n print(result.stderr)\n raise RuntimeError(\"PyHealth installation failed — see error above.\")\nprint(f\"✓ PyHealth installed from {FORK}/{BRANCH}\")\n\n# Environment detection — MUST come before any google.colab import\ntry:\n import google.colab # noqa: F401\n IN_COLAB = True\nexcept ImportError:\n IN_COLAB = False\n\nimport os\nimport json\nimport glob\nimport random\nimport numpy as np\nimport torch\nimport pandas as pd\nimport matplotlib.pyplot as plt\nfrom datetime import datetime\nfrom collections import Counter\nfrom IPython.display import display\n\nprint(f\"Running in Colab: {IN_COLAB}\")\nprint(f\"numpy: {np.__version__}\")\nprint(f\"PyTorch: {torch.__version__}\")\nif torch.cuda.is_available():\n print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n gb = torch.cuda.get_device_properties(0).total_memory / 1e9\n print(f\"GPU memory: {gb:.1f} GB\")\nelse:\n print(\"WARNING: No GPU detected. Training will be slow.\")\n print(\" → Runtime → Change runtime type → T4 GPU\")\n\n# Verify HuggingFace transformers version (>=4.48.3 required for use_cpu parameter)\nimport transformers\n_ver = tuple(int(x) for x in transformers.__version__.split(\".\")[:2])\nassert _ver >= (4, 48), (\n f\"transformers>=4.48.3 required (got {transformers.__version__}). \"\n \"Fix: pip install transformers --upgrade\"\n)\nprint(f\"transformers: {transformers.__version__} ✓\")\nprint(\"✓ All setup complete\")" + "source": "import subprocess\nimport sys\n\n# Install PyHealth from GitHub.\n# We uninstall first (to clear any stale build from a previous session),\n# then do a normal install. We do NOT use --force-reinstall because it\n# force-reinstalls ALL transitive deps, which in Colab's system environment\n# creates mixed-version states (old .so + new .py from different versions).\nFORK = 'jalengg'\nBRANCH = 'promptehr-pr-integration'\ninstall_url = f\"git+https://github.com/{FORK}/PyHealth.git@{BRANCH}\"\n\nsubprocess.run(\n [sys.executable, \"-m\", \"pip\", \"uninstall\", \"pyhealth\", \"-y\"],\n capture_output=True, text=True,\n)\nresult = subprocess.run(\n [sys.executable, \"-m\", \"pip\", \"install\", install_url,\n \"--quiet\", \"--no-cache-dir\"],\n capture_output=True, text=True,\n)\nif result.returncode != 0:\n print(result.stderr)\n raise RuntimeError(\"PyHealth installation failed — see error above.\")\nprint(f\"✓ PyHealth installed from {FORK}/{BRANCH}\")\n\n# Fix Pillow consistency — PyHealth install may partially upgrade Pillow\n# (via torch → torchvision → Pillow), leaving mixed-version files.\n# Force-reinstall ONLY Pillow (--no-deps) so all files come from one version.\n# Works because Pillow hasn't been imported yet in this Python process.\n# Same pattern as HALO's scipy fix (commit b80f837b).\nsubprocess.run(\n [sys.executable, \"-m\", \"pip\", \"install\", \"Pillow\",\n \"--force-reinstall\", \"--no-deps\", \"--quiet\", \"--no-cache-dir\"],\n capture_output=True, text=True,\n)\nprint(\"✓ Pillow reinstalled (consistent state)\")\n\n# Environment detection — MUST come before any google.colab import\ntry:\n import google.colab # noqa: F401\n IN_COLAB = True\nexcept ImportError:\n IN_COLAB = False\n\nimport os\nimport json\nimport glob\nimport random\nimport numpy as np\nimport torch\nimport pandas as pd\nimport matplotlib.pyplot as plt\nfrom datetime import datetime\nfrom collections import Counter\nfrom IPython.display import display\n\nprint(f\"Running in Colab: {IN_COLAB}\")\nprint(f\"numpy: {np.__version__}\")\nprint(f\"PyTorch: {torch.__version__}\")\nif torch.cuda.is_available():\n print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n gb = torch.cuda.get_device_properties(0).total_memory / 1e9\n print(f\"GPU memory: {gb:.1f} GB\")\nelse:\n print(\"WARNING: No GPU detected. Training will be slow.\")\n print(\" → Runtime → Change runtime type → T4 GPU\")\n\n# Verify HuggingFace transformers version (>=4.48.3 required for use_cpu parameter)\nimport transformers\n_ver = tuple(int(x) for x in transformers.__version__.split(\".\")[:2])\nassert _ver >= (4, 48), (\n f\"transformers>=4.48.3 required (got {transformers.__version__}). \"\n \"Fix: pip install transformers --upgrade\"\n)\nprint(f\"transformers: {transformers.__version__} ✓\")\nprint(\"✓ All setup complete\")" }, { "cell_type": "markdown", From d05d445811182527a51c841b6db52600a35f8b93 Mon Sep 17 00:00:00 2001 From: jalengg Date: Wed, 4 Mar 2026 13:37:10 -0600 Subject: [PATCH 32/37] Fix Colab PIL error: hide torchvision during BART import in PromptEHR MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit transformers 4.53+ eagerly imports loss_utils → image_utils → torchvision → PIL, even for non-vision models like BART. In Colab, Pillow is in a mixed-version state that can't be fixed by pip (system-managed files). Fix: temporarily remove torchvision from sys.modules during the BART import so transformers skips the vision chain entirely. PromptEHR only needs BART, not vision functionality. --- examples/promptehr_mimic3_colab.ipynb | 2 +- pyhealth/models/promptehr/model.py | 17 +++++++++++++++-- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/examples/promptehr_mimic3_colab.ipynb b/examples/promptehr_mimic3_colab.ipynb index 87fae512d..3e4366251 100644 --- a/examples/promptehr_mimic3_colab.ipynb +++ b/examples/promptehr_mimic3_colab.ipynb @@ -36,7 +36,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "import subprocess\nimport sys\n\n# Install PyHealth from GitHub.\n# We uninstall first (to clear any stale build from a previous session),\n# then do a normal install. We do NOT use --force-reinstall because it\n# force-reinstalls ALL transitive deps, which in Colab's system environment\n# creates mixed-version states (old .so + new .py from different versions).\nFORK = 'jalengg'\nBRANCH = 'promptehr-pr-integration'\ninstall_url = f\"git+https://github.com/{FORK}/PyHealth.git@{BRANCH}\"\n\nsubprocess.run(\n [sys.executable, \"-m\", \"pip\", \"uninstall\", \"pyhealth\", \"-y\"],\n capture_output=True, text=True,\n)\nresult = subprocess.run(\n [sys.executable, \"-m\", \"pip\", \"install\", install_url,\n \"--quiet\", \"--no-cache-dir\"],\n capture_output=True, text=True,\n)\nif result.returncode != 0:\n print(result.stderr)\n raise RuntimeError(\"PyHealth installation failed — see error above.\")\nprint(f\"✓ PyHealth installed from {FORK}/{BRANCH}\")\n\n# Fix Pillow consistency — PyHealth install may partially upgrade Pillow\n# (via torch → torchvision → Pillow), leaving mixed-version files.\n# Force-reinstall ONLY Pillow (--no-deps) so all files come from one version.\n# Works because Pillow hasn't been imported yet in this Python process.\n# Same pattern as HALO's scipy fix (commit b80f837b).\nsubprocess.run(\n [sys.executable, \"-m\", \"pip\", \"install\", \"Pillow\",\n \"--force-reinstall\", \"--no-deps\", \"--quiet\", \"--no-cache-dir\"],\n capture_output=True, text=True,\n)\nprint(\"✓ Pillow reinstalled (consistent state)\")\n\n# Environment detection — MUST come before any google.colab import\ntry:\n import google.colab # noqa: F401\n IN_COLAB = True\nexcept ImportError:\n IN_COLAB = False\n\nimport os\nimport json\nimport glob\nimport random\nimport numpy as np\nimport torch\nimport pandas as pd\nimport matplotlib.pyplot as plt\nfrom datetime import datetime\nfrom collections import Counter\nfrom IPython.display import display\n\nprint(f\"Running in Colab: {IN_COLAB}\")\nprint(f\"numpy: {np.__version__}\")\nprint(f\"PyTorch: {torch.__version__}\")\nif torch.cuda.is_available():\n print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n gb = torch.cuda.get_device_properties(0).total_memory / 1e9\n print(f\"GPU memory: {gb:.1f} GB\")\nelse:\n print(\"WARNING: No GPU detected. Training will be slow.\")\n print(\" → Runtime → Change runtime type → T4 GPU\")\n\n# Verify HuggingFace transformers version (>=4.48.3 required for use_cpu parameter)\nimport transformers\n_ver = tuple(int(x) for x in transformers.__version__.split(\".\")[:2])\nassert _ver >= (4, 48), (\n f\"transformers>=4.48.3 required (got {transformers.__version__}). \"\n \"Fix: pip install transformers --upgrade\"\n)\nprint(f\"transformers: {transformers.__version__} ✓\")\nprint(\"✓ All setup complete\")" + "source": "import subprocess\nimport sys\n\n# Install PyHealth from GitHub.\n# We uninstall first (to clear any stale build from a previous session),\n# then do a normal install. We do NOT use --force-reinstall because it\n# force-reinstalls ALL transitive deps, which in Colab's system environment\n# creates mixed-version states (old .so + new .py from different versions).\nFORK = 'jalengg'\nBRANCH = 'promptehr-pr-integration'\ninstall_url = f\"git+https://github.com/{FORK}/PyHealth.git@{BRANCH}\"\n\nsubprocess.run(\n [sys.executable, \"-m\", \"pip\", \"uninstall\", \"pyhealth\", \"-y\"],\n capture_output=True, text=True,\n)\nresult = subprocess.run(\n [sys.executable, \"-m\", \"pip\", \"install\", install_url,\n \"--quiet\", \"--no-cache-dir\"],\n capture_output=True, text=True,\n)\nif result.returncode != 0:\n print(result.stderr)\n raise RuntimeError(\"PyHealth installation failed — see error above.\")\nprint(f\"✓ PyHealth installed from {FORK}/{BRANCH}\")\n\n# Environment detection — MUST come before any google.colab import\ntry:\n import google.colab # noqa: F401\n IN_COLAB = True\nexcept ImportError:\n IN_COLAB = False\n\nimport os\nimport json\nimport glob\nimport random\nimport numpy as np\nimport torch\nimport pandas as pd\nimport matplotlib.pyplot as plt\nfrom datetime import datetime\nfrom collections import Counter\nfrom IPython.display import display\n\nprint(f\"Running in Colab: {IN_COLAB}\")\nprint(f\"numpy: {np.__version__}\")\nprint(f\"PyTorch: {torch.__version__}\")\nif torch.cuda.is_available():\n print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n gb = torch.cuda.get_device_properties(0).total_memory / 1e9\n print(f\"GPU memory: {gb:.1f} GB\")\nelse:\n print(\"WARNING: No GPU detected. Training will be slow.\")\n print(\" → Runtime → Change runtime type → T4 GPU\")\n\n# Verify HuggingFace transformers version (>=4.48.3 required for use_cpu parameter)\nimport transformers\n_ver = tuple(int(x) for x in transformers.__version__.split(\".\")[:2])\nassert _ver >= (4, 48), (\n f\"transformers>=4.48.3 required (got {transformers.__version__}). \"\n \"Fix: pip install transformers --upgrade\"\n)\nprint(f\"transformers: {transformers.__version__} ✓\")\nprint(\"✓ All setup complete\")" }, { "cell_type": "markdown", diff --git a/pyhealth/models/promptehr/model.py b/pyhealth/models/promptehr/model.py index 2657c0c7d..11847c0bf 100644 --- a/pyhealth/models/promptehr/model.py +++ b/pyhealth/models/promptehr/model.py @@ -8,12 +8,25 @@ import os import random +import sys from typing import Dict, List, Optional, Tuple import torch import torch.nn as nn from torch.nn.utils.rnn import pad_sequence -from transformers import BartConfig, BartForConditionalGeneration -from transformers.modeling_outputs import Seq2SeqLMOutput + +# Temporarily hide torchvision so transformers skips the +# image_utils → torchvision → PIL import chain (which fails in Colab +# due to mixed-version Pillow files). PromptEHR only needs BART, +# not any vision functionality from transformers. +_tv = sys.modules.pop("torchvision", None) +try: + from transformers import BartConfig, BartForConditionalGeneration + from transformers.modeling_outputs import Seq2SeqLMOutput +finally: + if _tv is not None: + sys.modules["torchvision"] = _tv + +del _tv from pyhealth.models import BaseModel from .conditional_prompt import ConditionalPromptEncoder From cd42f5bb9b027369477e14cb251d913dd15c26f6 Mon Sep 17 00:00:00 2001 From: jalengg Date: Wed, 4 Mar 2026 14:54:11 -0600 Subject: [PATCH 33/37] Fix PyHealth 2.0 API: remove code_mapping kwarg, use unique_patient_ids --- examples/promptehr_mimic3_colab.ipynb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/promptehr_mimic3_colab.ipynb b/examples/promptehr_mimic3_colab.ipynb index 3e4366251..f5450f227 100644 --- a/examples/promptehr_mimic3_colab.ipynb +++ b/examples/promptehr_mimic3_colab.ipynb @@ -114,7 +114,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "# Set all random seeds before any stochastic operation\ntorch.manual_seed(SEED)\nnp.random.seed(SEED)\nrandom.seed(SEED)\nif torch.cuda.is_available():\n torch.cuda.manual_seed_all(SEED)\n torch.backends.cudnn.deterministic = True\nprint(f\"✓ Random seed set to {SEED}\")\n\nfrom pyhealth.datasets import split_by_patient\nfrom pyhealth.models import PromptEHR\n\nif not DEMO_MODE:\n from pyhealth.datasets import MIMIC3Dataset\n from pyhealth.tasks import promptehr_generation_mimic3_fn\n\n print(\"\\nLoading MIMIC-III dataset (this may take a few minutes)...\")\n dataset = MIMIC3Dataset(\n root=DATA_DIR,\n tables=[\"patients\", \"admissions\", \"diagnoses_icd\"],\n code_mapping={},\n )\n print(f\"Loaded {len(dataset.patients):,} patients\")\n\n print(\"Applying PromptEHR generation task...\")\n sample_dataset = dataset.set_task(promptehr_generation_mimic3_fn)\n print(f\"Eligible patients (≥2 visits with ICD-9 codes): {len(sample_dataset):,}\")\nelse:\n from pyhealth.datasets.sample_dataset import InMemorySampleDataset\n sample_dataset = InMemorySampleDataset(\n samples=demo_samples,\n input_schema={\"visits\": \"nested_sequence\"},\n output_schema={},\n )\n print(f\"Demo dataset ready: {len(sample_dataset)} patients\")\n\ntrain_dataset, val_dataset, _ = split_by_patient(sample_dataset, [0.8, 0.1, 0.1])\nprint(f\"\\nSplit: {len(train_dataset):,} train / {len(val_dataset):,} val patients\")" + "source": "# Set all random seeds before any stochastic operation\ntorch.manual_seed(SEED)\nnp.random.seed(SEED)\nrandom.seed(SEED)\nif torch.cuda.is_available():\n torch.cuda.manual_seed_all(SEED)\n torch.backends.cudnn.deterministic = True\nprint(f\"✓ Random seed set to {SEED}\")\n\nfrom pyhealth.datasets import split_by_patient\nfrom pyhealth.models import PromptEHR\n\nif not DEMO_MODE:\n from pyhealth.datasets import MIMIC3Dataset\n from pyhealth.tasks import promptehr_generation_mimic3_fn\n\n print(\"\\nLoading MIMIC-III dataset (this may take a few minutes)...\")\n dataset = MIMIC3Dataset(\n root=DATA_DIR,\n tables=[\"patients\", \"admissions\", \"diagnoses_icd\"],\n )\n print(f\"Loaded {len(dataset.unique_patient_ids):,} patients\")\n\n print(\"Applying PromptEHR generation task...\")\n sample_dataset = dataset.set_task(promptehr_generation_mimic3_fn)\n print(f\"Eligible patients (≥2 visits with ICD-9 codes): {len(sample_dataset):,}\")\nelse:\n from pyhealth.datasets.sample_dataset import InMemorySampleDataset\n sample_dataset = InMemorySampleDataset(\n samples=demo_samples,\n input_schema={\"visits\": \"nested_sequence\"},\n output_schema={},\n )\n print(f\"Demo dataset ready: {len(sample_dataset)} patients\")\n\ntrain_dataset, val_dataset, _ = split_by_patient(sample_dataset, [0.8, 0.1, 0.1])\nprint(f\"\\nSplit: {len(train_dataset):,} train / {len(val_dataset):,} val patients\")" }, { "cell_type": "code", @@ -242,7 +242,7 @@ "metadata": {}, "outputs": [], "execution_count": null, - "source": "# ─────────────────────────────────────────────────────────────────────────────\n# CHECKPOINT RESUME — Run this cell instead of Section 4 if you already trained\n# ─────────────────────────────────────────────────────────────────────────────\n# Uncomment everything below to load an existing checkpoint, then skip to Section 5.\n\n# from pyhealth.datasets import MIMIC3Dataset, split_by_patient\n# from pyhealth.tasks import promptehr_generation_mimic3_fn\n# from pyhealth.models import PromptEHR\n#\n# dataset = MIMIC3Dataset(\n# root=DATA_DIR,\n# tables=[\"patients\", \"admissions\", \"diagnoses_icd\"],\n# code_mapping={},\n# )\n# sample_dataset = dataset.set_task(promptehr_generation_mimic3_fn)\n# train_dataset, val_dataset, _ = split_by_patient(sample_dataset, [0.8, 0.1, 0.1])\n#\n# model = PromptEHR(\n# dataset=train_dataset,\n# n_num_features=1, cat_cardinalities=[2],\n# d_hidden=D_HIDDEN, prompt_length=PROMPT_LENGTH,\n# bart_config_name=BART_CONFIG_NAME,\n# epochs=EPOCHS, batch_size=BATCH_SIZE,\n# lr=LR, warmup_steps=WARMUP_STEPS,\n# max_seq_length=MAX_SEQ_LENGTH,\n# save_dir=CHECKPOINT_DIR,\n# )\n# ckpt = f'{CHECKPOINT_DIR}/checkpoint.pt'\n# model.load_model(ckpt)\n# print(f\"✓ Loaded checkpoint from {ckpt}. Proceed to Section 5.\")\n\nprint(\"(Resume template — uncomment the lines above to use)\")" + "source": "# ─────────────────────────────────────────────────────────────────────────────\n# CHECKPOINT RESUME — Run this cell instead of Section 4 if you already trained\n# ─────────────────────────────────────────────────────────────────────────────\n# Uncomment everything below to load an existing checkpoint, then skip to Section 5.\n\n# from pyhealth.datasets import MIMIC3Dataset, split_by_patient\n# from pyhealth.tasks import promptehr_generation_mimic3_fn\n# from pyhealth.models import PromptEHR\n#\n# dataset = MIMIC3Dataset(\n# root=DATA_DIR,\n# tables=[\"patients\", \"admissions\", \"diagnoses_icd\"],\n# )\n# sample_dataset = dataset.set_task(promptehr_generation_mimic3_fn)\n# train_dataset, val_dataset, _ = split_by_patient(sample_dataset, [0.8, 0.1, 0.1])\n#\n# model = PromptEHR(\n# dataset=train_dataset,\n# n_num_features=1, cat_cardinalities=[2],\n# d_hidden=D_HIDDEN, prompt_length=PROMPT_LENGTH,\n# bart_config_name=BART_CONFIG_NAME,\n# epochs=EPOCHS, batch_size=BATCH_SIZE,\n# lr=LR, warmup_steps=WARMUP_STEPS,\n# max_seq_length=MAX_SEQ_LENGTH,\n# save_dir=CHECKPOINT_DIR,\n# )\n# ckpt = f'{CHECKPOINT_DIR}/checkpoint.pt'\n# model.load_model(ckpt)\n# print(f\"✓ Loaded checkpoint from {ckpt}. Proceed to Section 5.\")\n\nprint(\"(Resume template — uncomment the lines above to use)\")" }, { "cell_type": "markdown", From 8c1120ead2816c977119a1aa3d866f2587d851a2 Mon Sep 17 00:00:00 2001 From: jalengg Date: Wed, 4 Mar 2026 15:39:35 -0600 Subject: [PATCH 34/37] Remove icustays from MIMIC3Dataset defaults (same fix as HALO c52aa0b0) No PromptEHR task uses icustays, and most users don't have the file. --- pyhealth/datasets/mimic3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyhealth/datasets/mimic3.py b/pyhealth/datasets/mimic3.py index 7e569d2f3..bdc00f904 100644 --- a/pyhealth/datasets/mimic3.py +++ b/pyhealth/datasets/mimic3.py @@ -53,7 +53,7 @@ def __init__( if config_path is None: logger.info("No config path provided, using default config") config_path = Path(__file__).parent / "configs" / "mimic3.yaml" - default_tables = ["patients", "admissions", "icustays"] + default_tables = ["patients", "admissions"] tables = default_tables + tables if "prescriptions" in tables: warnings.warn( From cb0f6f979364ef4fb10b2df38551cca8b49f27bd Mon Sep 17 00:00:00 2001 From: jalengg Date: Wed, 4 Mar 2026 16:10:10 -0600 Subject: [PATCH 35/37] Fix device mismatch in synthesize_dataset: use bart_model's device HuggingFace Trainer moves bart_model to GPU but doesn't move the parent PromptEHR module. self.device (from _dummy_param) stays CPU while bart_model is on GPU, causing RuntimeError during generation. --- pyhealth/models/promptehr/model.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pyhealth/models/promptehr/model.py b/pyhealth/models/promptehr/model.py index 11847c0bf..35cb1d67c 100644 --- a/pyhealth/models/promptehr/model.py +++ b/pyhealth/models/promptehr/model.py @@ -722,7 +722,9 @@ def synthesize_dataset( ``"visits"`` (list of list of str): decoded code strings per visit. """ self.bart_model.eval() - device = self.device + # Use bart_model's device, not self.device — HuggingFace Trainer + # moves bart_model to GPU but doesn't move the parent PromptEHR module. + device = next(self.bart_model.parameters()).device results = [] with torch.no_grad(): From f5a7d4c881f326a8ae274ea5f9ae7179749767f0 Mon Sep 17 00:00:00 2001 From: jalengg Date: Wed, 4 Mar 2026 16:12:12 -0600 Subject: [PATCH 36/37] Fix beam search crash in generate: set num_beams=1 explicitly transformers defaults to beam search which fails with batch_size inference on our single-token encoder input. PromptEHR uses nucleus/greedy sampling, not beam search. --- pyhealth/models/promptehr/model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pyhealth/models/promptehr/model.py b/pyhealth/models/promptehr/model.py index 35cb1d67c..1780b8717 100644 --- a/pyhealth/models/promptehr/model.py +++ b/pyhealth/models/promptehr/model.py @@ -751,6 +751,7 @@ def synthesize_dataset( x_num=x_num, x_cat=x_cat, max_length=self.max_seq_length, + num_beams=1, do_sample=random_sampling, temperature=0.7 if random_sampling else 1.0, top_p=0.95 if random_sampling else 1.0, From d591d0b91b1d4528878db74c2cf063d85be73086 Mon Sep 17 00:00:00 2001 From: jalengg Date: Wed, 4 Mar 2026 17:01:39 -0600 Subject: [PATCH 37/37] Fix decode_tokens: skip BOS/EOS/PAD instead of breaking on them MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit BART generate() always starts output with decoder_start_token_id (BOS=1). The ported code treated BOS as a stop token (break), causing decode_tokens to return empty visits for every patient. Original pehr_scratch/generate.py::parse_sequence_to_visits uses continue to skip BOS — this was a porting bug. Fix from promptehr-port branch commit 97f6a7b. --- pyhealth/models/promptehr/model.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pyhealth/models/promptehr/model.py b/pyhealth/models/promptehr/model.py index 1780b8717..9e7ee6037 100644 --- a/pyhealth/models/promptehr/model.py +++ b/pyhealth/models/promptehr/model.py @@ -458,6 +458,10 @@ def decode_tokens(self, token_ids: List[int]) -> List[List[str]]: current_visit: List[str] = [] in_visit = False for tid in token_ids: + if tid in (self.PAD, self.BOS, self.EOS): + continue # skip framing tokens (BOS is first in generate output) + if tid == self.SEQ_END: + break if tid == self.VISIT_START: in_visit = True current_visit = [] @@ -465,8 +469,6 @@ def decode_tokens(self, token_ids: List[int]) -> List[List[str]]: if in_visit: visits.append(current_visit) in_visit = False - elif tid in (self.SEQ_END, self.EOS, self.PAD, self.BOS): - break elif in_visit and tid >= self.CODE_OFFSET: code = self._bart_to_code.get(tid) if code: