From a3f36a8f4c72e88fbf1a898bd086a1230684e353 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Tue, 10 Feb 2026 18:53:49 -0600 Subject: [PATCH 01/15] Non-working code, copy from add/bms_run to add reference code for stageattn. --- pyhealth/interpret/methods/chefer.py | 113 ++++++++++++++++++++++++--- 1 file changed, 104 insertions(+), 9 deletions(-) diff --git a/pyhealth/interpret/methods/chefer.py b/pyhealth/interpret/methods/chefer.py index 5bef2fc45..19a4bb6a2 100644 --- a/pyhealth/interpret/methods/chefer.py +++ b/pyhealth/interpret/methods/chefer.py @@ -16,6 +16,14 @@ HAS_TORCHVISION_MODEL = False TorchvisionModel = None +# Import StageAttentionNet conditionally to avoid circular imports +try: + from pyhealth.models import StageAttentionNet + HAS_STAGEATTN = True +except ImportError: + HAS_STAGEATTN = False + StageAttentionNet = None + def apply_self_attention_rules(R_ss, cam_ss): """Apply Chefer's self-attention rules for relevance propagation. @@ -73,11 +81,13 @@ class CheferRelevance(BaseInterpreter): Supported Models: - PyHealth Transformer: For sequential/EHR data with multiple feature keys + - StageAttentionNet: For temporal/EHR data with MHA-based StageNet layers - TorchvisionModel (ViT variants): vit_b_16, vit_b_32, vit_l_16, vit_l_32, vit_h_14 Args: - model (BaseModel): A trained PyHealth model to interpret. Must be either: + model (BaseModel): A trained PyHealth model to interpret. Must be one of: - A ``Transformer`` model for sequential/EHR data + - A ``StageAttentionNet`` model for temporal/EHR data - A ``TorchvisionModel`` with a ViT architecture for image data Example: @@ -163,15 +173,20 @@ def __init__(self, model: BaseModel): # Determine model type self._is_transformer = isinstance(model, Transformer) self._is_vit = False + self._is_stageattn = False + + if HAS_STAGEATTN and StageAttentionNet is not None: + self._is_stageattn = isinstance(model, StageAttentionNet) if HAS_TORCHVISION_MODEL and TorchvisionModel is not None: if isinstance(model, TorchvisionModel): self._is_vit = model.is_vit_model() - if not self._is_transformer and not self._is_vit: + if not self._is_transformer and not self._is_vit and not self._is_stageattn: raise ValueError( - f"CheferRelevance requires a Transformer or TorchvisionModel (ViT), " - f"got {type(model).__name__}. For TorchvisionModel, only ViT variants " + f"CheferRelevance requires a Transformer, StageAttentionNet, " + f"or TorchvisionModel (ViT), got {type(model).__name__}. " + f"For TorchvisionModel, only ViT variants " f"(vit_b_16, vit_b_32, etc.) are supported." ) @@ -193,13 +208,14 @@ def attribute( you want to explain why a specific class was predicted or to compare attributions across different classes. **data: Input data from dataloader batch containing: - - For Transformer: feature keys (conditions, procedures, etc.) + label + - For Transformer/StageAttentionNet: feature keys + label - For ViT: image feature key (e.g., "image") + label Returns: Dict[str, torch.Tensor]: Dictionary keyed by feature keys from the task schema. - - For Transformer: ``{"conditions": tensor, "procedures": tensor, ...}`` + - For Transformer/StageAttentionNet: + ``{"conditions": tensor, "procedures": tensor, ...}`` where each tensor has shape ``[batch, num_tokens]``. - For ViT: ``{"image": tensor}`` (or whatever the task's image key is) where tensor has shape ``[batch, 1, H, W]``. @@ -210,6 +226,8 @@ def attribute( class_index=class_index, **data ) + if self._is_stageattn: + return self._attribute_stageattn(class_index=class_index, **data) return self._attribute_transformer(class_index=class_index, **data) def _attribute_transformer( @@ -229,7 +247,10 @@ def _attribute_transformer( if class_index is None: class_index = torch.argmax(logits, dim=-1) - one_hot = F.one_hot(torch.tensor(class_index), logits.size()[1]).float() + if isinstance(class_index, torch.Tensor): + one_hot = F.one_hot(class_index.detach().clone(), logits.size()[1]).float() + else: + one_hot = F.one_hot(torch.tensor(class_index), logits.size()[1]).float() one_hot = one_hot.requires_grad_(True) one_hot = torch.sum(one_hot.to(logits.device) * logits) self.model.zero_grad() @@ -242,12 +263,13 @@ def _attribute_transformer( for block in feature_transformer: num_tokens[key] = block.attention.get_attn_map().shape[-1] + batch_size = logits.shape[0] attn = {} for key in feature_keys: R = ( torch.eye(num_tokens[key]) .unsqueeze(0) - .repeat(len(data[key]), 1, 1) + .repeat(batch_size, 1, 1) .to(logits.device) ) for blk in self.model.transformer[key].transformer: @@ -259,6 +281,79 @@ def _attribute_transformer( return attn + def _attribute_stageattn( + self, + class_index: int = None, + **data, + ) -> Dict[str, torch.Tensor]: + """Compute relevance for StageAttentionNet models. + + StageAttentionNet has a single MHA layer per feature key (inside + ``model.stagenet[key]``) rather than a stack of TransformerBlocks. + It also uses the *last valid timestep* (via ``get_last_visit``) + instead of a CLS token for classification, so we extract the + relevance row corresponding to that timestep. + + Args: + class_index: Target class for attribution. If None, uses predicted class. + **data: Input data from dataloader batch. + """ + # StageAttentionNet uses 'register_attn_hook' (not 'register_hook') + data["register_attn_hook"] = True + + logits = self.model(**data)["logit"] + if class_index is None: + class_index = torch.argmax(logits, dim=-1) + + if isinstance(class_index, torch.Tensor): + one_hot = F.one_hot(class_index.detach().clone(), logits.size()[1]).float() + else: + one_hot = F.one_hot(torch.tensor(class_index), logits.size()[1]).float() + one_hot = one_hot.requires_grad_(True) + one_hot = torch.sum(one_hot.to(logits.device) * logits) + self.model.zero_grad() + one_hot.backward(retain_graph=True) + + batch_size = logits.shape[0] + feature_keys = self.model.feature_keys + attn = {} + + for key in feature_keys: + layer = self.model.stagenet[key] + cam = layer.get_attn_map() + grad = layer.get_attn_grad() + num_tokens = cam.shape[-1] + + R = ( + torch.eye(num_tokens) + .unsqueeze(0) + .repeat(batch_size, 1, 1) + .to(logits.device) + ) + cam = avg_heads(cam, grad) + R += apply_self_attention_rules(R, cam).detach() + + # StageAttentionNet uses get_last_visit (last valid timestep) + # instead of a CLS token. Reconstruct the mask to find the + # index that was actually used for classification. + feature = data[key] + if isinstance(feature, tuple) and len(feature) == 2: + _, x_val = feature + else: + x_val = feature + + embedded = self.model.embedding_model({key: x_val}) + emb = embedded[key] + if emb.dim() == 4: + emb = emb.sum(dim=2) + mask = (emb.sum(dim=-1) != 0).long().to(logits.device) + + # last valid index per sample + last_idx = mask.sum(dim=1) - 1 # [batch] + attn[key] = R[torch.arange(batch_size, device=logits.device), last_idx] + + return attn + def _attribute_vit( self, interpolate: bool = True, @@ -364,4 +459,4 @@ def get_vit_attribution_map( ) # Return the attribution tensor directly (get the first/only value) feature_key = self.model.feature_keys[0] - return result[feature_key] + return result[feature_key] \ No newline at end of file From 8b28480313e38b04974f17896af230220bd17a8e Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Tue, 10 Feb 2026 19:29:48 -0600 Subject: [PATCH 02/15] Move interpretability methods into a separate API --- pyhealth/interpret/api.py | 218 ++++++++++++++++++++++++++++++++ pyhealth/models/base_model.py | 31 ----- pyhealth/models/mlp.py | 3 +- pyhealth/models/stagenet.py | 3 +- pyhealth/models/stagenet_mha.py | 3 +- pyhealth/models/transformer.py | 3 +- 6 files changed, 226 insertions(+), 35 deletions(-) create mode 100644 pyhealth/interpret/api.py diff --git a/pyhealth/interpret/api.py b/pyhealth/interpret/api.py new file mode 100644 index 000000000..09d521c68 --- /dev/null +++ b/pyhealth/interpret/api.py @@ -0,0 +1,218 @@ +from abc import ABC, abstractmethod +import torch +from torch import nn + +class InterpretableModelInterface(ABC): + """Abstract interface for models supporting interpretability methods. + + This class defines the contract that models must fulfill to be compatible + with PyHealth's interpretability module. It enables gradient-based + attribution methods and embedding-level perturbation methods to work with + your model. + + The interface separates the embedding stage (which generates learned + representations from raw features) from the prediction stage + (which generates outputs from embeddings). This separation allows + interpretability methods to either: + + 1. Use gradients flowing through embeddings + 2. Perturb embeddings and pass them through prediction head + 3. Directly access and analyze the learned representations + + Methods + ------- + forward_from_embedding + Perform forward pass starting from embeddings. + get_embedding_model + Get the embedding/feature extraction stage if applicable. + + Assumptions + ----------- + Models implementing this interface must adhere to the following assumptions: + + 1. **Optional label handling**: The ``forward_from_embedding()`` method must + accept label keys (as specified in ``self.label_keys``) as optional keyword + arguments. The method should handle cases where labels are missing without + raising exceptions. When labels are absent, the method should skip loss + computation and omit 'loss' and 'y_true' from the return dictionary. + + 2. **Non-linearity as nn.Module**: All non-linear activation functions + (ReLU, Sigmoid, Softmax, Tanh, etc.) must be defined as nn.Module instances + in the model's ``__init__`` method and called as instance methods + (e.g., ``self.relu(x)``). Do NOT use functional variants like ``F.relu(x)``, + ``F.sigmoid(x)``, or ``F.softmax(x)``. This is critical for + gradient-based interpretability methods (e.g., DeepLIFT) that require + hooks to be registered on non-linearities. + + Examples of correct activation usage:: + + class GoodModel(nn.Module): + def __init__(self): + super().__init__() + self.relu = nn.ReLU() # Correct + self.sigmoid = nn.Sigmoid() # Correct + + def forward(self, x): + x = self.relu(x) # Correct + x = self.sigmoid(x) # Correct + return x + + Examples of incorrect activation usage:: + + class BadModel(nn.Module): + def forward(self, x): + x = F.relu(x) # WRONG - functional variant + x = F.sigmoid(x) # WRONG - functional variant + return x + """ + + def forward_from_embedding( + self, + **kwargs: torch.Tensor | tuple[torch.Tensor, ...] + ) -> dict[str, torch.Tensor]: + """Forward pass of the model starting from embeddings. + + This method enables interpretability methods to pass embeddings directly + into the model's prediction head, bypassing the embedding stage. This is + useful for: + + - **Gradient-based attribution** (DeepLIFT, Integrated Gradients): + Allows gradients to be computed with respect to embeddings + - **Embedding perturbation** (LIME, SHAP): Allows perturbing embeddings + instead of raw features + - **Intermediate representation analysis**: Enables inspection of learned + representations at the embedding layer + + Kwargs keys typically mirror the model's feature keys (from the dataset's + input_schema), but represent embeddings instead of raw features. + + Parameters + ---------- + **kwargs : torch.Tensor or tuple[torch.Tensor, ...] + Variable keyword arguments representing input embeddings and optional labels. + + **Embedding arguments** (required): Should include all feature keys that + your model expects. Examples: + + - 'conditions': (batch_size, seq_length, embedding_dim) + - 'procedures': (batch_size, seq_length, embedding_dim) + - 'image': (batch_size, embedding_dim, height, width) + + **Label arguments** (optional): May include any label keys defined in + ``self.label_keys``. If label keys are present, the method should compute + loss and include 'loss' and 'y_true' in the return dictionary. If label + keys are absent, the method must not crash; simply omit 'loss' and 'y_true' + from the return dictionary. + + Returns + ------- + dict[str, torch.Tensor] + A dictionary containing model outputs with the following keys: + + - **logit** (torch.Tensor): Raw model predictions/logits of shape + (batch_size, num_classes) for classification tasks. + + - **y_prob** (torch.Tensor): Predicted probabilities of shape + (batch_size, num_classes). For binary classification, often + shape (batch_size, 1). + + - **loss** (torch.Tensor, optional): Scalar loss value if + any of ``self.label_keys`` are present in kwargs. Returned only when + ground truth labels are provided. Should not be included if labels are + unavailable. + + - **y_true** (torch.Tensor, optional): True labels if present in + kwargs. Useful for consistency checking during attribution. Should not + be included if labels are unavailable. + + Additional keys may be returned depending on the model's task type + (e.g., 'risks' for survival analysis, 'seq_output' for sequence models). + + Raises + ------ + NotImplementedError + If the subclass does not implement this method. + + Notes + ----- + The implementation must gracefully handle missing label keys. Interpretability + methods may invoke this method with only embedding inputs (no labels), expecting + forward passes for attribution computation. The method should compute predictions + successfully in both scenarios. + + Examples + -------- + For an EHR model with embedding dimension 64: + + >>> model = MyEHRModel(...) + >>> batch_embeddings = { + ... 'conditions': torch.randn(32, 100, 64), # 32 samples, 100 time steps + ... 'procedures': torch.randn(32, 100, 64), + ... } + >>> output = model.forward_from_embedding(**batch_embeddings) + >>> logits = output['logit'] # Shape: (32, num_classes) + >>> y_prob = output['y_prob'] # Shape: (32, num_classes) + + With optional labels: + + >>> batch_embeddings['mortality'] = torch.tensor([0, 1, 0, ...]) # Add labels + >>> output = model.forward_from_embedding(**batch_embeddings) + >>> loss = output['loss'] # Now included + >>> y_true = output['y_true'] # Now included + + For an image model with spatial embeddings: + + >>> model = MyImageModel(...) + >>> batch_embeddings = { + ... 'image': torch.randn(16, 768, 14, 14), # Vision Transformer embeddings + ... } + >>> output = model.forward_from_embedding(**batch_embeddings) + """ + raise NotImplementedError + + def get_embedding_model(self) -> nn.Module | None: + """Get the embedding/feature extraction stage of the model. + + This method provides access to the model's embedding stage, which + transforms raw input features into learned vector representations. + This is used by interpretability methods to: + + - Generate embeddings from raw features before attribution + - Identify the boundary between feature processing and prediction + - Apply embedding-level analysis separately from prediction + + Returns + ------- + nn.Module or None + The embedding model/stage as an nn.Module if applicable, or None + if the model does not have a separable embedding stage. + + When returning a model, it should: + + - Accept the same input signature as the parent model (raw features) + - Produce embeddings that are compatible with forward_from_embedding() + - Be in the same device as the parent model + + Raises + ------ + NotImplementedError + If the subclass does not implement this method. + + Examples + -------- + For a model with explicit embedding and prediction stages: + + >>> class MyModel(InterpretableModelInterface): + ... def __init__(self): + ... self.embedding_layer = EmbeddingBlock(...) + ... self.prediction_head = PredictionBlock(...) + ... + ... def get_embedding_model(self): + ... return self.embedding_layer + + For models without a clear separable embedding stage, return None: + + >>> def get_embedding_model(self): + ... return None # Embeddings are not separately accessible + """ + raise NotImplementedError \ No newline at end of file diff --git a/pyhealth/models/base_model.py b/pyhealth/models/base_model.py index fc59e6026..f40615794 100644 --- a/pyhealth/models/base_model.py +++ b/pyhealth/models/base_model.py @@ -73,37 +73,6 @@ def forward(self, y_true [optional]: a tensor representing the true labels, if self.label_keys in kwargs. """ raise NotImplementedError - - def forward_from_embedding( - self, - **kwargs: torch.Tensor | tuple[torch.Tensor, ...] - ) -> dict[str, torch.Tensor]: - """Forward pass of the model from embeddings. - - This method should be implemented for interpretability methods that require - access to the model's forward pass from embeddings. - - Args: - **kwargs: A variable number of keyword arguments representing input features - as embeddings. Each keyword argument is a tensor or a tuple of tensors of - shape (batch_size, ...). - - Returns: - A dictionary with the following keys: - logit: a tensor of predicted logits. - y_prob: a tensor of predicted probabilities. - loss [optional]: a scalar tensor representing the final loss, if self.label_keys in kwargs. - y_true [optional]: a tensor representing the true labels, if self.label_keys in kwargs. - """ - raise NotImplementedError - - def get_embedding_model(self) -> nn.Module | None: - """Get the embedding model if applicable. This is used in pair with `forward_from_embedding`. - - Returns: - nn.Module | None: The embedding model or None if not applicable. - """ - raise NotImplementedError # ------------------------------------------------------------------ # Internal helpers diff --git a/pyhealth/models/mlp.py b/pyhealth/models/mlp.py index 306073313..2d8826956 100644 --- a/pyhealth/models/mlp.py +++ b/pyhealth/models/mlp.py @@ -5,11 +5,12 @@ from pyhealth.datasets import SampleDataset from pyhealth.models import BaseModel +from pyhealth.interpret.api import InterpretableModelInterface from .embedding import EmbeddingModel -class MLP(BaseModel): +class MLP(BaseModel, InterpretableModelInterface): """Multi-layer perceptron model. This model applies a separate MLP layer for each feature, and then diff --git a/pyhealth/models/stagenet.py b/pyhealth/models/stagenet.py index bf78bb216..2bdf12296 100644 --- a/pyhealth/models/stagenet.py +++ b/pyhealth/models/stagenet.py @@ -6,6 +6,7 @@ from pyhealth.datasets import SampleDataset from pyhealth.models import BaseModel from pyhealth.models.utils import get_last_visit +from pyhealth.interpret.api import InterpretableModelInterface from .embedding import EmbeddingModel @@ -239,7 +240,7 @@ def forward( return last_output, output, torch.stack(distance) -class StageNet(BaseModel): +class StageNet(BaseModel, InterpretableModelInterface): """StageNet model. Paper: Junyi Gao et al. Stagenet: Stage-aware neural networks for health diff --git a/pyhealth/models/stagenet_mha.py b/pyhealth/models/stagenet_mha.py index f95d1e787..7cac034d3 100644 --- a/pyhealth/models/stagenet_mha.py +++ b/pyhealth/models/stagenet_mha.py @@ -8,6 +8,7 @@ from pyhealth.models import BaseModel from pyhealth.models.utils import get_last_visit from .transformer import MultiHeadedAttention +from pyhealth.interpret.api import InterpretableModelInterface from .embedding import EmbeddingModel @@ -297,7 +298,7 @@ def forward( return last_output, output, distance -class StageAttentionNet(BaseModel): +class StageAttentionNet(BaseModel, InterpretableModelInterface): """StageAttentionNet model. Paper: Junyi Gao et al. Stagenet: Stage-aware neural networks for health diff --git a/pyhealth/models/transformer.py b/pyhealth/models/transformer.py index f1e28304d..b571cda99 100644 --- a/pyhealth/models/transformer.py +++ b/pyhealth/models/transformer.py @@ -12,6 +12,7 @@ from pyhealth.datasets import SampleDataset from pyhealth.models import BaseModel from pyhealth.models.embedding import EmbeddingModel +from pyhealth.interpret.api import InterpretableModelInterface # VALID_OPERATION_LEVEL = ["visit", "event"] @@ -312,7 +313,7 @@ def forward( return emb, cls_emb -class Transformer(BaseModel): +class Transformer(BaseModel, InterpretableModelInterface): """Transformer model for PyHealth 2.0 datasets. Each feature stream is embedded with :class:`EmbeddingModel` and encoded by From 6389f0a55882116c374fb22f71217865952c18b6 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Tue, 10 Feb 2026 19:56:55 -0600 Subject: [PATCH 03/15] First draft for chefer API --- pyhealth/interpret/api.py | 290 +++++++++++++++++++++++++++++++++++++- 1 file changed, 289 insertions(+), 1 deletion(-) diff --git a/pyhealth/interpret/api.py b/pyhealth/interpret/api.py index 09d521c68..81e32d6b8 100644 --- a/pyhealth/interpret/api.py +++ b/pyhealth/interpret/api.py @@ -21,6 +21,8 @@ class InterpretableModelInterface(ABC): Methods ------- + forward + Standard forward pass of the model. forward_from_embedding Perform forward pass starting from embeddings. get_embedding_model @@ -66,6 +68,43 @@ def forward(self, x): return x """ + def forward( + self, + **kwargs: torch.Tensor | tuple[torch.Tensor, ...], + ) -> dict[str, torch.Tensor]: + """Forward pass of the model. + + This is the standard entry point for running the model on a batch + of data. It accepts the raw feature tensors (as produced by the + dataloader) and returns predictions. + + Parameters + ---------- + **kwargs : torch.Tensor or tuple[torch.Tensor, ...] + Keyword arguments keyed by the model's ``feature_keys`` and + ``label_keys``. Each value is either a single tensor or a + tuple of tensors (e.g. ``(value, mask)``). + + Returns + ------- + dict[str, torch.Tensor] + A dictionary containing at least: + + - **logit** (torch.Tensor): Raw model predictions / logits of + shape ``(batch_size, num_classes)``. + - **y_prob** (torch.Tensor): Predicted probabilities. + - **loss** (torch.Tensor, optional): Scalar loss, present only + when label keys are included in ``kwargs``. + - **y_true** (torch.Tensor, optional): Ground-truth labels, + present only when label keys are included in ``kwargs``. + + Raises + ------ + NotImplementedError + If the subclass does not implement this method. + """ + raise NotImplementedError + def forward_from_embedding( self, **kwargs: torch.Tensor | tuple[torch.Tensor, ...] @@ -215,4 +254,253 @@ def get_embedding_model(self) -> nn.Module | None: >>> def get_embedding_model(self): ... return None # Embeddings are not separately accessible """ - raise NotImplementedError \ No newline at end of file + raise NotImplementedError + + +class CheferInterpretableModelInterface(InterpretableModelInterface): + """Abstract interface for models supporting Chefer relevance attribution. + + This is a subclass of :class:`InterpretableModelInterface` and therefore + inherits the embedding-level interface (``forward_from_embedding``, + ``get_embedding_model``). Models that implement this interface + automatically satisfy the general interpretability contract **and** the + Chefer-specific contract, so they work with both embedding-perturbation + methods (DeepLIFT, LIME, …) and gradient-weighted attention methods + (Chefer). + + The Chefer algorithm works as follows: + + 1. **Forward + hook registration** — run the model while capturing + attention weight tensors and registering backward hooks so their + gradients are stored. + 2. **Backward** — back-propagate from a one-hot target class through + the logits. + 3. **Relevance propagation** — for every feature key, iterate over + attention layers, compute gradient-weighted attention + (``clamp(attn * grad, min=0)``), and accumulate into a relevance + matrix ``R`` via ``R += cam @ R``. + 4. **Attribution extraction** — extract the final per-token + attribution from ``R`` (e.g. read the CLS row, or the + last-valid-timestep row, possibly with reshaping). + + Steps 1, 3-b and 4 are model-specific; the rest is generic. This + interface captures exactly those model-specific pieces. + + Inherited from ``InterpretableModelInterface`` + ----------------------------------------------- + forward_from_embedding(**kwargs) -> dict[str, Tensor] + Forward pass starting from pre-computed embeddings. + get_embedding_model() -> nn.Module | None + Access the embedding / feature-extraction stage. + + Additional (Chefer-specific) methods + ------------------------------------- + set_attention_hooks(enabled) -> None + Toggle attention map capture and gradient hook registration. + get_attention_layers() -> dict[str, list[tuple[Tensor, Tensor]]] + Paired (attn_map, attn_grad) for each attention layer, keyed by + feature key. + extract_attribution(feature_key, R, **data) -> Tensor + Extract per-token attribution from the relevance matrix. + + Attributes + ---------- + feature_keys : list[str] + The feature keys from the task's ``input_schema`` (e.g. + ``["conditions", "procedures"]``). Already provided by + :class:`~pyhealth.models.base_model.BaseModel`. + + Notes + ----- + * ``set_attention_hooks(True)`` must be called **before** the forward + pass, and ``get_attention_layers`` must be called **after** the + forward + backward passes, because attention maps are populated + during forward and gradients during backward. + * The interface intentionally does **not** prescribe how hooks are + registered internally — ``nn.MultiheadAttention`` with + ``register_hook``, manual ``save_attn_grad`` callbacks, or explicit + QKV computation all work as long as the getter methods return the + right tensors. + + Examples + -------- + Minimal skeleton for a new model: + + >>> class MyAttentionModel(BaseModel, CheferInterpretableModelInterface): + ... # feature_keys is inherited from BaseModel + ... + ... def forward_from_embedding(self, **kwargs): + ... # ... prediction head from pre-computed embeddings ... + ... + ... def get_embedding_model(self): + ... return self.embedding_layer + ... + ... def set_attention_hooks(self, enabled): + ... self._register_hooks = enabled + ... + ... def get_attention_layers(self): + ... result = {} + ... for key in self.feature_keys: + ... result[key] = [ + ... (blk.attention.get_attn_map(), + ... blk.attention.get_attn_grad()) + ... for blk in self.encoder[key].blocks + ... ] + ... return result + ... + ... def extract_attribution(self, feature_key, R, **data): + ... return R[:, 0] # CLS token row + """ + + @abstractmethod + def set_attention_hooks(self, enabled: bool) -> None: + """Toggle attention hook registration for subsequent forward passes. + + When ``enabled=True``, the next call to ``forward()`` (or + ``forward_from_embedding()``) must: + + 1. Store attention weight tensors so they are retrievable via + :meth:`get_attention_layers`. + 2. Register backward hooks on those tensors so that after + ``.backward()`` the corresponding gradients are also stored. + + When ``enabled=False``, subsequent forward passes should **not** + capture attention maps or register gradient hooks, restoring the + model to its normal (faster) execution mode. + + Parameters + ---------- + enabled : bool + ``True`` to start capturing attention maps and registering + gradient hooks; ``False`` to stop. + + Typical implementations set an internal flag that the model's + forward method checks:: + + def set_attention_hooks(self, enabled): + self._attention_hooks_enabled = enabled + + And inside the forward / encoder logic:: + + if self._attention_hooks_enabled: + attn.register_hook(self.save_attn_grad) + """ + ... + + @abstractmethod + def get_attention_layers( + self, + ) -> dict[str, list[tuple[torch.Tensor, torch.Tensor]]]: + """Return (attention_map, attention_gradient) pairs for all feature keys. + + Must be called **after** ``set_attention_hooks(True)``, + a ``forward()`` call, and a subsequent ``backward()`` call so + that both attention maps and their gradients are populated. + + Returns + ------- + dict[str, list[tuple[torch.Tensor, torch.Tensor]]] + A dictionary keyed by ``feature_keys``. Each value is a list + with one ``(attn_map, attn_grad)`` tuple per attention layer, + ordered from the first (closest to input) to the last + (closest to output). + + Each tensor may have shape: + + * ``[batch, heads, seq, seq]`` — multi-head (will be + gradient-weighted-averaged across heads by Chefer). + * ``[batch, seq, seq]`` — already head-averaged. + + ``attn_map`` and ``attn_grad`` in the same tuple must have + the same shape. + + Examples + -------- + A model with stacked ``TransformerBlock`` layers per feature key: + + >>> def get_attention_layers(self): + ... return { + ... key: [ + ... (blk.attention.get_attn_map(), + ... blk.attention.get_attn_grad()) + ... for blk in self.transformer[key].transformer + ... ] + ... for key in self.feature_keys + ... } + + A model with a single MHA layer per feature key: + + >>> def get_attention_layers(self): + ... return { + ... key: [(self.stagenet[key].get_attn_map(), + ... self.stagenet[key].get_attn_grad())] + ... for key in self.feature_keys + ... } + """ + ... + + @abstractmethod + def extract_attribution( + self, + feature_key: str, + R: torch.Tensor, + **data: torch.Tensor | tuple[torch.Tensor, ...], + ) -> torch.Tensor: + """Extract per-token attribution from the relevance matrix. + + The Chefer algorithm builds a relevance matrix ``R`` of shape + ``[batch, seq_len, seq_len]`` for each feature key. This method + extracts the final attribution vector from ``R``, giving the + model full control over how the extraction is done. + + For most models this means selecting a single row (the + classification token's row) from ``R``. But the method can + perform any transformation — including post-processing such as + dropping columns, reshaping, or interpolation — giving maximum + flexibility. + + Parameters + ---------- + feature_key : str + One of the model's ``feature_keys``. + R : torch.Tensor + Relevance matrix of shape ``[batch, seq_len, seq_len]``. + **data : torch.Tensor or tuple[torch.Tensor, ...] + The original input data (same kwargs passed to + ``forward()``). Available for context when the extraction + logic is data-dependent (e.g. last valid timestep depends on + mask). + + Returns + ------- + torch.Tensor + Attribution tensor. Shape is model-dependent: + + * EHR models with CLS token: ``R[:, 0]`` → + ``[batch, seq_len]``. + * Last-valid-timestep models: ``R[i, last_idx[i]]`` → + ``[batch, seq_len]``. + + Examples + -------- + CLS-token model (e.g. Transformer): + + >>> def extract_attribution(self, feature_key, R, **data): + ... return R[:, 0] + + Last-valid-timestep model (e.g. StageAttentionNet): + + >>> def extract_attribution(self, feature_key, R, **data): + ... mask = self._get_mask(feature_key, **data) + ... last_idx = mask.sum(dim=1) - 1 # [batch] + ... batch_idx = torch.arange(R.shape[0], device=R.device) + ... return R[batch_idx, last_idx] + """ + ... + + # TODO: Add postprocess_attribution() when ViT support is ready. + # ViT models need to strip the CLS column, reshape the patch vector + # into a spatial [batch, 1, H, W] map, and optionally interpolate to + # the original image size. For EHR models this is a no-op. We can + # either fold this into extract_attribution() or add it as a separate + # optional method. \ No newline at end of file From 4f40094a5ad864f526cabfe01382d5c74449e9ae Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Tue, 10 Feb 2026 20:04:29 -0600 Subject: [PATCH 04/15] rename API --- pyhealth/interpret/api.py | 6 +++--- pyhealth/models/mlp.py | 4 ++-- pyhealth/models/stagenet.py | 4 ++-- pyhealth/models/stagenet_mha.py | 4 ++-- pyhealth/models/transformer.py | 4 ++-- 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/pyhealth/interpret/api.py b/pyhealth/interpret/api.py index 81e32d6b8..6d1e8a7a4 100644 --- a/pyhealth/interpret/api.py +++ b/pyhealth/interpret/api.py @@ -2,7 +2,7 @@ import torch from torch import nn -class InterpretableModelInterface(ABC): +class Interpretable(ABC): """Abstract interface for models supporting interpretability methods. This class defines the contract that models must fulfill to be compatible @@ -257,10 +257,10 @@ def get_embedding_model(self) -> nn.Module | None: raise NotImplementedError -class CheferInterpretableModelInterface(InterpretableModelInterface): +class CheferInterpretable(Interpretable): """Abstract interface for models supporting Chefer relevance attribution. - This is a subclass of :class:`InterpretableModelInterface` and therefore + This is a subclass of :class:`Interpretable` and therefore inherits the embedding-level interface (``forward_from_embedding``, ``get_embedding_model``). Models that implement this interface automatically satisfy the general interpretability contract **and** the diff --git a/pyhealth/models/mlp.py b/pyhealth/models/mlp.py index 2d8826956..4c6432225 100644 --- a/pyhealth/models/mlp.py +++ b/pyhealth/models/mlp.py @@ -5,12 +5,12 @@ from pyhealth.datasets import SampleDataset from pyhealth.models import BaseModel -from pyhealth.interpret.api import InterpretableModelInterface +from pyhealth.interpret.api import Interpretable from .embedding import EmbeddingModel -class MLP(BaseModel, InterpretableModelInterface): +class MLP(BaseModel, Interpretable): """Multi-layer perceptron model. This model applies a separate MLP layer for each feature, and then diff --git a/pyhealth/models/stagenet.py b/pyhealth/models/stagenet.py index 2bdf12296..aad794b02 100644 --- a/pyhealth/models/stagenet.py +++ b/pyhealth/models/stagenet.py @@ -6,7 +6,7 @@ from pyhealth.datasets import SampleDataset from pyhealth.models import BaseModel from pyhealth.models.utils import get_last_visit -from pyhealth.interpret.api import InterpretableModelInterface +from pyhealth.interpret.api import Interpretable from .embedding import EmbeddingModel @@ -240,7 +240,7 @@ def forward( return last_output, output, torch.stack(distance) -class StageNet(BaseModel, InterpretableModelInterface): +class StageNet(BaseModel, Interpretable): """StageNet model. Paper: Junyi Gao et al. Stagenet: Stage-aware neural networks for health diff --git a/pyhealth/models/stagenet_mha.py b/pyhealth/models/stagenet_mha.py index 7cac034d3..f6f1f67ed 100644 --- a/pyhealth/models/stagenet_mha.py +++ b/pyhealth/models/stagenet_mha.py @@ -8,7 +8,7 @@ from pyhealth.models import BaseModel from pyhealth.models.utils import get_last_visit from .transformer import MultiHeadedAttention -from pyhealth.interpret.api import InterpretableModelInterface +from pyhealth.interpret.api import Interpretable from .embedding import EmbeddingModel @@ -298,7 +298,7 @@ def forward( return last_output, output, distance -class StageAttentionNet(BaseModel, InterpretableModelInterface): +class StageAttentionNet(BaseModel, Interpretable): """StageAttentionNet model. Paper: Junyi Gao et al. Stagenet: Stage-aware neural networks for health diff --git a/pyhealth/models/transformer.py b/pyhealth/models/transformer.py index b571cda99..fa5cbcf1c 100644 --- a/pyhealth/models/transformer.py +++ b/pyhealth/models/transformer.py @@ -12,7 +12,7 @@ from pyhealth.datasets import SampleDataset from pyhealth.models import BaseModel from pyhealth.models.embedding import EmbeddingModel -from pyhealth.interpret.api import InterpretableModelInterface +from pyhealth.interpret.api import Interpretable # VALID_OPERATION_LEVEL = ["visit", "event"] @@ -313,7 +313,7 @@ def forward( return emb, cls_emb -class Transformer(BaseModel, InterpretableModelInterface): +class Transformer(BaseModel, Interpretable): """Transformer model for PyHealth 2.0 datasets. Each feature stream is embedded with :class:`EmbeddingModel` and encoded by From cafaae8d3a5268ead104fbd8e3d28c3cc2883a99 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Tue, 10 Feb 2026 20:13:57 -0600 Subject: [PATCH 05/15] Update typehint --- pyhealth/interpret/methods/base_interpreter.py | 8 +++++++- pyhealth/interpret/methods/deeplift.py | 6 +++--- pyhealth/interpret/methods/gim.py | 5 ++--- pyhealth/interpret/methods/ig_gim.py | 4 ++-- pyhealth/interpret/methods/integrated_gradients.py | 4 ++-- pyhealth/interpret/methods/lime.py | 4 ++-- pyhealth/interpret/methods/shap.py | 4 ++-- 7 files changed, 20 insertions(+), 15 deletions(-) diff --git a/pyhealth/interpret/methods/base_interpreter.py b/pyhealth/interpret/methods/base_interpreter.py index de75c897a..230a860ec 100644 --- a/pyhealth/interpret/methods/base_interpreter.py +++ b/pyhealth/interpret/methods/base_interpreter.py @@ -16,7 +16,13 @@ import torch.nn as nn from pyhealth.models import BaseModel +from pyhealth.interpret.api import Interpretable, CheferInterpretable +class _InterpretableModel(BaseModel, Interpretable): + pass + +class _CheferInterpretableModel(BaseModel, CheferInterpretable): + pass class BaseInterpreter(ABC): """Abstract base class for interpretability methods. @@ -97,7 +103,7 @@ class BaseInterpreter(ABC): >>> print(attributions["image"].shape) # [batch, 1, H, W] """ - def __init__(self, model: BaseModel): + def __init__(self, model: _InterpretableModel): """Initialize the base interpreter. Args: diff --git a/pyhealth/interpret/methods/deeplift.py b/pyhealth/interpret/methods/deeplift.py index 211aeff3b..f6ab49896 100644 --- a/pyhealth/interpret/methods/deeplift.py +++ b/pyhealth/interpret/methods/deeplift.py @@ -1,13 +1,13 @@ from __future__ import annotations import contextlib -from typing import Dict, List, Optional, Tuple, Type +from typing import Dict, List, Optional, Tuple, Type, cast import torch import torch.nn.functional as F from pyhealth.models import BaseModel -from .base_interpreter import BaseInterpreter +from .base_interpreter import BaseInterpreter, _InterpretableModel def _iter_child_modules(module: torch.nn.Module): @@ -328,7 +328,7 @@ class DeepLift(BaseInterpreter): Learning (ICML), 2017. https://proceedings.mlr.press/v70/shrikumar17a.html """ - def __init__(self, model: BaseModel, use_embeddings: bool = True): + def __init__(self, model: _InterpretableModel, use_embeddings: bool = True): super().__init__(model) self.use_embeddings = use_embeddings diff --git a/pyhealth/interpret/methods/gim.py b/pyhealth/interpret/methods/gim.py index f53dd3199..35ebcdcaf 100644 --- a/pyhealth/interpret/methods/gim.py +++ b/pyhealth/interpret/methods/gim.py @@ -8,8 +8,7 @@ import torch.nn.functional as F from pyhealth.models import BaseModel - -from .base_interpreter import BaseInterpreter +from .base_interpreter import BaseInterpreter, _InterpretableModel def _iter_child_modules(module: torch.nn.Module): @@ -347,7 +346,7 @@ class GIM(BaseInterpreter): def __init__( self, - model: BaseModel, + model: _InterpretableModel, temperature: float = 2.0, ): super().__init__(model) diff --git a/pyhealth/interpret/methods/ig_gim.py b/pyhealth/interpret/methods/ig_gim.py index b0a81d7b4..ca6798a8c 100644 --- a/pyhealth/interpret/methods/ig_gim.py +++ b/pyhealth/interpret/methods/ig_gim.py @@ -35,7 +35,7 @@ from pyhealth.models import BaseModel -from .base_interpreter import BaseInterpreter +from .base_interpreter import BaseInterpreter, _InterpretableModel from .gim import _GIMHookContext @@ -74,7 +74,7 @@ class IntegratedGradientGIM(BaseInterpreter): def __init__( self, - model: BaseModel, + model: _InterpretableModel, temperature: float = 2.0, steps: int = 50, ): diff --git a/pyhealth/interpret/methods/integrated_gradients.py b/pyhealth/interpret/methods/integrated_gradients.py index 7cc24b39d..06701fb1c 100644 --- a/pyhealth/interpret/methods/integrated_gradients.py +++ b/pyhealth/interpret/methods/integrated_gradients.py @@ -7,7 +7,7 @@ from pyhealth.models import BaseModel -from .base_interpreter import BaseInterpreter +from .base_interpreter import BaseInterpreter, _InterpretableModel class IntegratedGradients(BaseInterpreter): @@ -166,7 +166,7 @@ class IntegratedGradients(BaseInterpreter): ... ) """ - def __init__(self, model: BaseModel, use_embeddings: bool = True, steps: int = 50): + def __init__(self, model: _InterpretableModel, use_embeddings: bool = True, steps: int = 50): """Initialize IntegratedGradients interpreter. Args: diff --git a/pyhealth/interpret/methods/lime.py b/pyhealth/interpret/methods/lime.py index 7518aaa50..7d0f39580 100644 --- a/pyhealth/interpret/methods/lime.py +++ b/pyhealth/interpret/methods/lime.py @@ -8,7 +8,7 @@ from torch.nn import CosineSimilarity from pyhealth.models import BaseModel -from .base_interpreter import BaseInterpreter +from .base_interpreter import BaseInterpreter, _InterpretableModel class LimeExplainer(BaseInterpreter): @@ -102,7 +102,7 @@ class LimeExplainer(BaseInterpreter): def __init__( self, - model: BaseModel, + model: _InterpretableModel, use_embeddings: bool = True, n_samples: int = 1000, kernel_width: float = 0.25, diff --git a/pyhealth/interpret/methods/shap.py b/pyhealth/interpret/methods/shap.py index c68a2e08f..83565099e 100644 --- a/pyhealth/interpret/methods/shap.py +++ b/pyhealth/interpret/methods/shap.py @@ -6,7 +6,7 @@ import torch from pyhealth.models import BaseModel -from .base_interpreter import BaseInterpreter +from .base_interpreter import BaseInterpreter, _InterpretableModel class ShapExplainer(BaseInterpreter): @@ -94,7 +94,7 @@ class ShapExplainer(BaseInterpreter): def __init__( self, - model: BaseModel, + model: _InterpretableModel, use_embeddings: bool = True, n_background_samples: int = 100, max_coalitions: int = 1000, From da80a5e0c52a33c5b856329e8b66b6a6843c5204 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Tue, 10 Feb 2026 20:27:15 -0600 Subject: [PATCH 06/15] Intial attempt for Chefer API --- pyhealth/interpret/api.py | 69 +++++++++---------- .../interpret/methods/base_interpreter.py | 4 +- pyhealth/models/stagenet_mha.py | 60 ++++++++++++++-- pyhealth/models/transformer.py | 41 +++++++++-- 4 files changed, 127 insertions(+), 47 deletions(-) diff --git a/pyhealth/interpret/api.py b/pyhealth/interpret/api.py index 6d1e8a7a4..675ab92fe 100644 --- a/pyhealth/interpret/api.py +++ b/pyhealth/interpret/api.py @@ -300,8 +300,8 @@ class CheferInterpretable(Interpretable): get_attention_layers() -> dict[str, list[tuple[Tensor, Tensor]]] Paired (attn_map, attn_grad) for each attention layer, keyed by feature key. - extract_attribution(feature_key, R, **data) -> Tensor - Extract per-token attribution from the relevance matrix. + get_relevance_vector(R, **data) -> dict[str, Tensor] + Reduce relevance matrices to per-token attribution vectors. Attributes ---------- @@ -348,8 +348,8 @@ class CheferInterpretable(Interpretable): ... ] ... return result ... - ... def extract_attribution(self, feature_key, R, **data): - ... return R[:, 0] # CLS token row + ... def get_relevance_vector(self, R, **data): + ... return {key: r[:, 0] for key, r in R.items()} """ @abstractmethod @@ -440,61 +440,54 @@ def get_attention_layers( ... @abstractmethod - def extract_attribution( + def get_relevance_vector( self, - feature_key: str, - R: torch.Tensor, + R: dict[str, torch.Tensor], **data: torch.Tensor | tuple[torch.Tensor, ...], - ) -> torch.Tensor: - """Extract per-token attribution from the relevance matrix. + ) -> dict[str, torch.Tensor]: + """Reduce relevance matrices to per-token attribution vectors. - The Chefer algorithm builds a relevance matrix ``R`` of shape + The Chefer algorithm builds a relevance matrix of shape ``[batch, seq_len, seq_len]`` for each feature key. This method - extracts the final attribution vector from ``R``, giving the - model full control over how the extraction is done. - - For most models this means selecting a single row (the - classification token's row) from ``R``. But the method can - perform any transformation — including post-processing such as - dropping columns, reshaping, or interpolation — giving maximum - flexibility. + reduces each matrix to a ``[batch, seq_len]`` vector by selecting + the row corresponding to the classification position — giving the + model full control over how the selection is done. Parameters ---------- - feature_key : str - One of the model's ``feature_keys``. - R : torch.Tensor - Relevance matrix of shape ``[batch, seq_len, seq_len]``. + R : dict[str, torch.Tensor] + Relevance matrices keyed by ``feature_keys``. Each tensor + has shape ``[batch, seq_len, seq_len]`` (seq_len may differ + across keys). **data : torch.Tensor or tuple[torch.Tensor, ...] The original input data (same kwargs passed to - ``forward()``). Available for context when the extraction + ``forward()``). Available for context when the selection logic is data-dependent (e.g. last valid timestep depends on mask). Returns ------- - torch.Tensor - Attribution tensor. Shape is model-dependent: - - * EHR models with CLS token: ``R[:, 0]`` → - ``[batch, seq_len]``. - * Last-valid-timestep models: ``R[i, last_idx[i]]`` → - ``[batch, seq_len]``. + dict[str, torch.Tensor] + Attribution vectors keyed by ``feature_keys``. Each tensor + has shape ``[batch, seq_len]``. Examples -------- - CLS-token model (e.g. Transformer): + CLS-token model (e.g. Transformer) — row 0 for all keys: - >>> def extract_attribution(self, feature_key, R, **data): - ... return R[:, 0] + >>> def get_relevance_vector(self, R, **data): + ... return {key: r[:, 0] for key, r in R.items()} Last-valid-timestep model (e.g. StageAttentionNet): - >>> def extract_attribution(self, feature_key, R, **data): - ... mask = self._get_mask(feature_key, **data) - ... last_idx = mask.sum(dim=1) - 1 # [batch] - ... batch_idx = torch.arange(R.shape[0], device=R.device) - ... return R[batch_idx, last_idx] + >>> def get_relevance_vector(self, R, **data): + ... result = {} + ... for key, r in R.items(): + ... mask = self._get_mask(key, **data) + ... last_idx = mask.sum(dim=1) - 1 + ... batch_idx = torch.arange(r.shape[0], device=r.device) + ... result[key] = r[batch_idx, last_idx] + ... return result """ ... diff --git a/pyhealth/interpret/methods/base_interpreter.py b/pyhealth/interpret/methods/base_interpreter.py index 230a860ec..b8b129043 100644 --- a/pyhealth/interpret/methods/base_interpreter.py +++ b/pyhealth/interpret/methods/base_interpreter.py @@ -24,6 +24,8 @@ class _InterpretableModel(BaseModel, Interpretable): class _CheferInterpretableModel(BaseModel, CheferInterpretable): pass +type _AnyInterpretableModel = _InterpretableModel | _CheferInterpretableModel + class BaseInterpreter(ABC): """Abstract base class for interpretability methods. @@ -103,7 +105,7 @@ class BaseInterpreter(ABC): >>> print(attributions["image"].shape) # [batch, 1, H, W] """ - def __init__(self, model: _InterpretableModel): + def __init__(self, model: _AnyInterpretableModel): """Initialize the base interpreter. Args: diff --git a/pyhealth/models/stagenet_mha.py b/pyhealth/models/stagenet_mha.py index f6f1f67ed..987e4394f 100644 --- a/pyhealth/models/stagenet_mha.py +++ b/pyhealth/models/stagenet_mha.py @@ -8,7 +8,7 @@ from pyhealth.models import BaseModel from pyhealth.models.utils import get_last_visit from .transformer import MultiHeadedAttention -from pyhealth.interpret.api import Interpretable +from pyhealth.interpret.api import CheferInterpretable from .embedding import EmbeddingModel @@ -298,7 +298,7 @@ def forward( return last_output, output, distance -class StageAttentionNet(BaseModel, Interpretable): +class StageAttentionNet(BaseModel, CheferInterpretable): """StageAttentionNet model. Paper: Junyi Gao et al. Stagenet: Stage-aware neural networks for health @@ -406,6 +406,8 @@ def __init__( self.embedding_dim = embedding_dim self.chunk_size = chunk_size self.levels = levels + self._attention_hooks_enabled = False + self._masks: dict[str, torch.Tensor] = {} # validate kwargs for StageNet layer if "input_dim" in kwargs: @@ -475,7 +477,8 @@ def forward_from_embedding( logit: the raw logits before activation. embed: (if embed=True in kwargs) the patient embedding. """ - register_attn_hook = kwargs.pop("register_attn_hook", False) + # Support both the flag-based API and legacy kwarg-based API + register_attn_hook = self._attention_hooks_enabled patient_emb = [] distance = [] @@ -545,6 +548,9 @@ def forward_from_embedding( value, time=time, mask=mask, register_hook=register_attn_hook ) + # Store the final mask for get_relevance_vector + self._masks[feature_key] = mask + patient_emb.append(last_output) distance.append(cur_dis) @@ -598,7 +604,7 @@ def forward( """ register_attn_hook = kwargs.pop("register_attn_hook", False) if register_attn_hook: - kwargs["register_attn_hook"] = register_attn_hook # type: ignore + kwargs["register_attn_hook"] = register_attn_hook # type: ignore for feature_key in self.feature_keys: feature = kwargs[feature_key] @@ -629,3 +635,49 @@ def forward( def get_embedding_model(self) -> nn.Module | None: return self.embedding_model + + # ------------------------------------------------------------------ + # CheferInterpretable interface + # ------------------------------------------------------------------ + + def set_attention_hooks(self, enabled: bool) -> None: + self._attention_hooks_enabled = enabled + + def get_attention_layers( + self, + ) -> dict[str, list[tuple[torch.Tensor, torch.Tensor]]]: + return { # type: ignore[return-value] + key: [ + ( + cast(StageNetAttentionLayer, self.stagenet[key]).get_attn_map(), + cast(StageNetAttentionLayer, self.stagenet[key]).get_attn_grad(), + ) + ] + for key in self.feature_keys + } + + def get_relevance_vector( + self, + R: dict[str, torch.Tensor], + **data: torch.Tensor | tuple[torch.Tensor, ...], + ) -> dict[str, torch.Tensor]: + # StageAttentionNet uses get_last_visit (last valid timestep) + # instead of a CLS token. Use the masks stored during forward. + result = {} + for key in self.feature_keys: + r = R[key] + batch_size = r.shape[0] + device = r.device + mask = self._masks.get(key) + if mask is not None: + last_idx = mask.sum(dim=1).long() - 1 + last_idx = last_idx.clamp(min=0) + else: + # No mask stored → fall back to last position + last_idx = torch.full( + (batch_size,), r.shape[1] - 1, device=device, dtype=torch.long + ) + result[key] = r[ + torch.arange(batch_size, device=device), last_idx + ] + return result diff --git a/pyhealth/models/transformer.py b/pyhealth/models/transformer.py index fa5cbcf1c..05a9880af 100644 --- a/pyhealth/models/transformer.py +++ b/pyhealth/models/transformer.py @@ -12,7 +12,7 @@ from pyhealth.datasets import SampleDataset from pyhealth.models import BaseModel from pyhealth.models.embedding import EmbeddingModel -from pyhealth.interpret.api import Interpretable +from pyhealth.interpret.api import CheferInterpretable # VALID_OPERATION_LEVEL = ["visit", "event"] @@ -313,7 +313,7 @@ def forward( return emb, cls_emb -class Transformer(BaseModel, Interpretable): +class Transformer(BaseModel, CheferInterpretable): """Transformer model for PyHealth 2.0 datasets. Each feature stream is embedded with :class:`EmbeddingModel` and encoded by @@ -374,6 +374,7 @@ def __init__( self.heads = heads self.dropout = dropout self.num_layers = num_layers + self._attention_hooks_enabled = False assert ( len(self.label_keys) == 1 @@ -383,7 +384,7 @@ def __init__( self.embedding_model = EmbeddingModel(dataset, embedding_dim) - self.transformer = nn.ModuleDict() + self.transformer: nn.ModuleDict = nn.ModuleDict() for feature_key in self.feature_keys: self.transformer[feature_key] = TransformerLayer( feature_size=embedding_dim, @@ -466,7 +467,8 @@ def forward_from_embedding( logit: the raw logits before activation. embed: (if embed=True in kwargs) the patient embedding. """ - register_hook = bool(kwargs.pop("register_hook", False)) + # Support both the flag-based API and legacy kwarg-based API + register_hook = self._attention_hooks_enabled patient_emb = [] for feature_key in self.feature_keys: @@ -580,6 +582,37 @@ def get_embedding_model(self) -> nn.Module | None: """ return self.embedding_model + # ------------------------------------------------------------------ + # CheferInterpretable interface + # ------------------------------------------------------------------ + + def set_attention_hooks(self, enabled: bool) -> None: + self._attention_hooks_enabled = enabled + + def get_attention_layers( + self, + ) -> dict[str, list[tuple[torch.Tensor, torch.Tensor]]]: + return { # type: ignore[return-value] + key: [ + ( + cast(TransformerBlock, blk).attention.get_attn_map(), + cast(TransformerBlock, blk).attention.get_attn_grad(), + ) + for blk in cast( + TransformerLayer, self.transformer[key] + ).transformer + ] + for key in self.feature_keys + } + + def get_relevance_vector( + self, + R: dict[str, torch.Tensor], + **data: torch.Tensor | tuple[torch.Tensor, ...], + ) -> dict[str, torch.Tensor]: + # CLS token is at index 0 for all feature keys + return {key: r[:, 0] for key, r in R.items()} + if __name__ == "__main__": from pyhealth.datasets import create_sample_dataset, get_dataloader From 41aa26e5f2ae4f39cab889151e757f44ee86406d Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Tue, 10 Feb 2026 20:28:55 -0600 Subject: [PATCH 07/15] Fix get_relevance_vector --- pyhealth/models/stagenet_mha.py | 53 +++++++++++++++++++++++++-------- 1 file changed, 40 insertions(+), 13 deletions(-) diff --git a/pyhealth/models/stagenet_mha.py b/pyhealth/models/stagenet_mha.py index 987e4394f..b5ccb5ac6 100644 --- a/pyhealth/models/stagenet_mha.py +++ b/pyhealth/models/stagenet_mha.py @@ -407,7 +407,6 @@ def __init__( self.chunk_size = chunk_size self.levels = levels self._attention_hooks_enabled = False - self._masks: dict[str, torch.Tensor] = {} # validate kwargs for StageNet layer if "input_dim" in kwargs: @@ -548,9 +547,6 @@ def forward_from_embedding( value, time=time, mask=mask, register_hook=register_attn_hook ) - # Store the final mask for get_relevance_vector - self._masks[feature_key] = mask - patient_emb.append(last_output) distance.append(cur_dis) @@ -662,21 +658,52 @@ def get_relevance_vector( **data: torch.Tensor | tuple[torch.Tensor, ...], ) -> dict[str, torch.Tensor]: # StageAttentionNet uses get_last_visit (last valid timestep) - # instead of a CLS token. Use the masks stored during forward. + # instead of a CLS token. Derive the mask from **data using + # the same logic as forward_from_embedding. result = {} for key in self.feature_keys: r = R[key] batch_size = r.shape[0] device = r.device - mask = self._masks.get(key) - if mask is not None: - last_idx = mask.sum(dim=1).long() - 1 - last_idx = last_idx.clamp(min=0) + + processor = self.dataset.input_processors[key] + schema = processor.schema() + feature = data[key] + + if isinstance(feature, torch.Tensor): + feature = (feature,) + + value = feature[schema.index("value")] if "value" in schema else None + mask = feature[schema.index("mask")] if "mask" in schema else None + + if len(feature) == len(schema) + 1 and mask is None: + mask = feature[-1] + + if mask is None: + if value is not None: + v = value.to(device) + mask = (v.abs().sum(dim=-1) != 0).int() + else: + # Cannot determine mask; fall back to last position + last_idx = torch.full( + (batch_size,), r.shape[1] - 1, + device=device, dtype=torch.long, + ) + result[key] = r[ + torch.arange(batch_size, device=device), last_idx + ] + continue else: - # No mask stored → fall back to last position - last_idx = torch.full( - (batch_size,), r.shape[1] - 1, device=device, dtype=torch.long - ) + mask = mask.to(device) + if not processor.is_token() and value is not None and value.dim() == mask.dim(): + mask = mask.any(dim=-1).int() + + if mask.dim() == 3: + # Nested sequences: collapse inner dimension + mask = mask.any(dim=2).int() + + last_idx = mask.sum(dim=1).long() - 1 + last_idx = last_idx.clamp(min=0) result[key] = r[ torch.arange(batch_size, device=device), last_idx ] From 1242ef6e17677030e72ded766de326d352e94de5 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Tue, 10 Feb 2026 20:29:27 -0600 Subject: [PATCH 08/15] rename method --- pyhealth/interpret/api.py | 2 +- pyhealth/models/stagenet_mha.py | 2 +- pyhealth/models/transformer.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyhealth/interpret/api.py b/pyhealth/interpret/api.py index 675ab92fe..f84b8d249 100644 --- a/pyhealth/interpret/api.py +++ b/pyhealth/interpret/api.py @@ -440,7 +440,7 @@ def get_attention_layers( ... @abstractmethod - def get_relevance_vector( + def get_relevance_tensor( self, R: dict[str, torch.Tensor], **data: torch.Tensor | tuple[torch.Tensor, ...], diff --git a/pyhealth/models/stagenet_mha.py b/pyhealth/models/stagenet_mha.py index b5ccb5ac6..7dba08cea 100644 --- a/pyhealth/models/stagenet_mha.py +++ b/pyhealth/models/stagenet_mha.py @@ -652,7 +652,7 @@ def get_attention_layers( for key in self.feature_keys } - def get_relevance_vector( + def get_relevance_tensor( self, R: dict[str, torch.Tensor], **data: torch.Tensor | tuple[torch.Tensor, ...], diff --git a/pyhealth/models/transformer.py b/pyhealth/models/transformer.py index 05a9880af..69c69d0be 100644 --- a/pyhealth/models/transformer.py +++ b/pyhealth/models/transformer.py @@ -605,7 +605,7 @@ def get_attention_layers( for key in self.feature_keys } - def get_relevance_vector( + def get_relevance_tensor( self, R: dict[str, torch.Tensor], **data: torch.Tensor | tuple[torch.Tensor, ...], From 6eb04b0d96822b8f4465bf20b28db307771dc3a3 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Tue, 10 Feb 2026 20:41:31 -0600 Subject: [PATCH 09/15] Intial attempt for chefer --- pyhealth/interpret/methods/chefer.py | 606 ++++++++++++--------------- 1 file changed, 278 insertions(+), 328 deletions(-) diff --git a/pyhealth/interpret/methods/chefer.py b/pyhealth/interpret/methods/chefer.py index 19a4bb6a2..da48af7a0 100644 --- a/pyhealth/interpret/methods/chefer.py +++ b/pyhealth/interpret/methods/chefer.py @@ -1,29 +1,32 @@ -from typing import Dict +"""Chefer's gradient-weighted attention relevance propagation. + +This module implements the Chefer et al. relevance propagation method for +explaining transformer-family model predictions. It relies on the +:class:`~pyhealth.interpret.api.CheferInterpretable` interface — any model +that implements that interface is automatically supported. + +Paper: + Chefer, Hila, Shir Gur, and Lior Wolf. + "Generic Attention-model Explainability for Interpreting Bi-Modal and + Encoder-Decoder Transformers." + Proceedings of the IEEE/CVF International Conference on Computer Vision + (ICCV), 2021. +""" + +from typing import Dict, Optional, cast import torch import torch.nn.functional as F -from pyhealth.models import Transformer +from pyhealth.interpret.api import CheferInterpretable from pyhealth.models.base_model import BaseModel -from .base_interpreter import BaseInterpreter +from .base_interpreter import BaseInterpreter, _CheferInterpretableModel -# Import TorchvisionModel conditionally to avoid circular imports -try: - from pyhealth.models import TorchvisionModel - HAS_TORCHVISION_MODEL = True -except ImportError: - HAS_TORCHVISION_MODEL = False - TorchvisionModel = None - -# Import StageAttentionNet conditionally to avoid circular imports -try: - from pyhealth.models import StageAttentionNet - HAS_STAGEATTN = True -except ImportError: - HAS_STAGEATTN = False - StageAttentionNet = None +# --------------------------------------------------------------------------- +# Helper functions +# --------------------------------------------------------------------------- def apply_self_attention_rules(R_ss, cam_ss): """Apply Chefer's self-attention rules for relevance propagation. @@ -55,43 +58,38 @@ def avg_heads(cam, grad): return cam.clone() +# --------------------------------------------------------------------------- +# Main interpreter +# --------------------------------------------------------------------------- + class CheferRelevance(BaseInterpreter): """Chefer's gradient-weighted attention method for transformer interpretability. - This class implements the relevance propagation method from Chefer et al. for - explaining transformer model predictions. It computes relevance scores for each - input token (for text/EHR transformers) or patch (for Vision Transformers) by - combining attention weights with their gradients. + This interpreter works with **any** model that implements the + :class:`~pyhealth.interpret.api.CheferInterpretable` interface, which + currently includes: - The method works by: - 1. Performing a forward pass to capture attention maps from each layer - 2. Computing gradients of the target class w.r.t. attention weights - 3. Combining attention and gradients using element-wise multiplication - 4. Propagating relevance through layers using attention rollout rules + * :class:`~pyhealth.models.Transformer` + * :class:`~pyhealth.models.StageAttentionNet` - This approach provides more faithful explanations than raw attention weights - alone, as it accounts for how attention contributes to the final prediction. + The algorithm: - Paper: - Chefer, Hila, Shir Gur, and Lior Wolf. - "Generic Attention-model Explainability for Interpreting Bi-Modal and - Encoder-Decoder Transformers." - Proceedings of the IEEE/CVF International Conference on Computer Vision - (ICCV), 2021. + 1. Enable attention hooks via ``model.set_attention_hooks(True)``. + 2. Forward pass → capture attention maps and register gradient hooks. + 3. Backward pass from a one-hot target class. + 4. Retrieve ``(attn_map, attn_grad)`` pairs via ``model.get_attention_layers()``. + 5. Propagate relevance: ``R += clamp(attn * grad, min=0) @ R``. + 6. Reduce ``R`` to per-token vectors via ``model.get_relevance_tensor()``. - Supported Models: - - PyHealth Transformer: For sequential/EHR data with multiple feature keys - - StageAttentionNet: For temporal/EHR data with MHA-based StageNet layers - - TorchvisionModel (ViT variants): vit_b_16, vit_b_32, vit_l_16, vit_l_32, vit_h_14 + Steps 1, 4 and 6 are delegated to the model through the + ``CheferInterpretable`` interface, making this class fully + model-agnostic. Args: - model (BaseModel): A trained PyHealth model to interpret. Must be one of: - - A ``Transformer`` model for sequential/EHR data - - A ``StageAttentionNet`` model for temporal/EHR data - - A ``TorchvisionModel`` with a ViT architecture for image data + model (BaseModel): A trained PyHealth model that implements + :class:`~pyhealth.interpret.api.CheferInterpretable`. Example: - >>> # Example 1: PyHealth Transformer for EHR data >>> from pyhealth.datasets import create_sample_dataset, get_dataloader >>> from pyhealth.models import Transformer >>> from pyhealth.interpret.methods import CheferRelevance @@ -121,7 +119,6 @@ class CheferRelevance(BaseInterpreter): >>> model = Transformer(dataset=dataset) >>> # ... train the model ... >>> - >>> # Create interpreter and compute attribution >>> interpreter = CheferRelevance(model) >>> batch = next(iter(get_dataloader(dataset, batch_size=2))) >>> @@ -132,331 +129,284 @@ class CheferRelevance(BaseInterpreter): >>> >>> # Optional: attribute to a specific class (e.g., class 1) >>> attributions = interpreter.attribute(class_index=1, **batch) - >>> - >>> # Example 2: TorchvisionModel ViT for image data - >>> from pyhealth.datasets import COVID19CXRDataset - >>> from pyhealth.models import TorchvisionModel - >>> from pyhealth.interpret.utils import visualize_image_attr - >>> - >>> base_dataset = COVID19CXRDataset(root="/path/to/data") - >>> sample_dataset = base_dataset.set_task() - >>> model = TorchvisionModel( - ... dataset=sample_dataset, - ... model_name="vit_b_16", - ... model_config={"weights": "DEFAULT"}, - ... ) - >>> # ... train the model ... - >>> - >>> # Create interpreter and compute attribution - >>> # Task schema: input_schema={"image": "image"}, so feature_key="image" - >>> interpreter = CheferRelevance(model) - >>> - >>> # Default: attribute to predicted class - >>> result = interpreter.attribute(**batch) - >>> # Returns dict keyed by feature_key: {"image": tensor} - >>> attr_map = result["image"] # Shape: [batch, 1, H, W] - >>> - >>> # Optional: attribute to a specific class (e.g., predicted class) - >>> pred_class = model(**batch)["y_prob"].argmax().item() - >>> result = interpreter.attribute(class_index=pred_class, **batch) - >>> - >>> # Visualize - >>> img, attr, overlay = visualize_image_attr( - ... image=batch["image"][0], - ... attribution=result["image"][0, 0], - ... ) """ - def __init__(self, model: BaseModel): + def __init__(self, model: _CheferInterpretableModel): super().__init__(model) - - # Determine model type - self._is_transformer = isinstance(model, Transformer) - self._is_vit = False - self._is_stageattn = False - - if HAS_STAGEATTN and StageAttentionNet is not None: - self._is_stageattn = isinstance(model, StageAttentionNet) - - if HAS_TORCHVISION_MODEL and TorchvisionModel is not None: - if isinstance(model, TorchvisionModel): - self._is_vit = model.is_vit_model() - - if not self._is_transformer and not self._is_vit and not self._is_stageattn: + self.model = cast(_CheferInterpretableModel, model) + + if not isinstance(model, CheferInterpretable): raise ValueError( - f"CheferRelevance requires a Transformer, StageAttentionNet, " - f"or TorchvisionModel (ViT), got {type(model).__name__}. " - f"For TorchvisionModel, only ViT variants " - f"(vit_b_16, vit_b_32, etc.) are supported." + f"CheferRelevance requires a model implementing " + f"CheferInterpretable, got {type(model).__name__}." ) def attribute( self, - interpolate: bool = True, - class_index: int = None, + class_index: Optional[int] = None, **data, ) -> Dict[str, torch.Tensor]: - """Compute relevance scores for each token/patch. - - This is the primary method for computing attributions. Returns a - dictionary keyed by the model's feature keys (from the task schema). + """Compute relevance scores for each input token. Args: - interpolate: For ViT models, if True interpolate attribution to image size. - class_index: Target class index to compute attribution for. If None - (default), uses the model's predicted class. This is useful when - you want to explain why a specific class was predicted or to - compare attributions across different classes. - **data: Input data from dataloader batch containing: - - For Transformer/StageAttentionNet: feature keys + label - - For ViT: image feature key (e.g., "image") + label + class_index: Target class index to compute attribution for. + If None (default), uses the model's predicted class. + **data: Input data from dataloader batch containing feature + keys and label key. Returns: - Dict[str, torch.Tensor]: Dictionary keyed by feature keys from the task schema. - - - For Transformer/StageAttentionNet: - ``{"conditions": tensor, "procedures": tensor, ...}`` - where each tensor has shape ``[batch, num_tokens]``. - - For ViT: ``{"image": tensor}`` (or whatever the task's image key is) - where tensor has shape ``[batch, 1, H, W]``. + Dict[str, torch.Tensor]: Dictionary keyed by feature keys, + where each tensor has shape ``[batch, seq_len]`` with + per-token attribution scores. """ - if self._is_vit: - return self._attribute_vit( - interpolate=interpolate, - class_index=class_index, - **data - ) - if self._is_stageattn: - return self._attribute_stageattn(class_index=class_index, **data) - return self._attribute_transformer(class_index=class_index, **data) - - def _attribute_transformer( - self, - class_index: int = None, - **data - ) -> Dict[str, torch.Tensor]: - """Compute relevance for PyHealth Transformer models. - - Args: - class_index: Target class for attribution. If None, uses predicted class. - **data: Input data from dataloader batch. - """ - data["register_hook"] = True - - logits = self.model(**data)["logit"] + # --- 1. Forward with attention hooks enabled --- + self.model.set_attention_hooks(True) + try: + logits = self.model(**data)["logit"] + finally: + self.model.set_attention_hooks(False) + + # --- 2. Backward from target class --- if class_index is None: - class_index = torch.argmax(logits, dim=-1) - - if isinstance(class_index, torch.Tensor): - one_hot = F.one_hot(class_index.detach().clone(), logits.size()[1]).float() + class_index_t = torch.argmax(logits, dim=-1) + elif isinstance(class_index, int): + class_index_t = torch.tensor(class_index) else: - one_hot = F.one_hot(torch.tensor(class_index), logits.size()[1]).float() + class_index_t = class_index + + one_hot = F.one_hot( + class_index_t.detach().clone(), logits.size(1) + ).float() one_hot = one_hot.requires_grad_(True) - one_hot = torch.sum(one_hot.to(logits.device) * logits) + scalar = torch.sum(one_hot.to(logits.device) * logits) self.model.zero_grad() - one_hot.backward(retain_graph=True) + scalar.backward(retain_graph=True) - feature_keys = self.model.feature_keys - num_tokens = {} - for key in feature_keys: - feature_transformer = self.model.transformer[key].transformer - for block in feature_transformer: - num_tokens[key] = block.attention.get_attn_map().shape[-1] + # --- 3. Retrieve (attn_map, attn_grad) pairs per feature key --- + attention_layers = self.model.get_attention_layers() batch_size = logits.shape[0] - attn = {} - for key in feature_keys: + device = logits.device + + # --- 4. Relevance propagation per feature key --- + R_dict: dict[str, torch.Tensor] = {} + for key, layers in attention_layers.items(): + num_tokens = layers[0][0].shape[-1] R = ( - torch.eye(num_tokens[key]) + torch.eye(num_tokens, device=device) .unsqueeze(0) .repeat(batch_size, 1, 1) - .to(logits.device) ) - for blk in self.model.transformer[key].transformer: - grad = blk.attention.get_attn_grad() - cam = blk.attention.get_attn_map() + for cam, grad in layers: cam = avg_heads(cam, grad) - R += apply_self_attention_rules(R, cam).detach() - attn[key] = R[:, 0] + R = R + apply_self_attention_rules(R, cam).detach() + R_dict[key] = R - return attn + # --- 5. Reduce R matrices to per-token vectors --- + return self.model.get_relevance_tensor(R_dict, **data) - def _attribute_stageattn( - self, - class_index: int = None, - **data, - ) -> Dict[str, torch.Tensor]: - """Compute relevance for StageAttentionNet models. + # ------------------------------------------------------------------ + # Backward compatibility aliases + # ------------------------------------------------------------------ - StageAttentionNet has a single MHA layer per feature key (inside - ``model.stagenet[key]``) rather than a stack of TransformerBlocks. - It also uses the *last valid timestep* (via ``get_last_visit``) - instead of a CLS token for classification, so we extract the - relevance row corresponding to that timestep. + def get_relevance_matrix(self, **data): + """Alias for attribute(). Deprecated.""" + return self.attribute(**data) + + +# ====================================================================== +# LEGACY REFERENCE IMPLEMENTATIONS +# ====================================================================== +# The functions below are the original model-specific implementations +# that existed before the CheferInterpretable API was introduced. They +# are kept here ONLY as a reference for future developers and are NOT +# called by any production code. They may be removed in a future +# release. +# +# For ViT models, _reference_attribute_vit is the only implementation +# until ViT models implement CheferInterpretable. +# ====================================================================== + +def _reference_attribute_transformer( + model, + class_index=None, + **data, +) -> Dict[str, torch.Tensor]: + """[REFERENCE ONLY] Original Transformer-specific Chefer attribution. + + This was the body of ``CheferRelevance._attribute_transformer()`` + before the CheferInterpretable API was introduced. It accesses + model internals (``model.transformer[key].transformer``) directly. + """ + data["register_hook"] = True + + logits = model(**data)["logit"] + if class_index is None: + class_index = torch.argmax(logits, dim=-1) + + if isinstance(class_index, torch.Tensor): + one_hot = F.one_hot(class_index.detach().clone(), logits.size()[1]).float() + else: + one_hot = F.one_hot(torch.tensor(class_index), logits.size()[1]).float() + one_hot = one_hot.requires_grad_(True) + one_hot = torch.sum(one_hot.to(logits.device) * logits) + model.zero_grad() + one_hot.backward(retain_graph=True) + + feature_keys = model.feature_keys + num_tokens = {} + for key in feature_keys: + feature_transformer = model.transformer[key].transformer + for block in feature_transformer: + num_tokens[key] = block.attention.get_attn_map().shape[-1] + + batch_size = logits.shape[0] + attn = {} + for key in feature_keys: + R = ( + torch.eye(num_tokens[key]) + .unsqueeze(0) + .repeat(batch_size, 1, 1) + .to(logits.device) + ) + for blk in model.transformer[key].transformer: + grad = blk.attention.get_attn_grad() + cam = blk.attention.get_attn_map() + cam = avg_heads(cam, grad) + R += apply_self_attention_rules(R, cam).detach() + attn[key] = R[:, 0] - Args: - class_index: Target class for attribution. If None, uses predicted class. - **data: Input data from dataloader batch. - """ - # StageAttentionNet uses 'register_attn_hook' (not 'register_hook') - data["register_attn_hook"] = True + return attn - logits = self.model(**data)["logit"] - if class_index is None: - class_index = torch.argmax(logits, dim=-1) - if isinstance(class_index, torch.Tensor): - one_hot = F.one_hot(class_index.detach().clone(), logits.size()[1]).float() +def _reference_attribute_stageattn( + model, + class_index=None, + **data, +) -> Dict[str, torch.Tensor]: + """[REFERENCE ONLY] Original StageAttentionNet-specific Chefer attribution. + + This was the body of ``CheferRelevance._attribute_stageattn()`` + before the CheferInterpretable API was introduced. It accesses + model internals (``model.stagenet[key]``, ``model.embedding_model``) + directly. + """ + data["register_attn_hook"] = True + + logits = model(**data)["logit"] + if class_index is None: + class_index = torch.argmax(logits, dim=-1) + + if isinstance(class_index, torch.Tensor): + one_hot = F.one_hot(class_index.detach().clone(), logits.size()[1]).float() + else: + one_hot = F.one_hot(torch.tensor(class_index), logits.size()[1]).float() + one_hot = one_hot.requires_grad_(True) + one_hot = torch.sum(one_hot.to(logits.device) * logits) + model.zero_grad() + one_hot.backward(retain_graph=True) + + batch_size = logits.shape[0] + feature_keys = model.feature_keys + attn = {} + + for key in feature_keys: + layer = model.stagenet[key] + cam = layer.get_attn_map() + grad = layer.get_attn_grad() + num_tokens = cam.shape[-1] + + R = ( + torch.eye(num_tokens) + .unsqueeze(0) + .repeat(batch_size, 1, 1) + .to(logits.device) + ) + cam = avg_heads(cam, grad) + R += apply_self_attention_rules(R, cam).detach() + + feature = data[key] + if isinstance(feature, tuple) and len(feature) == 2: + _, x_val = feature else: - one_hot = F.one_hot(torch.tensor(class_index), logits.size()[1]).float() - one_hot = one_hot.requires_grad_(True) - one_hot = torch.sum(one_hot.to(logits.device) * logits) - self.model.zero_grad() - one_hot.backward(retain_graph=True) + x_val = feature - batch_size = logits.shape[0] - feature_keys = self.model.feature_keys - attn = {} + embedded = model.embedding_model({key: x_val}) + emb = embedded[key] + if emb.dim() == 4: + emb = emb.sum(dim=2) + mask = (emb.sum(dim=-1) != 0).long().to(logits.device) - for key in feature_keys: - layer = self.model.stagenet[key] - cam = layer.get_attn_map() - grad = layer.get_attn_grad() - num_tokens = cam.shape[-1] + last_idx = mask.sum(dim=1) - 1 + attn[key] = R[torch.arange(batch_size, device=logits.device), last_idx] - R = ( - torch.eye(num_tokens) - .unsqueeze(0) - .repeat(batch_size, 1, 1) - .to(logits.device) - ) - cam = avg_heads(cam, grad) - R += apply_self_attention_rules(R, cam).detach() + return attn - # StageAttentionNet uses get_last_visit (last valid timestep) - # instead of a CLS token. Reconstruct the mask to find the - # index that was actually used for classification. - feature = data[key] - if isinstance(feature, tuple) and len(feature) == 2: - _, x_val = feature - else: - x_val = feature - embedded = self.model.embedding_model({key: x_val}) - emb = embedded[key] - if emb.dim() == 4: - emb = emb.sum(dim=2) - mask = (emb.sum(dim=-1) != 0).long().to(logits.device) +def _reference_attribute_vit( + model, + interpolate: bool = True, + class_index=None, + **data, +) -> Dict[str, torch.Tensor]: + """[REFERENCE ONLY] Original ViT-specific Chefer attribution. - # last valid index per sample - last_idx = mask.sum(dim=1) - 1 # [batch] - attn[key] = R[torch.arange(batch_size, device=logits.device), last_idx] + ViT models do not yet implement CheferInterpretable. This code + shows the ViT-specific flow that will be needed when ViT support is + added to the unified API. + """ + feature_key = model.feature_keys[0] + x = data.get(feature_key) + if x is None: + raise ValueError( + f"Expected feature key '{feature_key}' in data. " + f"Available keys: {list(data.keys())}" + ) - return attn + x = x.to(model.device) + input_size = x.shape[-1] - def _attribute_vit( - self, - interpolate: bool = True, - class_index: int = None, - **data, - ) -> Dict[str, torch.Tensor]: - """Compute ViT attribution and return spatial attribution map. - - Args: - interpolate: If True, interpolate to full image size. - class_index: Target class for attribution. If None, uses predicted class. - **data: Must contain the image feature key. - - Returns: - Dict keyed by the model's feature_key (e.g., "image") with spatial - attribution map of shape [batch, 1, H, W]. - """ - # Get the feature key (first element of feature_keys list) - feature_key = self.model.feature_keys[0] - x = data.get(feature_key) - if x is None: - raise ValueError( - f"Expected feature key '{feature_key}' in data. " - f"Available keys: {list(data.keys())}" - ) - - x = x.to(self.model.device) - - # Infer input size from image dimensions (assumes square images) - input_size = x.shape[-1] - - # Forward pass with attention capture - self.model.zero_grad() - logits, attention_maps = self.model.forward_with_attention(x, register_hook=True) - - # Use predicted class if not specified - target_class = class_index - if target_class is None: - target_class = logits.argmax(dim=-1) - - # Backward pass - one_hot = torch.zeros_like(logits) - if isinstance(target_class, int): - one_hot[:, target_class] = 1 - else: - if target_class.dim() == 0: - target_class = target_class.unsqueeze(0) - one_hot.scatter_(1, target_class.unsqueeze(1), 1) - - one_hot = one_hot.requires_grad_(True) - (logits * one_hot).sum().backward(retain_graph=True) - - # Compute gradient-weighted attention - attention_gradients = self.model.get_attention_gradients() - batch_size = attention_maps[0].shape[0] - num_tokens = attention_maps[0].shape[-1] - device = attention_maps[0].device - - R = torch.eye(num_tokens, device=device) - R = R.unsqueeze(0).expand(batch_size, -1, -1).clone() - - for attn, grad in zip(attention_maps, attention_gradients): - cam = avg_heads(attn, grad) - R = R + apply_self_attention_rules(R.detach(), cam.detach()) - - # CLS token's relevance to patches (excluding CLS itself) - patches_attr = R[:, 0, 1:] - - # Reshape to spatial layout - h_patches, w_patches = self.model.get_num_patches(input_size) - attr_map = patches_attr.reshape(batch_size, 1, h_patches, w_patches) - - if interpolate: - attr_map = F.interpolate( - attr_map, - size=(input_size, input_size), - mode="bilinear", - align_corners=False, - ) - - # Return keyed by the model's feature key (e.g., "image") - return {feature_key: attr_map} + model.zero_grad() + logits, attention_maps = model.forward_with_attention(x, register_hook=True) - # Backwards compatibility aliases - def get_relevance_matrix(self, **data): - """Alias for _attribute_transformer. Use attribute() instead.""" - return self._attribute_transformer(**data) + target_class = class_index + if target_class is None: + target_class = logits.argmax(dim=-1) - def get_vit_attribution_map( - self, - interpolate: bool = True, - class_index: int = None, - **data - ): - """Alias for attribute() for ViT. Use attribute() instead. - - Returns the attribution tensor directly (not wrapped in a dict). - """ - result = self._attribute_vit( - interpolate=interpolate, - class_index=class_index, - **data + one_hot = torch.zeros_like(logits) + if isinstance(target_class, int): + one_hot[:, target_class] = 1 + else: + if target_class.dim() == 0: + target_class = target_class.unsqueeze(0) + one_hot.scatter_(1, target_class.unsqueeze(1), 1) + + one_hot = one_hot.requires_grad_(True) + (logits * one_hot).sum().backward(retain_graph=True) + + attention_gradients = model.get_attention_gradients() + batch_size = attention_maps[0].shape[0] + num_tokens = attention_maps[0].shape[-1] + device = attention_maps[0].device + + R = torch.eye(num_tokens, device=device) + R = R.unsqueeze(0).expand(batch_size, -1, -1).clone() + + for attn, grad in zip(attention_maps, attention_gradients): + cam = avg_heads(attn, grad) + R = R + apply_self_attention_rules(R.detach(), cam.detach()) + + patches_attr = R[:, 0, 1:] + + h_patches, w_patches = model.get_num_patches(input_size) + attr_map = patches_attr.reshape(batch_size, 1, h_patches, w_patches) + + if interpolate: + attr_map = F.interpolate( + attr_map, + size=(input_size, input_size), + mode="bilinear", + align_corners=False, ) - # Return the attribution tensor directly (get the first/only value) - feature_key = self.model.feature_keys[0] - return result[feature_key] \ No newline at end of file + + return {feature_key: attr_map} \ No newline at end of file From b365b47869f6119f84d057e744237d1eba8b7b37 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Tue, 10 Feb 2026 20:50:56 -0600 Subject: [PATCH 10/15] Update type hint --- .../interpret/methods/base_interpreter.py | 10 +--------- pyhealth/interpret/methods/deeplift.py | 17 +++++++--------- pyhealth/interpret/methods/gim.py | 19 +++++++----------- pyhealth/interpret/methods/ig_gim.py | 20 ++++++------------- .../interpret/methods/integrated_gradients.py | 18 +++++++---------- pyhealth/interpret/methods/lime.py | 9 +++++++-- pyhealth/interpret/methods/shap.py | 8 ++++++-- 7 files changed, 41 insertions(+), 60 deletions(-) diff --git a/pyhealth/interpret/methods/base_interpreter.py b/pyhealth/interpret/methods/base_interpreter.py index b8b129043..de75c897a 100644 --- a/pyhealth/interpret/methods/base_interpreter.py +++ b/pyhealth/interpret/methods/base_interpreter.py @@ -16,15 +16,7 @@ import torch.nn as nn from pyhealth.models import BaseModel -from pyhealth.interpret.api import Interpretable, CheferInterpretable -class _InterpretableModel(BaseModel, Interpretable): - pass - -class _CheferInterpretableModel(BaseModel, CheferInterpretable): - pass - -type _AnyInterpretableModel = _InterpretableModel | _CheferInterpretableModel class BaseInterpreter(ABC): """Abstract base class for interpretability methods. @@ -105,7 +97,7 @@ class BaseInterpreter(ABC): >>> print(attributions["image"].shape) # [batch, 1, H, W] """ - def __init__(self, model: _AnyInterpretableModel): + def __init__(self, model: BaseModel): """Initialize the base interpreter. Args: diff --git a/pyhealth/interpret/methods/deeplift.py b/pyhealth/interpret/methods/deeplift.py index f6ab49896..29f99d795 100644 --- a/pyhealth/interpret/methods/deeplift.py +++ b/pyhealth/interpret/methods/deeplift.py @@ -7,7 +7,8 @@ import torch.nn.functional as F from pyhealth.models import BaseModel -from .base_interpreter import BaseInterpreter, _InterpretableModel +from pyhealth.interpret.api import Interpretable +from .base_interpreter import BaseInterpreter def _iter_child_modules(module: torch.nn.Module): @@ -328,18 +329,14 @@ class DeepLift(BaseInterpreter): Learning (ICML), 2017. https://proceedings.mlr.press/v70/shrikumar17a.html """ - def __init__(self, model: _InterpretableModel, use_embeddings: bool = True): + def __init__(self, model: BaseModel, use_embeddings: bool = True): super().__init__(model) + if not isinstance(model, Interpretable): + raise ValueError("Model must implement Interpretable interface") + self.model = model + self.use_embeddings = use_embeddings - if use_embeddings: - assert hasattr(model, "forward_from_embedding"), ( - f"Model {type(model).__name__} must implement " - "forward_from_embedding() method to support embedding-level " - "DeepLIFT. Set use_embeddings=False to use input-level " - "gradients (only for continuous features)." - ) - # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------ diff --git a/pyhealth/interpret/methods/gim.py b/pyhealth/interpret/methods/gim.py index 35ebcdcaf..d1b74b573 100644 --- a/pyhealth/interpret/methods/gim.py +++ b/pyhealth/interpret/methods/gim.py @@ -8,7 +8,8 @@ import torch.nn.functional as F from pyhealth.models import BaseModel -from .base_interpreter import BaseInterpreter, _InterpretableModel +from pyhealth.interpret.api import Interpretable +from .base_interpreter import BaseInterpreter def _iter_child_modules(module: torch.nn.Module): @@ -346,20 +347,14 @@ class GIM(BaseInterpreter): def __init__( self, - model: _InterpretableModel, + model: BaseModel, temperature: float = 2.0, ): super().__init__(model) - if not hasattr(model, "forward_from_embedding"): - raise AssertionError( - "GIM requires models that implement `forward_from_embedding`." - ) - embedding_model = model.get_embedding_model() - if embedding_model is None: - raise AssertionError( - "GIM requires a model with an embedding model " - "accessible via `get_embedding_model()`." - ) + if not isinstance(model, Interpretable): + raise ValueError("Model must implement Interpretable interface") + self.model = model + self.temperature = max(float(temperature), 1.0) def attribute( diff --git a/pyhealth/interpret/methods/ig_gim.py b/pyhealth/interpret/methods/ig_gim.py index ca6798a8c..a33f5529b 100644 --- a/pyhealth/interpret/methods/ig_gim.py +++ b/pyhealth/interpret/methods/ig_gim.py @@ -35,7 +35,8 @@ from pyhealth.models import BaseModel -from .base_interpreter import BaseInterpreter, _InterpretableModel +from .base_interpreter import BaseInterpreter +from pyhealth.interpret.api import Interpretable from .gim import _GIMHookContext @@ -74,23 +75,14 @@ class IntegratedGradientGIM(BaseInterpreter): def __init__( self, - model: _InterpretableModel, + model: BaseModel, temperature: float = 2.0, steps: int = 50, ): super().__init__(model) - - if not hasattr(model, "forward_from_embedding"): - raise AssertionError( - f"Model {type(model).__name__} must implement " - "forward_from_embedding() to use IG-GIM. " - "Set use_embeddings=False for input-level IG instead." - ) - if model.get_embedding_model() is None: - raise AssertionError( - "Model must provide an embedding model via " - "get_embedding_model() for IG-GIM." - ) + if not isinstance(model, Interpretable): + raise ValueError("Model must implement Interpretable interface") + self.model = model self.temperature = max(float(temperature), 1.0) self.steps = steps diff --git a/pyhealth/interpret/methods/integrated_gradients.py b/pyhealth/interpret/methods/integrated_gradients.py index 06701fb1c..a529a6f3f 100644 --- a/pyhealth/interpret/methods/integrated_gradients.py +++ b/pyhealth/interpret/methods/integrated_gradients.py @@ -6,8 +6,8 @@ import torch.nn.functional as F from pyhealth.models import BaseModel - -from .base_interpreter import BaseInterpreter, _InterpretableModel +from pyhealth.interpret.api import Interpretable +from .base_interpreter import BaseInterpreter class IntegratedGradients(BaseInterpreter): @@ -166,7 +166,7 @@ class IntegratedGradients(BaseInterpreter): ... ) """ - def __init__(self, model: _InterpretableModel, use_embeddings: bool = True, steps: int = 50): + def __init__(self, model: BaseModel, use_embeddings: bool = True, steps: int = 50): """Initialize IntegratedGradients interpreter. Args: @@ -187,17 +187,13 @@ def __init__(self, model: _InterpretableModel, use_embeddings: bool = True, step implement forward_from_embedding() method. """ super().__init__(model) + if not isinstance(model, Interpretable): + raise ValueError("Model must implement Interpretable interface") + self.model = model + self.use_embeddings = use_embeddings self.steps = steps - # Check model supports forward_from_embedding if needed - if use_embeddings: - assert hasattr(model, "forward_from_embedding"), ( - f"Model {type(model).__name__} must implement " - "forward_from_embedding() method to support embedding-level " - "Integrated Gradients. Set use_embeddings=False to use " - "input-level gradients (only for continuous features)." - ) def attribute( self, diff --git a/pyhealth/interpret/methods/lime.py b/pyhealth/interpret/methods/lime.py index 7d0f39580..8481c50bf 100644 --- a/pyhealth/interpret/methods/lime.py +++ b/pyhealth/interpret/methods/lime.py @@ -8,7 +8,8 @@ from torch.nn import CosineSimilarity from pyhealth.models import BaseModel -from .base_interpreter import BaseInterpreter, _InterpretableModel +from pyhealth.interpret.api import Interpretable +from .base_interpreter import BaseInterpreter class LimeExplainer(BaseInterpreter): @@ -102,7 +103,7 @@ class LimeExplainer(BaseInterpreter): def __init__( self, - model: _InterpretableModel, + model: BaseModel, use_embeddings: bool = True, n_samples: int = 1000, kernel_width: float = 0.25, @@ -131,6 +132,10 @@ def __init__( ValueError: If feature_selection is not "lasso", "ridge", or "none". """ super().__init__(model) + if not isinstance(model, Interpretable): + raise ValueError("Model must implement Interpretable interface") + self.model = model + self.use_embeddings = use_embeddings self.n_samples = n_samples self.kernel_width = kernel_width diff --git a/pyhealth/interpret/methods/shap.py b/pyhealth/interpret/methods/shap.py index 83565099e..edcd02fc2 100644 --- a/pyhealth/interpret/methods/shap.py +++ b/pyhealth/interpret/methods/shap.py @@ -6,7 +6,8 @@ import torch from pyhealth.models import BaseModel -from .base_interpreter import BaseInterpreter, _InterpretableModel +from pyhealth.interpret.api import Interpretable +from .base_interpreter import BaseInterpreter class ShapExplainer(BaseInterpreter): @@ -94,7 +95,7 @@ class ShapExplainer(BaseInterpreter): def __init__( self, - model: _InterpretableModel, + model: BaseModel, use_embeddings: bool = True, n_background_samples: int = 100, max_coalitions: int = 1000, @@ -120,6 +121,9 @@ def __init__( implement forward_from_embedding() method. """ super().__init__(model) + if not isinstance(model, Interpretable): + raise ValueError("Model must implement Interpretable interface") + self.model = model self.use_embeddings = use_embeddings self.n_background_samples = n_background_samples self.max_coalitions = max_coalitions From a250dafdac6281a96c9f72bb11a927f9dec2c6fa Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Tue, 10 Feb 2026 20:52:01 -0600 Subject: [PATCH 11/15] Fix chefer --- pyhealth/interpret/methods/chefer.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/pyhealth/interpret/methods/chefer.py b/pyhealth/interpret/methods/chefer.py index da48af7a0..6bbbbad17 100644 --- a/pyhealth/interpret/methods/chefer.py +++ b/pyhealth/interpret/methods/chefer.py @@ -20,8 +20,8 @@ from pyhealth.interpret.api import CheferInterpretable from pyhealth.models.base_model import BaseModel - -from .base_interpreter import BaseInterpreter, _CheferInterpretableModel +from pyhealth.interpret.api import CheferInterpretable +from .base_interpreter import BaseInterpreter # --------------------------------------------------------------------------- @@ -131,15 +131,11 @@ class CheferRelevance(BaseInterpreter): >>> attributions = interpreter.attribute(class_index=1, **batch) """ - def __init__(self, model: _CheferInterpretableModel): + def __init__(self, model: BaseModel): super().__init__(model) - self.model = cast(_CheferInterpretableModel, model) - if not isinstance(model, CheferInterpretable): - raise ValueError( - f"CheferRelevance requires a model implementing " - f"CheferInterpretable, got {type(model).__name__}." - ) + raise ValueError("Model must implement CheferInterpretable interface") + self.model = model def attribute( self, From 4c081b2baeb4584e681cad6915940e0dcd675c0f Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Tue, 10 Feb 2026 21:04:05 -0600 Subject: [PATCH 12/15] Fix chefer --- pyhealth/interpret/methods/chefer.py | 50 +++++++++++++++++++++++++++- pyhealth/models/stagenet_mha.py | 6 ++-- pyhealth/models/transformer.py | 6 +++- 3 files changed, 58 insertions(+), 4 deletions(-) diff --git a/pyhealth/interpret/methods/chefer.py b/pyhealth/interpret/methods/chefer.py index 6bbbbad17..5efca68eb 100644 --- a/pyhealth/interpret/methods/chefer.py +++ b/pyhealth/interpret/methods/chefer.py @@ -199,7 +199,55 @@ def attribute( R_dict[key] = R # --- 5. Reduce R matrices to per-token vectors --- - return self.model.get_relevance_tensor(R_dict, **data) + attributions = self.model.get_relevance_tensor(R_dict, **data) + + # --- 6. Expand to match raw input shapes (nested sequences) --- + return self._map_to_input_shapes(attributions, data) + + # ------------------------------------------------------------------ + # Shape mapping + # ------------------------------------------------------------------ + + def _map_to_input_shapes( + self, + attributions: Dict[str, torch.Tensor], + data: dict, + ) -> Dict[str, torch.Tensor]: + """Expand attributions to match raw input value shapes. + + For nested sequences the attention operates on a pooled + (visit-level) sequence, but downstream consumers (e.g. ablation + metrics) expect attributions to match the raw input value shape. + Per-visit relevance scores are replicated across all codes + within each visit. + + Args: + attributions: Per-feature attribution tensors returned by + ``model.get_relevance_tensor()``. + data: Original ``**data`` kwargs from the dataloader batch. + + Returns: + Attributions expanded to raw input value shapes where needed. + """ + result: Dict[str, torch.Tensor] = {} + for key, attr in attributions.items(): + feature = data.get(key) + if feature is not None: + if isinstance(feature, torch.Tensor): + val = feature + else: + schema = self.model.dataset.input_processors[key].schema() + val = ( + feature[schema.index("value")] + if "value" in schema + else None + ) + if val is not None and val.dim() > attr.dim(): + for _ in range(val.dim() - attr.dim()): + attr = attr.unsqueeze(-1) + attr = attr.expand_as(val) + result[key] = attr + return result # ------------------------------------------------------------------ # Backward compatibility aliases diff --git a/pyhealth/models/stagenet_mha.py b/pyhealth/models/stagenet_mha.py index 7dba08cea..637f2fdc5 100644 --- a/pyhealth/models/stagenet_mha.py +++ b/pyhealth/models/stagenet_mha.py @@ -704,7 +704,9 @@ def get_relevance_tensor( last_idx = mask.sum(dim=1).long() - 1 last_idx = last_idx.clamp(min=0) - result[key] = r[ + attn = r[ torch.arange(batch_size, device=device), last_idx - ] + ] # [batch, attention_seq_len] + + result[key] = attn return result diff --git a/pyhealth/models/transformer.py b/pyhealth/models/transformer.py index 69c69d0be..de8555579 100644 --- a/pyhealth/models/transformer.py +++ b/pyhealth/models/transformer.py @@ -611,7 +611,11 @@ def get_relevance_tensor( **data: torch.Tensor | tuple[torch.Tensor, ...], ) -> dict[str, torch.Tensor]: # CLS token is at index 0 for all feature keys - return {key: r[:, 0] for key, r in R.items()} + result = {} + for key, r in R.items(): + # CLS token is at index 0; extract its attention row + result[key] = r[:, 0] # [batch, attention_seq_len] + return result if __name__ == "__main__": From a3bfb7dea5ae15520be199434e66670c4238d5cf Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Tue, 10 Feb 2026 21:26:29 -0600 Subject: [PATCH 13/15] Fix test --- tests/core/test_gim.py | 3 ++- tests/core/test_ig_gim.py | 3 ++- tests/core/test_lime.py | 7 ++++--- tests/core/test_shap.py | 7 ++++--- 4 files changed, 12 insertions(+), 8 deletions(-) diff --git a/tests/core/test_gim.py b/tests/core/test_gim.py index 0893b874a..284931883 100644 --- a/tests/core/test_gim.py +++ b/tests/core/test_gim.py @@ -6,6 +6,7 @@ import torch.nn as nn from pyhealth.interpret.methods import GIM +from pyhealth.interpret.api import Interpretable from pyhealth.models import BaseModel from pyhealth.models.transformer import Attention @@ -64,7 +65,7 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: return {key: self.embedding(val.long()) for key, val in inputs.items()} -class _ToyGIMModel(BaseModel): +class _ToyGIMModel(BaseModel, Interpretable): """Small attention-style model with module-based nonlinearities. Follows the new API conventions: ``forward_from_embedding(**kwargs)`` diff --git a/tests/core/test_ig_gim.py b/tests/core/test_ig_gim.py index 3b8ff1556..89ad6e73f 100644 --- a/tests/core/test_ig_gim.py +++ b/tests/core/test_ig_gim.py @@ -9,6 +9,7 @@ from pyhealth.interpret.methods import GIM, IntegratedGradients from pyhealth.interpret.methods.ig_gim import IntegratedGradientGIM +from pyhealth.interpret.api import Interpretable from pyhealth.models import BaseModel from pyhealth.models.transformer import Attention @@ -75,7 +76,7 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: return {key: self.embedding(val.long()) for key, val in inputs.items()} -class _ToyModel(BaseModel): +class _ToyModel(BaseModel, Interpretable): """Small model with softmax attention for testing IG-GIM.""" def __init__(self, vocab_size=32, embedding_dim=4, schema=("value",)): diff --git a/tests/core/test_lime.py b/tests/core/test_lime.py index b7970592d..0ed0f33fc 100644 --- a/tests/core/test_lime.py +++ b/tests/core/test_lime.py @@ -14,6 +14,7 @@ from pyhealth.datasets import SampleDataset, get_dataloader from pyhealth.datasets.sample_dataset import SampleBuilder from pyhealth.models import MLP, StageNet, BaseModel +from pyhealth.interpret.api import Interpretable from pyhealth.interpret.methods import LimeExplainer from pyhealth.interpret.methods.base_interpreter import BaseInterpreter @@ -50,7 +51,7 @@ def __init__(self, input_schema, output_schema, processors=None): # Test model helpers # --------------------------------------------------------------------------- -class _SimpleLimeModel(BaseModel): +class _SimpleLimeModel(BaseModel, Interpretable): """Minimal model for testing LIME with continuous inputs.""" def __init__(self): @@ -100,7 +101,7 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: return {key: self.embedding(value.long()) for key, value in inputs.items()} -class _EmbeddingForwardModel(BaseModel): +class _EmbeddingForwardModel(BaseModel, Interpretable): """Toy model exposing forward_from_embedding for discrete features.""" def __init__(self, schema=("value",)): @@ -156,7 +157,7 @@ def get_embedding_model(self): return self.embedding_model -class _MultiFeatureModel(BaseModel): +class _MultiFeatureModel(BaseModel, Interpretable): """Model with multiple feature inputs for testing multi-feature LIME.""" def __init__(self): diff --git a/tests/core/test_shap.py b/tests/core/test_shap.py index 105775cfd..be01bf957 100644 --- a/tests/core/test_shap.py +++ b/tests/core/test_shap.py @@ -11,6 +11,7 @@ from pyhealth.datasets import SampleDataset, get_dataloader from pyhealth.datasets.sample_dataset import SampleBuilder from pyhealth.models import MLP, StageNet, BaseModel +from pyhealth.interpret.api import Interpretable from pyhealth.interpret.methods import ShapExplainer from pyhealth.interpret.methods.base_interpreter import BaseInterpreter @@ -47,7 +48,7 @@ def __init__(self, input_schema, output_schema, processors=None): # Test model helpers # --------------------------------------------------------------------------- -class _SimpleShapModel(BaseModel): +class _SimpleShapModel(BaseModel, Interpretable): """Minimal model for testing SHAP with continuous inputs.""" def __init__(self): @@ -97,7 +98,7 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: return {key: self.embedding(value.long()) for key, value in inputs.items()} -class _EmbeddingForwardModel(BaseModel): +class _EmbeddingForwardModel(BaseModel, Interpretable): """Toy model exposing forward_from_embedding for discrete features.""" def __init__(self, schema=("value",)): @@ -152,7 +153,7 @@ def get_embedding_model(self): return self.embedding_model -class _MultiFeatureModel(BaseModel): +class _MultiFeatureModel(BaseModel, Interpretable): """Model with multiple feature inputs for testing multi-feature SHAP.""" def __init__(self): From 574d9bf7bdda19df37f76ed80630071b17ac21f8 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Tue, 10 Feb 2026 21:43:47 -0600 Subject: [PATCH 14/15] Fix test --- tests/core/test_ig_gim.py | 11 +++++++---- tests/core/test_lime.py | 2 +- tests/core/test_shap.py | 2 +- tests/core/test_stagenet_mha.py | 4 +++- 4 files changed, 12 insertions(+), 7 deletions(-) diff --git a/tests/core/test_ig_gim.py b/tests/core/test_ig_gim.py index 89ad6e73f..38565227d 100644 --- a/tests/core/test_ig_gim.py +++ b/tests/core/test_ig_gim.py @@ -376,15 +376,18 @@ class _NoFwdEmb: obj.label_keys = ["label"] obj.get_embedding_model = lambda: _ToyEmbeddingModel() - with self.assertRaises((AssertionError, AttributeError)): + with self.assertRaises((AssertionError, AttributeError, ValueError)): IntegratedGradientGIM(obj) def test_init_rejects_model_without_embedding_model(self): - """Should raise if get_embedding_model returns None.""" + """Should raise if get_embedding_model returns None during attribution.""" model = _ToyModel() model.get_embedding_model = lambda: None + ig_gim = IntegratedGradientGIM(model) with self.assertRaises(AssertionError): - IntegratedGradientGIM(model) + ig_gim.attribute( + codes=self.tokens, label=self.labels, target_class_idx=0, + ) def test_temperature_clamped_to_one(self): """Temperatures below 1.0 should be clamped to 1.0.""" @@ -654,7 +657,7 @@ def test_repr(self): # Additional toy model without Attention (no GIM-swappable modules) # --------------------------------------------------------------------------- -class _ToyModelNoAttention(BaseModel): +class _ToyModelNoAttention(BaseModel, Interpretable): """Simple model without Attention or LayerNorm — GIM hooks are no-ops.""" def __init__(self, vocab_size=32, embedding_dim=4): diff --git a/tests/core/test_lime.py b/tests/core/test_lime.py index 0ed0f33fc..ab061cb4a 100644 --- a/tests/core/test_lime.py +++ b/tests/core/test_lime.py @@ -585,7 +585,7 @@ def __init__(self): self.linear = nn.Linear(3, 1) model = _BareModel() - with self.assertRaises(AssertionError): + with self.assertRaises((AssertionError, ValueError)): LimeExplainer(model, use_embeddings=True) diff --git a/tests/core/test_shap.py b/tests/core/test_shap.py index be01bf957..8c03a1c1f 100644 --- a/tests/core/test_shap.py +++ b/tests/core/test_shap.py @@ -504,7 +504,7 @@ def __init__(self): self.linear = nn.Linear(3, 1) model = _BareModel() - with self.assertRaises(AssertionError): + with self.assertRaises((AssertionError, ValueError)): ShapExplainer(model, use_embeddings=True) diff --git a/tests/core/test_stagenet_mha.py b/tests/core/test_stagenet_mha.py index 78206f309..4d4f1eebd 100644 --- a/tests/core/test_stagenet_mha.py +++ b/tests/core/test_stagenet_mha.py @@ -117,8 +117,10 @@ def test_attention_hook_records_map_and_grad(self): loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) batch = next(iter(loader)) - ret = self.model(register_attn_hook=True, **batch) + self.model.set_attention_hooks(True) + ret = self.model(**batch) ret["loss"].backward() + self.model.set_attention_hooks(False) for feature_key, layer in self.model.stagenet.items(): attn_map = layer.get_attn_map() From 224a5459c7f0f861058d3b21c719ba187dfbdea2 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Tue, 10 Feb 2026 21:51:27 -0600 Subject: [PATCH 15/15] Enable chefer in run script --- examples/interpretability/dka_stageattn_mimic4_interpret.py | 2 +- examples/interpretability/dka_transformer_mimic4_interpret.py | 2 +- examples/interpretability/los_stageattn_mimic4_interpret.py | 2 +- examples/interpretability/los_transformer_mimic4_interpret.py | 2 +- examples/interpretability/mp_stageattn_mimic4_interpret.py | 2 +- examples/interpretability/mp_transformer_mimic4_interpret.py | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/interpretability/dka_stageattn_mimic4_interpret.py b/examples/interpretability/dka_stageattn_mimic4_interpret.py index ffbdf778d..3b405bd52 100644 --- a/examples/interpretability/dka_stageattn_mimic4_interpret.py +++ b/examples/interpretability/dka_stageattn_mimic4_interpret.py @@ -134,7 +134,7 @@ def count_labels(ds): "ig_gim": IntegratedGradientGIM(model), "deeplift": DeepLift(model, use_embeddings=True), "gim": GIM(model), - # "chefer": CheferRelevance(model), + "chefer": CheferRelevance(model), "shap": ShapExplainer(model, use_embeddings=True), "lime": LimeExplainer(model, use_embeddings=True, n_samples=200), } diff --git a/examples/interpretability/dka_transformer_mimic4_interpret.py b/examples/interpretability/dka_transformer_mimic4_interpret.py index cec0fc2dc..d2617d652 100644 --- a/examples/interpretability/dka_transformer_mimic4_interpret.py +++ b/examples/interpretability/dka_transformer_mimic4_interpret.py @@ -134,7 +134,7 @@ def count_labels(ds): "ig_gim": IntegratedGradientGIM(model), "deeplift": DeepLift(model, use_embeddings=True), "gim": GIM(model), - # "chefer": CheferRelevance(model), + "chefer": CheferRelevance(model), "shap": ShapExplainer(model, use_embeddings=True), "lime": LimeExplainer(model, use_embeddings=True, n_samples=200), } diff --git a/examples/interpretability/los_stageattn_mimic4_interpret.py b/examples/interpretability/los_stageattn_mimic4_interpret.py index e9b64aa95..51d253f25 100644 --- a/examples/interpretability/los_stageattn_mimic4_interpret.py +++ b/examples/interpretability/los_stageattn_mimic4_interpret.py @@ -120,7 +120,7 @@ def main(): "ig_gim": IntegratedGradientGIM(model), "deeplift": DeepLift(model, use_embeddings=True), "gim": GIM(model), - # "chefer": CheferRelevance(model), + "chefer": CheferRelevance(model), "shap": ShapExplainer(model, use_embeddings=True), "lime": LimeExplainer(model, use_embeddings=True, n_samples=200), } diff --git a/examples/interpretability/los_transformer_mimic4_interpret.py b/examples/interpretability/los_transformer_mimic4_interpret.py index 1abb96636..ccb06c707 100644 --- a/examples/interpretability/los_transformer_mimic4_interpret.py +++ b/examples/interpretability/los_transformer_mimic4_interpret.py @@ -120,7 +120,7 @@ def main(): "ig_gim": IntegratedGradientGIM(model), "deeplift": DeepLift(model, use_embeddings=True), "gim": GIM(model), - # "chefer": CheferRelevance(model), + "chefer": CheferRelevance(model), "shap": ShapExplainer(model, use_embeddings=True), "lime": LimeExplainer(model, use_embeddings=True, n_samples=200), } diff --git a/examples/interpretability/mp_stageattn_mimic4_interpret.py b/examples/interpretability/mp_stageattn_mimic4_interpret.py index 7037ddbf4..e42b9aca6 100644 --- a/examples/interpretability/mp_stageattn_mimic4_interpret.py +++ b/examples/interpretability/mp_stageattn_mimic4_interpret.py @@ -120,7 +120,7 @@ def main(): "ig_gim": IntegratedGradientGIM(model), "deeplift": DeepLift(model, use_embeddings=True), "gim": GIM(model), - # "chefer": CheferRelevance(model), + "chefer": CheferRelevance(model), "shap": ShapExplainer(model, use_embeddings=True), "lime": LimeExplainer(model, use_embeddings=True, n_samples=200), } diff --git a/examples/interpretability/mp_transformer_mimic4_interpret.py b/examples/interpretability/mp_transformer_mimic4_interpret.py index 57b90e91c..dcdb55215 100644 --- a/examples/interpretability/mp_transformer_mimic4_interpret.py +++ b/examples/interpretability/mp_transformer_mimic4_interpret.py @@ -120,7 +120,7 @@ def main(): "ig_gim": IntegratedGradientGIM(model), "deeplift": DeepLift(model, use_embeddings=True), "gim": GIM(model), - # "chefer": CheferRelevance(model), + "chefer": CheferRelevance(model), "shap": ShapExplainer(model, use_embeddings=True), "lime": LimeExplainer(model, use_embeddings=True, n_samples=200), }