Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions src/collectives/all_gather.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@

#include "enqueue.h"
#include "collectives.h"
#include "param.h"

#include "msccl/msccl_lifecycle.h"

extern int64_t ncclParamResilientEnabled();

NCCL_API(ncclResult_t, ncclAllGather, const void* sendbuff, void* recvbuff, size_t sendcount,
ncclDataType_t datatype, ncclComm_t comm, cudaStream_t stream);
ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t sendcount,
Expand All @@ -21,14 +24,21 @@ ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t sendcoun
size_t msgsize = sendcount * ncclTypeSize(datatype);
NVTX3_FUNC_WITH_PARAMS(AllGather, AllGatherSchema, msgsize)

ncclResult_t ret;
if (mscclAvailable() && !mscclIsCaller()) {
return mscclEnqueueCheck(
ret = mscclEnqueueCheck(
sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr,
sendcount, datatype, 0, 0, ncclSum, mscclFuncAllGather, comm, stream);
}
else{
struct ncclInfo info = { ncclFuncAllGather, "AllGather",
sendbuff, recvbuff, sendcount, datatype, ncclSum, 0, comm, stream, /* Args */
ALLGATHER_CHUNKSTEPS, ALLGATHER_SLICESTEPS };
ret = ncclEnqueueCheck(&info);
}
if (ncclParamResilientEnabled()){
cudaStreamSynchronize(stream);
}

struct ncclInfo info = { ncclFuncAllGather, "AllGather",
sendbuff, recvbuff, sendcount, datatype, ncclSum, 0, comm, stream, /* Args */
ALLGATHER_CHUNKSTEPS, ALLGATHER_SLICESTEPS };
return ncclEnqueueCheck(&info);
return ret;
}
22 changes: 16 additions & 6 deletions src/collectives/all_reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@

#include "enqueue.h"
#include "nccl.h"
#include "param.h"
#include <cuda_runtime.h>

#include "msccl/msccl_lifecycle.h"
extern int64_t ncclParamResilientEnabled();

NCCL_API(ncclResult_t, ncclAllReduce, const void* sendbuff, void* recvbuff, size_t count,
ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm, cudaStream_t stream);
Expand All @@ -26,15 +29,22 @@ ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t count,
};
NvtxParamsAllReduce payload{count * ncclTypeSize(datatype), op};
NVTX3_FUNC_WITH_PARAMS(AllReduce, AllReduceSchema, payload)

INFO(NCCL_INIT, "MSCCL: Enter into ncclAllReduce now");
ncclResult_t ret;
if (mscclAvailable() && !mscclIsCaller()) {
return mscclEnqueueCheck(
ret = mscclEnqueueCheck(
sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr,
count, datatype, 0, 0, op, mscclFuncAllReduce, comm, stream);
}
else{
struct ncclInfo info = { ncclFuncAllReduce, "AllReduce",
sendbuff, recvbuff, count, datatype, op, 0, comm, stream, /* Args */
ALLREDUCE_CHUNKSTEPS, ALLREDUCE_SLICESTEPS };
ret = ncclEnqueueCheck(&info);
}
if (ncclParamResilientEnabled()){
cudaStreamSynchronize(stream);
}

struct ncclInfo info = { ncclFuncAllReduce, "AllReduce",
sendbuff, recvbuff, count, datatype, op, 0, comm, stream, /* Args */
ALLREDUCE_CHUNKSTEPS, ALLREDUCE_SLICESTEPS };
return ncclEnqueueCheck(&info);
return ret;
}
35 changes: 23 additions & 12 deletions src/collectives/all_to_all.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,39 @@
#include "enqueue.h"
#include "collectives.h"
#include "graph/topo.h"
#include "param.h"

#include "msccl/msccl_lifecycle.h"

extern int64_t ncclParamResilientEnabled();

NCCL_API(ncclResult_t, ncclAllToAll, const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype,
ncclComm_t comm, cudaStream_t stream);
ncclResult_t ncclAllToAll(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype,
ncclComm_t comm, cudaStream_t stream) {
ncclResult_t ret;

if (mscclAvailable() && !mscclIsCaller()) {
return mscclEnqueueCheck(
ret = mscclEnqueueCheck(
sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr,
count, datatype, 0, 0, ncclSum, mscclFuncAllToAll, comm, stream);
}

size_t rankOffset = count * ncclTypeSize(datatype);
int nRanks;
NCCLCHECK(ncclCommCount(comm, &nRanks));
if (count == 0) return ncclSuccess;
NCCLCHECK(ncclGroupStart());
for (int r=0; r<nRanks; r++) {
NCCLCHECK(ncclSend(((char*)sendbuff)+r*rankOffset, count, datatype, r, comm, stream));
NCCLCHECK(ncclRecv(((char*)recvbuff)+r*rankOffset, count, datatype, r, comm, stream));
else{
size_t rankOffset = count * ncclTypeSize(datatype);
int nRanks;
NCCLCHECK(ncclCommCount(comm, &nRanks));
if (count == 0) return ncclSuccess;
NCCLCHECK(ncclGroupStart());
for (int r=0; r<nRanks; r++) {
NCCLCHECK(ncclSend(((char*)sendbuff)+r*rankOffset, count, datatype, r, comm, stream));
NCCLCHECK(ncclRecv(((char*)recvbuff)+r*rankOffset, count, datatype, r, comm, stream));
}
NCCLCHECK(ncclGroupEnd());
ret = ncclSuccess;
}
NCCLCHECK(ncclGroupEnd());
return ncclSuccess;
if (ncclParamResilientEnabled()){
cudaStreamSynchronize(stream);
}

return ret;
}
15 changes: 13 additions & 2 deletions src/collectives/broadcast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@

#include "enqueue.h"
#include "collectives.h"
#include "param.h"

#include "msccl/msccl_lifecycle.h"

extern int64_t ncclParamResilientEnabled();

NCCL_API(ncclResult_t, ncclBroadcast, const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype, int root,
ncclComm_t comm, cudaStream_t stream);
ncclResult_t ncclBroadcast(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype, int root,
Expand All @@ -25,15 +28,23 @@ ncclResult_t ncclBroadcast(const void* sendbuff, void* recvbuff, size_t count, n
NvtxParamsBroadcast payload{count * ncclTypeSize(datatype), root};
NVTX3_FUNC_WITH_PARAMS(Broadcast, BroadcastSchema, payload)

ncclResult_t ret;
if (mscclAvailable() && !mscclIsCaller()) {
return mscclEnqueueCheck(
ret = mscclEnqueueCheck(
sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr,
count, datatype, root, 0, ncclSum, mscclFuncBroadcast, comm, stream);
}
else{
struct ncclInfo info = { ncclFuncBroadcast, "Broadcast",
sendbuff, recvbuff, count, datatype, ncclSum, root, comm, stream, /* Args */
BROADCAST_CHUNKSTEPS, BROADCAST_SLICESTEPS };
return ncclEnqueueCheck(&info);
ret = ncclEnqueueCheck(&info);
}
if (ncclParamResilientEnabled()){
cudaStreamSynchronize(stream);
}

return ret;
}
/* Deprecated original "in place" function, similar to MPI */
NCCL_API(ncclResult_t, ncclBcast, void* buff, size_t count, ncclDataType_t datatype, int root,
Expand Down
5 changes: 4 additions & 1 deletion src/collectives/device/prims_simple.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
#include "npkit/npkit.h"
#endif

#include <stdio.h>

template<typename T, typename RedOp, typename Fan, int Direct,
int SlicePerChunk, int StepPerSlice, int Unroll, int P2p, int MultimemSrcs, int MultimemDsts>
class Primitives<
Expand Down Expand Up @@ -111,7 +113,8 @@ class Primitives<
inline __device__ bool checkAbort(int &spins) {
spins++;
if (!(flags & Aborted) && spins == NCCL_SPINS_BEFORE_CHECK_ABORT) {
if (*ncclShmem.comm.abortFlag) {
if (*ncclShmem.comm.abortFlag || *ncclShmem.comm.resilientRepairing) {
printf("checkAbort Simple, resilientRepairingFlag:%d, abortFlag:%u \n", *ncclShmem.comm.resilientRepairing, *ncclShmem.comm.abortFlag);
flags |= Aborted;
ncclShmem.aborted = 1;
}
Expand Down
18 changes: 14 additions & 4 deletions src/collectives/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@
#include "enqueue.h"
#include "collectives.h"
#include "nccl.h"
#include "param.h"

#include "msccl/msccl_lifecycle.h"

extern int64_t ncclParamResilientEnabled();

NCCL_API(ncclResult_t, ncclReduce, const void* sendbuff, void* recvbuff, size_t count,
ncclDataType_t datatype, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream);
ncclResult_t ncclReduce(const void* sendbuff, void* recvbuff, size_t count,
Expand All @@ -29,14 +32,21 @@ ncclResult_t ncclReduce(const void* sendbuff, void* recvbuff, size_t count,
NvtxParamsReduce payload{count * ncclTypeSize(datatype), root, op};
NVTX3_FUNC_WITH_PARAMS(Reduce, ReduceSchema, payload)

ncclResult_t ret;
if (mscclAvailable() && !mscclIsCaller()) {
return mscclEnqueueCheck(
ret = mscclEnqueueCheck(
sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr,
count, datatype, root, 0, op, mscclFuncReduce, comm, stream);
}

struct ncclInfo info = { ncclFuncReduce, "Reduce",
else{
struct ncclInfo info = { ncclFuncReduce, "Reduce",
sendbuff, recvbuff, count, datatype, op, root, comm, stream, /* Args */
REDUCE_CHUNKSTEPS, REDUCE_SLICESTEPS };
return ncclEnqueueCheck(&info);
ret = ncclEnqueueCheck(&info);
}
if (ncclParamResilientEnabled()){
cudaStreamSynchronize(stream);
}

return ret;
}
18 changes: 14 additions & 4 deletions src/collectives/reduce_scatter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@
#include "enqueue.h"
#include "collectives.h"
#include "nccl.h"
#include "param.h"

#include "msccl/msccl_lifecycle.h"

extern int64_t ncclParamResilientEnabled();

NCCL_API(ncclResult_t, ncclReduceScatter, const void* sendbuff, void* recvbuff, size_t recvcount,
ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm, cudaStream_t stream);
ncclResult_t ncclReduceScatter(const void* sendbuff, void* recvbuff, size_t recvcount,
Expand All @@ -27,14 +30,21 @@ ncclResult_t ncclReduceScatter(const void* sendbuff, void* recvbuff, size_t recv
NvtxParamsReduceScatter payload{recvcount * ncclTypeSize(datatype), op};
NVTX3_FUNC_WITH_PARAMS(ReduceScatter, ReduceScatterSchema, payload)

ncclResult_t ret;
if (mscclAvailable() && !mscclIsCaller()) {
return mscclEnqueueCheck(
ret = mscclEnqueueCheck(
sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr,
recvcount, datatype, 0, 0, op, mscclFuncReduceScatter, comm, stream);
}

struct ncclInfo info = { ncclFuncReduceScatter, "ReduceScatter",
else{
struct ncclInfo info = { ncclFuncReduceScatter, "ReduceScatter",
sendbuff, recvbuff, recvcount, datatype, op, 0, comm, stream, /* Args */
REDUCESCATTER_CHUNKSTEPS, REDUCESCATTER_SLICESTEPS };
return ncclEnqueueCheck(&info);
ret = ncclEnqueueCheck(&info);
}
if (ncclParamResilientEnabled()){
cudaStreamSynchronize(stream);
}

return ret;
}
41 changes: 27 additions & 14 deletions src/collectives/sendrecv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@
#include "enqueue.h"
#include "collectives.h"
#include "argcheck.h" // Need some checks here since we access comm
#include "param.h"

#include "msccl/msccl_lifecycle.h"

extern int64_t ncclParamResilientEnabled();

struct NvtxParamsSendRecv {
size_t bytes;
int peer;
Expand All @@ -27,19 +30,24 @@ ncclResult_t ncclSend(const void* sendbuff, size_t count, ncclDataType_t datatyp
NvtxParamsSendRecv payload{count * ncclTypeSize(datatype), peer};
NVTX3_FUNC_WITH_PARAMS(Send, SendRecvSchema, payload)

ncclResult_t ret;
if (mscclAvailable() && !mscclIsCaller()) {
return mscclEnqueueCheck(
ret = mscclEnqueueCheck(
sendbuff, nullptr, nullptr, nullptr, nullptr, nullptr,
count, datatype, 0, peer, ncclSum, mscclFuncSend, comm, stream);
}

struct ncclInfo info = { ncclFuncSend, "Send",
else{
struct ncclInfo info = { ncclFuncSend, "Send",
NULL, (void*)sendbuff, count, datatype, ncclSum, peer, comm, stream, /* Args */
1, 1 };
ncclResult_t ret;
NCCLCHECK(ncclGroupStart());
ret = ncclEnqueueCheck(&info);
NCCLCHECK(ncclGroupEnd());
NCCLCHECK(ncclGroupStart());
ret = ncclEnqueueCheck(&info);
NCCLCHECK(ncclGroupEnd());
}
if (ncclParamResilientEnabled()){
cudaStreamSynchronize(stream);
}

return ret;
}

Expand All @@ -50,18 +58,23 @@ ncclResult_t ncclRecv(void* recvbuff, size_t count, ncclDataType_t datatype, int
NvtxParamsSendRecv payload{count * ncclTypeSize(datatype), peer};
NVTX3_FUNC_WITH_PARAMS(Recv, SendRecvSchema, payload)

ncclResult_t ret;
if (mscclAvailable() && !mscclIsCaller()) {
return mscclEnqueueCheck(
ret = mscclEnqueueCheck(
nullptr, nullptr, nullptr, recvbuff, nullptr, nullptr,
count, datatype, 0, peer, ncclSum, mscclFuncRecv, comm, stream);
}

struct ncclInfo info = { ncclFuncRecv, "Recv",
else{
struct ncclInfo info = { ncclFuncRecv, "Recv",
NULL, recvbuff, count, datatype, ncclSum, peer, comm, stream, /* Args */
1, 1 };
ncclResult_t ret;
NCCLCHECK(ncclGroupStart());
ret = ncclEnqueueCheck(&info);
NCCLCHECK(ncclGroupEnd());
NCCLCHECK(ncclGroupStart());
ret = ncclEnqueueCheck(&info);
NCCLCHECK(ncclGroupEnd());
}
if (ncclParamResilientEnabled()){
cudaStreamSynchronize(stream);
}

return ret;
}
2 changes: 2 additions & 0 deletions src/include/comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,8 @@ struct ncclComm {
int finalizeRankCnt;
// Whether this comm is compatible with MSCCL
bool mscclCompatible;
// Whether this comm is current in resilient repairing mode
volatile bool *resilientRepairing;
};

enum ncclLaunchMode {
Expand Down
3 changes: 2 additions & 1 deletion src/include/devcomm.h
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,8 @@ struct ncclDevComm {
NpKitEventCollectContext* npKitEventCollectContexts;
uint64_t* cpuTimestamp;
#endif

// Whether this comm is current in resilient repairing mode
volatile bool *resilientRepairing;
};

struct alignas(16) ncclDevCommAndChannels {
Expand Down
Loading