Skip to content

Commit edd8d77

Browse files
committed
DPL: earlier forwarding
This anticipates the forwarding to the earliest possible moment, i.e. when we are about to insert the messages in a slot. This is the earliest moment we can guarantee messages will be seen only once.
1 parent 96e2f45 commit edd8d77

File tree

6 files changed

+143
-7
lines changed

6 files changed

+143
-7
lines changed

Framework/Core/include/Framework/DataRelayer.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,9 @@ class DataRelayer
114114

115115
using OnDropCallback = std::function<void(TimesliceSlot, std::vector<MessageSet>&, TimesliceIndex::OldestOutputInfo info)>;
116116

117+
// Callback for when some messages are about to be owned by the the DataRelayer
118+
using OnInsertionCallback = std::function<void(ServiceRegistryRef&, std::span<fair::mq::MessagePtr>&)>;
119+
117120
/// Prune all the pending entries in the cache.
118121
void prunePending(OnDropCallback);
119122
/// Prune the cache for a given slot
@@ -135,6 +138,7 @@ class DataRelayer
135138
InputInfo const& info,
136139
size_t nMessages,
137140
size_t nPayloads = 1,
141+
OnInsertionCallback onInsertion = nullptr,
138142
OnDropCallback onDrop = nullptr);
139143

140144
/// This is to set the oldest possible @a timeslice this relayer can

Framework/Core/src/DataProcessingDevice.cxx

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1854,11 +1854,59 @@ void DataProcessingDevice::handleData(ServiceRegistryRef ref, InputChannelInfo&
18541854
VariableContextHelpers::getTimeslice(variables);
18551855
forwardInputs(ref, slot, dropped, oldestOutputInfo, false, true);
18561856
};
1857+
1858+
auto onInsertion = [](ServiceRegistryRef& ref, std::span<fair::mq::MessagePtr>& messages) {
1859+
O2_LOG_ENABLE(forwarding);
1860+
O2_SIGNPOST_ID_GENERATE(sid, forwarding);
1861+
1862+
auto& spec = ref.get<DeviceSpec const>();
1863+
auto& context = ref.get<DataProcessorContext>();
1864+
if (!context.canForwardEarly || spec.forwards.empty()) {
1865+
O2_SIGNPOST_EVENT_EMIT(device, sid, "device", "Early forwardinding not enabled / needed.");
1866+
return;
1867+
}
1868+
1869+
O2_SIGNPOST_EVENT_EMIT(device, sid, "device", "Early forwardinding before injecting data into relayer.");
1870+
auto& timesliceIndex = ref.get<TimesliceIndex>();
1871+
auto oldestTimeslice = timesliceIndex.getOldestPossibleOutput();
1872+
1873+
auto& proxy = ref.get<FairMQDeviceProxy>();
1874+
1875+
O2_SIGNPOST_START(forwarding, sid, "forwardInputs",
1876+
"Starting forwarding for incoming messages with oldestTimeslice %zu with copy",
1877+
oldestTimeslice.timeslice.value);
1878+
std::vector<fair::mq::Parts> forwardedParts(proxy.getNumForwardChannels());
1879+
DataProcessingHelpers::routeForwardedMessages(proxy, messages, forwardedParts, true, false);
1880+
1881+
for (int fi = 0; fi < proxy.getNumForwardChannels(); fi++) {
1882+
if (forwardedParts[fi].Size() == 0) {
1883+
continue;
1884+
}
1885+
ForwardChannelInfo info = proxy.getForwardChannelInfo(ChannelIndex{fi});
1886+
auto& parts = forwardedParts[fi];
1887+
if (info.policy == nullptr) {
1888+
O2_SIGNPOST_EVENT_EMIT_ERROR(forwarding, sid, "forwardInputs", "Forwarding to %{public}s %d has no policy.", info.name.c_str(), fi);
1889+
continue;
1890+
}
1891+
O2_SIGNPOST_EVENT_EMIT(forwarding, sid, "forwardInputs", "Forwarding to %{public}s %d", info.name.c_str(), fi);
1892+
info.policy->forward(parts, ChannelIndex{fi}, ref);
1893+
}
1894+
auto& asyncQueue = ref.get<AsyncQueue>();
1895+
auto& decongestion = ref.get<DecongestionService>();
1896+
O2_SIGNPOST_ID_GENERATE(aid, async_queue);
1897+
O2_SIGNPOST_EVENT_EMIT(async_queue, aid, "forwardInputs", "Queuing forwarding oldestPossible %zu", oldestTimeslice.timeslice.value);
1898+
AsyncQueueHelpers::post(asyncQueue, AsyncTask{.timeslice = oldestTimeslice.timeslice, .id = decongestion.oldestPossibleTimesliceTask, .debounce = -1, .callback = decongestionCallbackLate}
1899+
.user<DecongestionContext>({.ref = ref, .oldestTimeslice = oldestTimeslice}));
1900+
O2_SIGNPOST_END(forwarding, sid, "forwardInputs", "Forwarding done");
1901+
O2_LOG_DISABLE(forwarding);
1902+
};
1903+
18571904
auto relayed = relayer.relay(parts.At(headerIndex)->GetData(),
18581905
&parts.At(headerIndex),
18591906
input,
18601907
nMessages,
18611908
nPayloadsPerHeader,
1909+
onInsertion,
18621910
onDrop);
18631911
switch (relayed.type) {
18641912
case DataRelayer::RelayChoice::Type::Backpressured:
@@ -2273,9 +2321,13 @@ bool DataProcessingDevice::tryDispatchComputation(ServiceRegistryRef ref, std::v
22732321
bool consumeSomething = action.op == CompletionPolicy::CompletionOp::Consume || action.op == CompletionPolicy::CompletionOp::ConsumeExisting;
22742322

22752323
if (context.canForwardEarly && hasForwards && consumeSomething) {
2276-
O2_SIGNPOST_EVENT_EMIT(device, aid, "device", "Early forwainding: %{public}s.", fmt::format("{}", action.op).c_str());
2277-
auto& timesliceIndex = ref.get<TimesliceIndex>();
2278-
forwardInputs(ref, action.slot, currentSetOfInputs, timesliceIndex.getOldestPossibleOutput(), true, action.op == CompletionPolicy::CompletionOp::Consume);
2324+
// We used to do fowarding here, however we now do it much earlier.
2325+
// We still need to clean the inputs which were already consumed
2326+
// via ConsumeExisting and which still have an header to hold the slot.
2327+
// FIXME: do we? This should really happen when we do the forwarding on
2328+
// insertion, because otherwise we lose the relevant information on how to
2329+
// navigate the set of headers. We could actually rely on the messageset index,
2330+
// is that the right thing to do though?
22792331
}
22802332
markInputsAsDone(action.slot);
22812333

Framework/Core/src/DataProcessingHelpers.cxx

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -343,8 +343,7 @@ auto DataProcessingHelpers::routeForwardedMessageSet(FairMQDeviceProxy& proxy,
343343
const bool copyByDefault, bool consume) -> std::vector<fair::mq::Parts>
344344
{
345345
// we collect all messages per forward in a map and send them together
346-
std::vector<fair::mq::Parts> forwardedParts;
347-
forwardedParts.resize(proxy.getNumForwards());
346+
std::vector<fair::mq::Parts> forwardedParts(proxy.getNumForwardChannels());
348347
std::vector<ChannelIndex> forwardingChoices{};
349348

350349
for (size_t ii = 0, ie = currentSetOfInputs.size(); ii < ie; ++ii) {

Framework/Core/src/DataRefUtils.cxx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ void* extractFromTFile(TFile& file, TClass const* cl, const char* what)
7272
return result;
7373
}
7474
} // namespace
75+
7576
// Adapted from CcdbApi private method interpretAsTMemFileAndExtract
7677
// If the former is moved to public, throws on error and could be changed to
7778
// not require a mutex we could use it.

Framework/Core/src/DataRelayer.cxx

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,8 @@ DataRelayer::RelayChoice
436436
InputInfo const& info,
437437
size_t nMessages,
438438
size_t nPayloads,
439-
std::function<void(TimesliceSlot, std::vector<MessageSet>&, TimesliceIndex::OldestOutputInfo)> onDrop)
439+
OnInsertionCallback onInsertion,
440+
OnDropCallback onDrop)
440441
{
441442
std::scoped_lock<O2_LOCKABLE(std::recursive_mutex)> lock(mMutex);
442443
DataProcessingHeader const* dph = o2::header::get<DataProcessingHeader*>(rawHeader);
@@ -482,6 +483,7 @@ DataRelayer::RelayChoice
482483
&messages,
483484
&nMessages,
484485
&nPayloads,
486+
&onInsertion,
485487
&cache = mCache,
486488
&services = mContext,
487489
numInputTypes = mDistinctRoutesIndex.size()](TimesliceId timeslice, int input, TimesliceSlot slot, InputInfo const& info) -> size_t {
@@ -512,7 +514,11 @@ DataRelayer::RelayChoice
512514
mi += nPayloads;
513515
continue;
514516
}
515-
target.add([&messages, &mi](size_t i) -> fair::mq::MessagePtr& { return messages[mi + i]; }, nPayloads + 1);
517+
auto span = std::span<fair::mq::MessagePtr>(messages + mi, messages + mi + nPayloads + 1);
518+
if (onInsertion) {
519+
onInsertion(services, span);
520+
}
521+
target.add([&span](size_t i) -> fair::mq::MessagePtr& { return span[i]; }, nPayloads + 1);
516522
mi += nPayloads;
517523
saved += nPayloads;
518524
}

Framework/Core/test/test_ForwardInputs.cxx

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -616,6 +616,80 @@ TEST_CASE("ForwardInputsSplitPayload")
616616
CHECK(result[1].Size() == 3);
617617
}
618618

619+
TEST_CASE("ForwardInputsSplitPayloadNoMessageSet")
620+
{
621+
o2::header::DataHeader dh;
622+
dh.dataOrigin = "TST";
623+
dh.dataDescription = "A";
624+
dh.subSpecification = 0;
625+
dh.splitPayloadIndex = 2;
626+
dh.splitPayloadParts = 2;
627+
628+
o2::header::DataHeader dh2;
629+
dh2.dataOrigin = "TST";
630+
dh2.dataDescription = "B";
631+
dh2.subSpecification = 0;
632+
dh2.splitPayloadIndex = 0;
633+
dh2.splitPayloadParts = 1;
634+
635+
o2::framework::DataProcessingHeader dph{0, 1};
636+
637+
std::vector<fair::mq::Channel> channels{
638+
fair::mq::Channel("from_A_to_B"),
639+
fair::mq::Channel("from_A_to_C"),
640+
};
641+
642+
bool consume = true;
643+
bool copyByDefault = true;
644+
FairMQDeviceProxy proxy;
645+
std::vector<ForwardRoute> routes{
646+
ForwardRoute{
647+
.timeslice = 0,
648+
.maxTimeslices = 1,
649+
.matcher = {"binding", ConcreteDataMatcher{"TST", "B", 0}},
650+
.channel = "from_A_to_B",
651+
.policy = nullptr,
652+
},
653+
ForwardRoute{
654+
.timeslice = 0,
655+
.maxTimeslices = 1,
656+
.matcher = {"binding", ConcreteDataMatcher{"TST", "A", 0}},
657+
.channel = "from_A_to_C",
658+
.policy = nullptr,
659+
}};
660+
661+
auto findChannelByName = [&channels](std::string const& channelName) -> fair::mq::Channel& {
662+
for (auto& channel : channels) {
663+
if (channel.GetName() == channelName) {
664+
return channel;
665+
}
666+
}
667+
throw std::runtime_error("Channel not found");
668+
};
669+
670+
proxy.bind({}, {}, routes, findChannelByName, nullptr);
671+
672+
auto transport = fair::mq::TransportFactory::CreateTransportFactory("zeromq");
673+
fair::mq::MessagePtr payload1(transport->CreateMessage());
674+
fair::mq::MessagePtr payload2(transport->CreateMessage());
675+
auto channelAlloc = o2::pmr::getTransportAllocator(transport.get());
676+
auto header = o2::pmr::getMessage(o2::header::Stack{channelAlloc, dh, dph});
677+
std::vector<std::unique_ptr<fair::mq::Message>> messages;
678+
messages.push_back(std::move(header));
679+
messages.push_back(std::move(payload1));
680+
messages.push_back(std::move(payload2));
681+
auto header2 = o2::pmr::getMessage(o2::header::Stack{channelAlloc, dh2, dph});
682+
messages.push_back(std::move(header2));
683+
messages.push_back(transport->CreateMessage());
684+
685+
std::vector<fair::mq::Parts> result(2);
686+
auto span = std::span(messages);
687+
o2::framework::DataProcessingHelpers::routeForwardedMessages(proxy, span, result, copyByDefault, consume);
688+
REQUIRE(result.size() == 2); // Two routes
689+
CHECK(result[0].Size() == 2); // No messages on this route
690+
CHECK(result[1].Size() == 3);
691+
}
692+
619693
TEST_CASE("ForwardInputEOSSingleRoute")
620694
{
621695
o2::framework::SourceInfoHeader sih{};

0 commit comments

Comments
 (0)