diff --git a/src/collectives/all_gather.cc b/src/collectives/all_gather.cc index 767ab49..b9a5700 100644 --- a/src/collectives/all_gather.cc +++ b/src/collectives/all_gather.cc @@ -7,9 +7,12 @@ #include "enqueue.h" #include "collectives.h" +#include "param.h" #include "msccl/msccl_lifecycle.h" +extern int64_t ncclParamResilientEnabled(); + NCCL_API(ncclResult_t, ncclAllGather, const void* sendbuff, void* recvbuff, size_t sendcount, ncclDataType_t datatype, ncclComm_t comm, cudaStream_t stream); ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t sendcount, @@ -21,14 +24,21 @@ ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t sendcoun size_t msgsize = sendcount * ncclTypeSize(datatype); NVTX3_FUNC_WITH_PARAMS(AllGather, AllGatherSchema, msgsize) + ncclResult_t ret; if (mscclAvailable() && !mscclIsCaller()) { - return mscclEnqueueCheck( + ret = mscclEnqueueCheck( sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr, sendcount, datatype, 0, 0, ncclSum, mscclFuncAllGather, comm, stream); } + else{ + struct ncclInfo info = { ncclFuncAllGather, "AllGather", + sendbuff, recvbuff, sendcount, datatype, ncclSum, 0, comm, stream, /* Args */ + ALLGATHER_CHUNKSTEPS, ALLGATHER_SLICESTEPS }; + ret = ncclEnqueueCheck(&info); + } + if (ncclParamResilientEnabled()){ + cudaStreamSynchronize(stream); + } - struct ncclInfo info = { ncclFuncAllGather, "AllGather", - sendbuff, recvbuff, sendcount, datatype, ncclSum, 0, comm, stream, /* Args */ - ALLGATHER_CHUNKSTEPS, ALLGATHER_SLICESTEPS }; - return ncclEnqueueCheck(&info); + return ret; } diff --git a/src/collectives/all_reduce.cc b/src/collectives/all_reduce.cc index 365c73c..bea6313 100644 --- a/src/collectives/all_reduce.cc +++ b/src/collectives/all_reduce.cc @@ -7,8 +7,11 @@ #include "enqueue.h" #include "nccl.h" +#include "param.h" +#include #include "msccl/msccl_lifecycle.h" +extern int64_t ncclParamResilientEnabled(); NCCL_API(ncclResult_t, ncclAllReduce, const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm, cudaStream_t stream); @@ -26,15 +29,22 @@ ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t count, }; NvtxParamsAllReduce payload{count * ncclTypeSize(datatype), op}; NVTX3_FUNC_WITH_PARAMS(AllReduce, AllReduceSchema, payload) - + INFO(NCCL_INIT, "MSCCL: Enter into ncclAllReduce now"); + ncclResult_t ret; if (mscclAvailable() && !mscclIsCaller()) { - return mscclEnqueueCheck( + ret = mscclEnqueueCheck( sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr, count, datatype, 0, 0, op, mscclFuncAllReduce, comm, stream); } + else{ + struct ncclInfo info = { ncclFuncAllReduce, "AllReduce", + sendbuff, recvbuff, count, datatype, op, 0, comm, stream, /* Args */ + ALLREDUCE_CHUNKSTEPS, ALLREDUCE_SLICESTEPS }; + ret = ncclEnqueueCheck(&info); + } + if (ncclParamResilientEnabled()){ + cudaStreamSynchronize(stream); + } - struct ncclInfo info = { ncclFuncAllReduce, "AllReduce", - sendbuff, recvbuff, count, datatype, op, 0, comm, stream, /* Args */ - ALLREDUCE_CHUNKSTEPS, ALLREDUCE_SLICESTEPS }; - return ncclEnqueueCheck(&info); + return ret; } diff --git a/src/collectives/all_to_all.cc b/src/collectives/all_to_all.cc index a1bc655..81a8913 100644 --- a/src/collectives/all_to_all.cc +++ b/src/collectives/all_to_all.cc @@ -8,28 +8,39 @@ #include "enqueue.h" #include "collectives.h" #include "graph/topo.h" +#include "param.h" #include "msccl/msccl_lifecycle.h" +extern int64_t ncclParamResilientEnabled(); + NCCL_API(ncclResult_t, ncclAllToAll, const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype, ncclComm_t comm, cudaStream_t stream); ncclResult_t ncclAllToAll(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype, ncclComm_t comm, cudaStream_t stream) { + ncclResult_t ret; + if (mscclAvailable() && !mscclIsCaller()) { - return mscclEnqueueCheck( + ret = mscclEnqueueCheck( sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr, count, datatype, 0, 0, ncclSum, mscclFuncAllToAll, comm, stream); } - - size_t rankOffset = count * ncclTypeSize(datatype); - int nRanks; - NCCLCHECK(ncclCommCount(comm, &nRanks)); - if (count == 0) return ncclSuccess; - NCCLCHECK(ncclGroupStart()); - for (int r=0; r + template class Primitives< @@ -111,7 +113,8 @@ class Primitives< inline __device__ bool checkAbort(int &spins) { spins++; if (!(flags & Aborted) && spins == NCCL_SPINS_BEFORE_CHECK_ABORT) { - if (*ncclShmem.comm.abortFlag) { + if (*ncclShmem.comm.abortFlag || *ncclShmem.comm.resilientRepairing) { + printf("checkAbort Simple, resilientRepairingFlag:%d, abortFlag:%u \n", *ncclShmem.comm.resilientRepairing, *ncclShmem.comm.abortFlag); flags |= Aborted; ncclShmem.aborted = 1; } diff --git a/src/collectives/reduce.cc b/src/collectives/reduce.cc index 8cc9e67..7372309 100644 --- a/src/collectives/reduce.cc +++ b/src/collectives/reduce.cc @@ -8,9 +8,12 @@ #include "enqueue.h" #include "collectives.h" #include "nccl.h" +#include "param.h" #include "msccl/msccl_lifecycle.h" +extern int64_t ncclParamResilientEnabled(); + NCCL_API(ncclResult_t, ncclReduce, const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream); ncclResult_t ncclReduce(const void* sendbuff, void* recvbuff, size_t count, @@ -29,14 +32,21 @@ ncclResult_t ncclReduce(const void* sendbuff, void* recvbuff, size_t count, NvtxParamsReduce payload{count * ncclTypeSize(datatype), root, op}; NVTX3_FUNC_WITH_PARAMS(Reduce, ReduceSchema, payload) + ncclResult_t ret; if (mscclAvailable() && !mscclIsCaller()) { - return mscclEnqueueCheck( + ret = mscclEnqueueCheck( sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr, count, datatype, root, 0, op, mscclFuncReduce, comm, stream); } - - struct ncclInfo info = { ncclFuncReduce, "Reduce", + else{ + struct ncclInfo info = { ncclFuncReduce, "Reduce", sendbuff, recvbuff, count, datatype, op, root, comm, stream, /* Args */ REDUCE_CHUNKSTEPS, REDUCE_SLICESTEPS }; - return ncclEnqueueCheck(&info); + ret = ncclEnqueueCheck(&info); + } + if (ncclParamResilientEnabled()){ + cudaStreamSynchronize(stream); + } + + return ret; } diff --git a/src/collectives/reduce_scatter.cc b/src/collectives/reduce_scatter.cc index 1b06a6d..b86bd1c 100644 --- a/src/collectives/reduce_scatter.cc +++ b/src/collectives/reduce_scatter.cc @@ -8,9 +8,12 @@ #include "enqueue.h" #include "collectives.h" #include "nccl.h" +#include "param.h" #include "msccl/msccl_lifecycle.h" +extern int64_t ncclParamResilientEnabled(); + NCCL_API(ncclResult_t, ncclReduceScatter, const void* sendbuff, void* recvbuff, size_t recvcount, ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm, cudaStream_t stream); ncclResult_t ncclReduceScatter(const void* sendbuff, void* recvbuff, size_t recvcount, @@ -27,14 +30,21 @@ ncclResult_t ncclReduceScatter(const void* sendbuff, void* recvbuff, size_t recv NvtxParamsReduceScatter payload{recvcount * ncclTypeSize(datatype), op}; NVTX3_FUNC_WITH_PARAMS(ReduceScatter, ReduceScatterSchema, payload) + ncclResult_t ret; if (mscclAvailable() && !mscclIsCaller()) { - return mscclEnqueueCheck( + ret = mscclEnqueueCheck( sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr, recvcount, datatype, 0, 0, op, mscclFuncReduceScatter, comm, stream); } - - struct ncclInfo info = { ncclFuncReduceScatter, "ReduceScatter", + else{ + struct ncclInfo info = { ncclFuncReduceScatter, "ReduceScatter", sendbuff, recvbuff, recvcount, datatype, op, 0, comm, stream, /* Args */ REDUCESCATTER_CHUNKSTEPS, REDUCESCATTER_SLICESTEPS }; - return ncclEnqueueCheck(&info); + ret = ncclEnqueueCheck(&info); + } + if (ncclParamResilientEnabled()){ + cudaStreamSynchronize(stream); + } + + return ret; } diff --git a/src/collectives/sendrecv.cc b/src/collectives/sendrecv.cc index 61cd2fc..3ac1ced 100644 --- a/src/collectives/sendrecv.cc +++ b/src/collectives/sendrecv.cc @@ -8,9 +8,12 @@ #include "enqueue.h" #include "collectives.h" #include "argcheck.h" // Need some checks here since we access comm +#include "param.h" #include "msccl/msccl_lifecycle.h" +extern int64_t ncclParamResilientEnabled(); + struct NvtxParamsSendRecv { size_t bytes; int peer; @@ -27,19 +30,24 @@ ncclResult_t ncclSend(const void* sendbuff, size_t count, ncclDataType_t datatyp NvtxParamsSendRecv payload{count * ncclTypeSize(datatype), peer}; NVTX3_FUNC_WITH_PARAMS(Send, SendRecvSchema, payload) + ncclResult_t ret; if (mscclAvailable() && !mscclIsCaller()) { - return mscclEnqueueCheck( + ret = mscclEnqueueCheck( sendbuff, nullptr, nullptr, nullptr, nullptr, nullptr, count, datatype, 0, peer, ncclSum, mscclFuncSend, comm, stream); } - - struct ncclInfo info = { ncclFuncSend, "Send", + else{ + struct ncclInfo info = { ncclFuncSend, "Send", NULL, (void*)sendbuff, count, datatype, ncclSum, peer, comm, stream, /* Args */ 1, 1 }; - ncclResult_t ret; - NCCLCHECK(ncclGroupStart()); - ret = ncclEnqueueCheck(&info); - NCCLCHECK(ncclGroupEnd()); + NCCLCHECK(ncclGroupStart()); + ret = ncclEnqueueCheck(&info); + NCCLCHECK(ncclGroupEnd()); + } + if (ncclParamResilientEnabled()){ + cudaStreamSynchronize(stream); + } + return ret; } @@ -50,18 +58,23 @@ ncclResult_t ncclRecv(void* recvbuff, size_t count, ncclDataType_t datatype, int NvtxParamsSendRecv payload{count * ncclTypeSize(datatype), peer}; NVTX3_FUNC_WITH_PARAMS(Recv, SendRecvSchema, payload) + ncclResult_t ret; if (mscclAvailable() && !mscclIsCaller()) { - return mscclEnqueueCheck( + ret = mscclEnqueueCheck( nullptr, nullptr, nullptr, recvbuff, nullptr, nullptr, count, datatype, 0, peer, ncclSum, mscclFuncRecv, comm, stream); } - - struct ncclInfo info = { ncclFuncRecv, "Recv", + else{ + struct ncclInfo info = { ncclFuncRecv, "Recv", NULL, recvbuff, count, datatype, ncclSum, peer, comm, stream, /* Args */ 1, 1 }; - ncclResult_t ret; - NCCLCHECK(ncclGroupStart()); - ret = ncclEnqueueCheck(&info); - NCCLCHECK(ncclGroupEnd()); + NCCLCHECK(ncclGroupStart()); + ret = ncclEnqueueCheck(&info); + NCCLCHECK(ncclGroupEnd()); + } + if (ncclParamResilientEnabled()){ + cudaStreamSynchronize(stream); + } + return ret; } diff --git a/src/include/comm.h b/src/include/comm.h index 88e3ee7..12edf6c 100644 --- a/src/include/comm.h +++ b/src/include/comm.h @@ -347,6 +347,8 @@ struct ncclComm { int finalizeRankCnt; // Whether this comm is compatible with MSCCL bool mscclCompatible; + // Whether this comm is current in resilient repairing mode + volatile bool *resilientRepairing; }; enum ncclLaunchMode { diff --git a/src/include/devcomm.h b/src/include/devcomm.h index 7aca4ae..1182fce 100644 --- a/src/include/devcomm.h +++ b/src/include/devcomm.h @@ -312,7 +312,8 @@ struct ncclDevComm { NpKitEventCollectContext* npKitEventCollectContexts; uint64_t* cpuTimestamp; #endif - + // Whether this comm is current in resilient repairing mode + volatile bool *resilientRepairing; }; struct alignas(16) ncclDevCommAndChannels { diff --git a/src/include/msccl/msccl_scheduler.h b/src/include/msccl/msccl_scheduler.h index 1dd9f6b..9fd6e6a 100644 --- a/src/include/msccl/msccl_scheduler.h +++ b/src/include/msccl/msccl_scheduler.h @@ -36,17 +36,35 @@ struct mscclSchedulerParam { int nRanks; bool scheduled; mscclAlgoHandle_t handle; + bool repair; + void* bootstrap; + ncclResult_t (*send)(void* commState, int peer, int tag, void* data, int size); + ncclResult_t (*receive)(void* commState, int peer, int tag, void* data, int size); + ncclResult_t (*allgather)(void* commState, void* allData, int size); }; typedef struct { // Name of the scheduler (mainly for logs) const char* name; // Load all algorithms - ncclResult_t (*init)(); + ncclResult_t (*init)(struct mscclSchedulerInitParam *initParam); // Select an algorithm ncclResult_t (*selectAlgo)(struct mscclSchedulerParam* param); // Unload all algorithms ncclResult_t (*teardown)(); } mscclSchedulerInterface; +struct mscclSchedulerInitParam{ + int rank; + int nRanks; + int nNodes; + void* bootstrap; + // bootstrap send operation + ncclResult_t (*send)(void* commState, int peer, int tag, void* data, int size); + // bootstrap receive operation + ncclResult_t (*receive)(void* commState, int peer, int tag, void* data, int size); + // bootstrap allgather operation + ncclResult_t (*allgather)(void* commState, void* allData, int size); +}; + #endif diff --git a/src/include/nccl_net.h b/src/include/nccl_net.h index a387e66..7ee3fbf 100644 --- a/src/include/nccl_net.h +++ b/src/include/nccl_net.h @@ -48,6 +48,10 @@ typedef struct { ncclResult_t (*devices)(int* ndev); // Get various device properties. ncclResult_t (*getProperties)(int dev, ncclNetProperties_v6_t* props); + // Get device status. + ncclResult_t (*getStatus)(int* nstat); + // Set device status. + ncclResult_t (*setStatus)(int nstat); // Create a receiving object and provide a handle to connect to it. The // handle can be up to NCCL_NET_HANDLE_MAXSIZE bytes and will be exchanged // between ranks to create a connection. diff --git a/src/include/proxy.h b/src/include/proxy.h index d0067b1..ac3c552 100644 --- a/src/include/proxy.h +++ b/src/include/proxy.h @@ -228,6 +228,8 @@ struct ncclProxyState { // Queue of expected responses from the proxy struct ncclExpectedProxyResponse* expectedResponses; + // Whether this comm is current in resilient repairing mode + volatile bool *resilientRepairing; }; enum proxyConnectState { diff --git a/src/init.cc b/src/init.cc index e466363..a37edda 100644 --- a/src/init.cc +++ b/src/init.cc @@ -239,6 +239,7 @@ static ncclResult_t commFree(ncclComm_t comm) { if (ncclAtomicRefCountDecrement(comm->abortFlagRefCount) == 0) { NCCLCHECK(ncclCudaHostFree((void *)comm->abortFlag)); + NCCLCHECK(ncclCudaHostFree((void *)comm->resilientRepairing)); free(comm->abortFlagRefCount); } free((void*)comm->config.netName); @@ -400,6 +401,7 @@ static ncclResult_t devCommSetup(ncclComm_t comm) { tmpCommAndChans.comm.rank = comm->rank; tmpCommAndChans.comm.nRanks = nRanks; tmpCommAndChans.comm.abortFlag = comm->abortFlag; + tmpCommAndChans.comm.resilientRepairing = comm->resilientRepairing; for (int p=0; p < NCCL_NUM_PROTOCOLS; p++) { tmpCommAndChans.comm.buffSizes[p] = comm->buffSizes[p]; } @@ -1617,6 +1619,9 @@ static ncclResult_t ncclCommInitRankDev(ncclComm_t* newcomm, int nranks, ncclUni NCCLCHECKGOTO(ncclCudaHostCalloc((uint32_t**)&comm->abortFlag, 1), res, fail); NCCLCHECKGOTO(ncclCalloc((uint32_t**)&comm->abortFlagRefCount, 1), res, fail); *comm->abortFlagRefCount = 1; + NCCLCHECKGOTO(ncclCudaHostCalloc((bool**)&comm->resilientRepairing, 1), res, fail); + *comm->resilientRepairing = false; + NCCLCHECKGOTO(parseCommConfig(comm, config), res, fail); /* start with ncclInternalError and will be changed to ncclSuccess if init succeeds. */ comm->initState = ncclInternalError; @@ -1636,6 +1641,7 @@ static ncclResult_t ncclCommInitRankDev(ncclComm_t* newcomm, int nranks, ncclUni if (comm) { if (comm->abortFlag) ncclCudaHostFree((void *)comm->abortFlag); if (comm->abortFlagRefCount) free(comm->abortFlagRefCount); + if (comm->resilientRepairing) ncclCudaHostFree((void *)comm->resilientRepairing); free(comm); } if (newcomm) *newcomm = NULL; diff --git a/src/misc/msccl/msccl_lifecycle.cc b/src/misc/msccl/msccl_lifecycle.cc index 9ec9695..307ab86 100644 --- a/src/misc/msccl/msccl_lifecycle.cc +++ b/src/misc/msccl/msccl_lifecycle.cc @@ -14,6 +14,7 @@ #include #include "alloc.h" +#include "bootstrap.h" #include "checks.h" #include "graph/topo.h" @@ -21,11 +22,17 @@ #include "msccl/msccl_parser.h" #include "msccl/msccl_setup.h" #include "msccl/msccl_status.h" +#include "msccl/msccl_scheduler.h" NCCL_PARAM(MscclEnabled, "MSCCL_ENABLE", 1); static std::atomic mscclInitialized; static bool mscclSchedulerTriedLoadAlgo = false; static std::mutex mscclLifecycleMutex; +extern ncclResult_t bootstrapAllGather(void* commState, void* allData, int size); +extern ncclResult_t bootstrapSend(void* commState, int peer, int tag, void* data, int size); +extern ncclResult_t bootstrapRecv(void* commState, int peer, int tag, void* data, int size); +extern ncclNet_t ncclNetIb; +extern int64_t ncclParamResilientEnabled(); int getEnvInt(const char* env, int64_t deftVal) { char* str = getenv(env); @@ -57,6 +64,11 @@ bool mscclAvailable() { return mscclEnabled() && mscclInitialized.load(std::memory_order_acquire); } +mscclSchedulerInitParam& mscclGetSchedulerInitParam() { + static mscclSchedulerInitParam initParam; + return initParam; +} + static bool mscclCommCompatible(ncclComm_t comm) { std::map> hostHashToPidHashes; for (int i = 0; i < comm->nRanks; i++) { @@ -154,7 +166,7 @@ static ncclResult_t mscclInternalSchedulerInit() { return ncclSuccess; } -static ncclResult_t mscclSchedulerInit() { +static ncclResult_t mscclSchedulerInit(mscclSchedulerInitParam *initParam) { mscclStatus& status = mscclGetStatus(); bool useInternalScheduler = false; @@ -177,7 +189,7 @@ static ncclResult_t mscclSchedulerInit() { if (useInternalScheduler) { NCCLCHECK(mscclInternalSchedulerInit()); } else { - NCCLCHECK(status.mscclSchedulerPtr->init()); + NCCLCHECK(status.mscclSchedulerPtr->init(initParam)); } return ncclSuccess; } @@ -217,7 +229,16 @@ ncclResult_t mscclInit(ncclComm_t comm) { status.needsProxy = false; mscclSchedulerTriedLoadAlgo = false; - NCCLCHECK(mscclSchedulerInit()); + mscclSchedulerInitParam initParam = mscclGetSchedulerInitParam(); + initParam.nRanks = comm->nRanks; + initParam.rank = comm->rank; + initParam.nNodes = comm->nNodes; + initParam.bootstrap = comm->bootstrap; + initParam.send = &bootstrapSend; + initParam.receive = &bootstrapRecv; + initParam.allgather = &bootstrapAllGather; + + NCCLCHECK(mscclSchedulerInit(&initParam)); mscclInitialized.store(true, std::memory_order_release); } @@ -312,7 +333,10 @@ static ncclResult_t mscclSetSavedSchedulerParam( const void* sendBuff, const size_t sendCounts[], const size_t sDisPls[], void* recvBuff, const size_t recvCounts[], const size_t rDisPls[], size_t count, ncclDataType_t dataType, int root, int peer, ncclRedOp_t op, - mscclFunc_t func, ncclComm_t comm, cudaStream_t stream, + mscclFunc_t func, ncclComm_t comm, cudaStream_t stream, bool repair, + ncclResult_t (*send)(void* commState, int peer, int tag, void* data, int size), + ncclResult_t (*receive)(void* commState, int peer, int tag, void* data, int size), + ncclResult_t (*allgather)(void* commState, void* allData, int size), struct mscclSavedSchedulerParam* param) { param->p.sendBuff = sendBuff; param->p.sendCounts = sendCounts; @@ -328,6 +352,11 @@ static ncclResult_t mscclSetSavedSchedulerParam( param->p.func = func; param->p.rank = comm->rank; param->p.nRanks = comm->nRanks; + param->p.repair = repair; + param->p.bootstrap = comm->bootstrap; + param->p.send = send; + param->p.receive = receive; + param->p.allgather = allgather; param->comm = comm; param->stream = stream; return ncclSuccess; @@ -416,19 +445,37 @@ ncclResult_t mscclEnqueueCheck( mscclThreadLocalStatus& threadLocalStatus = mscclGetThreadLocalStatus(); threadLocalStatus.savedSchedulerParams.push_back({}); + bool repair = false; + INFO(NCCL_INIT, "MSCCL: Enter into mscclEnqueueCheck now"); + // if(ncclParamResilientEnabled() && *comm->resilientRepairing) + // { + // INFO(NCCL_INIT, "MSCCL: Enter into mscclEnqueueCheck and in resilient repairing mode now"); + // *comm->resilientRepairing = false; + // // *comm->abortFlag = 0; + // ncclNetIb.setStatus(0); + // repair = true; + // } NCCLCHECK(mscclSetSavedSchedulerParam( sendBuff, sendCounts, sDisPls, recvBuff, recvCounts, rDisPls, - count, dataType, root, peer, op, func, comm, stream, + count, dataType, root, peer, op, func, comm, stream, *comm->resilientRepairing, &bootstrapSend, &bootstrapRecv, &bootstrapAllGather, &threadLocalStatus.savedSchedulerParams.back())); switch (threadLocalStatus.groupStatus) { case mscclNoGroup: if (comm->mscclCompatible) { - NCCLCHECK(mscclSchedulerSelectAlgo(&threadLocalStatus.savedSchedulerParams.back())); - if (threadLocalStatus.savedSchedulerParams.back().p.scheduled) { - NCCLCHECK(mscclRunSavedParams()); - break; + NCCLCHECK(mscclSchedulerSelectAlgo(&threadLocalStatus.savedSchedulerParams.back())); + if (threadLocalStatus.savedSchedulerParams.back().p.scheduled) { + if(ncclParamResilientEnabled() && *comm->resilientRepairing) + { + INFO(NCCL_INIT, "MSCCL: Enter into mscclNoGroup's mscclEnqueueCheck and in resilient repairing mode now"); + *comm->resilientRepairing = false; + // *comm->abortFlag = 0; + ncclNetIb.setStatus(0); } + INFO(NCCL_INIT, "MSCCL: mscclRunSavedParams for rank: %d", comm->rank); + NCCLCHECK(mscclRunSavedParams()); + break; + } } NCCLCHECK(mscclFallBackSavedParams()); break; @@ -436,6 +483,13 @@ ncclResult_t mscclEnqueueCheck( if (comm->mscclCompatible) { NCCLCHECK(mscclSchedulerSelectAlgo(&threadLocalStatus.savedSchedulerParams.back())); if (threadLocalStatus.savedSchedulerParams.back().p.scheduled) { + if(ncclParamResilientEnabled() && *comm->resilientRepairing) + { + INFO(NCCL_INIT, "MSCCL: Enter into mscclEnqueueCheck and in resilient repairing mode now"); + *comm->resilientRepairing = false; + // *comm->abortFlag = 0; + ncclNetIb.setStatus(0); + } // Only save counts and displs when there is suitable MSCCL algorithm for this NCCLCHECK(mscclSaveCountsAndDispls(&threadLocalStatus.savedSchedulerParams.back())); break; diff --git a/src/proxy.cc b/src/proxy.cc index c16ebd2..0080e1e 100644 --- a/src/proxy.cc +++ b/src/proxy.cc @@ -8,15 +8,23 @@ #include "comm.h" #include "info.h" #include "collectives.h" +#include "bootstrap.h" #include "socket.h" #include "shm.h" #include "profiler.h" #define ENABLE_TIMER 0 #include "timer.h" +#include "param.h" #include #include +extern ncclNet_t ncclNetIb; +static bool resilientDaemonRunning = false; +static bool resilientRepairing = false; +NCCL_PARAM(ResilientEnabled, "RESILIENT_ENABLED", 0); +NCCL_PARAM(ResilientCheckInterval, "RESILIENT_CHECK_INTERVAL", 5); + static bool NeedProxy(int type, int pattern, int root, struct ncclRing* ring, int nranks) { if (pattern == ncclPatternRing || pattern == ncclPatternRingTwice) return true; @@ -680,7 +688,14 @@ static ncclResult_t progressOps(struct ncclProxyState* proxyState, struct ncclPr while (op) { if (op->state == ncclProxyOpNone) return ncclInternalError; TIME_START(0); TIME_START(1); - NCCLCHECK(op->progress(proxyState, op)); + if (*proxyState->resilientRepairing) + { + // if in resilient repairing state, we need to clean up all the ops from the proxy + INFO(NCCL_NET|NCCL_PROXY,"detect in resilient repairing mode, will start to remove Ops from the queue for rank: %d", proxyState->tpRank); + op->state = ncclProxyOpNone; + }else{ + NCCLCHECK(op->progress(proxyState, op)); + } if (op->idle) { TIME_STOP(1); TIME_CANCEL(0); } else { TIME_CANCEL(1); TIME_STOP(0); } *idle &= op->idle; if (op->state == ncclProxyOpNone) { @@ -880,6 +895,7 @@ void* ncclProxyProgress(void *proxyState_) { } lastIdle = idle; } + INFO(NCCL_ALL,"[Proxy Thread] ncclProxyProgress exit"); return NULL; } @@ -1383,7 +1399,6 @@ static ncclResult_t proxyServiceInitOp(int type, struct ncclProxyLocalPeer* peer } #include - static bool proxyMatchOpType(int type) { switch (type) { case ncclProxyMsgInit: @@ -1434,7 +1449,10 @@ void* ncclProxyService(void* _args) { /* Even if local comm aborts, we cannot let proxy thread exit if we still have peer * connections. Need to wait until all other related comms call abort and safely exit * together, or we could face segmentation fault. */ - if (*proxyState->abortFlag != 0) stop = 1; + if (*proxyState->abortFlag != 0) { + INFO(NCCL_INIT|NCCL_NET,"ncclProxyService: receive the abortFlag signal now \n"); + stop = 1; + } /* never let proxy service thread blocks in poll, or it cannot receive abortFlag. */ int ret; do { @@ -1551,6 +1569,46 @@ void* ncclProxyService(void* _args) { return NULL; } +void* ncclResilientDaemon(void* _args) { + ncclComm *comm = (ncclComm*)_args; + INFO(NCCL_INIT, "[Proxy Service] start the ncclResilientDaemon now, ranks:%d, rank:%d", comm->nRanks, comm->rank); + int interval = ncclParamResilientCheckInterval(); + int *status = (int*)malloc(comm->nRanks * sizeof(int)); + if (status == NULL) { + WARN("[Proxy Service] ncclResilientDaemon, memory allocation failed, ncclResilientDaemon will quit soon"); + } + while(resilientDaemonRunning) + { + if (!*comm->resilientRepairing) + { + int nicStat = 0; + memset(status, 0, comm->nRanks * sizeof(int)); + + ncclNetIb.getStatus(&nicStat); + status[comm->rank] = nicStat; + + bootstrapAllGather(comm->bootstrap, status, comm->nRanks * sizeof(int)); + int all_status = 0; + for (int i = 0; i < comm->nRanks; ++i) { + all_status |= status[i]; + } + + if (0 != all_status) + { + INFO(NCCL_INIT, "[Proxy Service] ncclResilientDaemon, detect the nic failure, will abort the kernel execution now"); + *comm->resilientRepairing = true; + resilientRepairing = true; + // *comm->abortFlag = 1; + } + } + sleep(interval); + } + free(status); + + INFO(NCCL_INIT, "[Proxy Service] will quit the ncclResilientDaemon now"); + return NULL; +} + ncclResult_t ncclProxyInit(struct ncclComm* comm, struct ncclSocket* sock, union ncclSocketAddress* peerAddresses) { assert(comm->sharedRes->proxyState == NULL); NCCLCHECK(ncclCalloc(&comm->sharedRes->proxyState, 1)); @@ -1579,10 +1637,17 @@ ncclResult_t ncclProxyCreate(struct ncclComm* comm) { proxyState->dmaBufSupport = comm->dmaBufSupport; proxyState->ncclNet = comm->ncclNet; proxyState->ncclCollNet = comm->ncclCollNet; + proxyState->resilientRepairing = comm->resilientRepairing; memcpy(proxyState->buffSizes, comm->buffSizes, sizeof(comm->buffSizes)); pthread_create(&comm->proxyState->thread, NULL, ncclProxyService, comm->proxyState); ncclSetThreadName(comm->proxyState->thread, "NCCL Service %2d", comm->cudaDev); + + if (ncclParamResilientEnabled()){ + pthread_t resilientDaemonthread; + resilientDaemonRunning = true; + pthread_create(&resilientDaemonthread, NULL, ncclResilientDaemon, comm); + } } return ncclSuccess; } @@ -1625,7 +1690,7 @@ ncclResult_t ncclProxyStop(struct ncclComm* comm) { } } } - + resilientDaemonRunning = false; return ncclSuccess; } diff --git a/src/transport/net_ib.cc b/src/transport/net_ib.cc index 861fa57..e0b9e99 100644 --- a/src/transport/net_ib.cc +++ b/src/transport/net_ib.cc @@ -28,6 +28,7 @@ #define MAXNAMESIZE 64 static char ncclIbIfName[MAX_IF_NAME_SIZE+1]; static union ncclSocketAddress ncclIbIfAddr; +static int nicStatus = 0; struct ncclIbMr { uintptr_t addr; @@ -92,7 +93,13 @@ static void* ncclIbAsyncThreadMain(void* args) { char *str; if (ncclSuccess != wrap_ibv_event_type_str(&str, event.event_type)) { break; } if (event.event_type != IBV_EVENT_COMM_EST) - WARN("NET/IB : Got async event : %s", str); + { + WARN("NET/IB : Got async event : %s, event type: %d", str, event.event_type); + if (strcmp(str, "local catastrophic error") == 0) { + WARN("NET/IB : Detect Nic failure, will repair soon, event type: %d", event.event_type); + nicStatus = 1; + } + } if (ncclSuccess != wrap_ibv_ack_async_event(&event)) { break; } } return NULL; @@ -336,6 +343,16 @@ ncclResult_t ncclIbGetProperties(int dev, ncclNetProperties_t* props) { return ncclSuccess; } +ncclResult_t ncclIbGetStatus(int* nstat) { + *nstat = nicStatus; + return ncclSuccess; +} + +ncclResult_t ncclIbSetStatus(int nstat) { + nicStatus = nstat; + return ncclSuccess; +} + // We need to support NCCL_NET_MAX_REQUESTS for each concurrent receive #define MAX_REQUESTS (NCCL_NET_MAX_REQUESTS*NCCL_NET_IB_MAX_RECVS) static_assert(MAX_REQUESTS <= 256, "request id are encoded in wr_id and we need up to 8 requests ids per completion"); @@ -1368,6 +1385,8 @@ ncclNet_t ncclNetIb = { ncclIbInit, ncclIbDevices, ncclIbGetProperties, + ncclIbGetStatus, + ncclIbSetStatus, ncclIbListen, ncclIbConnect, ncclIbAccept, diff --git a/src/transport/net_socket.cc b/src/transport/net_socket.cc index 08a8c3a..801d194 100644 --- a/src/transport/net_socket.cc +++ b/src/transport/net_socket.cc @@ -104,6 +104,14 @@ ncclResult_t ncclNetSocketGetProperties(int dev, ncclNetProperties_t* props) { return ncclSuccess; } +ncclResult_t ncclNetSocketGetStatus(int* nstat){ + return ncclSuccess; +} +ncclResult_t ncclNetSocketSetStatus(int nstat){ + return ncclSuccess; +} + + /* Communication functions */ #define MAX_SOCKETS 64 @@ -597,6 +605,8 @@ ncclNet_t ncclNetSocket = { ncclNetSocketInit, ncclNetSocketDevices, ncclNetSocketGetProperties, + ncclNetSocketGetStatus, + ncclNetSocketSetStatus, ncclNetSocketListen, ncclNetSocketConnect, ncclNetSocketAccept,