diff --git a/include/mscclpp/port_channel.hpp b/include/mscclpp/port_channel.hpp index dfcdf0cd3..ff374f11e 100644 --- a/include/mscclpp/port_channel.hpp +++ b/include/mscclpp/port_channel.hpp @@ -4,6 +4,8 @@ #ifndef MSCCLPP_PORT_CHANNEL_HPP_ #define MSCCLPP_PORT_CHANNEL_HPP_ +#include + #include "core.hpp" #include "port_channel_device.hpp" #include "proxy.hpp" @@ -45,6 +47,13 @@ class ProxyService : public BaseProxyService { /// @return The ID of the memory region. MemoryId addMemory(RegisteredMemory memory); + /// Unregister a memory region from the proxy service. + /// @note It is the caller’s responsibility to manage memory lifetimes safely. + /// ProxyService only ensures that memory remains valid while it is in use by the service; + /// other peers may still hold references to that memory beyond this scope. + /// @param memoryId The ID of the memory region to unregister. + void removeMemory(MemoryId memoryId); + /// Get a semaphore by ID. /// @param id The ID of the semaphore. /// @return The semaphore. @@ -72,8 +81,10 @@ class ProxyService : public BaseProxyService { std::vector> semaphores_; std::vector memories_; std::shared_ptr proxy_; - int deviceNumaNode; - std::unordered_map, int> inflightRequests; + std::set reusableMemoryIds_; + int deviceNumaNode_; + std::unordered_map, int> inflightRequests_; + std::atomic_flag lock_; void bindThread(); diff --git a/include/mscclpp/utils.hpp b/include/mscclpp/utils.hpp index 882622fca..b04a2f63f 100644 --- a/include/mscclpp/utils.hpp +++ b/include/mscclpp/utils.hpp @@ -35,6 +35,16 @@ std::string getIBDeviceName(Transport ibTransport); /// @return The InfiniBand transport associated with the specified device name. Transport getIBTransportByDeviceName(const std::string& ibDeviceName); +/// A simple spinlock implementation using std::atomic_flag. +/// It is used to protect shared resources in a multi-threaded environment. +class SpinLock { + public: + SpinLock(std::atomic_flag& flag, bool yield = true); + ~SpinLock(); + + private: + std::atomic_flag& flag_; +}; } // namespace mscclpp #endif // MSCCLPP_UTILS_HPP_ diff --git a/src/ib.cc b/src/ib.cc index 475579638..e7468c56b 100644 --- a/src/ib.cc +++ b/src/ib.cc @@ -59,8 +59,11 @@ IbMr::IbMr(ibv_pd* pd, void* buff, std::size_t size) : buff(buff) { MSCCLPP_CUTHROW(cuCtxGetDevice(&dev)); MSCCLPP_CUTHROW(cuDeviceGetAttribute(&dmaBufSupported, CU_DEVICE_ATTRIBUTE_DMA_BUF_SUPPORTED, dev)); #endif // !defined(__HIP_PLATFORM_AMD__) - if (cuMemAlloc && dmaBufSupported) { + if (cuMemAlloc) { #if !defined(__HIP_PLATFORM_AMD__) + if (!dmaBufSupported) { + throw mscclpp::Error("Please make sure dma buffer is supported by the device", ErrorCode::InvalidUsage); + } int fd; MSCCLPP_CUTHROW(cuMemGetHandleForAddressRange(&fd, addr, pages * pageSize, CU_MEM_RANGE_HANDLE_TYPE_DMA_BUF_FD, 0)); diff --git a/src/nvls.cc b/src/nvls.cc index 188028aad..9317cb640 100644 --- a/src/nvls.cc +++ b/src/nvls.cc @@ -109,7 +109,8 @@ NvlsConnection::Impl::Impl(const std::vector& data) { } NvlsConnection::Impl::~Impl() { - // we don't need to free multicast handle object according to NCCL. + // Please ensure that all memory mappings are unmapped from the handle before calling the connection destructor. + cuMemRelease(mcHandle_); if (rootPid_ == getpid()) { close(mcFileDesc_); } diff --git a/src/port_channel.cc b/src/port_channel.cc index d05162be5..74119ffdb 100644 --- a/src/port_channel.cc +++ b/src/port_channel.cc @@ -20,28 +20,49 @@ MSCCLPP_API_CPP PortChannel::PortChannel(SemaphoreId semaphoreId, std::shared_pt MSCCLPP_API_CPP ProxyService::ProxyService(size_t fifoSize) : proxy_(std::make_shared([&](ProxyTrigger triggerRaw) { return handleTrigger(triggerRaw); }, - [&]() { bindThread(); }, fifoSize)) { + [&]() { bindThread(); }, fifoSize)), + lock_(false) { int cudaDevice; MSCCLPP_CUDATHROW(cudaGetDevice(&cudaDevice)); - deviceNumaNode = getDeviceNumaNode(cudaDevice); + deviceNumaNode_ = getDeviceNumaNode(cudaDevice); } MSCCLPP_API_CPP SemaphoreId ProxyService::buildAndAddSemaphore(Communicator& communicator, std::shared_ptr connection) { + SpinLock spin(lock_); semaphores_.push_back(std::make_shared(communicator, connection)); return semaphores_.size() - 1; } MSCCLPP_API_CPP SemaphoreId ProxyService::addSemaphore(std::shared_ptr semaphore) { + SpinLock spin(lock_); semaphores_.push_back(semaphore); return semaphores_.size() - 1; } MSCCLPP_API_CPP MemoryId ProxyService::addMemory(RegisteredMemory memory) { + SpinLock spin(lock_); + if (!reusableMemoryIds_.empty()) { + auto it = reusableMemoryIds_.begin(); + MemoryId memoryId = *it; + reusableMemoryIds_.erase(it); + memories_[memoryId] = memory; + return memoryId; + } memories_.push_back(memory); return memories_.size() - 1; } +MSCCLPP_API_CPP void ProxyService::removeMemory(MemoryId memoryId) { + SpinLock spin(lock_); + if (reusableMemoryIds_.find(memoryId) != reusableMemoryIds_.end() || memoryId >= memories_.size()) { + WARN("Attempted to remove a memory that is not registered or already removed: %u", memoryId); + return; + } + memories_[memoryId] = RegisteredMemory(); + reusableMemoryIds_.insert(memoryId); +} + MSCCLPP_API_CPP std::shared_ptr ProxyService::semaphore(SemaphoreId id) const { return semaphores_[id]; } @@ -59,13 +80,14 @@ MSCCLPP_API_CPP void ProxyService::startProxy() { proxy_->start(); } MSCCLPP_API_CPP void ProxyService::stopProxy() { proxy_->stop(); } MSCCLPP_API_CPP void ProxyService::bindThread() { - if (deviceNumaNode >= 0) { - numaBind(deviceNumaNode); - INFO(MSCCLPP_INIT, "NUMA node of ProxyService proxy thread is set to %d", deviceNumaNode); + if (deviceNumaNode_ >= 0) { + numaBind(deviceNumaNode_); + INFO(MSCCLPP_INIT, "NUMA node of ProxyService proxy thread is set to %d", deviceNumaNode_); } } ProxyHandlerResult ProxyService::handleTrigger(ProxyTrigger triggerRaw) { + SpinLock spin(lock_, false); ChannelTrigger* trigger = reinterpret_cast(&triggerRaw); std::shared_ptr semaphore = semaphores_[trigger->fields.semaphoreId]; @@ -77,19 +99,19 @@ ProxyHandlerResult ProxyService::handleTrigger(ProxyTrigger triggerRaw) { RegisteredMemory& src = memories_[trigger->fields.srcMemoryId]; semaphore->connection()->write(dst, trigger->fields.dstOffset, src, trigger->fields.srcOffset, trigger->fields.size); - inflightRequests[semaphore->connection()]++; + inflightRequests_[semaphore->connection()]++; } if (trigger->fields.type & TriggerFlag) { semaphore->signal(); - inflightRequests[semaphore->connection()]++; + inflightRequests_[semaphore->connection()]++; } if (trigger->fields.type & TriggerSync || - (maxWriteQueueSize != -1 && inflightRequests[semaphore->connection()] > maxWriteQueueSize)) { + (maxWriteQueueSize != -1 && inflightRequests_[semaphore->connection()] > maxWriteQueueSize)) { semaphore->connection()->flush(); result = ProxyHandlerResult::FlushFifoTailAndContinue; - inflightRequests[semaphore->connection()] = 0; + inflightRequests_[semaphore->connection()] = 0; } return result; diff --git a/src/registered_memory.cc b/src/registered_memory.cc index e9855dd3a..58b91c2af 100644 --- a/src/registered_memory.cc +++ b/src/registered_memory.cc @@ -277,7 +277,15 @@ RegisteredMemory::Impl::Impl(const std::vector& serialization) { RegisteredMemory::Impl::~Impl() { // Close the CUDA IPC handle if it was opened during deserialization - if (data && transports.has(Transport::CudaIpc) && getHostHash() == this->hostHash && getPidHash() != this->pidHash) { + if (data && transports.has(Transport::CudaIpc) && getHostHash() == this->hostHash) { + if (getPidHash() == this->pidHash) { + // For local registered memory + if (fileDesc >= 0) { + close(fileDesc); + fileDesc = -1; + } + return; + } void* base = static_cast(data) - getTransportInfo(Transport::CudaIpc).cudaIpcOffsetFromBase; if (this->isCuMemMapAlloc) { CUmemGenericAllocationHandle handle; @@ -288,9 +296,6 @@ RegisteredMemory::Impl::~Impl() { MSCCLPP_CULOG_WARN(cuMemUnmap((CUdeviceptr)base, size)); MSCCLPP_CULOG_WARN(cuMemRelease(handle)); MSCCLPP_CULOG_WARN(cuMemAddressFree((CUdeviceptr)base, size)); - if (getNvlsMemHandleType() == CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR && fileDesc >= 0) { - close(fileDesc); - } } else { cudaError_t err = cudaIpcCloseMemHandle(base); if (err != cudaSuccess) { diff --git a/src/utils.cc b/src/utils.cc index 803a5fc62..e450e2bd3 100644 --- a/src/utils.cc +++ b/src/utils.cc @@ -21,4 +21,14 @@ std::string getHostName(int maxlen, const char delim) { return hostname.substr(0, i); } +SpinLock::SpinLock(std::atomic_flag& flag, bool yield) : flag_(flag) { + while (flag_.test_and_set(std::memory_order_acq_rel)) { + if (yield) { + std::this_thread::yield(); + } + } +} + +SpinLock::~SpinLock() { flag_.clear(std::memory_order_release); } + } // namespace mscclpp diff --git a/test/unit/core_tests.cc b/test/unit/core_tests.cc index d2a53d434..0a0fb7ef1 100644 --- a/test/unit/core_tests.cc +++ b/test/unit/core_tests.cc @@ -5,6 +5,7 @@ #include #include +#include class LocalCommunicatorTest : public ::testing::Test { protected: @@ -12,10 +13,12 @@ class LocalCommunicatorTest : public ::testing::Test { bootstrap = std::make_shared(0, 1); bootstrap->initialize(bootstrap->createUniqueId()); comm = std::make_shared(bootstrap); + proxyService = std::make_shared(); } std::shared_ptr bootstrap; std::shared_ptr comm; + std::shared_ptr proxyService; }; TEST_F(LocalCommunicatorTest, RegisterMemory) { @@ -36,3 +39,12 @@ TEST_F(LocalCommunicatorTest, SendMemoryToSelf) { EXPECT_EQ(sameMemory.size(), memory.size()); EXPECT_EQ(sameMemory.transports(), memory.transports()); } + +TEST_F(LocalCommunicatorTest, ProxyServiceAddRemoveMemory) { + auto memory = mscclpp::RegisteredMemory(); + auto memoryId = proxyService->addMemory(memory); + EXPECT_EQ(memoryId, 0); + proxyService->removeMemory(memoryId); + memoryId = proxyService->addMemory(memory); + EXPECT_EQ(memoryId, 0); +}