diff --git a/pyhealth/models/mlp.py b/pyhealth/models/mlp.py index ae2a4ff70..ed440f116 100644 --- a/pyhealth/models/mlp.py +++ b/pyhealth/models/mlp.py @@ -6,6 +6,8 @@ from pyhealth.datasets import SampleDataset from pyhealth.models import BaseModel +from .embedding import EmbeddingModel + class MLP(BaseModel): """Multi-layer perceptron model. @@ -123,9 +125,8 @@ def __init__( assert len(self.label_keys) == 1, "Only one label key is supported" self.label_key = self.label_keys[0] - # Create embedding and linear layers (will be populated dynamically) - self.embeddings = nn.ModuleDict() - self.linear_layers = nn.ModuleDict() + # Use the EmbeddingModel to handle embedding logic + self.embedding_model = EmbeddingModel(dataset, embedding_dim) # Set up activation function if activation == "relu": @@ -209,6 +210,10 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: """ patient_emb = [] + # Preprocess inputs for EmbeddingModel + processed_inputs = {} + reshape_info = {} # Track which inputs were reshaped + for feature_key in self.feature_keys: x = kwargs[feature_key] @@ -218,77 +223,48 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: else: x = x.to(self.device) - # Handle different tensor dimensions - if x.dim() == 1: - # Case: Single sample with 1D data [val1, val2, ...] - # Need to add batch dimension - x = x.unsqueeze(0) # (1, features) - - if x.dim() == 2: - # Case 1: Sequence of tokens (batch, seq_len) - for codes - # Case 3: Numerical features (batch, features) - for values - - # Check if sequence data (int) or numerical data (float) - if x.dtype in [torch.int32, torch.int64]: - # This is sequence data (codes) - need embedding - if feature_key not in self.embeddings: - # Create embedding layer with max token + 1 vocab size - max_token_id = x.max().item() - vocab_size = max_token_id + 1 - self.embeddings[feature_key] = nn.Embedding( - vocab_size, self.embedding_dim - ).to(self.device) - - x = self.embeddings[feature_key](x) # (batch, seq, embed) - # Apply mean pooling - mask = torch.any(x != 0, dim=2) - x = self.mean_pooling(x, mask) - else: - # This is numerical data - apply linear layer - if feature_key not in self.linear_layers: - input_size = x.shape[-1] - self.linear_layers[feature_key] = nn.Linear( - input_size, self.embedding_dim - ).to(self.device) - # Ensure x is float for linear layer - x = x.float() - x = self.linear_layers[feature_key](x) - - elif x.dim() == 3: - # Case 2: Nested sequences [[code1, code2], [code3, ...], ...] - # Case 4: Time series [[val1, val2], [val3, val4], ...] - - if x.dtype in [torch.int32, torch.int64]: - # Sequence data with embeddings - if feature_key not in self.embeddings: - max_token_id = x.max().item() - vocab_size = max_token_id + 1 - self.embeddings[feature_key] = nn.Embedding( - vocab_size, self.embedding_dim - ).to(self.device) - - batch_size, seq_len, inner_len = x.shape - x = x.view(-1, inner_len) # Flatten for embedding - x = self.embeddings[feature_key](x) # Embed - x = x.view(batch_size, seq_len, -1) # Reshape back - # Apply mean pooling over sequence dimension - mask = torch.any(x != 0, dim=2) + # Handle 3D input: (patient, event, # of codes) -> flatten to 2D + if x.dim() == 3: + batch_size, seq_len, inner_len = x.shape + # Flatten to (patient, event * # of codes) + x = x.view(batch_size, seq_len * inner_len) + # Store reshape info for later reconstruction + reshape_info[feature_key] = { + "original_shape": (batch_size, seq_len, inner_len), + "was_3d": True, + } + else: + reshape_info[feature_key] = {"was_3d": False} + + processed_inputs[feature_key] = x + + # Pass through EmbeddingModel + embedded = self.embedding_model(processed_inputs) + + for feature_key in self.feature_keys: + x = embedded[feature_key] + + # Handle reshaped 3D inputs + if reshape_info[feature_key]["was_3d"]: + # Reconstruct 3D shape: (batch, seq_len, embedding_dim) + original_shape = reshape_info[feature_key]["original_shape"] + batch_size, seq_len, inner_len = original_shape + # x is currently (batch, embedding_dim) from EmbeddingModel + # We need to handle the sequence dimension through pooling + # For now, treat as already pooled since EmbeddingModel did it + pass + + # Handle different tensor dimensions for pooling + if x.dim() == 3: + # Case: (batch, seq_len, embedding_dim) - apply mean pooling + mask = (x.sum(dim=-1) != 0).float() + if mask.sum(dim=-1, keepdim=True).any(): x = self.mean_pooling(x, mask) else: - # Numerical data - apply linear layer then pool - if feature_key not in self.linear_layers: - input_size = x.shape[-1] - self.linear_layers[feature_key] = nn.Linear( - input_size, self.embedding_dim - ).to(self.device) - # Ensure x is float for linear layer - x = x.float() - # Apply linear layer to get embeddings - x = self.linear_layers[feature_key](x) - # Apply mean pooling - mask = torch.ones(x.shape[:2], device=self.device) - x = self.mean_pooling(x, mask) - + x = x.mean(dim=1) + elif x.dim() == 2: + # Case: (batch, embedding_dim) - already pooled, use as is + pass else: raise ValueError(f"Unsupported tensor dimension: {x.dim()}") @@ -297,7 +273,7 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: patient_emb.append(x) patient_emb = torch.cat(patient_emb, dim=1) - print("debug:", patient_emb.shape) + # (patient, label_size) logits = self.fc(patient_emb) # obtain y_true, loss, y_prob diff --git a/tests/core/test_mlp.py b/tests/core/test_mlp.py index 20332c8a8..2320b577c 100644 --- a/tests/core/test_mlp.py +++ b/tests/core/test_mlp.py @@ -14,7 +14,7 @@ def setUp(self): { "patient_id": "patient-0", "visit_id": "visit-0", - "conditions": ["cond-33", "cond-86", "cond-80"], + "conditions": ["cond-33", "cond-86", "cond-80", "cond-12"], "procedures": [1.0, 2.0, 3.5, 4], "label": 0, },