From 2e556764d77be21d77e75187cac752800fffbcf1 Mon Sep 17 00:00:00 2001 From: Martin Date: Thu, 27 Mar 2025 15:23:56 -0500 Subject: [PATCH 01/26] RetryAction compiles --- .../SonicCore/interface/RetryActionBase.h | 26 +++++++++++++++++++ .../interface/RetrySameServerAction.h | 14 ++++++++++ .../SonicCore/src/RetryActionBase.cc | 15 +++++++++++ .../SonicCore/src/RetrySameServerAction.cc | 11 ++++++++ 4 files changed, 66 insertions(+) create mode 100644 HeterogeneousCore/SonicCore/interface/RetryActionBase.h create mode 100644 HeterogeneousCore/SonicCore/interface/RetrySameServerAction.h create mode 100644 HeterogeneousCore/SonicCore/src/RetryActionBase.cc create mode 100644 HeterogeneousCore/SonicCore/src/RetrySameServerAction.cc diff --git a/HeterogeneousCore/SonicCore/interface/RetryActionBase.h b/HeterogeneousCore/SonicCore/interface/RetryActionBase.h new file mode 100644 index 0000000000000..3a95578783b3d --- /dev/null +++ b/HeterogeneousCore/SonicCore/interface/RetryActionBase.h @@ -0,0 +1,26 @@ +#ifndef RETRY_ACTION_BASE_H +#define RETRY_ACTION_BASE_H + +#include "FWCore/PluginManager/interface/PluginFactory.h" +#include "FWCore/ParameterSet/interface/ParameterSet.h" +#include "HeterogeneousCore/SonicCore/interface/SonicClientBase.h" +#include +#include + +// Base class for retry actions +class RetryActionBase { +public: + RetryActionBase(const edm::ParameterSet& conf, SonicClientBase* client); + virtual ~RetryActionBase() = default; +protected: + virtual void retry() = 0; // Pure virtual function for execution logic + void eval(); // interface for calling evaluate in client + +protected: + SonicClientBase* client_; +}; + +// Define the factory for creating retry actions +using RetryActionFactory = edmplugin::PluginFactory; + +#endif // RETRY_ACTION_BASE_H diff --git a/HeterogeneousCore/SonicCore/interface/RetrySameServerAction.h b/HeterogeneousCore/SonicCore/interface/RetrySameServerAction.h new file mode 100644 index 0000000000000..cb752262dce28 --- /dev/null +++ b/HeterogeneousCore/SonicCore/interface/RetrySameServerAction.h @@ -0,0 +1,14 @@ +#include "HeterogeneousCore/SonicCore/interface/RetryActionBase.h" +#include "HeterogeneousCore/SonicCore/interface/SonicClientBase.h" + +class RetrySameServerAction : public RetryActionBase { +public: + RetrySameServerAction(const edm::ParameterSet& pset, SonicClientBase* client) + : RetryActionBase(pset, client), + allowedTries_(pset.getUntrackedParameter("allowedTries", 0)) {} +protected: + void retry(); + +private: + unsigned allowedTries_,tries_; +}; diff --git a/HeterogeneousCore/SonicCore/src/RetryActionBase.cc b/HeterogeneousCore/SonicCore/src/RetryActionBase.cc new file mode 100644 index 0000000000000..ecdae15543654 --- /dev/null +++ b/HeterogeneousCore/SonicCore/src/RetryActionBase.cc @@ -0,0 +1,15 @@ +#include "HeterogeneousCore/SonicCore/interface/RetryActionBase.h" + +// Constructor implementation +RetryActionBase::RetryActionBase(const edm::ParameterSet& conf,SonicClientBase* client) : client_(client) {} + + +void RetryActionBase::eval() { + if (client_) { + client_->evaluate(); + } else { + edm::LogError("RetryActionBase") << "Client pointer is null, cannot evaluate."; + } +} + +EDM_REGISTER_PLUGINFACTORY(RetryActionFactory, "RetryActionFactory"); diff --git a/HeterogeneousCore/SonicCore/src/RetrySameServerAction.cc b/HeterogeneousCore/SonicCore/src/RetrySameServerAction.cc new file mode 100644 index 0000000000000..637c4fe2bbff9 --- /dev/null +++ b/HeterogeneousCore/SonicCore/src/RetrySameServerAction.cc @@ -0,0 +1,11 @@ +#include "HeterogeneousCore/SonicCore/interface/RetrySameServerAction.h" +#include "HeterogeneousCore/SonicCore/interface/SonicClientBase.h" + +void RetrySameServerAction::retry() { + ++tries_; + //if max retries has not been exceeded, call evaluate again + if (tries_ < allowedTries_) { + eval(); + return; + } +} From 78315915b338378bc14fe4fcbb8a2313c8e5bdc2 Mon Sep 17 00:00:00 2001 From: Martin Date: Wed, 2 Apr 2025 09:52:35 -0500 Subject: [PATCH 02/26] Include RetryAction in SonicClientBase --- .../SonicCore/interface/RetryActionBase.h | 23 +++++--- .../interface/RetrySameServerAction.h | 12 ++-- .../SonicCore/interface/SonicClientBase.h | 12 ++++ .../SonicCore/src/RetryActionBase.cc | 13 ++--- .../SonicCore/src/RetrySameServerAction.cc | 19 ++++--- .../SonicCore/src/SonicClientBase.cc | 55 ++++++++++++++++--- 6 files changed, 99 insertions(+), 35 deletions(-) diff --git a/HeterogeneousCore/SonicCore/interface/RetryActionBase.h b/HeterogeneousCore/SonicCore/interface/RetryActionBase.h index 3a95578783b3d..4732abc27a38f 100644 --- a/HeterogeneousCore/SonicCore/interface/RetryActionBase.h +++ b/HeterogeneousCore/SonicCore/interface/RetryActionBase.h @@ -10,17 +10,24 @@ // Base class for retry actions class RetryActionBase { public: - RetryActionBase(const edm::ParameterSet& conf, SonicClientBase* client); - virtual ~RetryActionBase() = default; + RetryActionBase(const edm::ParameterSet& conf, SonicClientBase* client); + virtual ~RetryActionBase() = default; + + bool shouldRetry() const { return shouldRetry_; } // Getter for shouldRetry_ + + virtual void retry() = 0; // Pure virtual function for execution logic + virtual void start() = 0; // Pure virtual function for execution logic for initialization + protected: - virtual void retry() = 0; // Pure virtual function for execution logic - void eval(); // interface for calling evaluate in client - + void eval(); // interface for calling evaluate in client + protected: - SonicClientBase* client_; + SonicClientBase* client_; + bool shouldRetry_; // Flag to track if further retries should happen }; // Define the factory for creating retry actions -using RetryActionFactory = edmplugin::PluginFactory; +using RetryActionFactory = + edmplugin::PluginFactory; -#endif // RETRY_ACTION_BASE_H +#endif // RETRY_ACTION_BASE_H diff --git a/HeterogeneousCore/SonicCore/interface/RetrySameServerAction.h b/HeterogeneousCore/SonicCore/interface/RetrySameServerAction.h index cb752262dce28..cd8cda3a2d435 100644 --- a/HeterogeneousCore/SonicCore/interface/RetrySameServerAction.h +++ b/HeterogeneousCore/SonicCore/interface/RetrySameServerAction.h @@ -3,12 +3,14 @@ class RetrySameServerAction : public RetryActionBase { public: - RetrySameServerAction(const edm::ParameterSet& pset, SonicClientBase* client) - : RetryActionBase(pset, client), - allowedTries_(pset.getUntrackedParameter("allowedTries", 0)) {} + RetrySameServerAction(const edm::ParameterSet& pset, SonicClientBase* client) + : RetryActionBase(pset, client), allowedTries_(pset.getUntrackedParameter("allowedTries", 0)) {} + + void start() override { tries_=0;}; + protected: - void retry(); + void retry() override; private: - unsigned allowedTries_,tries_; + unsigned allowedTries_, tries_; }; diff --git a/HeterogeneousCore/SonicCore/interface/SonicClientBase.h b/HeterogeneousCore/SonicCore/interface/SonicClientBase.h index 47caaae8b2052..5038f566dbc27 100644 --- a/HeterogeneousCore/SonicCore/interface/SonicClientBase.h +++ b/HeterogeneousCore/SonicCore/interface/SonicClientBase.h @@ -9,12 +9,15 @@ #include "HeterogeneousCore/SonicCore/interface/SonicDispatcherPseudoAsync.h" #include +#include #include #include #include enum class SonicMode { Sync = 1, Async = 2, PseudoAsync = 3 }; +class RetryActionBase; + class SonicClientBase { public: //constructor @@ -57,11 +60,20 @@ class SonicClientBase { unsigned allowedTries_, tries_; std::optional holder_; + // Use a unique_ptr with a custom deleter to avoid incomplete type issues + struct RetryDeleter { + void operator()(RetryActionBase* ptr) const; + }; + + using RetryActionPtr = std::unique_ptr; + std::vector retryActions_; + //for logging/debugging std::string debugName_, clientName_, fullDebugName_; friend class SonicDispatcher; friend class SonicDispatcherPseudoAsync; + friend class RetryActionBase; }; #endif diff --git a/HeterogeneousCore/SonicCore/src/RetryActionBase.cc b/HeterogeneousCore/SonicCore/src/RetryActionBase.cc index ecdae15543654..c595458570b0d 100644 --- a/HeterogeneousCore/SonicCore/src/RetryActionBase.cc +++ b/HeterogeneousCore/SonicCore/src/RetryActionBase.cc @@ -1,15 +1,14 @@ #include "HeterogeneousCore/SonicCore/interface/RetryActionBase.h" // Constructor implementation -RetryActionBase::RetryActionBase(const edm::ParameterSet& conf,SonicClientBase* client) : client_(client) {} - +RetryActionBase::RetryActionBase(const edm::ParameterSet& conf, SonicClientBase* client) : client_(client), shouldRetry_(true) {} void RetryActionBase::eval() { - if (client_) { - client_->evaluate(); - } else { - edm::LogError("RetryActionBase") << "Client pointer is null, cannot evaluate."; - } + if (client_) { + client_->evaluate(); + } else { + edm::LogError("RetryActionBase") << "Client pointer is null, cannot evaluate."; + } } EDM_REGISTER_PLUGINFACTORY(RetryActionFactory, "RetryActionFactory"); diff --git a/HeterogeneousCore/SonicCore/src/RetrySameServerAction.cc b/HeterogeneousCore/SonicCore/src/RetrySameServerAction.cc index 637c4fe2bbff9..16959bec547a1 100644 --- a/HeterogeneousCore/SonicCore/src/RetrySameServerAction.cc +++ b/HeterogeneousCore/SonicCore/src/RetrySameServerAction.cc @@ -1,11 +1,16 @@ #include "HeterogeneousCore/SonicCore/interface/RetrySameServerAction.h" #include "HeterogeneousCore/SonicCore/interface/SonicClientBase.h" -void RetrySameServerAction::retry() { - ++tries_; - //if max retries has not been exceeded, call evaluate again - if (tries_ < allowedTries_) { - eval(); - return; - } +void RetrySameServerAction::retry() { + ++tries_; + //if max retries has not been exceeded, call evaluate again + if (tries_ < allowedTries_) { + eval(); + return; + }else{ + shouldRetry_ = false; // Flip flag when max retries are reached + edm::LogInfo("RetrySameServerAction") << "Max retry attempts reached. No further retries."; + } } + +DEFINE_EDM_PLUGIN(RetryActionFactory, RetrySameServerAction, "RetrySameServerAction"); diff --git a/HeterogeneousCore/SonicCore/src/SonicClientBase.cc b/HeterogeneousCore/SonicCore/src/SonicClientBase.cc index 745c51f17aaf3..2a4bb73a128b8 100644 --- a/HeterogeneousCore/SonicCore/src/SonicClientBase.cc +++ b/HeterogeneousCore/SonicCore/src/SonicClientBase.cc @@ -1,7 +1,14 @@ #include "HeterogeneousCore/SonicCore/interface/SonicClientBase.h" +#include "HeterogeneousCore/SonicCore/interface/RetryActionBase.h" #include "FWCore/Utilities/interface/Exception.h" #include "FWCore/ParameterSet/interface/allowedValues.h" + +// Custom deleter implementation +void SonicClientBase::RetryDeleter::operator()(RetryActionBase* ptr) const { + delete ptr; +} + SonicClientBase::SonicClientBase(const edm::ParameterSet& params, const std::string& debugName, const std::string& clientName) @@ -12,6 +19,18 @@ SonicClientBase::SonicClientBase(const edm::ParameterSet& params, if (!clientName_.empty()) fullDebugName_ += ":" + clientName_; + std::vector retryPSetList = params.getParameter>("Retry"); + + for (const auto& retryPSet : retryPSetList) { + std::string actionType = retryPSet.getParameter("retryType"); + + auto retryAction = RetryActionFactory::get()->create(actionType, retryPSet, this); + if (retryAction) { + //Convert to RetryActionPtr Type from raw pointer of retryAction + retryActions_.emplace_back(RetryActionPtr(retryAction.release())); + } + } + std::string modeName(params.getParameter("mode")); if (modeName == "Sync") setMode(SonicMode::Sync); @@ -40,19 +59,39 @@ void SonicClientBase::start(edm::WaitingTaskWithArenaHolder holder) { holder_ = std::move(holder); } -void SonicClientBase::start() { tries_ = 0; } +void SonicClientBase::start() { + tries_ = 0; + // initialize all actions + for (const auto& action : retryActions_) { + action->start(); + } +} void SonicClientBase::finish(bool success, std::exception_ptr eptr) { //retries are only allowed if no exception was raised if (!success and !eptr) { - ++tries_; - //if max retries has not been exceeded, call evaluate again - if (tries_ < allowedTries_) { - evaluate(); - //avoid calling doneWaiting() twice - return; + //++tries_; + ////if max retries has not been exceeded, call evaluate again + //if (tries_ < allowedTries_) { + // evaluate(); + // //avoid calling doneWaiting() twice + // return; + //} + + // Check if any retry actions are still valid + bool anyRetryAllowed = false; + for (const auto& action : retryActions_) { + if (action->shouldRetry()) { + action->retry(); // Call retry only if shouldRetry_ is true + return; + } + } + // If no actions allow retries, stop retrying + if (!anyRetryAllowed) { + edm::LogInfo("SonicClientBase") << "No retry actions available. Stopping retries."; + return; } - //prepare an exception if exceeded + //prepare an exception if no more retries left else { edm::Exception ex(edm::errors::ExternalFailure); ex << "SonicCallFailed: call failed after max " << tries_ << " tries"; From d7d40b0e2ca7c3888baee46cd63c2d84cc6a2c10 Mon Sep 17 00:00:00 2001 From: Martin Date: Mon, 7 Apr 2025 08:55:41 -0500 Subject: [PATCH 03/26] Update PR comments --- .../SonicCore/interface/RetryActionBase.h | 8 +- .../interface/RetrySameServerAction.h | 2 +- .../SonicCore/interface/SonicClientBase.h | 4 +- .../SonicCore/src/RetryActionBase.cc | 3 +- .../SonicCore/src/RetrySameServerAction.cc | 6 +- .../SonicCore/src/SonicClientBase.cc | 82 +++++++++---------- .../SonicCore/test/DummyClient.h | 2 +- .../SonicCore/test/sonicTest_cfg.py | 44 ++++++++-- 8 files changed, 91 insertions(+), 60 deletions(-) diff --git a/HeterogeneousCore/SonicCore/interface/RetryActionBase.h b/HeterogeneousCore/SonicCore/interface/RetryActionBase.h index 4732abc27a38f..d81183df39a47 100644 --- a/HeterogeneousCore/SonicCore/interface/RetryActionBase.h +++ b/HeterogeneousCore/SonicCore/interface/RetryActionBase.h @@ -1,5 +1,5 @@ -#ifndef RETRY_ACTION_BASE_H -#define RETRY_ACTION_BASE_H +#ifndef HeterogeneousCore_SonicCore_RetryActionBase +#define HeterogeneousCore_SonicCore_RetryActionBase #include "FWCore/PluginManager/interface/PluginFactory.h" #include "FWCore/ParameterSet/interface/ParameterSet.h" @@ -19,7 +19,7 @@ class RetryActionBase { virtual void start() = 0; // Pure virtual function for execution logic for initialization protected: - void eval(); // interface for calling evaluate in client + void eval(); // interface for calling evaluate in client protected: SonicClientBase* client_; @@ -30,4 +30,4 @@ class RetryActionBase { using RetryActionFactory = edmplugin::PluginFactory; -#endif // RETRY_ACTION_BASE_H +#endif diff --git a/HeterogeneousCore/SonicCore/interface/RetrySameServerAction.h b/HeterogeneousCore/SonicCore/interface/RetrySameServerAction.h index cd8cda3a2d435..8ecce2a170847 100644 --- a/HeterogeneousCore/SonicCore/interface/RetrySameServerAction.h +++ b/HeterogeneousCore/SonicCore/interface/RetrySameServerAction.h @@ -6,7 +6,7 @@ class RetrySameServerAction : public RetryActionBase { RetrySameServerAction(const edm::ParameterSet& pset, SonicClientBase* client) : RetryActionBase(pset, client), allowedTries_(pset.getUntrackedParameter("allowedTries", 0)) {} - void start() override { tries_=0;}; + void start() override { tries_ = 0; }; protected: void retry() override; diff --git a/HeterogeneousCore/SonicCore/interface/SonicClientBase.h b/HeterogeneousCore/SonicCore/interface/SonicClientBase.h index 5038f566dbc27..45a089701ed12 100644 --- a/HeterogeneousCore/SonicCore/interface/SonicClientBase.h +++ b/HeterogeneousCore/SonicCore/interface/SonicClientBase.h @@ -57,12 +57,12 @@ class SonicClientBase { SonicMode mode_; bool verbose_; std::unique_ptr dispatcher_; - unsigned allowedTries_, tries_; + unsigned totalTries_; std::optional holder_; // Use a unique_ptr with a custom deleter to avoid incomplete type issues struct RetryDeleter { - void operator()(RetryActionBase* ptr) const; + void operator()(RetryActionBase* ptr) const; }; using RetryActionPtr = std::unique_ptr; diff --git a/HeterogeneousCore/SonicCore/src/RetryActionBase.cc b/HeterogeneousCore/SonicCore/src/RetryActionBase.cc index c595458570b0d..41b9a6186da2b 100644 --- a/HeterogeneousCore/SonicCore/src/RetryActionBase.cc +++ b/HeterogeneousCore/SonicCore/src/RetryActionBase.cc @@ -1,7 +1,8 @@ #include "HeterogeneousCore/SonicCore/interface/RetryActionBase.h" // Constructor implementation -RetryActionBase::RetryActionBase(const edm::ParameterSet& conf, SonicClientBase* client) : client_(client), shouldRetry_(true) {} +RetryActionBase::RetryActionBase(const edm::ParameterSet& conf, SonicClientBase* client) + : client_(client), shouldRetry_(true) {} void RetryActionBase::eval() { if (client_) { diff --git a/HeterogeneousCore/SonicCore/src/RetrySameServerAction.cc b/HeterogeneousCore/SonicCore/src/RetrySameServerAction.cc index 16959bec547a1..b5a24af935596 100644 --- a/HeterogeneousCore/SonicCore/src/RetrySameServerAction.cc +++ b/HeterogeneousCore/SonicCore/src/RetrySameServerAction.cc @@ -7,9 +7,9 @@ void RetrySameServerAction::retry() { if (tries_ < allowedTries_) { eval(); return; - }else{ - shouldRetry_ = false; // Flip flag when max retries are reached - edm::LogInfo("RetrySameServerAction") << "Max retry attempts reached. No further retries."; + } else { + shouldRetry_ = false; // Flip flag when max retries are reached + edm::LogInfo("RetrySameServerAction") << "Max retry attempts reached. No further retries."; } } diff --git a/HeterogeneousCore/SonicCore/src/SonicClientBase.cc b/HeterogeneousCore/SonicCore/src/SonicClientBase.cc index 2a4bb73a128b8..514a680b2518b 100644 --- a/HeterogeneousCore/SonicCore/src/SonicClientBase.cc +++ b/HeterogeneousCore/SonicCore/src/SonicClientBase.cc @@ -3,35 +3,31 @@ #include "FWCore/Utilities/interface/Exception.h" #include "FWCore/ParameterSet/interface/allowedValues.h" - // Custom deleter implementation -void SonicClientBase::RetryDeleter::operator()(RetryActionBase* ptr) const { - delete ptr; -} +void SonicClientBase::RetryDeleter::operator()(RetryActionBase* ptr) const { delete ptr; } SonicClientBase::SonicClientBase(const edm::ParameterSet& params, const std::string& debugName, const std::string& clientName) - : allowedTries_(params.getUntrackedParameter("allowedTries", 0)), - debugName_(debugName), - clientName_(clientName), - fullDebugName_(debugName_) { + : debugName_(debugName), clientName_(clientName), fullDebugName_(debugName_) { if (!clientName_.empty()) fullDebugName_ += ":" + clientName_; - std::vector retryPSetList = params.getParameter>("Retry"); + const auto& retryPSetList = params.getParameter>("Retry"); + std::string modeName(params.getParameter("mode")); for (const auto& retryPSet : retryPSetList) { - std::string actionType = retryPSet.getParameter("retryType"); + const std::string& actionType = retryPSet.getParameter("retryType"); - auto retryAction = RetryActionFactory::get()->create(actionType, retryPSet, this); - if (retryAction) { - //Convert to RetryActionPtr Type from raw pointer of retryAction - retryActions_.emplace_back(RetryActionPtr(retryAction.release())); - } + auto retryAction = RetryActionFactory::get()->create(actionType, retryPSet, this); + if (retryAction) { + //Convert to RetryActionPtr Type from raw pointer of retryAction + retryActions_.emplace_back(RetryActionPtr(retryAction.release())); + } else { + throw cms::Exception("Configuration") << "Unknown Retry type" << actionType << " for SonicClient: " << modeName; + } } - std::string modeName(params.getParameter("mode")); if (modeName == "Sync") setMode(SonicMode::Sync); else if (modeName == "Async") @@ -59,42 +55,32 @@ void SonicClientBase::start(edm::WaitingTaskWithArenaHolder holder) { holder_ = std::move(holder); } -void SonicClientBase::start() { - tries_ = 0; - // initialize all actions - for (const auto& action : retryActions_) { - action->start(); - } +void SonicClientBase::start() { + totalTries_ = 0; + // initialize all actions + for (const auto& action : retryActions_) { + action->start(); + } } void SonicClientBase::finish(bool success, std::exception_ptr eptr) { //retries are only allowed if no exception was raised if (!success and !eptr) { - //++tries_; - ////if max retries has not been exceeded, call evaluate again - //if (tries_ < allowedTries_) { - // evaluate(); - // //avoid calling doneWaiting() twice - // return; - //} - + ++totalTries_; // Check if any retry actions are still valid bool anyRetryAllowed = false; for (const auto& action : retryActions_) { - if (action->shouldRetry()) { - action->retry(); // Call retry only if shouldRetry_ is true - return; - } - } - // If no actions allow retries, stop retrying - if (!anyRetryAllowed) { - edm::LogInfo("SonicClientBase") << "No retry actions available. Stopping retries."; + if (action->shouldRetry()) { + action->retry(); // Call retry only if shouldRetry_ is true return; + } } //prepare an exception if no more retries left - else { + if (!anyRetryAllowed) { + edm::LogInfo("SonicClientBase") << "SonicCallFailed: call failed, no retry actions available after " + << totalTries_ << " tries."; edm::Exception ex(edm::errors::ExternalFailure); - ex << "SonicCallFailed: call failed after max " << tries_ << " tries"; + ex << "SonicCallFailed: call failed, no retry actions available after " << totalTries_ << " tries."; eptr = make_exception_ptr(ex); } } @@ -113,7 +99,19 @@ void SonicClientBase::fillBasePSetDescription(edm::ParameterSetDescription& desc //restrict allowed values desc.ifValue(edm::ParameterDescription("mode", "PseudoAsync", true), edm::allowedValues("Sync", "Async", "PseudoAsync")); - if (allowRetry) - desc.addUntracked("allowedTries", 0); + if (allowRetry) { + // Defines the structure of each entry in the VPSet + edm::ParameterSetDescription retryDesc; + retryDesc.add("retryType", "RetrySameServerAction"); + + // Define a default retry action + edm::ParameterSet defaultRetry; + defaultRetry.addParameter("retryType", "RetrySameServerAction"); + defaultRetry.addUntrackedParameter("allowedTries", 0); + + // Add the VPSet with the default retry action + desc.addVPSet("Retry", retryDesc, {defaultRetry}); + } + desc.add("sonicClientBase", desc); desc.addUntracked("verbose", false); } diff --git a/HeterogeneousCore/SonicCore/test/DummyClient.h b/HeterogeneousCore/SonicCore/test/DummyClient.h index ccef888ad9f7d..6504843926c0a 100644 --- a/HeterogeneousCore/SonicCore/test/DummyClient.h +++ b/HeterogeneousCore/SonicCore/test/DummyClient.h @@ -36,7 +36,7 @@ class DummyClient : public SonicClient { this->output_ = this->input_ * factor_; //simulate a failure - if (this->tries_ < fails_) + if (this->totalTries_ < fails_) this->finish(false); else this->finish(true); diff --git a/HeterogeneousCore/SonicCore/test/sonicTest_cfg.py b/HeterogeneousCore/SonicCore/test/sonicTest_cfg.py index 614297d86e3bb..92dad76112429 100644 --- a/HeterogeneousCore/SonicCore/test/sonicTest_cfg.py +++ b/HeterogeneousCore/SonicCore/test/sonicTest_cfg.py @@ -24,8 +24,13 @@ mode = cms.string("Sync"), factor = cms.int32(-1), wait = cms.int32(10), - allowedTries = cms.untracked.uint32(0), fails = cms.uint32(0), + Retry = cms.VPSet( + cms.PSet( + retryType = cms.string('RetrySameServerAction'), + allowedTries = cms.untracked.uint32(0) + ) + ) ), ) @@ -35,8 +40,14 @@ mode = cms.string("PseudoAsync"), factor = cms.int32(2), wait = cms.int32(10), - allowedTries = cms.untracked.uint32(0), fails = cms.uint32(0), + Retry = cms.VPSet( + cms.PSet( + retryType = cms.string('RetrySameServerAction'), + allowedTries = cms.untracked.uint32(0) + ) + ) + ), ) @@ -46,32 +57,53 @@ mode = cms.string("Async"), factor = cms.int32(5), wait = cms.int32(10), - allowedTries = cms.untracked.uint32(0), fails = cms.uint32(0), + Retry = cms.VPSet( + cms.PSet( + retryType = cms.string('RetrySameServerAction'), + allowedTries = cms.untracked.uint32(0) + ) + ) ), ) process.dummySyncRetry = process.dummySync.clone( Client = dict( wait = 2, - allowedTries = 2, fails = 1, + Retry = cms.VPSet( + cms.PSet( + retryType = cms.string('RetrySameServerAction'), + allowedTries = cms.untracked.uint32(2) + ) + ) + ) ) process.dummyPseudoAsyncRetry = process.dummyPseudoAsync.clone( Client = dict( wait = 2, - allowedTries = 2, fails = 1, + Retry = cms.VPSet( + cms.PSet( + retryType = cms.string('RetrySameServerAction'), + allowedTries = cms.untracked.uint32(2) + ) + ) ) ) process.dummyAsyncRetry = process.dummyAsync.clone( Client = dict( wait = 2, - allowedTries = 2, fails = 1, + Retry = cms.VPSet( + cms.PSet( + allowedTries = cms.untracked.uint32(2), + retryType = cms.string('RetrySameServerAction') + ) + ) ) ) From 5cb6e8e7679e54e3012b9222130cd4bb51476391 Mon Sep 17 00:00:00 2001 From: Martin Date: Fri, 11 Apr 2025 11:43:38 -0500 Subject: [PATCH 04/26] PR comments, fix fillDescriptions --- .../SonicCore/interface/RetryActionBase.h | 2 ++ .../SonicCore/src/RetrySameServerAction.cc | 2 +- .../SonicCore/src/SonicClientBase.cc | 17 +++++++---------- .../SonicCore/test/sonicTest_cfg.py | 1 - 4 files changed, 10 insertions(+), 12 deletions(-) diff --git a/HeterogeneousCore/SonicCore/interface/RetryActionBase.h b/HeterogeneousCore/SonicCore/interface/RetryActionBase.h index d81183df39a47..e3fc0bbb8af9a 100644 --- a/HeterogeneousCore/SonicCore/interface/RetryActionBase.h +++ b/HeterogeneousCore/SonicCore/interface/RetryActionBase.h @@ -31,3 +31,5 @@ using RetryActionFactory = edmplugin::PluginFactory; #endif + +#define DEFINE_RETRY_ACTION(type) DEFINE_EDM_PLUGIN(RetryActionFactory, type, #type); diff --git a/HeterogeneousCore/SonicCore/src/RetrySameServerAction.cc b/HeterogeneousCore/SonicCore/src/RetrySameServerAction.cc index b5a24af935596..31c4fec227500 100644 --- a/HeterogeneousCore/SonicCore/src/RetrySameServerAction.cc +++ b/HeterogeneousCore/SonicCore/src/RetrySameServerAction.cc @@ -13,4 +13,4 @@ void RetrySameServerAction::retry() { } } -DEFINE_EDM_PLUGIN(RetryActionFactory, RetrySameServerAction, "RetrySameServerAction"); +DEFINE_RETRY_ACTION(RetrySameServerAction) diff --git a/HeterogeneousCore/SonicCore/src/SonicClientBase.cc b/HeterogeneousCore/SonicCore/src/SonicClientBase.cc index 514a680b2518b..9949d9d1f2ea2 100644 --- a/HeterogeneousCore/SonicCore/src/SonicClientBase.cc +++ b/HeterogeneousCore/SonicCore/src/SonicClientBase.cc @@ -24,7 +24,7 @@ SonicClientBase::SonicClientBase(const edm::ParameterSet& params, //Convert to RetryActionPtr Type from raw pointer of retryAction retryActions_.emplace_back(RetryActionPtr(retryAction.release())); } else { - throw cms::Exception("Configuration") << "Unknown Retry type" << actionType << " for SonicClient: " << modeName; + throw cms::Exception("Configuration") << "Unknown Retry type " << actionType << " for SonicClient: " << modeName; } } @@ -67,8 +67,6 @@ void SonicClientBase::finish(bool success, std::exception_ptr eptr) { //retries are only allowed if no exception was raised if (!success and !eptr) { ++totalTries_; - // Check if any retry actions are still valid - bool anyRetryAllowed = false; for (const auto& action : retryActions_) { if (action->shouldRetry()) { action->retry(); // Call retry only if shouldRetry_ is true @@ -76,13 +74,11 @@ void SonicClientBase::finish(bool success, std::exception_ptr eptr) { } } //prepare an exception if no more retries left - if (!anyRetryAllowed) { - edm::LogInfo("SonicClientBase") << "SonicCallFailed: call failed, no retry actions available after " - << totalTries_ << " tries."; - edm::Exception ex(edm::errors::ExternalFailure); - ex << "SonicCallFailed: call failed, no retry actions available after " << totalTries_ << " tries."; - eptr = make_exception_ptr(ex); - } + edm::LogInfo("SonicClientBase") << "SonicCallFailed: call failed, no retry actions available after " << totalTries_ + << " tries."; + edm::Exception ex(edm::errors::ExternalFailure); + ex << "SonicCallFailed: call failed, no retry actions available after " << totalTries_ << " tries."; + eptr = make_exception_ptr(ex); } if (holder_) { holder_->doneWaiting(eptr); @@ -103,6 +99,7 @@ void SonicClientBase::fillBasePSetDescription(edm::ParameterSetDescription& desc // Defines the structure of each entry in the VPSet edm::ParameterSetDescription retryDesc; retryDesc.add("retryType", "RetrySameServerAction"); + retryDesc.addUntracked("allowedTries", 0); // Define a default retry action edm::ParameterSet defaultRetry; diff --git a/HeterogeneousCore/SonicCore/test/sonicTest_cfg.py b/HeterogeneousCore/SonicCore/test/sonicTest_cfg.py index 92dad76112429..bcbe820030440 100644 --- a/HeterogeneousCore/SonicCore/test/sonicTest_cfg.py +++ b/HeterogeneousCore/SonicCore/test/sonicTest_cfg.py @@ -17,7 +17,6 @@ process.options.numberOfThreads = 2 process.options.numberOfStreams = 0 - process.dummySync = _moduleClass(_moduleName, input = cms.int32(1), Client = cms.PSet( From a10ee94f9c0e84fb2fae31cd916bb7def0eeb7b6 Mon Sep 17 00:00:00 2001 From: Martin Date: Fri, 12 Sep 2025 13:36:52 -0500 Subject: [PATCH 05/26] rebase 15_1_0_pre6 --- HeterogeneousCore/SonicCore/BuildFile.xml | 3 ++- .../SonicCore/plugins/BuildFile.xml | 6 ++++++ .../RetrySameServerAction.cc} | 14 ++++++++++++++ .../SonicCore/src/RetrySameServerAction.cc | 16 ---------------- .../SonicTriton/src/TritonClient.cc | 2 +- .../SonicTriton/test/tritonTest_cfg.py | 9 +++++++++ 6 files changed, 32 insertions(+), 18 deletions(-) create mode 100644 HeterogeneousCore/SonicCore/plugins/BuildFile.xml rename HeterogeneousCore/SonicCore/{interface/RetrySameServerAction.h => plugins/RetrySameServerAction.cc} (56%) delete mode 100644 HeterogeneousCore/SonicCore/src/RetrySameServerAction.cc diff --git a/HeterogeneousCore/SonicCore/BuildFile.xml b/HeterogeneousCore/SonicCore/BuildFile.xml index b0d5e2a08b98f..9796c4363c612 100644 --- a/HeterogeneousCore/SonicCore/BuildFile.xml +++ b/HeterogeneousCore/SonicCore/BuildFile.xml @@ -2,7 +2,8 @@ + - +i diff --git a/HeterogeneousCore/SonicCore/plugins/BuildFile.xml b/HeterogeneousCore/SonicCore/plugins/BuildFile.xml new file mode 100644 index 0000000000000..0ecf2187a0f82 --- /dev/null +++ b/HeterogeneousCore/SonicCore/plugins/BuildFile.xml @@ -0,0 +1,6 @@ + + + + + + diff --git a/HeterogeneousCore/SonicCore/interface/RetrySameServerAction.h b/HeterogeneousCore/SonicCore/plugins/RetrySameServerAction.cc similarity index 56% rename from HeterogeneousCore/SonicCore/interface/RetrySameServerAction.h rename to HeterogeneousCore/SonicCore/plugins/RetrySameServerAction.cc index 8ecce2a170847..9877013b93d5b 100644 --- a/HeterogeneousCore/SonicCore/interface/RetrySameServerAction.h +++ b/HeterogeneousCore/SonicCore/plugins/RetrySameServerAction.cc @@ -14,3 +14,17 @@ class RetrySameServerAction : public RetryActionBase { private: unsigned allowedTries_, tries_; }; + +void RetrySameServerAction::retry() { + ++tries_; + //if max retries has not been exceeded, call evaluate again + if (tries_ < allowedTries_) { + eval(); + return; + } else { + shouldRetry_ = false; // Flip flag when max retries are reached + edm::LogInfo("RetrySameServerAction") << "Max retry attempts reached. No further retries."; + } +} + +DEFINE_RETRY_ACTION(RetrySameServerAction) diff --git a/HeterogeneousCore/SonicCore/src/RetrySameServerAction.cc b/HeterogeneousCore/SonicCore/src/RetrySameServerAction.cc deleted file mode 100644 index 31c4fec227500..0000000000000 --- a/HeterogeneousCore/SonicCore/src/RetrySameServerAction.cc +++ /dev/null @@ -1,16 +0,0 @@ -#include "HeterogeneousCore/SonicCore/interface/RetrySameServerAction.h" -#include "HeterogeneousCore/SonicCore/interface/SonicClientBase.h" - -void RetrySameServerAction::retry() { - ++tries_; - //if max retries has not been exceeded, call evaluate again - if (tries_ < allowedTries_) { - eval(); - return; - } else { - shouldRetry_ = false; // Flip flag when max retries are reached - edm::LogInfo("RetrySameServerAction") << "Max retry attempts reached. No further retries."; - } -} - -DEFINE_RETRY_ACTION(RetrySameServerAction) diff --git a/HeterogeneousCore/SonicTriton/src/TritonClient.cc b/HeterogeneousCore/SonicTriton/src/TritonClient.cc index ddcdff83448d0..729b6b74ca8dc 100644 --- a/HeterogeneousCore/SonicTriton/src/TritonClient.cc +++ b/HeterogeneousCore/SonicTriton/src/TritonClient.cc @@ -369,7 +369,7 @@ void TritonClient::getResults(const std::vector //default case for sync and pseudo async void TritonClient::evaluate() { //undo previous signal from TritonException - if (tries_ > 0) { + if (totalTries_ > 0) { // If we are retrying then the evaluate method is called outside the frameworks TBB thread pool. // So we need to setup the service token for the current thread to access the service registry. edm::ServiceRegistry::Operate op(token_); diff --git a/HeterogeneousCore/SonicTriton/test/tritonTest_cfg.py b/HeterogeneousCore/SonicTriton/test/tritonTest_cfg.py index 33d6a9c60aad4..0f8404de2d4c8 100644 --- a/HeterogeneousCore/SonicTriton/test/tritonTest_cfg.py +++ b/HeterogeneousCore/SonicTriton/test/tritonTest_cfg.py @@ -62,6 +62,15 @@ modelName = cms.string(model), modelVersion = cms.string(""), modelConfigPath = cms.FileInPath("HeterogeneousCore/SonicTriton/data/models/{}/config.pbtxt".format(model)), + verbose = cms.untracked.bool(options.verbose or options.verboseClient), + useSharedMemory = cms.untracked.bool(not options.noShm), + compression = cms.untracked.string(options.compression), + Retry = cms.VPSet( + cms.PSet( + retryType = cms.string('RetrySameServerAction'), + allowedTries = cms.untracked.uint32(options.tries) + ) + ) ) ) ) From 8843bbc88f22aae75fd0e28f28059d0fb5c6e480 Mon Sep 17 00:00:00 2001 From: Martin Date: Tue, 3 Jun 2025 18:04:44 -0500 Subject: [PATCH 06/26] Add update server function for client --- .../SonicTriton/interface/TritonClient.h | 1 + .../SonicTriton/src/TritonClient.cc | 41 +++++++++++-------- 2 files changed, 25 insertions(+), 17 deletions(-) diff --git a/HeterogeneousCore/SonicTriton/interface/TritonClient.h b/HeterogeneousCore/SonicTriton/interface/TritonClient.h index df8f9b559427c..670e1a750bf0a 100644 --- a/HeterogeneousCore/SonicTriton/interface/TritonClient.h +++ b/HeterogeneousCore/SonicTriton/interface/TritonClient.h @@ -65,6 +65,7 @@ class TritonClient : public SonicClient { bool handle_exception(F&& call); void reportServerSideStats(const ServerSideStats& stats) const; + void updateServer(std::string serverName); ServerSideStats summarizeServerStats(const inference::ModelStatistics& start_status, const inference::ModelStatistics& end_status) const; diff --git a/HeterogeneousCore/SonicTriton/src/TritonClient.cc b/HeterogeneousCore/SonicTriton/src/TritonClient.cc index 729b6b74ca8dc..f232414c1e9e5 100644 --- a/HeterogeneousCore/SonicTriton/src/TritonClient.cc +++ b/HeterogeneousCore/SonicTriton/src/TritonClient.cc @@ -61,7 +61,7 @@ TritonClient::TritonClient(const edm::ParameterSet& params, const std::string& d useSharedMemory_(params.getUntrackedParameter("useSharedMemory")), compressionAlgo_(getCompressionAlgo(params.getUntrackedParameter("compression"))) { options_.emplace_back(params.getParameter("modelName")); - //get appropriate server for this model + edm::Service ts; // We save the token to be able to notify the service in case of an exception in the evaluate method. @@ -70,22 +70,9 @@ TritonClient::TritonClient(const edm::ParameterSet& params, const std::string& d // create the context. token_ = edm::ServiceRegistry::instance().presentToken(); - const auto& server = - ts->serverInfo(options_[0].model_name_, params.getUntrackedParameter("preferredServer")); - serverType_ = server.type; - edm::LogInfo("TritonDiscovery") << debugName_ << " assigned server: " << server.url; - //enforce sync mode for fallback CPU server to avoid contention - //todo: could enforce async mode otherwise (unless mode was specified by user?) - if (serverType_ == TritonServerType::LocalCPU) - setMode(SonicMode::Sync); - isLocal_ = serverType_ == TritonServerType::LocalCPU or serverType_ == TritonServerType::LocalGPU; - - //connect to the server - TRITON_THROW_IF_ERROR( - tc::InferenceServerGrpcClient::Create(&client_, server.url, false, server.useSsl, server.sslOptions), - "TritonClient(): unable to create inference context", - isLocal_); - + //Connect to server + updateServer(params.getUntrackedParameter("preferredServer")); + //set options options_[0].model_version_ = params.getParameter("modelVersion"); options_[0].client_timeout_ = params.getUntrackedParameter("timeout"); @@ -574,6 +561,26 @@ inference::ModelStatistics TritonClient::getServerSideStatus() const { return inference::ModelStatistics{}; } +void TritonClient::updateServer(std::string serverName){ + //get appropriate server for this model + edm::Service ts; + + const auto& server = ts->serverInfo(options_[0].model_name_, serverName); + serverType_ = server.type; + edm::LogInfo("TritonDiscovery") << debugName_ << " assigned server: " << server.url; + //enforce sync mode for fallback CPU server to avoid contention + //todo: could enforce async mode otherwise (unless mode was specified by user?) + if (serverType_ == TritonServerType::LocalCPU) + setMode(SonicMode::Sync); + isLocal_ = serverType_ == TritonServerType::LocalCPU or serverType_ == TritonServerType::LocalGPU; + + //connect to the server + TRITON_THROW_IF_ERROR( + tc::InferenceServerGrpcClient::Create(&client_, server.url, false, server.useSsl, server.sslOptions), + "TritonClient(): unable to create inference context", + isLocal_); +} + //for fillDescriptions void TritonClient::fillPSetDescription(edm::ParameterSetDescription& iDesc) { edm::ParameterSetDescription descClient; From 2f06ebe1d0f5699c8964dababe485bb6ae327a23 Mon Sep 17 00:00:00 2001 From: Trevin Lee Date: Mon, 28 Jul 2025 06:26:58 +0200 Subject: [PATCH 07/26] Add test for Triton retry action in BuildFile.xml --- .../interface/RetryActionDiffServer.h | 20 +++++++++ .../SonicTriton/src/RetryActionDiffServer.cc | 12 +++++ .../SonicTriton/test/BuildFile.xml | 1 + .../test/tritonRetryActionTest_cfg.py | 44 +++++++++++++++++++ 4 files changed, 77 insertions(+) create mode 100644 HeterogeneousCore/SonicTriton/interface/RetryActionDiffServer.h create mode 100644 HeterogeneousCore/SonicTriton/src/RetryActionDiffServer.cc create mode 100644 HeterogeneousCore/SonicTriton/test/tritonRetryActionTest_cfg.py diff --git a/HeterogeneousCore/SonicTriton/interface/RetryActionDiffServer.h b/HeterogeneousCore/SonicTriton/interface/RetryActionDiffServer.h new file mode 100644 index 0000000000000..7fd5050762252 --- /dev/null +++ b/HeterogeneousCore/SonicTriton/interface/RetryActionDiffServer.h @@ -0,0 +1,20 @@ +#ifndef HeterogeneousCore_SonicTriton_RetryActionDiffServer_h +#define HeterogeneousCore_SonicTriton_RetryActionDiffServer_h + +#include "HeterogeneousCore/SonicCore/interface/RetryActionBase.h" + +class RetryActionDiffServer : public RetryActionBase { +public: + RetryActionDiffServer(const edm::ParameterSet& conf, SonicClientBase* client); + ~RetryActionDiffServer() override = default; + + void retry() override; + void start() override; + +private: + std::string diff_server_url_; + std::string diff_server_token_; +}; + +#endif + diff --git a/HeterogeneousCore/SonicTriton/src/RetryActionDiffServer.cc b/HeterogeneousCore/SonicTriton/src/RetryActionDiffServer.cc new file mode 100644 index 0000000000000..1056490b262d6 --- /dev/null +++ b/HeterogeneousCore/SonicTriton/src/RetryActionDiffServer.cc @@ -0,0 +1,12 @@ +#include "HeterogeneousCore/SonicTriton/interface/RetryActionDiffServer.h" + +RetryActionDiffServer::RetryActionDiffServer(const edm::ParameterSet& conf, SonicClientBase* client) + : RetryActionBase(conf, client) {} + +void RetryActionDiffServer::start() { + // to-do +} + +void RetryActionDiffServer::retry() { + // to-do +} \ No newline at end of file diff --git a/HeterogeneousCore/SonicTriton/test/BuildFile.xml b/HeterogeneousCore/SonicTriton/test/BuildFile.xml index e4ff7a0bb56f3..25cf78abe2664 100644 --- a/HeterogeneousCore/SonicTriton/test/BuildFile.xml +++ b/HeterogeneousCore/SonicTriton/test/BuildFile.xml @@ -1,5 +1,6 @@ + diff --git a/HeterogeneousCore/SonicTriton/test/tritonRetryActionTest_cfg.py b/HeterogeneousCore/SonicTriton/test/tritonRetryActionTest_cfg.py new file mode 100644 index 0000000000000..f07461f4cad9b --- /dev/null +++ b/HeterogeneousCore/SonicTriton/test/tritonRetryActionTest_cfg.py @@ -0,0 +1,44 @@ +import FWCore.ParameterSet.Config as cms +import os, sys, json +from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter +from Configuration.ProcessModifiers.enableSonicTriton_cff import enableSonicTriton + +process = cms.Process('tritonTest', enableSonicTriton) + +process.load("HeterogeneousCore.SonicTriton.TritonService_cff") + +process.maxEvents = cms.untracked.PSet(input=cms.untracked.int32(10)) + +process.source = cms.Source("EmptySource") + +process.myProducer = cms.EDProducer("TritonGraphProducer", + # minimal inputs for testing + nodeMin = cms.uint32(1), + nodeMax = cms.uint32(10), + edgeMin = cms.uint32(20), + edgeMax = cms.uint32(40), + # client setup + Client = cms.PSet( + # This address is fake to force an error + address = cms.string("localhost:9999"), + mode = cms.string("Sync"), + # This is your retry logic + Retry = cms.VPSet( + cms.PSet( + retryType = cms.string("RetryActionDiffServer"), + # The address of the real server will be filled in by the TritonService + diff_server_url = cms.string(""), + diff_server_token = cms.string("") + ) + ) + ) +) + +process.p = cms.Path(process.myProducer) + +process.load('FWCore/MessageService/MessageLogger_cfi') +process.MessageLogger.cerr.FwkReport.reportEvery = 1 +# enable verbose output for everything +process.MessageLogger.cerr.default = cms.untracked.PSet( + limit = cms.untracked.int32(10000000) +) \ No newline at end of file From 6d61227960c3f6b68fc2d9d79d7a6c2c5f9903a9 Mon Sep 17 00:00:00 2001 From: Trevin Lee Date: Mon, 4 Aug 2025 07:28:31 +0200 Subject: [PATCH 08/26] Implement retry logic in RetryActionDiffServer and add connectToServer method in TritonClient. Update BuildFile.xml and fix formatting in header files. --- HeterogeneousCore/SonicCore/BuildFile.xml | 2 +- .../interface/RetryActionDiffServer.h | 2 +- .../SonicTriton/interface/TritonClient.h | 1 + .../SonicTriton/src/RetryActionDiffServer.cc | 44 ++++++++++++++++--- .../SonicTriton/src/TritonClient.cc | 19 ++++++++ 5 files changed, 61 insertions(+), 7 deletions(-) diff --git a/HeterogeneousCore/SonicCore/BuildFile.xml b/HeterogeneousCore/SonicCore/BuildFile.xml index 9796c4363c612..5208c91638f37 100644 --- a/HeterogeneousCore/SonicCore/BuildFile.xml +++ b/HeterogeneousCore/SonicCore/BuildFile.xml @@ -6,4 +6,4 @@ -i + diff --git a/HeterogeneousCore/SonicTriton/interface/RetryActionDiffServer.h b/HeterogeneousCore/SonicTriton/interface/RetryActionDiffServer.h index 7fd5050762252..7a7abfe84db5d 100644 --- a/HeterogeneousCore/SonicTriton/interface/RetryActionDiffServer.h +++ b/HeterogeneousCore/SonicTriton/interface/RetryActionDiffServer.h @@ -14,7 +14,7 @@ class RetryActionDiffServer : public RetryActionBase { private: std::string diff_server_url_; std::string diff_server_token_; -}; +}; #endif diff --git a/HeterogeneousCore/SonicTriton/interface/TritonClient.h b/HeterogeneousCore/SonicTriton/interface/TritonClient.h index 670e1a750bf0a..93d41568a853c 100644 --- a/HeterogeneousCore/SonicTriton/interface/TritonClient.h +++ b/HeterogeneousCore/SonicTriton/interface/TritonClient.h @@ -50,6 +50,7 @@ class TritonClient : public SonicClient { void reset() override; TritonServerType serverType() const { return serverType_; } bool isLocal() const { return isLocal_; } + void connectToServer(const std::string& url); //for fillDescriptions static void fillPSetDescription(edm::ParameterSetDescription& iDesc); diff --git a/HeterogeneousCore/SonicTriton/src/RetryActionDiffServer.cc b/HeterogeneousCore/SonicTriton/src/RetryActionDiffServer.cc index 1056490b262d6..c92310e189abb 100644 --- a/HeterogeneousCore/SonicTriton/src/RetryActionDiffServer.cc +++ b/HeterogeneousCore/SonicTriton/src/RetryActionDiffServer.cc @@ -1,12 +1,46 @@ #include "HeterogeneousCore/SonicTriton/interface/RetryActionDiffServer.h" +#include "HeterogeneousCore/SonicTriton/interface/TritonClient.h" +#include "FWCore/MessageLogger/interface/MessageLogger.h" -RetryActionDiffServer::RetryActionDiffServer(const edm::ParameterSet& conf, SonicClientBase* client) - : RetryActionBase(conf, client) {} +RetryActionDiffServer::RetryActionDiffServer( const edm::ParameterSet& conf, SonicClientBase* client) +: RetryActionBase(conf, client), + diff_server_url_(conf.getUntrackedParameter("diffServerUrl", "")), + diff_server_token_(conf.getUntrackedParameter("diffServerToken", "")) +{ + if (this->diff_server_url_.empty()) { + edm::LogWarning("RetryActionDiffServer") << "No alternative server URL provided. This retry action will be disabled."; + this->shouldRetry_ = false; + } +} void RetryActionDiffServer::start() { - // to-do + this->shouldRetry_ = true; } void RetryActionDiffServer::retry() { - // to-do -} \ No newline at end of file + if (!this->shouldRetry_ || this->diff_server_url_.empty()) { + this->shouldRetry_ = false; + edm::LogInfo("RetryActionDiffServer") << "No alternative server available for retry."; + return; + } + + try { + TritonClient* tritonClient = static_cast(client_); + + edm::LogInfo("RetryActionDiffServer") + << "Attempting retry by switching to server: " + << this->diff_server_url_; + + tritonClient->connectToServer(this->diff_server_url_); + eval(); + + } catch (const std::exception& e) { + edm::LogError("RetryActionDiffServer") + << "Failed to retry with alternative server: " + << e.what(); + } + + this->shouldRetry_ = false; +} + +DEFINE_RETRY_ACTION(RetryActionDiffServer); \ No newline at end of file diff --git a/HeterogeneousCore/SonicTriton/src/TritonClient.cc b/HeterogeneousCore/SonicTriton/src/TritonClient.cc index f232414c1e9e5..5f6d4ad48ad76 100644 --- a/HeterogeneousCore/SonicTriton/src/TritonClient.cc +++ b/HeterogeneousCore/SonicTriton/src/TritonClient.cc @@ -598,3 +598,22 @@ void TritonClient::fillPSetDescription(edm::ParameterSetDescription& iDesc) { descClient.addUntracked>("outputs", {}); iDesc.add("Client", descClient); } + +void TritonClient::connectToServer(const std::string& url) { + // Update client state for a generic remote server + serverType_ = TritonServerType::Remote; + isLocal_ = false; + + edm::LogInfo("TritonDiscovery") << debugName_ << " connecting to server: " << url; + + // Use default SSL options + triton::client::SslOptions sslOptions; + bool useSsl = false; // Assuming no SSL for direct URL connection + + // Connect to the server + TRITON_THROW_IF_ERROR( + triton::client::InferenceServerGrpcClient::Create(&client_, url, false, useSsl, sslOptions), + "TritonClient::connectToServer(): unable to create inference context", + false // isLocal is false + ); +} From c5d4a935f3f94ee905a58c073ed109f58e6489da Mon Sep 17 00:00:00 2001 From: Trevin Lee Date: Mon, 4 Aug 2025 08:12:40 +0200 Subject: [PATCH 09/26] Add RetryActionDiffServer class documentation, implement testing constructor for TritonClient, and update BuildFile.xml to include Catch2 for testing. --- HeterogeneousCore/SonicTriton/BuildFile.xml | 6 +- .../interface/RetryActionDiffServer.h | 13 ++ .../SonicTriton/interface/TritonClient.h | 221 +++++++++--------- .../SonicTriton/src/TritonClient.cc | 4 + .../SonicTriton/test/RetryActionDiffServer.cc | 111 +++++++++ 5 files changed, 250 insertions(+), 105 deletions(-) create mode 100644 HeterogeneousCore/SonicTriton/test/RetryActionDiffServer.cc diff --git a/HeterogeneousCore/SonicTriton/BuildFile.xml b/HeterogeneousCore/SonicTriton/BuildFile.xml index b93d51e711e87..0f6e5f6bd24a6 100644 --- a/HeterogeneousCore/SonicTriton/BuildFile.xml +++ b/HeterogeneousCore/SonicTriton/BuildFile.xml @@ -7,9 +7,13 @@ + + - + + + diff --git a/HeterogeneousCore/SonicTriton/interface/RetryActionDiffServer.h b/HeterogeneousCore/SonicTriton/interface/RetryActionDiffServer.h index 7a7abfe84db5d..2b24d4a643786 100644 --- a/HeterogeneousCore/SonicTriton/interface/RetryActionDiffServer.h +++ b/HeterogeneousCore/SonicTriton/interface/RetryActionDiffServer.h @@ -3,6 +3,19 @@ #include "HeterogeneousCore/SonicCore/interface/RetryActionBase.h" +/** + * @class RetryActionDiffServer + * @brief A concrete implementation of RetryActionBase that attempts to retry an inference + * request on a different, user-specified Triton server. + * + * This class is designed to provide a fallback mechanism. If an initial inference + * request fails (e.g., due to server unavailability or a model-specific error), + * this action will be triggered. It reads an alternative server URL from the + * ParameterSet and instructs the TritonClient to reconnect to this new server + * for the retry attempt. This action is designed for one-time use per inference + * call; after the retry attempt, it disables itself until the next `start()` call. + */ + class RetryActionDiffServer : public RetryActionBase { public: RetryActionDiffServer(const edm::ParameterSet& conf, SonicClientBase* client); diff --git a/HeterogeneousCore/SonicTriton/interface/TritonClient.h b/HeterogeneousCore/SonicTriton/interface/TritonClient.h index 93d41568a853c..bb5cdab4164a8 100644 --- a/HeterogeneousCore/SonicTriton/interface/TritonClient.h +++ b/HeterogeneousCore/SonicTriton/interface/TritonClient.h @@ -1,104 +1,117 @@ -#ifndef HeterogeneousCore_SonicTriton_TritonClient -#define HeterogeneousCore_SonicTriton_TritonClient - -#include "FWCore/ParameterSet/interface/ParameterSet.h" -#include "FWCore/ParameterSet/interface/ParameterSetDescription.h" -#include "FWCore/ServiceRegistry/interface/ServiceToken.h" -#include "HeterogeneousCore/SonicCore/interface/SonicClient.h" -#include "HeterogeneousCore/SonicTriton/interface/TritonData.h" -#include "HeterogeneousCore/SonicTriton/interface/TritonService.h" - -#include -#include -#include -#include -#include - -#include "grpc_client.h" -#include "grpc_service.pb.h" - -enum class TritonBatchMode { Rectangular = 1, Ragged = 2 }; - -class TritonClient : public SonicClient { -public: - struct ServerSideStats { - uint64_t inference_count_; - uint64_t execution_count_; - uint64_t success_count_; - uint64_t cumm_time_ns_; - uint64_t queue_time_ns_; - uint64_t compute_input_time_ns_; - uint64_t compute_infer_time_ns_; - uint64_t compute_output_time_ns_; - }; - - //constructor - TritonClient(const edm::ParameterSet& params, const std::string& debugName); - - //destructor - ~TritonClient() override; - - //accessors - unsigned batchSize() const; - TritonBatchMode batchMode() const { return batchMode_; } - bool verbose() const { return verbose_; } - bool useSharedMemory() const { return useSharedMemory_; } - void setUseSharedMemory(bool useShm) { useSharedMemory_ = useShm; } - bool setBatchSize(unsigned bsize); - void setBatchMode(TritonBatchMode batchMode); - void resetBatchMode(); - void reset() override; - TritonServerType serverType() const { return serverType_; } - bool isLocal() const { return isLocal_; } - void connectToServer(const std::string& url); - - //for fillDescriptions - static void fillPSetDescription(edm::ParameterSetDescription& iDesc); - -protected: - //helpers - bool noOuterDim() const { return noOuterDim_; } - unsigned outerDim() const { return outerDim_; } - unsigned nEntries() const; - void getResults(const std::vector>& results); - void evaluate() override; - template - bool handle_exception(F&& call); - - void reportServerSideStats(const ServerSideStats& stats) const; - void updateServer(std::string serverName); - ServerSideStats summarizeServerStats(const inference::ModelStatistics& start_status, - const inference::ModelStatistics& end_status) const; - - inference::ModelStatistics getServerSideStatus() const; - - //members - unsigned maxOuterDim_; - unsigned outerDim_; - bool noOuterDim_; - unsigned nEntries_; - TritonBatchMode batchMode_; - bool manualBatchMode_; - bool verbose_; - bool useSharedMemory_; - TritonServerType serverType_; - bool isLocal_; - grpc_compression_algorithm compressionAlgo_; - triton::client::Headers headers_; - - std::unique_ptr client_; - //stores timeout, model name and version - std::vector options_; - edm::ServiceToken token_; - -private: - friend TritonInputData; - friend TritonOutputData; - - //private accessors only used by data - auto client() { return client_.get(); } - void addEntry(unsigned entry); - void resizeEntries(unsigned entry); -}; - -#endif +#ifndef HeterogeneousCore_SonicTriton_TritonClient +#define HeterogeneousCore_SonicTriton_TritonClient + +#include "FWCore/ParameterSet/interface/ParameterSet.h" +#include "FWCore/ParameterSet/interface/ParameterSetDescription.h" +#include "FWCore/ServiceRegistry/interface/ServiceToken.h" +#include "HeterogeneousCore/SonicCore/interface/SonicClient.h" +#include "HeterogeneousCore/SonicTriton/interface/TritonData.h" +#include "HeterogeneousCore/SonicTriton/interface/TritonService.h" + +#include +#include +#include +#include +#include + +#include "grpc_client.h" +#include "grpc_service.pb.h" + +enum class TritonBatchMode { Rectangular = 1, Ragged = 2 }; + +class TritonClient : public SonicClient { +public: + struct ServerSideStats { + uint64_t inference_count_; + uint64_t execution_count_; + uint64_t success_count_; + uint64_t cumm_time_ns_; + uint64_t queue_time_ns_; + uint64_t compute_input_time_ns_; + uint64_t compute_infer_time_ns_; + uint64_t compute_output_time_ns_; + }; + + //constructor + TritonClient(const edm::ParameterSet& params, const std::string& debugName); + + //destructor + ~TritonClient() override; + + //accessors + unsigned batchSize() const; + TritonBatchMode batchMode() const { return batchMode_; } + bool verbose() const { return verbose_; } + bool useSharedMemory() const { return useSharedMemory_; } + void setUseSharedMemory(bool useShm) { useSharedMemory_ = useShm; } + bool setBatchSize(unsigned bsize); + void setBatchMode(TritonBatchMode batchMode); + void resetBatchMode(); + void reset() override; + TritonServerType serverType() const { return serverType_; } + bool isLocal() const { return isLocal_; } + virtual void connectToServer(const std::string& url); + + //for fillDescriptions + static void fillPSetDescription(edm::ParameterSetDescription& iDesc); + +protected: + /** + * @brief Constructor for unit testing purposes only. + * + * This constructor is provided to allow the creation of a TritonClient + * instance (or a mock derived from it) without needing the full CMSSW + * Service framework, which is required by the standard constructor. + * This is essential for writing isolated unit tests that do not depend + * on external services. It initializes the base SonicClient with dummy + * parameters. + * @param is_testing A boolean flag to select this constructor. + */ + TritonClient(bool is_testing); + + //helpers + bool noOuterDim() const { return noOuterDim_; } + unsigned outerDim() const { return outerDim_; } + unsigned nEntries() const; + void getResults(const std::vector>& results); + virtual void evaluate() override; + template + bool handle_exception(F&& call); + + void reportServerSideStats(const ServerSideStats& stats) const; + void updateServer(std::string serverName); + ServerSideStats summarizeServerStats(const inference::ModelStatistics& start_status, + const inference::ModelStatistics& end_status) const; + + inference::ModelStatistics getServerSideStatus() const; + + //members + unsigned maxOuterDim_; + unsigned outerDim_; + bool noOuterDim_; + unsigned nEntries_; + TritonBatchMode batchMode_; + bool manualBatchMode_; + bool verbose_; + bool useSharedMemory_; + TritonServerType serverType_; + bool isLocal_; + grpc_compression_algorithm compressionAlgo_; + triton::client::Headers headers_; + + std::unique_ptr client_; + //stores timeout, model name and version + std::vector options_; + edm::ServiceToken token_; + +private: + friend TritonInputData; + friend TritonOutputData; + + //private accessors only used by data + auto client() { return client_.get(); } + void addEntry(unsigned entry); + void resizeEntries(unsigned entry); +}; + +#endif diff --git a/HeterogeneousCore/SonicTriton/src/TritonClient.cc b/HeterogeneousCore/SonicTriton/src/TritonClient.cc index 5f6d4ad48ad76..a112896e33026 100644 --- a/HeterogeneousCore/SonicTriton/src/TritonClient.cc +++ b/HeterogeneousCore/SonicTriton/src/TritonClient.cc @@ -617,3 +617,7 @@ void TritonClient::connectToServer(const std::string& url) { false // isLocal is false ); } + +//constructor for testing +TritonClient::TritonClient(bool /*is_testing*/) : SonicClient(edm::ParameterSet(), "TritonClient_test", "TritonClient") {} + diff --git a/HeterogeneousCore/SonicTriton/test/RetryActionDiffServer.cc b/HeterogeneousCore/SonicTriton/test/RetryActionDiffServer.cc new file mode 100644 index 0000000000000..6c429a95a9816 --- /dev/null +++ b/HeterogeneousCore/SonicTriton/test/RetryActionDiffServer.cc @@ -0,0 +1,111 @@ +#define CATCH_CONFIG_MAIN +#include "catch.hpp" + +#include "FWCore/ParameterSet/interface/ParameterSet.h" +#include "HeterogeneousCore/SonicTriton/interface/TritonClient.h" +#include "HeterogeneousCore/SonicTriton/interface/RetryActionDiffServer.h" + +#include + +// Anonymous namespace to hold our mock object, keeping it local to this test file. +namespace { + // Mock TritonClient to intercept and verify method calls without needing a real server or CMSSW services. + class MockTritonClient : public TritonClient { + public: + // Use the protected, testing-only constructor from the base class. + MockTritonClient() : TritonClient(true) {} + + // --- Methods to override for testing --- + void evaluate() override { + // This method is called by RetryActionBase::eval() + // We can leave it empty as the test directly calls retry(). + } + + void connectToServer(const std::string& url) override { + connectToServer_called_ = true; + last_url_ = url; + } + + // --- Test utility methods --- + bool connectToServerCalled() const { return connectToServer_called_; } + const std::string& getLastUrl() const { return last_url_; } + void reset() { + connectToServer_called_ = false; + last_url_ = ""; + } + + private: + bool connectToServer_called_ = false; + std::string last_url_; + }; +} + +TEST_CASE("Test RetryActionDiffServer Logic", "[RetryActionDiffServer]") { + + // 1. Create the mock client object. + MockTritonClient mockClient; + + // 2. Create the ParameterSet that configures the retry action. + edm::ParameterSet retryPSet; + const std::string alternate_server = "grpc://new-server-for-retry.com:8001"; + retryPSet.addUntrackedParameter("diffServerUrl", alternate_server); + + // 3. Create an instance of the class we are testing. + RetryActionDiffServer retryAction(retryPSet, &mockClient); + + SECTION("Retry calls connectToServer with the correct URL") { + // ARRANGE: Reset state before the test. + mockClient.reset(); + retryAction.start(); // Arms the action, setting shouldRetry_ = true + + // ACT: Manually call the retry method to simulate a failure event. + retryAction.retry(); + + // ASSERT: Verify that our mock's overridden method was called with the expected arguments. + REQUIRE(mockClient.connectToServerCalled()); + REQUIRE(mockClient.getLastUrl() == alternate_server); + } + + SECTION("Retry action is a one-shot") { + // ARRANGE + mockClient.reset(); + retryAction.start(); + + // ACT + retryAction.retry(); // First retry, should work. + + // After the first retry, the internal `shouldRetry_` flag should be false. + // A second call to retry() should do nothing. + // We can verify this by checking that connectToServer was not called a second time. + mockClient.reset(); // Reset our trackers. + retryAction.retry(); // Second retry, should fail silently. + + // ASSERT + REQUIRE_FALSE(mockClient.connectToServerCalled()); + } + + SECTION("Start method re-arms the action") { + // ARRANGE + mockClient.reset(); + retryAction.start(); + retryAction.retry(); // Use up the action. + REQUIRE_FALSE(retryAction.shouldRetry()); // Verify it's spent. + + // ACT: A new inference call begins, so `start()` is called again. + retryAction.start(); + + // ASSERT: The action should now be ready for another retry. + REQUIRE(retryAction.shouldRetry()); + } + + SECTION("Constructor disables action if URL is missing") { + // ARRANGE: Create a PSet with no URL. + edm::ParameterSet emptyPSet; + + // ACT + RetryActionDiffServer disabledAction(emptyPSet, &mockClient); + + // ASSERT + REQUIRE_FALSE(disabledAction.shouldRetry()); + } +} From 91f5cb3ae29db280740cf09dbf26ad8f3b9a6465 Mon Sep 17 00:00:00 2001 From: Trevin Lee Date: Tue, 12 Aug 2025 01:30:49 +0200 Subject: [PATCH 10/26] SonicTriton: implement retry action against different server; update tests; remove old cfg --- .../interface/RetryActionDiffServer.h | 4 +- .../SonicTriton/src/RetryActionDiffServer.cc | 32 +++++++------- .../SonicTriton/test/RetryActionDiffServer.cc | 2 - .../test/tritonRetryActionTest_cfg.py | 44 ------------------- 4 files changed, 18 insertions(+), 64 deletions(-) delete mode 100644 HeterogeneousCore/SonicTriton/test/tritonRetryActionTest_cfg.py diff --git a/HeterogeneousCore/SonicTriton/interface/RetryActionDiffServer.h b/HeterogeneousCore/SonicTriton/interface/RetryActionDiffServer.h index 2b24d4a643786..d823d78d3a960 100644 --- a/HeterogeneousCore/SonicTriton/interface/RetryActionDiffServer.h +++ b/HeterogeneousCore/SonicTriton/interface/RetryActionDiffServer.h @@ -25,8 +25,8 @@ class RetryActionDiffServer : public RetryActionBase { void start() override; private: - std::string diff_server_url_; - std::string diff_server_token_; + std::string alt_server_url_; + std::string alt_server_token_; }; #endif diff --git a/HeterogeneousCore/SonicTriton/src/RetryActionDiffServer.cc b/HeterogeneousCore/SonicTriton/src/RetryActionDiffServer.cc index c92310e189abb..f516289e655da 100644 --- a/HeterogeneousCore/SonicTriton/src/RetryActionDiffServer.cc +++ b/HeterogeneousCore/SonicTriton/src/RetryActionDiffServer.cc @@ -2,15 +2,19 @@ #include "HeterogeneousCore/SonicTriton/interface/TritonClient.h" #include "FWCore/MessageLogger/interface/MessageLogger.h" -RetryActionDiffServer::RetryActionDiffServer( const edm::ParameterSet& conf, SonicClientBase* client) -: RetryActionBase(conf, client), - diff_server_url_(conf.getUntrackedParameter("diffServerUrl", "")), - diff_server_token_(conf.getUntrackedParameter("diffServerToken", "")) -{ - if (this->diff_server_url_.empty()) { - edm::LogWarning("RetryActionDiffServer") << "No alternative server URL provided. This retry action will be disabled."; - this->shouldRetry_ = false; - } +RetryActionDiffServer::RetryActionDiffServer( + const edm::ParameterSet& conf, + SonicClientBase* client +): RetryActionBase(conf, client) { + alt_server_url_ = conf.getUntrackedParameter("altServerUrl", ""); + alt_server_token_ = conf.getUntrackedParameter("altServerToken", ""); + + if (this->alt_server_url_.empty()) { + edm::LogWarning("RetryActionDiffServer") + << "No alternative server URL provided. " + << "This retry action will be disabled."; + this->shouldRetry_ = false; + } } void RetryActionDiffServer::start() { @@ -18,7 +22,7 @@ void RetryActionDiffServer::start() { } void RetryActionDiffServer::retry() { - if (!this->shouldRetry_ || this->diff_server_url_.empty()) { + if (!this->shouldRetry_ || this->alt_server_url_.empty()) { this->shouldRetry_ = false; edm::LogInfo("RetryActionDiffServer") << "No alternative server available for retry."; return; @@ -26,20 +30,16 @@ void RetryActionDiffServer::retry() { try { TritonClient* tritonClient = static_cast(client_); - edm::LogInfo("RetryActionDiffServer") << "Attempting retry by switching to server: " - << this->diff_server_url_; - - tritonClient->connectToServer(this->diff_server_url_); + << this->alt_server_url_; + tritonClient->connectToServer(this->alt_server_url_); eval(); - } catch (const std::exception& e) { edm::LogError("RetryActionDiffServer") << "Failed to retry with alternative server: " << e.what(); } - this->shouldRetry_ = false; } diff --git a/HeterogeneousCore/SonicTriton/test/RetryActionDiffServer.cc b/HeterogeneousCore/SonicTriton/test/RetryActionDiffServer.cc index 6c429a95a9816..2648abfc202d2 100644 --- a/HeterogeneousCore/SonicTriton/test/RetryActionDiffServer.cc +++ b/HeterogeneousCore/SonicTriton/test/RetryActionDiffServer.cc @@ -7,9 +7,7 @@ #include -// Anonymous namespace to hold our mock object, keeping it local to this test file. namespace { - // Mock TritonClient to intercept and verify method calls without needing a real server or CMSSW services. class MockTritonClient : public TritonClient { public: // Use the protected, testing-only constructor from the base class. diff --git a/HeterogeneousCore/SonicTriton/test/tritonRetryActionTest_cfg.py b/HeterogeneousCore/SonicTriton/test/tritonRetryActionTest_cfg.py deleted file mode 100644 index f07461f4cad9b..0000000000000 --- a/HeterogeneousCore/SonicTriton/test/tritonRetryActionTest_cfg.py +++ /dev/null @@ -1,44 +0,0 @@ -import FWCore.ParameterSet.Config as cms -import os, sys, json -from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter -from Configuration.ProcessModifiers.enableSonicTriton_cff import enableSonicTriton - -process = cms.Process('tritonTest', enableSonicTriton) - -process.load("HeterogeneousCore.SonicTriton.TritonService_cff") - -process.maxEvents = cms.untracked.PSet(input=cms.untracked.int32(10)) - -process.source = cms.Source("EmptySource") - -process.myProducer = cms.EDProducer("TritonGraphProducer", - # minimal inputs for testing - nodeMin = cms.uint32(1), - nodeMax = cms.uint32(10), - edgeMin = cms.uint32(20), - edgeMax = cms.uint32(40), - # client setup - Client = cms.PSet( - # This address is fake to force an error - address = cms.string("localhost:9999"), - mode = cms.string("Sync"), - # This is your retry logic - Retry = cms.VPSet( - cms.PSet( - retryType = cms.string("RetryActionDiffServer"), - # The address of the real server will be filled in by the TritonService - diff_server_url = cms.string(""), - diff_server_token = cms.string("") - ) - ) - ) -) - -process.p = cms.Path(process.myProducer) - -process.load('FWCore/MessageService/MessageLogger_cfi') -process.MessageLogger.cerr.FwkReport.reportEvery = 1 -# enable verbose output for everything -process.MessageLogger.cerr.default = cms.untracked.PSet( - limit = cms.untracked.int32(10000000) -) \ No newline at end of file From 00cfd4ded7c4f378b499a9a3ee4886d98e12cce4 Mon Sep 17 00:00:00 2001 From: Trevin Lee Date: Mon, 18 Aug 2025 12:18:57 +0200 Subject: [PATCH 11/26] Refactor RetryActionDiffServer to utilize TritonService for server selection; remove unused parameters and improve documentation. --- .../interface/RetryActionDiffServer.h | 17 ++++++------ .../SonicTriton/interface/TritonClient.h | 2 +- .../SonicTriton/src/RetryActionDiffServer.cc | 26 ++++++------------- 3 files changed, 17 insertions(+), 28 deletions(-) diff --git a/HeterogeneousCore/SonicTriton/interface/RetryActionDiffServer.h b/HeterogeneousCore/SonicTriton/interface/RetryActionDiffServer.h index d823d78d3a960..af7720b90cb0b 100644 --- a/HeterogeneousCore/SonicTriton/interface/RetryActionDiffServer.h +++ b/HeterogeneousCore/SonicTriton/interface/RetryActionDiffServer.h @@ -6,14 +6,15 @@ /** * @class RetryActionDiffServer * @brief A concrete implementation of RetryActionBase that attempts to retry an inference - * request on a different, user-specified Triton server. + * request on a different Triton server. * - * This class is designed to provide a fallback mechanism. If an initial inference - * request fails (e.g., due to server unavailability or a model-specific error), - * this action will be triggered. It reads an alternative server URL from the - * ParameterSet and instructs the TritonClient to reconnect to this new server - * for the retry attempt. This action is designed for one-time use per inference - * call; after the retry attempt, it disables itself until the next `start()` call. + * This class provides a fallback mechanism. If an initial inference request fails + * (e.g., due to server unavailability or a model-specific error), this action will be + * triggered. It queries the central TritonService to select an alternative server (e.g., + * the fallback server when available) and instructs the TritonClient to reconnect to + * that server for the retry attempt. This action is designed for one-time use per + * inference call; after the retry attempt, it disables itself until the next `start()` + * call. */ class RetryActionDiffServer : public RetryActionBase { @@ -25,8 +26,6 @@ class RetryActionDiffServer : public RetryActionBase { void start() override; private: - std::string alt_server_url_; - std::string alt_server_token_; }; #endif diff --git a/HeterogeneousCore/SonicTriton/interface/TritonClient.h b/HeterogeneousCore/SonicTriton/interface/TritonClient.h index bb5cdab4164a8..e41ccb076ecb1 100644 --- a/HeterogeneousCore/SonicTriton/interface/TritonClient.h +++ b/HeterogeneousCore/SonicTriton/interface/TritonClient.h @@ -51,6 +51,7 @@ class TritonClient : public SonicClient { TritonServerType serverType() const { return serverType_; } bool isLocal() const { return isLocal_; } virtual void connectToServer(const std::string& url); + void updateServer(std::string serverName); //for fillDescriptions static void fillPSetDescription(edm::ParameterSetDescription& iDesc); @@ -79,7 +80,6 @@ class TritonClient : public SonicClient { bool handle_exception(F&& call); void reportServerSideStats(const ServerSideStats& stats) const; - void updateServer(std::string serverName); ServerSideStats summarizeServerStats(const inference::ModelStatistics& start_status, const inference::ModelStatistics& end_status) const; diff --git a/HeterogeneousCore/SonicTriton/src/RetryActionDiffServer.cc b/HeterogeneousCore/SonicTriton/src/RetryActionDiffServer.cc index f516289e655da..c3f931ae567ba 100644 --- a/HeterogeneousCore/SonicTriton/src/RetryActionDiffServer.cc +++ b/HeterogeneousCore/SonicTriton/src/RetryActionDiffServer.cc @@ -1,39 +1,29 @@ #include "HeterogeneousCore/SonicTriton/interface/RetryActionDiffServer.h" #include "HeterogeneousCore/SonicTriton/interface/TritonClient.h" +#include "HeterogeneousCore/SonicTriton/interface/TritonService.h" #include "FWCore/MessageLogger/interface/MessageLogger.h" +#include "FWCore/ServiceRegistry/interface/Service.h" RetryActionDiffServer::RetryActionDiffServer( const edm::ParameterSet& conf, SonicClientBase* client -): RetryActionBase(conf, client) { - alt_server_url_ = conf.getUntrackedParameter("altServerUrl", ""); - alt_server_token_ = conf.getUntrackedParameter("altServerToken", ""); - - if (this->alt_server_url_.empty()) { - edm::LogWarning("RetryActionDiffServer") - << "No alternative server URL provided. " - << "This retry action will be disabled."; - this->shouldRetry_ = false; - } -} +): RetryActionBase(conf, client) {} void RetryActionDiffServer::start() { this->shouldRetry_ = true; } void RetryActionDiffServer::retry() { - if (!this->shouldRetry_ || this->alt_server_url_.empty()) { + if (!this->shouldRetry_) { this->shouldRetry_ = false; - edm::LogInfo("RetryActionDiffServer") << "No alternative server available for retry."; + edm::LogInfo("RetryActionDiffServer") << "Retry not armed; skipping."; return; } try { - TritonClient* tritonClient = static_cast(client_); - edm::LogInfo("RetryActionDiffServer") - << "Attempting retry by switching to server: " - << this->alt_server_url_; - tritonClient->connectToServer(this->alt_server_url_); + auto* tritonClient = static_cast(client_); + edm::LogInfo("RetryActionDiffServer") << "Attempting retry by switching to fallback server"; + tritonClient->updateServer(TritonService::Server::fallbackName); eval(); } catch (const std::exception& e) { edm::LogError("RetryActionDiffServer") From aec18e64ad4105aa76467735ccff1d596f69b7c4 Mon Sep 17 00:00:00 2001 From: Martin Date: Fri, 12 Sep 2025 13:43:00 -0500 Subject: [PATCH 12/26] rebase 15_1_0_pre6 --- HeterogeneousCore/SonicTriton/BuildFile.xml | 3 - .../SonicTriton/interface/TritonClient.h | 16 +-- .../SonicTriton/src/TritonClient.cc | 15 ++- .../SonicTriton/test/BuildFile.xml | 14 ++- .../SonicTriton/test/RetryActionDiffServer.cc | 109 ------------------ .../test/retry_action_diff_log_test.sh | 22 ++++ .../test/test_RetryActionDiffServer.cc | 73 ++++++++++++ .../SonicTriton/test/tritonTest_cfg.py | 36 +++++- 8 files changed, 156 insertions(+), 132 deletions(-) delete mode 100644 HeterogeneousCore/SonicTriton/test/RetryActionDiffServer.cc create mode 100755 HeterogeneousCore/SonicTriton/test/retry_action_diff_log_test.sh create mode 100644 HeterogeneousCore/SonicTriton/test/test_RetryActionDiffServer.cc diff --git a/HeterogeneousCore/SonicTriton/BuildFile.xml b/HeterogeneousCore/SonicTriton/BuildFile.xml index 0f6e5f6bd24a6..4af38d69d89e9 100644 --- a/HeterogeneousCore/SonicTriton/BuildFile.xml +++ b/HeterogeneousCore/SonicTriton/BuildFile.xml @@ -7,7 +7,6 @@ - @@ -15,5 +14,3 @@ - - diff --git a/HeterogeneousCore/SonicTriton/interface/TritonClient.h b/HeterogeneousCore/SonicTriton/interface/TritonClient.h index e41ccb076ecb1..9e21b646508e9 100644 --- a/HeterogeneousCore/SonicTriton/interface/TritonClient.h +++ b/HeterogeneousCore/SonicTriton/interface/TritonClient.h @@ -51,24 +51,14 @@ class TritonClient : public SonicClient { TritonServerType serverType() const { return serverType_; } bool isLocal() const { return isLocal_; } virtual void connectToServer(const std::string& url); - void updateServer(std::string serverName); + virtual void updateServer(std::string serverName); //for fillDescriptions static void fillPSetDescription(edm::ParameterSetDescription& iDesc); protected: - /** - * @brief Constructor for unit testing purposes only. - * - * This constructor is provided to allow the creation of a TritonClient - * instance (or a mock derived from it) without needing the full CMSSW - * Service framework, which is required by the standard constructor. - * This is essential for writing isolated unit tests that do not depend - * on external services. It initializes the base SonicClient with dummy - * parameters. - * @param is_testing A boolean flag to select this constructor. - */ - TritonClient(bool is_testing); + // Protected default constructor for unit testing (no framework services) + TritonClient(); //helpers bool noOuterDim() const { return noOuterDim_; } diff --git a/HeterogeneousCore/SonicTriton/src/TritonClient.cc b/HeterogeneousCore/SonicTriton/src/TritonClient.cc index a112896e33026..404c84455e87d 100644 --- a/HeterogeneousCore/SonicTriton/src/TritonClient.cc +++ b/HeterogeneousCore/SonicTriton/src/TritonClient.cc @@ -28,6 +28,19 @@ namespace tc = triton::client; namespace { + // Minimal ParameterSet to satisfy SonicClientBase requirements during unit tests + edm::ParameterSet makeMinimalSonicParamsForTest() { + edm::ParameterSet params; + params.addParameter("mode", "PseudoAsync"); + + edm::ParameterSet defaultRetry; + defaultRetry.addParameter("retryType", "RetrySameServerAction"); + defaultRetry.addUntrackedParameter("allowedTries", 0u); + std::vector retryVec{defaultRetry}; + params.addParameter>("Retry", retryVec); + + return params; + } grpc_compression_algorithm getCompressionAlgo(const std::string& name) { if (name.empty() or name.compare("none") == 0) return grpc_compression_algorithm::GRPC_COMPRESS_NONE; @@ -619,5 +632,5 @@ void TritonClient::connectToServer(const std::string& url) { } //constructor for testing -TritonClient::TritonClient(bool /*is_testing*/) : SonicClient(edm::ParameterSet(), "TritonClient_test", "TritonClient") {} +TritonClient::TritonClient() : SonicClient(makeMinimalSonicParamsForTest(), "TritonClient_test", "TritonClient") {} diff --git a/HeterogeneousCore/SonicTriton/test/BuildFile.xml b/HeterogeneousCore/SonicTriton/test/BuildFile.xml index 25cf78abe2664..f6ab9108035f2 100644 --- a/HeterogeneousCore/SonicTriton/test/BuildFile.xml +++ b/HeterogeneousCore/SonicTriton/test/BuildFile.xml @@ -1,12 +1,22 @@ - + + - + + + + + + + + + + diff --git a/HeterogeneousCore/SonicTriton/test/RetryActionDiffServer.cc b/HeterogeneousCore/SonicTriton/test/RetryActionDiffServer.cc deleted file mode 100644 index 2648abfc202d2..0000000000000 --- a/HeterogeneousCore/SonicTriton/test/RetryActionDiffServer.cc +++ /dev/null @@ -1,109 +0,0 @@ -#define CATCH_CONFIG_MAIN -#include "catch.hpp" - -#include "FWCore/ParameterSet/interface/ParameterSet.h" -#include "HeterogeneousCore/SonicTriton/interface/TritonClient.h" -#include "HeterogeneousCore/SonicTriton/interface/RetryActionDiffServer.h" - -#include - -namespace { - class MockTritonClient : public TritonClient { - public: - // Use the protected, testing-only constructor from the base class. - MockTritonClient() : TritonClient(true) {} - - // --- Methods to override for testing --- - void evaluate() override { - // This method is called by RetryActionBase::eval() - // We can leave it empty as the test directly calls retry(). - } - - void connectToServer(const std::string& url) override { - connectToServer_called_ = true; - last_url_ = url; - } - - // --- Test utility methods --- - bool connectToServerCalled() const { return connectToServer_called_; } - const std::string& getLastUrl() const { return last_url_; } - void reset() { - connectToServer_called_ = false; - last_url_ = ""; - } - - private: - bool connectToServer_called_ = false; - std::string last_url_; - }; -} - -TEST_CASE("Test RetryActionDiffServer Logic", "[RetryActionDiffServer]") { - - // 1. Create the mock client object. - MockTritonClient mockClient; - - // 2. Create the ParameterSet that configures the retry action. - edm::ParameterSet retryPSet; - const std::string alternate_server = "grpc://new-server-for-retry.com:8001"; - retryPSet.addUntrackedParameter("diffServerUrl", alternate_server); - - // 3. Create an instance of the class we are testing. - RetryActionDiffServer retryAction(retryPSet, &mockClient); - - SECTION("Retry calls connectToServer with the correct URL") { - // ARRANGE: Reset state before the test. - mockClient.reset(); - retryAction.start(); // Arms the action, setting shouldRetry_ = true - - // ACT: Manually call the retry method to simulate a failure event. - retryAction.retry(); - - // ASSERT: Verify that our mock's overridden method was called with the expected arguments. - REQUIRE(mockClient.connectToServerCalled()); - REQUIRE(mockClient.getLastUrl() == alternate_server); - } - - SECTION("Retry action is a one-shot") { - // ARRANGE - mockClient.reset(); - retryAction.start(); - - // ACT - retryAction.retry(); // First retry, should work. - - // After the first retry, the internal `shouldRetry_` flag should be false. - // A second call to retry() should do nothing. - // We can verify this by checking that connectToServer was not called a second time. - mockClient.reset(); // Reset our trackers. - retryAction.retry(); // Second retry, should fail silently. - - // ASSERT - REQUIRE_FALSE(mockClient.connectToServerCalled()); - } - - SECTION("Start method re-arms the action") { - // ARRANGE - mockClient.reset(); - retryAction.start(); - retryAction.retry(); // Use up the action. - REQUIRE_FALSE(retryAction.shouldRetry()); // Verify it's spent. - - // ACT: A new inference call begins, so `start()` is called again. - retryAction.start(); - - // ASSERT: The action should now be ready for another retry. - REQUIRE(retryAction.shouldRetry()); - } - - SECTION("Constructor disables action if URL is missing") { - // ARRANGE: Create a PSet with no URL. - edm::ParameterSet emptyPSet; - - // ACT - RetryActionDiffServer disabledAction(emptyPSet, &mockClient); - - // ASSERT - REQUIRE_FALSE(disabledAction.shouldRetry()); - } -} diff --git a/HeterogeneousCore/SonicTriton/test/retry_action_diff_log_test.sh b/HeterogeneousCore/SonicTriton/test/retry_action_diff_log_test.sh new file mode 100755 index 0000000000000..00e1cc1dce90a --- /dev/null +++ b/HeterogeneousCore/SonicTriton/test/retry_action_diff_log_test.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +LOCALTOP=$1 + +tmpFile=$(mktemp -p ${LOCALTOP} RetryActionDiffLogXXXXXXXX.log) +cmsRun ${LOCALTOP}/src/HeterogeneousCore/SonicTriton/test/tritonTest_cfg.py \ + --modules TritonGraphProducer --models gat_test \ + --maxEvents 2 --unittest --device cpu --retryAction diff --verbose \ + > "$tmpFile" 2>&1 +status=$? + +if ! grep -q "Retry type: RetryActionDiffServer" "$tmpFile"; then + echo "Expected retry type log line not found" >&2 + cat "$tmpFile" + rm -f "$tmpFile" + exit 1 +fi + +rm -f "$tmpFile" +exit $status + + diff --git a/HeterogeneousCore/SonicTriton/test/test_RetryActionDiffServer.cc b/HeterogeneousCore/SonicTriton/test/test_RetryActionDiffServer.cc new file mode 100644 index 0000000000000..48305dd083216 --- /dev/null +++ b/HeterogeneousCore/SonicTriton/test/test_RetryActionDiffServer.cc @@ -0,0 +1,73 @@ +#include "catch.hpp" + +#include "HeterogeneousCore/SonicTriton/interface/RetryActionDiffServer.h" +#include "HeterogeneousCore/SonicTriton/interface/TritonClient.h" +#include "HeterogeneousCore/SonicTriton/interface/TritonService.h" +#include "HeterogeneousCore/SonicCore/interface/RetryActionBase.h" + +#include "FWCore/ParameterSet/interface/ParameterSet.h" + +#include + +// Test double for TritonClient to observe updateServer calls without framework/services +class TestTritonClient : public TritonClient { +public: + TestTritonClient() : TritonClient() {} + + void connectToServer(const std::string& url) override { lastConnectedUrl = url; } + + void updateServer(std::string serverName) override { + lastUpdatedServerName = std::move(serverName); + } + + const std::string& lastUrl() const { return lastConnectedUrl; } + const std::string& lastServerName() const { return lastUpdatedServerName; } + +protected: + void evaluate() override {} + +private: + std::string lastConnectedUrl; + std::string lastUpdatedServerName; +}; + +TEST_CASE("RetryActionDiffServer switches to fallback via updateServer", "[RetryActionDiffServer]") { + edm::ParameterSet empty; + TestTritonClient client; + + RetryActionDiffServer action(empty, static_cast(&client)); + + // start should arm the action + action.start(); + REQUIRE(action.shouldRetry()); + + // retry should call updateServer with fallback name then disarm + action.retry(); + REQUIRE(client.lastServerName() == TritonService::Server::fallbackName); + + // second retry without re-arming should be a no-op: lastServerName unchanged + std::string afterFirst = client.lastServerName(); + action.retry(); + REQUIRE(client.lastServerName() == afterFirst); +} + +// A client that throws during updateServer to exercise error handling path +class ThrowingTritonClient : public TritonClient { +public: + ThrowingTritonClient() : TritonClient() {} + void updateServer(std::string) override { throw std::runtime_error("updateServer failure"); } +protected: + void evaluate() override {} +}; + +TEST_CASE("RetryActionDiffServer catches exceptions from updateServer", "[RetryActionDiffServer]") { + edm::ParameterSet empty; + ThrowingTritonClient client; + RetryActionDiffServer action(empty, static_cast(&client)); + action.start(); + + // Should not throw despite client throwing internally; action disarms afterward + REQUIRE_NOTHROW(action.retry()); +} + + diff --git a/HeterogeneousCore/SonicTriton/test/tritonTest_cfg.py b/HeterogeneousCore/SonicTriton/test/tritonTest_cfg.py index 0f8404de2d4c8..8588dcf5cad2c 100644 --- a/HeterogeneousCore/SonicTriton/test/tritonTest_cfg.py +++ b/HeterogeneousCore/SonicTriton/test/tritonTest_cfg.py @@ -21,6 +21,14 @@ parser.add_argument("--brief", default=False, action="store_true", help="briefer output for graph modules") parser.add_argument("--unittest", default=False, action="store_true", help="unit test mode: reduce input sizes") parser.add_argument("--testother", default=False, action="store_true", help="also test gRPC communication if shared memory enabled, or vice versa") +parser.add_argument("--noShm", default=False, action="store_true", help="disable shared memory") +parser.add_argument("--compression", default="", type=str, choices=allowed_compression, help="enable I/O compression") +parser.add_argument("--ssl", default=False, action="store_true", help="enable SSL authentication for server communication") +parser.add_argument("--device", default="auto", type=str.lower, choices=allowed_devices, help="specify device for fallback server") +parser.add_argument("--container", default="apptainer", type=str.lower, choices=allowed_containers, help="specify container for fallback server") +parser.add_argument("--tries", default=0, type=int, help="number of retries for failed request") +parser.add_argument("--retryAction", default="same", type=str, choices=["same","diff"], help="retry policy: same server or different server") +options = parser.parse_args() options = getOptions(parser, verbose=True) @@ -50,6 +58,16 @@ } defaultClient = applyClientOptions(getDefaultClientPSet().clone(), options) +keepMsgs = [] +if options.verbose or options.verboseDiscovery: + keepMsgs.append('TritonDiscovery') +if options.verbose or options.verboseClient: + keepMsgs.append('TritonClient') +if options.verbose or options.verboseService: + keepMsgs.append('TritonService') +if options.verbose: + # ensure RetryActionDiffServer messages are not suppressed if emitted + keepMsgs.append('RetryActionDiffServer') for im,module in enumerate(options.modules): model = options.models[im] @@ -65,10 +83,16 @@ verbose = cms.untracked.bool(options.verbose or options.verboseClient), useSharedMemory = cms.untracked.bool(not options.noShm), compression = cms.untracked.string(options.compression), - Retry = cms.VPSet( - cms.PSet( - retryType = cms.string('RetrySameServerAction'), - allowedTries = cms.untracked.uint32(options.tries) + Retry = ( + cms.VPSet( + cms.PSet( + retryType = cms.string('RetrySameServerAction'), + allowedTries = cms.untracked.uint32(options.tries) + ) + ) if options.retryAction == 'same' else cms.VPSet( + cms.PSet( + retryType = cms.string('RetryActionDiffServer') + ) ) ) ) @@ -93,6 +117,10 @@ processModule.edgeMax = cms.uint32(15000) processModule.brief = cms.bool(options.brief) process.p += processModule + if options.verbose: + print("Retry type:", ('RetrySameServerAction' if options.retryAction == 'same' else 'RetryActionDiffServer')) + if options.verbose or options.verboseClient: + keepMsgs.extend([module,module+':TritonClient']) if options.testother: # clone modules to test both gRPC and shared memory _module2 = module+"GRPC" if processModule.Client.useSharedMemory else "SHM" From d07ba13e9efe6a611e01512c15e019095be53b71 Mon Sep 17 00:00:00 2001 From: Martin Date: Fri, 12 Sep 2025 14:56:29 -0500 Subject: [PATCH 13/26] Fixes to compile --- HeterogeneousCore/SonicCore/plugins/BuildFile.xml | 2 +- .../SonicTriton/test/test_RetryActionDiffServer.cc | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/HeterogeneousCore/SonicCore/plugins/BuildFile.xml b/HeterogeneousCore/SonicCore/plugins/BuildFile.xml index 0ecf2187a0f82..eaff0919e46bc 100644 --- a/HeterogeneousCore/SonicCore/plugins/BuildFile.xml +++ b/HeterogeneousCore/SonicCore/plugins/BuildFile.xml @@ -2,5 +2,5 @@ - + diff --git a/HeterogeneousCore/SonicTriton/test/test_RetryActionDiffServer.cc b/HeterogeneousCore/SonicTriton/test/test_RetryActionDiffServer.cc index 48305dd083216..24f9342af3cd5 100644 --- a/HeterogeneousCore/SonicTriton/test/test_RetryActionDiffServer.cc +++ b/HeterogeneousCore/SonicTriton/test/test_RetryActionDiffServer.cc @@ -1,3 +1,4 @@ +#define CATCH_CONFIG_MAIN #include "catch.hpp" #include "HeterogeneousCore/SonicTriton/interface/RetryActionDiffServer.h" From ccc6607b96f6f5fda8fc98a78c254601c7721d9a Mon Sep 17 00:00:00 2001 From: Martin Date: Tue, 16 Sep 2025 10:22:08 -0500 Subject: [PATCH 14/26] PR comments --- .../SonicCore/src/SonicClientBase.cc | 7 +-- .../SonicCore/test/sonicTestAna_cfg.py | 15 +++++- .../interface/RetryActionDiffServer.h | 7 +-- .../SonicTriton/interface/TritonClient.h | 2 +- .../SonicTriton/src/RetryActionDiffServer.cc | 23 +++++---- .../SonicTriton/src/TritonClient.cc | 48 +++++++++---------- .../SonicTriton/test/BuildFile.xml | 4 +- .../test/test_RetryActionDiffServer.cc | 9 ++-- .../SonicTriton/test/tritonTest_cfg.py | 17 ------- 9 files changed, 59 insertions(+), 73 deletions(-) diff --git a/HeterogeneousCore/SonicCore/src/SonicClientBase.cc b/HeterogeneousCore/SonicCore/src/SonicClientBase.cc index 9949d9d1f2ea2..739e7e6fe7913 100644 --- a/HeterogeneousCore/SonicCore/src/SonicClientBase.cc +++ b/HeterogeneousCore/SonicCore/src/SonicClientBase.cc @@ -24,7 +24,8 @@ SonicClientBase::SonicClientBase(const edm::ParameterSet& params, //Convert to RetryActionPtr Type from raw pointer of retryAction retryActions_.emplace_back(RetryActionPtr(retryAction.release())); } else { - throw cms::Exception("Configuration") << "Unknown Retry type " << actionType << " for SonicClient: " << modeName; + throw cms::Exception("Configuration") + << "Unknown Retry type " << actionType << " for SonicClient: " << fullDebugName_; } } @@ -58,7 +59,7 @@ void SonicClientBase::start(edm::WaitingTaskWithArenaHolder holder) { void SonicClientBase::start() { totalTries_ = 0; // initialize all actions - for (const auto& action : retryActions_) { + for (auto& action : retryActions_) { action->start(); } } @@ -77,7 +78,7 @@ void SonicClientBase::finish(bool success, std::exception_ptr eptr) { edm::LogInfo("SonicClientBase") << "SonicCallFailed: call failed, no retry actions available after " << totalTries_ << " tries."; edm::Exception ex(edm::errors::ExternalFailure); - ex << "SonicCallFailed: call failed, no retry actions available after " << totalTries_ << " tries."; + ex << "SonicCallFailed: call failed, no retry actions available after " << totalTries_ << " tries."; eptr = make_exception_ptr(ex); } if (holder_) { diff --git a/HeterogeneousCore/SonicCore/test/sonicTestAna_cfg.py b/HeterogeneousCore/SonicCore/test/sonicTestAna_cfg.py index 11c23c6cdfcc9..35cf42fa2b5ae 100644 --- a/HeterogeneousCore/SonicCore/test/sonicTestAna_cfg.py +++ b/HeterogeneousCore/SonicCore/test/sonicTestAna_cfg.py @@ -16,16 +16,27 @@ mode = cms.string("Sync"), factor = cms.int32(-1), wait = cms.int32(10), - allowedTries = cms.untracked.uint32(0), fails = cms.uint32(0), + Retry = cms.VPSet( + cms.PSet( + retryType = cms.string('RetrySameServerAction'), + allowedTries = cms.untracked.uint32(0), + ) + ) ), ) process.dummySyncAnaRetry = process.dummySyncAna.clone( Client = dict( wait = 2, - allowedTries = 2, fails = 1, + Retry = cms.VPSet( + cms.PSet( + retryType = cms.string('RetrySameServerAction'), + allowedTries = cms.untracked.uint32(2), + ) + ) + ) ) diff --git a/HeterogeneousCore/SonicTriton/interface/RetryActionDiffServer.h b/HeterogeneousCore/SonicTriton/interface/RetryActionDiffServer.h index af7720b90cb0b..e992e9631f92c 100644 --- a/HeterogeneousCore/SonicTriton/interface/RetryActionDiffServer.h +++ b/HeterogeneousCore/SonicTriton/interface/RetryActionDiffServer.h @@ -16,7 +16,7 @@ * inference call; after the retry attempt, it disables itself until the next `start()` * call. */ - + class RetryActionDiffServer : public RetryActionBase { public: RetryActionDiffServer(const edm::ParameterSet& conf, SonicClientBase* client); @@ -24,9 +24,6 @@ class RetryActionDiffServer : public RetryActionBase { void retry() override; void start() override; - -private: -}; +}; #endif - diff --git a/HeterogeneousCore/SonicTriton/interface/TritonClient.h b/HeterogeneousCore/SonicTriton/interface/TritonClient.h index 9e21b646508e9..84a9cf0328147 100644 --- a/HeterogeneousCore/SonicTriton/interface/TritonClient.h +++ b/HeterogeneousCore/SonicTriton/interface/TritonClient.h @@ -51,7 +51,7 @@ class TritonClient : public SonicClient { TritonServerType serverType() const { return serverType_; } bool isLocal() const { return isLocal_; } virtual void connectToServer(const std::string& url); - virtual void updateServer(std::string serverName); + virtual void updateServer(const std::string& serverName); //for fillDescriptions static void fillPSetDescription(edm::ParameterSetDescription& iDesc); diff --git a/HeterogeneousCore/SonicTriton/src/RetryActionDiffServer.cc b/HeterogeneousCore/SonicTriton/src/RetryActionDiffServer.cc index c3f931ae567ba..c61700030f61d 100644 --- a/HeterogeneousCore/SonicTriton/src/RetryActionDiffServer.cc +++ b/HeterogeneousCore/SonicTriton/src/RetryActionDiffServer.cc @@ -4,14 +4,10 @@ #include "FWCore/MessageLogger/interface/MessageLogger.h" #include "FWCore/ServiceRegistry/interface/Service.h" -RetryActionDiffServer::RetryActionDiffServer( - const edm::ParameterSet& conf, - SonicClientBase* client -): RetryActionBase(conf, client) {} +RetryActionDiffServer::RetryActionDiffServer(const edm::ParameterSet& conf, SonicClientBase* client) + : RetryActionBase(conf, client) {} -void RetryActionDiffServer::start() { - this->shouldRetry_ = true; -} +void RetryActionDiffServer::start() { this->shouldRetry_ = true; } void RetryActionDiffServer::retry() { if (!this->shouldRetry_) { @@ -23,14 +19,17 @@ void RetryActionDiffServer::retry() { try { auto* tritonClient = static_cast(client_); edm::LogInfo("RetryActionDiffServer") << "Attempting retry by switching to fallback server"; + // TODO: Get the server name from TritonService, use fallback for testing tritonClient->updateServer(TritonService::Server::fallbackName); eval(); - } catch (const std::exception& e) { - edm::LogError("RetryActionDiffServer") - << "Failed to retry with alternative server: " - << e.what(); + } catch (TritonException& e) { + e.convertToWarning(); + } catch (std::exception& e) { + edm::LogError("RetryActionDiffServer") << "Failed to retry with alternative server: " << e.what(); + } catch (...) { + edm::LogError("RetryActionDiffServe: rUnknownFailure") << "An unknown exception was thrown"; } this->shouldRetry_ = false; } -DEFINE_RETRY_ACTION(RetryActionDiffServer); \ No newline at end of file +DEFINE_RETRY_ACTION(RetryActionDiffServer); diff --git a/HeterogeneousCore/SonicTriton/src/TritonClient.cc b/HeterogeneousCore/SonicTriton/src/TritonClient.cc index 404c84455e87d..a17d379db2ed4 100644 --- a/HeterogeneousCore/SonicTriton/src/TritonClient.cc +++ b/HeterogeneousCore/SonicTriton/src/TritonClient.cc @@ -85,7 +85,7 @@ TritonClient::TritonClient(const edm::ParameterSet& params, const std::string& d //Connect to server updateServer(params.getUntrackedParameter("preferredServer")); - + //set options options_[0].model_version_ = params.getParameter("modelVersion"); options_[0].client_timeout_ = params.getUntrackedParameter("timeout"); @@ -574,24 +574,24 @@ inference::ModelStatistics TritonClient::getServerSideStatus() const { return inference::ModelStatistics{}; } -void TritonClient::updateServer(std::string serverName){ - //get appropriate server for this model - edm::Service ts; - - const auto& server = ts->serverInfo(options_[0].model_name_, serverName); - serverType_ = server.type; - edm::LogInfo("TritonDiscovery") << debugName_ << " assigned server: " << server.url; - //enforce sync mode for fallback CPU server to avoid contention - //todo: could enforce async mode otherwise (unless mode was specified by user?) - if (serverType_ == TritonServerType::LocalCPU) - setMode(SonicMode::Sync); - isLocal_ = serverType_ == TritonServerType::LocalCPU or serverType_ == TritonServerType::LocalGPU; - - //connect to the server - TRITON_THROW_IF_ERROR( - tc::InferenceServerGrpcClient::Create(&client_, server.url, false, server.useSsl, server.sslOptions), - "TritonClient(): unable to create inference context", - isLocal_); +void TritonClient::updateServer(const std::string& serverName) { + //get appropriate server for this model + edm::Service ts; + + const auto& server = ts->serverInfo(options_[0].model_name_, serverName); + serverType_ = server.type; + edm::LogInfo("TritonDiscovery") << debugName_ << " assigned server: " << server.url; + //enforce sync mode for fallback CPU server to avoid contention + //todo: could enforce async mode otherwise (unless mode was specified by user?) + if (serverType_ == TritonServerType::LocalCPU) + setMode(SonicMode::Sync); + isLocal_ = serverType_ == TritonServerType::LocalCPU or serverType_ == TritonServerType::LocalGPU; + + //connect to the server + TRITON_THROW_IF_ERROR( + tc::InferenceServerGrpcClient::Create(&client_, server.url, false, server.useSsl, server.sslOptions), + "TritonClient(): unable to create inference context", + isLocal_); } //for fillDescriptions @@ -621,16 +621,14 @@ void TritonClient::connectToServer(const std::string& url) { // Use default SSL options triton::client::SslOptions sslOptions; - bool useSsl = false; // Assuming no SSL for direct URL connection + bool useSsl = false; // Assuming no SSL for direct URL connection // Connect to the server - TRITON_THROW_IF_ERROR( - triton::client::InferenceServerGrpcClient::Create(&client_, url, false, useSsl, sslOptions), - "TritonClient::connectToServer(): unable to create inference context", - false // isLocal is false + TRITON_THROW_IF_ERROR(triton::client::InferenceServerGrpcClient::Create(&client_, url, false, useSsl, sslOptions), + "TritonClient::connectToServer(): unable to create inference context", + false // isLocal is false ); } //constructor for testing TritonClient::TritonClient() : SonicClient(makeMinimalSonicParamsForTest(), "TritonClient_test", "TritonClient") {} - diff --git a/HeterogeneousCore/SonicTriton/test/BuildFile.xml b/HeterogeneousCore/SonicTriton/test/BuildFile.xml index f6ab9108035f2..2f6def462c7b0 100644 --- a/HeterogeneousCore/SonicTriton/test/BuildFile.xml +++ b/HeterogeneousCore/SonicTriton/test/BuildFile.xml @@ -1,7 +1,7 @@ - - + + diff --git a/HeterogeneousCore/SonicTriton/test/test_RetryActionDiffServer.cc b/HeterogeneousCore/SonicTriton/test/test_RetryActionDiffServer.cc index 24f9342af3cd5..c501e8f9c9e3c 100644 --- a/HeterogeneousCore/SonicTriton/test/test_RetryActionDiffServer.cc +++ b/HeterogeneousCore/SonicTriton/test/test_RetryActionDiffServer.cc @@ -17,9 +17,7 @@ class TestTritonClient : public TritonClient { void connectToServer(const std::string& url) override { lastConnectedUrl = url; } - void updateServer(std::string serverName) override { - lastUpdatedServerName = std::move(serverName); - } + void updateServer(const std::string& serverName) override { lastUpdatedServerName = serverName; } const std::string& lastUrl() const { return lastConnectedUrl; } const std::string& lastServerName() const { return lastUpdatedServerName; } @@ -56,7 +54,8 @@ TEST_CASE("RetryActionDiffServer switches to fallback via updateServer", "[Retry class ThrowingTritonClient : public TritonClient { public: ThrowingTritonClient() : TritonClient() {} - void updateServer(std::string) override { throw std::runtime_error("updateServer failure"); } + void updateServer(const std::string&) override { throw TritonException("updateServer failure"); } + protected: void evaluate() override {} }; @@ -70,5 +69,3 @@ TEST_CASE("RetryActionDiffServer catches exceptions from updateServer", "[RetryA // Should not throw despite client throwing internally; action disarms afterward REQUIRE_NOTHROW(action.retry()); } - - diff --git a/HeterogeneousCore/SonicTriton/test/tritonTest_cfg.py b/HeterogeneousCore/SonicTriton/test/tritonTest_cfg.py index 8588dcf5cad2c..539d6444059df 100644 --- a/HeterogeneousCore/SonicTriton/test/tritonTest_cfg.py +++ b/HeterogeneousCore/SonicTriton/test/tritonTest_cfg.py @@ -58,16 +58,6 @@ } defaultClient = applyClientOptions(getDefaultClientPSet().clone(), options) -keepMsgs = [] -if options.verbose or options.verboseDiscovery: - keepMsgs.append('TritonDiscovery') -if options.verbose or options.verboseClient: - keepMsgs.append('TritonClient') -if options.verbose or options.verboseService: - keepMsgs.append('TritonService') -if options.verbose: - # ensure RetryActionDiffServer messages are not suppressed if emitted - keepMsgs.append('RetryActionDiffServer') for im,module in enumerate(options.modules): model = options.models[im] @@ -80,9 +70,6 @@ modelName = cms.string(model), modelVersion = cms.string(""), modelConfigPath = cms.FileInPath("HeterogeneousCore/SonicTriton/data/models/{}/config.pbtxt".format(model)), - verbose = cms.untracked.bool(options.verbose or options.verboseClient), - useSharedMemory = cms.untracked.bool(not options.noShm), - compression = cms.untracked.string(options.compression), Retry = ( cms.VPSet( cms.PSet( @@ -117,10 +104,6 @@ processModule.edgeMax = cms.uint32(15000) processModule.brief = cms.bool(options.brief) process.p += processModule - if options.verbose: - print("Retry type:", ('RetrySameServerAction' if options.retryAction == 'same' else 'RetryActionDiffServer')) - if options.verbose or options.verboseClient: - keepMsgs.extend([module,module+':TritonClient']) if options.testother: # clone modules to test both gRPC and shared memory _module2 = module+"GRPC" if processModule.Client.useSharedMemory else "SHM" From 5fee17922ecba1fecc636d642e238e6e0aa70c80 Mon Sep 17 00:00:00 2001 From: Martin Date: Tue, 16 Sep 2025 10:29:29 -0500 Subject: [PATCH 15/26] Move retry options to customize.py --- HeterogeneousCore/SonicTriton/python/customize.py | 8 +++++++- HeterogeneousCore/SonicTriton/test/tritonTest_cfg.py | 12 ------------ 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/HeterogeneousCore/SonicTriton/python/customize.py b/HeterogeneousCore/SonicTriton/python/customize.py index b4f9943423133..1c812a9d280ed 100644 --- a/HeterogeneousCore/SonicTriton/python/customize.py +++ b/HeterogeneousCore/SonicTriton/python/customize.py @@ -90,12 +90,18 @@ def applyOptions(process, options, applyToModules=False): return process def getClientOptions(options): + action = cms.PSet( + retryType = cms.string('RetrySameServerAction'), + allowedTries = cms.untracked.uint32(options.tries)) + if options.retryAction != 'same': + action.retryType = cms.string('RetryActionDiffServer') + return dict( compression = cms.untracked.string(options.compression), useSharedMemory = cms.untracked.bool(not options.noShm), timeout = cms.untracked.uint32(options.timeout), timeoutUnit = cms.untracked.string(options.timeoutUnit), - allowedTries = cms.untracked.uint32(options.tries), + Retry = cms.VPSet(action) ) def applyClientOptions(client, options): diff --git a/HeterogeneousCore/SonicTriton/test/tritonTest_cfg.py b/HeterogeneousCore/SonicTriton/test/tritonTest_cfg.py index 539d6444059df..99744d29b6ddf 100644 --- a/HeterogeneousCore/SonicTriton/test/tritonTest_cfg.py +++ b/HeterogeneousCore/SonicTriton/test/tritonTest_cfg.py @@ -70,18 +70,6 @@ modelName = cms.string(model), modelVersion = cms.string(""), modelConfigPath = cms.FileInPath("HeterogeneousCore/SonicTriton/data/models/{}/config.pbtxt".format(model)), - Retry = ( - cms.VPSet( - cms.PSet( - retryType = cms.string('RetrySameServerAction'), - allowedTries = cms.untracked.uint32(options.tries) - ) - ) if options.retryAction == 'same' else cms.VPSet( - cms.PSet( - retryType = cms.string('RetryActionDiffServer') - ) - ) - ) ) ) ) From 2670deea59eb01450f53c13e0bd273467100b43e Mon Sep 17 00:00:00 2001 From: Martin Date: Wed, 17 Sep 2025 18:47:35 -0500 Subject: [PATCH 16/26] more clean ups --- HeterogeneousCore/SonicTriton/python/customize.py | 1 + HeterogeneousCore/SonicTriton/test/tritonTest_cfg.py | 7 ------- 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/HeterogeneousCore/SonicTriton/python/customize.py b/HeterogeneousCore/SonicTriton/python/customize.py index 1c812a9d280ed..63618b6c155a5 100644 --- a/HeterogeneousCore/SonicTriton/python/customize.py +++ b/HeterogeneousCore/SonicTriton/python/customize.py @@ -35,6 +35,7 @@ def getParser(): parser.add_argument("--fallbackName", default="", type=str, help="name for fallback server") parser.add_argument("--imageName", default="", type=str, help="container image name for fallback server") parser.add_argument("--tempDir", default="", type=str, help="temp directory for fallback server") + parser.add_argument("--retryAction", default="same", type=str, choices=["same","diff"], help="retry policy: same server or different server") return parser diff --git a/HeterogeneousCore/SonicTriton/test/tritonTest_cfg.py b/HeterogeneousCore/SonicTriton/test/tritonTest_cfg.py index 99744d29b6ddf..0bfd04d095128 100644 --- a/HeterogeneousCore/SonicTriton/test/tritonTest_cfg.py +++ b/HeterogeneousCore/SonicTriton/test/tritonTest_cfg.py @@ -21,13 +21,6 @@ parser.add_argument("--brief", default=False, action="store_true", help="briefer output for graph modules") parser.add_argument("--unittest", default=False, action="store_true", help="unit test mode: reduce input sizes") parser.add_argument("--testother", default=False, action="store_true", help="also test gRPC communication if shared memory enabled, or vice versa") -parser.add_argument("--noShm", default=False, action="store_true", help="disable shared memory") -parser.add_argument("--compression", default="", type=str, choices=allowed_compression, help="enable I/O compression") -parser.add_argument("--ssl", default=False, action="store_true", help="enable SSL authentication for server communication") -parser.add_argument("--device", default="auto", type=str.lower, choices=allowed_devices, help="specify device for fallback server") -parser.add_argument("--container", default="apptainer", type=str.lower, choices=allowed_containers, help="specify container for fallback server") -parser.add_argument("--tries", default=0, type=int, help="number of retries for failed request") -parser.add_argument("--retryAction", default="same", type=str, choices=["same","diff"], help="retry policy: same server or different server") options = parser.parse_args() options = getOptions(parser, verbose=True) From 143be1e3484935ed680c2409b286f83386346084 Mon Sep 17 00:00:00 2001 From: Martin Date: Wed, 17 Sep 2025 18:50:18 -0500 Subject: [PATCH 17/26] First draft for server health --- .../SonicTriton/interface/TritonService.h | 30 +++++ .../SonicTriton/src/TritonService.cc | 113 +++++++++++++++++- 2 files changed, 142 insertions(+), 1 deletion(-) diff --git a/HeterogeneousCore/SonicTriton/interface/TritonService.h b/HeterogeneousCore/SonicTriton/interface/TritonService.h index 8f36f73566e06..d583a55d1c4bf 100644 --- a/HeterogeneousCore/SonicTriton/interface/TritonService.h +++ b/HeterogeneousCore/SonicTriton/interface/TritonService.h @@ -3,6 +3,7 @@ #include "FWCore/ParameterSet/interface/ParameterSet.h" #include "FWCore/Utilities/interface/GlobalIdentifier.h" +#include "oneapi/tbb/concurrent_hash_map.h" #include #include @@ -11,6 +12,7 @@ #include #include #include +#include #include "grpc_client.h" @@ -90,6 +92,16 @@ class TritonService { static const std::string fallbackAddress; static const std::string siteconfName; }; + //Dynamic quantities of servers + struct ServerHealth { + bool live{false}; + bool ready{false}; + + uint64_t inferenceCount{0}; + uint64_t failureCount{0}; + double avgQueueTimeMs{0.0}; + double avgInferTimeMs{0.0}; + }; struct Model { Model(const std::string& path_ = "") : path(path_) {} @@ -111,7 +123,23 @@ class TritonService { //accessors void addModel(const std::string& modelName, const std::string& path); + // Change to getServer? Server serverInfo(const std::string& model, const std::string& preferred = "") const; + + // update health stats of all servers + void updateServerHealth(const std::string& modelName = ""); + + // return the best server for retry, ignore the current server + std::optional getBestServer(const std::string& modelName, const std::string& IgnoreServer = ""); + + // helper functions to get server statistics? + // - getServerSideStatus() + // - updateServerStatus() + // - loop over servers_ get statistics + // - getBestServer(model) + // - call updateServerStatus() + // - loop over servers_ get their statistics, compute metric, return server name + const std::string& pid() const { return pid_; } void notifyCallStatus(bool status) const; @@ -139,6 +167,8 @@ class TritonService { std::unordered_map unservedModels_; //this represents a many:many:many map std::unordered_map servers_; + //server health needs concurrent-safe edits + tbb::concurrent_hash_map serversHealth_; std::unordered_map models_; std::unordered_map modules_; int numberOfThreads_; diff --git a/HeterogeneousCore/SonicTriton/src/TritonService.cc b/HeterogeneousCore/SonicTriton/src/TritonService.cc index d0d82bbaa9efc..a7b228a10b491 100644 --- a/HeterogeneousCore/SonicTriton/src/TritonService.cc +++ b/HeterogeneousCore/SonicTriton/src/TritonService.cc @@ -119,11 +119,14 @@ TritonService::TritonService(const edm::ParameterSet& pset, edm::ActivityRegistr << "TritonService: Not allowed to specify more than one server with same name (" << serverName << ")"; } - //loop over all servers: check which models they have + //loop over all servers: check which models they have, populate serverHealth std::string msg; if (verbose_) msg = "List of models for each server:\n"; for (auto& [serverName, server] : servers_) { + //populate serverHealth + serversHealth_.emplace(serverName, ServerHealth{}); + std::unique_ptr client; TRITON_THROW_IF_ERROR( tc::InferenceServerGrpcClient::Create(&client, server.url, false, server.useSsl, server.sslOptions), @@ -245,6 +248,114 @@ TritonService::Server TritonService::serverInfo(const std::string& model, const return server; } +void TritonService::updateServerHealth(const std::string& modelName) { + for (auto& [serverName, server] : servers_) { + try { + std::unique_ptr client; + TRITON_THROW_IF_ERROR( + tc::InferenceServerGrpcClient::Create(&client, server.url, false, server.useSsl, server.sslOptions), + "TritonService(): unable to create inference context for " + serverName + " (" + server.url + ")", + false); + + bool live = false, ready = false; + client->IsServerLive(&live); + client->IsServerReady(&ready); + + inference::ModelStatisticsResponse stats; + if (!modelName.empty()) { + client->ModelInferenceStatistics(&stats, modelName); + } else { + for (const auto& m : server.models) { + client->ModelInferenceStatistics(&stats, m); + } + } + + uint64_t infer_count = 0, queue_count = 0, failures = 0; + double avgQueueTimeMs = 0.0; + double avgInferTimeMs = 0.0; + + for (const auto& mstat : stats.model_stats()) { + if (modelName.empty() || mstat.name() == modelName) { + const auto& infer = mstat.inference_stats(); + + infer_count += infer.compute_infer().count(); + avgInferTimeMs += infer.compute_infer().ns() / 1e3; + queue_count += infer.queue().count(); + avgQueueTimeMs += infer.queue().ns() / 1e3; + failures += infer.fail().count(); + } + } + // Update health map safely with accessor + tbb::concurrent_hash_map::accessor acc; + serversHealth_.find(acc, serverName); + + ServerHealth& health = acc->second; + health.live = live; + health.ready = ready; + health.failureCount = failures; + health.avgQueueTimeMs = avgQueueTimeMs / queue_count; + health.avgInferTimeMs = avgInferTimeMs / infer_count; + + } catch (const TritonException& e) { + // mark existing entry unhealthy if present + tbb::concurrent_hash_map::accessor acc; + if (serversHealth_.find(acc, serverName)) { + ServerHealth& health = acc->second; + health.live = false; + health.ready = false; + } + } catch (const std::exception& e) { + // fallback for other exceptions + tbb::concurrent_hash_map::accessor acc; + if (serversHealth_.find(acc, serverName)) { + ServerHealth& health = acc->second; + health.live = false; + health.ready = false; + } + } + } +} + +std::optional TritonService::getBestServer(const std::string& modelName, + const std::string& IgnoreServer) { + std::optional bestServer; + ServerHealth bestHealth; + + // get fresh ServerHealth statistics + updateServerHealth(modelName); + + for (auto& [serverName, server] : servers_) { + if (serverName == IgnoreServer) + continue; // skip ignored server + if (server.models.find(modelName) == server.models.end()) + continue; // server doesn't have model + + tbb::concurrent_hash_map::const_accessor acc; + if (!serversHealth_.find(acc, serverName)) + continue; // no health info + + const ServerHealth& health = acc->second; + + if (!health.live || !health.ready) + continue; // skip unhealthy + + // Select server according to rules: + // 1) lowest failureCount + // 2) tie-breaker: lowest avgQueueTimeMs + if (!bestServer || health.failureCount < bestHealth.failureCount || + (health.failureCount == bestHealth.failureCount && health.avgQueueTimeMs < bestHealth.avgQueueTimeMs)) { + bestServer = server; + bestHealth = health; + } + } + if (verbose_ && bestServer) { + edm::LogInfo("Chosen server for model '" + modelName + "': " + bestServer->url + + " (failures=" + std::to_string(bestHealth.failureCount) + + ", avgQueueTime=" + std::to_string(bestHealth.avgQueueTimeMs) + " ms)"); + } + return bestServer; +} + void TritonService::preBeginJob(edm::ProcessContext const&) { //only need fallback if there are unserved models if (!fallbackOpts_.enable or unservedModels_.empty()) From 9d264df044b479d5a769aeb74a9fe57b69451f33 Mon Sep 17 00:00:00 2001 From: Martin Date: Thu, 18 Sep 2025 16:48:36 -0500 Subject: [PATCH 18/26] remove redundant --- .../test/retry_action_diff_log_test.sh | 22 ------------------- 1 file changed, 22 deletions(-) delete mode 100755 HeterogeneousCore/SonicTriton/test/retry_action_diff_log_test.sh diff --git a/HeterogeneousCore/SonicTriton/test/retry_action_diff_log_test.sh b/HeterogeneousCore/SonicTriton/test/retry_action_diff_log_test.sh deleted file mode 100755 index 00e1cc1dce90a..0000000000000 --- a/HeterogeneousCore/SonicTriton/test/retry_action_diff_log_test.sh +++ /dev/null @@ -1,22 +0,0 @@ -#!/bin/bash - -LOCALTOP=$1 - -tmpFile=$(mktemp -p ${LOCALTOP} RetryActionDiffLogXXXXXXXX.log) -cmsRun ${LOCALTOP}/src/HeterogeneousCore/SonicTriton/test/tritonTest_cfg.py \ - --modules TritonGraphProducer --models gat_test \ - --maxEvents 2 --unittest --device cpu --retryAction diff --verbose \ - > "$tmpFile" 2>&1 -status=$? - -if ! grep -q "Retry type: RetryActionDiffServer" "$tmpFile"; then - echo "Expected retry type log line not found" >&2 - cat "$tmpFile" - rm -f "$tmpFile" - exit 1 -fi - -rm -f "$tmpFile" -exit $status - - From 965cae190b717b8019d78875ac632440e8bc4aef Mon Sep 17 00:00:00 2001 From: Martin Date: Thu, 18 Sep 2025 16:53:45 -0500 Subject: [PATCH 19/26] Use getBestServer in RetryDiffServerAction --- .../SonicTriton/interface/TritonClient.h | 3 +++ .../SonicTriton/interface/TritonService.h | 6 +++--- .../SonicTriton/src/RetryActionDiffServer.cc | 16 ++++++++++++--- .../SonicTriton/src/TritonClient.cc | 8 +++++++- .../SonicTriton/src/TritonService.cc | 20 +++++++++---------- 5 files changed, 36 insertions(+), 17 deletions(-) diff --git a/HeterogeneousCore/SonicTriton/interface/TritonClient.h b/HeterogeneousCore/SonicTriton/interface/TritonClient.h index 84a9cf0328147..2dd6205442fe1 100644 --- a/HeterogeneousCore/SonicTriton/interface/TritonClient.h +++ b/HeterogeneousCore/SonicTriton/interface/TritonClient.h @@ -50,6 +50,8 @@ class TritonClient : public SonicClient { void reset() override; TritonServerType serverType() const { return serverType_; } bool isLocal() const { return isLocal_; } + std::string modelName() const { return options_[0].model_name_; } + std::string serverName() const { return serverName_; } virtual void connectToServer(const std::string& url); virtual void updateServer(const std::string& serverName); @@ -86,6 +88,7 @@ class TritonClient : public SonicClient { bool useSharedMemory_; TritonServerType serverType_; bool isLocal_; + std::string serverName_; grpc_compression_algorithm compressionAlgo_; triton::client::Headers headers_; diff --git a/HeterogeneousCore/SonicTriton/interface/TritonService.h b/HeterogeneousCore/SonicTriton/interface/TritonService.h index d583a55d1c4bf..44b3b98a18932 100644 --- a/HeterogeneousCore/SonicTriton/interface/TritonService.h +++ b/HeterogeneousCore/SonicTriton/interface/TritonService.h @@ -123,14 +123,14 @@ class TritonService { //accessors void addModel(const std::string& modelName, const std::string& path); - // Change to getServer? - Server serverInfo(const std::string& model, const std::string& preferred = "") const; + + const std::pair& serverInfo(const std::string& model, const std::string& preferred = "") const; // update health stats of all servers void updateServerHealth(const std::string& modelName = ""); // return the best server for retry, ignore the current server - std::optional getBestServer(const std::string& modelName, const std::string& IgnoreServer = ""); + std::optional getBestServer(const std::string& modelName, const std::string& IgnoreServer = ""); // helper functions to get server statistics? // - getServerSideStatus() diff --git a/HeterogeneousCore/SonicTriton/src/RetryActionDiffServer.cc b/HeterogeneousCore/SonicTriton/src/RetryActionDiffServer.cc index c61700030f61d..0abd2da6ca569 100644 --- a/HeterogeneousCore/SonicTriton/src/RetryActionDiffServer.cc +++ b/HeterogeneousCore/SonicTriton/src/RetryActionDiffServer.cc @@ -20,14 +20,24 @@ void RetryActionDiffServer::retry() { auto* tritonClient = static_cast(client_); edm::LogInfo("RetryActionDiffServer") << "Attempting retry by switching to fallback server"; // TODO: Get the server name from TritonService, use fallback for testing - tritonClient->updateServer(TritonService::Server::fallbackName); - eval(); + edm::Service ts; + + // get best server, ignoring the current server + auto bestServerName = ts->getBestServer(tritonClient->modelName(),tritonClient->serverName()); + + if (bestServerName) { + tritonClient->updateServer(*bestServerName); + eval(); + } else { + edm::LogWarning("RetryActionDiffServer") + << "No alternative server found for model " << tritonClient->modelName(); + } } catch (TritonException& e) { e.convertToWarning(); } catch (std::exception& e) { edm::LogError("RetryActionDiffServer") << "Failed to retry with alternative server: " << e.what(); } catch (...) { - edm::LogError("RetryActionDiffServe: rUnknownFailure") << "An unknown exception was thrown"; + edm::LogError("RetryActionDiffServer: UnknownFailure") << "An unknown exception was thrown"; } this->shouldRetry_ = false; } diff --git a/HeterogeneousCore/SonicTriton/src/TritonClient.cc b/HeterogeneousCore/SonicTriton/src/TritonClient.cc index a17d379db2ed4..0becdb31758f7 100644 --- a/HeterogeneousCore/SonicTriton/src/TritonClient.cc +++ b/HeterogeneousCore/SonicTriton/src/TritonClient.cc @@ -578,7 +578,13 @@ void TritonClient::updateServer(const std::string& serverName) { //get appropriate server for this model edm::Service ts; - const auto& server = ts->serverInfo(options_[0].model_name_, serverName); + const auto& serverMap = ts->serverInfo(options_[0].model_name_, serverName); + + const auto& server = serverMap.second; + + //update server name + serverName_ = serverMap.first; + serverType_ = server.type; edm::LogInfo("TritonDiscovery") << debugName_ << " assigned server: " << server.url; //enforce sync mode for fallback CPU server to avoid contention diff --git a/HeterogeneousCore/SonicTriton/src/TritonService.cc b/HeterogeneousCore/SonicTriton/src/TritonService.cc index a7b228a10b491..fd5e0368bc80c 100644 --- a/HeterogeneousCore/SonicTriton/src/TritonService.cc +++ b/HeterogeneousCore/SonicTriton/src/TritonService.cc @@ -226,7 +226,7 @@ void TritonService::preModuleDestruction(edm::ModuleDescription const& desc) { } //second return value is only true if fallback CPU server is being used -TritonService::Server TritonService::serverInfo(const std::string& model, const std::string& preferred) const { +const std::pair& TritonService::serverInfo(const std::string& model, const std::string& preferred) const { auto mit = models_.find(model); if (mit == models_.end()) throw cms::Exception("MissingModel") << "TritonService: There are no servers that provide model " << model; @@ -244,8 +244,8 @@ TritonService::Server TritonService::serverInfo(const std::string& model, const const auto& serverName(msit == modelServers.end() ? *modelServers.begin() : preferred); //todo: use some algorithm to select server rather than just picking arbitrarily - const auto& server(servers_.find(serverName)->second); - return server; + const auto serverPair = servers_.find(serverName); + return *serverPair; } void TritonService::updateServerHealth(const std::string& modelName) { @@ -316,9 +316,9 @@ void TritonService::updateServerHealth(const std::string& modelName) { } } -std::optional TritonService::getBestServer(const std::string& modelName, +std::optional TritonService::getBestServer(const std::string& modelName, const std::string& IgnoreServer) { - std::optional bestServer; + std::optional bestServerName; ServerHealth bestHealth; // get fresh ServerHealth statistics @@ -342,18 +342,18 @@ std::optional TritonService::getBestServer(const std::str // Select server according to rules: // 1) lowest failureCount // 2) tie-breaker: lowest avgQueueTimeMs - if (!bestServer || health.failureCount < bestHealth.failureCount || + if (!bestServerName || health.failureCount < bestHealth.failureCount || (health.failureCount == bestHealth.failureCount && health.avgQueueTimeMs < bestHealth.avgQueueTimeMs)) { - bestServer = server; + bestServerName = serverName; bestHealth = health; } } - if (verbose_ && bestServer) { - edm::LogInfo("Chosen server for model '" + modelName + "': " + bestServer->url + + if (verbose_ && bestServerName) { + edm::LogInfo("Chosen server for model '" + modelName + "': " + *bestServerName + " (failures=" + std::to_string(bestHealth.failureCount) + ", avgQueueTime=" + std::to_string(bestHealth.avgQueueTimeMs) + " ms)"); } - return bestServer; + return bestServerName; } void TritonService::preBeginJob(edm::ProcessContext const&) { From a1070cfba4d425361e42b8f9d8b5725615b8ae17 Mon Sep 17 00:00:00 2001 From: Trevin Date: Sun, 12 Oct 2025 23:35:00 -0700 Subject: [PATCH 20/26] Add dynamic model loading and unloading to TritonService - Introduced `loadModel` and `unloadModel` methods for managing model lifecycle. - Added mutex for thread safety during model operations. - Updated `TritonService` header and implementation to support dynamic model management. - Enhanced logging for model loading and unloading processes. - Updated test configurations to include dynamic model loading tests. --- .../SonicTriton/interface/TritonService.h | 15 +- .../SonicTriton/src/TritonService.cc | 223 ++++++++++++++++-- .../SonicTriton/test/BuildFile.xml | 7 +- .../test/DynamicModelLoadingProducer.cc | 85 +++++++ .../SonicTriton/test/tritonTest_cfg.py | 9 + 5 files changed, 311 insertions(+), 28 deletions(-) create mode 100644 HeterogeneousCore/SonicTriton/test/DynamicModelLoadingProducer.cc diff --git a/HeterogeneousCore/SonicTriton/interface/TritonService.h b/HeterogeneousCore/SonicTriton/interface/TritonService.h index 44b3b98a18932..b70a6dc40a241 100644 --- a/HeterogeneousCore/SonicTriton/interface/TritonService.h +++ b/HeterogeneousCore/SonicTriton/interface/TritonService.h @@ -13,6 +13,7 @@ #include #include #include +#include #include "grpc_client.h" @@ -44,7 +45,8 @@ class TritonService { instanceName(pset.getUntrackedParameter("instanceName")), tempDir(pset.getUntrackedParameter("tempDir")), imageName(pset.getUntrackedParameter("imageName")), - sandboxName(pset.getUntrackedParameter("sandboxName")) { + sandboxName(pset.getUntrackedParameter("sandboxName")) + { //randomize instance name if (instanceName.empty()) { instanceName = @@ -104,7 +106,6 @@ class TritonService { }; struct Model { Model(const std::string& path_ = "") : path(path_) {} - //members std::string path; std::unordered_set servers; @@ -113,7 +114,6 @@ class TritonService { struct Module { //currently assumes that a module can only have one associated model Module(const std::string& model_) : model(model_) {} - //members std::string model; }; @@ -123,7 +123,6 @@ class TritonService { //accessors void addModel(const std::string& modelName, const std::string& path); - const std::pair& serverInfo(const std::string& model, const std::string& preferred = "") const; // update health stats of all servers @@ -144,6 +143,9 @@ class TritonService { void notifyCallStatus(bool status) const; static void fillDescriptions(edm::ConfigurationDescriptions& descriptions); + + bool loadModel(const std::string& modelName, const std::string& path); + bool unloadModel(const std::string& modelName); private: void preallocate(edm::service::SystemBounds const&); @@ -172,6 +174,11 @@ class TritonService { std::unordered_map models_; std::unordered_map modules_; int numberOfThreads_; + + //Dynamic model loading and unloading + std::unordered_map modelRefCount_; + std::unordered_set loadedModels_; + std::mutex modelLoadMutex_; }; #endif diff --git a/HeterogeneousCore/SonicTriton/src/TritonService.cc b/HeterogeneousCore/SonicTriton/src/TritonService.cc index fd5e0368bc80c..97f22b0ad7523 100644 --- a/HeterogeneousCore/SonicTriton/src/TritonService.cc +++ b/HeterogeneousCore/SonicTriton/src/TritonService.cc @@ -51,7 +51,7 @@ namespace { thisErrno = ferror(pipe); if (thisErrno) throw cms::Exception("SystemError") - << "TritonService: failed reading command output with errno " << thisErrno; + << "TritonService: failed reading command output with errno " << thisErrno; } } @@ -129,18 +129,21 @@ TritonService::TritonService(const edm::ParameterSet& pset, edm::ActivityRegistr std::unique_ptr client; TRITON_THROW_IF_ERROR( - tc::InferenceServerGrpcClient::Create(&client, server.url, false, server.useSsl, server.sslOptions), - "TritonService(): unable to create inference context for " + serverName + " (" + server.url + ")", - false); + tc::InferenceServerGrpcClient::Create(&client, server.url, false, server.useSsl, server.sslOptions), + "TritonService(): unable to create inference context for " + serverName + " (" + server.url + ")", + false); if (verbose_) { inference::ServerMetadataResponse serverMetaResponse; auto err = client->ServerMetadata(&serverMetaResponse); if (err.IsOk()) - edm::LogInfo("TritonService") << "Server " << serverName << ": url = " << server.url - << ", version = " << serverMetaResponse.version(); + edm::LogInfo("TritonService") + << "Server " << serverName << ": url = " << server.url + << ", version = " << serverMetaResponse.version(); else - edm::LogInfo("TritonService") << "unable to get metadata for " + serverName + " (" + server.url + ")"; + edm::LogInfo("TritonService") + << "unable to get metadata for " + serverName + " (" + server.url + ")" + << err.Message(); } //if this query fails, it indicates that the server is nonresponsive or saturated @@ -155,28 +158,27 @@ TritonService::TritonService(const edm::ParameterSet& pset, edm::ActivityRegistr for (const auto& modelIndex : repoIndexResponse.models()) { const auto& modelName = modelIndex.name(); auto mit = models_.find(modelName); - if (mit == models_.end()) - mit = models_.emplace(modelName, "").first; + if (mit == models_.end()) mit = models_.emplace(modelName, "").first; auto& modelInfo(mit->second); modelInfo.servers.insert(serverName); server.models.insert(modelName); - if (verbose_) - msg += modelName + ", "; + if (verbose_) msg += modelName + ", "; } } else { const std::string& baseMsg = "unable to get repository index"; const std::string& extraMsg = err.Message().empty() ? "" : ": " + err.Message(); - if (verbose_) - msg += baseMsg + extraMsg; + if (verbose_) msg += baseMsg + extraMsg; else - edm::LogWarning("TritonFailure") << "TritonService(): " << baseMsg << " for " << serverName << " (" - << server.url << ")" << extraMsg; + edm::LogWarning("TritonFailure") + << "TritonService(): " + << baseMsg << " for " + << serverName + << " (" << server.url << ")" + << extraMsg; } - if (verbose_) - msg += "\n"; + if (verbose_) msg += "\n"; } - if (verbose_) - edm::LogInfo("TritonDiscovery") << msg; + if (verbose_) edm::LogInfo("TritonDiscovery") << msg; } void TritonService::preallocate(edm::service::SystemBounds const& bounds) { @@ -548,10 +550,12 @@ void TritonService::fillDescriptions(edm::ConfigurationDescriptions& description fallbackDesc.addUntracked("enable", false); fallbackDesc.addUntracked("debug", false); fallbackDesc.addUntracked("verbose", false); - fallbackDesc.ifValue(edm::ParameterDescription("container", "apptainer", false), - edm::allowedValues("apptainer", "docker", "podman")); - fallbackDesc.ifValue(edm::ParameterDescription("device", "auto", false), - edm::allowedValues("auto", "cpu", "gpu")); + fallbackDesc.ifValue( + edm::ParameterDescription("container", "apptainer", false), + edm::allowedValues("apptainer", "docker", "podman")); + fallbackDesc.ifValue( + edm::ParameterDescription("device", "auto", false), + edm::allowedValues("auto", "cpu", "gpu")); fallbackDesc.addUntracked("retries", -1); fallbackDesc.addUntracked("wait", -1); fallbackDesc.addUntracked("instanceBaseName", "triton_server_instance"); @@ -563,3 +567,176 @@ void TritonService::fillDescriptions(edm::ConfigurationDescriptions& description descriptions.addWithDefaultLabel(desc); } + +bool TritonService::loadModel(const std::string& modelName, const std::string& path) { + std::lock_guard lock(modelLoadMutex_); + + bool isModelLoaded = loadedModels_.count(modelName); + if (isModelLoaded) { + modelRefCount_[modelName]++; + if (verbose_) + edm::LogInfo("TritonService") + << "Model " << modelName + << " already loaded, ref count: " + << modelRefCount_[modelName]; + return true; + } + + // Find which server can host this model + auto mit = models_.find(modelName); + bool isNoServerAvailable = (mit == models_.end() || mit->second.servers.empty()); + if (isNoServerAvailable) { + edm::LogWarning("TritonService") << "loadModel: No server available for model " << modelName; + return false; + } + + const std::string& serverName = *mit->second.servers.begin(); + auto sit = servers_.find(serverName); + if (sit == servers_.end()) { + edm::LogWarning("TritonService") << "loadModel: Server " << serverName << " not found"; + return false; + } // Gets first available server + + std::unique_ptr client; + auto err = tc::InferenceServerGrpcClient::Create( + &client, + sit->second.url, + false, + sit->second.useSsl, + sit->second.sslOptions + ); + + if (!err.IsOk()) { + edm::LogWarning("TritonService") + << "loadModel: Unable to create client for server " << serverName + << ": " << err.Message(); + return false; + } + + // Actually load the model on the server + err = client->LoadModel(modelName); + if (!err.IsOk()) { + edm::LogWarning("TritonService") + << "loadModel: Failed to load model " << modelName + << " on server " << serverName << ": " << err.Message(); + return false; + } + + loadedModels_.insert(modelName); + modelRefCount_[modelName] = 1; + models_[modelName].path = path; + + if (verbose_) + edm::LogInfo("TritonService") + << "Successfully loaded model " << modelName + << " on server " << serverName; + + return true; +} + +bool TritonService::unloadModel(const std::string& modelName) { + std::lock_guard lock(modelLoadMutex_); + + bool isModelLoaded = loadedModels_.count(modelName); + if (!isModelLoaded) { + edm::LogWarning("TritonService") + << "unloadModel: Model " + << modelName + << " is not loaded"; + return false; + } + + // Decrement reference count and check if still in use + modelRefCount_[modelName]--; + bool isStillReferenced = (modelRefCount_[modelName] > 0); + if (isStillReferenced) { + if (verbose_) + edm::LogInfo("TritonService") + << "Model " << modelName + << " still in use, ref count: " + << modelRefCount_[modelName]; + return true; + } + + // Reference count reached 0, determine which server hosts this model + auto mit = models_.find(modelName); + if (mit == models_.end() || mit->second.servers.empty()) { + edm::LogWarning("TritonService") + << "unloadModel: No server information for model " + << modelName; + loadedModels_.erase(modelName); + modelRefCount_.erase(modelName); + return false; + } + + const std::string& serverName = *mit->second.servers.begin(); + + bool isFallbackServer = (serverName == Server::fallbackName); + if (!isFallbackServer) { + if (verbose_) + edm::LogInfo("TritonService") + << "Model " << modelName + << " is on shared server " << serverName + << ", not unloading (other jobs may be using it)"; + + loadedModels_.erase(modelName); + modelRefCount_.erase(modelName); + return true; + } + + if (verbose_) + edm::LogInfo("TritonService") + << "Model " << modelName + << " is on fallback server, proceeding to unload"; + + auto sit = servers_.find(serverName); + bool isNotSafeToUnload = (sit == servers_.end()); + if (isNotSafeToUnload) { + edm::LogWarning("TritonService") << "unloadModel: Fallback server not found"; + loadedModels_.erase(modelName); + modelRefCount_.erase(modelName); + return false; + } + + std::unique_ptr client; + auto err = tc::InferenceServerGrpcClient::Create( + &client, + sit->second.url, + false, + sit->second.useSsl, + sit->second.sslOptions + ); // Creates Triton gRPC client + + if (!err.IsOk()) { + edm::LogWarning("TritonService") + << "unloadModel: Unable to create client for fallback server: " + << err.Message(); + loadedModels_.erase(modelName); + modelRefCount_.erase(modelName); + return false; + } + + err = client->UnloadModel(modelName); + + if (!err.IsOk()) { + edm::LogWarning("TritonService") + << "unloadModel: Failed to unload model " + << modelName + << " from fallback server: " + << err.Message(); + + loadedModels_.erase(modelName); + modelRefCount_.erase(modelName); + return false; + } + + loadedModels_.erase(modelName); + modelRefCount_.erase(modelName); + + if (verbose_) + edm::LogInfo("TritonService") + << "Successfully unloaded model " << modelName + << " from fallback server"; + + return true; +} diff --git a/HeterogeneousCore/SonicTriton/test/BuildFile.xml b/HeterogeneousCore/SonicTriton/test/BuildFile.xml index 2f6def462c7b0..4e8b765205e6f 100644 --- a/HeterogeneousCore/SonicTriton/test/BuildFile.xml +++ b/HeterogeneousCore/SonicTriton/test/BuildFile.xml @@ -5,11 +5,13 @@ - + + + @@ -20,3 +22,6 @@ + + + diff --git a/HeterogeneousCore/SonicTriton/test/DynamicModelLoadingProducer.cc b/HeterogeneousCore/SonicTriton/test/DynamicModelLoadingProducer.cc new file mode 100644 index 0000000000000..efa87ba8b4440 --- /dev/null +++ b/HeterogeneousCore/SonicTriton/test/DynamicModelLoadingProducer.cc @@ -0,0 +1,85 @@ +#include "HeterogeneousCore/SonicTriton/interface/TritonEDProducer.h" +#include "HeterogeneousCore/SonicTriton/interface/TritonService.h" +#include "DataFormats/TestObjects/interface/ToyProducts.h" +#include "FWCore/Framework/interface/MakerMacros.h" +#include "FWCore/MessageLogger/interface/MessageLogger.h" +#include "FWCore/ServiceRegistry/interface/Service.h" + +#include +#include +#include +#include + +// Test module that explicitly exercises dynamic model loading +// This tests the reference counting and thread safety of loadModel/unloadModel +class DynamicModelLoadingProducer : public TritonEDProducer<> { +public: + explicit DynamicModelLoadingProducer(edm::ParameterSet const& cfg) + : TritonEDProducer<>(cfg), + testModelName_(cfg.getParameter("testModelName")), + testModelPath_(cfg.getParameter("testModelPath")), + loadUnloadCycles_(cfg.getParameter("loadUnloadCycles")), + testConcurrency_(cfg.getParameter("testConcurrency")) { + putToken_ = produces(); + } + + void acquire(edm::Event const& iEvent, edm::EventSetup const& iSetup, Input& iInput) override { + edm::Service ts; + + // Test dynamic loading and unloading + if (testConcurrency_) { + // Stress test with multiple rapid load/unload cycles + for (int i = 0; i < loadUnloadCycles_; ++i) { + bool loadResult = ts->loadModel(testModelName_, testModelPath_); + edm::LogInfo("DynamicModelLoadingProducer") + << "Load attempt " << i << ": " << (loadResult ? "success" : "failed"); + + // Small delay to allow other threads to interleave + if (i % 5 == 0) { + std::this_thread::yield(); + } + + bool unloadResult = ts->unloadModel(testModelName_); + edm::LogInfo("DynamicModelLoadingProducer") + << "Unload attempt " << i << ": " << (unloadResult ? "success" : "failed"); + } + } else { + // Simple test: load once, unload once + bool loadResult = ts->loadModel(testModelName_, testModelPath_); + edm::LogInfo("DynamicModelLoadingProducer") + << "Single load: " << (loadResult ? "success" : "failed"); + + bool unloadResult = ts->unloadModel(testModelName_); + edm::LogInfo("DynamicModelLoadingProducer") + << "Single unload: " << (unloadResult ? "success" : "failed"); + } + + // Fill dummy input for base class + iInput = std::make_shared>(); + } + + void produce(edm::Event& iEvent, edm::EventSetup const& iSetup, Output const& iOutput) override { + // Produce dummy output + iEvent.emplace(putToken_, loadUnloadCycles_); + } + + static void fillDescriptions(edm::ConfigurationDescriptions& descriptions) { + edm::ParameterSetDescription desc; + TritonEDProducer<>::fillPSetDescription(desc); + desc.add("testModelName", "test_model"); + desc.add("testModelPath", "/path/to/test_model"); + desc.add("loadUnloadCycles", 1); + desc.add("testConcurrency", false); + descriptions.addWithDefaultLabel(desc); + } + +private: + std::string testModelName_; + std::string testModelPath_; + int loadUnloadCycles_; + bool testConcurrency_; + edm::EDPutTokenT putToken_; +}; + +DEFINE_FWK_MODULE(DynamicModelLoadingProducer); + diff --git a/HeterogeneousCore/SonicTriton/test/tritonTest_cfg.py b/HeterogeneousCore/SonicTriton/test/tritonTest_cfg.py index 0bfd04d095128..0f4477f21c03a 100644 --- a/HeterogeneousCore/SonicTriton/test/tritonTest_cfg.py +++ b/HeterogeneousCore/SonicTriton/test/tritonTest_cfg.py @@ -9,6 +9,7 @@ "TritonGraphFilter": ["gat_test"], "TritonGraphAnalyzer": ["gat_test"], "TritonIdentityProducer": ["ragged_io"], + "DynamicModelLoadingProducer": ["gat_test"], } # other choices @@ -21,6 +22,8 @@ parser.add_argument("--brief", default=False, action="store_true", help="briefer output for graph modules") parser.add_argument("--unittest", default=False, action="store_true", help="unit test mode: reduce input sizes") parser.add_argument("--testother", default=False, action="store_true", help="also test gRPC communication if shared memory enabled, or vice versa") +parser.add_argument("--loadUnloadCycles", default=3, type=int, help="number of load/unload cycles for dynamic model loading test") +parser.add_argument("--testConcurrency", default=False, action="store_true", help="enable concurrent stress test for dynamic model loading") options = parser.parse_args() options = getOptions(parser, verbose=True) @@ -84,6 +87,12 @@ processModule.edgeMin = cms.uint32(8000) processModule.edgeMax = cms.uint32(15000) processModule.brief = cms.bool(options.brief) + elif module=="DynamicModelLoadingProducer": + # Configure dynamic model loading test + processModule.testModelName = cms.string(model) + processModule.testModelPath = cms.string("HeterogeneousCore/SonicTriton/data/models/{}".format(model)) + processModule.loadUnloadCycles = cms.int32(options.loadUnloadCycles) + processModule.testConcurrency = cms.bool(options.testConcurrency) process.p += processModule if options.testother: # clone modules to test both gRPC and shared memory From 1a8e84e26c2769975686ab4790c49e0347eec438 Mon Sep 17 00:00:00 2001 From: Trevin Date: Mon, 13 Oct 2025 11:59:36 -0700 Subject: [PATCH 21/26] Update DynamicModelLoadingProducer to use model input for base class requirements - Modified input handling to utilize actual model input for "x" instead of dummy data. - Adjusted shape and data allocation for input to meet base class expectations. - Updated parameter set description method to use TritonClient for configuration. --- .../SonicTriton/test/DynamicModelLoadingProducer.cc | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/HeterogeneousCore/SonicTriton/test/DynamicModelLoadingProducer.cc b/HeterogeneousCore/SonicTriton/test/DynamicModelLoadingProducer.cc index efa87ba8b4440..b0b1d2ee2b43a 100644 --- a/HeterogeneousCore/SonicTriton/test/DynamicModelLoadingProducer.cc +++ b/HeterogeneousCore/SonicTriton/test/DynamicModelLoadingProducer.cc @@ -54,8 +54,14 @@ class DynamicModelLoadingProducer : public TritonEDProducer<> { << "Single unload: " << (unloadResult ? "success" : "failed"); } - // Fill dummy input for base class - iInput = std::make_shared>(); + // Fill dummy input - use actual input from the model (gat_test expects "x" input) + // This is just to satisfy the base class requirements, not for actual inference + auto& input_x = iInput.at("x"); + auto data_x = input_x.allocate(); + // Minimal dummy data + (*data_x)[0] = std::vector{1.0f}; + input_x.setShape(0, 1, 0); + input_x.toServer(data_x); } void produce(edm::Event& iEvent, edm::EventSetup const& iSetup, Output const& iOutput) override { @@ -65,7 +71,7 @@ class DynamicModelLoadingProducer : public TritonEDProducer<> { static void fillDescriptions(edm::ConfigurationDescriptions& descriptions) { edm::ParameterSetDescription desc; - TritonEDProducer<>::fillPSetDescription(desc); + TritonClient::fillPSetDescription(desc); desc.add("testModelName", "test_model"); desc.add("testModelPath", "/path/to/test_model"); desc.add("loadUnloadCycles", 1); From 734604043f90905c0c7e8cd805fddd052abaf0df Mon Sep 17 00:00:00 2001 From: Trevin Date: Sun, 2 Nov 2025 13:42:52 -0800 Subject: [PATCH 22/26] Fix formatting and syntax issues - Applied scram b code-format to fix formatting issues - Added TRITON_THROW_IF_ERROR for error handling in loadModel/unloadModel - Fixed unload semantics to not erase tracking data on failure - Added comments explaining dynamic loading test requirements - Removed spurious formatting changes --- .../SonicTriton/interface/TritonService.h | 10 +- .../SonicTriton/src/RetryActionDiffServer.cc | 5 +- .../SonicTriton/src/TritonClient.cc | 2 +- .../SonicTriton/src/TritonService.cc | 206 ++++++------------ .../SonicTriton/test/BuildFile.xml | 3 + .../test/DynamicModelLoadingProducer.cc | 21 +- .../SonicTriton/test/tritonTest_cfg.py | 2 + 7 files changed, 94 insertions(+), 155 deletions(-) diff --git a/HeterogeneousCore/SonicTriton/interface/TritonService.h b/HeterogeneousCore/SonicTriton/interface/TritonService.h index b70a6dc40a241..e1a3f05eae77f 100644 --- a/HeterogeneousCore/SonicTriton/interface/TritonService.h +++ b/HeterogeneousCore/SonicTriton/interface/TritonService.h @@ -45,8 +45,7 @@ class TritonService { instanceName(pset.getUntrackedParameter("instanceName")), tempDir(pset.getUntrackedParameter("tempDir")), imageName(pset.getUntrackedParameter("imageName")), - sandboxName(pset.getUntrackedParameter("sandboxName")) - { + sandboxName(pset.getUntrackedParameter("sandboxName")) { //randomize instance name if (instanceName.empty()) { instanceName = @@ -123,7 +122,8 @@ class TritonService { //accessors void addModel(const std::string& modelName, const std::string& path); - const std::pair& serverInfo(const std::string& model, const std::string& preferred = "") const; + const std::pair& serverInfo(const std::string& model, + const std::string& preferred = "") const; // update health stats of all servers void updateServerHealth(const std::string& modelName = ""); @@ -143,7 +143,7 @@ class TritonService { void notifyCallStatus(bool status) const; static void fillDescriptions(edm::ConfigurationDescriptions& descriptions); - + bool loadModel(const std::string& modelName, const std::string& path); bool unloadModel(const std::string& modelName); @@ -174,7 +174,7 @@ class TritonService { std::unordered_map models_; std::unordered_map modules_; int numberOfThreads_; - + //Dynamic model loading and unloading std::unordered_map modelRefCount_; std::unordered_set loadedModels_; diff --git a/HeterogeneousCore/SonicTriton/src/RetryActionDiffServer.cc b/HeterogeneousCore/SonicTriton/src/RetryActionDiffServer.cc index 0abd2da6ca569..073cc393c3ac2 100644 --- a/HeterogeneousCore/SonicTriton/src/RetryActionDiffServer.cc +++ b/HeterogeneousCore/SonicTriton/src/RetryActionDiffServer.cc @@ -23,14 +23,13 @@ void RetryActionDiffServer::retry() { edm::Service ts; // get best server, ignoring the current server - auto bestServerName = ts->getBestServer(tritonClient->modelName(),tritonClient->serverName()); + auto bestServerName = ts->getBestServer(tritonClient->modelName(), tritonClient->serverName()); if (bestServerName) { tritonClient->updateServer(*bestServerName); eval(); } else { - edm::LogWarning("RetryActionDiffServer") - << "No alternative server found for model " << tritonClient->modelName(); + edm::LogWarning("RetryActionDiffServer") << "No alternative server found for model " << tritonClient->modelName(); } } catch (TritonException& e) { e.convertToWarning(); diff --git a/HeterogeneousCore/SonicTriton/src/TritonClient.cc b/HeterogeneousCore/SonicTriton/src/TritonClient.cc index 0becdb31758f7..e71e5cd37046c 100644 --- a/HeterogeneousCore/SonicTriton/src/TritonClient.cc +++ b/HeterogeneousCore/SonicTriton/src/TritonClient.cc @@ -583,7 +583,7 @@ void TritonClient::updateServer(const std::string& serverName) { const auto& server = serverMap.second; //update server name - serverName_ = serverMap.first; + serverName_ = serverMap.first; serverType_ = server.type; edm::LogInfo("TritonDiscovery") << debugName_ << " assigned server: " << server.url; diff --git a/HeterogeneousCore/SonicTriton/src/TritonService.cc b/HeterogeneousCore/SonicTriton/src/TritonService.cc index 97f22b0ad7523..feffc1eb54b03 100644 --- a/HeterogeneousCore/SonicTriton/src/TritonService.cc +++ b/HeterogeneousCore/SonicTriton/src/TritonService.cc @@ -51,7 +51,7 @@ namespace { thisErrno = ferror(pipe); if (thisErrno) throw cms::Exception("SystemError") - << "TritonService: failed reading command output with errno " << thisErrno; + << "TritonService: failed reading command output with errno " << thisErrno; } } @@ -129,21 +129,19 @@ TritonService::TritonService(const edm::ParameterSet& pset, edm::ActivityRegistr std::unique_ptr client; TRITON_THROW_IF_ERROR( - tc::InferenceServerGrpcClient::Create(&client, server.url, false, server.useSsl, server.sslOptions), - "TritonService(): unable to create inference context for " + serverName + " (" + server.url + ")", - false); + tc::InferenceServerGrpcClient::Create(&client, server.url, false, server.useSsl, server.sslOptions), + "TritonService(): unable to create inference context for " + serverName + " (" + server.url + ")", + false); if (verbose_) { inference::ServerMetadataResponse serverMetaResponse; auto err = client->ServerMetadata(&serverMetaResponse); if (err.IsOk()) - edm::LogInfo("TritonService") - << "Server " << serverName << ": url = " << server.url - << ", version = " << serverMetaResponse.version(); + edm::LogInfo("TritonService") << "Server " << serverName << ": url = " << server.url + << ", version = " << serverMetaResponse.version(); else - edm::LogInfo("TritonService") - << "unable to get metadata for " + serverName + " (" + server.url + ")" - << err.Message(); + edm::LogInfo("TritonService") << "unable to get metadata for " + serverName + " (" + server.url + ")" + << err.Message(); } //if this query fails, it indicates that the server is nonresponsive or saturated @@ -158,27 +156,28 @@ TritonService::TritonService(const edm::ParameterSet& pset, edm::ActivityRegistr for (const auto& modelIndex : repoIndexResponse.models()) { const auto& modelName = modelIndex.name(); auto mit = models_.find(modelName); - if (mit == models_.end()) mit = models_.emplace(modelName, "").first; + if (mit == models_.end()) + mit = models_.emplace(modelName, "").first; auto& modelInfo(mit->second); modelInfo.servers.insert(serverName); server.models.insert(modelName); - if (verbose_) msg += modelName + ", "; + if (verbose_) + msg += modelName + ", "; } } else { const std::string& baseMsg = "unable to get repository index"; const std::string& extraMsg = err.Message().empty() ? "" : ": " + err.Message(); - if (verbose_) msg += baseMsg + extraMsg; + if (verbose_) + msg += baseMsg + extraMsg; else - edm::LogWarning("TritonFailure") - << "TritonService(): " - << baseMsg << " for " - << serverName - << " (" << server.url << ")" - << extraMsg; + edm::LogWarning("TritonFailure") << "TritonService(): " << baseMsg << " for " << serverName << " (" + << server.url << ")" << extraMsg; } - if (verbose_) msg += "\n"; + if (verbose_) + msg += "\n"; } - if (verbose_) edm::LogInfo("TritonDiscovery") << msg; + if (verbose_) + edm::LogInfo("TritonDiscovery") << msg; } void TritonService::preallocate(edm::service::SystemBounds const& bounds) { @@ -228,7 +227,8 @@ void TritonService::preModuleDestruction(edm::ModuleDescription const& desc) { } //second return value is only true if fallback CPU server is being used -const std::pair& TritonService::serverInfo(const std::string& model, const std::string& preferred) const { +const std::pair& TritonService::serverInfo( + const std::string& model, const std::string& preferred) const { auto mit = models_.find(model); if (mit == models_.end()) throw cms::Exception("MissingModel") << "TritonService: There are no servers that provide model " << model; @@ -318,8 +318,7 @@ void TritonService::updateServerHealth(const std::string& modelName) { } } -std::optional TritonService::getBestServer(const std::string& modelName, - const std::string& IgnoreServer) { +std::optional TritonService::getBestServer(const std::string& modelName, const std::string& IgnoreServer) { std::optional bestServerName; ServerHealth bestHealth; @@ -346,7 +345,7 @@ std::optional TritonService::getBestServer(const std::string& model // 2) tie-breaker: lowest avgQueueTimeMs if (!bestServerName || health.failureCount < bestHealth.failureCount || (health.failureCount == bestHealth.failureCount && health.avgQueueTimeMs < bestHealth.avgQueueTimeMs)) { - bestServerName = serverName; + bestServerName = serverName; bestHealth = health; } } @@ -550,12 +549,10 @@ void TritonService::fillDescriptions(edm::ConfigurationDescriptions& description fallbackDesc.addUntracked("enable", false); fallbackDesc.addUntracked("debug", false); fallbackDesc.addUntracked("verbose", false); - fallbackDesc.ifValue( - edm::ParameterDescription("container", "apptainer", false), - edm::allowedValues("apptainer", "docker", "podman")); - fallbackDesc.ifValue( - edm::ParameterDescription("device", "auto", false), - edm::allowedValues("auto", "cpu", "gpu")); + fallbackDesc.ifValue(edm::ParameterDescription("container", "apptainer", false), + edm::allowedValues("apptainer", "docker", "podman")); + fallbackDesc.ifValue(edm::ParameterDescription("device", "auto", false), + edm::allowedValues("auto", "cpu", "gpu")); fallbackDesc.addUntracked("retries", -1); fallbackDesc.addUntracked("wait", -1); fallbackDesc.addUntracked("instanceBaseName", "triton_server_instance"); @@ -570,18 +567,16 @@ void TritonService::fillDescriptions(edm::ConfigurationDescriptions& description bool TritonService::loadModel(const std::string& modelName, const std::string& path) { std::lock_guard lock(modelLoadMutex_); - + bool isModelLoaded = loadedModels_.count(modelName); if (isModelLoaded) { modelRefCount_[modelName]++; if (verbose_) - edm::LogInfo("TritonService") - << "Model " << modelName - << " already loaded, ref count: " - << modelRefCount_[modelName]; + edm::LogInfo("TritonService") << "Model " << modelName + << " already loaded, ref count: " << modelRefCount_[modelName]; return true; } - + // Find which server can host this model auto mit = models_.find(modelName); bool isNoServerAvailable = (mit == models_.end() || mit->second.servers.empty()); @@ -589,154 +584,97 @@ bool TritonService::loadModel(const std::string& modelName, const std::string& p edm::LogWarning("TritonService") << "loadModel: No server available for model " << modelName; return false; } - + const std::string& serverName = *mit->second.servers.begin(); auto sit = servers_.find(serverName); if (sit == servers_.end()) { edm::LogWarning("TritonService") << "loadModel: Server " << serverName << " not found"; return false; - } // Gets first available server - + } // Gets first available server + std::unique_ptr client; - auto err = tc::InferenceServerGrpcClient::Create( - &client, - sit->second.url, - false, - sit->second.useSsl, - sit->second.sslOptions - ); - - if (!err.IsOk()) { - edm::LogWarning("TritonService") - << "loadModel: Unable to create client for server " << serverName - << ": " << err.Message(); - return false; - } - + TRITON_THROW_IF_ERROR(tc::InferenceServerGrpcClient::Create( + &client, sit->second.url, false, sit->second.useSsl, sit->second.sslOptions), + "loadModel: unable to create client for server " + serverName, + false); + // Actually load the model on the server - err = client->LoadModel(modelName); - if (!err.IsOk()) { - edm::LogWarning("TritonService") - << "loadModel: Failed to load model " << modelName - << " on server " << serverName << ": " << err.Message(); - return false; - } - + auto err = client->LoadModel(modelName); + TRITON_THROW_IF_ERROR(err, "loadModel: failed to load model " + modelName + " on server " + serverName, false); + loadedModels_.insert(modelName); modelRefCount_[modelName] = 1; models_[modelName].path = path; - + if (verbose_) - edm::LogInfo("TritonService") - << "Successfully loaded model " << modelName - << " on server " << serverName; - + edm::LogInfo("TritonService") << "Successfully loaded model " << modelName << " on server " << serverName; + return true; } bool TritonService::unloadModel(const std::string& modelName) { std::lock_guard lock(modelLoadMutex_); - + bool isModelLoaded = loadedModels_.count(modelName); if (!isModelLoaded) { - edm::LogWarning("TritonService") - << "unloadModel: Model " - << modelName - << " is not loaded"; + edm::LogWarning("TritonService") << "unloadModel: Model " << modelName << " is not loaded"; return false; } - + // Decrement reference count and check if still in use modelRefCount_[modelName]--; bool isStillReferenced = (modelRefCount_[modelName] > 0); if (isStillReferenced) { - if (verbose_) - edm::LogInfo("TritonService") - << "Model " << modelName - << " still in use, ref count: " - << modelRefCount_[modelName]; + if (verbose_) + edm::LogInfo("TritonService") << "Model " << modelName + << " still in use, ref count: " << modelRefCount_[modelName]; return true; } - + // Reference count reached 0, determine which server hosts this model auto mit = models_.find(modelName); if (mit == models_.end() || mit->second.servers.empty()) { - edm::LogWarning("TritonService") - << "unloadModel: No server information for model " - << modelName; - loadedModels_.erase(modelName); - modelRefCount_.erase(modelName); + edm::LogWarning("TritonService") << "unloadModel: No server information for model " << modelName; return false; } - + const std::string& serverName = *mit->second.servers.begin(); - + bool isFallbackServer = (serverName == Server::fallbackName); if (!isFallbackServer) { if (verbose_) - edm::LogInfo("TritonService") - << "Model " << modelName - << " is on shared server " << serverName - << ", not unloading (other jobs may be using it)"; - + edm::LogInfo("TritonService") << "Model " << modelName << " is on shared server " << serverName + << ", not unloading (other jobs may be using it)"; + loadedModels_.erase(modelName); modelRefCount_.erase(modelName); return true; } - + if (verbose_) - edm::LogInfo("TritonService") - << "Model " << modelName - << " is on fallback server, proceeding to unload"; - + edm::LogInfo("TritonService") << "Model " << modelName << " is on fallback server, proceeding to unload"; + auto sit = servers_.find(serverName); bool isNotSafeToUnload = (sit == servers_.end()); if (isNotSafeToUnload) { edm::LogWarning("TritonService") << "unloadModel: Fallback server not found"; - loadedModels_.erase(modelName); - modelRefCount_.erase(modelName); return false; } - + std::unique_ptr client; - auto err = tc::InferenceServerGrpcClient::Create( - &client, - sit->second.url, - false, - sit->second.useSsl, - sit->second.sslOptions - ); // Creates Triton gRPC client - - if (!err.IsOk()) { - edm::LogWarning("TritonService") - << "unloadModel: Unable to create client for fallback server: " - << err.Message(); - loadedModels_.erase(modelName); - modelRefCount_.erase(modelName); - return false; - } - - err = client->UnloadModel(modelName); + TRITON_THROW_IF_ERROR(tc::InferenceServerGrpcClient::Create( + &client, sit->second.url, false, sit->second.useSsl, sit->second.sslOptions), + "unloadModel: unable to create client for fallback server", + false); - if (!err.IsOk()) { - edm::LogWarning("TritonService") - << "unloadModel: Failed to unload model " - << modelName - << " from fallback server: " - << err.Message(); + auto err = client->UnloadModel(modelName); + TRITON_THROW_IF_ERROR(err, "unloadModel: failed to unload model " + modelName + " from fallback server", false); - loadedModels_.erase(modelName); - modelRefCount_.erase(modelName); - return false; - } - loadedModels_.erase(modelName); modelRefCount_.erase(modelName); - + if (verbose_) - edm::LogInfo("TritonService") - << "Successfully unloaded model " << modelName - << " from fallback server"; - + edm::LogInfo("TritonService") << "Successfully unloaded model " << modelName << " from fallback server"; + return true; } diff --git a/HeterogeneousCore/SonicTriton/test/BuildFile.xml b/HeterogeneousCore/SonicTriton/test/BuildFile.xml index 4e8b765205e6f..a7265a19efd91 100644 --- a/HeterogeneousCore/SonicTriton/test/BuildFile.xml +++ b/HeterogeneousCore/SonicTriton/test/BuildFile.xml @@ -22,6 +22,9 @@ + diff --git a/HeterogeneousCore/SonicTriton/test/DynamicModelLoadingProducer.cc b/HeterogeneousCore/SonicTriton/test/DynamicModelLoadingProducer.cc index b0b1d2ee2b43a..aa4742ec3782c 100644 --- a/HeterogeneousCore/SonicTriton/test/DynamicModelLoadingProducer.cc +++ b/HeterogeneousCore/SonicTriton/test/DynamicModelLoadingProducer.cc @@ -25,35 +25,33 @@ class DynamicModelLoadingProducer : public TritonEDProducer<> { void acquire(edm::Event const& iEvent, edm::EventSetup const& iSetup, Input& iInput) override { edm::Service ts; - + // Test dynamic loading and unloading if (testConcurrency_) { // Stress test with multiple rapid load/unload cycles for (int i = 0; i < loadUnloadCycles_; ++i) { bool loadResult = ts->loadModel(testModelName_, testModelPath_); - edm::LogInfo("DynamicModelLoadingProducer") + edm::LogInfo("DynamicModelLoadingProducer") << "Load attempt " << i << ": " << (loadResult ? "success" : "failed"); - + // Small delay to allow other threads to interleave if (i % 5 == 0) { std::this_thread::yield(); } - + bool unloadResult = ts->unloadModel(testModelName_); - edm::LogInfo("DynamicModelLoadingProducer") + edm::LogInfo("DynamicModelLoadingProducer") << "Unload attempt " << i << ": " << (unloadResult ? "success" : "failed"); } } else { // Simple test: load once, unload once bool loadResult = ts->loadModel(testModelName_, testModelPath_); - edm::LogInfo("DynamicModelLoadingProducer") - << "Single load: " << (loadResult ? "success" : "failed"); - + edm::LogInfo("DynamicModelLoadingProducer") << "Single load: " << (loadResult ? "success" : "failed"); + bool unloadResult = ts->unloadModel(testModelName_); - edm::LogInfo("DynamicModelLoadingProducer") - << "Single unload: " << (unloadResult ? "success" : "failed"); + edm::LogInfo("DynamicModelLoadingProducer") << "Single unload: " << (unloadResult ? "success" : "failed"); } - + // Fill dummy input - use actual input from the model (gat_test expects "x" input) // This is just to satisfy the base class requirements, not for actual inference auto& input_x = iInput.at("x"); @@ -88,4 +86,3 @@ class DynamicModelLoadingProducer : public TritonEDProducer<> { }; DEFINE_FWK_MODULE(DynamicModelLoadingProducer); - diff --git a/HeterogeneousCore/SonicTriton/test/tritonTest_cfg.py b/HeterogeneousCore/SonicTriton/test/tritonTest_cfg.py index 0f4477f21c03a..224e139acb702 100644 --- a/HeterogeneousCore/SonicTriton/test/tritonTest_cfg.py +++ b/HeterogeneousCore/SonicTriton/test/tritonTest_cfg.py @@ -89,6 +89,8 @@ processModule.brief = cms.bool(options.brief) elif module=="DynamicModelLoadingProducer": # Configure dynamic model loading test + # NOTE: This test requires the fallback server to be started with explicit model control mode + # (--model-control-mode explicit flag), otherwise dynamic loading will fail processModule.testModelName = cms.string(model) processModule.testModelPath = cms.string("HeterogeneousCore/SonicTriton/data/models/{}".format(model)) processModule.loadUnloadCycles = cms.int32(options.loadUnloadCycles) From 34b3d117104c62dc4d502e577b23ee9df37a989a Mon Sep 17 00:00:00 2001 From: Trevin Date: Sun, 2 Nov 2025 13:48:44 -0800 Subject: [PATCH 23/26] Restrict dynamic model loading to fallback server only - Removed loadedModels_ set, now derive loaded status from modelRefCount_ > 0 - Simplified loadModel() and unloadModel() to only work with fallback server - Removed server searching logic since only fallback is supported - Added documentation clarifying fallback-only limitation - Updated error messages to reflect this architectural decision - Fixed XML comment syntax in BuildFile.xml --- .../SonicTriton/interface/TritonService.h | 6 +- .../SonicTriton/src/TritonService.cc | 93 ++++++++----------- .../SonicTriton/test/BuildFile.xml | 4 +- 3 files changed, 45 insertions(+), 58 deletions(-) diff --git a/HeterogeneousCore/SonicTriton/interface/TritonService.h b/HeterogeneousCore/SonicTriton/interface/TritonService.h index e1a3f05eae77f..b7cbff54f9b78 100644 --- a/HeterogeneousCore/SonicTriton/interface/TritonService.h +++ b/HeterogeneousCore/SonicTriton/interface/TritonService.h @@ -144,6 +144,9 @@ class TritonService { static void fillDescriptions(edm::ConfigurationDescriptions& descriptions); + // Dynamic model loading/unloading - only supported for the fallback server + // The fallback server must be started with explicit model control mode + // (--model-control-mode explicit) for these functions to work bool loadModel(const std::string& modelName, const std::string& path); bool unloadModel(const std::string& modelName); @@ -175,9 +178,8 @@ class TritonService { std::unordered_map modules_; int numberOfThreads_; - //Dynamic model loading and unloading + //Dynamic model loading and unloading (fallback server only) std::unordered_map modelRefCount_; - std::unordered_set loadedModels_; std::mutex modelLoadMutex_; }; diff --git a/HeterogeneousCore/SonicTriton/src/TritonService.cc b/HeterogeneousCore/SonicTriton/src/TritonService.cc index feffc1eb54b03..768145b32df20 100644 --- a/HeterogeneousCore/SonicTriton/src/TritonService.cc +++ b/HeterogeneousCore/SonicTriton/src/TritonService.cc @@ -568,46 +568,51 @@ void TritonService::fillDescriptions(edm::ConfigurationDescriptions& description bool TritonService::loadModel(const std::string& modelName, const std::string& path) { std::lock_guard lock(modelLoadMutex_); - bool isModelLoaded = loadedModels_.count(modelName); - if (isModelLoaded) { - modelRefCount_[modelName]++; + // Check if model is already loaded (ref count > 0) + auto refIt = modelRefCount_.find(modelName); + if (refIt != modelRefCount_.end() && refIt->second > 0) { + refIt->second++; if (verbose_) - edm::LogInfo("TritonService") << "Model " << modelName - << " already loaded, ref count: " << modelRefCount_[modelName]; + edm::LogInfo("TritonService") << "Model " << modelName << " already loaded, ref count: " << refIt->second; return true; } - // Find which server can host this model - auto mit = models_.find(modelName); - bool isNoServerAvailable = (mit == models_.end() || mit->second.servers.empty()); - if (isNoServerAvailable) { - edm::LogWarning("TritonService") << "loadModel: No server available for model " << modelName; + // Dynamic loading is only supported for the fallback server + auto sit = servers_.find(Server::fallbackName); + if (sit == servers_.end()) { + edm::LogWarning("TritonService") << "loadModel: Failed to load model " << modelName << " on server " + << Server::fallbackName << ": server not found"; return false; } - const std::string& serverName = *mit->second.servers.begin(); - auto sit = servers_.find(serverName); - if (sit == servers_.end()) { - edm::LogWarning("TritonService") << "loadModel: Server " << serverName << " not found"; + // Verify that the fallback server is actually running + if (!startedFallback_) { + edm::LogWarning("TritonService") << "loadModel: Failed to load model " << modelName << " on server " + << Server::fallbackName << ": server not started"; return false; - } // Gets first available server + } std::unique_ptr client; TRITON_THROW_IF_ERROR(tc::InferenceServerGrpcClient::Create( &client, sit->second.url, false, sit->second.useSsl, sit->second.sslOptions), - "loadModel: unable to create client for server " + serverName, + "loadModel: unable to create client for fallback server", false); // Actually load the model on the server auto err = client->LoadModel(modelName); - TRITON_THROW_IF_ERROR(err, "loadModel: failed to load model " + modelName + " on server " + serverName, false); + TRITON_THROW_IF_ERROR(err, "loadModel: failed to load model " + modelName + " on fallback server", false); - loadedModels_.insert(modelName); + // Update tracking modelRefCount_[modelName] = 1; - models_[modelName].path = path; + + // Add model to unservedModels_ if not already tracked + auto umit = unservedModels_.find(modelName); + if (umit == unservedModels_.end()) { + unservedModels_.emplace(modelName, path); + } if (verbose_) - edm::LogInfo("TritonService") << "Successfully loaded model " << modelName << " on server " << serverName; + edm::LogInfo("TritonService") << "Successfully loaded model " << modelName << " on fallback server"; return true; } @@ -615,51 +620,31 @@ bool TritonService::loadModel(const std::string& modelName, const std::string& p bool TritonService::unloadModel(const std::string& modelName) { std::lock_guard lock(modelLoadMutex_); - bool isModelLoaded = loadedModels_.count(modelName); - if (!isModelLoaded) { + // Check if model is loaded (exists in refcount map) + auto refIt = modelRefCount_.find(modelName); + if (refIt == modelRefCount_.end() || refIt->second == 0) { edm::LogWarning("TritonService") << "unloadModel: Model " << modelName << " is not loaded"; return false; } // Decrement reference count and check if still in use - modelRefCount_[modelName]--; - bool isStillReferenced = (modelRefCount_[modelName] > 0); - if (isStillReferenced) { + refIt->second--; + if (refIt->second > 0) { if (verbose_) - edm::LogInfo("TritonService") << "Model " << modelName - << " still in use, ref count: " << modelRefCount_[modelName]; + edm::LogInfo("TritonService") << "Model " << modelName << " still in use, ref count: " << refIt->second; return true; } - // Reference count reached 0, determine which server hosts this model - auto mit = models_.find(modelName); - if (mit == models_.end() || mit->second.servers.empty()) { - edm::LogWarning("TritonService") << "unloadModel: No server information for model " << modelName; + // Reference count reached 0, unload from fallback server only + // (dynamic unloading is only supported for fallback server) + auto sit = servers_.find(Server::fallbackName); + if (sit == servers_.end()) { + edm::LogWarning("TritonService") << "unloadModel: Fallback server not found"; return false; } - const std::string& serverName = *mit->second.servers.begin(); - - bool isFallbackServer = (serverName == Server::fallbackName); - if (!isFallbackServer) { - if (verbose_) - edm::LogInfo("TritonService") << "Model " << modelName << " is on shared server " << serverName - << ", not unloading (other jobs may be using it)"; - - loadedModels_.erase(modelName); - modelRefCount_.erase(modelName); - return true; - } - if (verbose_) - edm::LogInfo("TritonService") << "Model " << modelName << " is on fallback server, proceeding to unload"; - - auto sit = servers_.find(serverName); - bool isNotSafeToUnload = (sit == servers_.end()); - if (isNotSafeToUnload) { - edm::LogWarning("TritonService") << "unloadModel: Fallback server not found"; - return false; - } + edm::LogInfo("TritonService") << "Model " << modelName << " ref count is 0, unloading from fallback server"; std::unique_ptr client; TRITON_THROW_IF_ERROR(tc::InferenceServerGrpcClient::Create( @@ -670,8 +655,8 @@ bool TritonService::unloadModel(const std::string& modelName) { auto err = client->UnloadModel(modelName); TRITON_THROW_IF_ERROR(err, "unloadModel: failed to unload model " + modelName + " from fallback server", false); - loadedModels_.erase(modelName); - modelRefCount_.erase(modelName); + // Remove from tracking + modelRefCount_.erase(refIt); if (verbose_) edm::LogInfo("TritonService") << "Successfully unloaded model " << modelName << " from fallback server"; diff --git a/HeterogeneousCore/SonicTriton/test/BuildFile.xml b/HeterogeneousCore/SonicTriton/test/BuildFile.xml index a7265a19efd91..d6365368d6ef0 100644 --- a/HeterogeneousCore/SonicTriton/test/BuildFile.xml +++ b/HeterogeneousCore/SonicTriton/test/BuildFile.xml @@ -23,8 +23,8 @@ + the fallback server to be started with explicit model control mode (model-control-mode explicit). + The current cmsTriton script does not support this flag yet. --> From 96e21b1618ee1d10d974fb20f0063cf252c29b65 Mon Sep 17 00:00:00 2001 From: Trevin Date: Sun, 2 Nov 2025 17:19:03 -0800 Subject: [PATCH 24/26] Enhance RetryActionDiffServer to utilize fallback server for model loading - Updated retry logic to switch to the fallback server directly for model loading. - Removed previous best server selection logic, simplifying the retry process. - Added logging for fallback loading failures and re-evaluation after loading. - Ensured dynamic loading is explicitly handled in TritonService with updated command options. --- .../SonicTriton/src/RetryActionDiffServer.cc | 21 ++++++++------ .../SonicTriton/src/TritonService.cc | 28 ++++++------------- 2 files changed, 22 insertions(+), 27 deletions(-) diff --git a/HeterogeneousCore/SonicTriton/src/RetryActionDiffServer.cc b/HeterogeneousCore/SonicTriton/src/RetryActionDiffServer.cc index 073cc393c3ac2..71d2e0dabe6d3 100644 --- a/HeterogeneousCore/SonicTriton/src/RetryActionDiffServer.cc +++ b/HeterogeneousCore/SonicTriton/src/RetryActionDiffServer.cc @@ -19,18 +19,23 @@ void RetryActionDiffServer::retry() { try { auto* tritonClient = static_cast(client_); edm::LogInfo("RetryActionDiffServer") << "Attempting retry by switching to fallback server"; - // TODO: Get the server name from TritonService, use fallback for testing edm::Service ts; - // get best server, ignoring the current server - auto bestServerName = ts->getBestServer(tritonClient->modelName(), tritonClient->serverName()); + // Fallback-only: update client to fallback server and dynamically load the model there + const std::string& fallbackName = TritonService::Server::fallbackName; + tritonClient->updateServer(fallbackName); - if (bestServerName) { - tritonClient->updateServer(*bestServerName); - eval(); - } else { - edm::LogWarning("RetryActionDiffServer") << "No alternative server found for model " << tritonClient->modelName(); + // Load model on fallback (path not required for explicit control mode) + bool loaded = ts->loadModel(tritonClient->modelName(), ""); + if (!loaded) { + edm::LogWarning("RetryActionDiffServer") << "Fallback dynamic load failed for model " + << tritonClient->modelName(); + this->shouldRetry_ = false; + return; } + + // Re-evaluate on fallback + eval(); } catch (TritonException& e) { e.convertToWarning(); } catch (std::exception& e) { diff --git a/HeterogeneousCore/SonicTriton/src/TritonService.cc b/HeterogeneousCore/SonicTriton/src/TritonService.cc index 768145b32df20..feb7e950c1dbc 100644 --- a/HeterogeneousCore/SonicTriton/src/TritonService.cc +++ b/HeterogeneousCore/SonicTriton/src/TritonService.cc @@ -350,9 +350,9 @@ std::optional TritonService::getBestServer(const std::string& model } } if (verbose_ && bestServerName) { - edm::LogInfo("Chosen server for model '" + modelName + "': " + *bestServerName + - " (failures=" + std::to_string(bestHealth.failureCount) + - ", avgQueueTime=" + std::to_string(bestHealth.avgQueueTimeMs) + " ms)"); + edm::LogInfo("TritonDiscovery") << "Chosen server for model '" << modelName << "': " << *bestServerName + << " (failures=" << bestHealth.failureCount + << ", avgQueueTime=" << bestHealth.avgQueueTimeMs << " ms)"; } return bestServerName; } @@ -399,6 +399,8 @@ void TritonService::preBeginJob(edm::ProcessContext const&) { fallbackOpts_.command += " -r " + std::to_string(fallbackOpts_.retries); if (fallbackOpts_.wait >= 0) fallbackOpts_.command += " -w " + std::to_string(fallbackOpts_.wait); + // Explicit model control mode is required for dynamic loading + fallbackOpts_.command += " --model-control-mode explicit"; for (const auto& [modelName, model] : unservedModels_) { fallbackOpts_.command += " -m " + model.path; } @@ -568,7 +570,6 @@ void TritonService::fillDescriptions(edm::ConfigurationDescriptions& description bool TritonService::loadModel(const std::string& modelName, const std::string& path) { std::lock_guard lock(modelLoadMutex_); - // Check if model is already loaded (ref count > 0) auto refIt = modelRefCount_.find(modelName); if (refIt != modelRefCount_.end() && refIt->second > 0) { refIt->second++; @@ -577,7 +578,6 @@ bool TritonService::loadModel(const std::string& modelName, const std::string& p return true; } - // Dynamic loading is only supported for the fallback server auto sit = servers_.find(Server::fallbackName); if (sit == servers_.end()) { edm::LogWarning("TritonService") << "loadModel: Failed to load model " << modelName << " on server " @@ -585,7 +585,6 @@ bool TritonService::loadModel(const std::string& modelName, const std::string& p return false; } - // Verify that the fallback server is actually running if (!startedFallback_) { edm::LogWarning("TritonService") << "loadModel: Failed to load model " << modelName << " on server " << Server::fallbackName << ": server not started"; @@ -598,22 +597,18 @@ bool TritonService::loadModel(const std::string& modelName, const std::string& p "loadModel: unable to create client for fallback server", false); - // Actually load the model on the server auto err = client->LoadModel(modelName); TRITON_THROW_IF_ERROR(err, "loadModel: failed to load model " + modelName + " on fallback server", false); - // Update tracking modelRefCount_[modelName] = 1; - // Add model to unservedModels_ if not already tracked - auto umit = unservedModels_.find(modelName); - if (umit == unservedModels_.end()) { - unservedModels_.emplace(modelName, path); - } + // Track dynamically loaded model in service maps + auto& modelInfo(models_.emplace(modelName, path).first->second); + modelInfo.servers.insert(Server::fallbackName); + sit->second.models.insert(modelName); if (verbose_) edm::LogInfo("TritonService") << "Successfully loaded model " << modelName << " on fallback server"; - return true; } @@ -627,7 +622,6 @@ bool TritonService::unloadModel(const std::string& modelName) { return false; } - // Decrement reference count and check if still in use refIt->second--; if (refIt->second > 0) { if (verbose_) @@ -635,8 +629,6 @@ bool TritonService::unloadModel(const std::string& modelName) { return true; } - // Reference count reached 0, unload from fallback server only - // (dynamic unloading is only supported for fallback server) auto sit = servers_.find(Server::fallbackName); if (sit == servers_.end()) { edm::LogWarning("TritonService") << "unloadModel: Fallback server not found"; @@ -655,11 +647,9 @@ bool TritonService::unloadModel(const std::string& modelName) { auto err = client->UnloadModel(modelName); TRITON_THROW_IF_ERROR(err, "unloadModel: failed to unload model " + modelName + " from fallback server", false); - // Remove from tracking modelRefCount_.erase(refIt); if (verbose_) edm::LogInfo("TritonService") << "Successfully unloaded model " << modelName << " from fallback server"; - return true; } From f3abf98e86e66598b2db5d64a50e7894e59369a5 Mon Sep 17 00:00:00 2001 From: Trevin Date: Sun, 2 Nov 2025 17:33:48 -0800 Subject: [PATCH 25/26] Add FallbackModelState structure and enhance model loading/unloading in TritonService - Introduced FallbackModelState to track dynamic model state, including reference count and model path. - Updated loadModel and unloadModel methods to utilize FallbackModelState for managing model lifecycle. - Enhanced logging to reflect model state changes and operations. - Adjusted test configuration to reflect changes in module handling for dynamic model loading. --- .../SonicTriton/interface/TritonService.h | 12 +++ .../SonicTriton/src/TritonService.cc | 79 +++++++++++++------ .../SonicTriton/test/tritonTest_cfg.py | 2 +- 3 files changed, 69 insertions(+), 24 deletions(-) diff --git a/HeterogeneousCore/SonicTriton/interface/TritonService.h b/HeterogeneousCore/SonicTriton/interface/TritonService.h index b7cbff54f9b78..60ee250947682 100644 --- a/HeterogeneousCore/SonicTriton/interface/TritonService.h +++ b/HeterogeneousCore/SonicTriton/interface/TritonService.h @@ -110,6 +110,13 @@ class TritonService { std::unordered_set servers; std::unordered_set modules; }; + // Tracks fallback dynamic model state (refcount and path) + struct FallbackModelState { + std::string modelName; // canonical model id for Triton calls + std::string path; // model repository path for fallback server + int refCount{0}; + bool isLoaded() const { return refCount > 0; } + }; struct Module { //currently assumes that a module can only have one associated model Module(const std::string& model_) : model(model_) {} @@ -161,6 +168,9 @@ class TritonService { //helper template void printFallbackServerLog() const; + // Internal helpers that operate on FallbackModelState directly (caller holds lock) + bool loadModel(FallbackModelState& state); + bool unloadModel(FallbackModelState& state); bool verbose_; FallbackOpts fallbackOpts_; @@ -181,6 +191,8 @@ class TritonService { //Dynamic model loading and unloading (fallback server only) std::unordered_map modelRefCount_; std::mutex modelLoadMutex_; + // Fallback dynamic model states, keyed by model name + std::unordered_map fallbackModels_; }; #endif diff --git a/HeterogeneousCore/SonicTriton/src/TritonService.cc b/HeterogeneousCore/SonicTriton/src/TritonService.cc index feb7e950c1dbc..613b491bdcdd6 100644 --- a/HeterogeneousCore/SonicTriton/src/TritonService.cc +++ b/HeterogeneousCore/SonicTriton/src/TritonService.cc @@ -570,23 +570,34 @@ void TritonService::fillDescriptions(edm::ConfigurationDescriptions& description bool TritonService::loadModel(const std::string& modelName, const std::string& path) { std::lock_guard lock(modelLoadMutex_); - auto refIt = modelRefCount_.find(modelName); - if (refIt != modelRefCount_.end() && refIt->second > 0) { - refIt->second++; + // Resolve state and canonicalize fields + auto& state = fallbackModels_[modelName]; + if (state.modelName.empty()) state.modelName = modelName; + if (state.path.empty() && !path.empty()) state.path = path; + + return loadModel(state); +} + +bool TritonService::loadModel(FallbackModelState& state) { + // if already loaded, bump refcount + if (state.refCount > 0) { + ++state.refCount; + modelRefCount_[state.modelName] = state.refCount; if (verbose_) - edm::LogInfo("TritonService") << "Model " << modelName << " already loaded, ref count: " << refIt->second; + edm::LogInfo("TritonService") << "Model " << state.modelName << " already loaded, ref count: " + << state.refCount; return true; } auto sit = servers_.find(Server::fallbackName); if (sit == servers_.end()) { - edm::LogWarning("TritonService") << "loadModel: Failed to load model " << modelName << " on server " + edm::LogWarning("TritonService") << "loadModel: Failed to load model " << state.modelName << " on server " << Server::fallbackName << ": server not found"; return false; } if (!startedFallback_) { - edm::LogWarning("TritonService") << "loadModel: Failed to load model " << modelName << " on server " + edm::LogWarning("TritonService") << "loadModel: Failed to load model " << state.modelName << " on server " << Server::fallbackName << ": server not started"; return false; } @@ -597,35 +608,48 @@ bool TritonService::loadModel(const std::string& modelName, const std::string& p "loadModel: unable to create client for fallback server", false); - auto err = client->LoadModel(modelName); - TRITON_THROW_IF_ERROR(err, "loadModel: failed to load model " + modelName + " on fallback server", false); + auto err = client->LoadModel(state.modelName); + TRITON_THROW_IF_ERROR(err, "loadModel: failed to load model " + state.modelName + " on fallback server", false); - modelRefCount_[modelName] = 1; + // Update state and tracking + state.refCount = 1; + modelRefCount_[state.modelName] = state.refCount; // Track dynamically loaded model in service maps - auto& modelInfo(models_.emplace(modelName, path).first->second); + auto& modelInfo(models_.emplace(state.modelName, state.path).first->second); modelInfo.servers.insert(Server::fallbackName); - sit->second.models.insert(modelName); + sit->second.models.insert(state.modelName); if (verbose_) - edm::LogInfo("TritonService") << "Successfully loaded model " << modelName << " on fallback server"; + edm::LogInfo("TritonService") << "Successfully loaded model " << state.modelName << " on fallback server"; return true; } bool TritonService::unloadModel(const std::string& modelName) { std::lock_guard lock(modelLoadMutex_); - // Check if model is loaded (exists in refcount map) - auto refIt = modelRefCount_.find(modelName); - if (refIt == modelRefCount_.end() || refIt->second == 0) { + auto it = fallbackModels_.find(modelName); + if (it == fallbackModels_.end()) { edm::LogWarning("TritonService") << "unloadModel: Model " << modelName << " is not loaded"; return false; } + // Ensure struct has canonical name + if (it->second.modelName.empty()) it->second.modelName = modelName; + return unloadModel(it->second); +} + +bool TritonService::unloadModel(FallbackModelState& state) { + if (state.refCount == 0) { + edm::LogWarning("TritonService") << "unloadModel: Model " << state.modelName << " is not loaded"; + return false; + } - refIt->second--; - if (refIt->second > 0) { + if (state.refCount > 1) { + --(state.refCount); + modelRefCount_[state.modelName] = state.refCount; if (verbose_) - edm::LogInfo("TritonService") << "Model " << modelName << " still in use, ref count: " << refIt->second; + edm::LogInfo("TritonService") << "Model " << state.modelName << " still in use, ref count: " + << state.refCount; return true; } @@ -636,7 +660,7 @@ bool TritonService::unloadModel(const std::string& modelName) { } if (verbose_) - edm::LogInfo("TritonService") << "Model " << modelName << " ref count is 0, unloading from fallback server"; + edm::LogInfo("TritonService") << "Model " << state.modelName << " ref count is 1, unloading from fallback server"; std::unique_ptr client; TRITON_THROW_IF_ERROR(tc::InferenceServerGrpcClient::Create( @@ -644,12 +668,21 @@ bool TritonService::unloadModel(const std::string& modelName) { "unloadModel: unable to create client for fallback server", false); - auto err = client->UnloadModel(modelName); - TRITON_THROW_IF_ERROR(err, "unloadModel: failed to unload model " + modelName + " from fallback server", false); + auto err = client->UnloadModel(state.modelName); + TRITON_THROW_IF_ERROR(err, "unloadModel: failed to unload model " + state.modelName + " from fallback server", + false); - modelRefCount_.erase(refIt); + modelRefCount_.erase(state.modelName); + state.refCount = 0; + + // Update dynamic tracking: remove fallback association + auto mit = models_.find(state.modelName); + if (mit != models_.end()) { + mit->second.servers.erase(Server::fallbackName); + } + sit->second.models.erase(state.modelName); if (verbose_) - edm::LogInfo("TritonService") << "Successfully unloaded model " << modelName << " from fallback server"; + edm::LogInfo("TritonService") << "Successfully unloaded model " << state.modelName << " from fallback server"; return true; } diff --git a/HeterogeneousCore/SonicTriton/test/tritonTest_cfg.py b/HeterogeneousCore/SonicTriton/test/tritonTest_cfg.py index 224e139acb702..3052a409fae8c 100644 --- a/HeterogeneousCore/SonicTriton/test/tritonTest_cfg.py +++ b/HeterogeneousCore/SonicTriton/test/tritonTest_cfg.py @@ -16,7 +16,7 @@ allowed_modes = ["Async","PseudoAsync","Sync"] parser = getParser() -parser.add_argument("--modules", metavar=("MODULES"), default=["TritonGraphProducer"], nargs='+', type=str, choices=list(models), help="list of modules to run (choices: %(choices)s)") +parser.add_argument("--modules", metavar=("MODULES"), default=["DynamicModelLoadingProducer"], nargs='+', type=str, choices=list(models), help="list of modules to run (choices: %(choices)s)") parser.add_argument("--models", default=["gat_test"], nargs='+', type=str, help="list of models (same length as modules, or just 1 entry if all modules use same model)") parser.add_argument("--mode", default="Async", type=str, choices=allowed_modes, help="mode for client") parser.add_argument("--brief", default=False, action="store_true", help="briefer output for graph modules") From 01e1c306d148e5746a0851e70632e94a55c3bd48 Mon Sep 17 00:00:00 2001 From: Trevin Date: Sun, 2 Nov 2025 17:43:40 -0800 Subject: [PATCH 26/26] Refactor TritonService to remove unservedModels_ and streamline model handling - Eliminated unservedModels_ from TritonService, simplifying model management. - Updated addModel and preBeginJob methods to work directly with models_. - Enhanced error handling in loadModel for fallback server scenarios. - Improved comments for clarity on model path handling and fallback server usage. --- .../SonicTriton/interface/TritonService.h | 1 - .../SonicTriton/src/TritonService.cc | 50 ++++++++----------- 2 files changed, 21 insertions(+), 30 deletions(-) diff --git a/HeterogeneousCore/SonicTriton/interface/TritonService.h b/HeterogeneousCore/SonicTriton/interface/TritonService.h index 60ee250947682..bf83cd24e4681 100644 --- a/HeterogeneousCore/SonicTriton/interface/TritonService.h +++ b/HeterogeneousCore/SonicTriton/interface/TritonService.h @@ -179,7 +179,6 @@ class TritonService { bool startedFallback_; mutable std::atomic callFails_; std::string pid_; - std::unordered_map unservedModels_; //this represents a many:many:many map std::unordered_map servers_; //server health needs concurrent-safe edits diff --git a/HeterogeneousCore/SonicTriton/src/TritonService.cc b/HeterogeneousCore/SonicTriton/src/TritonService.cc index 613b491bdcdd6..fd70c05632b64 100644 --- a/HeterogeneousCore/SonicTriton/src/TritonService.cc +++ b/HeterogeneousCore/SonicTriton/src/TritonService.cc @@ -194,33 +194,25 @@ void TritonService::addModel(const std::string& modelName, const std::string& pa if (!allowAddModel_) throw cms::Exception("DisallowedAddModel") << "TritonService: Attempt to call addModel() outside of module constructors"; - //if model is not in the list, then no specified server provides it - auto mit = models_.find(modelName); - if (mit == models_.end()) { - auto& modelInfo(unservedModels_.emplace(modelName, path).first->second); - modelInfo.modules.insert(currentModuleId_); - //only keep track of modules that need unserved models - modules_.emplace(currentModuleId_, modelName); - } + // Ensure model exists in declared models; preserve non-empty path if provided + auto& modelInfo(models_.emplace(modelName, path).first->second); + if (modelInfo.path.empty() && !path.empty()) modelInfo.path = path; + // Track the module using this model + modelInfo.modules.insert(currentModuleId_); + modules_.emplace(currentModuleId_, modelName); } void TritonService::postModuleConstruction(edm::ModuleDescription const& desc) { allowAddModel_ = false; } void TritonService::preModuleDestruction(edm::ModuleDescription const& desc) { - //remove destructed modules from unserved list - if (unservedModels_.empty()) - return; auto id = desc.id(); auto oit = modules_.find(id); if (oit != modules_.end()) { const auto& moduleInfo(oit->second); - auto mit = unservedModels_.find(moduleInfo.model); - if (mit != unservedModels_.end()) { + auto mit = models_.find(moduleInfo.model); + if (mit != models_.end()) { auto& modelInfo(mit->second); modelInfo.modules.erase(id); - //remove a model if it is no longer needed by any modules - if (modelInfo.modules.empty()) - unservedModels_.erase(mit); } modules_.erase(oit); } @@ -359,7 +351,7 @@ std::optional TritonService::getBestServer(const std::string& model void TritonService::preBeginJob(edm::ProcessContext const&) { //only need fallback if there are unserved models - if (!fallbackOpts_.enable or unservedModels_.empty()) + if (!fallbackOpts_.enable) return; //include fallback server in set @@ -373,10 +365,12 @@ void TritonService::preBeginJob(edm::ProcessContext const&) { std::string msg; if (verbose_) msg = "List of models for fallback server: "; - //all unserved models are provided by fallback server + // Provide all declared models with known paths via the fallback server auto& server(servers_.find(Server::fallbackName)->second); - for (const auto& [modelName, model] : unservedModels_) { - auto& modelInfo(models_.emplace(modelName, model).first->second); + for (const auto& [modelName, model] : models_) { + // Only seed models for which we have a repository path + if (model.path.empty()) continue; + auto& modelInfo(models_.find(modelName)->second); modelInfo.servers.insert(Server::fallbackName); server.models.insert(modelName); if (verbose_) @@ -401,7 +395,8 @@ void TritonService::preBeginJob(edm::ProcessContext const&) { fallbackOpts_.command += " -w " + std::to_string(fallbackOpts_.wait); // Explicit model control mode is required for dynamic loading fallbackOpts_.command += " --model-control-mode explicit"; - for (const auto& [modelName, model] : unservedModels_) { + for (const auto& [modelName, model] : models_) { + if (model.path.empty()) continue; fallbackOpts_.command += " -m " + model.path; } std::string thread_string = " -I " + std::to_string(numberOfThreads_); @@ -410,8 +405,7 @@ void TritonService::preBeginJob(edm::ProcessContext const&) { fallbackOpts_.command += " -i " + fallbackOpts_.imageName; if (!fallbackOpts_.sandboxName.empty()) fallbackOpts_.command += " -s " + fallbackOpts_.sandboxName; - //don't need this anymore - unservedModels_.clear(); + // models_ remains for runtime queries; nothing to clear here //get a random temporary directory if none specified if (fallbackOpts_.tempDir.empty()) { @@ -591,15 +585,13 @@ bool TritonService::loadModel(FallbackModelState& state) { auto sit = servers_.find(Server::fallbackName); if (sit == servers_.end()) { - edm::LogWarning("TritonService") << "loadModel: Failed to load model " << state.modelName << " on server " - << Server::fallbackName << ": server not found"; - return false; + throw cms::Exception("TritonService") + << "loadModel: fallback server not found for model '" << state.modelName << "'"; } if (!startedFallback_) { - edm::LogWarning("TritonService") << "loadModel: Failed to load model " << state.modelName << " on server " - << Server::fallbackName << ": server not started"; - return false; + throw cms::Exception("TritonService") + << "loadModel: fallback server not started; cannot load model '" << state.modelName << "'"; } std::unique_ptr client;