@@ -13,7 +13,10 @@ AvoidCudaGraphCaptureGuard::AvoidCudaGraphCaptureGuard() : mode_(cudaStreamCaptu
1313
1414AvoidCudaGraphCaptureGuard::~AvoidCudaGraphCaptureGuard () { (void )cudaThreadExchangeStreamCaptureMode (&mode_); }
1515
16+ CudaStreamWithFlags::CudaStreamWithFlags () : stream_(nullptr ) { MSCCLPP_CUDATHROW (cudaGetDevice (&deviceId_)); }
17+
1618CudaStreamWithFlags::CudaStreamWithFlags (unsigned int flags) {
19+ MSCCLPP_CUDATHROW (cudaGetDevice (&deviceId_));
1720 MSCCLPP_CUDATHROW (cudaStreamCreateWithFlags (&stream_, flags));
1821}
1922
@@ -23,22 +26,29 @@ CudaStreamWithFlags::~CudaStreamWithFlags() {
2326
2427void CudaStreamWithFlags::set (unsigned int flags) {
2528 if (!empty ()) throw Error (" CudaStreamWithFlags already set" , ErrorCode::InvalidUsage);
29+ int originalDeviceId;
30+ MSCCLPP_CUDATHROW (cudaGetDevice (&originalDeviceId)); // Save the current device
31+ MSCCLPP_CUDATHROW (cudaSetDevice (deviceId_));
2632 MSCCLPP_CUDATHROW (cudaStreamCreateWithFlags (&stream_, flags));
33+ MSCCLPP_CUDATHROW (cudaSetDevice (originalDeviceId)); // Restore the original device
2734}
2835
2936bool CudaStreamWithFlags::empty () const { return stream_ == nullptr ; }
3037
3138GpuStream::GpuStream (std::shared_ptr<GpuStreamPool> pool, std::shared_ptr<CudaStreamWithFlags> stream)
3239 : pool_(pool), stream_(stream) {}
3340
34- GpuStream::~GpuStream () { pool_->streams_ .push_back (stream_); }
41+ GpuStream::~GpuStream () { pool_->streams_ [ deviceId ()] .push_back (stream_); }
3542
3643GpuStreamPool::GpuStreamPool () {}
3744
3845GpuStream GpuStreamPool::getStream () {
39- if (!streams_.empty ()) {
40- auto stream = streams_.back ();
41- streams_.pop_back ();
46+ int deviceId;
47+ MSCCLPP_CUDATHROW (cudaGetDevice (&deviceId));
48+ auto & streamVec = streams_[deviceId];
49+ if (!streamVec.empty ()) {
50+ auto stream = streamVec.back ();
51+ streamVec.pop_back ();
4252 return GpuStream (gpuStreamPool (), stream);
4353 }
4454 return GpuStream (gpuStreamPool (), std::make_shared<CudaStreamWithFlags>(cudaStreamNonBlocking));
@@ -47,7 +57,7 @@ GpuStream GpuStreamPool::getStream() {
4757void GpuStreamPool::clear () { streams_.clear (); }
4858
4959// A global pool instance
50- std::shared_ptr<GpuStreamPool> gGpuStreamPool_ ;
60+ static std::shared_ptr<GpuStreamPool> gGpuStreamPool_ ;
5161
5262std::shared_ptr<GpuStreamPool> gpuStreamPool () {
5363 if (!gGpuStreamPool_ ) {
0 commit comments