Skip to content

Commit b1af011

Browse files
authored
Merge pull request #133 from ChEB-AI/fix/mlp_enable_pretrained_ckpt
Enable to load pretrained weights for MLP model
2 parents 678ae19 + 45fee03 commit b1af011

File tree

2 files changed

+35
-0
lines changed

2 files changed

+35
-0
lines changed

chebai/models/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ def __init__(
5050
assert input_dim is not None, "input_dim must be specified"
5151
self.out_dim = out_dim
5252
self.input_dim = input_dim
53+
print(
54+
f"Input dimension for the model: {self.input_dim}",
55+
f"Output dimension for the model: {self.out_dim}",
56+
)
5357

5458
self.save_hyperparameters(
5559
ignore=[

chebai/models/ffn.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
from collections import OrderedDict
12
from typing import Any, Dict, List, Optional, Tuple
23

34
import torch
45
from torch import Tensor, nn
56

67
from chebai.models import ChebaiBaseNet
8+
from chebai.models.electra import filter_dict
79

810

911
class FFN(ChebaiBaseNet):
@@ -15,6 +17,8 @@ def __init__(
1517
1024,
1618
],
1719
use_adam_optimizer: bool = False,
20+
pretrained_checkpoint: Optional[str] = None,
21+
load_prefix: Optional[str] = "model.",
1822
**kwargs,
1923
):
2024
super().__init__(**kwargs)
@@ -32,6 +36,33 @@ def __init__(
3236
layers.append(torch.nn.Linear(current_layer_input_size, self.out_dim))
3337
self.model = nn.Sequential(*layers)
3438

39+
if pretrained_checkpoint is not None:
40+
ckpt_file = torch.load(
41+
pretrained_checkpoint, map_location=self.device, weights_only=False
42+
)
43+
if load_prefix is not None:
44+
state_dict = filter_dict(ckpt_file["state_dict"], load_prefix)
45+
else:
46+
state_dict = ckpt_file["state_dict"]
47+
48+
model_sd = self.model.state_dict()
49+
filtered = OrderedDict()
50+
skipped = set()
51+
for k, v in state_dict.items():
52+
if model_sd[k].shape == v.shape:
53+
filtered[k] = v # only load params with matching shapes
54+
else:
55+
skipped.add(k)
56+
filtered[k] = model_sd[k]
57+
# else: silently skip mismatched keys like "2.weight", "2.bias"
58+
# which is the last linear layers which maps to output dimension
59+
60+
self.model.load_state_dict(filtered)
61+
print(
62+
f"Loaded (shape-matched) weights from {pretrained_checkpoint}",
63+
f"Skipped the following weights: {skipped}",
64+
)
65+
3566
def _get_prediction_and_labels(self, data, labels, model_output):
3667
d = model_output["logits"]
3768
loss_kwargs = data.get("loss_kwargs", dict())

0 commit comments

Comments
 (0)