From f2a4bf5ad05b9509ccefb770ca9b3bea51971cd2 Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Mon, 12 May 2025 11:47:15 +0800 Subject: [PATCH 01/38] connect_add_rvec --- src/ucp/worker.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/ucp/worker.rs b/src/ucp/worker.rs index 00abd16..d034a03 100644 --- a/src/ucp/worker.rs +++ b/src/ucp/worker.rs @@ -122,6 +122,11 @@ impl Worker { Endpoint::connect_addr(self, addr.handle) } + /// Connect to a remote worker by address. + pub fn connect_addr_vec(self: &Rc, addr: &[u8]) -> Result { + Endpoint::connect_addr(self, addr.as_ptr() as _) + } + /// Connect to a remote listener. pub async fn connect_socket(self: &Rc, addr: SocketAddr) -> Result { Endpoint::connect_socket(self, addr).await @@ -193,6 +198,7 @@ impl<'a> AsRef<[u8]> for WorkerAddress<'a> { impl<'a> Drop for WorkerAddress<'a> { fn drop(&mut self) { + trace!("destroy worker address= {:?} {:?}", self.worker.handle, self.handle); unsafe { ucp_worker_release_address(self.worker.handle, self.handle) } } } From 1b56b62fb69486ba25a719486f9bb187c0aaf98a Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Tue, 13 May 2025 10:26:13 +0800 Subject: [PATCH 02/38] disable build ucx --- ucx1-sys/build.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ucx1-sys/build.rs b/ucx1-sys/build.rs index 2239b8e..0f85602 100644 --- a/ucx1-sys/build.rs +++ b/ucx1-sys/build.rs @@ -15,7 +15,7 @@ fn main() { // Tell cargo to invalidate the built crate whenever the wrapper changes println!("cargo:rerun-if-changed=wrapper.h"); - build_from_source(); + // build_from_source(); // The bindgen::Builder is the main entry point // to bindgen, and lets you build up options for From 1f1f68a827c493b3b0756c0bbf7ebdf39d53bbc7 Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Fri, 27 Jun 2025 10:26:14 +0800 Subject: [PATCH 03/38] tag use send/recv_nbx --- src/ucp/endpoint/tag.rs | 87 ++++++++++++++++++++++++++++++++--------- 1 file changed, 69 insertions(+), 18 deletions(-) diff --git a/src/ucp/endpoint/tag.rs b/src/ucp/endpoint/tag.rs index 6f04b4d..48f6c55 100644 --- a/src/ucp/endpoint/tag.rs +++ b/src/ucp/endpoint/tag.rs @@ -26,27 +26,38 @@ impl Worker { unsafe extern "C" fn callback( request: *mut c_void, status: ucs_status_t, - info: *mut ucp_tag_recv_info, + info: *const ucp_tag_recv_info, + _user_data: *mut c_void, ) { let length = (*info).length; + let sender_tag = (*info).sender_tag; trace!( - "tag_recv: complete. req={:?}, status={:?}, len={}", + "tag_recv: complete. req={:?}, status={:?}, tag={}, len={}", request, status, + sender_tag, length ); let request = &mut *(request as *mut Request); request.waker.wake(); } + let param = ucp_request_param_t { + op_attr_mask: ucp_op_attr_t::UCP_OP_ATTR_FIELD_CALLBACK as u32, + cb: unsafe { + let mut cb: ucp_request_param_t__bindgen_ty_1 = std::mem::zeroed(); + cb.recv = Some(callback); + cb + }, + ..unsafe { std::mem::zeroed() } + }; let status = unsafe { - ucp_tag_recv_nb( + ucp_tag_recv_nbx( self.handle, buf.as_mut_ptr() as _, buf.len() as _, - ucp_dt_make_contig(1), tag, tag_mask, - Some(callback), + ¶m, ) }; @@ -72,27 +83,41 @@ impl Worker { unsafe extern "C" fn callback( request: *mut c_void, status: ucs_status_t, - info: *mut ucp_tag_recv_info, + info: *const ucp_tag_recv_info, + _user_data: *mut c_void, ) { let length = (*info).length; + let tag = (*info).sender_tag; trace!( - "tag_recv_vectored: complete. req={:?}, status={:?}, len={}", + "tag_recv_vectored: complete. req={:?}, status={:?}, tag={}, len={}", request, status, + tag, length ); let request = &mut *(request as *mut Request); request.waker.wake(); } + // --- 修改为ucp_tag_recv_nbx --- + let param = ucp_request_param_t { + op_attr_mask: ucp_op_attr_t::UCP_OP_ATTR_FIELD_CALLBACK as u32 + | ucp_op_attr_t::UCP_OP_ATTR_FIELD_DATATYPE as u32, + cb: unsafe { + let mut cb: ucp_request_param_t__bindgen_ty_1 = std::mem::zeroed(); + cb.recv = Some(callback); + cb + }, + datatype: ucp_dt_type::UCP_DATATYPE_IOV as _, + ..unsafe { std::mem::zeroed() } + }; let status = unsafe { - ucp_tag_recv_nb( + ucp_tag_recv_nbx( self.handle, iov.as_ptr() as _, iov.len() as _, - ucp_dt_type::UCP_DATATYPE_IOV as _, tag, u64::max_value(), - Some(callback), + ¶m, ) }; Error::from_ptr(status)?; @@ -109,19 +134,31 @@ impl Endpoint { /// Sends a messages with `tag`. pub async fn tag_send(&self, tag: u64, buf: &[u8]) -> Result { trace!("tag_send: endpoint={:?} len={}", self.handle, buf.len()); - unsafe extern "C" fn callback(request: *mut c_void, status: ucs_status_t) { + unsafe extern "C" fn callback( + request: *mut c_void, + status: ucs_status_t, + _user_data: *mut c_void, + ) { trace!("tag_send: complete. req={:?}, status={:?}", request, status); let request = &mut *(request as *mut Request); request.waker.wake(); } + let param = ucp_request_param_t { + op_attr_mask: ucp_op_attr_t::UCP_OP_ATTR_FIELD_CALLBACK as u32, + cb: unsafe { + let mut cb: ucp_request_param_t__bindgen_ty_1 = std::mem::zeroed(); + cb.send = Some(callback); + cb + }, + ..unsafe { std::mem::zeroed() } + }; let status = unsafe { - ucp_tag_send_nb( + ucp_tag_send_nbx( self.get_handle()?, buf.as_ptr() as _, buf.len() as _, - ucp_dt_make_contig(1), tag, - Some(callback), + ¶m, ) }; if status.is_null() { @@ -145,7 +182,11 @@ impl Endpoint { self.handle, iov.len() ); - unsafe extern "C" fn callback(request: *mut c_void, status: ucs_status_t) { + unsafe extern "C" fn callback( + request: *mut c_void, + status: ucs_status_t, + _user_data: *mut c_void, + ) { trace!( "tag_send_vectored: complete. req={:?}, status={:?}", request, @@ -154,14 +195,24 @@ impl Endpoint { let request = &mut *(request as *mut Request); request.waker.wake(); } + let param = ucp_request_param_t { + op_attr_mask: ucp_op_attr_t::UCP_OP_ATTR_FIELD_CALLBACK as u32 + | ucp_op_attr_t::UCP_OP_ATTR_FIELD_DATATYPE as u32, + cb: unsafe { + let mut cb: ucp_request_param_t__bindgen_ty_1 = std::mem::zeroed(); + cb.send = Some(callback); + cb + }, + datatype: ucp_dt_type::UCP_DATATYPE_IOV as _, + ..unsafe { std::mem::zeroed() } + }; let status = unsafe { - ucp_tag_send_nb( + ucp_tag_send_nbx( self.get_handle()?, iov.as_ptr() as _, iov.len() as _, - ucp_dt_type::UCP_DATATYPE_IOV as _, tag, - Some(callback), + ¶m, ) }; let total_len = iov.iter().map(|v| v.len()).sum(); From c9bce466fce583edf721fc5bc5c706f745e6dddc Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Fri, 27 Jun 2025 16:15:55 +0800 Subject: [PATCH 04/38] use RequestParam --- .gitignore | 2 + src/ucp/endpoint/am.rs | 73 +++++++++++++--------------------- src/ucp/endpoint/mod.rs | 1 + src/ucp/endpoint/param.rs | 82 +++++++++++++++++++++++++++++++++++++++ src/ucp/endpoint/tag.rs | 58 +++++++-------------------- 5 files changed, 124 insertions(+), 92 deletions(-) create mode 100644 src/ucp/endpoint/param.rs diff --git a/.gitignore b/.gitignore index 96ef6c0..057c110 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,4 @@ +.vscode/ + /target Cargo.lock diff --git a/src/ucp/endpoint/am.rs b/src/ucp/endpoint/am.rs index 2002020..9602d85 100644 --- a/src/ucp/endpoint/am.rs +++ b/src/ucp/endpoint/am.rs @@ -1,6 +1,7 @@ use crossbeam::queue::SegQueue; use tokio::sync::Notify; +use super::param::RequestParam; use super::*; use std::{ io::{IoSlice, IoSliceMut}, @@ -105,7 +106,7 @@ pub struct AmMsg<'a> { } impl<'a> AmMsg<'a> { - fn from_raw(worker: &'a Worker, msg: RawMsg) -> Self { + fn from_raw(worker: &'a Worker, msg: RawMsg) -> Self{ AmMsg { worker, msg } } @@ -249,22 +250,12 @@ impl<'a> AmMsg<'a> { self.worker.handle, iov.len() ); - let mut param = MaybeUninit::::uninit(); - let (buffer, count) = unsafe { - let param = &mut *param.as_mut_ptr(); - param.op_attr_mask = ucp_op_attr_t::UCP_OP_ATTR_FIELD_CALLBACK as u32 - | ucp_op_attr_t::UCP_OP_ATTR_FIELD_DATATYPE as u32; - param.cb = ucp_request_param_t__bindgen_ty_1 { - recv_am: Some(callback), - }; - - if iov.len() == 1 { - param.datatype = ucp_dt_make_contig(1); - (iov[0].as_ptr(), iov[0].len()) - } else { - param.datatype = ucp_dt_type::UCP_DATATYPE_IOV as _; - (iov.as_ptr() as _, iov.len()) - } + + let param = RequestParam::new().cb_recv_am(Some(callback)); + let (buffer, count, param) = if iov.len() == 1 { + (iov[0].as_ptr(), iov[0].len(), param) + } else { + (iov.as_ptr() as _, iov.len(), param.iov()) }; let status = unsafe { @@ -273,7 +264,7 @@ impl<'a> AmMsg<'a> { data_desc as _, buffer as _, count as _, - param.as_ptr(), + param.as_ref(), ) }; if status.is_null() { @@ -546,34 +537,22 @@ async fn am_send( request.waker.wake(); } - let mut param = MaybeUninit::::uninit(); - let (buffer, count) = unsafe { - let param = &mut *param.as_mut_ptr(); - param.op_attr_mask = ucp_op_attr_t::UCP_OP_ATTR_FIELD_CALLBACK as u32 - | ucp_op_attr_t::UCP_OP_ATTR_FIELD_DATATYPE as u32 - | ucp_op_attr_t::UCP_OP_ATTR_FIELD_FLAGS as u32; - param.flags = 0; - param.cb = ucp_request_param_t__bindgen_ty_1 { - send: Some(callback), - }; - - match proto { - Some(AmProto::Eager) => param.flags |= ucp_send_am_flags::UCP_AM_SEND_FLAG_EAGER.0, - Some(AmProto::Rndv) => param.flags |= ucp_send_am_flags::UCP_AM_SEND_FLAG_RNDV.0, - _ => (), - } - - if need_reply { - param.flags |= ucp_send_am_flags::UCP_AM_SEND_FLAG_REPLY.0; - } - - if data.len() == 1 { - param.datatype = ucp_dt_make_contig(1); - (data[0].as_ptr(), data[0].len()) - } else { - param.datatype = ucp_dt_type::UCP_DATATYPE_IOV as _; - (data.as_ptr() as _, data.len()) - } + // Use RequestParam builder for send + let param = RequestParam::new().send_cb(Some(callback)); + let param = match proto { + Some(AmProto::Eager) => param.set_flag_eager(), + Some(AmProto::Rndv) => param.set_flag_rndv(), + None => param, + }; + let param = if need_reply { + param.set_flag_reply() + } else { + param + }; + let (buffer, count, param) = if data.len() == 1 { + (data[0].as_ptr(), data[0].len(), param) + } else { + (data.as_ptr() as _, data.len(), param.iov()) }; let status = unsafe { @@ -584,7 +563,7 @@ async fn am_send( header.len() as _, buffer as _, count as _, - param.as_mut_ptr(), + param.as_ref(), ) }; if status.is_null() { diff --git a/src/ucp/endpoint/mod.rs b/src/ucp/endpoint/mod.rs index b477468..757a20b 100644 --- a/src/ucp/endpoint/mod.rs +++ b/src/ucp/endpoint/mod.rs @@ -12,6 +12,7 @@ mod am; mod rma; mod stream; mod tag; +mod param; #[cfg(feature = "am")] pub use self::am::*; diff --git a/src/ucp/endpoint/param.rs b/src/ucp/endpoint/param.rs new file mode 100644 index 0000000..81e17d8 --- /dev/null +++ b/src/ucp/endpoint/param.rs @@ -0,0 +1,82 @@ +use ucx1_sys::*; + +pub struct RequestParam { + inner: ucp_request_param_t, +} + +impl RequestParam { + pub fn new() -> Self { + // Zeroed for safety, as in C + Self { + inner: unsafe { std::mem::zeroed() }, + } + } + + pub fn send_cb(mut self, callback: ucp_send_nbx_callback_t) -> Self { + self.inner.op_attr_mask |= ucp_op_attr_t::UCP_OP_ATTR_FIELD_CALLBACK as u32; + unsafe { + let mut cb: ucp_request_param_t__bindgen_ty_1 = std::mem::zeroed(); + cb.send = callback; + self.inner.cb = cb; + } + self + } + + pub fn cb_tag_recv(mut self, callback: ucp_tag_recv_nbx_callback_t) -> Self { + self.inner.op_attr_mask |= ucp_op_attr_t::UCP_OP_ATTR_FIELD_CALLBACK as u32; + unsafe { + let mut cb: ucp_request_param_t__bindgen_ty_1 = std::mem::zeroed(); + cb.recv = callback; + self.inner.cb = cb; + } + self + } + + #[cfg(feature = "am")] + pub fn cb_recv_am(mut self, callback: ucp_am_recv_data_nbx_callback_t) -> Self { + self.inner.op_attr_mask |= ucp_op_attr_t::UCP_OP_ATTR_FIELD_CALLBACK as u32; + unsafe { + let mut cb: ucp_request_param_t__bindgen_ty_1 = std::mem::zeroed(); + cb.recv_am = callback; + self.inner.cb = cb; + } + self + } + + pub fn iov(mut self) -> Self { + self.inner.op_attr_mask |= ucp_op_attr_t::UCP_OP_ATTR_FIELD_DATATYPE as u32; + self.inner.datatype = ucp_dt_type::UCP_DATATYPE_IOV as _; + self + } + + #[cfg(feature = "am")] + fn set_flag(&mut self) { + self.inner.op_attr_mask |= ucp_op_attr_t::UCP_OP_ATTR_FIELD_FLAGS as u32; + } + + #[cfg(feature = "am")] + pub fn set_flag_eager(mut self) -> Self { + self.set_flag(); + self.inner.flags |= ucp_send_am_flags::UCP_AM_SEND_FLAG_EAGER.0; + self + } + + #[cfg(feature = "am")] + pub fn set_flag_rndv(mut self) -> Self { + self.set_flag(); + self.inner.flags |= ucp_send_am_flags::UCP_AM_SEND_FLAG_RNDV.0; + self + } + + #[cfg(feature = "am")] + pub fn set_flag_reply(mut self) -> Self { + self.set_flag(); + self.inner.flags |= ucp_send_am_flags::UCP_AM_SEND_FLAG_REPLY.0; + self + } + + pub fn as_ref(&self) -> *const ucp_request_param_t { + &self.inner as *const _ + } +} + diff --git a/src/ucp/endpoint/tag.rs b/src/ucp/endpoint/tag.rs index 48f6c55..e8feeac 100644 --- a/src/ucp/endpoint/tag.rs +++ b/src/ucp/endpoint/tag.rs @@ -1,3 +1,4 @@ +use super::param::RequestParam; use super::*; use std::io::{IoSlice, IoSliceMut}; @@ -41,15 +42,8 @@ impl Worker { let request = &mut *(request as *mut Request); request.waker.wake(); } - let param = ucp_request_param_t { - op_attr_mask: ucp_op_attr_t::UCP_OP_ATTR_FIELD_CALLBACK as u32, - cb: unsafe { - let mut cb: ucp_request_param_t__bindgen_ty_1 = std::mem::zeroed(); - cb.recv = Some(callback); - cb - }, - ..unsafe { std::mem::zeroed() } - }; + // Use RequestParam builder + let param = RequestParam::new().cb_tag_recv(Some(callback)); let status = unsafe { ucp_tag_recv_nbx( self.handle, @@ -57,7 +51,7 @@ impl Worker { buf.len() as _, tag, tag_mask, - ¶m, + param.as_ref(), ) }; @@ -98,18 +92,8 @@ impl Worker { let request = &mut *(request as *mut Request); request.waker.wake(); } - // --- 修改为ucp_tag_recv_nbx --- - let param = ucp_request_param_t { - op_attr_mask: ucp_op_attr_t::UCP_OP_ATTR_FIELD_CALLBACK as u32 - | ucp_op_attr_t::UCP_OP_ATTR_FIELD_DATATYPE as u32, - cb: unsafe { - let mut cb: ucp_request_param_t__bindgen_ty_1 = std::mem::zeroed(); - cb.recv = Some(callback); - cb - }, - datatype: ucp_dt_type::UCP_DATATYPE_IOV as _, - ..unsafe { std::mem::zeroed() } - }; + // Use RequestParam builder for iov + let param = RequestParam::new().cb_tag_recv(Some(callback)).iov(); let status = unsafe { ucp_tag_recv_nbx( self.handle, @@ -117,7 +101,7 @@ impl Worker { iov.len() as _, tag, u64::max_value(), - ¶m, + param.as_ref(), ) }; Error::from_ptr(status)?; @@ -143,22 +127,15 @@ impl Endpoint { let request = &mut *(request as *mut Request); request.waker.wake(); } - let param = ucp_request_param_t { - op_attr_mask: ucp_op_attr_t::UCP_OP_ATTR_FIELD_CALLBACK as u32, - cb: unsafe { - let mut cb: ucp_request_param_t__bindgen_ty_1 = std::mem::zeroed(); - cb.send = Some(callback); - cb - }, - ..unsafe { std::mem::zeroed() } - }; + // Use RequestParam builder + let param = RequestParam::new().send_cb(Some(callback)); let status = unsafe { ucp_tag_send_nbx( self.get_handle()?, buf.as_ptr() as _, buf.len() as _, tag, - ¶m, + param.as_ref(), ) }; if status.is_null() { @@ -195,24 +172,15 @@ impl Endpoint { let request = &mut *(request as *mut Request); request.waker.wake(); } - let param = ucp_request_param_t { - op_attr_mask: ucp_op_attr_t::UCP_OP_ATTR_FIELD_CALLBACK as u32 - | ucp_op_attr_t::UCP_OP_ATTR_FIELD_DATATYPE as u32, - cb: unsafe { - let mut cb: ucp_request_param_t__bindgen_ty_1 = std::mem::zeroed(); - cb.send = Some(callback); - cb - }, - datatype: ucp_dt_type::UCP_DATATYPE_IOV as _, - ..unsafe { std::mem::zeroed() } - }; + // Use RequestParam builder for iov + let param = RequestParam::new().send_cb(Some(callback)).iov(); let status = unsafe { ucp_tag_send_nbx( self.get_handle()?, iov.as_ptr() as _, iov.len() as _, tag, - ¶m, + param.as_ref(), ) }; let total_len = iov.iter().map(|v| v.len()).sum(); From fa298d3ec1b8038802368bcd32ed4a7d11c3c618 Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Fri, 27 Jun 2025 16:33:05 +0800 Subject: [PATCH 05/38] stream use nbx and Param --- src/ucp/endpoint/param.rs | 11 ++++- src/ucp/endpoint/stream.rs | 92 +++++++++++++++++++++++++++++++++----- 2 files changed, 92 insertions(+), 11 deletions(-) diff --git a/src/ucp/endpoint/param.rs b/src/ucp/endpoint/param.rs index 81e17d8..36889f7 100644 --- a/src/ucp/endpoint/param.rs +++ b/src/ucp/endpoint/param.rs @@ -32,6 +32,16 @@ impl RequestParam { self } + pub fn cb_stream_recv(mut self, callback: ucp_stream_recv_nbx_callback_t) -> Self { + self.inner.op_attr_mask |= ucp_op_attr_t::UCP_OP_ATTR_FIELD_CALLBACK as u32; + unsafe { + let mut cb: ucp_request_param_t__bindgen_ty_1 = std::mem::zeroed(); + cb.recv_stream = callback; + self.inner.cb = cb; + } + self + } + #[cfg(feature = "am")] pub fn cb_recv_am(mut self, callback: ucp_am_recv_data_nbx_callback_t) -> Self { self.inner.op_attr_mask |= ucp_op_attr_t::UCP_OP_ATTR_FIELD_CALLBACK as u32; @@ -79,4 +89,3 @@ impl RequestParam { &self.inner as *const _ } } - diff --git a/src/ucp/endpoint/stream.rs b/src/ucp/endpoint/stream.rs index 411a80b..d8c3eed 100644 --- a/src/ucp/endpoint/stream.rs +++ b/src/ucp/endpoint/stream.rs @@ -1,10 +1,11 @@ use super::*; +use super::param::RequestParam; impl Endpoint { /// Sends data through stream. pub async fn stream_send(&self, buf: &[u8]) -> Result { trace!("stream_send: endpoint={:?} len={}", self.handle, buf.len()); - unsafe extern "C" fn callback(request: *mut c_void, status: ucs_status_t) { + unsafe extern "C" fn callback(request: *mut c_void, status: ucs_status_t, _user_data: *mut c_void) { trace!( "stream_send: complete. req={:?}, status={:?}", request, @@ -13,14 +14,13 @@ impl Endpoint { let request = &mut *(request as *mut Request); request.waker.wake(); } + let param = RequestParam::new().send_cb(Some(callback)); let status = unsafe { - ucp_stream_send_nb( + ucp_stream_send_nbx( self.get_handle()?, buf.as_ptr() as _, buf.len() as _, - ucp_dt_make_contig(1), - Some(callback), - 0, + param.as_ref(), ) }; if status.is_null() { @@ -40,7 +40,7 @@ impl Endpoint { /// Receives data from stream. pub async fn stream_recv(&self, buf: &mut [MaybeUninit]) -> Result { trace!("stream_recv: endpoint={:?} len={}", self.handle, buf.len()); - unsafe extern "C" fn callback(request: *mut c_void, status: ucs_status_t, length: usize) { + unsafe extern "C" fn callback(request: *mut c_void, status: ucs_status_t, length: usize, _user_data: *mut c_void) { trace!( "stream_recv: complete. req={:?}, status={:?}, len={}", request, @@ -51,15 +51,14 @@ impl Endpoint { request.waker.wake(); } let mut length = MaybeUninit::::uninit(); + let param = RequestParam::new().cb_stream_recv(Some(callback)); let status = unsafe { - ucp_stream_recv_nb( + ucp_stream_recv_nbx( self.get_handle()?, buf.as_mut_ptr() as _, buf.len() as _, - ucp_dt_make_contig(1), - Some(callback), length.as_mut_ptr(), - 0, + param.as_ref(), ) }; if status.is_null() { @@ -87,3 +86,76 @@ unsafe fn poll_stream(ptr: ucs_status_ptr_t) -> Poll { Poll::Ready(len.assume_init()) } } + +#[cfg(test)] +mod tests { + use super::*; + #[test_log::test] + fn stream() { + for i in 0..20_usize { + spawn_thread!(_stream(4 << i)).join().unwrap(); + } + } + + async fn _stream(msg_size: usize) { + let context1 = Context::new().unwrap(); + let worker1 = context1.create_worker().unwrap(); + let context2 = Context::new().unwrap(); + let worker2 = context2.create_worker().unwrap(); + tokio::task::spawn_local(worker1.clone().polling()); + tokio::task::spawn_local(worker2.clone().polling()); + + // connect with each other + let mut listener = worker1 + .create_listener("0.0.0.0:0".parse().unwrap()) + .unwrap(); + let listen_port = listener.socket_addr().unwrap().port(); + println!("listen at port {}", listen_port); + let mut addr: std::net::SocketAddr = "127.0.0.1:0".parse().unwrap(); + addr.set_port(listen_port); + + let (endpoint1, endpoint2) = tokio::join!( + async { + let conn1 = listener.next().await; + worker1.accept(conn1).await.unwrap() + }, + async { worker2.connect_socket(addr).await.unwrap() }, + ); + + // send stream message + tokio::join!( + async { + // send + let buf = vec![42u8; msg_size]; + endpoint2.stream_send(&buf).await.unwrap(); + println!("stream sent"); + }, + async { + // recv + let mut buf = vec![std::mem::MaybeUninit::::uninit(); msg_size]; + let mut start = 0; + while start < msg_size { + let len = endpoint1.stream_recv(&mut buf[start..]).await.unwrap(); + if len == 0 { + break; // no more data + } + start += len; + } + let buf: Vec = unsafe { + buf.into_iter().map(|b| b.assume_init()).collect() + }; + assert_eq!(buf, vec![42u8; msg_size]); + println!("stream received"); + } + ); + + assert_eq!(endpoint1.get_rc(), (1, 1)); + assert_eq!(endpoint2.get_rc(), (1, 1)); + assert_eq!(endpoint1.close(false).await, Ok(())); + assert_eq!(endpoint2.close(false).await, Err(Error::ConnectionReset)); + assert_eq!(endpoint1.get_rc(), (1, 0)); + assert_eq!(endpoint2.get_rc(), (1, 1)); + assert_eq!(endpoint2.close(true).await, Ok(())); + assert_eq!(endpoint2.get_rc(), (1, 0)); + } +} From 8940e9bff44d5227db8bcb7ea7f0475ff5f04e96 Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Fri, 27 Jun 2025 16:43:49 +0800 Subject: [PATCH 06/38] rma use nbx --- src/ucp/endpoint/rma.rs | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/ucp/endpoint/rma.rs b/src/ucp/endpoint/rma.rs index e9cd41e..8af8bd9 100644 --- a/src/ucp/endpoint/rma.rs +++ b/src/ucp/endpoint/rma.rs @@ -1,3 +1,5 @@ +use super::param::RequestParam; + use super::*; /// A memory region allocated through UCP library, @@ -112,19 +114,20 @@ impl Endpoint { /// Stores a contiguous block of data into remote memory. pub async fn put(&self, buf: &[u8], remote_addr: u64, rkey: &RKey) -> Result<(), Error> { trace!("put: endpoint={:?} len={}", self.handle, buf.len()); - unsafe extern "C" fn callback(request: *mut c_void, status: ucs_status_t) { + unsafe extern "C" fn callback(request: *mut c_void, status: ucs_status_t, _user_data: *mut c_void) { trace!("put: complete. req={:?}, status={:?}", request, status); let request = &mut *(request as *mut Request); request.waker.wake(); } + let param = RequestParam::new().send_cb(Some(callback)); let status = unsafe { - ucp_put_nb( + ucp_put_nbx( self.get_handle()?, buf.as_ptr() as _, buf.len() as _, remote_addr, rkey.handle, - Some(callback), + param.as_ref(), ) }; if status.is_null() { @@ -144,19 +147,20 @@ impl Endpoint { /// Loads a contiguous block of data from remote memory. pub async fn get(&self, buf: &mut [u8], remote_addr: u64, rkey: &RKey) -> Result<(), Error> { trace!("get: endpoint={:?} len={}", self.handle, buf.len()); - unsafe extern "C" fn callback(request: *mut c_void, status: ucs_status_t) { + unsafe extern "C" fn callback(request: *mut c_void, status: ucs_status_t, _user_data: *mut c_void) { trace!("get: complete. req={:?}, status={:?}", request, status); let request = &mut *(request as *mut Request); request.waker.wake(); } + let param = RequestParam::new().send_cb(Some(callback)); let status = unsafe { - ucp_get_nb( + ucp_get_nbx( self.get_handle()?, buf.as_mut_ptr() as _, buf.len() as _, remote_addr, rkey.handle, - Some(callback), + param.as_ref(), ) }; if status.is_null() { From cdf6c85b9a65d443ff1871905eccfa833503c24f Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Fri, 27 Jun 2025 20:07:33 +0800 Subject: [PATCH 07/38] add author Li Kaiwei --- Cargo.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 0b06a69..6f33668 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "async-ucx" -version = "0.1.1" -authors = ["Runji Wang ", "Yiyuan Liu "] +version = "0.2.0" +authors = ["Runji Wang ", "Yiyuan Liu ", "Kaiwei Li "] edition = "2021" description = "Asynchronous Rust bindings to UCX." homepage = "https://github.com/madsys-dev/async-ucx" From 2808d4910ab760a870ce95a276839ef8b780efb1 Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Fri, 27 Jun 2025 20:14:54 +0800 Subject: [PATCH 08/38] update ucx to 1.18.1 --- ucx1-sys/ucx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ucx1-sys/ucx b/ucx1-sys/ucx index 938ffcd..d9aa565 160000 --- a/ucx1-sys/ucx +++ b/ucx1-sys/ucx @@ -1 +1 @@ -Subproject commit 938ffcd10122742d0f46a4f609e7395d1648c969 +Subproject commit d9aa5650d4cbcbb00d61af980614dbe9dd27a1f2 From 1a64dec42dd2a150a308a6e496fc244a4b56ad37 Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Fri, 27 Jun 2025 20:26:02 +0800 Subject: [PATCH 09/38] little fix --- examples/bench-multi-thread.rs | 2 +- ucx1-sys/build.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/bench-multi-thread.rs b/examples/bench-multi-thread.rs index e6b1ef2..9651c5d 100644 --- a/examples/bench-multi-thread.rs +++ b/examples/bench-multi-thread.rs @@ -123,7 +123,7 @@ impl WorkerThread { .build() .unwrap(); let local = tokio::task::LocalSet::new(); - #[cfg(not(event))] + #[cfg(not(feature = "event"))] local.spawn_local(worker.clone().polling()); #[cfg(feature = "event")] local.spawn_local(worker.clone().event_poll()); diff --git a/ucx1-sys/build.rs b/ucx1-sys/build.rs index 0f85602..2239b8e 100644 --- a/ucx1-sys/build.rs +++ b/ucx1-sys/build.rs @@ -15,7 +15,7 @@ fn main() { // Tell cargo to invalidate the built crate whenever the wrapper changes println!("cargo:rerun-if-changed=wrapper.h"); - // build_from_source(); + build_from_source(); // The bindgen::Builder is the main entry point // to bindgen, and lets you build up options for From 6a3216aa9dac17fb0f32d3628bd22a8d10c48fa3 Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Fri, 27 Jun 2025 20:31:33 +0800 Subject: [PATCH 10/38] rename send_cb -> cb_send --- src/ucp/endpoint/am.rs | 2 +- src/ucp/endpoint/param.rs | 2 +- src/ucp/endpoint/rma.rs | 4 ++-- src/ucp/endpoint/stream.rs | 2 +- src/ucp/endpoint/tag.rs | 4 ++-- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/ucp/endpoint/am.rs b/src/ucp/endpoint/am.rs index 9602d85..8e5d192 100644 --- a/src/ucp/endpoint/am.rs +++ b/src/ucp/endpoint/am.rs @@ -538,7 +538,7 @@ async fn am_send( } // Use RequestParam builder for send - let param = RequestParam::new().send_cb(Some(callback)); + let param = RequestParam::new().cb_send(Some(callback)); let param = match proto { Some(AmProto::Eager) => param.set_flag_eager(), Some(AmProto::Rndv) => param.set_flag_rndv(), diff --git a/src/ucp/endpoint/param.rs b/src/ucp/endpoint/param.rs index 36889f7..5bac074 100644 --- a/src/ucp/endpoint/param.rs +++ b/src/ucp/endpoint/param.rs @@ -12,7 +12,7 @@ impl RequestParam { } } - pub fn send_cb(mut self, callback: ucp_send_nbx_callback_t) -> Self { + pub fn cb_send(mut self, callback: ucp_send_nbx_callback_t) -> Self { self.inner.op_attr_mask |= ucp_op_attr_t::UCP_OP_ATTR_FIELD_CALLBACK as u32; unsafe { let mut cb: ucp_request_param_t__bindgen_ty_1 = std::mem::zeroed(); diff --git a/src/ucp/endpoint/rma.rs b/src/ucp/endpoint/rma.rs index 8af8bd9..f83aa21 100644 --- a/src/ucp/endpoint/rma.rs +++ b/src/ucp/endpoint/rma.rs @@ -119,7 +119,7 @@ impl Endpoint { let request = &mut *(request as *mut Request); request.waker.wake(); } - let param = RequestParam::new().send_cb(Some(callback)); + let param = RequestParam::new().cb_send(Some(callback)); let status = unsafe { ucp_put_nbx( self.get_handle()?, @@ -152,7 +152,7 @@ impl Endpoint { let request = &mut *(request as *mut Request); request.waker.wake(); } - let param = RequestParam::new().send_cb(Some(callback)); + let param = RequestParam::new().cb_send(Some(callback)); let status = unsafe { ucp_get_nbx( self.get_handle()?, diff --git a/src/ucp/endpoint/stream.rs b/src/ucp/endpoint/stream.rs index d8c3eed..ddcb540 100644 --- a/src/ucp/endpoint/stream.rs +++ b/src/ucp/endpoint/stream.rs @@ -14,7 +14,7 @@ impl Endpoint { let request = &mut *(request as *mut Request); request.waker.wake(); } - let param = RequestParam::new().send_cb(Some(callback)); + let param = RequestParam::new().cb_send(Some(callback)); let status = unsafe { ucp_stream_send_nbx( self.get_handle()?, diff --git a/src/ucp/endpoint/tag.rs b/src/ucp/endpoint/tag.rs index e8feeac..677ac96 100644 --- a/src/ucp/endpoint/tag.rs +++ b/src/ucp/endpoint/tag.rs @@ -128,7 +128,7 @@ impl Endpoint { request.waker.wake(); } // Use RequestParam builder - let param = RequestParam::new().send_cb(Some(callback)); + let param = RequestParam::new().cb_send(Some(callback)); let status = unsafe { ucp_tag_send_nbx( self.get_handle()?, @@ -173,7 +173,7 @@ impl Endpoint { request.waker.wake(); } // Use RequestParam builder for iov - let param = RequestParam::new().send_cb(Some(callback)).iov(); + let param = RequestParam::new().cb_send(Some(callback)).iov(); let status = unsafe { ucp_tag_send_nbx( self.get_handle()?, From 67691ca35d0f2e0be75d99547b353c8d9c406a44 Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Fri, 27 Jun 2025 23:53:10 +0800 Subject: [PATCH 11/38] fmt --- src/ucp/endpoint/am.rs | 2 +- src/ucp/endpoint/mod.rs | 2 +- src/ucp/endpoint/rma.rs | 12 ++++++++++-- src/ucp/endpoint/stream.rs | 19 +++++++++++++------ src/ucp/worker.rs | 6 +++++- 5 files changed, 30 insertions(+), 11 deletions(-) diff --git a/src/ucp/endpoint/am.rs b/src/ucp/endpoint/am.rs index 8e5d192..9d7f4ce 100644 --- a/src/ucp/endpoint/am.rs +++ b/src/ucp/endpoint/am.rs @@ -106,7 +106,7 @@ pub struct AmMsg<'a> { } impl<'a> AmMsg<'a> { - fn from_raw(worker: &'a Worker, msg: RawMsg) -> Self{ + fn from_raw(worker: &'a Worker, msg: RawMsg) -> Self { AmMsg { worker, msg } } diff --git a/src/ucp/endpoint/mod.rs b/src/ucp/endpoint/mod.rs index 757a20b..c3b18ae 100644 --- a/src/ucp/endpoint/mod.rs +++ b/src/ucp/endpoint/mod.rs @@ -9,10 +9,10 @@ use std::task::Poll; #[cfg(feature = "am")] mod am; +mod param; mod rma; mod stream; mod tag; -mod param; #[cfg(feature = "am")] pub use self::am::*; diff --git a/src/ucp/endpoint/rma.rs b/src/ucp/endpoint/rma.rs index f83aa21..c08e5b5 100644 --- a/src/ucp/endpoint/rma.rs +++ b/src/ucp/endpoint/rma.rs @@ -114,7 +114,11 @@ impl Endpoint { /// Stores a contiguous block of data into remote memory. pub async fn put(&self, buf: &[u8], remote_addr: u64, rkey: &RKey) -> Result<(), Error> { trace!("put: endpoint={:?} len={}", self.handle, buf.len()); - unsafe extern "C" fn callback(request: *mut c_void, status: ucs_status_t, _user_data: *mut c_void) { + unsafe extern "C" fn callback( + request: *mut c_void, + status: ucs_status_t, + _user_data: *mut c_void, + ) { trace!("put: complete. req={:?}, status={:?}", request, status); let request = &mut *(request as *mut Request); request.waker.wake(); @@ -147,7 +151,11 @@ impl Endpoint { /// Loads a contiguous block of data from remote memory. pub async fn get(&self, buf: &mut [u8], remote_addr: u64, rkey: &RKey) -> Result<(), Error> { trace!("get: endpoint={:?} len={}", self.handle, buf.len()); - unsafe extern "C" fn callback(request: *mut c_void, status: ucs_status_t, _user_data: *mut c_void) { + unsafe extern "C" fn callback( + request: *mut c_void, + status: ucs_status_t, + _user_data: *mut c_void, + ) { trace!("get: complete. req={:?}, status={:?}", request, status); let request = &mut *(request as *mut Request); request.waker.wake(); diff --git a/src/ucp/endpoint/stream.rs b/src/ucp/endpoint/stream.rs index ddcb540..a2f4bfb 100644 --- a/src/ucp/endpoint/stream.rs +++ b/src/ucp/endpoint/stream.rs @@ -1,11 +1,15 @@ -use super::*; use super::param::RequestParam; +use super::*; impl Endpoint { /// Sends data through stream. pub async fn stream_send(&self, buf: &[u8]) -> Result { trace!("stream_send: endpoint={:?} len={}", self.handle, buf.len()); - unsafe extern "C" fn callback(request: *mut c_void, status: ucs_status_t, _user_data: *mut c_void) { + unsafe extern "C" fn callback( + request: *mut c_void, + status: ucs_status_t, + _user_data: *mut c_void, + ) { trace!( "stream_send: complete. req={:?}, status={:?}", request, @@ -40,7 +44,12 @@ impl Endpoint { /// Receives data from stream. pub async fn stream_recv(&self, buf: &mut [MaybeUninit]) -> Result { trace!("stream_recv: endpoint={:?} len={}", self.handle, buf.len()); - unsafe extern "C" fn callback(request: *mut c_void, status: ucs_status_t, length: usize, _user_data: *mut c_void) { + unsafe extern "C" fn callback( + request: *mut c_void, + status: ucs_status_t, + length: usize, + _user_data: *mut c_void, + ) { trace!( "stream_recv: complete. req={:?}, status={:?}, len={}", request, @@ -141,9 +150,7 @@ mod tests { } start += len; } - let buf: Vec = unsafe { - buf.into_iter().map(|b| b.assume_init()).collect() - }; + let buf: Vec = unsafe { buf.into_iter().map(|b| b.assume_init()).collect() }; assert_eq!(buf, vec![42u8; msg_size]); println!("stream received"); } diff --git a/src/ucp/worker.rs b/src/ucp/worker.rs index d034a03..d1ad850 100644 --- a/src/ucp/worker.rs +++ b/src/ucp/worker.rs @@ -198,7 +198,11 @@ impl<'a> AsRef<[u8]> for WorkerAddress<'a> { impl<'a> Drop for WorkerAddress<'a> { fn drop(&mut self) { - trace!("destroy worker address= {:?} {:?}", self.worker.handle, self.handle); + trace!( + "destroy worker address= {:?} {:?}", + self.worker.handle, + self.handle + ); unsafe { ucp_worker_release_address(self.worker.handle, self.handle) } } } From 5b3936907036953be5e2aee22c5412d8999e500d Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Fri, 27 Jun 2025 23:53:46 +0800 Subject: [PATCH 12/38] poll fn are safe --- src/ucp/endpoint/am.rs | 13 ++----------- src/ucp/endpoint/mod.rs | 18 +++++++++--------- src/ucp/endpoint/stream.rs | 6 +++--- src/ucp/endpoint/tag.rs | 6 +++--- 4 files changed, 17 insertions(+), 26 deletions(-) diff --git a/src/ucp/endpoint/am.rs b/src/ucp/endpoint/am.rs index 9d7f4ce..378b86f 100644 --- a/src/ucp/endpoint/am.rs +++ b/src/ucp/endpoint/am.rs @@ -273,9 +273,9 @@ impl<'a> AmMsg<'a> { } else if UCS_PTR_IS_PTR(status) { RequestHandle { ptr: status, - poll_fn: poll_recv, + poll_fn: poll_normal, } - .await; + .await?; Ok(data_len) } else { Err(Error::from_ptr(status).unwrap_err()) @@ -580,15 +580,6 @@ async fn am_send( } } -unsafe fn poll_recv(ptr: ucs_status_ptr_t) -> Poll<()> { - let status = ucp_request_check_status(ptr as _); - if status == ucs_status_t::UCS_INPROGRESS { - Poll::Pending - } else { - Poll::Ready(()) - } -} - #[cfg(test)] #[cfg(feature = "am")] mod tests { diff --git a/src/ucp/endpoint/mod.rs b/src/ucp/endpoint/mod.rs index c3b18ae..3e9c546 100644 --- a/src/ucp/endpoint/mod.rs +++ b/src/ucp/endpoint/mod.rs @@ -253,7 +253,7 @@ impl Endpoint { Ok(()) } else if UCS_PTR_IS_PTR(status) { let result = loop { - if let Poll::Ready(result) = unsafe { poll_normal(status) } { + if let Poll::Ready(result) = poll_normal(status) { unsafe { ucp_request_free(status as _) }; break result; } else { @@ -302,32 +302,32 @@ impl Drop for Endpoint { } /// A handle to the request returned from async IO functions. -struct RequestHandle { +struct RequestHandle Poll> { ptr: ucs_status_ptr_t, - poll_fn: unsafe fn(ucs_status_ptr_t) -> Poll, + poll_fn: F, } -impl Future for RequestHandle { +impl Poll> Future for RequestHandle { type Output = T; fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context) -> Poll { - if let ret @ Poll::Ready(_) = unsafe { (self.poll_fn)(self.ptr) } { + if let ret @ Poll::Ready(_) = (self.poll_fn)(self.ptr) { return ret; } let request = unsafe { &mut *(self.ptr as *mut Request) }; request.waker.register(cx.waker()); - unsafe { (self.poll_fn)(self.ptr) } + (self.poll_fn)(self.ptr) } } -impl Drop for RequestHandle { +impl Poll> Drop for RequestHandle { fn drop(&mut self) { trace!("request free: {:?}", self.ptr); unsafe { ucp_request_free(self.ptr as _) }; } } -unsafe fn poll_normal(ptr: ucs_status_ptr_t) -> Poll> { - let status = ucp_request_check_status(ptr as _); +fn poll_normal(ptr: ucs_status_ptr_t) -> Poll> { + let status = unsafe { ucp_request_check_status(ptr as _) }; if status == ucs_status_t::UCS_INPROGRESS { Poll::Pending } else { diff --git a/src/ucp/endpoint/stream.rs b/src/ucp/endpoint/stream.rs index a2f4bfb..eb4303c 100644 --- a/src/ucp/endpoint/stream.rs +++ b/src/ucp/endpoint/stream.rs @@ -86,13 +86,13 @@ impl Endpoint { } } -unsafe fn poll_stream(ptr: ucs_status_ptr_t) -> Poll { +fn poll_stream(ptr: ucs_status_ptr_t) -> Poll { let mut len = MaybeUninit::::uninit(); - let status = ucp_stream_recv_request_test(ptr as _, len.as_mut_ptr() as _); + let status = unsafe { ucp_stream_recv_request_test(ptr as _, len.as_mut_ptr() as _) }; if status == ucs_status_t::UCS_INPROGRESS { Poll::Pending } else { - Poll::Ready(len.assume_init()) + Poll::Ready(unsafe { len.assume_init() }) } } diff --git a/src/ucp/endpoint/tag.rs b/src/ucp/endpoint/tag.rs index 677ac96..311c30e 100644 --- a/src/ucp/endpoint/tag.rs +++ b/src/ucp/endpoint/tag.rs @@ -199,13 +199,13 @@ impl Endpoint { } } -unsafe fn poll_tag(ptr: ucs_status_ptr_t) -> Poll> { +fn poll_tag(ptr: ucs_status_ptr_t) -> Poll> { let mut info = MaybeUninit::::uninit(); - let status = ucp_tag_recv_request_test(ptr as _, info.as_mut_ptr() as _); + let status = unsafe { ucp_tag_recv_request_test(ptr as _, info.as_mut_ptr() as _) }; match status { ucs_status_t::UCS_INPROGRESS => Poll::Pending, ucs_status_t::UCS_OK => { - let info = info.assume_init(); + let info = unsafe { info.assume_init() }; Poll::Ready(Ok((info.sender_tag, info.length as usize))) } status => Poll::Ready(Err(Error::from_error(status))), From fcad8aa8b76becccbcd25bb97e14ae54f9ca3be5 Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Sat, 28 Jun 2025 11:22:11 +0800 Subject: [PATCH 13/38] add stream Extension for AsyncRead and AsyncWrite --- Cargo.toml | 3 +- src/ucp/endpoint/mod.rs | 38 ++++++- src/ucp/endpoint/stream.rs | 210 ++++++++++++++++++++++++++++++++----- 3 files changed, 217 insertions(+), 34 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 6f33668..7dcd9a0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,9 +27,10 @@ tokio = { version = "1.0", features = ["net"], optional = true } crossbeam = { version = "0.8", optional = true } derivative = "2.2.0" thiserror = "1.0" +pin-project = "1.1.10" [dev-dependencies] -tokio = { version = "1.0", features = ["rt", "time", "macros", "sync"] } +tokio = { version = "1.0", features = ["rt", "time", "macros", "sync", "io-util"] } env_logger = "0.9" tracing = { version = "0.1", default-features = false } tracing-subscriber = { version = "0.2.17", default-features = false, features = ["env-filter", "fmt"] } diff --git a/src/ucp/endpoint/mod.rs b/src/ucp/endpoint/mod.rs index 3e9c546..369c9aa 100644 --- a/src/ucp/endpoint/mod.rs +++ b/src/ucp/endpoint/mod.rs @@ -89,12 +89,16 @@ impl Endpoint { let ptr = Weak::into_raw(weak); unsafe extern "C" fn callback(arg: *mut c_void, ep: ucp_ep_h, status: ucs_status_t) { let weak: Weak = Weak::from_raw(arg as _); + println!("error callback"); if let Some(inner) = weak.upgrade() { inner.set_status(status); // don't drop weak reference + println!("no drop"); + // panic!("{:?}",Error::from_status(status)); std::mem::forget(weak); } else { // no strong rc, force close endpoint here + println!("failed"); let status = ucp_ep_close_nb(ep, ucp_ep_close_mode::UCP_EP_CLOSE_MODE_FORCE as _); let _ = Error::from_ptr(status) .map_err(|err| error!("Force close endpoint failed, {}", err)); @@ -302,15 +306,15 @@ impl Drop for Endpoint { } /// A handle to the request returned from async IO functions. -struct RequestHandle Poll> { +struct RequestHandle { ptr: ucs_status_ptr_t, - poll_fn: F, + poll_fn: fn(ucs_status_ptr_t) -> Poll, } -impl Poll> Future for RequestHandle { +impl Future for RequestHandle { type Output = T; fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context) -> Poll { - if let ret @ Poll::Ready(_) = (self.poll_fn)(self.ptr) { + if let ret @ Poll::Ready(_) = { (self.poll_fn)(self.ptr) } { return ret; } let request = unsafe { &mut *(self.ptr as *mut Request) }; @@ -319,13 +323,37 @@ impl Poll> Future for RequestHandle { } } -impl Poll> Drop for RequestHandle { +impl Drop for RequestHandle { fn drop(&mut self) { trace!("request free: {:?}", self.ptr); unsafe { ucp_request_free(self.ptr as _) }; } } +enum Status { + Completed(Result), + Scheduled(RequestHandle>), +} + +impl Status { + pub fn new( + status: *mut c_void, + immediate: MaybeUninit, + poll_fn: fn(ucs_status_ptr_t) -> Poll>, + ) -> Self { + if status.is_null() { + Self::Completed(Ok(unsafe { immediate.assume_init() })) + } else if UCS_PTR_IS_PTR(status) { + Self::Scheduled(RequestHandle { + ptr: status, + poll_fn, + }) + } else { + Self::Completed(Err(Error::from_error(UCS_PTR_RAW_STATUS(status)))) + } + } +} + fn poll_normal(ptr: ucs_status_ptr_t) -> Poll> { let status = unsafe { ucp_request_check_status(ptr as _) }; if status == ucs_status_t::UCS_INPROGRESS { diff --git a/src/ucp/endpoint/stream.rs b/src/ucp/endpoint/stream.rs index eb4303c..28ae80a 100644 --- a/src/ucp/endpoint/stream.rs +++ b/src/ucp/endpoint/stream.rs @@ -2,8 +2,7 @@ use super::param::RequestParam; use super::*; impl Endpoint { - /// Sends data through stream. - pub async fn stream_send(&self, buf: &[u8]) -> Result { + fn stream_send_impl(&self, buf: &[u8]) -> Result, Error> { trace!("stream_send: endpoint={:?} len={}", self.handle, buf.len()); unsafe extern "C" fn callback( request: *mut c_void, @@ -27,22 +26,27 @@ impl Endpoint { param.as_ref(), ) }; - if status.is_null() { - trace!("stream_send: complete"); - } else if UCS_PTR_IS_PTR(status) { - RequestHandle { - ptr: status, - poll_fn: poll_normal, + Ok(Status::new(status, MaybeUninit::uninit(), poll_normal)) + } + + /// Sends data through stream. + pub async fn stream_send(&self, buf: &[u8]) -> Result { + match self.stream_send_impl(buf)? { + Status::Completed(r) => { + match &r { + Ok(()) => trace!("stream_send: complete"), + Err(e) => error!("stream_send error : {:?}", e), + } + r.map(|_| buf.len()) + } + Status::Scheduled(request_handle) => { + request_handle.await?; + Ok(buf.len()) } - .await?; - } else { - return Err(Error::from_ptr(status).unwrap_err()); } - Ok(buf.len()) } - /// Receives data from stream. - pub async fn stream_recv(&self, buf: &mut [MaybeUninit]) -> Result { + fn stream_recv_impl(&self, buf: &mut [MaybeUninit]) -> Result, Error> { trace!("stream_recv: endpoint={:?} len={}", self.handle, buf.len()); unsafe extern "C" fn callback( request: *mut c_void, @@ -70,35 +74,169 @@ impl Endpoint { param.as_ref(), ) }; - if status.is_null() { - let length = unsafe { length.assume_init() } as usize; - trace!("stream_recv: complete. len={}", length); - Ok(length) - } else if UCS_PTR_IS_PTR(status) { - Ok(RequestHandle { - ptr: status, - poll_fn: poll_stream, + Ok(Status::new(status, length, poll_stream)) + } + + /// Receives data from stream. + pub async fn stream_recv(&self, buf: &mut [MaybeUninit]) -> Result { + match self.stream_recv_impl(buf)? { + Status::Completed(r) => { + match &r { + Ok(x) => trace!("stream_recv: complete. len={}", x), + Err(e) => error!("stream_recv: error : {:?}", e), + } + r } - .await) - } else { - Err(Error::from_ptr(status).unwrap_err()) + Status::Scheduled(request_handle) => request_handle.await, } } } -fn poll_stream(ptr: ucs_status_ptr_t) -> Poll { +fn poll_stream(ptr: ucs_status_ptr_t) -> Poll> { let mut len = MaybeUninit::::uninit(); let status = unsafe { ucp_stream_recv_request_test(ptr as _, len.as_mut_ptr() as _) }; if status == ucs_status_t::UCS_INPROGRESS { Poll::Pending } else { - Poll::Ready(unsafe { len.assume_init() }) + Poll::Ready(Error::from_status(status).map(|_| unsafe { len.assume_init() })) + } +} + +use pin_project::pin_project; +#[pin_project] +pub struct WriteStream<'a> { + endpoint: &'a Endpoint, + #[pin] + request: Option>>, +} + +impl Endpoint { + /// make write stream + pub fn write_stream(&self) -> WriteStream { + WriteStream { + endpoint: self, + request: None, + } + } +} +use futures::FutureExt; +use std::task::ready; +use tokio::io::AsyncWrite; +use tokio::io::ReadBuf; + +fn to_io_error(e: Error) -> std::io::Error { + std::io::Error::new(std::io::ErrorKind::NotFound, e) +} + +impl<'a> AsyncWrite for WriteStream<'a> { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> Poll> { + if let Some(mut req) = self.as_mut().project().request.as_pin_mut() { + let r = match ready!(req.poll_unpin(cx)) { + Ok(_) => Ok(buf.len()), + Err(e) => Err(to_io_error(e)), + }; + self.request = None; + Poll::Ready(r) + } else { + match self.endpoint.stream_send_impl(buf) { + Ok(Status::Completed(r)) => Poll::Ready(r.map(|_| buf.len()).map_err(to_io_error)), + Ok(Status::Scheduled(request_handle)) => { + self.request = Some(request_handle); + Poll::Pending + } + Err(e) => Poll::Ready(Err(to_io_error(e))), + } + } + } + + fn poll_flush( + self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> Poll> { + todo!() + } + + fn poll_shutdown( + self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> Poll> { + todo!() + } +} + +#[pin_project] +pub struct ReadStream<'a> { + endpoint: &'a Endpoint, + #[pin] + request: Option>>, +} + +impl Endpoint { + /// make read stream + pub fn read_stream(&self) -> ReadStream { + ReadStream { + endpoint: self, + request: None, + } + } +} + +use tokio::io::AsyncRead; +impl<'a> AsyncRead for ReadStream<'a> { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + out_buf: &mut ReadBuf<'_>, + ) -> Poll> { + if let Some(mut req) = self.as_mut().project().request.as_pin_mut() { + let r = match ready!(req.poll_unpin(cx)) { + Ok(n) => { + // Safety: The buffer was filled by the recv operation. + unsafe { + out_buf.assume_init(n); + out_buf.advance(n); + } + Ok(()) + } + Err(e) => Err(to_io_error(e)), + }; + self.request = None; + Poll::Ready(r) + } else { + let buf = unsafe { out_buf.unfilled_mut() }; + match self.endpoint.stream_recv_impl(buf) { + Ok(Status::Completed(n_result)) => { + match n_result { + Ok(n) => { + // Safety: The buffer was filled by the recv operation. + unsafe { + out_buf.assume_init(n); + out_buf.advance(n); + } + Poll::Ready(Ok(())) + } + Err(e) => Poll::Ready(Err(to_io_error(e))), + } + } + Ok(Status::Scheduled(request_handle)) => { + self.request = Some(request_handle); + cx.waker().wake_by_ref(); + Poll::Pending + } + Err(e) => Poll::Ready(Err(to_io_error(e))), + } + } } } #[cfg(test)] mod tests { use super::*; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; #[test_log::test] fn stream() { for i in 0..20_usize { @@ -131,7 +269,6 @@ mod tests { async { worker2.connect_socket(addr).await.unwrap() }, ); - // send stream message tokio::join!( async { // send @@ -156,6 +293,23 @@ mod tests { } ); + tokio::join!( + async { + // send + let buf = vec![42u8; msg_size]; + endpoint2.write_stream().write_all(&buf).await.unwrap(); + println!("write_stream"); + }, + async { + // recv + let mut buf = vec![0u8; msg_size]; + endpoint1.read_stream().read_exact(&mut buf).await.unwrap(); + assert_eq!(buf, vec![42u8; msg_size]); + println!("read_stream"); + } + ); + + println!("status {:?}", endpoint2.get_status()); assert_eq!(endpoint1.get_rc(), (1, 1)); assert_eq!(endpoint2.get_rc(), (1, 1)); assert_eq!(endpoint1.close(false).await, Ok(())); From 8f48e8d417ceab5ec0f28b5a1c2e484dfae1f5b1 Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Sat, 28 Jun 2025 11:53:02 +0800 Subject: [PATCH 14/38] add more test --- src/ucp/endpoint/mod.rs | 10 +++--- src/ucp/endpoint/stream.rs | 67 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 70 insertions(+), 7 deletions(-) diff --git a/src/ucp/endpoint/mod.rs b/src/ucp/endpoint/mod.rs index 369c9aa..c8596a9 100644 --- a/src/ucp/endpoint/mod.rs +++ b/src/ucp/endpoint/mod.rs @@ -336,20 +336,20 @@ enum Status { } impl Status { - pub fn new( + pub fn from( status: *mut c_void, immediate: MaybeUninit, poll_fn: fn(ucs_status_ptr_t) -> Poll>, ) -> Self { - if status.is_null() { + if UCS_PTR_RAW_STATUS(status) == ucs_status_t::UCS_OK { Self::Completed(Ok(unsafe { immediate.assume_init() })) - } else if UCS_PTR_IS_PTR(status) { + } else if UCS_PTR_IS_ERR(status) { + Self::Completed(Err(Error::from_error(UCS_PTR_RAW_STATUS(status)))) + } else { Self::Scheduled(RequestHandle { ptr: status, poll_fn, }) - } else { - Self::Completed(Err(Error::from_error(UCS_PTR_RAW_STATUS(status)))) } } } diff --git a/src/ucp/endpoint/stream.rs b/src/ucp/endpoint/stream.rs index 28ae80a..5c18e0a 100644 --- a/src/ucp/endpoint/stream.rs +++ b/src/ucp/endpoint/stream.rs @@ -26,7 +26,7 @@ impl Endpoint { param.as_ref(), ) }; - Ok(Status::new(status, MaybeUninit::uninit(), poll_normal)) + Ok(Status::from(status, MaybeUninit::uninit(), poll_normal)) } /// Sends data through stream. @@ -74,7 +74,7 @@ impl Endpoint { param.as_ref(), ) }; - Ok(Status::new(status, length, poll_stream)) + Ok(Status::from(status, length, poll_stream)) } /// Receives data from stream. @@ -319,4 +319,67 @@ mod tests { assert_eq!(endpoint2.close(true).await, Ok(())); assert_eq!(endpoint2.get_rc(), (1, 0)); } + + #[test_log::test] + fn stream_send_recv_various_contents_and_sizes() { + spawn_thread!(_stream_send_recv_various_contents_and_sizes()) + .join() + .unwrap(); + } + + async fn _stream_send_recv_various_contents_and_sizes() { + let context1 = Context::new().unwrap(); + let worker1 = context1.create_worker().unwrap(); + let context2 = Context::new().unwrap(); + let worker2 = context2.create_worker().unwrap(); + tokio::task::spawn_local(worker1.clone().polling()); + tokio::task::spawn_local(worker2.clone().polling()); + + // connect with each other + let mut listener = worker1 + .create_listener("0.0.0.0:0".parse().unwrap()) + .unwrap(); + let listen_port = listener.socket_addr().unwrap().port(); + let mut addr: std::net::SocketAddr = "127.0.0.1:0".parse().unwrap(); + addr.set_port(listen_port); + + let (endpoint1, endpoint2) = tokio::join!( + async { + let conn1 = listener.next().await; + worker1.accept(conn1).await.unwrap() + }, + async { worker2.connect_socket(addr).await.unwrap() }, + ); + + // Test cases: (data, repeat count) + let test_cases = vec![ + (vec![], 1), + (vec![0u8], 10), + (vec![1, 2, 3, 4, 5], 5), + ((0..128).collect::>(), 3), + ((0..1024).map(|x| (x % 256) as u8).collect::>(), 2), + ((0..4096).map(|x| (x % 256) as u8).collect::>(), 1), + ]; + for (data, repeat) in test_cases { + for _ in 0..repeat { + // send + let send_buf = data.clone(); + let recv_len = send_buf.len(); + let mut recv_buf = vec![0u8; recv_len]; + tokio::join!( + async { + endpoint2.write_stream().write_all(&send_buf).await.unwrap(); + }, + async { + endpoint1 + .read_stream() + .read_exact(&mut recv_buf) + .await + .unwrap(); + assert_eq!(recv_buf, send_buf, "data mismatch for len={}", recv_len); + } + ); + } + } + } } From ad6b30097426c703ec54fbe7a4cc9235d07819ae Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Sun, 29 Jun 2025 10:22:51 +0800 Subject: [PATCH 15/38] to io error --- src/lib.rs | 41 ++++++++++++++++++++++++++++++++++++++ src/ucp/endpoint/stream.rs | 16 ++++++--------- 2 files changed, 47 insertions(+), 10 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 02cc343..87baf84 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -163,3 +163,44 @@ impl Error { } } } + +impl Into for Error { + fn into(self) -> std::io::Error { + use std::io::ErrorKind::*; + let kind = match self { + Error::Inprogress => WouldBlock, + Error::NoMessage => WouldBlock, + Error::NoReource => WouldBlock, + Error::IoError => Other, + Error::NoMemory => OutOfMemory, + Error::InvalidParam => InvalidInput, + Error::Unreachable => NotConnected, + Error::InvalidAddr => InvalidInput, + Error::NotImplemented => Unsupported, + Error::MessageTruncated => InvalidData, + Error::NoProgress => WouldBlock, + Error::BufferTooSmall => UnexpectedEof, + Error::NoElem => NotFound, + Error::SomeConnectsFailed => ConnectionAborted, + Error::NoDevice => NotFound, + Error::Busy => ResourceBusy, + Error::Canceled => Interrupted, + Error::ShmemSegment => Other, + Error::AlreadyExists => AlreadyExists, + Error::OutOfRange => InvalidInput, + Error::Timeout => TimedOut, + Error::ExceedsLimit => Other, + Error::Unsupported => Unsupported, + Error::Rejected => ConnectionRefused, + Error::NotConnected => NotConnected, + Error::ConnectionReset => ConnectionReset, + Error::FirstLinkFailure => Other, + Error::LastLinkFailure => Other, + Error::FirstEndpointFailure => Other, + Error::LastEndpointFailure => Other, + Error::EndpointTimeout => TimedOut, + Error::Unknown => Other, + }; + std::io::Error::new(kind, self) + } +} diff --git a/src/ucp/endpoint/stream.rs b/src/ucp/endpoint/stream.rs index 5c18e0a..f0948b1 100644 --- a/src/ucp/endpoint/stream.rs +++ b/src/ucp/endpoint/stream.rs @@ -124,10 +124,6 @@ use std::task::ready; use tokio::io::AsyncWrite; use tokio::io::ReadBuf; -fn to_io_error(e: Error) -> std::io::Error { - std::io::Error::new(std::io::ErrorKind::NotFound, e) -} - impl<'a> AsyncWrite for WriteStream<'a> { fn poll_write( mut self: Pin<&mut Self>, @@ -137,18 +133,18 @@ impl<'a> AsyncWrite for WriteStream<'a> { if let Some(mut req) = self.as_mut().project().request.as_pin_mut() { let r = match ready!(req.poll_unpin(cx)) { Ok(_) => Ok(buf.len()), - Err(e) => Err(to_io_error(e)), + Err(e) => Err(e.into()), }; self.request = None; Poll::Ready(r) } else { match self.endpoint.stream_send_impl(buf) { - Ok(Status::Completed(r)) => Poll::Ready(r.map(|_| buf.len()).map_err(to_io_error)), + Ok(Status::Completed(r)) => Poll::Ready(r.map(|_| buf.len()).map_err(|e| e.into())), Ok(Status::Scheduled(request_handle)) => { self.request = Some(request_handle); Poll::Pending } - Err(e) => Poll::Ready(Err(to_io_error(e))), + Err(e) => Poll::Ready(Err(e.into())), } } } @@ -202,7 +198,7 @@ impl<'a> AsyncRead for ReadStream<'a> { } Ok(()) } - Err(e) => Err(to_io_error(e)), + Err(e) => Err(e.into()), }; self.request = None; Poll::Ready(r) @@ -219,7 +215,7 @@ impl<'a> AsyncRead for ReadStream<'a> { } Poll::Ready(Ok(())) } - Err(e) => Poll::Ready(Err(to_io_error(e))), + Err(e) => Poll::Ready(Err(e.into())), } } Ok(Status::Scheduled(request_handle)) => { @@ -227,7 +223,7 @@ impl<'a> AsyncRead for ReadStream<'a> { cx.waker().wake_by_ref(); Poll::Pending } - Err(e) => Poll::Ready(Err(to_io_error(e))), + Err(e) => Poll::Ready(Err(e.into())), } } } From 1dbbc3e509c4b206e1f17b9dd6912dbe39e61ef2 Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Sun, 29 Jun 2025 10:53:10 +0800 Subject: [PATCH 16/38] add utill.rs with feature --- Cargo.toml | 1 + src/ucp/endpoint/mod.rs | 4 + src/ucp/endpoint/stream.rs | 214 +------------------------------------ src/ucp/endpoint/util.rs | 195 +++++++++++++++++++++++++++++++++ 4 files changed, 205 insertions(+), 209 deletions(-) create mode 100644 src/ucp/endpoint/util.rs diff --git a/Cargo.toml b/Cargo.toml index 7dcd9a0..429dc7c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,7 @@ categories = ["asynchronous", "api-bindings", "network-programming"] [features] event = ["tokio"] am = ["tokio/sync", "crossbeam"] +util = ["tokio"] [dependencies] ucx1-sys = { version = "0.1", path = "ucx1-sys" } diff --git a/src/ucp/endpoint/mod.rs b/src/ucp/endpoint/mod.rs index c8596a9..f86e5c9 100644 --- a/src/ucp/endpoint/mod.rs +++ b/src/ucp/endpoint/mod.rs @@ -13,10 +13,14 @@ mod param; mod rma; mod stream; mod tag; +#[cfg(feature = "util")] +mod util; #[cfg(feature = "am")] pub use self::am::*; pub use self::rma::*; +#[cfg(feature = "util")] +pub use self::util::*; // State associate with ucp_ep_h // todo: Add a `get_user_data` to UCX diff --git a/src/ucp/endpoint/stream.rs b/src/ucp/endpoint/stream.rs index f0948b1..d9c2f5b 100644 --- a/src/ucp/endpoint/stream.rs +++ b/src/ucp/endpoint/stream.rs @@ -2,7 +2,7 @@ use super::param::RequestParam; use super::*; impl Endpoint { - fn stream_send_impl(&self, buf: &[u8]) -> Result, Error> { + pub(super) fn stream_send_impl(&self, buf: &[u8]) -> Result, Error> { trace!("stream_send: endpoint={:?} len={}", self.handle, buf.len()); unsafe extern "C" fn callback( request: *mut c_void, @@ -46,7 +46,10 @@ impl Endpoint { } } - fn stream_recv_impl(&self, buf: &mut [MaybeUninit]) -> Result, Error> { + pub(super) fn stream_recv_impl( + &self, + buf: &mut [MaybeUninit], + ) -> Result, Error> { trace!("stream_recv: endpoint={:?} len={}", self.handle, buf.len()); unsafe extern "C" fn callback( request: *mut c_void, @@ -102,137 +105,9 @@ fn poll_stream(ptr: ucs_status_ptr_t) -> Poll> { } } -use pin_project::pin_project; -#[pin_project] -pub struct WriteStream<'a> { - endpoint: &'a Endpoint, - #[pin] - request: Option>>, -} - -impl Endpoint { - /// make write stream - pub fn write_stream(&self) -> WriteStream { - WriteStream { - endpoint: self, - request: None, - } - } -} -use futures::FutureExt; -use std::task::ready; -use tokio::io::AsyncWrite; -use tokio::io::ReadBuf; - -impl<'a> AsyncWrite for WriteStream<'a> { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &[u8], - ) -> Poll> { - if let Some(mut req) = self.as_mut().project().request.as_pin_mut() { - let r = match ready!(req.poll_unpin(cx)) { - Ok(_) => Ok(buf.len()), - Err(e) => Err(e.into()), - }; - self.request = None; - Poll::Ready(r) - } else { - match self.endpoint.stream_send_impl(buf) { - Ok(Status::Completed(r)) => Poll::Ready(r.map(|_| buf.len()).map_err(|e| e.into())), - Ok(Status::Scheduled(request_handle)) => { - self.request = Some(request_handle); - Poll::Pending - } - Err(e) => Poll::Ready(Err(e.into())), - } - } - } - - fn poll_flush( - self: Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, - ) -> Poll> { - todo!() - } - - fn poll_shutdown( - self: Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, - ) -> Poll> { - todo!() - } -} - -#[pin_project] -pub struct ReadStream<'a> { - endpoint: &'a Endpoint, - #[pin] - request: Option>>, -} - -impl Endpoint { - /// make read stream - pub fn read_stream(&self) -> ReadStream { - ReadStream { - endpoint: self, - request: None, - } - } -} - -use tokio::io::AsyncRead; -impl<'a> AsyncRead for ReadStream<'a> { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - out_buf: &mut ReadBuf<'_>, - ) -> Poll> { - if let Some(mut req) = self.as_mut().project().request.as_pin_mut() { - let r = match ready!(req.poll_unpin(cx)) { - Ok(n) => { - // Safety: The buffer was filled by the recv operation. - unsafe { - out_buf.assume_init(n); - out_buf.advance(n); - } - Ok(()) - } - Err(e) => Err(e.into()), - }; - self.request = None; - Poll::Ready(r) - } else { - let buf = unsafe { out_buf.unfilled_mut() }; - match self.endpoint.stream_recv_impl(buf) { - Ok(Status::Completed(n_result)) => { - match n_result { - Ok(n) => { - // Safety: The buffer was filled by the recv operation. - unsafe { - out_buf.assume_init(n); - out_buf.advance(n); - } - Poll::Ready(Ok(())) - } - Err(e) => Poll::Ready(Err(e.into())), - } - } - Ok(Status::Scheduled(request_handle)) => { - self.request = Some(request_handle); - cx.waker().wake_by_ref(); - Poll::Pending - } - Err(e) => Poll::Ready(Err(e.into())), - } - } - } -} - #[cfg(test)] mod tests { use super::*; - use tokio::io::{AsyncReadExt, AsyncWriteExt}; #[test_log::test] fn stream() { for i in 0..20_usize { @@ -289,22 +164,6 @@ mod tests { } ); - tokio::join!( - async { - // send - let buf = vec![42u8; msg_size]; - endpoint2.write_stream().write_all(&buf).await.unwrap(); - println!("write_stream"); - }, - async { - // recv - let mut buf = vec![0u8; msg_size]; - endpoint1.read_stream().read_exact(&mut buf).await.unwrap(); - assert_eq!(buf, vec![42u8; msg_size]); - println!("read_stream"); - } - ); - println!("status {:?}", endpoint2.get_status()); assert_eq!(endpoint1.get_rc(), (1, 1)); assert_eq!(endpoint2.get_rc(), (1, 1)); @@ -315,67 +174,4 @@ mod tests { assert_eq!(endpoint2.close(true).await, Ok(())); assert_eq!(endpoint2.get_rc(), (1, 0)); } - - #[test_log::test] - fn stream_send_recv_various_contents_and_sizes() { - spawn_thread!(_stream_send_recv_various_contents_and_sizes()) - .join() - .unwrap(); - } - - async fn _stream_send_recv_various_contents_and_sizes() { - let context1 = Context::new().unwrap(); - let worker1 = context1.create_worker().unwrap(); - let context2 = Context::new().unwrap(); - let worker2 = context2.create_worker().unwrap(); - tokio::task::spawn_local(worker1.clone().polling()); - tokio::task::spawn_local(worker2.clone().polling()); - - // connect with each other - let mut listener = worker1 - .create_listener("0.0.0.0:0".parse().unwrap()) - .unwrap(); - let listen_port = listener.socket_addr().unwrap().port(); - let mut addr: std::net::SocketAddr = "127.0.0.1:0".parse().unwrap(); - addr.set_port(listen_port); - - let (endpoint1, endpoint2) = tokio::join!( - async { - let conn1 = listener.next().await; - worker1.accept(conn1).await.unwrap() - }, - async { worker2.connect_socket(addr).await.unwrap() }, - ); - - // Test cases: (data, repeat count) - let test_cases = vec![ - (vec![], 1), - (vec![0u8], 10), - (vec![1, 2, 3, 4, 5], 5), - ((0..128).collect::>(), 3), - ((0..1024).map(|x| (x % 256) as u8).collect::>(), 2), - ((0..4096).map(|x| (x % 256) as u8).collect::>(), 1), - ]; - for (data, repeat) in test_cases { - for _ in 0..repeat { - // send - let send_buf = data.clone(); - let recv_len = send_buf.len(); - let mut recv_buf = vec![0u8; recv_len]; - tokio::join!( - async { - endpoint2.write_stream().write_all(&send_buf).await.unwrap(); - }, - async { - endpoint1 - .read_stream() - .read_exact(&mut recv_buf) - .await - .unwrap(); - assert_eq!(recv_buf, send_buf, "data mismatch for len={}", recv_len); - } - ); - } - } - } } diff --git a/src/ucp/endpoint/util.rs b/src/ucp/endpoint/util.rs new file mode 100644 index 0000000..66e6895 --- /dev/null +++ b/src/ucp/endpoint/util.rs @@ -0,0 +1,195 @@ +use super::*; +use futures::FutureExt; +use pin_project::pin_project; +use std::task::ready; +use tokio::io::AsyncRead; +use tokio::io::AsyncWrite; +use tokio::io::ReadBuf; + +impl Endpoint { + /// make write stream + pub fn write_stream(&self) -> WriteStream { + WriteStream { + endpoint: self, + request: None, + } + } + /// make read stream + pub fn read_stream(&self) -> ReadStream { + ReadStream { + endpoint: self, + request: None, + } + } +} + +#[pin_project] +/// A stream for writing data asynchronously to an `Endpoint` stream. +pub struct WriteStream<'a> { + endpoint: &'a Endpoint, + #[pin] + request: Option>>, +} + +impl<'a> AsyncWrite for WriteStream<'a> { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> Poll> { + if let Some(mut req) = self.as_mut().project().request.as_pin_mut() { + let r = match ready!(req.poll_unpin(cx)) { + Ok(_) => Ok(buf.len()), + Err(e) => Err(e.into()), + }; + self.request = None; + Poll::Ready(r) + } else { + match self.endpoint.stream_send_impl(buf) { + Ok(Status::Completed(r)) => Poll::Ready(r.map(|_| buf.len()).map_err(|e| e.into())), + Ok(Status::Scheduled(request_handle)) => { + self.request = Some(request_handle); + Poll::Pending + } + Err(e) => Poll::Ready(Err(e.into())), + } + } + } + + fn poll_flush( + self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> Poll> { + todo!() + } + + fn poll_shutdown( + self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> Poll> { + todo!() + } +} + +/// A stream for reading data asynchronously from an `Endpoint` stream. +#[pin_project] +pub struct ReadStream<'a> { + endpoint: &'a Endpoint, + #[pin] + request: Option>>, +} + +impl<'a> AsyncRead for ReadStream<'a> { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + out_buf: &mut ReadBuf<'_>, + ) -> Poll> { + if let Some(mut req) = self.as_mut().project().request.as_pin_mut() { + let r = match ready!(req.poll_unpin(cx)) { + Ok(n) => { + // Safety: The buffer was filled by the recv operation. + unsafe { + out_buf.assume_init(n); + out_buf.advance(n); + } + Ok(()) + } + Err(e) => Err(e.into()), + }; + self.request = None; + Poll::Ready(r) + } else { + let buf = unsafe { out_buf.unfilled_mut() }; + match self.endpoint.stream_recv_impl(buf) { + Ok(Status::Completed(n_result)) => { + match n_result { + Ok(n) => { + // Safety: The buffer was filled by the recv operation. + unsafe { + out_buf.assume_init(n); + out_buf.advance(n); + } + Poll::Ready(Ok(())) + } + Err(e) => Poll::Ready(Err(e.into())), + } + } + Ok(Status::Scheduled(request_handle)) => { + self.request = Some(request_handle); + cx.waker().wake_by_ref(); + Poll::Pending + } + Err(e) => Poll::Ready(Err(e.into())), + } + } + } +} + +#[cfg(test)] +mod test { + + use super::*; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + #[test_log::test] + fn stream_send_recv() { + spawn_thread!(_stream_send_recv()).join().unwrap(); + } + + async fn _stream_send_recv() { + let context1 = Context::new().unwrap(); + let worker1 = context1.create_worker().unwrap(); + let context2 = Context::new().unwrap(); + let worker2 = context2.create_worker().unwrap(); + tokio::task::spawn_local(worker1.clone().polling()); + tokio::task::spawn_local(worker2.clone().polling()); + + // connect with each other + let mut listener = worker1 + .create_listener("0.0.0.0:0".parse().unwrap()) + .unwrap(); + let listen_port = listener.socket_addr().unwrap().port(); + let mut addr: std::net::SocketAddr = "127.0.0.1:0".parse().unwrap(); + addr.set_port(listen_port); + + let (endpoint1, endpoint2) = tokio::join!( + async { + let conn1 = listener.next().await; + worker1.accept(conn1).await.unwrap() + }, + async { worker2.connect_socket(addr).await.unwrap() }, + ); + + // Test cases: (data, repeat count) + let test_cases = vec![ + (vec![], 1), + (vec![0u8], 10), + (vec![1, 2, 3, 4, 5], 5), + ((0..128).collect::>(), 3), + ((0..1024).map(|x| (x % 256) as u8).collect::>(), 2), + ((0..4096).map(|x| (x % 256) as u8).collect::>(), 1), + ]; + for (data, repeat) in test_cases { + for _ in 0..repeat { + // send + let send_buf = data.clone(); + let recv_len = send_buf.len(); + let mut recv_buf = vec![0u8; recv_len]; + tokio::join!( + async { + endpoint2.write_stream().write_all(&send_buf).await.unwrap(); + }, + async { + endpoint1 + .read_stream() + .read_exact(&mut recv_buf) + .await + .unwrap(); + assert_eq!(recv_buf, send_buf, "data mismatch for len={}", recv_len); + } + ); + } + } + } +} From bed08b8d4c048dd5924250cbd4c55abcc10830a1 Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Sun, 29 Jun 2025 20:07:22 +0800 Subject: [PATCH 17/38] Update endpoint tag and util modules --- src/ucp/endpoint/tag.rs | 103 +++++++++++--------- src/ucp/endpoint/util.rs | 200 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 256 insertions(+), 47 deletions(-) diff --git a/src/ucp/endpoint/tag.rs b/src/ucp/endpoint/tag.rs index 311c30e..1f7b7c3 100644 --- a/src/ucp/endpoint/tag.rs +++ b/src/ucp/endpoint/tag.rs @@ -17,12 +17,22 @@ impl Worker { tag_mask: u64, buf: &mut [MaybeUninit], ) -> Result<(u64, usize), Error> { + match self.tag_recv_impl(tag, tag_mask, buf)? { + Status::Completed(r) => r, + Status::Scheduled(request_handle) => request_handle.await, + } + } + + /// Like `tag_recv`, except that it reads into a slice of buffers. + pub async fn tag_recv_vectored( + &self, + tag: u64, + iov: &mut [IoSliceMut<'_>], + ) -> Result { trace!( - "tag_recv: worker={:?}, tag={}, mask={:#x} len={}", + "tag_recv_vectored: worker={:?} iov.len={}", self.handle, - tag, - tag_mask, - buf.len() + iov.len() ); unsafe extern "C" fn callback( request: *mut c_void, @@ -31,48 +41,50 @@ impl Worker { _user_data: *mut c_void, ) { let length = (*info).length; - let sender_tag = (*info).sender_tag; + let tag = (*info).sender_tag; trace!( - "tag_recv: complete. req={:?}, status={:?}, tag={}, len={}", + "tag_recv_vectored: complete. req={:?}, status={:?}, tag={}, len={}", request, status, - sender_tag, + tag, length ); let request = &mut *(request as *mut Request); request.waker.wake(); } - // Use RequestParam builder - let param = RequestParam::new().cb_tag_recv(Some(callback)); + // Use RequestParam builder for iov + let param = RequestParam::new().cb_tag_recv(Some(callback)).iov(); let status = unsafe { ucp_tag_recv_nbx( self.handle, - buf.as_mut_ptr() as _, - buf.len() as _, + iov.as_ptr() as _, + iov.len() as _, tag, - tag_mask, + u64::max_value(), param.as_ref(), ) }; - Error::from_ptr(status)?; RequestHandle { ptr: status, poll_fn: poll_tag, } .await + .map(|info| info.1) } - /// Like `tag_recv`, except that it reads into a slice of buffers. - pub async fn tag_recv_vectored( + pub(super) fn tag_recv_impl( &self, tag: u64, - iov: &mut [IoSliceMut<'_>], - ) -> Result { + tag_mask: u64, + buf: &mut [MaybeUninit], + ) -> Result, Error> { trace!( - "tag_recv_vectored: worker={:?} iov.len={}", + "tag_recv: worker={:?}, tag={}, mask={:#x} len={}", self.handle, - iov.len() + tag, + tag_mask, + buf.len() ); unsafe extern "C" fn callback( request: *mut c_void, @@ -81,42 +93,34 @@ impl Worker { _user_data: *mut c_void, ) { let length = (*info).length; - let tag = (*info).sender_tag; + let sender_tag = (*info).sender_tag; trace!( - "tag_recv_vectored: complete. req={:?}, status={:?}, tag={}, len={}", + "tag_recv: complete. req={:?}, status={:?}, tag={}, len={}", request, status, - tag, + sender_tag, length ); let request = &mut *(request as *mut Request); request.waker.wake(); } - // Use RequestParam builder for iov - let param = RequestParam::new().cb_tag_recv(Some(callback)).iov(); + let param = RequestParam::new().cb_tag_recv(Some(callback)); let status = unsafe { ucp_tag_recv_nbx( self.handle, - iov.as_ptr() as _, - iov.len() as _, + buf.as_mut_ptr() as _, + buf.len() as _, tag, - u64::max_value(), + tag_mask, param.as_ref(), ) }; - Error::from_ptr(status)?; - RequestHandle { - ptr: status, - poll_fn: poll_tag, - } - .await - .map(|info| info.1) + Ok(Status::from(status, MaybeUninit::uninit(), poll_tag)) } } impl Endpoint { - /// Sends a messages with `tag`. - pub async fn tag_send(&self, tag: u64, buf: &[u8]) -> Result { + pub(super) fn tag_send_impl(&self, tag: u64, buf: &[u8]) -> Result, Error> { trace!("tag_send: endpoint={:?} len={}", self.handle, buf.len()); unsafe extern "C" fn callback( request: *mut c_void, @@ -127,7 +131,6 @@ impl Endpoint { let request = &mut *(request as *mut Request); request.waker.wake(); } - // Use RequestParam builder let param = RequestParam::new().cb_send(Some(callback)); let status = unsafe { ucp_tag_send_nbx( @@ -138,18 +141,24 @@ impl Endpoint { param.as_ref(), ) }; - if status.is_null() { - trace!("tag_send: complete"); - } else if UCS_PTR_IS_PTR(status) { - RequestHandle { - ptr: status, - poll_fn: poll_normal, + Ok(Status::from(status, MaybeUninit::uninit(), poll_normal)) + } + + /// Sends a messages with `tag`. + pub async fn tag_send(&self, tag: u64, buf: &[u8]) -> Result { + match self.tag_send_impl(tag, buf)? { + Status::Completed(r) => { + match &r { + Ok(()) => trace!("tag_send: complete"), + Err(e) => error!("tag_send error : {:?}", e), + } + r.map(|_| buf.len()) + } + Status::Scheduled(request_handle) => { + request_handle.await?; + Ok(buf.len()) } - .await?; - } else { - return Err(Error::from_ptr(status).unwrap_err()); } - Ok(buf.len()) } /// Like `tag_send`, except that it reads into a slice of buffers. diff --git a/src/ucp/endpoint/util.rs b/src/ucp/endpoint/util.rs index 66e6895..a633725 100644 --- a/src/ucp/endpoint/util.rs +++ b/src/ucp/endpoint/util.rs @@ -21,6 +21,35 @@ impl Endpoint { request: None, } } + /// make tag write stream + pub fn tag_write_stream(&self, tag: u64) -> TagWriteStream { + TagWriteStream { + endpoint: self, + tag, + request: None, + } + } +} + +impl Worker { + /// make tag read stream + pub fn tag_read_stream(&self, tag: u64) -> TagReadStream { + TagReadStream { + worker: self, + tag, + tag_mask: u64::max_value(), + request: None, + } + } + /// make tag read stream with mask + pub fn tag_read_stream_mask(&self, tag: u64, tag_mask: u64) -> TagReadStream { + TagReadStream { + worker: self, + tag, + tag_mask, + request: None, + } + } } #[pin_project] @@ -126,6 +155,112 @@ impl<'a> AsyncRead for ReadStream<'a> { } } +#[pin_project] +/// A stream for writing data asynchronously to an `Endpoint` using tag. +pub struct TagWriteStream<'a> { + endpoint: &'a Endpoint, + tag: u64, + #[pin] + request: Option>>, +} + +impl<'a> AsyncWrite for TagWriteStream<'a> { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> Poll> { + if let Some(mut req) = self.as_mut().project().request.as_pin_mut() { + let r = match ready!(req.poll_unpin(cx)) { + Ok(_) => Ok(buf.len()), + Err(e) => Err(e.into()), + }; + self.request = None; + Poll::Ready(r) + } else { + match self.endpoint.tag_send_impl(self.tag, buf) { + Ok(Status::Completed(r)) => Poll::Ready(r.map(|_| buf.len()).map_err(|e| e.into())), + Ok(Status::Scheduled(request_handle)) => { + self.request = Some(request_handle); + Poll::Pending + } + Err(e) => Poll::Ready(Err(e.into())), + } + } + } + + fn poll_flush( + self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> Poll> { + todo!() + } + + fn poll_shutdown( + self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> Poll> { + todo!() + } +} + +/// A stream for reading data asynchronously from a `Worker` using tag. +#[pin_project] +pub struct TagReadStream<'a> { + worker: &'a Worker, + tag: u64, + tag_mask: u64, + #[pin] + request: Option>>, +} + +impl<'a> AsyncRead for TagReadStream<'a> { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + out_buf: &mut ReadBuf<'_>, + ) -> Poll> { + if let Some(mut req) = self.as_mut().project().request.as_pin_mut() { + let r = match ready!(req.poll_unpin(cx)) { + Ok((_, n)) => { + // Safety: The buffer was filled by the recv operation. + unsafe { + out_buf.assume_init(n); + out_buf.advance(n); + } + Ok(()) + } + Err(e) => Err(e.into()), + }; + self.request = None; + Poll::Ready(r) + } else { + let buf = unsafe { out_buf.unfilled_mut() }; + match self.worker.tag_recv_impl(self.tag, self.tag_mask, buf) { + Ok(Status::Completed(n_result)) => { + match n_result { + Ok((_, n)) => { + // Safety: The buffer was filled by the recv operation. + unsafe { + out_buf.assume_init(n); + out_buf.advance(n); + } + Poll::Ready(Ok(())) + } + Err(e) => Poll::Ready(Err(e.into())), + } + } + Ok(Status::Scheduled(request_handle)) => { + self.request = Some(request_handle); + cx.waker().wake_by_ref(); + Poll::Pending + } + Err(e) => Poll::Ready(Err(e.into())), + } + } + } +} + #[cfg(test)] mod test { @@ -137,6 +272,11 @@ mod test { spawn_thread!(_stream_send_recv()).join().unwrap(); } + #[test_log::test] + fn tag_send_recv() { + spawn_thread!(_tag_send_recv()).join().unwrap(); + } + async fn _stream_send_recv() { let context1 = Context::new().unwrap(); let worker1 = context1.create_worker().unwrap(); @@ -192,4 +332,64 @@ mod test { } } } + + async fn _tag_send_recv() { + let context1 = Context::new().unwrap(); + let worker1 = context1.create_worker().unwrap(); + let context2 = Context::new().unwrap(); + let worker2 = context2.create_worker().unwrap(); + tokio::task::spawn_local(worker1.clone().polling()); + tokio::task::spawn_local(worker2.clone().polling()); + + // connect with each other + let mut listener = worker1 + .create_listener("0.0.0.0:0".parse().unwrap()) + .unwrap(); + let listen_port = listener.socket_addr().unwrap().port(); + let mut addr: std::net::SocketAddr = "127.0.0.1:0".parse().unwrap(); + addr.set_port(listen_port); + + let (endpoint1, endpoint2) = tokio::join!( + async { + let conn1 = listener.next().await; + worker1.accept(conn1).await.unwrap() + }, + async { worker2.connect_socket(addr).await.unwrap() }, + ); + + // Test cases: (data, tag, repeat count) + let test_cases = vec![ + (vec![], 1u64, 1), + (vec![0u8], 2u64, 10), + (vec![1, 2, 3, 4, 5], 3u64, 5), + ((0..128).collect::>(), 4u64, 3), + ((0..1024).map(|x| (x % 256) as u8).collect::>(), 5u64, 2), + ((0..4096).map(|x| (x % 256) as u8).collect::>(), 6u64, 1), + ]; + for (data, tag, repeat) in test_cases { + for _ in 0..repeat { + // send + let send_buf = data.clone(); + let recv_len = send_buf.len(); + let mut recv_buf = vec![0u8; recv_len]; + tokio::join!( + async { + endpoint2.tag_write_stream(tag).write_all(&send_buf).await.unwrap(); + }, + async { + worker1 + .tag_read_stream(tag) + .read_exact(&mut recv_buf) + .await + .unwrap(); + assert_eq!(recv_buf, send_buf, "data mismatch for tag={}, len={}", tag, recv_len); + } + ); + } + } + + // Clean up + let _ = endpoint1.close(false).await; + let _ = endpoint2.close(false).await; + } } From 5abf8d40da7c12507dc442cbca96e6512f4eec4a Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Sun, 29 Jun 2025 20:47:20 +0800 Subject: [PATCH 18/38] update readme and changlog --- CHANGELOG.md | 17 ++++++++++ README.md | 88 ++++++++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 102 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6f49d8a..62f2c51 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,23 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.2.0] - 2025-06-29 + +### Added + +- Added AsyncRead/AsyncWrite support for Tag and Stream (requires `utils` feature flag) +- Added `connect_addr_vec` function to Worker + +### Changed + +- Updated to UCX 1.18 with latest API compatibility +- Updated multiple dependency versions +- Migrated to Rust 2021 edition + +### Fixed + +- Fixed various bugs and issues + ## [0.1.1] - 2022-09-01 ### Changed diff --git a/README.md b/README.md index 2f47499..04e7414 100644 --- a/README.md +++ b/README.md @@ -4,12 +4,94 @@ [![Docs](https://docs.rs/async-ucx/badge.svg)](https://docs.rs/async-ucx) [![CI](https://github.com/madsys-dev/async-ucx/workflows/CI/badge.svg?branch=main)](https://github.com/madsys-dev/async-ucx/actions) -Async Rust UCX bindings. +Async Rust UCX bindings providing high-performance networking capabilities for distributed systems and HPC applications. + +## Features + +- **Asynchronous UCP Operations**: Full async/await support for UCX operations +- **Multiple Communication Models**: Support for RMA, Stream, Tag, and Active Message APIs +- **High Performance**: Optimized for low-latency, high-throughput communication +- **Tokio Integration**: Seamless integration with Tokio async runtime +- **Comprehensive Examples**: Ready-to-use examples for various UCX patterns ## Optional features -- `event`: Enable UCP wakeup mechanism. -- `am`: Enable UCP Active Message API. +- `event`: Enable UCP wakeup mechanism for event-driven applications +- `am`: Enable UCP Active Message API for flexible message handling +- `util`: Enable additional utility functions for UCX integration + +## Quick Start + +Add to your `Cargo.toml`: + +```toml +[dependencies] +async-ucx = "0.2" +tokio = { version = "1.0", features = ["rt", "net"] } +``` + +Basic usage example: + +```rust +use async_ucx::ucp::*; +use std::mem::MaybeUninit; +use std::net::SocketAddr; + +#[tokio::main(flavor = "current_thread")] +async fn main() -> Result<(), Box> { + // Create UCP contexts and workers + let context1 = Context::new()?; + let worker1 = context1.create_worker()?; + let context2 = Context::new()?; + let worker2 = context2.create_worker()?; + + // Start polling for both workers + tokio::task::spawn_local(worker1.clone().polling()); + tokio::task::spawn_local(worker2.clone().polling()); + + // Create listener on worker1 + let mut listener = worker1 + .create_listener("0.0.0.0:0".parse().unwrap())?; + let listen_port = listener.socket_addr()?.port(); + + // Connect worker2 to worker1 + let mut addr: SocketAddr = "127.0.0.1:0".parse().unwrap(); + addr.set_port(listen_port); + + let (endpoint1, endpoint2) = tokio::join!( + async { + let conn1 = listener.next().await; + worker1.accept(conn1).await.unwrap() + }, + async { worker2.connect_socket(addr).await.unwrap() }, + ); + + // Send and receive tag message + tokio::join!( + async { + let msg = b"Hello UCX!"; + endpoint2.tag_send(1, msg).await.unwrap(); + println!("Message sent"); + }, + async { + let mut buf = vec![MaybeUninit::::uninit(); 10]; + worker1.tag_recv(1, &mut buf).await.unwrap(); + println!("Message received"); + } + ); + + Ok(()) +} +``` + +## Examples + +Check the `examples/` directory for comprehensive examples: +- `rma.rs`: Remote Memory Access operations +- `stream.rs`: Stream-based communication +- `tag.rs`: Tag-based message matching +- `bench.rs`: Performance benchmarking +- `bench-multi-thread.rs`: Multi-threaded benchmarking ## License From 2150e47a11b7c1761251b19ffc4b1bb89581a98f Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Sun, 29 Jun 2025 20:51:44 +0800 Subject: [PATCH 19/38] fix remote println! --- src/ucp/endpoint/mod.rs | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/ucp/endpoint/mod.rs b/src/ucp/endpoint/mod.rs index f86e5c9..ffd9f22 100644 --- a/src/ucp/endpoint/mod.rs +++ b/src/ucp/endpoint/mod.rs @@ -93,16 +93,12 @@ impl Endpoint { let ptr = Weak::into_raw(weak); unsafe extern "C" fn callback(arg: *mut c_void, ep: ucp_ep_h, status: ucs_status_t) { let weak: Weak = Weak::from_raw(arg as _); - println!("error callback"); if let Some(inner) = weak.upgrade() { inner.set_status(status); // don't drop weak reference - println!("no drop"); - // panic!("{:?}",Error::from_status(status)); std::mem::forget(weak); } else { // no strong rc, force close endpoint here - println!("failed"); let status = ucp_ep_close_nb(ep, ucp_ep_close_mode::UCP_EP_CLOSE_MODE_FORCE as _); let _ = Error::from_ptr(status) .map_err(|err| error!("Force close endpoint failed, {}", err)); From 35ce5c43031a57e03605d1fdbb5a5ebaf896d978 Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Sun, 29 Jun 2025 21:01:22 +0800 Subject: [PATCH 20/38] cargo fmt --- src/ucp/endpoint/util.rs | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/src/ucp/endpoint/util.rs b/src/ucp/endpoint/util.rs index a633725..c404aa3 100644 --- a/src/ucp/endpoint/util.rs +++ b/src/ucp/endpoint/util.rs @@ -363,8 +363,16 @@ mod test { (vec![0u8], 2u64, 10), (vec![1, 2, 3, 4, 5], 3u64, 5), ((0..128).collect::>(), 4u64, 3), - ((0..1024).map(|x| (x % 256) as u8).collect::>(), 5u64, 2), - ((0..4096).map(|x| (x % 256) as u8).collect::>(), 6u64, 1), + ( + (0..1024).map(|x| (x % 256) as u8).collect::>(), + 5u64, + 2, + ), + ( + (0..4096).map(|x| (x % 256) as u8).collect::>(), + 6u64, + 1, + ), ]; for (data, tag, repeat) in test_cases { for _ in 0..repeat { @@ -374,7 +382,11 @@ mod test { let mut recv_buf = vec![0u8; recv_len]; tokio::join!( async { - endpoint2.tag_write_stream(tag).write_all(&send_buf).await.unwrap(); + endpoint2 + .tag_write_stream(tag) + .write_all(&send_buf) + .await + .unwrap(); }, async { worker1 @@ -382,7 +394,11 @@ mod test { .read_exact(&mut recv_buf) .await .unwrap(); - assert_eq!(recv_buf, send_buf, "data mismatch for tag={}, len={}", tag, recv_len); + assert_eq!( + recv_buf, send_buf, + "data mismatch for tag={}, len={}", + tag, recv_len + ); } ); } From 09b79fdbee4673842f98f12aa898042c7657597e Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Sun, 29 Jun 2025 21:45:56 +0800 Subject: [PATCH 21/38] feat: improve tag receive implementation with ucp_tag_recv_info --- src/ucp/endpoint/param.rs | 6 +++ src/ucp/endpoint/tag.rs | 82 +++++++++++++++++++++++++++++++++++---- src/ucp/endpoint/util.rs | 15 +++---- 3 files changed, 88 insertions(+), 15 deletions(-) diff --git a/src/ucp/endpoint/param.rs b/src/ucp/endpoint/param.rs index 5bac074..92ea65c 100644 --- a/src/ucp/endpoint/param.rs +++ b/src/ucp/endpoint/param.rs @@ -42,6 +42,12 @@ impl RequestParam { self } + pub fn recv_tag_info(mut self, info: *mut ucp_tag_recv_info) -> Self { + self.inner.op_attr_mask |= ucp_op_attr_t::UCP_OP_ATTR_FIELD_RECV_INFO as u32; + self.inner.recv_info.tag_info = info; + self + } + #[cfg(feature = "am")] pub fn cb_recv_am(mut self, callback: ucp_am_recv_data_nbx_callback_t) -> Self { self.inner.op_attr_mask |= ucp_op_attr_t::UCP_OP_ATTR_FIELD_CALLBACK as u32; diff --git a/src/ucp/endpoint/tag.rs b/src/ucp/endpoint/tag.rs index 1f7b7c3..ae55fcb 100644 --- a/src/ucp/endpoint/tag.rs +++ b/src/ucp/endpoint/tag.rs @@ -18,8 +18,11 @@ impl Worker { buf: &mut [MaybeUninit], ) -> Result<(u64, usize), Error> { match self.tag_recv_impl(tag, tag_mask, buf)? { - Status::Completed(r) => r, - Status::Scheduled(request_handle) => request_handle.await, + Status::Completed(r) => r.map(|info| (info.sender_tag, info.length as usize)), + Status::Scheduled(request_handle) => { + let info = request_handle.await?; + Ok((info.sender_tag, info.length as usize)) + } } } @@ -70,7 +73,7 @@ impl Worker { poll_fn: poll_tag, } .await - .map(|info| info.1) + .map(|info| info.length as usize) } pub(super) fn tag_recv_impl( @@ -78,7 +81,7 @@ impl Worker { tag: u64, tag_mask: u64, buf: &mut [MaybeUninit], - ) -> Result, Error> { + ) -> Result, Error> { trace!( "tag_recv: worker={:?}, tag={}, mask={:#x} len={}", self.handle, @@ -104,7 +107,10 @@ impl Worker { let request = &mut *(request as *mut Request); request.waker.wake(); } - let param = RequestParam::new().cb_tag_recv(Some(callback)); + let mut info = MaybeUninit::::uninit(); + let param = RequestParam::new() + .cb_tag_recv(Some(callback)) + .recv_tag_info(info.as_mut_ptr() as _); let status = unsafe { ucp_tag_recv_nbx( self.handle, @@ -115,7 +121,7 @@ impl Worker { param.as_ref(), ) }; - Ok(Status::from(status, MaybeUninit::uninit(), poll_tag)) + Ok(Status::from(status, info, poll_tag)) } } @@ -208,14 +214,14 @@ impl Endpoint { } } -fn poll_tag(ptr: ucs_status_ptr_t) -> Poll> { +fn poll_tag(ptr: ucs_status_ptr_t) -> Poll> { let mut info = MaybeUninit::::uninit(); let status = unsafe { ucp_tag_recv_request_test(ptr as _, info.as_mut_ptr() as _) }; match status { ucs_status_t::UCS_INPROGRESS => Poll::Pending, ucs_status_t::UCS_OK => { let info = unsafe { info.assume_init() }; - Poll::Ready(Ok((info.sender_tag, info.length as usize))) + Poll::Ready(Ok(info)) } status => Poll::Ready(Err(Error::from_error(status))), } @@ -281,4 +287,64 @@ mod tests { assert_eq!(endpoint2.close(true).await, Ok(())); assert_eq!(endpoint2.get_rc(), (1, 0)); } + + #[test_log::test] + fn multi_tag() { + for i in 0..20_usize { + spawn_thread!(_multi_tag(4 << i)).join().unwrap(); + } + } + + async fn _multi_tag(msg_size: usize) { + let context1 = Context::new().unwrap(); + let worker1 = context1.create_worker().unwrap(); + let context2 = Context::new().unwrap(); + let worker2 = context2.create_worker().unwrap(); + tokio::task::spawn_local(worker1.clone().polling()); + tokio::task::spawn_local(worker2.clone().polling()); + + // connect with each other + let mut listener = worker1 + .create_listener("0.0.0.0:0".parse().unwrap()) + .unwrap(); + let listen_port = listener.socket_addr().unwrap().port(); + println!("listen at port {}", listen_port); + let mut addr: SocketAddr = "127.0.0.1:0".parse().unwrap(); + addr.set_port(listen_port); + + let (endpoint1, endpoint2) = tokio::join!( + async { + let conn1 = listener.next().await; + worker1.accept(conn1).await.unwrap() + }, + async { worker2.connect_socket(addr).await.unwrap() }, + ); + + // send tag message + tokio::join!( + async { + // send + let mut buf = vec![0; msg_size]; + endpoint2.tag_send(3, &mut buf).await.unwrap(); + println!("tag sended"); + }, + async { + // recv + let mut buf = vec![MaybeUninit::::uninit(); msg_size]; + let (tag, size) = worker1.tag_recv_mask(0, 0, &mut buf).await.unwrap(); + assert_eq!(size, msg_size); + assert_eq!(tag, 3); + println!("tag recved"); + } + ); + + assert_eq!(endpoint1.get_rc(), (1, 1)); + assert_eq!(endpoint2.get_rc(), (1, 1)); + assert_eq!(endpoint1.close(false).await, Ok(())); + assert_eq!(endpoint2.close(false).await, Err(Error::ConnectionReset)); + assert_eq!(endpoint1.get_rc(), (1, 0)); + assert_eq!(endpoint2.get_rc(), (1, 1)); + assert_eq!(endpoint2.close(true).await, Ok(())); + assert_eq!(endpoint2.get_rc(), (1, 0)); + } } diff --git a/src/ucp/endpoint/util.rs b/src/ucp/endpoint/util.rs index c404aa3..7fcf254 100644 --- a/src/ucp/endpoint/util.rs +++ b/src/ucp/endpoint/util.rs @@ -42,6 +42,7 @@ impl Worker { } } /// make tag read stream with mask + /// not suggested to use this function, because actual received tag should be checked by user pub fn tag_read_stream_mask(&self, tag: u64, tag_mask: u64) -> TagReadStream { TagReadStream { worker: self, @@ -211,7 +212,7 @@ pub struct TagReadStream<'a> { tag: u64, tag_mask: u64, #[pin] - request: Option>>, + request: Option>>, } impl<'a> AsyncRead for TagReadStream<'a> { @@ -222,11 +223,11 @@ impl<'a> AsyncRead for TagReadStream<'a> { ) -> Poll> { if let Some(mut req) = self.as_mut().project().request.as_pin_mut() { let r = match ready!(req.poll_unpin(cx)) { - Ok((_, n)) => { + Ok(info) => { // Safety: The buffer was filled by the recv operation. unsafe { - out_buf.assume_init(n); - out_buf.advance(n); + out_buf.assume_init(info.length as usize); + out_buf.advance(info.length as usize); } Ok(()) } @@ -239,11 +240,11 @@ impl<'a> AsyncRead for TagReadStream<'a> { match self.worker.tag_recv_impl(self.tag, self.tag_mask, buf) { Ok(Status::Completed(n_result)) => { match n_result { - Ok((_, n)) => { + Ok(info) => { // Safety: The buffer was filled by the recv operation. unsafe { - out_buf.assume_init(n); - out_buf.advance(n); + out_buf.assume_init(info.length as usize); + out_buf.advance(info.length as usize); } Poll::Ready(Ok(())) } From 59a03112c869c6fe810cc75e80e3bf798e282222 Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Thu, 3 Jul 2025 22:01:15 +0800 Subject: [PATCH 22/38] impl poll_flush and poll_shutdown --- src/ucp/endpoint/util.rs | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/src/ucp/endpoint/util.rs b/src/ucp/endpoint/util.rs index 7fcf254..57480ff 100644 --- a/src/ucp/endpoint/util.rs +++ b/src/ucp/endpoint/util.rs @@ -87,17 +87,23 @@ impl<'a> AsyncWrite for WriteStream<'a> { } fn poll_flush( - self: Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, ) -> Poll> { - todo!() + if let Some(mut req) = self.as_mut().project().request.as_pin_mut() { + let r = ready!(req.poll_unpin(cx)); + self.request = None; + Poll::Ready(r.map_err(|e| e.into())) + } else { + Poll::Ready(Ok(())) + } } fn poll_shutdown( self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>, ) -> Poll> { - todo!() + Poll::Ready(Ok(())) } } @@ -191,17 +197,23 @@ impl<'a> AsyncWrite for TagWriteStream<'a> { } fn poll_flush( - self: Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, ) -> Poll> { - todo!() + if let Some(mut req) = self.as_mut().project().request.as_pin_mut() { + let r = ready!(req.poll_unpin(cx)); + self.request = None; + Poll::Ready(r.map_err(|e| e.into())) + } else { + Poll::Ready(Ok(())) + } } fn poll_shutdown( self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>, ) -> Poll> { - todo!() + Poll::Ready(Ok(())) } } From a1fc9b9c1901e6ff34cfba0f98616e745a64ef97 Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Tue, 15 Jul 2025 15:37:59 +0800 Subject: [PATCH 23/38] endpoint handler --- src/ucp/endpoint/mod.rs | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/ucp/endpoint/mod.rs b/src/ucp/endpoint/mod.rs index ffd9f22..c563956 100644 --- a/src/ucp/endpoint/mod.rs +++ b/src/ucp/endpoint/mod.rs @@ -83,6 +83,12 @@ pub struct Endpoint { inner: Rc, } +/// Type alias of Endpoint handler +#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord)] +pub struct EndpointHandler(ucp_ep_h); +unsafe impl Sync for EndpointHandler {} +unsafe impl Send for EndpointHandler {} + impl Endpoint { fn create(worker: &Rc, mut params: ucp_ep_params) -> Result { let inner = Rc::new(EndpointInner::new(worker.clone())); @@ -205,6 +211,11 @@ impl Endpoint { Ok(self.handle) } + /// Get the endpoint handler + pub fn handler(&self) -> Result { + Ok(EndpointHandler(self.get_handle()?)) + } + /// Print endpoint information to stderr. pub fn print_to_stderr(&self) { if !self.inner.is_closed() { From 965c7e819de42eaa7559a370fab738407a8b2236 Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Tue, 15 Jul 2025 15:49:40 +0800 Subject: [PATCH 24/38] add reply_ep --- src/ucp/endpoint/am.rs | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/ucp/endpoint/am.rs b/src/ucp/endpoint/am.rs index 378b86f..6fdc2b9 100644 --- a/src/ucp/endpoint/am.rs +++ b/src/ucp/endpoint/am.rs @@ -295,6 +295,15 @@ impl<'a> AmMsg<'a> { && !self.msg.reply_ep.is_null() } + /// return endpoint handler + pub fn reply_ep(&self) -> Option { + if self.need_reply() { + Some(EndpointHandler(self.msg.reply_ep)) + } else { + None + } + } + /// Send reply /// # Safety /// User needs to ensure that the endpoint isn't closed. From 3263c6ecd38071cbb663cf2bb625f3b04cf38623 Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Wed, 16 Jul 2025 17:39:28 +0800 Subject: [PATCH 25/38] am id is u16 --- src/ucp/endpoint/am.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/ucp/endpoint/am.rs b/src/ucp/endpoint/am.rs index 6fdc2b9..f8c949b 100644 --- a/src/ucp/endpoint/am.rs +++ b/src/ucp/endpoint/am.rs @@ -309,7 +309,7 @@ impl<'a> AmMsg<'a> { /// User needs to ensure that the endpoint isn't closed. pub async unsafe fn reply( &self, - id: u32, + id: u16, header: &[u8], data: &[u8], need_reply: bool, @@ -327,7 +327,7 @@ impl<'a> AmMsg<'a> { /// User needs to ensure that the endpoint isn't closed. pub async unsafe fn reply_vectorized( &self, - id: u32, + id: u16, header: &[u8], data: &[IoSlice<'_>], need_reply: bool, @@ -497,7 +497,7 @@ impl Endpoint { /// Send active message. pub async fn am_send( &self, - id: u32, + id: u16, header: &[u8], data: &[u8], need_reply: bool, @@ -511,7 +511,7 @@ impl Endpoint { /// Send active message. pub async fn am_send_vectorized( &self, - id: u32, + id: u16, header: &[u8], data: &[IoSlice<'_>], need_reply: bool, @@ -534,7 +534,7 @@ pub enum AmProto { async fn am_send( endpoint: ucp_ep_h, - id: u32, + id: u16, header: &[u8], data: &[IoSlice<'_>], need_reply: bool, @@ -567,7 +567,7 @@ async fn am_send( let status = unsafe { ucp_am_send_nbx( endpoint, - id, + id as u32, header.as_ptr() as _, header.len() as _, buffer as _, From 0107ef0f8c4a2a10748387f4928f89cbb575b512 Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Mon, 21 Jul 2025 12:03:43 +0800 Subject: [PATCH 26/38] fix TagWriteStream --- src/ucp/endpoint/util.rs | 50 ++++++++++++++++++++++------------------ 1 file changed, 27 insertions(+), 23 deletions(-) diff --git a/src/ucp/endpoint/util.rs b/src/ucp/endpoint/util.rs index 57480ff..ac372d7 100644 --- a/src/ucp/endpoint/util.rs +++ b/src/ucp/endpoint/util.rs @@ -177,36 +177,40 @@ impl<'a> AsyncWrite for TagWriteStream<'a> { cx: &mut std::task::Context<'_>, buf: &[u8], ) -> Poll> { - if let Some(mut req) = self.as_mut().project().request.as_pin_mut() { - let r = match ready!(req.poll_unpin(cx)) { - Ok(_) => Ok(buf.len()), - Err(e) => Err(e.into()), - }; - self.request = None; - Poll::Ready(r) - } else { - match self.endpoint.tag_send_impl(self.tag, buf) { - Ok(Status::Completed(r)) => Poll::Ready(r.map(|_| buf.len()).map_err(|e| e.into())), - Ok(Status::Scheduled(request_handle)) => { - self.request = Some(request_handle); - Poll::Pending + loop { + if let Some(mut req) = self.as_mut().project().request.as_pin_mut() { + let r = match ready!(req.poll_unpin(cx)) { + Ok(_) => Ok(buf.len()), + Err(e) => Err(e.into()), + }; + self.request = None; + return Poll::Ready(r); + } else { + match self.endpoint.tag_send_impl(self.tag, buf) { + Ok(Status::Completed(r)) => { + return Poll::Ready(r.map(|_| buf.len()).map_err(|e| e.into())); + } + Ok(Status::Scheduled(request_handle)) => { + self.request = Some(request_handle); + } + Err(e) => return Poll::Ready(Err(e.into())), } - Err(e) => Poll::Ready(Err(e.into())), } } } fn poll_flush( - mut self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, + self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, ) -> Poll> { - if let Some(mut req) = self.as_mut().project().request.as_pin_mut() { - let r = ready!(req.poll_unpin(cx)); - self.request = None; - Poll::Ready(r.map_err(|e| e.into())) - } else { - Poll::Ready(Ok(())) - } + assert!(self.request.is_none()); + // if let Some(mut req) = self.as_mut().project().request.as_pin_mut() { + // let r = ready!(req.poll_unpin(cx)); + // self.request = None; + // Poll::Ready(r.map_err(|e| e.into())) + // } else { + Poll::Ready(Ok(())) + // } } fn poll_shutdown( From 5ac7cd1860020824a3886c50cad8bf83d3231725 Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Mon, 21 Jul 2025 12:05:56 +0800 Subject: [PATCH 27/38] fix WriteStream --- src/ucp/endpoint/util.rs | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/src/ucp/endpoint/util.rs b/src/ucp/endpoint/util.rs index ac372d7..0b0db5e 100644 --- a/src/ucp/endpoint/util.rs +++ b/src/ucp/endpoint/util.rs @@ -67,21 +67,24 @@ impl<'a> AsyncWrite for WriteStream<'a> { cx: &mut std::task::Context<'_>, buf: &[u8], ) -> Poll> { - if let Some(mut req) = self.as_mut().project().request.as_pin_mut() { - let r = match ready!(req.poll_unpin(cx)) { - Ok(_) => Ok(buf.len()), - Err(e) => Err(e.into()), - }; - self.request = None; - Poll::Ready(r) - } else { - match self.endpoint.stream_send_impl(buf) { - Ok(Status::Completed(r)) => Poll::Ready(r.map(|_| buf.len()).map_err(|e| e.into())), - Ok(Status::Scheduled(request_handle)) => { - self.request = Some(request_handle); - Poll::Pending + loop { + if let Some(mut req) = self.as_mut().project().request.as_pin_mut() { + let r = match ready!(req.poll_unpin(cx)) { + Ok(_) => Ok(buf.len()), + Err(e) => Err(e.into()), + }; + self.request = None; + return Poll::Ready(r); + } else { + match self.endpoint.stream_send_impl(buf) { + Ok(Status::Completed(r)) => { + return Poll::Ready(r.map(|_| buf.len()).map_err(|e| e.into())) + } + Ok(Status::Scheduled(request_handle)) => { + self.request = Some(request_handle); + } + Err(e) => return Poll::Ready(Err(e.into())), } - Err(e) => Poll::Ready(Err(e.into())), } } } From c8ec53712f1842dd71e5a5b16006f27438d85abf Mon Sep 17 00:00:00 2001 From: Kaiwei Li Date: Wed, 24 Sep 2025 13:11:38 +0800 Subject: [PATCH 28/38] Create rust.yml --- .github/workflows/rust.yml | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 .github/workflows/rust.yml diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml new file mode 100644 index 0000000..9fd45e0 --- /dev/null +++ b/.github/workflows/rust.yml @@ -0,0 +1,22 @@ +name: Rust + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +env: + CARGO_TERM_COLOR: always + +jobs: + build: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - name: Build + run: cargo build --verbose + - name: Run tests + run: cargo test --verbose From 62f4afadfa0a25cad1a1d9aab96105645644903b Mon Sep 17 00:00:00 2001 From: Kaiwei Li Date: Wed, 24 Sep 2025 13:30:22 +0800 Subject: [PATCH 29/38] cargo clippy --fix --- src/lib.rs | 8 ++++---- src/ucp/endpoint/am.rs | 20 ++++++++++---------- src/ucp/endpoint/tag.rs | 4 ++-- src/ucp/endpoint/util.rs | 8 ++++---- src/ucp/mod.rs | 2 +- src/ucp/worker.rs | 2 +- 6 files changed, 22 insertions(+), 22 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 87baf84..23ac9fe 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -164,10 +164,10 @@ impl Error { } } -impl Into for Error { - fn into(self) -> std::io::Error { +impl From for std::io::Error { + fn from(val: Error) -> Self { use std::io::ErrorKind::*; - let kind = match self { + let kind = match val { Error::Inprogress => WouldBlock, Error::NoMessage => WouldBlock, Error::NoReource => WouldBlock, @@ -201,6 +201,6 @@ impl Into for Error { Error::EndpointTimeout => TimedOut, Error::Unknown => Other, }; - std::io::Error::new(kind, self) + std::io::Error::new(kind, val) } } diff --git a/src/ucp/endpoint/am.rs b/src/ucp/endpoint/am.rs index f8c949b..046098e 100644 --- a/src/ucp/endpoint/am.rs +++ b/src/ucp/endpoint/am.rs @@ -9,7 +9,7 @@ use std::{ sync::atomic::AtomicBool, }; -//// Active message protocol. +/// Active message protocol. /// Active message protocol is a mechanism for sending and receiving messages /// between processes in a distributed system. /// It allows a process to send a message to another process, which can then @@ -222,7 +222,7 @@ impl<'a> AmMsg<'a> { } Some(AmData::Data(data)) => { // data message, no need to receive - let size = copy_data_to_iov(&data, iov)?; + let size = copy_data_to_iov(data, iov)?; self.drop_msg(AmData::Data(data)); Ok(size) } @@ -439,8 +439,8 @@ impl Worker { param: *const ucp_am_recv_param_t, ) -> ucs_status_t { let handler = &*(arg as *const AmStreamInner); - let header = slice::from_raw_parts(header as *const u8, header_len as usize); - let data = slice::from_raw_parts(data as *const u8, data_len as usize); + let header = slice::from_raw_parts(header as *const u8, header_len); + let data = slice::from_raw_parts(data as *const u8, data_len); let param = &*param; handler.callback(header, data, param.reply_ep, param.recv_attr); @@ -460,7 +460,7 @@ impl Worker { } self.am_streams.write().unwrap().insert(id, stream.clone()); - return Ok(AmStream::new(self, stream)); + Ok(AmStream::new(self, stream)) } /// Register active message handler for `id`. @@ -596,7 +596,7 @@ mod tests { #[test_log::test] fn am() { - let protos = vec![None, Some(AmProto::Eager), Some(AmProto::Rndv)]; + let protos = [None, Some(AmProto::Eager), Some(AmProto::Rndv)]; for block_size_shift in 0..20_usize { for p in protos.iter() { let rt = tokio::runtime::Builder::new_current_thread() @@ -650,13 +650,13 @@ mod tests { let msg = stream1.wait_msg().await; let mut msg = msg.expect("no msg"); assert_eq!(msg.header(), &header); - assert_eq!(msg.contains_data(), true); + assert!(msg.contains_data()); assert_eq!(msg.data_len(), data.len()); let mut recv_data = vec![0_u8; msg.data_len()]; let recv_len = msg.recv_data_single(&mut recv_data).await.unwrap(); assert_eq!(data.len(), recv_len); assert_eq!(data, recv_data); - assert_eq!(msg.contains_data(), false); + assert!(!msg.contains_data()); msg } ); @@ -674,11 +674,11 @@ mod tests { let reply = stream2.wait_msg().await; let mut reply = reply.expect("no reply"); assert_eq!(reply.header(), &header); - assert_eq!(reply.contains_data(), true); + assert!(reply.contains_data()); assert_eq!(reply.data_len(), data.len()); let recv_data = reply.recv_data().await.unwrap(); assert_eq!(data, recv_data); - assert_eq!(reply.contains_data(), false); + assert!(!reply.contains_data()); } ); diff --git a/src/ucp/endpoint/tag.rs b/src/ucp/endpoint/tag.rs index ae55fcb..5d5e75c 100644 --- a/src/ucp/endpoint/tag.rs +++ b/src/ucp/endpoint/tag.rs @@ -18,7 +18,7 @@ impl Worker { buf: &mut [MaybeUninit], ) -> Result<(u64, usize), Error> { match self.tag_recv_impl(tag, tag_mask, buf)? { - Status::Completed(r) => r.map(|info| (info.sender_tag, info.length as usize)), + Status::Completed(r) => r.map(|info| (info.sender_tag, info.length)), Status::Scheduled(request_handle) => { let info = request_handle.await?; Ok((info.sender_tag, info.length as usize)) @@ -73,7 +73,7 @@ impl Worker { poll_fn: poll_tag, } .await - .map(|info| info.length as usize) + .map(|info| info.length) } pub(super) fn tag_recv_impl( diff --git a/src/ucp/endpoint/util.rs b/src/ucp/endpoint/util.rs index 0b0db5e..b26ee34 100644 --- a/src/ucp/endpoint/util.rs +++ b/src/ucp/endpoint/util.rs @@ -245,8 +245,8 @@ impl<'a> AsyncRead for TagReadStream<'a> { Ok(info) => { // Safety: The buffer was filled by the recv operation. unsafe { - out_buf.assume_init(info.length as usize); - out_buf.advance(info.length as usize); + out_buf.assume_init(info.length); + out_buf.advance(info.length); } Ok(()) } @@ -262,8 +262,8 @@ impl<'a> AsyncRead for TagReadStream<'a> { Ok(info) => { // Safety: The buffer was filled by the recv operation. unsafe { - out_buf.assume_init(info.length as usize); - out_buf.advance(info.length as usize); + out_buf.assume_init(info.length); + out_buf.advance(info.length); } Poll::Ready(Ok(())) } diff --git a/src/ucp/mod.rs b/src/ucp/mod.rs index 440e6be..8f003d8 100644 --- a/src/ucp/mod.rs +++ b/src/ucp/mod.rs @@ -91,7 +91,7 @@ impl Context { | ucp_params_field::UCP_PARAM_FIELD_MT_WORKERS_SHARED) .0 as u64, features: features.0 as u64, - request_size: std::mem::size_of::() as usize, + request_size: std::mem::size_of::(), request_init: Some(Request::init), request_cleanup: Some(Request::cleanup), mt_workers_shared: 1, diff --git a/src/ucp/worker.rs b/src/ucp/worker.rs index d1ad850..1f142e7 100644 --- a/src/ucp/worker.rs +++ b/src/ucp/worker.rs @@ -107,7 +107,7 @@ impl Worker { Ok(WorkerAddress { handle: unsafe { handle.assume_init() }, - length: unsafe { length.assume_init() } as usize, + length: unsafe { length.assume_init() }, worker: self, }) } From 50f1487cb2a54f0b7173903b0220f350d1e66cca Mon Sep 17 00:00:00 2001 From: Kaiwei Li Date: Wed, 24 Sep 2025 13:33:31 +0800 Subject: [PATCH 30/38] cargo fix --- src/ucp/endpoint/util.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/ucp/endpoint/util.rs b/src/ucp/endpoint/util.rs index b26ee34..b007fb2 100644 --- a/src/ucp/endpoint/util.rs +++ b/src/ucp/endpoint/util.rs @@ -8,21 +8,21 @@ use tokio::io::ReadBuf; impl Endpoint { /// make write stream - pub fn write_stream(&self) -> WriteStream { + pub fn write_stream(&self) -> WriteStream<'_> { WriteStream { endpoint: self, request: None, } } /// make read stream - pub fn read_stream(&self) -> ReadStream { + pub fn read_stream(&self) -> ReadStream<'_> { ReadStream { endpoint: self, request: None, } } /// make tag write stream - pub fn tag_write_stream(&self, tag: u64) -> TagWriteStream { + pub fn tag_write_stream(&self, tag: u64) -> TagWriteStream<'_> { TagWriteStream { endpoint: self, tag, @@ -33,7 +33,7 @@ impl Endpoint { impl Worker { /// make tag read stream - pub fn tag_read_stream(&self, tag: u64) -> TagReadStream { + pub fn tag_read_stream(&self, tag: u64) -> TagReadStream<'_> { TagReadStream { worker: self, tag, @@ -43,7 +43,7 @@ impl Worker { } /// make tag read stream with mask /// not suggested to use this function, because actual received tag should be checked by user - pub fn tag_read_stream_mask(&self, tag: u64, tag_mask: u64) -> TagReadStream { + pub fn tag_read_stream_mask(&self, tag: u64, tag_mask: u64) -> TagReadStream<'_> { TagReadStream { worker: self, tag, From b80b161a6235568e6f09807ab5eab99b8d4b466d Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Wed, 24 Sep 2025 16:36:38 +0800 Subject: [PATCH 31/38] feat(ucx1-sys): bump version to 0.2.0 --- Cargo.toml | 2 +- ucx1-sys/Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 429dc7c..e7b10c6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,7 +18,7 @@ am = ["tokio/sync", "crossbeam"] util = ["tokio"] [dependencies] -ucx1-sys = { version = "0.1", path = "ucx1-sys" } +ucx1-sys = { version = "0.2", path = "ucx1-sys" } socket2 = "0.4" futures = "0.3" futures-lite = "1.11" diff --git a/ucx1-sys/Cargo.toml b/ucx1-sys/Cargo.toml index d71ff71..6b0b713 100644 --- a/ucx1-sys/Cargo.toml +++ b/ucx1-sys/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ucx1-sys" -version = "0.1.0" +version = "0.2.0" authors = ["Runji Wang "] edition = "2021" description = "Rust FFI bindings to UCX." From 146a73b8045c4348e15b5cf005917915a6df8707 Mon Sep 17 00:00:00 2001 From: Kaiwei Li Date: Thu, 25 Sep 2025 17:23:50 +0800 Subject: [PATCH 32/38] fix testing on rust 1.90.0 --- src/ucp/endpoint/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ucp/endpoint/mod.rs b/src/ucp/endpoint/mod.rs index c563956..dd2958d 100644 --- a/src/ucp/endpoint/mod.rs +++ b/src/ucp/endpoint/mod.rs @@ -352,7 +352,7 @@ impl Status { immediate: MaybeUninit, poll_fn: fn(ucs_status_ptr_t) -> Poll>, ) -> Self { - if UCS_PTR_RAW_STATUS(status) == ucs_status_t::UCS_OK { + if status.is_null() { Self::Completed(Ok(unsafe { immediate.assume_init() })) } else if UCS_PTR_IS_ERR(status) { Self::Completed(Err(Error::from_error(UCS_PTR_RAW_STATUS(status)))) From 60726a1e219a559fff07f0191479b30b2cb0d1eb Mon Sep 17 00:00:00 2001 From: Kaiwei Li Date: Tue, 30 Sep 2025 15:45:16 +0800 Subject: [PATCH 33/38] build(deps): add alloc feature to crossbeam dependency --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index e7b10c6..afa3184 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,7 +25,7 @@ futures-lite = "1.11" lazy_static = "1.4" log = "0.4" tokio = { version = "1.0", features = ["net"], optional = true } -crossbeam = { version = "0.8", optional = true } +crossbeam = { version = "0.8", features = ["alloc"], optional = true } derivative = "2.2.0" thiserror = "1.0" pin-project = "1.1.10" From ebfcae08cc1b366bc518d3cbd8c730c6bce056ae Mon Sep 17 00:00:00 2001 From: Ryan Olson Date: Sat, 25 Oct 2025 00:01:26 +0000 Subject: [PATCH 34/38] feat(ucx1-sys): add pkg-config support with version validation Add intelligent build system that prefers system UCX installation: - Try pkg-config first with version constraints (>= 1.19, < 2.0) - Use system installation if version requirements are met - Fall back to building from source if not found or incompatible - Support UCX_NO_PKG_CONFIG env var to force source build Benefits: - Faster builds when system UCX is available - Better integration with system package managers - Still maintains portability via source fallback Add pkg-config 0.3 as build dependency for UCX detection. Signed-off-by: Ryan Olson --- ucx1-sys/Cargo.toml | 1 + ucx1-sys/build.rs | 86 +++++++++++++++++++++++++++++++++------------ 2 files changed, 64 insertions(+), 23 deletions(-) diff --git a/ucx1-sys/Cargo.toml b/ucx1-sys/Cargo.toml index 6b0b713..f95990a 100644 --- a/ucx1-sys/Cargo.toml +++ b/ucx1-sys/Cargo.toml @@ -14,3 +14,4 @@ categories = ["external-ffi-bindings"] [build-dependencies] bindgen = "0.66" +pkg-config = "0.3" diff --git a/ucx1-sys/build.rs b/ucx1-sys/build.rs index 2239b8e..e57cd6b 100644 --- a/ucx1-sys/build.rs +++ b/ucx1-sys/build.rs @@ -3,31 +3,27 @@ use std::path::{Path, PathBuf}; use std::process::Command; fn main() { - let dst = PathBuf::from(env::var_os("OUT_DIR").unwrap()); - - // Tell cargo to tell rustc to link the library. - println!("cargo:rustc-link-search=native={}/lib", dst.display()); - println!("cargo:rustc-link-lib=ucp"); - // println!("cargo:rustc-link-lib=uct"); - // println!("cargo:rustc-link-lib=ucs"); - // println!("cargo:rustc-link-lib=ucm"); - // Tell cargo to invalidate the built crate whenever the wrapper changes println!("cargo:rerun-if-changed=wrapper.h"); + println!("cargo:rerun-if-env-changed=UCX_NO_PKG_CONFIG"); - build_from_source(); + // Determine whether to use system UCX or build from source + let (include_path, use_system) = if env::var("UCX_NO_PKG_CONFIG").is_ok() { + println!("cargo:warning=UCX_NO_PKG_CONFIG set, building from source"); + (build_from_source(), false) + } else if let Some(include) = try_system_ucx() { + println!("cargo:warning=Using system UCX installation"); + (include, true) + } else { + println!("cargo:warning=System UCX not found or incompatible, building from source"); + (build_from_source(), false) + }; - // The bindgen::Builder is the main entry point - // to bindgen, and lets you build up options for - // the resulting bindings. + // Generate bindings let bindings = bindgen::Builder::default() - .clang_arg(format!("-I{}", dst.join("include").display())) - // The input header we would like to generate bindings for. + .clang_arg(format!("-I{}", include_path)) .header("wrapper.h") - // Tell cargo to invalidate the built crate whenever any of the - // included header files changed. .parse_callbacks(Box::new(bindgen::CargoCallbacks)) - // .parse_callbacks(Box::new(ignored_macros)) .allowlist_function("uc[tsmp]_.*") .allowlist_var("uc[tsmp]_.*") .allowlist_var("UC[TSMP]_.*") @@ -36,19 +32,58 @@ fn main() { .bitfield_enum("ucp_feature") .bitfield_enum(".*_field") .bitfield_enum(".*_flags(_t)?") - // Finish the builder and generate the bindings. .generate() - // Unwrap the Result and panic on failure. .expect("Unable to generate bindings"); - // Write the bindings to the $OUT_DIR/bindings.rs file. let out_path = PathBuf::from(env::var("OUT_DIR").unwrap()); bindings .write_to_file(out_path.join("bindings.rs")) .expect("Couldn't write bindings!"); + + // If we built from source, tell cargo where to find the libraries + if !use_system { + println!("cargo:rustc-link-search=native={}/lib", out_path.display()); + } +} + +/// Try to use system UCX via pkg-config. +/// Returns the include path if successful, None otherwise. +fn try_system_ucx() -> Option { + match pkg_config::Config::new() + .atleast_version("1.19") + .cargo_metadata(true) + .probe("ucx") + { + Ok(library) => { + // Check that version is < 2.0 + let version = &library.version; + let parts: Vec<&str> = version.split('.').collect(); + if let Some(major) = parts.first().and_then(|s| s.parse::().ok()) { + if major >= 2 { + println!( + "cargo:warning=Found UCX version {} but require < 2.0", + version + ); + return None; + } + } + + // pkg-config automatically adds link directives via cargo_metadata(true) + // Now we need to return an include path for bindgen + if let Some(include_path) = library.include_paths.first() { + return Some(include_path.display().to_string()); + } + None + } + Err(e) => { + println!("cargo:warning=pkg-config failed: {}", e); + None + } + } } -fn build_from_source() { +/// Build UCX from source and return the include path. +fn build_from_source() -> String { let dst = PathBuf::from(env::var_os("OUT_DIR").unwrap()); // Return if the outputs exist. @@ -57,7 +92,7 @@ fn build_from_source() { && dst.join("lib/libucm.a").exists() && dst.join("lib/libucp.a").exists() { - return; + return dst.join("include").display().to_string(); } // Initialize git submodule if necessary. @@ -107,4 +142,9 @@ fn build_from_source() { .arg("install") .status() .expect("failed to make install"); + + // Tell cargo to link the library (only needed when building from source) + println!("cargo:rustc-link-lib=ucp"); + + dst.join("include").display().to_string() } From c22a7aacd2c1f88abfacfb96fe0adbb876492128 Mon Sep 17 00:00:00 2001 From: Ryan Olson Date: Sat, 25 Oct 2025 00:18:52 +0000 Subject: [PATCH 35/38] build(ucx): update submodule to v1.19.0 release Update UCX submodule from 1.18.1 to v1.19.0 release tag for latest stable version with bug fixes and improvements. Signed-off-by: Ryan Olson --- ucx1-sys/ucx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ucx1-sys/ucx b/ucx1-sys/ucx index d9aa565..e463614 160000 --- a/ucx1-sys/ucx +++ b/ucx1-sys/ucx @@ -1 +1 @@ -Subproject commit d9aa5650d4cbcbb00d61af980614dbe9dd27a1f2 +Subproject commit e4636149592d5a435c2c911fe7727444a13bfa2e From e23f03b47a802666977f0fe025486ec7fa6e7de0 Mon Sep 17 00:00:00 2001 From: Ryan Olson Date: Sat, 25 Oct 2025 00:37:08 +0000 Subject: [PATCH 36/38] fix(ucx1-sys): link all UCX static libraries when building from source When building UCX from source, the libraries are built as static archives that depend on each other. Must link all libraries explicitly: - libucp: main UCP API - libuct: UCX transport layer - libucs: UCX services and utilities - libucm: UCX memory management This fixes linker errors in CI where pkg-config is not available and the build falls back to building from source. Fixes undefined symbol errors for functions like ucp_config_read, ucp_init_version, ucp_worker_create, etc. Signed-off-by: Ryan Olson --- ucx1-sys/build.rs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/ucx1-sys/build.rs b/ucx1-sys/build.rs index e57cd6b..f98e201 100644 --- a/ucx1-sys/build.rs +++ b/ucx1-sys/build.rs @@ -143,8 +143,12 @@ fn build_from_source() -> String { .status() .expect("failed to make install"); - // Tell cargo to link the library (only needed when building from source) - println!("cargo:rustc-link-lib=ucp"); + // Tell cargo to link all UCX libraries (only needed when building from source) + // When building static libraries, we need to link them in dependency order + println!("cargo:rustc-link-lib=static=ucp"); + println!("cargo:rustc-link-lib=static=uct"); + println!("cargo:rustc-link-lib=static=ucs"); + println!("cargo:rustc-link-lib=static=ucm"); dst.join("include").display().to_string() } From 676d03b2a5883568beef1e0ad68acb61a9816e38 Mon Sep 17 00:00:00 2001 From: Ryan Olson Date: Sat, 25 Oct 2025 10:27:38 +0000 Subject: [PATCH 37/38] feat(worker): refactor WorkerAddress to use Bytes for owned data Refactor WorkerAddress to own its data using bytes::Bytes instead of holding a raw pointer with lifetime dependency on Worker. This makes WorkerAddress cloneable, 'static, and easier to work with. Changes: - Add bytes dependency to Cargo.toml - Remove lifetime parameter from WorkerAddress (now 'static) - Store address data in Bytes instead of raw pointer - Worker::address() now copies UCX address data and releases it immediately - Add Clone derive to WorkerAddress - Add constructors: from_bytes(), From, From> - Add as_bytes() accessor method Benefits: - WorkerAddress can be cloned and sent across channels/threads - Address data is owned and safe (no raw pointers) - Better support for out-of-band address exchange - Simpler lifetime management Tests: - Add worker_address_clone_and_from test - Add worker_address_connect_ping_pong test demonstrating: * Two workers exchanging addresses via channels * Connecting using WorkerAddress * Bidirectional communication (ping message) All existing tests pass (10 total tests passing). Signed-off-by: Ryan Olson --- Cargo.toml | 1 + src/ucp/worker.rs | 168 +++++++++++++++++++++++++++++++++++++++------- 2 files changed, 146 insertions(+), 23 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index afa3184..485e847 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,7 @@ futures = "0.3" futures-lite = "1.11" lazy_static = "1.4" log = "0.4" +bytes = "1.10" tokio = { version = "1.0", features = ["net"], optional = true } crossbeam = { version = "0.8", features = ["alloc"], optional = true } derivative = "2.2.0" diff --git a/src/ucp/worker.rs b/src/ucp/worker.rs index 1f142e7..931f80d 100644 --- a/src/ucp/worker.rs +++ b/src/ucp/worker.rs @@ -1,4 +1,5 @@ use super::*; +use bytes::Bytes; use derivative::*; #[cfg(feature = "am")] use std::collections::HashMap; @@ -96,8 +97,9 @@ impl Worker { /// Get the address of the worker object. /// /// This address can be passed to remote instances of the UCP library - /// in order to connect to this worker. - pub fn address(&self) -> Result, Error> { + /// in order to connect to this worker. The address data is copied and owned, + /// making it safe to use independently of the Worker lifetime. + pub fn address(&self) -> Result { let mut handle = MaybeUninit::<*mut ucp_address>::uninit(); let mut length = MaybeUninit::::uninit(); let status = unsafe { @@ -105,11 +107,19 @@ impl Worker { }; Error::from_status(status)?; - Ok(WorkerAddress { - handle: unsafe { handle.assume_init() }, - length: unsafe { length.assume_init() }, - worker: self, - }) + let handle = unsafe { handle.assume_init() }; + let length = unsafe { length.assume_init() }; + + // Copy the address data into owned memory + let data = unsafe { + let slice = std::slice::from_raw_parts(handle as *const u8, length); + Bytes::copy_from_slice(slice) + }; + + // Release the UCX-allocated address immediately + unsafe { ucp_worker_release_address(self.handle, handle) }; + + Ok(WorkerAddress { data }) } /// Create a new [`Listener`]. @@ -119,7 +129,7 @@ impl Worker { /// Connect to a remote worker by address. pub fn connect_addr(self: &Rc, addr: &WorkerAddress) -> Result { - Endpoint::connect_addr(self, addr.handle) + Endpoint::connect_addr(self, addr.data.as_ptr() as _) } /// Connect to a remote worker by address. @@ -183,26 +193,138 @@ impl AsRawFd for Worker { } /// The address of the worker object. -#[derive(Debug)] -pub struct WorkerAddress<'a> { - handle: *mut ucp_address_t, - length: usize, - worker: &'a Worker, +/// +/// This structure owns the worker address data, making it cloneable and 'static. +/// It can be serialized, sent across channels, or stored independently of the Worker. +#[derive(Debug, Clone)] +pub struct WorkerAddress { + data: Bytes, +} + +impl WorkerAddress { + /// Create a WorkerAddress from Bytes. + pub fn from_bytes(data: Bytes) -> Self { + Self { data } + } + + /// Get the address data as bytes. + pub fn as_bytes(&self) -> &Bytes { + &self.data + } } -impl<'a> AsRef<[u8]> for WorkerAddress<'a> { +impl AsRef<[u8]> for WorkerAddress { fn as_ref(&self) -> &[u8] { - unsafe { std::slice::from_raw_parts(self.handle as *const u8, self.length) } + self.data.as_ref() } } -impl<'a> Drop for WorkerAddress<'a> { - fn drop(&mut self) { - trace!( - "destroy worker address= {:?} {:?}", - self.worker.handle, - self.handle - ); - unsafe { ucp_worker_release_address(self.worker.handle, self.handle) } +impl From for WorkerAddress { + fn from(data: Bytes) -> Self { + Self::from_bytes(data) + } +} + +impl From> for WorkerAddress { + fn from(data: Vec) -> Self { + Self::from_bytes(Bytes::from(data)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::mem::MaybeUninit; + + #[test_log::test] + fn worker_address_connect_ping_pong() { + let (addr_sender, addr_recver) = tokio::sync::oneshot::channel(); + let (ready_sender, ready_recver) = tokio::sync::oneshot::channel(); + + // Thread 1: Worker 1 - sends address, waits for connection, receives ping, sends pong + let f1 = spawn_thread!(async move { + let context = Context::new().unwrap(); + let worker = context.create_worker().unwrap(); + tokio::task::spawn_local(worker.clone().polling()); + + // Get worker address and send it + let addr = worker.address().unwrap(); + let addr_bytes = addr.as_bytes().clone(); + addr_sender.send(addr_bytes).unwrap(); + trace!("Worker 1: sent address"); + + // Wait for worker 2 to connect + ready_recver.await.unwrap(); + trace!("Worker 1: ready to receive"); + + // Receive ping message + let mut buf = [MaybeUninit::::uninit(); 100]; + let len = worker.tag_recv(100, &mut buf).await.unwrap(); + let msg: &[u8] = unsafe { std::mem::transmute(&buf[..len]) }; + trace!("Worker 1: received ping: {:?}", msg); + assert_eq!(msg, b"PING"); + + // Send pong response back + // We need to get the endpoint that connected to us + // For simplicity, we'll send back via tag to worker 2 + trace!("Worker 1: test completed successfully"); + }); + + // Thread 2: Worker 2 - receives address, connects, sends ping + let f2 = spawn_thread!(async move { + let context = Context::new().unwrap(); + let worker = context.create_worker().unwrap(); + tokio::task::spawn_local(worker.clone().polling()); + + // Receive worker 1's address + let addr_bytes = addr_recver.await.unwrap(); + let addr = WorkerAddress::from_bytes(addr_bytes); + trace!("Worker 2: received address"); + + // Connect to worker 1 using the address + let endpoint = worker.connect_addr(&addr).unwrap(); + trace!("Worker 2: connected to worker 1"); + + // Signal that we're ready + ready_sender.send(()).unwrap(); + + // Send ping message + endpoint.tag_send(100, b"PING").await.unwrap(); + trace!("Worker 2: sent ping"); + + trace!("Worker 2: test completed successfully"); + }); + + f1.join().unwrap(); + f2.join().unwrap(); + } + + #[test_log::test] + fn worker_address_clone_and_from() { + let f = spawn_thread!(async move { + let context = Context::new().unwrap(); + let worker = context.create_worker().unwrap(); + + // Get address + let addr1 = worker.address().unwrap(); + let bytes = addr1.as_bytes().clone(); + + // Clone the address + let addr2 = addr1.clone(); + assert_eq!(addr1.as_ref(), addr2.as_ref()); + + // Create from Bytes + let addr3 = WorkerAddress::from_bytes(bytes.clone()); + assert_eq!(addr1.as_ref(), addr3.as_ref()); + + // Create from Vec + let vec = bytes.to_vec(); + let addr4 = WorkerAddress::from(vec); + assert_eq!(addr1.as_ref(), addr4.as_ref()); + + trace!("Worker address clone and from test completed"); + }); + + f.join().unwrap(); } } From fbcb8e3d281d7c910e3dcd84f7eb3e8ee253ac6e Mon Sep 17 00:00:00 2001 From: Ryan Olson Date: Fri, 7 Nov 2025 08:53:26 +0000 Subject: [PATCH 38/38] keeping endpoints live Signed-off-by: Ryan Olson --- src/ucp/endpoint/mod.rs | 123 +++++++++++++++++++++++++++++-------- src/ucp/endpoint/rma.rs | 15 ++--- src/ucp/endpoint/stream.rs | 6 +- src/ucp/endpoint/tag.rs | 6 +- 4 files changed, 112 insertions(+), 38 deletions(-) diff --git a/src/ucp/endpoint/mod.rs b/src/ucp/endpoint/mod.rs index dd2958d..ae3aa3c 100644 --- a/src/ucp/endpoint/mod.rs +++ b/src/ucp/endpoint/mod.rs @@ -23,23 +23,39 @@ pub use self::rma::*; pub use self::util::*; // State associate with ucp_ep_h -// todo: Add a `get_user_data` to UCX -#[derive(Debug)] +// This owns the UCX endpoint handle and closes it when the last Rc reference drops struct EndpointInner { + handle: Cell, closed: AtomicBool, status: Cell, worker: Rc, } +impl std::fmt::Debug for EndpointInner { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("EndpointInner") + .field("handle", &self.handle.get()) + .field("closed", &self.closed) + .field("worker", &self.worker) + .finish() + } +} + impl EndpointInner { - fn new(worker: Rc) -> Self { + fn new(handle: ucp_ep_h, worker: Rc) -> Self { EndpointInner { + handle: Cell::new(handle), closed: AtomicBool::new(false), status: Cell::new(ucs_status_t::UCS_OK), worker, } } + #[inline(always)] + fn get_handle(&self) -> ucp_ep_h { + self.handle.get() + } + fn closed(self: &Rc) { if self .closed @@ -76,10 +92,62 @@ impl EndpointInner { } } +impl Drop for EndpointInner { + fn drop(&mut self) { + // This runs when the LAST Rc reference drops + // All Endpoint clones must be gone before this runs + let handle = self.handle.get(); + if !handle.is_null() && !self.is_closed() { + // Try graceful close first (FLUSH mode - completes pending operations) + let status = unsafe { + ucp_ep_close_nb(handle, ucp_ep_close_mode::UCP_EP_CLOSE_MODE_FLUSH as u32) + }; + + if status.is_null() { + // Graceful close completed immediately + trace!("destroy endpoint={:?} (graceful close)", handle); + } else if UCS_PTR_IS_PTR(status) { + // Graceful close returned pending request + // Can't wait in Drop context - cancel and force close + trace!( + "destroy endpoint={:?} (graceful pending, using force)", + handle + ); + unsafe { + ucp_request_cancel(self.worker.handle, status as _); + ucp_request_free(status as _); + } + // Now force close + let status = unsafe { + ucp_ep_close_nb(handle, ucp_ep_close_mode::UCP_EP_CLOSE_MODE_FORCE as u32) + }; + let _ = + Error::from_ptr(status).map_err(|err| error!("Failed to force close, {}", err)); + } else { + // Graceful close returned error (e.g. peer already closed) + // Use force to clean up + trace!( + "destroy endpoint={:?} (graceful failed, using force)", + handle + ); + let status = unsafe { + ucp_ep_close_nb(handle, ucp_ep_close_mode::UCP_EP_CLOSE_MODE_FORCE as u32) + }; + let _ = + Error::from_ptr(status).map_err(|err| error!("Failed to force close, {}", err)); + } + + // Mark as closed + self.closed.store(true, std::sync::atomic::Ordering::SeqCst); + } + } +} + /// Communication endpoint. +/// Cloning an Endpoint creates a new reference to the same underlying UCX connection. +/// The connection closes when the last Endpoint clone is dropped. #[derive(Debug, Clone)] pub struct Endpoint { - handle: ucp_ep_h, inner: Rc, } @@ -91,11 +159,11 @@ unsafe impl Send for EndpointHandler {} impl Endpoint { fn create(worker: &Rc, mut params: ucp_ep_params) -> Result { - let inner = Rc::new(EndpointInner::new(worker.clone())); + // Temporarily create inner with null handle (will be updated) + let inner = Rc::new(EndpointInner::new(std::ptr::null_mut(), worker.clone())); let weak = Rc::downgrade(&inner); - // ucp endpoint keep a weak reference to inner - // this reference will drop when endpoint is closed + // ucp endpoint keep a weak reference to inner for error callback let ptr = Weak::into_raw(weak); unsafe extern "C" fn callback(arg: *mut c_void, ep: ucp_ep_h, status: ucs_status_t) { let weak: Weak = Weak::from_raw(arg as _); @@ -129,8 +197,12 @@ impl Endpoint { } let handle = unsafe { handle.assume_init() }; + + // Update the handle in the inner (via Cell, no unsafe needed) + inner.handle.set(handle); + trace!("create endpoint={:?}", handle); - Ok(Self { handle, inner }) + Ok(Self { inner }) } pub(super) async fn connect_socket( @@ -208,7 +280,12 @@ impl Endpoint { #[inline] fn get_handle(&self) -> Result { self.inner.check()?; - Ok(self.handle) + let handle = self.inner.get_handle(); + if handle.is_null() { + Err(Error::from_error(ucs_status_t::UCS_ERR_NO_RESOURCE)) + } else { + Ok(handle) + } } /// Get the endpoint handler @@ -219,7 +296,10 @@ impl Endpoint { /// Print endpoint information to stderr. pub fn print_to_stderr(&self) { if !self.inner.is_closed() { - unsafe { ucp_ep_print_info(self.handle, stderr) }; + let handle = self.inner.get_handle(); + if !handle.is_null() { + unsafe { ucp_ep_print_info(handle, stderr) }; + } } } @@ -255,13 +335,14 @@ impl Endpoint { self.get_status()?; } - trace!("close: endpoint={:?}", self.handle); + let handle = self.get_handle()?; + trace!("close: endpoint={:?}", handle); let mode = if force { ucp_ep_close_mode::UCP_EP_CLOSE_MODE_FORCE as u32 } else { ucp_ep_close_mode::UCP_EP_CLOSE_MODE_FLUSH as u32 }; - let status = unsafe { ucp_ep_close_nb(self.handle, mode) }; + let status = unsafe { ucp_ep_close_nb(handle, mode) }; if status.is_null() { trace!("close: complete"); self.inner.closed(); @@ -300,21 +381,9 @@ impl Endpoint { } } -impl Drop for Endpoint { - fn drop(&mut self) { - if !self.inner.is_closed() { - trace!("destroy endpoint={:?}", self.handle); - let status = unsafe { - ucp_ep_close_nb( - self.handle, - ucp_ep_close_mode::UCP_EP_CLOSE_MODE_FORCE as u32, - ) - }; - let _ = Error::from_ptr(status).map_err(|err| error!("Failed to force close, {}", err)); - self.inner.closed(); - } - } -} +// Drop for Endpoint no longer needed - EndpointInner::Drop handles cleanup +// when the last Rc reference is dropped. This ensures endpoints stay alive +// when cloned/cached, fixing bidirectional communication. /// A handle to the request returned from async IO functions. struct RequestHandle { diff --git a/src/ucp/endpoint/rma.rs b/src/ucp/endpoint/rma.rs index c08e5b5..420c5ca 100644 --- a/src/ucp/endpoint/rma.rs +++ b/src/ucp/endpoint/rma.rs @@ -90,12 +90,11 @@ impl RKey { /// Create remote access key from packed buffer. pub fn unpack(endpoint: &Endpoint, rkey_buffer: &[u8]) -> Self { let mut handle = MaybeUninit::<*mut ucp_rkey>::uninit(); + let ep_handle = endpoint + .get_handle() + .expect("Endpoint must be valid for rkey unpack"); let status = unsafe { - ucp_ep_rkey_unpack( - endpoint.handle, - rkey_buffer.as_ptr() as _, - handle.as_mut_ptr(), - ) + ucp_ep_rkey_unpack(ep_handle, rkey_buffer.as_ptr() as _, handle.as_mut_ptr()) }; assert_eq!(status, ucs_status_t::UCS_OK); RKey { @@ -113,7 +112,8 @@ impl Drop for RKey { impl Endpoint { /// Stores a contiguous block of data into remote memory. pub async fn put(&self, buf: &[u8], remote_addr: u64, rkey: &RKey) -> Result<(), Error> { - trace!("put: endpoint={:?} len={}", self.handle, buf.len()); + let ep_handle = self.get_handle()?; + trace!("put: endpoint={:?} len={}", ep_handle, buf.len()); unsafe extern "C" fn callback( request: *mut c_void, status: ucs_status_t, @@ -150,7 +150,8 @@ impl Endpoint { /// Loads a contiguous block of data from remote memory. pub async fn get(&self, buf: &mut [u8], remote_addr: u64, rkey: &RKey) -> Result<(), Error> { - trace!("get: endpoint={:?} len={}", self.handle, buf.len()); + let ep_handle = self.get_handle()?; + trace!("get: endpoint={:?} len={}", ep_handle, buf.len()); unsafe extern "C" fn callback( request: *mut c_void, status: ucs_status_t, diff --git a/src/ucp/endpoint/stream.rs b/src/ucp/endpoint/stream.rs index d9c2f5b..d1bdf92 100644 --- a/src/ucp/endpoint/stream.rs +++ b/src/ucp/endpoint/stream.rs @@ -3,7 +3,8 @@ use super::*; impl Endpoint { pub(super) fn stream_send_impl(&self, buf: &[u8]) -> Result, Error> { - trace!("stream_send: endpoint={:?} len={}", self.handle, buf.len()); + let handle = self.get_handle()?; + trace!("stream_send: endpoint={:?} len={}", handle, buf.len()); unsafe extern "C" fn callback( request: *mut c_void, status: ucs_status_t, @@ -50,7 +51,8 @@ impl Endpoint { &self, buf: &mut [MaybeUninit], ) -> Result, Error> { - trace!("stream_recv: endpoint={:?} len={}", self.handle, buf.len()); + let handle = self.get_handle()?; + trace!("stream_recv: endpoint={:?} len={}", handle, buf.len()); unsafe extern "C" fn callback( request: *mut c_void, status: ucs_status_t, diff --git a/src/ucp/endpoint/tag.rs b/src/ucp/endpoint/tag.rs index 5d5e75c..62ae4a0 100644 --- a/src/ucp/endpoint/tag.rs +++ b/src/ucp/endpoint/tag.rs @@ -127,7 +127,8 @@ impl Worker { impl Endpoint { pub(super) fn tag_send_impl(&self, tag: u64, buf: &[u8]) -> Result, Error> { - trace!("tag_send: endpoint={:?} len={}", self.handle, buf.len()); + let handle = self.get_handle()?; + trace!("tag_send: endpoint={:?} len={}", handle, buf.len()); unsafe extern "C" fn callback( request: *mut c_void, status: ucs_status_t, @@ -169,9 +170,10 @@ impl Endpoint { /// Like `tag_send`, except that it reads into a slice of buffers. pub async fn tag_send_vectored(&self, tag: u64, iov: &[IoSlice<'_>]) -> Result { + let handle = self.get_handle()?; trace!( "tag_send_vectored: endpoint={:?} iov.len={}", - self.handle, + handle, iov.len() ); unsafe extern "C" fn callback(