diff --git a/examples/interpretability/dka_stageattn_mimic4_interpret.py b/examples/interpretability/dka_stageattn_mimic4_interpret.py index ffbdf778d..3b405bd52 100644 --- a/examples/interpretability/dka_stageattn_mimic4_interpret.py +++ b/examples/interpretability/dka_stageattn_mimic4_interpret.py @@ -134,7 +134,7 @@ def count_labels(ds): "ig_gim": IntegratedGradientGIM(model), "deeplift": DeepLift(model, use_embeddings=True), "gim": GIM(model), - # "chefer": CheferRelevance(model), + "chefer": CheferRelevance(model), "shap": ShapExplainer(model, use_embeddings=True), "lime": LimeExplainer(model, use_embeddings=True, n_samples=200), } diff --git a/examples/interpretability/dka_transformer_mimic4_interpret.py b/examples/interpretability/dka_transformer_mimic4_interpret.py index cec0fc2dc..d2617d652 100644 --- a/examples/interpretability/dka_transformer_mimic4_interpret.py +++ b/examples/interpretability/dka_transformer_mimic4_interpret.py @@ -134,7 +134,7 @@ def count_labels(ds): "ig_gim": IntegratedGradientGIM(model), "deeplift": DeepLift(model, use_embeddings=True), "gim": GIM(model), - # "chefer": CheferRelevance(model), + "chefer": CheferRelevance(model), "shap": ShapExplainer(model, use_embeddings=True), "lime": LimeExplainer(model, use_embeddings=True, n_samples=200), } diff --git a/examples/interpretability/los_stageattn_mimic4_interpret.py b/examples/interpretability/los_stageattn_mimic4_interpret.py index e9b64aa95..51d253f25 100644 --- a/examples/interpretability/los_stageattn_mimic4_interpret.py +++ b/examples/interpretability/los_stageattn_mimic4_interpret.py @@ -120,7 +120,7 @@ def main(): "ig_gim": IntegratedGradientGIM(model), "deeplift": DeepLift(model, use_embeddings=True), "gim": GIM(model), - # "chefer": CheferRelevance(model), + "chefer": CheferRelevance(model), "shap": ShapExplainer(model, use_embeddings=True), "lime": LimeExplainer(model, use_embeddings=True, n_samples=200), } diff --git a/examples/interpretability/los_transformer_mimic4_interpret.py b/examples/interpretability/los_transformer_mimic4_interpret.py index 1abb96636..ccb06c707 100644 --- a/examples/interpretability/los_transformer_mimic4_interpret.py +++ b/examples/interpretability/los_transformer_mimic4_interpret.py @@ -120,7 +120,7 @@ def main(): "ig_gim": IntegratedGradientGIM(model), "deeplift": DeepLift(model, use_embeddings=True), "gim": GIM(model), - # "chefer": CheferRelevance(model), + "chefer": CheferRelevance(model), "shap": ShapExplainer(model, use_embeddings=True), "lime": LimeExplainer(model, use_embeddings=True, n_samples=200), } diff --git a/examples/interpretability/mp_stageattn_mimic4_interpret.py b/examples/interpretability/mp_stageattn_mimic4_interpret.py index 7037ddbf4..e42b9aca6 100644 --- a/examples/interpretability/mp_stageattn_mimic4_interpret.py +++ b/examples/interpretability/mp_stageattn_mimic4_interpret.py @@ -120,7 +120,7 @@ def main(): "ig_gim": IntegratedGradientGIM(model), "deeplift": DeepLift(model, use_embeddings=True), "gim": GIM(model), - # "chefer": CheferRelevance(model), + "chefer": CheferRelevance(model), "shap": ShapExplainer(model, use_embeddings=True), "lime": LimeExplainer(model, use_embeddings=True, n_samples=200), } diff --git a/examples/interpretability/mp_transformer_mimic4_interpret.py b/examples/interpretability/mp_transformer_mimic4_interpret.py index 57b90e91c..dcdb55215 100644 --- a/examples/interpretability/mp_transformer_mimic4_interpret.py +++ b/examples/interpretability/mp_transformer_mimic4_interpret.py @@ -120,7 +120,7 @@ def main(): "ig_gim": IntegratedGradientGIM(model), "deeplift": DeepLift(model, use_embeddings=True), "gim": GIM(model), - # "chefer": CheferRelevance(model), + "chefer": CheferRelevance(model), "shap": ShapExplainer(model, use_embeddings=True), "lime": LimeExplainer(model, use_embeddings=True, n_samples=200), } diff --git a/pyhealth/interpret/api.py b/pyhealth/interpret/api.py new file mode 100644 index 000000000..f84b8d249 --- /dev/null +++ b/pyhealth/interpret/api.py @@ -0,0 +1,499 @@ +from abc import ABC, abstractmethod +import torch +from torch import nn + +class Interpretable(ABC): + """Abstract interface for models supporting interpretability methods. + + This class defines the contract that models must fulfill to be compatible + with PyHealth's interpretability module. It enables gradient-based + attribution methods and embedding-level perturbation methods to work with + your model. + + The interface separates the embedding stage (which generates learned + representations from raw features) from the prediction stage + (which generates outputs from embeddings). This separation allows + interpretability methods to either: + + 1. Use gradients flowing through embeddings + 2. Perturb embeddings and pass them through prediction head + 3. Directly access and analyze the learned representations + + Methods + ------- + forward + Standard forward pass of the model. + forward_from_embedding + Perform forward pass starting from embeddings. + get_embedding_model + Get the embedding/feature extraction stage if applicable. + + Assumptions + ----------- + Models implementing this interface must adhere to the following assumptions: + + 1. **Optional label handling**: The ``forward_from_embedding()`` method must + accept label keys (as specified in ``self.label_keys``) as optional keyword + arguments. The method should handle cases where labels are missing without + raising exceptions. When labels are absent, the method should skip loss + computation and omit 'loss' and 'y_true' from the return dictionary. + + 2. **Non-linearity as nn.Module**: All non-linear activation functions + (ReLU, Sigmoid, Softmax, Tanh, etc.) must be defined as nn.Module instances + in the model's ``__init__`` method and called as instance methods + (e.g., ``self.relu(x)``). Do NOT use functional variants like ``F.relu(x)``, + ``F.sigmoid(x)``, or ``F.softmax(x)``. This is critical for + gradient-based interpretability methods (e.g., DeepLIFT) that require + hooks to be registered on non-linearities. + + Examples of correct activation usage:: + + class GoodModel(nn.Module): + def __init__(self): + super().__init__() + self.relu = nn.ReLU() # Correct + self.sigmoid = nn.Sigmoid() # Correct + + def forward(self, x): + x = self.relu(x) # Correct + x = self.sigmoid(x) # Correct + return x + + Examples of incorrect activation usage:: + + class BadModel(nn.Module): + def forward(self, x): + x = F.relu(x) # WRONG - functional variant + x = F.sigmoid(x) # WRONG - functional variant + return x + """ + + def forward( + self, + **kwargs: torch.Tensor | tuple[torch.Tensor, ...], + ) -> dict[str, torch.Tensor]: + """Forward pass of the model. + + This is the standard entry point for running the model on a batch + of data. It accepts the raw feature tensors (as produced by the + dataloader) and returns predictions. + + Parameters + ---------- + **kwargs : torch.Tensor or tuple[torch.Tensor, ...] + Keyword arguments keyed by the model's ``feature_keys`` and + ``label_keys``. Each value is either a single tensor or a + tuple of tensors (e.g. ``(value, mask)``). + + Returns + ------- + dict[str, torch.Tensor] + A dictionary containing at least: + + - **logit** (torch.Tensor): Raw model predictions / logits of + shape ``(batch_size, num_classes)``. + - **y_prob** (torch.Tensor): Predicted probabilities. + - **loss** (torch.Tensor, optional): Scalar loss, present only + when label keys are included in ``kwargs``. + - **y_true** (torch.Tensor, optional): Ground-truth labels, + present only when label keys are included in ``kwargs``. + + Raises + ------ + NotImplementedError + If the subclass does not implement this method. + """ + raise NotImplementedError + + def forward_from_embedding( + self, + **kwargs: torch.Tensor | tuple[torch.Tensor, ...] + ) -> dict[str, torch.Tensor]: + """Forward pass of the model starting from embeddings. + + This method enables interpretability methods to pass embeddings directly + into the model's prediction head, bypassing the embedding stage. This is + useful for: + + - **Gradient-based attribution** (DeepLIFT, Integrated Gradients): + Allows gradients to be computed with respect to embeddings + - **Embedding perturbation** (LIME, SHAP): Allows perturbing embeddings + instead of raw features + - **Intermediate representation analysis**: Enables inspection of learned + representations at the embedding layer + + Kwargs keys typically mirror the model's feature keys (from the dataset's + input_schema), but represent embeddings instead of raw features. + + Parameters + ---------- + **kwargs : torch.Tensor or tuple[torch.Tensor, ...] + Variable keyword arguments representing input embeddings and optional labels. + + **Embedding arguments** (required): Should include all feature keys that + your model expects. Examples: + + - 'conditions': (batch_size, seq_length, embedding_dim) + - 'procedures': (batch_size, seq_length, embedding_dim) + - 'image': (batch_size, embedding_dim, height, width) + + **Label arguments** (optional): May include any label keys defined in + ``self.label_keys``. If label keys are present, the method should compute + loss and include 'loss' and 'y_true' in the return dictionary. If label + keys are absent, the method must not crash; simply omit 'loss' and 'y_true' + from the return dictionary. + + Returns + ------- + dict[str, torch.Tensor] + A dictionary containing model outputs with the following keys: + + - **logit** (torch.Tensor): Raw model predictions/logits of shape + (batch_size, num_classes) for classification tasks. + + - **y_prob** (torch.Tensor): Predicted probabilities of shape + (batch_size, num_classes). For binary classification, often + shape (batch_size, 1). + + - **loss** (torch.Tensor, optional): Scalar loss value if + any of ``self.label_keys`` are present in kwargs. Returned only when + ground truth labels are provided. Should not be included if labels are + unavailable. + + - **y_true** (torch.Tensor, optional): True labels if present in + kwargs. Useful for consistency checking during attribution. Should not + be included if labels are unavailable. + + Additional keys may be returned depending on the model's task type + (e.g., 'risks' for survival analysis, 'seq_output' for sequence models). + + Raises + ------ + NotImplementedError + If the subclass does not implement this method. + + Notes + ----- + The implementation must gracefully handle missing label keys. Interpretability + methods may invoke this method with only embedding inputs (no labels), expecting + forward passes for attribution computation. The method should compute predictions + successfully in both scenarios. + + Examples + -------- + For an EHR model with embedding dimension 64: + + >>> model = MyEHRModel(...) + >>> batch_embeddings = { + ... 'conditions': torch.randn(32, 100, 64), # 32 samples, 100 time steps + ... 'procedures': torch.randn(32, 100, 64), + ... } + >>> output = model.forward_from_embedding(**batch_embeddings) + >>> logits = output['logit'] # Shape: (32, num_classes) + >>> y_prob = output['y_prob'] # Shape: (32, num_classes) + + With optional labels: + + >>> batch_embeddings['mortality'] = torch.tensor([0, 1, 0, ...]) # Add labels + >>> output = model.forward_from_embedding(**batch_embeddings) + >>> loss = output['loss'] # Now included + >>> y_true = output['y_true'] # Now included + + For an image model with spatial embeddings: + + >>> model = MyImageModel(...) + >>> batch_embeddings = { + ... 'image': torch.randn(16, 768, 14, 14), # Vision Transformer embeddings + ... } + >>> output = model.forward_from_embedding(**batch_embeddings) + """ + raise NotImplementedError + + def get_embedding_model(self) -> nn.Module | None: + """Get the embedding/feature extraction stage of the model. + + This method provides access to the model's embedding stage, which + transforms raw input features into learned vector representations. + This is used by interpretability methods to: + + - Generate embeddings from raw features before attribution + - Identify the boundary between feature processing and prediction + - Apply embedding-level analysis separately from prediction + + Returns + ------- + nn.Module or None + The embedding model/stage as an nn.Module if applicable, or None + if the model does not have a separable embedding stage. + + When returning a model, it should: + + - Accept the same input signature as the parent model (raw features) + - Produce embeddings that are compatible with forward_from_embedding() + - Be in the same device as the parent model + + Raises + ------ + NotImplementedError + If the subclass does not implement this method. + + Examples + -------- + For a model with explicit embedding and prediction stages: + + >>> class MyModel(InterpretableModelInterface): + ... def __init__(self): + ... self.embedding_layer = EmbeddingBlock(...) + ... self.prediction_head = PredictionBlock(...) + ... + ... def get_embedding_model(self): + ... return self.embedding_layer + + For models without a clear separable embedding stage, return None: + + >>> def get_embedding_model(self): + ... return None # Embeddings are not separately accessible + """ + raise NotImplementedError + + +class CheferInterpretable(Interpretable): + """Abstract interface for models supporting Chefer relevance attribution. + + This is a subclass of :class:`Interpretable` and therefore + inherits the embedding-level interface (``forward_from_embedding``, + ``get_embedding_model``). Models that implement this interface + automatically satisfy the general interpretability contract **and** the + Chefer-specific contract, so they work with both embedding-perturbation + methods (DeepLIFT, LIME, …) and gradient-weighted attention methods + (Chefer). + + The Chefer algorithm works as follows: + + 1. **Forward + hook registration** — run the model while capturing + attention weight tensors and registering backward hooks so their + gradients are stored. + 2. **Backward** — back-propagate from a one-hot target class through + the logits. + 3. **Relevance propagation** — for every feature key, iterate over + attention layers, compute gradient-weighted attention + (``clamp(attn * grad, min=0)``), and accumulate into a relevance + matrix ``R`` via ``R += cam @ R``. + 4. **Attribution extraction** — extract the final per-token + attribution from ``R`` (e.g. read the CLS row, or the + last-valid-timestep row, possibly with reshaping). + + Steps 1, 3-b and 4 are model-specific; the rest is generic. This + interface captures exactly those model-specific pieces. + + Inherited from ``InterpretableModelInterface`` + ----------------------------------------------- + forward_from_embedding(**kwargs) -> dict[str, Tensor] + Forward pass starting from pre-computed embeddings. + get_embedding_model() -> nn.Module | None + Access the embedding / feature-extraction stage. + + Additional (Chefer-specific) methods + ------------------------------------- + set_attention_hooks(enabled) -> None + Toggle attention map capture and gradient hook registration. + get_attention_layers() -> dict[str, list[tuple[Tensor, Tensor]]] + Paired (attn_map, attn_grad) for each attention layer, keyed by + feature key. + get_relevance_vector(R, **data) -> dict[str, Tensor] + Reduce relevance matrices to per-token attribution vectors. + + Attributes + ---------- + feature_keys : list[str] + The feature keys from the task's ``input_schema`` (e.g. + ``["conditions", "procedures"]``). Already provided by + :class:`~pyhealth.models.base_model.BaseModel`. + + Notes + ----- + * ``set_attention_hooks(True)`` must be called **before** the forward + pass, and ``get_attention_layers`` must be called **after** the + forward + backward passes, because attention maps are populated + during forward and gradients during backward. + * The interface intentionally does **not** prescribe how hooks are + registered internally — ``nn.MultiheadAttention`` with + ``register_hook``, manual ``save_attn_grad`` callbacks, or explicit + QKV computation all work as long as the getter methods return the + right tensors. + + Examples + -------- + Minimal skeleton for a new model: + + >>> class MyAttentionModel(BaseModel, CheferInterpretableModelInterface): + ... # feature_keys is inherited from BaseModel + ... + ... def forward_from_embedding(self, **kwargs): + ... # ... prediction head from pre-computed embeddings ... + ... + ... def get_embedding_model(self): + ... return self.embedding_layer + ... + ... def set_attention_hooks(self, enabled): + ... self._register_hooks = enabled + ... + ... def get_attention_layers(self): + ... result = {} + ... for key in self.feature_keys: + ... result[key] = [ + ... (blk.attention.get_attn_map(), + ... blk.attention.get_attn_grad()) + ... for blk in self.encoder[key].blocks + ... ] + ... return result + ... + ... def get_relevance_vector(self, R, **data): + ... return {key: r[:, 0] for key, r in R.items()} + """ + + @abstractmethod + def set_attention_hooks(self, enabled: bool) -> None: + """Toggle attention hook registration for subsequent forward passes. + + When ``enabled=True``, the next call to ``forward()`` (or + ``forward_from_embedding()``) must: + + 1. Store attention weight tensors so they are retrievable via + :meth:`get_attention_layers`. + 2. Register backward hooks on those tensors so that after + ``.backward()`` the corresponding gradients are also stored. + + When ``enabled=False``, subsequent forward passes should **not** + capture attention maps or register gradient hooks, restoring the + model to its normal (faster) execution mode. + + Parameters + ---------- + enabled : bool + ``True`` to start capturing attention maps and registering + gradient hooks; ``False`` to stop. + + Typical implementations set an internal flag that the model's + forward method checks:: + + def set_attention_hooks(self, enabled): + self._attention_hooks_enabled = enabled + + And inside the forward / encoder logic:: + + if self._attention_hooks_enabled: + attn.register_hook(self.save_attn_grad) + """ + ... + + @abstractmethod + def get_attention_layers( + self, + ) -> dict[str, list[tuple[torch.Tensor, torch.Tensor]]]: + """Return (attention_map, attention_gradient) pairs for all feature keys. + + Must be called **after** ``set_attention_hooks(True)``, + a ``forward()`` call, and a subsequent ``backward()`` call so + that both attention maps and their gradients are populated. + + Returns + ------- + dict[str, list[tuple[torch.Tensor, torch.Tensor]]] + A dictionary keyed by ``feature_keys``. Each value is a list + with one ``(attn_map, attn_grad)`` tuple per attention layer, + ordered from the first (closest to input) to the last + (closest to output). + + Each tensor may have shape: + + * ``[batch, heads, seq, seq]`` — multi-head (will be + gradient-weighted-averaged across heads by Chefer). + * ``[batch, seq, seq]`` — already head-averaged. + + ``attn_map`` and ``attn_grad`` in the same tuple must have + the same shape. + + Examples + -------- + A model with stacked ``TransformerBlock`` layers per feature key: + + >>> def get_attention_layers(self): + ... return { + ... key: [ + ... (blk.attention.get_attn_map(), + ... blk.attention.get_attn_grad()) + ... for blk in self.transformer[key].transformer + ... ] + ... for key in self.feature_keys + ... } + + A model with a single MHA layer per feature key: + + >>> def get_attention_layers(self): + ... return { + ... key: [(self.stagenet[key].get_attn_map(), + ... self.stagenet[key].get_attn_grad())] + ... for key in self.feature_keys + ... } + """ + ... + + @abstractmethod + def get_relevance_tensor( + self, + R: dict[str, torch.Tensor], + **data: torch.Tensor | tuple[torch.Tensor, ...], + ) -> dict[str, torch.Tensor]: + """Reduce relevance matrices to per-token attribution vectors. + + The Chefer algorithm builds a relevance matrix of shape + ``[batch, seq_len, seq_len]`` for each feature key. This method + reduces each matrix to a ``[batch, seq_len]`` vector by selecting + the row corresponding to the classification position — giving the + model full control over how the selection is done. + + Parameters + ---------- + R : dict[str, torch.Tensor] + Relevance matrices keyed by ``feature_keys``. Each tensor + has shape ``[batch, seq_len, seq_len]`` (seq_len may differ + across keys). + **data : torch.Tensor or tuple[torch.Tensor, ...] + The original input data (same kwargs passed to + ``forward()``). Available for context when the selection + logic is data-dependent (e.g. last valid timestep depends on + mask). + + Returns + ------- + dict[str, torch.Tensor] + Attribution vectors keyed by ``feature_keys``. Each tensor + has shape ``[batch, seq_len]``. + + Examples + -------- + CLS-token model (e.g. Transformer) — row 0 for all keys: + + >>> def get_relevance_vector(self, R, **data): + ... return {key: r[:, 0] for key, r in R.items()} + + Last-valid-timestep model (e.g. StageAttentionNet): + + >>> def get_relevance_vector(self, R, **data): + ... result = {} + ... for key, r in R.items(): + ... mask = self._get_mask(key, **data) + ... last_idx = mask.sum(dim=1) - 1 + ... batch_idx = torch.arange(r.shape[0], device=r.device) + ... result[key] = r[batch_idx, last_idx] + ... return result + """ + ... + + # TODO: Add postprocess_attribution() when ViT support is ready. + # ViT models need to strip the CLS column, reshape the patch vector + # into a spatial [batch, 1, H, W] map, and optionally interpolate to + # the original image size. For EHR models this is a no-op. We can + # either fold this into extract_attribution() or add it as a separate + # optional method. \ No newline at end of file diff --git a/pyhealth/interpret/methods/chefer.py b/pyhealth/interpret/methods/chefer.py index 5bef2fc45..5efca68eb 100644 --- a/pyhealth/interpret/methods/chefer.py +++ b/pyhealth/interpret/methods/chefer.py @@ -1,21 +1,32 @@ -from typing import Dict +"""Chefer's gradient-weighted attention relevance propagation. + +This module implements the Chefer et al. relevance propagation method for +explaining transformer-family model predictions. It relies on the +:class:`~pyhealth.interpret.api.CheferInterpretable` interface — any model +that implements that interface is automatically supported. + +Paper: + Chefer, Hila, Shir Gur, and Lior Wolf. + "Generic Attention-model Explainability for Interpreting Bi-Modal and + Encoder-Decoder Transformers." + Proceedings of the IEEE/CVF International Conference on Computer Vision + (ICCV), 2021. +""" + +from typing import Dict, Optional, cast import torch import torch.nn.functional as F -from pyhealth.models import Transformer +from pyhealth.interpret.api import CheferInterpretable from pyhealth.models.base_model import BaseModel - +from pyhealth.interpret.api import CheferInterpretable from .base_interpreter import BaseInterpreter -# Import TorchvisionModel conditionally to avoid circular imports -try: - from pyhealth.models import TorchvisionModel - HAS_TORCHVISION_MODEL = True -except ImportError: - HAS_TORCHVISION_MODEL = False - TorchvisionModel = None +# --------------------------------------------------------------------------- +# Helper functions +# --------------------------------------------------------------------------- def apply_self_attention_rules(R_ss, cam_ss): """Apply Chefer's self-attention rules for relevance propagation. @@ -47,41 +58,38 @@ def avg_heads(cam, grad): return cam.clone() +# --------------------------------------------------------------------------- +# Main interpreter +# --------------------------------------------------------------------------- + class CheferRelevance(BaseInterpreter): """Chefer's gradient-weighted attention method for transformer interpretability. - This class implements the relevance propagation method from Chefer et al. for - explaining transformer model predictions. It computes relevance scores for each - input token (for text/EHR transformers) or patch (for Vision Transformers) by - combining attention weights with their gradients. + This interpreter works with **any** model that implements the + :class:`~pyhealth.interpret.api.CheferInterpretable` interface, which + currently includes: - The method works by: - 1. Performing a forward pass to capture attention maps from each layer - 2. Computing gradients of the target class w.r.t. attention weights - 3. Combining attention and gradients using element-wise multiplication - 4. Propagating relevance through layers using attention rollout rules + * :class:`~pyhealth.models.Transformer` + * :class:`~pyhealth.models.StageAttentionNet` - This approach provides more faithful explanations than raw attention weights - alone, as it accounts for how attention contributes to the final prediction. + The algorithm: - Paper: - Chefer, Hila, Shir Gur, and Lior Wolf. - "Generic Attention-model Explainability for Interpreting Bi-Modal and - Encoder-Decoder Transformers." - Proceedings of the IEEE/CVF International Conference on Computer Vision - (ICCV), 2021. + 1. Enable attention hooks via ``model.set_attention_hooks(True)``. + 2. Forward pass → capture attention maps and register gradient hooks. + 3. Backward pass from a one-hot target class. + 4. Retrieve ``(attn_map, attn_grad)`` pairs via ``model.get_attention_layers()``. + 5. Propagate relevance: ``R += clamp(attn * grad, min=0) @ R``. + 6. Reduce ``R`` to per-token vectors via ``model.get_relevance_tensor()``. - Supported Models: - - PyHealth Transformer: For sequential/EHR data with multiple feature keys - - TorchvisionModel (ViT variants): vit_b_16, vit_b_32, vit_l_16, vit_l_32, vit_h_14 + Steps 1, 4 and 6 are delegated to the model through the + ``CheferInterpretable`` interface, making this class fully + model-agnostic. Args: - model (BaseModel): A trained PyHealth model to interpret. Must be either: - - A ``Transformer`` model for sequential/EHR data - - A ``TorchvisionModel`` with a ViT architecture for image data + model (BaseModel): A trained PyHealth model that implements + :class:`~pyhealth.interpret.api.CheferInterpretable`. Example: - >>> # Example 1: PyHealth Transformer for EHR data >>> from pyhealth.datasets import create_sample_dataset, get_dataloader >>> from pyhealth.models import Transformer >>> from pyhealth.interpret.methods import CheferRelevance @@ -111,7 +119,6 @@ class CheferRelevance(BaseInterpreter): >>> model = Transformer(dataset=dataset) >>> # ... train the model ... >>> - >>> # Create interpreter and compute attribution >>> interpreter = CheferRelevance(model) >>> batch = next(iter(get_dataloader(dataset, batch_size=2))) >>> @@ -122,246 +129,328 @@ class CheferRelevance(BaseInterpreter): >>> >>> # Optional: attribute to a specific class (e.g., class 1) >>> attributions = interpreter.attribute(class_index=1, **batch) - >>> - >>> # Example 2: TorchvisionModel ViT for image data - >>> from pyhealth.datasets import COVID19CXRDataset - >>> from pyhealth.models import TorchvisionModel - >>> from pyhealth.interpret.utils import visualize_image_attr - >>> - >>> base_dataset = COVID19CXRDataset(root="/path/to/data") - >>> sample_dataset = base_dataset.set_task() - >>> model = TorchvisionModel( - ... dataset=sample_dataset, - ... model_name="vit_b_16", - ... model_config={"weights": "DEFAULT"}, - ... ) - >>> # ... train the model ... - >>> - >>> # Create interpreter and compute attribution - >>> # Task schema: input_schema={"image": "image"}, so feature_key="image" - >>> interpreter = CheferRelevance(model) - >>> - >>> # Default: attribute to predicted class - >>> result = interpreter.attribute(**batch) - >>> # Returns dict keyed by feature_key: {"image": tensor} - >>> attr_map = result["image"] # Shape: [batch, 1, H, W] - >>> - >>> # Optional: attribute to a specific class (e.g., predicted class) - >>> pred_class = model(**batch)["y_prob"].argmax().item() - >>> result = interpreter.attribute(class_index=pred_class, **batch) - >>> - >>> # Visualize - >>> img, attr, overlay = visualize_image_attr( - ... image=batch["image"][0], - ... attribution=result["image"][0, 0], - ... ) """ def __init__(self, model: BaseModel): super().__init__(model) - - # Determine model type - self._is_transformer = isinstance(model, Transformer) - self._is_vit = False - - if HAS_TORCHVISION_MODEL and TorchvisionModel is not None: - if isinstance(model, TorchvisionModel): - self._is_vit = model.is_vit_model() - - if not self._is_transformer and not self._is_vit: - raise ValueError( - f"CheferRelevance requires a Transformer or TorchvisionModel (ViT), " - f"got {type(model).__name__}. For TorchvisionModel, only ViT variants " - f"(vit_b_16, vit_b_32, etc.) are supported." - ) + if not isinstance(model, CheferInterpretable): + raise ValueError("Model must implement CheferInterpretable interface") + self.model = model def attribute( self, - interpolate: bool = True, - class_index: int = None, + class_index: Optional[int] = None, **data, ) -> Dict[str, torch.Tensor]: - """Compute relevance scores for each token/patch. - - This is the primary method for computing attributions. Returns a - dictionary keyed by the model's feature keys (from the task schema). + """Compute relevance scores for each input token. Args: - interpolate: For ViT models, if True interpolate attribution to image size. - class_index: Target class index to compute attribution for. If None - (default), uses the model's predicted class. This is useful when - you want to explain why a specific class was predicted or to - compare attributions across different classes. - **data: Input data from dataloader batch containing: - - For Transformer: feature keys (conditions, procedures, etc.) + label - - For ViT: image feature key (e.g., "image") + label + class_index: Target class index to compute attribution for. + If None (default), uses the model's predicted class. + **data: Input data from dataloader batch containing feature + keys and label key. Returns: - Dict[str, torch.Tensor]: Dictionary keyed by feature keys from the task schema. - - - For Transformer: ``{"conditions": tensor, "procedures": tensor, ...}`` - where each tensor has shape ``[batch, num_tokens]``. - - For ViT: ``{"image": tensor}`` (or whatever the task's image key is) - where tensor has shape ``[batch, 1, H, W]``. + Dict[str, torch.Tensor]: Dictionary keyed by feature keys, + where each tensor has shape ``[batch, seq_len]`` with + per-token attribution scores. """ - if self._is_vit: - return self._attribute_vit( - interpolate=interpolate, - class_index=class_index, - **data - ) - return self._attribute_transformer(class_index=class_index, **data) - - def _attribute_transformer( - self, - class_index: int = None, - **data - ) -> Dict[str, torch.Tensor]: - """Compute relevance for PyHealth Transformer models. - - Args: - class_index: Target class for attribution. If None, uses predicted class. - **data: Input data from dataloader batch. - """ - data["register_hook"] = True - - logits = self.model(**data)["logit"] + # --- 1. Forward with attention hooks enabled --- + self.model.set_attention_hooks(True) + try: + logits = self.model(**data)["logit"] + finally: + self.model.set_attention_hooks(False) + + # --- 2. Backward from target class --- if class_index is None: - class_index = torch.argmax(logits, dim=-1) + class_index_t = torch.argmax(logits, dim=-1) + elif isinstance(class_index, int): + class_index_t = torch.tensor(class_index) + else: + class_index_t = class_index - one_hot = F.one_hot(torch.tensor(class_index), logits.size()[1]).float() + one_hot = F.one_hot( + class_index_t.detach().clone(), logits.size(1) + ).float() one_hot = one_hot.requires_grad_(True) - one_hot = torch.sum(one_hot.to(logits.device) * logits) + scalar = torch.sum(one_hot.to(logits.device) * logits) self.model.zero_grad() - one_hot.backward(retain_graph=True) + scalar.backward(retain_graph=True) - feature_keys = self.model.feature_keys - num_tokens = {} - for key in feature_keys: - feature_transformer = self.model.transformer[key].transformer - for block in feature_transformer: - num_tokens[key] = block.attention.get_attn_map().shape[-1] + # --- 3. Retrieve (attn_map, attn_grad) pairs per feature key --- + attention_layers = self.model.get_attention_layers() - attn = {} - for key in feature_keys: + batch_size = logits.shape[0] + device = logits.device + + # --- 4. Relevance propagation per feature key --- + R_dict: dict[str, torch.Tensor] = {} + for key, layers in attention_layers.items(): + num_tokens = layers[0][0].shape[-1] R = ( - torch.eye(num_tokens[key]) + torch.eye(num_tokens, device=device) .unsqueeze(0) - .repeat(len(data[key]), 1, 1) - .to(logits.device) + .repeat(batch_size, 1, 1) ) - for blk in self.model.transformer[key].transformer: - grad = blk.attention.get_attn_grad() - cam = blk.attention.get_attn_map() + for cam, grad in layers: cam = avg_heads(cam, grad) - R += apply_self_attention_rules(R, cam).detach() - attn[key] = R[:, 0] + R = R + apply_self_attention_rules(R, cam).detach() + R_dict[key] = R + + # --- 5. Reduce R matrices to per-token vectors --- + attributions = self.model.get_relevance_tensor(R_dict, **data) + + # --- 6. Expand to match raw input shapes (nested sequences) --- + return self._map_to_input_shapes(attributions, data) - return attn + # ------------------------------------------------------------------ + # Shape mapping + # ------------------------------------------------------------------ - def _attribute_vit( + def _map_to_input_shapes( self, - interpolate: bool = True, - class_index: int = None, - **data, + attributions: Dict[str, torch.Tensor], + data: dict, ) -> Dict[str, torch.Tensor]: - """Compute ViT attribution and return spatial attribution map. - + """Expand attributions to match raw input value shapes. + + For nested sequences the attention operates on a pooled + (visit-level) sequence, but downstream consumers (e.g. ablation + metrics) expect attributions to match the raw input value shape. + Per-visit relevance scores are replicated across all codes + within each visit. + Args: - interpolate: If True, interpolate to full image size. - class_index: Target class for attribution. If None, uses predicted class. - **data: Must contain the image feature key. - + attributions: Per-feature attribution tensors returned by + ``model.get_relevance_tensor()``. + data: Original ``**data`` kwargs from the dataloader batch. + Returns: - Dict keyed by the model's feature_key (e.g., "image") with spatial - attribution map of shape [batch, 1, H, W]. + Attributions expanded to raw input value shapes where needed. """ - # Get the feature key (first element of feature_keys list) - feature_key = self.model.feature_keys[0] - x = data.get(feature_key) - if x is None: - raise ValueError( - f"Expected feature key '{feature_key}' in data. " - f"Available keys: {list(data.keys())}" - ) - - x = x.to(self.model.device) - - # Infer input size from image dimensions (assumes square images) - input_size = x.shape[-1] - - # Forward pass with attention capture - self.model.zero_grad() - logits, attention_maps = self.model.forward_with_attention(x, register_hook=True) - - # Use predicted class if not specified - target_class = class_index - if target_class is None: - target_class = logits.argmax(dim=-1) - - # Backward pass - one_hot = torch.zeros_like(logits) - if isinstance(target_class, int): - one_hot[:, target_class] = 1 - else: - if target_class.dim() == 0: - target_class = target_class.unsqueeze(0) - one_hot.scatter_(1, target_class.unsqueeze(1), 1) - - one_hot = one_hot.requires_grad_(True) - (logits * one_hot).sum().backward(retain_graph=True) - - # Compute gradient-weighted attention - attention_gradients = self.model.get_attention_gradients() - batch_size = attention_maps[0].shape[0] - num_tokens = attention_maps[0].shape[-1] - device = attention_maps[0].device - - R = torch.eye(num_tokens, device=device) - R = R.unsqueeze(0).expand(batch_size, -1, -1).clone() - - for attn, grad in zip(attention_maps, attention_gradients): - cam = avg_heads(attn, grad) - R = R + apply_self_attention_rules(R.detach(), cam.detach()) - - # CLS token's relevance to patches (excluding CLS itself) - patches_attr = R[:, 0, 1:] - - # Reshape to spatial layout - h_patches, w_patches = self.model.get_num_patches(input_size) - attr_map = patches_attr.reshape(batch_size, 1, h_patches, w_patches) - - if interpolate: - attr_map = F.interpolate( - attr_map, - size=(input_size, input_size), - mode="bilinear", - align_corners=False, - ) - - # Return keyed by the model's feature key (e.g., "image") - return {feature_key: attr_map} + result: Dict[str, torch.Tensor] = {} + for key, attr in attributions.items(): + feature = data.get(key) + if feature is not None: + if isinstance(feature, torch.Tensor): + val = feature + else: + schema = self.model.dataset.input_processors[key].schema() + val = ( + feature[schema.index("value")] + if "value" in schema + else None + ) + if val is not None and val.dim() > attr.dim(): + for _ in range(val.dim() - attr.dim()): + attr = attr.unsqueeze(-1) + attr = attr.expand_as(val) + result[key] = attr + return result + + # ------------------------------------------------------------------ + # Backward compatibility aliases + # ------------------------------------------------------------------ - # Backwards compatibility aliases def get_relevance_matrix(self, **data): - """Alias for _attribute_transformer. Use attribute() instead.""" - return self._attribute_transformer(**data) + """Alias for attribute(). Deprecated.""" + return self.attribute(**data) + + +# ====================================================================== +# LEGACY REFERENCE IMPLEMENTATIONS +# ====================================================================== +# The functions below are the original model-specific implementations +# that existed before the CheferInterpretable API was introduced. They +# are kept here ONLY as a reference for future developers and are NOT +# called by any production code. They may be removed in a future +# release. +# +# For ViT models, _reference_attribute_vit is the only implementation +# until ViT models implement CheferInterpretable. +# ====================================================================== + +def _reference_attribute_transformer( + model, + class_index=None, + **data, +) -> Dict[str, torch.Tensor]: + """[REFERENCE ONLY] Original Transformer-specific Chefer attribution. + + This was the body of ``CheferRelevance._attribute_transformer()`` + before the CheferInterpretable API was introduced. It accesses + model internals (``model.transformer[key].transformer``) directly. + """ + data["register_hook"] = True - def get_vit_attribution_map( - self, - interpolate: bool = True, - class_index: int = None, - **data - ): - """Alias for attribute() for ViT. Use attribute() instead. - - Returns the attribution tensor directly (not wrapped in a dict). - """ - result = self._attribute_vit( - interpolate=interpolate, - class_index=class_index, - **data + logits = model(**data)["logit"] + if class_index is None: + class_index = torch.argmax(logits, dim=-1) + + if isinstance(class_index, torch.Tensor): + one_hot = F.one_hot(class_index.detach().clone(), logits.size()[1]).float() + else: + one_hot = F.one_hot(torch.tensor(class_index), logits.size()[1]).float() + one_hot = one_hot.requires_grad_(True) + one_hot = torch.sum(one_hot.to(logits.device) * logits) + model.zero_grad() + one_hot.backward(retain_graph=True) + + feature_keys = model.feature_keys + num_tokens = {} + for key in feature_keys: + feature_transformer = model.transformer[key].transformer + for block in feature_transformer: + num_tokens[key] = block.attention.get_attn_map().shape[-1] + + batch_size = logits.shape[0] + attn = {} + for key in feature_keys: + R = ( + torch.eye(num_tokens[key]) + .unsqueeze(0) + .repeat(batch_size, 1, 1) + .to(logits.device) + ) + for blk in model.transformer[key].transformer: + grad = blk.attention.get_attn_grad() + cam = blk.attention.get_attn_map() + cam = avg_heads(cam, grad) + R += apply_self_attention_rules(R, cam).detach() + attn[key] = R[:, 0] + + return attn + + +def _reference_attribute_stageattn( + model, + class_index=None, + **data, +) -> Dict[str, torch.Tensor]: + """[REFERENCE ONLY] Original StageAttentionNet-specific Chefer attribution. + + This was the body of ``CheferRelevance._attribute_stageattn()`` + before the CheferInterpretable API was introduced. It accesses + model internals (``model.stagenet[key]``, ``model.embedding_model``) + directly. + """ + data["register_attn_hook"] = True + + logits = model(**data)["logit"] + if class_index is None: + class_index = torch.argmax(logits, dim=-1) + + if isinstance(class_index, torch.Tensor): + one_hot = F.one_hot(class_index.detach().clone(), logits.size()[1]).float() + else: + one_hot = F.one_hot(torch.tensor(class_index), logits.size()[1]).float() + one_hot = one_hot.requires_grad_(True) + one_hot = torch.sum(one_hot.to(logits.device) * logits) + model.zero_grad() + one_hot.backward(retain_graph=True) + + batch_size = logits.shape[0] + feature_keys = model.feature_keys + attn = {} + + for key in feature_keys: + layer = model.stagenet[key] + cam = layer.get_attn_map() + grad = layer.get_attn_grad() + num_tokens = cam.shape[-1] + + R = ( + torch.eye(num_tokens) + .unsqueeze(0) + .repeat(batch_size, 1, 1) + .to(logits.device) + ) + cam = avg_heads(cam, grad) + R += apply_self_attention_rules(R, cam).detach() + + feature = data[key] + if isinstance(feature, tuple) and len(feature) == 2: + _, x_val = feature + else: + x_val = feature + + embedded = model.embedding_model({key: x_val}) + emb = embedded[key] + if emb.dim() == 4: + emb = emb.sum(dim=2) + mask = (emb.sum(dim=-1) != 0).long().to(logits.device) + + last_idx = mask.sum(dim=1) - 1 + attn[key] = R[torch.arange(batch_size, device=logits.device), last_idx] + + return attn + + +def _reference_attribute_vit( + model, + interpolate: bool = True, + class_index=None, + **data, +) -> Dict[str, torch.Tensor]: + """[REFERENCE ONLY] Original ViT-specific Chefer attribution. + + ViT models do not yet implement CheferInterpretable. This code + shows the ViT-specific flow that will be needed when ViT support is + added to the unified API. + """ + feature_key = model.feature_keys[0] + x = data.get(feature_key) + if x is None: + raise ValueError( + f"Expected feature key '{feature_key}' in data. " + f"Available keys: {list(data.keys())}" + ) + + x = x.to(model.device) + input_size = x.shape[-1] + + model.zero_grad() + logits, attention_maps = model.forward_with_attention(x, register_hook=True) + + target_class = class_index + if target_class is None: + target_class = logits.argmax(dim=-1) + + one_hot = torch.zeros_like(logits) + if isinstance(target_class, int): + one_hot[:, target_class] = 1 + else: + if target_class.dim() == 0: + target_class = target_class.unsqueeze(0) + one_hot.scatter_(1, target_class.unsqueeze(1), 1) + + one_hot = one_hot.requires_grad_(True) + (logits * one_hot).sum().backward(retain_graph=True) + + attention_gradients = model.get_attention_gradients() + batch_size = attention_maps[0].shape[0] + num_tokens = attention_maps[0].shape[-1] + device = attention_maps[0].device + + R = torch.eye(num_tokens, device=device) + R = R.unsqueeze(0).expand(batch_size, -1, -1).clone() + + for attn, grad in zip(attention_maps, attention_gradients): + cam = avg_heads(attn, grad) + R = R + apply_self_attention_rules(R.detach(), cam.detach()) + + patches_attr = R[:, 0, 1:] + + h_patches, w_patches = model.get_num_patches(input_size) + attr_map = patches_attr.reshape(batch_size, 1, h_patches, w_patches) + + if interpolate: + attr_map = F.interpolate( + attr_map, + size=(input_size, input_size), + mode="bilinear", + align_corners=False, ) - # Return the attribution tensor directly (get the first/only value) - feature_key = self.model.feature_keys[0] - return result[feature_key] + + return {feature_key: attr_map} \ No newline at end of file diff --git a/pyhealth/interpret/methods/deeplift.py b/pyhealth/interpret/methods/deeplift.py index 211aeff3b..29f99d795 100644 --- a/pyhealth/interpret/methods/deeplift.py +++ b/pyhealth/interpret/methods/deeplift.py @@ -1,12 +1,13 @@ from __future__ import annotations import contextlib -from typing import Dict, List, Optional, Tuple, Type +from typing import Dict, List, Optional, Tuple, Type, cast import torch import torch.nn.functional as F from pyhealth.models import BaseModel +from pyhealth.interpret.api import Interpretable from .base_interpreter import BaseInterpreter @@ -330,16 +331,12 @@ class DeepLift(BaseInterpreter): def __init__(self, model: BaseModel, use_embeddings: bool = True): super().__init__(model) + if not isinstance(model, Interpretable): + raise ValueError("Model must implement Interpretable interface") + self.model = model + self.use_embeddings = use_embeddings - if use_embeddings: - assert hasattr(model, "forward_from_embedding"), ( - f"Model {type(model).__name__} must implement " - "forward_from_embedding() method to support embedding-level " - "DeepLIFT. Set use_embeddings=False to use input-level " - "gradients (only for continuous features)." - ) - # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------ diff --git a/pyhealth/interpret/methods/gim.py b/pyhealth/interpret/methods/gim.py index f53dd3199..d1b74b573 100644 --- a/pyhealth/interpret/methods/gim.py +++ b/pyhealth/interpret/methods/gim.py @@ -8,7 +8,7 @@ import torch.nn.functional as F from pyhealth.models import BaseModel - +from pyhealth.interpret.api import Interpretable from .base_interpreter import BaseInterpreter @@ -351,16 +351,10 @@ def __init__( temperature: float = 2.0, ): super().__init__(model) - if not hasattr(model, "forward_from_embedding"): - raise AssertionError( - "GIM requires models that implement `forward_from_embedding`." - ) - embedding_model = model.get_embedding_model() - if embedding_model is None: - raise AssertionError( - "GIM requires a model with an embedding model " - "accessible via `get_embedding_model()`." - ) + if not isinstance(model, Interpretable): + raise ValueError("Model must implement Interpretable interface") + self.model = model + self.temperature = max(float(temperature), 1.0) def attribute( diff --git a/pyhealth/interpret/methods/ig_gim.py b/pyhealth/interpret/methods/ig_gim.py index b0a81d7b4..a33f5529b 100644 --- a/pyhealth/interpret/methods/ig_gim.py +++ b/pyhealth/interpret/methods/ig_gim.py @@ -36,6 +36,7 @@ from pyhealth.models import BaseModel from .base_interpreter import BaseInterpreter +from pyhealth.interpret.api import Interpretable from .gim import _GIMHookContext @@ -79,18 +80,9 @@ def __init__( steps: int = 50, ): super().__init__(model) - - if not hasattr(model, "forward_from_embedding"): - raise AssertionError( - f"Model {type(model).__name__} must implement " - "forward_from_embedding() to use IG-GIM. " - "Set use_embeddings=False for input-level IG instead." - ) - if model.get_embedding_model() is None: - raise AssertionError( - "Model must provide an embedding model via " - "get_embedding_model() for IG-GIM." - ) + if not isinstance(model, Interpretable): + raise ValueError("Model must implement Interpretable interface") + self.model = model self.temperature = max(float(temperature), 1.0) self.steps = steps diff --git a/pyhealth/interpret/methods/integrated_gradients.py b/pyhealth/interpret/methods/integrated_gradients.py index 7cc24b39d..a529a6f3f 100644 --- a/pyhealth/interpret/methods/integrated_gradients.py +++ b/pyhealth/interpret/methods/integrated_gradients.py @@ -6,7 +6,7 @@ import torch.nn.functional as F from pyhealth.models import BaseModel - +from pyhealth.interpret.api import Interpretable from .base_interpreter import BaseInterpreter @@ -187,17 +187,13 @@ def __init__(self, model: BaseModel, use_embeddings: bool = True, steps: int = 5 implement forward_from_embedding() method. """ super().__init__(model) + if not isinstance(model, Interpretable): + raise ValueError("Model must implement Interpretable interface") + self.model = model + self.use_embeddings = use_embeddings self.steps = steps - # Check model supports forward_from_embedding if needed - if use_embeddings: - assert hasattr(model, "forward_from_embedding"), ( - f"Model {type(model).__name__} must implement " - "forward_from_embedding() method to support embedding-level " - "Integrated Gradients. Set use_embeddings=False to use " - "input-level gradients (only for continuous features)." - ) def attribute( self, diff --git a/pyhealth/interpret/methods/lime.py b/pyhealth/interpret/methods/lime.py index 7518aaa50..8481c50bf 100644 --- a/pyhealth/interpret/methods/lime.py +++ b/pyhealth/interpret/methods/lime.py @@ -8,6 +8,7 @@ from torch.nn import CosineSimilarity from pyhealth.models import BaseModel +from pyhealth.interpret.api import Interpretable from .base_interpreter import BaseInterpreter @@ -131,6 +132,10 @@ def __init__( ValueError: If feature_selection is not "lasso", "ridge", or "none". """ super().__init__(model) + if not isinstance(model, Interpretable): + raise ValueError("Model must implement Interpretable interface") + self.model = model + self.use_embeddings = use_embeddings self.n_samples = n_samples self.kernel_width = kernel_width diff --git a/pyhealth/interpret/methods/shap.py b/pyhealth/interpret/methods/shap.py index c68a2e08f..edcd02fc2 100644 --- a/pyhealth/interpret/methods/shap.py +++ b/pyhealth/interpret/methods/shap.py @@ -6,6 +6,7 @@ import torch from pyhealth.models import BaseModel +from pyhealth.interpret.api import Interpretable from .base_interpreter import BaseInterpreter @@ -120,6 +121,9 @@ def __init__( implement forward_from_embedding() method. """ super().__init__(model) + if not isinstance(model, Interpretable): + raise ValueError("Model must implement Interpretable interface") + self.model = model self.use_embeddings = use_embeddings self.n_background_samples = n_background_samples self.max_coalitions = max_coalitions diff --git a/pyhealth/models/base_model.py b/pyhealth/models/base_model.py index fc59e6026..f40615794 100644 --- a/pyhealth/models/base_model.py +++ b/pyhealth/models/base_model.py @@ -73,37 +73,6 @@ def forward(self, y_true [optional]: a tensor representing the true labels, if self.label_keys in kwargs. """ raise NotImplementedError - - def forward_from_embedding( - self, - **kwargs: torch.Tensor | tuple[torch.Tensor, ...] - ) -> dict[str, torch.Tensor]: - """Forward pass of the model from embeddings. - - This method should be implemented for interpretability methods that require - access to the model's forward pass from embeddings. - - Args: - **kwargs: A variable number of keyword arguments representing input features - as embeddings. Each keyword argument is a tensor or a tuple of tensors of - shape (batch_size, ...). - - Returns: - A dictionary with the following keys: - logit: a tensor of predicted logits. - y_prob: a tensor of predicted probabilities. - loss [optional]: a scalar tensor representing the final loss, if self.label_keys in kwargs. - y_true [optional]: a tensor representing the true labels, if self.label_keys in kwargs. - """ - raise NotImplementedError - - def get_embedding_model(self) -> nn.Module | None: - """Get the embedding model if applicable. This is used in pair with `forward_from_embedding`. - - Returns: - nn.Module | None: The embedding model or None if not applicable. - """ - raise NotImplementedError # ------------------------------------------------------------------ # Internal helpers diff --git a/pyhealth/models/mlp.py b/pyhealth/models/mlp.py index 306073313..4c6432225 100644 --- a/pyhealth/models/mlp.py +++ b/pyhealth/models/mlp.py @@ -5,11 +5,12 @@ from pyhealth.datasets import SampleDataset from pyhealth.models import BaseModel +from pyhealth.interpret.api import Interpretable from .embedding import EmbeddingModel -class MLP(BaseModel): +class MLP(BaseModel, Interpretable): """Multi-layer perceptron model. This model applies a separate MLP layer for each feature, and then diff --git a/pyhealth/models/stagenet.py b/pyhealth/models/stagenet.py index bf78bb216..aad794b02 100644 --- a/pyhealth/models/stagenet.py +++ b/pyhealth/models/stagenet.py @@ -6,6 +6,7 @@ from pyhealth.datasets import SampleDataset from pyhealth.models import BaseModel from pyhealth.models.utils import get_last_visit +from pyhealth.interpret.api import Interpretable from .embedding import EmbeddingModel @@ -239,7 +240,7 @@ def forward( return last_output, output, torch.stack(distance) -class StageNet(BaseModel): +class StageNet(BaseModel, Interpretable): """StageNet model. Paper: Junyi Gao et al. Stagenet: Stage-aware neural networks for health diff --git a/pyhealth/models/stagenet_mha.py b/pyhealth/models/stagenet_mha.py index f95d1e787..637f2fdc5 100644 --- a/pyhealth/models/stagenet_mha.py +++ b/pyhealth/models/stagenet_mha.py @@ -8,6 +8,7 @@ from pyhealth.models import BaseModel from pyhealth.models.utils import get_last_visit from .transformer import MultiHeadedAttention +from pyhealth.interpret.api import CheferInterpretable from .embedding import EmbeddingModel @@ -297,7 +298,7 @@ def forward( return last_output, output, distance -class StageAttentionNet(BaseModel): +class StageAttentionNet(BaseModel, CheferInterpretable): """StageAttentionNet model. Paper: Junyi Gao et al. Stagenet: Stage-aware neural networks for health @@ -405,6 +406,7 @@ def __init__( self.embedding_dim = embedding_dim self.chunk_size = chunk_size self.levels = levels + self._attention_hooks_enabled = False # validate kwargs for StageNet layer if "input_dim" in kwargs: @@ -474,7 +476,8 @@ def forward_from_embedding( logit: the raw logits before activation. embed: (if embed=True in kwargs) the patient embedding. """ - register_attn_hook = kwargs.pop("register_attn_hook", False) + # Support both the flag-based API and legacy kwarg-based API + register_attn_hook = self._attention_hooks_enabled patient_emb = [] distance = [] @@ -597,7 +600,7 @@ def forward( """ register_attn_hook = kwargs.pop("register_attn_hook", False) if register_attn_hook: - kwargs["register_attn_hook"] = register_attn_hook # type: ignore + kwargs["register_attn_hook"] = register_attn_hook # type: ignore for feature_key in self.feature_keys: feature = kwargs[feature_key] @@ -628,3 +631,82 @@ def forward( def get_embedding_model(self) -> nn.Module | None: return self.embedding_model + + # ------------------------------------------------------------------ + # CheferInterpretable interface + # ------------------------------------------------------------------ + + def set_attention_hooks(self, enabled: bool) -> None: + self._attention_hooks_enabled = enabled + + def get_attention_layers( + self, + ) -> dict[str, list[tuple[torch.Tensor, torch.Tensor]]]: + return { # type: ignore[return-value] + key: [ + ( + cast(StageNetAttentionLayer, self.stagenet[key]).get_attn_map(), + cast(StageNetAttentionLayer, self.stagenet[key]).get_attn_grad(), + ) + ] + for key in self.feature_keys + } + + def get_relevance_tensor( + self, + R: dict[str, torch.Tensor], + **data: torch.Tensor | tuple[torch.Tensor, ...], + ) -> dict[str, torch.Tensor]: + # StageAttentionNet uses get_last_visit (last valid timestep) + # instead of a CLS token. Derive the mask from **data using + # the same logic as forward_from_embedding. + result = {} + for key in self.feature_keys: + r = R[key] + batch_size = r.shape[0] + device = r.device + + processor = self.dataset.input_processors[key] + schema = processor.schema() + feature = data[key] + + if isinstance(feature, torch.Tensor): + feature = (feature,) + + value = feature[schema.index("value")] if "value" in schema else None + mask = feature[schema.index("mask")] if "mask" in schema else None + + if len(feature) == len(schema) + 1 and mask is None: + mask = feature[-1] + + if mask is None: + if value is not None: + v = value.to(device) + mask = (v.abs().sum(dim=-1) != 0).int() + else: + # Cannot determine mask; fall back to last position + last_idx = torch.full( + (batch_size,), r.shape[1] - 1, + device=device, dtype=torch.long, + ) + result[key] = r[ + torch.arange(batch_size, device=device), last_idx + ] + continue + else: + mask = mask.to(device) + if not processor.is_token() and value is not None and value.dim() == mask.dim(): + mask = mask.any(dim=-1).int() + + if mask.dim() == 3: + # Nested sequences: collapse inner dimension + mask = mask.any(dim=2).int() + + last_idx = mask.sum(dim=1).long() - 1 + last_idx = last_idx.clamp(min=0) + attn = r[ + torch.arange(batch_size, device=device), last_idx + ] # [batch, attention_seq_len] + + result[key] = attn + return result diff --git a/pyhealth/models/transformer.py b/pyhealth/models/transformer.py index f1e28304d..de8555579 100644 --- a/pyhealth/models/transformer.py +++ b/pyhealth/models/transformer.py @@ -12,6 +12,7 @@ from pyhealth.datasets import SampleDataset from pyhealth.models import BaseModel from pyhealth.models.embedding import EmbeddingModel +from pyhealth.interpret.api import CheferInterpretable # VALID_OPERATION_LEVEL = ["visit", "event"] @@ -312,7 +313,7 @@ def forward( return emb, cls_emb -class Transformer(BaseModel): +class Transformer(BaseModel, CheferInterpretable): """Transformer model for PyHealth 2.0 datasets. Each feature stream is embedded with :class:`EmbeddingModel` and encoded by @@ -373,6 +374,7 @@ def __init__( self.heads = heads self.dropout = dropout self.num_layers = num_layers + self._attention_hooks_enabled = False assert ( len(self.label_keys) == 1 @@ -382,7 +384,7 @@ def __init__( self.embedding_model = EmbeddingModel(dataset, embedding_dim) - self.transformer = nn.ModuleDict() + self.transformer: nn.ModuleDict = nn.ModuleDict() for feature_key in self.feature_keys: self.transformer[feature_key] = TransformerLayer( feature_size=embedding_dim, @@ -465,7 +467,8 @@ def forward_from_embedding( logit: the raw logits before activation. embed: (if embed=True in kwargs) the patient embedding. """ - register_hook = bool(kwargs.pop("register_hook", False)) + # Support both the flag-based API and legacy kwarg-based API + register_hook = self._attention_hooks_enabled patient_emb = [] for feature_key in self.feature_keys: @@ -579,6 +582,41 @@ def get_embedding_model(self) -> nn.Module | None: """ return self.embedding_model + # ------------------------------------------------------------------ + # CheferInterpretable interface + # ------------------------------------------------------------------ + + def set_attention_hooks(self, enabled: bool) -> None: + self._attention_hooks_enabled = enabled + + def get_attention_layers( + self, + ) -> dict[str, list[tuple[torch.Tensor, torch.Tensor]]]: + return { # type: ignore[return-value] + key: [ + ( + cast(TransformerBlock, blk).attention.get_attn_map(), + cast(TransformerBlock, blk).attention.get_attn_grad(), + ) + for blk in cast( + TransformerLayer, self.transformer[key] + ).transformer + ] + for key in self.feature_keys + } + + def get_relevance_tensor( + self, + R: dict[str, torch.Tensor], + **data: torch.Tensor | tuple[torch.Tensor, ...], + ) -> dict[str, torch.Tensor]: + # CLS token is at index 0 for all feature keys + result = {} + for key, r in R.items(): + # CLS token is at index 0; extract its attention row + result[key] = r[:, 0] # [batch, attention_seq_len] + return result + if __name__ == "__main__": from pyhealth.datasets import create_sample_dataset, get_dataloader diff --git a/tests/core/test_gim.py b/tests/core/test_gim.py index 0893b874a..284931883 100644 --- a/tests/core/test_gim.py +++ b/tests/core/test_gim.py @@ -6,6 +6,7 @@ import torch.nn as nn from pyhealth.interpret.methods import GIM +from pyhealth.interpret.api import Interpretable from pyhealth.models import BaseModel from pyhealth.models.transformer import Attention @@ -64,7 +65,7 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: return {key: self.embedding(val.long()) for key, val in inputs.items()} -class _ToyGIMModel(BaseModel): +class _ToyGIMModel(BaseModel, Interpretable): """Small attention-style model with module-based nonlinearities. Follows the new API conventions: ``forward_from_embedding(**kwargs)`` diff --git a/tests/core/test_ig_gim.py b/tests/core/test_ig_gim.py index 3b8ff1556..38565227d 100644 --- a/tests/core/test_ig_gim.py +++ b/tests/core/test_ig_gim.py @@ -9,6 +9,7 @@ from pyhealth.interpret.methods import GIM, IntegratedGradients from pyhealth.interpret.methods.ig_gim import IntegratedGradientGIM +from pyhealth.interpret.api import Interpretable from pyhealth.models import BaseModel from pyhealth.models.transformer import Attention @@ -75,7 +76,7 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: return {key: self.embedding(val.long()) for key, val in inputs.items()} -class _ToyModel(BaseModel): +class _ToyModel(BaseModel, Interpretable): """Small model with softmax attention for testing IG-GIM.""" def __init__(self, vocab_size=32, embedding_dim=4, schema=("value",)): @@ -375,15 +376,18 @@ class _NoFwdEmb: obj.label_keys = ["label"] obj.get_embedding_model = lambda: _ToyEmbeddingModel() - with self.assertRaises((AssertionError, AttributeError)): + with self.assertRaises((AssertionError, AttributeError, ValueError)): IntegratedGradientGIM(obj) def test_init_rejects_model_without_embedding_model(self): - """Should raise if get_embedding_model returns None.""" + """Should raise if get_embedding_model returns None during attribution.""" model = _ToyModel() model.get_embedding_model = lambda: None + ig_gim = IntegratedGradientGIM(model) with self.assertRaises(AssertionError): - IntegratedGradientGIM(model) + ig_gim.attribute( + codes=self.tokens, label=self.labels, target_class_idx=0, + ) def test_temperature_clamped_to_one(self): """Temperatures below 1.0 should be clamped to 1.0.""" @@ -653,7 +657,7 @@ def test_repr(self): # Additional toy model without Attention (no GIM-swappable modules) # --------------------------------------------------------------------------- -class _ToyModelNoAttention(BaseModel): +class _ToyModelNoAttention(BaseModel, Interpretable): """Simple model without Attention or LayerNorm — GIM hooks are no-ops.""" def __init__(self, vocab_size=32, embedding_dim=4): diff --git a/tests/core/test_lime.py b/tests/core/test_lime.py index b7970592d..ab061cb4a 100644 --- a/tests/core/test_lime.py +++ b/tests/core/test_lime.py @@ -14,6 +14,7 @@ from pyhealth.datasets import SampleDataset, get_dataloader from pyhealth.datasets.sample_dataset import SampleBuilder from pyhealth.models import MLP, StageNet, BaseModel +from pyhealth.interpret.api import Interpretable from pyhealth.interpret.methods import LimeExplainer from pyhealth.interpret.methods.base_interpreter import BaseInterpreter @@ -50,7 +51,7 @@ def __init__(self, input_schema, output_schema, processors=None): # Test model helpers # --------------------------------------------------------------------------- -class _SimpleLimeModel(BaseModel): +class _SimpleLimeModel(BaseModel, Interpretable): """Minimal model for testing LIME with continuous inputs.""" def __init__(self): @@ -100,7 +101,7 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: return {key: self.embedding(value.long()) for key, value in inputs.items()} -class _EmbeddingForwardModel(BaseModel): +class _EmbeddingForwardModel(BaseModel, Interpretable): """Toy model exposing forward_from_embedding for discrete features.""" def __init__(self, schema=("value",)): @@ -156,7 +157,7 @@ def get_embedding_model(self): return self.embedding_model -class _MultiFeatureModel(BaseModel): +class _MultiFeatureModel(BaseModel, Interpretable): """Model with multiple feature inputs for testing multi-feature LIME.""" def __init__(self): @@ -584,7 +585,7 @@ def __init__(self): self.linear = nn.Linear(3, 1) model = _BareModel() - with self.assertRaises(AssertionError): + with self.assertRaises((AssertionError, ValueError)): LimeExplainer(model, use_embeddings=True) diff --git a/tests/core/test_shap.py b/tests/core/test_shap.py index 105775cfd..8c03a1c1f 100644 --- a/tests/core/test_shap.py +++ b/tests/core/test_shap.py @@ -11,6 +11,7 @@ from pyhealth.datasets import SampleDataset, get_dataloader from pyhealth.datasets.sample_dataset import SampleBuilder from pyhealth.models import MLP, StageNet, BaseModel +from pyhealth.interpret.api import Interpretable from pyhealth.interpret.methods import ShapExplainer from pyhealth.interpret.methods.base_interpreter import BaseInterpreter @@ -47,7 +48,7 @@ def __init__(self, input_schema, output_schema, processors=None): # Test model helpers # --------------------------------------------------------------------------- -class _SimpleShapModel(BaseModel): +class _SimpleShapModel(BaseModel, Interpretable): """Minimal model for testing SHAP with continuous inputs.""" def __init__(self): @@ -97,7 +98,7 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: return {key: self.embedding(value.long()) for key, value in inputs.items()} -class _EmbeddingForwardModel(BaseModel): +class _EmbeddingForwardModel(BaseModel, Interpretable): """Toy model exposing forward_from_embedding for discrete features.""" def __init__(self, schema=("value",)): @@ -152,7 +153,7 @@ def get_embedding_model(self): return self.embedding_model -class _MultiFeatureModel(BaseModel): +class _MultiFeatureModel(BaseModel, Interpretable): """Model with multiple feature inputs for testing multi-feature SHAP.""" def __init__(self): @@ -503,7 +504,7 @@ def __init__(self): self.linear = nn.Linear(3, 1) model = _BareModel() - with self.assertRaises(AssertionError): + with self.assertRaises((AssertionError, ValueError)): ShapExplainer(model, use_embeddings=True) diff --git a/tests/core/test_stagenet_mha.py b/tests/core/test_stagenet_mha.py index 78206f309..4d4f1eebd 100644 --- a/tests/core/test_stagenet_mha.py +++ b/tests/core/test_stagenet_mha.py @@ -117,8 +117,10 @@ def test_attention_hook_records_map_and_grad(self): loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) batch = next(iter(loader)) - ret = self.model(register_attn_hook=True, **batch) + self.model.set_attention_hooks(True) + ret = self.model(**batch) ret["loss"].backward() + self.model.set_attention_hooks(False) for feature_key, layer in self.model.stagenet.items(): attn_map = layer.get_attn_map()