1+ from collections import OrderedDict
12from typing import Any , Dict , List , Optional , Tuple
23
34import torch
45from torch import Tensor , nn
56
67from chebai .models import ChebaiBaseNet
8+ from chebai .models .electra import filter_dict
79
810
911class FFN (ChebaiBaseNet ):
@@ -15,6 +17,8 @@ def __init__(
1517 1024 ,
1618 ],
1719 use_adam_optimizer : bool = False ,
20+ pretrained_checkpoint : Optional [str ] = None ,
21+ load_prefix : Optional [str ] = "model." ,
1822 ** kwargs ,
1923 ):
2024 super ().__init__ (** kwargs )
@@ -32,6 +36,33 @@ def __init__(
3236 layers .append (torch .nn .Linear (current_layer_input_size , self .out_dim ))
3337 self .model = nn .Sequential (* layers )
3438
39+ if pretrained_checkpoint is not None :
40+ ckpt_file = torch .load (
41+ pretrained_checkpoint , map_location = self .device , weights_only = False
42+ )
43+ if load_prefix is not None :
44+ state_dict = filter_dict (ckpt_file ["state_dict" ], load_prefix )
45+ else :
46+ state_dict = ckpt_file ["state_dict" ]
47+
48+ model_sd = self .model .state_dict ()
49+ filtered = OrderedDict ()
50+ skipped = set ()
51+ for k , v in state_dict .items ():
52+ if model_sd [k ].shape == v .shape :
53+ filtered [k ] = v # only load params with matching shapes
54+ else :
55+ skipped .add (k )
56+ filtered [k ] = model_sd [k ]
57+ # else: silently skip mismatched keys like "2.weight", "2.bias"
58+ # which is the last linear layers which maps to output dimension
59+
60+ self .model .load_state_dict (filtered )
61+ print (
62+ f"Loaded (shape-matched) weights from { pretrained_checkpoint } " ,
63+ f"Skipped the following weights: { skipped } " ,
64+ )
65+
3566 def _get_prediction_and_labels (self , data , labels , model_output ):
3667 d = model_output ["logits" ]
3768 loss_kwargs = data .get ("loss_kwargs" , dict ())
0 commit comments