From eb1299e746e8d71cecc014ea2dead49048cb690f Mon Sep 17 00:00:00 2001 From: Isabel Kantak Date: Thu, 22 Jan 2026 18:14:05 +0100 Subject: [PATCH 1/2] Update Configurable for ML based photon cuts and add Configurable for selecting centrality variable for 2D ML model selection. --- PWGEM/PhotonMeson/Core/Pi0EtaToGammaGamma.h | 18 +++-- PWGEM/PhotonMeson/Core/Pi0EtaToGammaGammaMC.h | 18 +++-- PWGEM/PhotonMeson/Core/V0PhotonCandidate.h | 14 +++- PWGEM/PhotonMeson/Core/V0PhotonCut.cxx | 30 +++++-- PWGEM/PhotonMeson/Core/V0PhotonCut.h | 80 ++++++++++++++++--- .../TableProducer/photonconversionbuilder.cxx | 55 ++++++++++++- 6 files changed, 177 insertions(+), 38 deletions(-) diff --git a/PWGEM/PhotonMeson/Core/Pi0EtaToGammaGamma.h b/PWGEM/PhotonMeson/Core/Pi0EtaToGammaGamma.h index 1d155b7cb2a..8cc1efd73dd 100644 --- a/PWGEM/PhotonMeson/Core/Pi0EtaToGammaGamma.h +++ b/PWGEM/PhotonMeson/Core/Pi0EtaToGammaGamma.h @@ -143,17 +143,20 @@ struct Pi0EtaToGammaGamma { o2::framework::Configurable cfg_disable_tpconly_track{"cfg_disable_tpconly_track", false, "flag to disable TPConly tracks"}; o2::framework::Configurable cfg_apply_ml_cuts{"cfg_apply_ml", false, "flag to apply ML cut"}; - o2::framework::Configurable cfg_use_2d_binning{"cfg_use_2d_binning", true, "flag to use 2D binning (pT, cent)"}; + o2::framework::Configurable cfg_use_2d_binning{"cfg_use_2d_binning", false, "flag to use 2D binning (pT, cent)"}; o2::framework::Configurable cfg_load_ml_models_from_ccdb{"cfg_load_ml_models_from_ccdb", true, "flag to load ML models from CCDB"}; o2::framework::Configurable cfg_timestamp_ccdb{"cfg_timestamp_ccdb", -1, "timestamp for CCDB"}; o2::framework::Configurable cfg_nclasses_ml{"cfg_nclasses_ml", static_cast(o2::analysis::em_cuts_ml::NCutScores), "number of classes for ML"}; + o2::framework::Configurable cfg_cent_type_ml{"cfg_cent_type_ml", "CentFT0C", "centrality type for 2D ML application: CentFT0C, CentFT0M, or CentFT0A"}; o2::framework::Configurable> cfg_cut_dir_ml{"cfg_cut_dir_ml", std::vector{o2::analysis::em_cuts_ml::vecCutDir}, "cut direction for ML"}; o2::framework::Configurable> cfg_input_feature_names{"cfg_input_feature_names", std::vector{"feature1", "feature2"}, "input feature names for ML models"}; o2::framework::Configurable> cfg_model_paths_ccdb{"cfg_model_paths_ccdb", std::vector{"path_ccdb/BDT_PCM/"}, "CCDB paths for ML models"}; o2::framework::Configurable> cfg_onnx_file_names{"cfg_onnx_file_names", std::vector{"ModelHandler_onnx_PCM.onnx"}, "ONNX file names for ML models"}; - o2::framework::Configurable> cfg_bins_pt_ml{"cfg_bins_pt_ml", std::vector{o2::analysis::em_cuts_ml::vecBinsPt}, "pT bins for ML"}; + o2::framework::Configurable> cfg_labels_bins_ml{"cfg_labels_bins_ml", std::vector{"bin 0", "bin 1"}, "Labels for bins"}; + o2::framework::Configurable> cfg_labels_cut_scores_ml{"cfg_labels_cut_scores_ml", std::vector{o2::analysis::em_cuts_ml::labelsCutScore}, "Labels for cut scores"}; + o2::framework::Configurable> cfg_bins_pt_ml{"cfg_bins_pt_ml", std::vector{0.0, +1e+10}, "pT bin limits for ML application"}; o2::framework::Configurable> cfg_bins_cent_ml{"cfg_bins_cent_ml", std::vector{o2::analysis::em_cuts_ml::vecBinsCent}, "centrality bins for ML"}; - o2::framework::Configurable> cfg_cuts_pcm_ml{"cfg_cuts_pcm_ml", {o2::analysis::em_cuts_ml::Cuts[0], o2::analysis::em_cuts_ml::NBinsPt, o2::analysis::em_cuts_ml::NCutScores, o2::analysis::em_cuts_ml::labelsPt, o2::analysis::em_cuts_ml::labelsCutScore}, "ML selections per pT bin"}; + o2::framework::Configurable> cfg_cuts_ml_flat{"cfg_cuts_ml_flat", {0.5}, "Flattened ML cuts: [bin0_score0, bin0_score1, ..., binN_scoreM]"}; } pcmcuts; DalitzEECut fDileptonCut; @@ -435,6 +438,7 @@ struct Pi0EtaToGammaGamma { d_bz = std::lround(5.f * grpmag->getL3Current() / 30000.f); LOG(info) << "Retrieved GRP for timestamp " << run3grp_timestamp << " with magnetic field of " << d_bz << " kZG"; } + fV0PhotonCut.SetD_Bz(d_bz); mRunNumber = collision.runNumber(); } @@ -504,13 +508,16 @@ struct Pi0EtaToGammaGamma { fV0PhotonCut.SetNClassesMl(pcmcuts.cfg_nclasses_ml); fV0PhotonCut.SetMlTimestampCCDB(pcmcuts.cfg_timestamp_ccdb); fV0PhotonCut.SetCcdbUrl(ccdburl); + fV0PhotonCut.SetCentralityTypeMl(pcmcuts.cfg_cent_type_ml); fV0PhotonCut.SetCutDirMl(pcmcuts.cfg_cut_dir_ml); fV0PhotonCut.SetMlModelPathsCCDB(pcmcuts.cfg_model_paths_ccdb); fV0PhotonCut.SetMlOnnxFileNames(pcmcuts.cfg_onnx_file_names); fV0PhotonCut.SetBinsPtMl(pcmcuts.cfg_bins_pt_ml); fV0PhotonCut.SetBinsCentMl(pcmcuts.cfg_bins_cent_ml); - fV0PhotonCut.SetCutsPCMMl(pcmcuts.cfg_cuts_pcm_ml); + fV0PhotonCut.SetCutsMl(pcmcuts.cfg_cuts_ml_flat); fV0PhotonCut.SetNamesInputFeatures(pcmcuts.cfg_input_feature_names); + fV0PhotonCut.SetLabelsBinsMl(pcmcuts.cfg_labels_bins_ml); + fV0PhotonCut.SetLabelsCutScoresMl(pcmcuts.cfg_labels_cut_scores_ml); if (pcmcuts.cfg_apply_ml_cuts) { fV0PhotonCut.initV0MlModels(ccdbApi); @@ -695,8 +702,7 @@ struct Pi0EtaToGammaGamma { { for (const auto& collision : collisions) { initCCDB(collision); - fV0PhotonCut.SetCentrality(collision.centFT0M()); - fV0PhotonCut.SetD_Bz(d_bz); + fV0PhotonCut.SetCentrality(collision.centFT0A(), collision.centFT0C(), collision.centFT0M()); int ndiphoton = 0; if ((pairtype == o2::aod::pwgem::photonmeson::photonpair::PairType::kPHOSPHOS || pairtype == o2::aod::pwgem::photonmeson::photonpair::PairType::kPCMPHOS) && !collision.alias_bit(triggerAliases::kTVXinPHOS)) { continue; diff --git a/PWGEM/PhotonMeson/Core/Pi0EtaToGammaGammaMC.h b/PWGEM/PhotonMeson/Core/Pi0EtaToGammaGammaMC.h index f592386b74e..8e05897fd6d 100644 --- a/PWGEM/PhotonMeson/Core/Pi0EtaToGammaGammaMC.h +++ b/PWGEM/PhotonMeson/Core/Pi0EtaToGammaGammaMC.h @@ -132,17 +132,20 @@ struct Pi0EtaToGammaGammaMC { o2::framework::Configurable cfg_disable_tpconly_track{"cfg_disable_tpconly_track", false, "flag to disable TPConly tracks"}; o2::framework::Configurable cfg_apply_ml_cuts{"cfg_apply_ml", false, "flag to apply ML cut"}; - o2::framework::Configurable cfg_use_2d_binning{"cfg_use_2d_binning", true, "flag to use 2D binning (pT, cent)"}; + o2::framework::Configurable cfg_use_2d_binning{"cfg_use_2d_binning", false, "flag to use 2D binning (pT, cent)"}; o2::framework::Configurable cfg_load_ml_models_from_ccdb{"cfg_load_ml_models_from_ccdb", true, "flag to load ML models from CCDB"}; o2::framework::Configurable cfg_timestamp_ccdb{"cfg_timestamp_ccdb", -1, "timestamp for CCDB"}; o2::framework::Configurable cfg_nclasses_ml{"cfg_nclasses_ml", static_cast(o2::analysis::em_cuts_ml::NCutScores), "number of classes for ML"}; + o2::framework::Configurable cfg_cent_type_ml{"cfg_cent_type_ml", "CentFT0C", "centrality type for 2D ML application: CentFT0C, CentFT0M, or CentFT0A"}; o2::framework::Configurable> cfg_cut_dir_ml{"cfg_cut_dir_ml", std::vector{o2::analysis::em_cuts_ml::vecCutDir}, "cut direction for ML"}; o2::framework::Configurable> cfg_input_feature_names{"cfg_input_feature_names", std::vector{"feature1", "feature2"}, "input feature names for ML models"}; o2::framework::Configurable> cfg_model_paths_ccdb{"cfg_model_paths_ccdb", std::vector{"path_ccdb/BDT_PCM/"}, "CCDB paths for ML models"}; o2::framework::Configurable> cfg_onnx_file_names{"cfg_onnx_file_names", std::vector{"ModelHandler_onnx_PCM.onnx"}, "ONNX file names for ML models"}; - o2::framework::Configurable> cfg_bins_pt_ml{"cfg_bins_pt_ml", std::vector{o2::analysis::em_cuts_ml::vecBinsPt}, "pT bins for ML"}; + o2::framework::Configurable> cfg_labels_bins_ml{"cfg_labels_bins_ml", std::vector{"bin 0", "bin 1"}, "Labels for bins"}; + o2::framework::Configurable> cfg_labels_cut_scores_ml{"cfg_labels_cut_scores_ml", std::vector{o2::analysis::em_cuts_ml::labelsCutScore}, "Labels for cut scores"}; + o2::framework::Configurable> cfg_bins_pt_ml{"cfg_bins_pt_ml", std::vector{0.0, +1e+10}, "pT bin limits for ML application"}; o2::framework::Configurable> cfg_bins_cent_ml{"cfg_bins_cent_ml", std::vector{o2::analysis::em_cuts_ml::vecBinsCent}, "centrality bins for ML"}; - o2::framework::Configurable> cfg_cuts_pcm_ml{"cfg_cuts_pcm_ml", {o2::analysis::em_cuts_ml::Cuts[0], o2::analysis::em_cuts_ml::NBinsPt, o2::analysis::em_cuts_ml::NCutScores, o2::analysis::em_cuts_ml::labelsPt, o2::analysis::em_cuts_ml::labelsCutScore}, "ML selections per pT bin"}; + o2::framework::Configurable> cfg_cuts_ml_flat{"cfg_cuts_ml_flat", {0.5}, "Flattened ML cuts: [bin0_score0, bin0_score1, ..., binN_scoreM]"}; } pcmcuts; DalitzEECut fDileptonCut; @@ -283,6 +286,7 @@ struct Pi0EtaToGammaGammaMC { d_bz = std::lround(5.f * grpmag->getL3Current() / 30000.f); LOG(info) << "Retrieved GRP for timestamp " << run3grp_timestamp << " with magnetic field of " << d_bz << " kZG"; } + fV0PhotonCut.SetD_Bz(d_bz); mRunNumber = collision.runNumber(); } @@ -344,13 +348,16 @@ struct Pi0EtaToGammaGammaMC { fV0PhotonCut.SetNClassesMl(pcmcuts.cfg_nclasses_ml); fV0PhotonCut.SetMlTimestampCCDB(pcmcuts.cfg_timestamp_ccdb); fV0PhotonCut.SetCcdbUrl(ccdburl); + fV0PhotonCut.SetCentralityTypeMl(pcmcuts.cfg_cent_type_ml); fV0PhotonCut.SetCutDirMl(pcmcuts.cfg_cut_dir_ml); fV0PhotonCut.SetMlModelPathsCCDB(pcmcuts.cfg_model_paths_ccdb); fV0PhotonCut.SetMlOnnxFileNames(pcmcuts.cfg_onnx_file_names); fV0PhotonCut.SetBinsPtMl(pcmcuts.cfg_bins_pt_ml); fV0PhotonCut.SetBinsCentMl(pcmcuts.cfg_bins_cent_ml); - fV0PhotonCut.SetCutsPCMMl(pcmcuts.cfg_cuts_pcm_ml); + fV0PhotonCut.SetCutsMl(pcmcuts.cfg_cuts_ml_flat); fV0PhotonCut.SetNamesInputFeatures(pcmcuts.cfg_input_feature_names); + fV0PhotonCut.SetLabelsBinsMl(pcmcuts.cfg_labels_bins_ml); + fV0PhotonCut.SetLabelsCutScoresMl(pcmcuts.cfg_labels_cut_scores_ml); if (pcmcuts.cfg_apply_ml_cuts) { fV0PhotonCut.initV0MlModels(ccdbApi); @@ -553,8 +560,7 @@ struct Pi0EtaToGammaGammaMC { { for (auto& collision : collisions) { initCCDB(collision); - fV0PhotonCut.SetCentrality(collision.centFT0M()); - fV0PhotonCut.SetD_Bz(d_bz); + fV0PhotonCut.SetCentrality(collision.centFT0A(), collision.centFT0C(), collision.centFT0M()); if ((pairtype == o2::aod::pwgem::photonmeson::photonpair::PairType::kPHOSPHOS || pairtype == o2::aod::pwgem::photonmeson::photonpair::PairType::kPCMPHOS) && !collision.alias_bit(triggerAliases::kTVXinPHOS)) { continue; } diff --git a/PWGEM/PhotonMeson/Core/V0PhotonCandidate.h b/PWGEM/PhotonMeson/Core/V0PhotonCandidate.h index 68ad1a4e207..88078a4b497 100644 --- a/PWGEM/PhotonMeson/Core/V0PhotonCandidate.h +++ b/PWGEM/PhotonMeson/Core/V0PhotonCandidate.h @@ -48,7 +48,9 @@ struct V0PhotonCandidate { float psipair; float cospa; float chi2ndf; - float centrality; + float centFT0M; + float centFT0C; + float centFT0A; float pca; public: @@ -89,11 +91,13 @@ struct V0PhotonCandidate { phiv = o2::aod::pwgem::dilepton::utils::pairutil::getPhivPair(posPx, posPy, posPz, elePx, elePy, elePz, posSign, eleSign, d_bz); psipair = o2::aod::pwgem::dilepton::utils::pairutil::getPsiPair(posPx, posPy, posPz, elePx, elePy, elePz); - centrality = collision.centFT0M(); + centFT0M = collision.centFT0M(); + centFT0C = collision.centFT0C(); + centFT0A = collision.centFT0A(); } // Constructor for V0PhotonCut - V0PhotonCandidate(const auto& v0, const auto& pos, const auto& ele, float cent, float d_bz) : centrality(cent) + V0PhotonCandidate(const auto& v0, const auto& pos, const auto& ele, float centFT0A, float centFT0C, float centFT0M, float d_bz) : centFT0A(centFT0A), centFT0C(centFT0C), centFT0M(centFT0M) { px = v0.px(); py = v0.py(); @@ -144,7 +148,9 @@ struct V0PhotonCandidate { float GetElePx() const { return elePx; } float GetElePy() const { return elePy; } float GetElePz() const { return elePz; } - float GetCent() const { return centrality; } + float GetCentFT0M() const { return centFT0M; } + float GetCentFT0C() const { return centFT0C; } + float GetCentFT0A() const { return centFT0A; } float GetPCA() const { return pca; } }; diff --git a/PWGEM/PhotonMeson/Core/V0PhotonCut.cxx b/PWGEM/PhotonMeson/Core/V0PhotonCut.cxx index e6087f76fe7..f37ad9a4e29 100644 --- a/PWGEM/PhotonMeson/Core/V0PhotonCut.cxx +++ b/PWGEM/PhotonMeson/Core/V0PhotonCut.cxx @@ -297,10 +297,10 @@ void V0PhotonCut::SetBinsCentMl(const std::vector& binsCent) LOG(info) << "V0 Photon Cut, set bins centrality ML with size:" << mBinsCentMl.size(); } -void V0PhotonCut::SetCutsPCMMl(const o2::framework::LabeledArray& cuts) +void V0PhotonCut::SetCutsMl(const std::vector& cuts) { - mCutsPCMMl = cuts; - LOG(info) << "V0 Photon Cut, set cuts PCM ML"; + mCutsMlFlat = cuts; + LOG(info) << "V0 Photon Cut, set cuts ML with size:" << mCutsMlFlat.size(); } void V0PhotonCut::SetNClassesMl(int nClasses) @@ -315,9 +315,11 @@ void V0PhotonCut::SetNamesInputFeatures(const std::vector& featureN LOG(info) << "V0 Photon Cut, set ML input feature names with size:" << mNamesInputFeatures.size(); } -void V0PhotonCut::SetCentrality(float cent) +void V0PhotonCut::SetCentrality(float centFT0A, float centFT0C, float centFT0M) { - mCent = cent; + mCentFT0A = centFT0A; + mCentFT0C = centFT0C; + mCentFT0M = centFT0M; } void V0PhotonCut::SetD_Bz(float d_bz) { @@ -329,3 +331,21 @@ void V0PhotonCut::SetCutDirMl(const std::vector& cutDirMl) mCutDirMl = cutDirMl; LOG(info) << "V0 Photon Cut, set ML cut directions with size:" << mCutDirMl.size(); } + +void V0PhotonCut::SetCentralityTypeMl(const std::string& centType) +{ + mCentralityTypeMl = centType; + LOG(info) << "V0 Photon Cut, set centrality type ML: " << mCentralityTypeMl; +} + +void V0PhotonCut::SetLabelsBinsMl(const std::vector& labelsBins) +{ + mLabelsBinsMl = labelsBins; + LOG(info) << "V0 Photon Cut, set ML labels bins with size:" << mLabelsBinsMl.size(); +} + +void V0PhotonCut::SetLabelsCutScoresMl(const std::vector& labelsCutScores) +{ + mLabelsCutScoresMl = labelsCutScores; + LOG(info) << "V0 Photon Cut, set ML labels cut scores with size:" << mLabelsCutScoresMl.size(); +} diff --git a/PWGEM/PhotonMeson/Core/V0PhotonCut.h b/PWGEM/PhotonMeson/Core/V0PhotonCut.h index 6c643b40d69..15770493477 100644 --- a/PWGEM/PhotonMeson/Core/V0PhotonCut.h +++ b/PWGEM/PhotonMeson/Core/V0PhotonCut.h @@ -55,6 +55,8 @@ enum CutDirection { CutNot // do not cut on score }; +static constexpr int NBins = 12; + static constexpr int NBinsPt = 12; static constexpr int NCutScores = 2; // default values for the pT bin edges, offset by 1 from the bin numbers in cuts array @@ -98,7 +100,7 @@ constexpr int CutDir[NCutScores] = {CutGreater, CutSmaller}; const auto vecCutDir = std::vector{CutDir, CutDir + NCutScores}; // default values for the cuts -constexpr double Cuts[NBinsPt][NCutScores] = { +constexpr double Cuts[NBins][NCutScores] = { {0.5, 0.5}, {0.5, 0.5}, {0.5, 0.5}, @@ -318,16 +320,24 @@ class V0PhotonCut : public TNamed LOG(error) << "EM ML Response is not initialized!"; return false; } - bool isSelectedMl = false; + bool mIsSelectedMl = false; std::vector mOutputML; - V0PhotonCandidate v0photoncandidate(v0, pos, ele, mCent, mD_Bz); + V0PhotonCandidate v0photoncandidate(v0, pos, ele, mCentFT0A, mCentFT0C, mCentFT0M, mD_Bz); std::vector mlInputFeatures = mEmMlResponse->getInputFeatures(v0photoncandidate, pos, ele); if (mUse2DBinning) { - isSelectedMl = mEmMlResponse->isSelectedMl(mlInputFeatures, v0photoncandidate.GetPt(), v0photoncandidate.GetCent(), mOutputML); + if (mCentralityTypeMl == "CentFT0C") { + mIsSelectedMl = mEmMlResponse->isSelectedMl(mlInputFeatures, v0photoncandidate.GetPt(), v0photoncandidate.GetCentFT0C(), mOutputML); + } else if (mCentralityTypeMl == "CentFT0A") { + mIsSelectedMl = mEmMlResponse->isSelectedMl(mlInputFeatures, v0photoncandidate.GetPt(), v0photoncandidate.GetCentFT0A(), mOutputML); + } else if (mCentralityTypeMl == "CentFT0M") { + mIsSelectedMl = mEmMlResponse->isSelectedMl(mlInputFeatures, v0photoncandidate.GetPt(), v0photoncandidate.GetCentFT0M(), mOutputML); + } else { + LOG(fatal) << "Unsupported centTypePCMMl: " << mCentralityTypeMl << " , please choose from CentFT0C, CentFT0A, CentFT0M."; + } } else { - isSelectedMl = mEmMlResponse->isSelectedMl(mlInputFeatures, v0photoncandidate.GetPt(), mOutputML); + mIsSelectedMl = mEmMlResponse->isSelectedMl(mlInputFeatures, v0photoncandidate.GetPt(), mOutputML); } - if (!isSelectedMl) { + if (!mIsSelectedMl) { return false; } } @@ -542,9 +552,45 @@ class V0PhotonCut : public TNamed mEmMlResponse = new o2::analysis::EmMlResponsePCM(); } if (mUse2DBinning) { - mEmMlResponse->configure2D(mBinsPtMl, mBinsCentMl, mCutsPCMMl, mCutDirMl, mNClassesMl); + int binsNPt = static_cast(mBinsPtMl.size()) - 1; + int binsNCent = static_cast(mBinsCentMl.size()) - 1; + int binsN = binsNPt * binsNCent; + if (binsN * static_cast(mCutDirMl.size()) != static_cast(mCutsMlFlat.size())) { + LOG(fatal) << "Mismatch in number of bins and cuts provided for 2D ML application: binsN * mCutDirMl: " << int(binsN) * int(mCutDirMl.size()) << " bins vs. mCutsMlFlat: " << mCutsMlFlat.size() << " cuts"; + } + if (binsN != static_cast(mOnnxFileNames.size())) { + LOG(fatal) << "Mismatch in number of bins and ONNX files provided for 2D ML application: binsN " << binsN << " bins vs. mOnnxFileNames: " << mOnnxFileNames.size() << " ONNX files"; + } + if (binsN != static_cast(mLabelsBinsMl.size())) { + LOG(fatal) << "Mismatch in number of bins and labels provided for 2D ML application: binsN:" << binsN << " bins vs. mLabelsBinsMl: " << mLabelsBinsMl.size() << " labels"; + } + if (static_cast(mCutDirMl.size()) != mNClassesMl) { + LOG(fatal) << "Mismatch in number of classes and cut directions provided for 2D ML application: mNClassesMl: " << mNClassesMl << " classes vs. mCutDirMl: " << mCutDirMl.size() << " cut directions"; + } + if (static_cast(mLabelsCutScoresMl.size()) != mNClassesMl) { + LOG(fatal) << "Mismatch in number of labels for cut scores and number of classes provided for 2D ML application: mNClassesMl: " << mNClassesMl << " classes vs. mLabelsCutScoresMl: " << mLabelsCutScoresMl.size() << " labels"; + } + o2::framework::LabeledArray mCutsMl(mCutsMlFlat.data(), binsN, mNClassesMl, mLabelsBinsMl, mLabelsCutScoresMl); + mEmMlResponse->configure2D(mBinsPtMl, mBinsCentMl, mCutsMl, mCutDirMl, mNClassesMl); } else { - mEmMlResponse->configure(mBinsPtMl, mCutsPCMMl, mCutDirMl, mNClassesMl); + int binsNPt = static_cast(mBinsPtMl.size()) - 1; + if (binsNPt * static_cast(mCutDirMl.size()) != static_cast(mCutsMlFlat.size())) { + LOG(fatal) << "Mismatch in number of pT bins and cuts provided for ML application: binsNPt * mCutDirMl:" << binsNPt * mCutDirMl.size() << " bins vs. mCutsMlFlat: " << mCutsMlFlat.size() << " cuts"; + } + if (binsNPt != static_cast(mOnnxFileNames.size())) { + LOG(fatal) << "Mismatch in number of pT bins and ONNX files provided for ML application: binsNPt " << binsNPt << " bins vs. mOnnxFileNames: " << mOnnxFileNames.size() << " ONNX files"; + } + if (binsNPt != static_cast(mLabelsBinsMl.size())) { + LOG(fatal) << "Mismatch in number of pT bins and labels provided for ML application: binsNPt:" << binsNPt << " bins vs. mLabelsBinsMl: " << mLabelsBinsMl.size() << " labels"; + } + if (mNClassesMl != static_cast(mCutDirMl.size())) { + LOG(fatal) << "Mismatch in number of classes and cut directions provided for ML application: mNClassesMl: " << mNClassesMl << " classes vs. mCutDirMl: " << mCutDirMl.size() << " cut directions"; + } + if (static_cast(mLabelsCutScoresMl.size()) != mNClassesMl) { + LOG(fatal) << "Mismatch in number of labels for cut scores and number of classes provided for ML application: mNClassesMl:" << mNClassesMl << " classes vs. mLabelsCutScoresMl: " << mLabelsCutScoresMl.size() << " labels"; + } + o2::framework::LabeledArray mCutsMl(mCutsMlFlat.data(), binsNPt, mNClassesMl, mLabelsBinsMl, mLabelsCutScoresMl); + mEmMlResponse->configure(mBinsPtMl, mCutsMl, mCutDirMl, mNClassesMl); } if (mLoadMlModelsFromCCDB) { ccdbApi.init(mCcdbUrl); @@ -604,15 +650,18 @@ class V0PhotonCut : public TNamed void SetLoadMlModelsFromCCDB(bool flag = true); void SetNClassesMl(int nClasses); void SetMlTimestampCCDB(int timestamp); - void SetCentrality(float cent); + void SetCentrality(float centFT0A, float centFT0C, float centFT0M); void SetD_Bz(float d_bz); void SetCcdbUrl(const std::string& url = "http://alice-ccdb.cern.ch"); + void SetCentralityTypeMl(const std::string& centType); void SetCutDirMl(const std::vector& cutDirMl); void SetMlModelPathsCCDB(const std::vector& modelPaths); void SetMlOnnxFileNames(const std::vector& onnxFileNamesVec); + void SetLabelsBinsMl(const std::vector& labelsBins); + void SetLabelsCutScoresMl(const std::vector& labelsCutScores); void SetBinsPtMl(const std::vector& binsPt); void SetBinsCentMl(const std::vector& binsCent); - void SetCutsPCMMl(const o2::framework::LabeledArray& cutsPCM); + void SetCutsMl(const std::vector& cutsMlFlat); void SetNamesInputFeatures(const std::vector& namesInputFeaturesVec); private: @@ -642,16 +691,21 @@ class V0PhotonCut : public TNamed bool mLoadMlModelsFromCCDB{true}; int mTimestampCCDB{-1}; int mNClassesMl{static_cast(o2::analysis::em_cuts_ml::NCutScores)}; - float mCent{0.f}; + float mCentFT0A{0.f}; + float mCentFT0C{0.f}; + float mCentFT0M{0.f}; float mD_Bz{0.f}; std::string mCcdbUrl{"http://alice-ccdb.cern.ch"}; + std::string mCentralityTypeMl{"FT0C"}; std::vector mCutDirMl{std::vector{o2::analysis::em_cuts_ml::vecCutDir}}; std::vector mModelPathsCCDB{std::vector{"path_ccdb/BDT_PCM/"}}; std::vector mOnnxFileNames{std::vector{"ModelHandler_onnx_PCM.onnx"}}; std::vector mNamesInputFeatures{std::vector{"feature1", "feature2"}}; + std::vector mLabelsBinsMl{std::vector{"bin 0", "bin 1"}}; + std::vector mLabelsCutScoresMl{std::vector{"score primary photons", "score background"}}; std::vector mBinsPtMl{std::vector{o2::analysis::em_cuts_ml::vecBinsPt}}; std::vector mBinsCentMl{std::vector{o2::analysis::em_cuts_ml::vecBinsCent}}; - o2::framework::LabeledArray mCutsPCMMl{o2::framework::LabeledArray{o2::analysis::em_cuts_ml::Cuts[0], o2::analysis::em_cuts_ml::NBinsPt, o2::analysis::em_cuts_ml::NCutScores, o2::analysis::em_cuts_ml::labelsPt, o2::analysis::em_cuts_ml::labelsCutScore}}; + std::vector mCutsMlFlat{std::vector{0.5}}; o2::analysis::EmMlResponsePCM* mEmMlResponse{nullptr}; // pid cuts @@ -683,7 +737,7 @@ class V0PhotonCut : public TNamed bool mDisableITSonly{false}; bool mDisableTPConly{false}; - ClassDef(V0PhotonCut, 3); + ClassDef(V0PhotonCut, 4); }; #endif // PWGEM_PHOTONMESON_CORE_V0PHOTONCUT_H_ diff --git a/PWGEM/PhotonMeson/TableProducer/photonconversionbuilder.cxx b/PWGEM/PhotonMeson/TableProducer/photonconversionbuilder.cxx index 93107aa689b..b85b0046ce5 100644 --- a/PWGEM/PhotonMeson/TableProducer/photonconversionbuilder.cxx +++ b/PWGEM/PhotonMeson/TableProducer/photonconversionbuilder.cxx @@ -161,17 +161,20 @@ struct PhotonConversionBuilder { // PCM ML inference Configurable applyPCMMl{"applyPCMMl", false, "Flag to apply ML selections"}; - Configurable use2DBinning{"use2DBinning", true, "Flag to enable/disable 2D binning for ML application"}; + Configurable use2DBinning{"use2DBinning", false, "Flag to enable/disable 2D binning for ML application"}; Configurable loadModelsFromCCDB{"loadModelsFromCCDB", false, "Flag to enable or disable the loading of models from CCDB"}; Configurable nClassesPCMMl{"nClassesPCMMl", static_cast(o2::analysis::em_cuts_ml::NCutScores), "Number of classes in ML model"}; Configurable timestampCCDB{"timestampCCDB", -1, "timestamp of the ONNX file for ML model used to query in CCDB"}; + Configurable centTypePCMMl{"centTypePCMMl", "CentFT0C", "Centrality type for 2D ML application: CentFT0C, CentFT0M, or CentFT0A"}; Configurable> cutDirPCMMl{"cutDirPCMMl", std::vector{o2::analysis::em_cuts_ml::vecCutDir}, "Whether to reject score values greater or smaller than the threshold"}; Configurable> namesInputFeatures{"namesInputFeatures", std::vector{"feature1", "feature2"}, "Names of ML model input features"}; Configurable> modelPathsCCDB{"modelPathsCCDB", std::vector{"path_ccdb/BDT_PCM/"}, "Paths of models on CCDB"}; Configurable> onnxFileNames{"onnxFileNames", std::vector{"ModelHandler_onnx_PCM.onnx"}, "ONNX file names for each pT bin (if not from CCDB full path)"}; - Configurable> binsPtPCMMl{"binsPtPCMMl", std::vector{o2::analysis::em_cuts_ml::vecBinsPt}, "pT bin limits for ML application"}; + Configurable> labelsBinsPCMMl{"labelsBinsPCMMl", std::vector{"bin 0", "bin 1"}, "Labels for bins"}; + Configurable> labelsCutScoresPCMMl{"labelsCutScoresPCMMl", std::vector{o2::analysis::em_cuts_ml::labelsCutScore}, "Labels for cut scores"}; + Configurable> binsPtPCMMl{"binsPtPCMMl", std::vector{0.0, +1e+10}, "pT bin limits for ML application"}; Configurable> binsCentPCMMl{"binsCentPCMMl", std::vector{0.0, 100.0}, "Centrality bin limits for ML application"}; - Configurable> cutsPCMMl{"cutsPCMMl", {o2::analysis::em_cuts_ml::Cuts[0], o2::analysis::em_cuts_ml::NBinsPt, o2::analysis::em_cuts_ml::NCutScores, o2::analysis::em_cuts_ml::labelsPt, o2::analysis::em_cuts_ml::labelsCutScore}, "ML selections per pT bin"}; + Configurable> cutsPCMMlFlat{"cutsPCMMlFlat", {0.5}, "Flattened ML cuts: [bin0_score0, bin0_score1, ..., binN_scoreM]"}; o2::analysis::EmMlResponsePCM emMlResponse; std::vector outputML; @@ -254,8 +257,44 @@ struct PhotonConversionBuilder { if (applyPCMMl) { if (use2DBinning) { + int binsNPt = static_cast(binsPtPCMMl->size()) - 1; + int binsNCent = static_cast(binsCentPCMMl->size()) - 1; + int binsN = binsNPt * binsNCent; + if (binsN * static_cast(cutDirPCMMl->size()) != static_cast(cutsPCMMlFlat->size())) { + LOG(fatal) << "Mismatch in number of bins and cuts provided for 2D ML application: binsN * cutDirPCMMl: " << int(binsN) * int(cutDirPCMMl->size()) << " bins vs. cutsPCMMlFlat: " << cutsPCMMlFlat->size() << " cuts"; + } + if (binsN != static_cast(onnxFileNames->size())) { + LOG(fatal) << "Mismatch in number of bins and ONNX files provided for 2D ML application: binsN " << binsN << " bins vs. onnxFileNames: " << onnxFileNames->size() << " ONNX files"; + } + if (binsN != static_cast(labelsBinsPCMMl->size())) { + LOG(fatal) << "Mismatch in number of bins and labels provided for 2D ML application: binsN:" << binsN << " bins vs. labelsBinsPCMMl: " << labelsBinsPCMMl->size() << " labels"; + } + if (static_cast(cutDirPCMMl->size()) != nClassesPCMMl) { + LOG(fatal) << "Mismatch in number of classes and cut directions provided for 2D ML application: nClassesPCMMl: " << nClassesPCMMl << " classes vs. cutDirPCMMl: " << cutDirPCMMl->size() << " cut directions"; + } + if (static_cast(labelsCutScoresPCMMl->size()) != nClassesPCMMl) { + LOG(fatal) << "Mismatch in number of labels for cut scores and number of classes provided for 2D ML application: nClassesPCMMl: " << nClassesPCMMl << " classes vs. labelsCutScoresPCMMl: " << labelsCutScoresPCMMl->size() << " labels"; + } + LabeledArray cutsPCMMl(cutsPCMMlFlat->data(), binsN, nClassesPCMMl, labelsBinsPCMMl, labelsCutScoresPCMMl); emMlResponse.configure2D(binsPtPCMMl, binsCentPCMMl, cutsPCMMl, cutDirPCMMl, nClassesPCMMl); } else { + int binsNPt = static_cast(binsPtPCMMl->size()) - 1; + if (binsNPt * static_cast(cutDirPCMMl->size()) != static_cast(cutsPCMMlFlat->size())) { + LOG(fatal) << "Mismatch in number of pT bins and cuts provided for ML application: binsNPt * cutDirPCMMl:" << binsNPt * cutDirPCMMl->size() << " bins vs. cutsPCMMlFlat: " << cutsPCMMlFlat->size() << " cuts"; + } + if (binsNPt != static_cast(onnxFileNames->size())) { + LOG(fatal) << "Mismatch in number of pT bins and ONNX files provided for ML application: binsNPt " << binsNPt << " bins vs. onnxFileNames: " << onnxFileNames->size() << " ONNX files"; + } + if (binsNPt != static_cast(labelsBinsPCMMl->size())) { + LOG(fatal) << "Mismatch in number of pT bins and labels provided for ML application: binsNPt:" << binsNPt << " bins vs. labelsBinsPCMMl: " << labelsBinsPCMMl->size() << " labels"; + } + if (nClassesPCMMl != static_cast(cutDirPCMMl->size())) { + LOG(fatal) << "Mismatch in number of classes and cut directions provided for ML application: nClassesPCMMl: " << nClassesPCMMl << " classes vs. cutDirPCMMl: " << cutDirPCMMl->size() << " cut directions"; + } + if (static_cast(labelsCutScoresPCMMl->size()) != nClassesPCMMl) { + LOG(fatal) << "Mismatch in number of labels for cut scores and number of classes provided for ML application: nClassesPCMMl:" << nClassesPCMMl << " classes vs. labelsCutScoresPCMMl: " << labelsCutScoresPCMMl->size() << " labels"; + } + LabeledArray cutsPCMMl(cutsPCMMlFlat->data(), binsNPt, nClassesPCMMl, labelsBinsPCMMl, labelsCutScoresPCMMl); emMlResponse.configure(binsPtPCMMl, cutsPCMMl, cutDirPCMMl, nClassesPCMMl); } if (loadModelsFromCCDB) { @@ -685,7 +724,15 @@ struct PhotonConversionBuilder { bool isSelectedML = false; std::vector mlInputFeatures = emMlResponse.getInputFeatures(v0photoncandidate, pos, ele); if (use2DBinning) { - isSelectedML = emMlResponse.isSelectedMl(mlInputFeatures, v0photoncandidate.GetPt(), v0photoncandidate.GetCent(), outputML); + if (std::string(centTypePCMMl) == "CentFT0C") { + isSelectedML = emMlResponse.isSelectedMl(mlInputFeatures, v0photoncandidate.GetPt(), v0photoncandidate.GetCentFT0C(), outputML); + } else if (std::string(centTypePCMMl) == "CentFT0A") { + isSelectedML = emMlResponse.isSelectedMl(mlInputFeatures, v0photoncandidate.GetPt(), v0photoncandidate.GetCentFT0A(), outputML); + } else if (std::string(centTypePCMMl) == "CentFT0M") { + isSelectedML = emMlResponse.isSelectedMl(mlInputFeatures, v0photoncandidate.GetPt(), v0photoncandidate.GetCentFT0M(), outputML); + } else { + LOG(fatal) << "Unsupported centTypePCMMl: " << centTypePCMMl << " , please choose from CentFT0C, CentFT0A, CentFT0M."; + } } else { isSelectedML = emMlResponse.isSelectedMl(mlInputFeatures, v0photoncandidate.GetPt(), outputML); } From 3c3a0c295f2a12ff735c94818b0f42efc41dd583 Mon Sep 17 00:00:00 2001 From: Isabel Kantak Date: Thu, 22 Jan 2026 20:23:35 +0100 Subject: [PATCH 2/2] Change declaration order of struct variables to fit with the constructor --- PWGEM/PhotonMeson/Core/V0PhotonCandidate.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/PWGEM/PhotonMeson/Core/V0PhotonCandidate.h b/PWGEM/PhotonMeson/Core/V0PhotonCandidate.h index 88078a4b497..87d72496d20 100644 --- a/PWGEM/PhotonMeson/Core/V0PhotonCandidate.h +++ b/PWGEM/PhotonMeson/Core/V0PhotonCandidate.h @@ -48,9 +48,9 @@ struct V0PhotonCandidate { float psipair; float cospa; float chi2ndf; - float centFT0M; - float centFT0C; float centFT0A; + float centFT0C; + float centFT0M; float pca; public: