Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions comms/common/algorithms/AlgoUtils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@ namespace meta {
namespace comms {

inline uint32_t divRoundUp(size_t a, size_t b) {
return static_cast<uint32_t>((a + b - 1) / b);
uint32_t y = static_cast<uint32_t>((a + b - 1) / b);
if (y == 0) {
y = 1;
}
return y;
}

constexpr uint32_t
Expand All @@ -18,7 +22,11 @@ calcBlockCount(size_t numThreads, size_t threadsPerBlock, size_t maxBlocks) {
// Overflow safe variant of (a + b - 1) / b
const uint64_t blocks =
uNumThreads / uThreadsPerBlock + (uNumThreads % uThreadsPerBlock != 0);
return static_cast<uint32_t>(std::min(blocks, maxBlocks));
uint32_t y = static_cast<uint32_t>(std::min(blocks, maxBlocks));
if (y == 0) {
y = 1;
}
return y;
}

std::pair<dim3, dim3>
Expand Down
17 changes: 13 additions & 4 deletions comms/rcclx/develop/meta/algorithms/AlgoInit.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
#include "param.h"

// Meta custom algorithm configs
RCCL_PARAM(DdaNRanks, "DDA_NRANKS", 8);
RCCL_PARAM(DdaMaxBlocks, "DDA_MAX_BLOCKS", 24);
RCCL_PARAM(DdaSendbufBytes, "DDA_SENDBUF_BYTES", 32 * 1024 * 1024);

RCCL_PARAM(EnableDdaAllReduce, "ENABLE_DDA_ALL_REDUCE", 0);
RCCL_PARAM(EnableDdaAllReduce, "ENABLE_DDA_ALL_REDUCE", 1);
RCCL_PARAM(
DdaAllReduceFlatMaxBytes,
"DDA_ALL_REDUCE_FLAT_MAX_BYTES",
Expand All @@ -22,16 +23,16 @@ RCCL_PARAM(
"DDA_ALL_REDUCE_TREE_MAX_BYTES",
29 * 1024 * 1024);

RCCL_PARAM(EnableDdaAllGather, "ENABLE_DDA_ALL_GATHER", 0);
RCCL_PARAM(EnableDdaAllGather, "ENABLE_DDA_ALL_GATHER", 1);
RCCL_PARAM(DdaAllGatherMaxBytes, "DDA_ALL_GATHER_MAX_BYTES", 16 * 1024 * 1024);

RCCL_PARAM(EnableDdaReduceScatter, "ENABLE_DDA_REDUCE_SCATTER", 0);
RCCL_PARAM(EnableDdaReduceScatter, "ENABLE_DDA_REDUCE_SCATTER", 1);
RCCL_PARAM(
DdaReduceScatterMaxBytes,
"DDA_REDUCE_SCATTER_MAX_BYTES",
8 * 1024 * 1024);

RCCL_PARAM(EnableDdaAllToAll, "ENABLE_DDA_ALL_TO_ALL", 0);
RCCL_PARAM(EnableDdaAllToAll, "ENABLE_DDA_ALL_TO_ALL", 1);
RCCL_PARAM(DdaAllToAllMaxBytes, "DDA_ALL_TO_ALL_MAX_BYTES", 2 * 1024 * 1024);

std::unique_ptr<meta::comms::AlgoFactoryDev> initAlgoFactory(ncclComm_t comm) {
Expand All @@ -46,6 +47,14 @@ std::unique_ptr<meta::comms::AlgoFactoryDev> initAlgoFactory(ncclComm_t comm) {
return nullptr;
}

if (comm->nRanks != rcclParamDdaNRanks()) {
INFO(
NCCL_INIT,
"Disabling DDA for single-node when nRanks != 8 setup (nRanks=%d)",
comm->nRanks);
return nullptr;
}

return std::make_unique<::meta::comms::AlgoFactoryDev>(
std::make_shared<::rcclx::BaselineBootstrap>(comm),
comm->nRanks,
Expand Down
Loading