Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions pyprophet/io/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,11 @@ def _save_tsv_weights(self, weights):

def _save_bin_weights(self, weights):
"""
Save the model weights to a binary file.
Save the model weights to a binary file with metadata.

For XGBoost/HistGradientBoosting classifiers, saves the model along with
metadata including ss_main_score and feature names to ensure proper
feature alignment when applying weights.

Args:
weights: Model weights or trained object.
Expand All @@ -537,8 +541,16 @@ def _save_bin_weights(self, weights):
f"trained_model_path_{self.level}"
)
if trained_weights_path is not None:
# For XGBoost/HistGradientBoosting, wrap model with metadata
# to ensure feature alignment when applying weights
model_data = {
"model": weights,
"ss_main_score": self.config.runner.ss_main_score,
"classifier": self.classifier,
"level": self.level,
}
with open(trained_weights_path, "wb") as file:
self.persisted_weights = pickle.dump(weights, file)
pickle.dump(model_data, file)
logger.success("%s written." % trained_weights_path)
else:
logger.error(f"Trained model path {trained_weights_path} not found. ")
Expand Down
14 changes: 11 additions & 3 deletions pyprophet/io/scoring/osw.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,9 +723,17 @@ def save_weights(self, weights):

weights.to_sql("PYPROPHET_WEIGHTS", con, index=False, if_exists="append")

elif self.classifier == "XGBoost":
elif self.classifier == "XGBoost" or self.classifier == "HistGradientBoosting":
con = sqlite3.connect(self.outfile)

# Wrap model with metadata for feature alignment
model_data = {
"model": weights,
"ss_main_score": self.config.runner.ss_main_score,
"classifier": self.classifier,
"level": self.level,
}

c = con.cursor()
if self.glyco and self.level in ["ms2", "ms1ms2"]:
c.execute(
Expand All @@ -743,7 +751,7 @@ def save_weights(self, weights):

c.execute(
"INSERT INTO GLYCOPEPTIDEPROPHET_XGB VALUES(?, ?)",
[self.level, pickle.dumps(weights)],
[self.level, pickle.dumps(model_data)],
)
else:
c.execute(
Expand All @@ -758,7 +766,7 @@ def save_weights(self, weights):

c.execute(
"INSERT INTO PYPROPHET_XGB VALUES(?, ?)",
[self.level, pickle.dumps(weights)],
[self.level, pickle.dumps(model_data)],
)
con.commit()
c.close()
Loading
Loading