Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 49 additions & 73 deletions pyhealth/models/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from pyhealth.datasets import SampleDataset
from pyhealth.models import BaseModel

from .embedding import EmbeddingModel


class MLP(BaseModel):
"""Multi-layer perceptron model.
Expand Down Expand Up @@ -123,9 +125,8 @@ def __init__(
assert len(self.label_keys) == 1, "Only one label key is supported"
self.label_key = self.label_keys[0]

# Create embedding and linear layers (will be populated dynamically)
self.embeddings = nn.ModuleDict()
self.linear_layers = nn.ModuleDict()
# Use the EmbeddingModel to handle embedding logic
self.embedding_model = EmbeddingModel(dataset, embedding_dim)

# Set up activation function
if activation == "relu":
Expand Down Expand Up @@ -209,6 +210,10 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]:
"""
patient_emb = []

# Preprocess inputs for EmbeddingModel
processed_inputs = {}
reshape_info = {} # Track which inputs were reshaped

for feature_key in self.feature_keys:
x = kwargs[feature_key]

Expand All @@ -218,77 +223,48 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]:
else:
x = x.to(self.device)

# Handle different tensor dimensions
if x.dim() == 1:
# Case: Single sample with 1D data [val1, val2, ...]
# Need to add batch dimension
x = x.unsqueeze(0) # (1, features)

if x.dim() == 2:
# Case 1: Sequence of tokens (batch, seq_len) - for codes
# Case 3: Numerical features (batch, features) - for values

# Check if sequence data (int) or numerical data (float)
if x.dtype in [torch.int32, torch.int64]:
# This is sequence data (codes) - need embedding
if feature_key not in self.embeddings:
# Create embedding layer with max token + 1 vocab size
max_token_id = x.max().item()
vocab_size = max_token_id + 1
self.embeddings[feature_key] = nn.Embedding(
vocab_size, self.embedding_dim
).to(self.device)

x = self.embeddings[feature_key](x) # (batch, seq, embed)
# Apply mean pooling
mask = torch.any(x != 0, dim=2)
x = self.mean_pooling(x, mask)
else:
# This is numerical data - apply linear layer
if feature_key not in self.linear_layers:
input_size = x.shape[-1]
self.linear_layers[feature_key] = nn.Linear(
input_size, self.embedding_dim
).to(self.device)
# Ensure x is float for linear layer
x = x.float()
x = self.linear_layers[feature_key](x)

elif x.dim() == 3:
# Case 2: Nested sequences [[code1, code2], [code3, ...], ...]
# Case 4: Time series [[val1, val2], [val3, val4], ...]

if x.dtype in [torch.int32, torch.int64]:
# Sequence data with embeddings
if feature_key not in self.embeddings:
max_token_id = x.max().item()
vocab_size = max_token_id + 1
self.embeddings[feature_key] = nn.Embedding(
vocab_size, self.embedding_dim
).to(self.device)

batch_size, seq_len, inner_len = x.shape
x = x.view(-1, inner_len) # Flatten for embedding
x = self.embeddings[feature_key](x) # Embed
x = x.view(batch_size, seq_len, -1) # Reshape back
# Apply mean pooling over sequence dimension
mask = torch.any(x != 0, dim=2)
# Handle 3D input: (patient, event, # of codes) -> flatten to 2D
if x.dim() == 3:
batch_size, seq_len, inner_len = x.shape
# Flatten to (patient, event * # of codes)
x = x.view(batch_size, seq_len * inner_len)
# Store reshape info for later reconstruction
reshape_info[feature_key] = {
"original_shape": (batch_size, seq_len, inner_len),
"was_3d": True,
}
else:
reshape_info[feature_key] = {"was_3d": False}

processed_inputs[feature_key] = x

# Pass through EmbeddingModel
embedded = self.embedding_model(processed_inputs)

for feature_key in self.feature_keys:
x = embedded[feature_key]

# Handle reshaped 3D inputs
if reshape_info[feature_key]["was_3d"]:
# Reconstruct 3D shape: (batch, seq_len, embedding_dim)
original_shape = reshape_info[feature_key]["original_shape"]
batch_size, seq_len, inner_len = original_shape
# x is currently (batch, embedding_dim) from EmbeddingModel
# We need to handle the sequence dimension through pooling
# For now, treat as already pooled since EmbeddingModel did it
pass

# Handle different tensor dimensions for pooling
if x.dim() == 3:
# Case: (batch, seq_len, embedding_dim) - apply mean pooling
mask = (x.sum(dim=-1) != 0).float()
if mask.sum(dim=-1, keepdim=True).any():
x = self.mean_pooling(x, mask)
else:
# Numerical data - apply linear layer then pool
if feature_key not in self.linear_layers:
input_size = x.shape[-1]
self.linear_layers[feature_key] = nn.Linear(
input_size, self.embedding_dim
).to(self.device)
# Ensure x is float for linear layer
x = x.float()
# Apply linear layer to get embeddings
x = self.linear_layers[feature_key](x)
# Apply mean pooling
mask = torch.ones(x.shape[:2], device=self.device)
x = self.mean_pooling(x, mask)

x = x.mean(dim=1)
elif x.dim() == 2:
# Case: (batch, embedding_dim) - already pooled, use as is
pass
else:
raise ValueError(f"Unsupported tensor dimension: {x.dim()}")

Expand All @@ -297,7 +273,7 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]:
patient_emb.append(x)

patient_emb = torch.cat(patient_emb, dim=1)
print("debug:", patient_emb.shape)

# (patient, label_size)
logits = self.fc(patient_emb)
# obtain y_true, loss, y_prob
Expand Down
2 changes: 1 addition & 1 deletion tests/core/test_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def setUp(self):
{
"patient_id": "patient-0",
"visit_id": "visit-0",
"conditions": ["cond-33", "cond-86", "cond-80"],
"conditions": ["cond-33", "cond-86", "cond-80", "cond-12"],
"procedures": [1.0, 2.0, 3.5, 4],
"label": 0,
},
Expand Down