From f2a4bf5ad05b9509ccefb770ca9b3bea51971cd2 Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Mon, 12 May 2025 11:47:15 +0800 Subject: [PATCH 01/33] connect_add_rvec --- src/ucp/worker.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/ucp/worker.rs b/src/ucp/worker.rs index 00abd16..d034a03 100644 --- a/src/ucp/worker.rs +++ b/src/ucp/worker.rs @@ -122,6 +122,11 @@ impl Worker { Endpoint::connect_addr(self, addr.handle) } + /// Connect to a remote worker by address. + pub fn connect_addr_vec(self: &Rc, addr: &[u8]) -> Result { + Endpoint::connect_addr(self, addr.as_ptr() as _) + } + /// Connect to a remote listener. pub async fn connect_socket(self: &Rc, addr: SocketAddr) -> Result { Endpoint::connect_socket(self, addr).await @@ -193,6 +198,7 @@ impl<'a> AsRef<[u8]> for WorkerAddress<'a> { impl<'a> Drop for WorkerAddress<'a> { fn drop(&mut self) { + trace!("destroy worker address= {:?} {:?}", self.worker.handle, self.handle); unsafe { ucp_worker_release_address(self.worker.handle, self.handle) } } } From 1b56b62fb69486ba25a719486f9bb187c0aaf98a Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Tue, 13 May 2025 10:26:13 +0800 Subject: [PATCH 02/33] disable build ucx --- ucx1-sys/build.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ucx1-sys/build.rs b/ucx1-sys/build.rs index 2239b8e..0f85602 100644 --- a/ucx1-sys/build.rs +++ b/ucx1-sys/build.rs @@ -15,7 +15,7 @@ fn main() { // Tell cargo to invalidate the built crate whenever the wrapper changes println!("cargo:rerun-if-changed=wrapper.h"); - build_from_source(); + // build_from_source(); // The bindgen::Builder is the main entry point // to bindgen, and lets you build up options for From 1f1f68a827c493b3b0756c0bbf7ebdf39d53bbc7 Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Fri, 27 Jun 2025 10:26:14 +0800 Subject: [PATCH 03/33] tag use send/recv_nbx --- src/ucp/endpoint/tag.rs | 87 ++++++++++++++++++++++++++++++++--------- 1 file changed, 69 insertions(+), 18 deletions(-) diff --git a/src/ucp/endpoint/tag.rs b/src/ucp/endpoint/tag.rs index 6f04b4d..48f6c55 100644 --- a/src/ucp/endpoint/tag.rs +++ b/src/ucp/endpoint/tag.rs @@ -26,27 +26,38 @@ impl Worker { unsafe extern "C" fn callback( request: *mut c_void, status: ucs_status_t, - info: *mut ucp_tag_recv_info, + info: *const ucp_tag_recv_info, + _user_data: *mut c_void, ) { let length = (*info).length; + let sender_tag = (*info).sender_tag; trace!( - "tag_recv: complete. req={:?}, status={:?}, len={}", + "tag_recv: complete. req={:?}, status={:?}, tag={}, len={}", request, status, + sender_tag, length ); let request = &mut *(request as *mut Request); request.waker.wake(); } + let param = ucp_request_param_t { + op_attr_mask: ucp_op_attr_t::UCP_OP_ATTR_FIELD_CALLBACK as u32, + cb: unsafe { + let mut cb: ucp_request_param_t__bindgen_ty_1 = std::mem::zeroed(); + cb.recv = Some(callback); + cb + }, + ..unsafe { std::mem::zeroed() } + }; let status = unsafe { - ucp_tag_recv_nb( + ucp_tag_recv_nbx( self.handle, buf.as_mut_ptr() as _, buf.len() as _, - ucp_dt_make_contig(1), tag, tag_mask, - Some(callback), + ¶m, ) }; @@ -72,27 +83,41 @@ impl Worker { unsafe extern "C" fn callback( request: *mut c_void, status: ucs_status_t, - info: *mut ucp_tag_recv_info, + info: *const ucp_tag_recv_info, + _user_data: *mut c_void, ) { let length = (*info).length; + let tag = (*info).sender_tag; trace!( - "tag_recv_vectored: complete. req={:?}, status={:?}, len={}", + "tag_recv_vectored: complete. req={:?}, status={:?}, tag={}, len={}", request, status, + tag, length ); let request = &mut *(request as *mut Request); request.waker.wake(); } + // --- 修改为ucp_tag_recv_nbx --- + let param = ucp_request_param_t { + op_attr_mask: ucp_op_attr_t::UCP_OP_ATTR_FIELD_CALLBACK as u32 + | ucp_op_attr_t::UCP_OP_ATTR_FIELD_DATATYPE as u32, + cb: unsafe { + let mut cb: ucp_request_param_t__bindgen_ty_1 = std::mem::zeroed(); + cb.recv = Some(callback); + cb + }, + datatype: ucp_dt_type::UCP_DATATYPE_IOV as _, + ..unsafe { std::mem::zeroed() } + }; let status = unsafe { - ucp_tag_recv_nb( + ucp_tag_recv_nbx( self.handle, iov.as_ptr() as _, iov.len() as _, - ucp_dt_type::UCP_DATATYPE_IOV as _, tag, u64::max_value(), - Some(callback), + ¶m, ) }; Error::from_ptr(status)?; @@ -109,19 +134,31 @@ impl Endpoint { /// Sends a messages with `tag`. pub async fn tag_send(&self, tag: u64, buf: &[u8]) -> Result { trace!("tag_send: endpoint={:?} len={}", self.handle, buf.len()); - unsafe extern "C" fn callback(request: *mut c_void, status: ucs_status_t) { + unsafe extern "C" fn callback( + request: *mut c_void, + status: ucs_status_t, + _user_data: *mut c_void, + ) { trace!("tag_send: complete. req={:?}, status={:?}", request, status); let request = &mut *(request as *mut Request); request.waker.wake(); } + let param = ucp_request_param_t { + op_attr_mask: ucp_op_attr_t::UCP_OP_ATTR_FIELD_CALLBACK as u32, + cb: unsafe { + let mut cb: ucp_request_param_t__bindgen_ty_1 = std::mem::zeroed(); + cb.send = Some(callback); + cb + }, + ..unsafe { std::mem::zeroed() } + }; let status = unsafe { - ucp_tag_send_nb( + ucp_tag_send_nbx( self.get_handle()?, buf.as_ptr() as _, buf.len() as _, - ucp_dt_make_contig(1), tag, - Some(callback), + ¶m, ) }; if status.is_null() { @@ -145,7 +182,11 @@ impl Endpoint { self.handle, iov.len() ); - unsafe extern "C" fn callback(request: *mut c_void, status: ucs_status_t) { + unsafe extern "C" fn callback( + request: *mut c_void, + status: ucs_status_t, + _user_data: *mut c_void, + ) { trace!( "tag_send_vectored: complete. req={:?}, status={:?}", request, @@ -154,14 +195,24 @@ impl Endpoint { let request = &mut *(request as *mut Request); request.waker.wake(); } + let param = ucp_request_param_t { + op_attr_mask: ucp_op_attr_t::UCP_OP_ATTR_FIELD_CALLBACK as u32 + | ucp_op_attr_t::UCP_OP_ATTR_FIELD_DATATYPE as u32, + cb: unsafe { + let mut cb: ucp_request_param_t__bindgen_ty_1 = std::mem::zeroed(); + cb.send = Some(callback); + cb + }, + datatype: ucp_dt_type::UCP_DATATYPE_IOV as _, + ..unsafe { std::mem::zeroed() } + }; let status = unsafe { - ucp_tag_send_nb( + ucp_tag_send_nbx( self.get_handle()?, iov.as_ptr() as _, iov.len() as _, - ucp_dt_type::UCP_DATATYPE_IOV as _, tag, - Some(callback), + ¶m, ) }; let total_len = iov.iter().map(|v| v.len()).sum(); From c9bce466fce583edf721fc5bc5c706f745e6dddc Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Fri, 27 Jun 2025 16:15:55 +0800 Subject: [PATCH 04/33] use RequestParam --- .gitignore | 2 + src/ucp/endpoint/am.rs | 73 +++++++++++++--------------------- src/ucp/endpoint/mod.rs | 1 + src/ucp/endpoint/param.rs | 82 +++++++++++++++++++++++++++++++++++++++ src/ucp/endpoint/tag.rs | 58 +++++++-------------------- 5 files changed, 124 insertions(+), 92 deletions(-) create mode 100644 src/ucp/endpoint/param.rs diff --git a/.gitignore b/.gitignore index 96ef6c0..057c110 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,4 @@ +.vscode/ + /target Cargo.lock diff --git a/src/ucp/endpoint/am.rs b/src/ucp/endpoint/am.rs index 2002020..9602d85 100644 --- a/src/ucp/endpoint/am.rs +++ b/src/ucp/endpoint/am.rs @@ -1,6 +1,7 @@ use crossbeam::queue::SegQueue; use tokio::sync::Notify; +use super::param::RequestParam; use super::*; use std::{ io::{IoSlice, IoSliceMut}, @@ -105,7 +106,7 @@ pub struct AmMsg<'a> { } impl<'a> AmMsg<'a> { - fn from_raw(worker: &'a Worker, msg: RawMsg) -> Self { + fn from_raw(worker: &'a Worker, msg: RawMsg) -> Self{ AmMsg { worker, msg } } @@ -249,22 +250,12 @@ impl<'a> AmMsg<'a> { self.worker.handle, iov.len() ); - let mut param = MaybeUninit::::uninit(); - let (buffer, count) = unsafe { - let param = &mut *param.as_mut_ptr(); - param.op_attr_mask = ucp_op_attr_t::UCP_OP_ATTR_FIELD_CALLBACK as u32 - | ucp_op_attr_t::UCP_OP_ATTR_FIELD_DATATYPE as u32; - param.cb = ucp_request_param_t__bindgen_ty_1 { - recv_am: Some(callback), - }; - - if iov.len() == 1 { - param.datatype = ucp_dt_make_contig(1); - (iov[0].as_ptr(), iov[0].len()) - } else { - param.datatype = ucp_dt_type::UCP_DATATYPE_IOV as _; - (iov.as_ptr() as _, iov.len()) - } + + let param = RequestParam::new().cb_recv_am(Some(callback)); + let (buffer, count, param) = if iov.len() == 1 { + (iov[0].as_ptr(), iov[0].len(), param) + } else { + (iov.as_ptr() as _, iov.len(), param.iov()) }; let status = unsafe { @@ -273,7 +264,7 @@ impl<'a> AmMsg<'a> { data_desc as _, buffer as _, count as _, - param.as_ptr(), + param.as_ref(), ) }; if status.is_null() { @@ -546,34 +537,22 @@ async fn am_send( request.waker.wake(); } - let mut param = MaybeUninit::::uninit(); - let (buffer, count) = unsafe { - let param = &mut *param.as_mut_ptr(); - param.op_attr_mask = ucp_op_attr_t::UCP_OP_ATTR_FIELD_CALLBACK as u32 - | ucp_op_attr_t::UCP_OP_ATTR_FIELD_DATATYPE as u32 - | ucp_op_attr_t::UCP_OP_ATTR_FIELD_FLAGS as u32; - param.flags = 0; - param.cb = ucp_request_param_t__bindgen_ty_1 { - send: Some(callback), - }; - - match proto { - Some(AmProto::Eager) => param.flags |= ucp_send_am_flags::UCP_AM_SEND_FLAG_EAGER.0, - Some(AmProto::Rndv) => param.flags |= ucp_send_am_flags::UCP_AM_SEND_FLAG_RNDV.0, - _ => (), - } - - if need_reply { - param.flags |= ucp_send_am_flags::UCP_AM_SEND_FLAG_REPLY.0; - } - - if data.len() == 1 { - param.datatype = ucp_dt_make_contig(1); - (data[0].as_ptr(), data[0].len()) - } else { - param.datatype = ucp_dt_type::UCP_DATATYPE_IOV as _; - (data.as_ptr() as _, data.len()) - } + // Use RequestParam builder for send + let param = RequestParam::new().send_cb(Some(callback)); + let param = match proto { + Some(AmProto::Eager) => param.set_flag_eager(), + Some(AmProto::Rndv) => param.set_flag_rndv(), + None => param, + }; + let param = if need_reply { + param.set_flag_reply() + } else { + param + }; + let (buffer, count, param) = if data.len() == 1 { + (data[0].as_ptr(), data[0].len(), param) + } else { + (data.as_ptr() as _, data.len(), param.iov()) }; let status = unsafe { @@ -584,7 +563,7 @@ async fn am_send( header.len() as _, buffer as _, count as _, - param.as_mut_ptr(), + param.as_ref(), ) }; if status.is_null() { diff --git a/src/ucp/endpoint/mod.rs b/src/ucp/endpoint/mod.rs index b477468..757a20b 100644 --- a/src/ucp/endpoint/mod.rs +++ b/src/ucp/endpoint/mod.rs @@ -12,6 +12,7 @@ mod am; mod rma; mod stream; mod tag; +mod param; #[cfg(feature = "am")] pub use self::am::*; diff --git a/src/ucp/endpoint/param.rs b/src/ucp/endpoint/param.rs new file mode 100644 index 0000000..81e17d8 --- /dev/null +++ b/src/ucp/endpoint/param.rs @@ -0,0 +1,82 @@ +use ucx1_sys::*; + +pub struct RequestParam { + inner: ucp_request_param_t, +} + +impl RequestParam { + pub fn new() -> Self { + // Zeroed for safety, as in C + Self { + inner: unsafe { std::mem::zeroed() }, + } + } + + pub fn send_cb(mut self, callback: ucp_send_nbx_callback_t) -> Self { + self.inner.op_attr_mask |= ucp_op_attr_t::UCP_OP_ATTR_FIELD_CALLBACK as u32; + unsafe { + let mut cb: ucp_request_param_t__bindgen_ty_1 = std::mem::zeroed(); + cb.send = callback; + self.inner.cb = cb; + } + self + } + + pub fn cb_tag_recv(mut self, callback: ucp_tag_recv_nbx_callback_t) -> Self { + self.inner.op_attr_mask |= ucp_op_attr_t::UCP_OP_ATTR_FIELD_CALLBACK as u32; + unsafe { + let mut cb: ucp_request_param_t__bindgen_ty_1 = std::mem::zeroed(); + cb.recv = callback; + self.inner.cb = cb; + } + self + } + + #[cfg(feature = "am")] + pub fn cb_recv_am(mut self, callback: ucp_am_recv_data_nbx_callback_t) -> Self { + self.inner.op_attr_mask |= ucp_op_attr_t::UCP_OP_ATTR_FIELD_CALLBACK as u32; + unsafe { + let mut cb: ucp_request_param_t__bindgen_ty_1 = std::mem::zeroed(); + cb.recv_am = callback; + self.inner.cb = cb; + } + self + } + + pub fn iov(mut self) -> Self { + self.inner.op_attr_mask |= ucp_op_attr_t::UCP_OP_ATTR_FIELD_DATATYPE as u32; + self.inner.datatype = ucp_dt_type::UCP_DATATYPE_IOV as _; + self + } + + #[cfg(feature = "am")] + fn set_flag(&mut self) { + self.inner.op_attr_mask |= ucp_op_attr_t::UCP_OP_ATTR_FIELD_FLAGS as u32; + } + + #[cfg(feature = "am")] + pub fn set_flag_eager(mut self) -> Self { + self.set_flag(); + self.inner.flags |= ucp_send_am_flags::UCP_AM_SEND_FLAG_EAGER.0; + self + } + + #[cfg(feature = "am")] + pub fn set_flag_rndv(mut self) -> Self { + self.set_flag(); + self.inner.flags |= ucp_send_am_flags::UCP_AM_SEND_FLAG_RNDV.0; + self + } + + #[cfg(feature = "am")] + pub fn set_flag_reply(mut self) -> Self { + self.set_flag(); + self.inner.flags |= ucp_send_am_flags::UCP_AM_SEND_FLAG_REPLY.0; + self + } + + pub fn as_ref(&self) -> *const ucp_request_param_t { + &self.inner as *const _ + } +} + diff --git a/src/ucp/endpoint/tag.rs b/src/ucp/endpoint/tag.rs index 48f6c55..e8feeac 100644 --- a/src/ucp/endpoint/tag.rs +++ b/src/ucp/endpoint/tag.rs @@ -1,3 +1,4 @@ +use super::param::RequestParam; use super::*; use std::io::{IoSlice, IoSliceMut}; @@ -41,15 +42,8 @@ impl Worker { let request = &mut *(request as *mut Request); request.waker.wake(); } - let param = ucp_request_param_t { - op_attr_mask: ucp_op_attr_t::UCP_OP_ATTR_FIELD_CALLBACK as u32, - cb: unsafe { - let mut cb: ucp_request_param_t__bindgen_ty_1 = std::mem::zeroed(); - cb.recv = Some(callback); - cb - }, - ..unsafe { std::mem::zeroed() } - }; + // Use RequestParam builder + let param = RequestParam::new().cb_tag_recv(Some(callback)); let status = unsafe { ucp_tag_recv_nbx( self.handle, @@ -57,7 +51,7 @@ impl Worker { buf.len() as _, tag, tag_mask, - ¶m, + param.as_ref(), ) }; @@ -98,18 +92,8 @@ impl Worker { let request = &mut *(request as *mut Request); request.waker.wake(); } - // --- 修改为ucp_tag_recv_nbx --- - let param = ucp_request_param_t { - op_attr_mask: ucp_op_attr_t::UCP_OP_ATTR_FIELD_CALLBACK as u32 - | ucp_op_attr_t::UCP_OP_ATTR_FIELD_DATATYPE as u32, - cb: unsafe { - let mut cb: ucp_request_param_t__bindgen_ty_1 = std::mem::zeroed(); - cb.recv = Some(callback); - cb - }, - datatype: ucp_dt_type::UCP_DATATYPE_IOV as _, - ..unsafe { std::mem::zeroed() } - }; + // Use RequestParam builder for iov + let param = RequestParam::new().cb_tag_recv(Some(callback)).iov(); let status = unsafe { ucp_tag_recv_nbx( self.handle, @@ -117,7 +101,7 @@ impl Worker { iov.len() as _, tag, u64::max_value(), - ¶m, + param.as_ref(), ) }; Error::from_ptr(status)?; @@ -143,22 +127,15 @@ impl Endpoint { let request = &mut *(request as *mut Request); request.waker.wake(); } - let param = ucp_request_param_t { - op_attr_mask: ucp_op_attr_t::UCP_OP_ATTR_FIELD_CALLBACK as u32, - cb: unsafe { - let mut cb: ucp_request_param_t__bindgen_ty_1 = std::mem::zeroed(); - cb.send = Some(callback); - cb - }, - ..unsafe { std::mem::zeroed() } - }; + // Use RequestParam builder + let param = RequestParam::new().send_cb(Some(callback)); let status = unsafe { ucp_tag_send_nbx( self.get_handle()?, buf.as_ptr() as _, buf.len() as _, tag, - ¶m, + param.as_ref(), ) }; if status.is_null() { @@ -195,24 +172,15 @@ impl Endpoint { let request = &mut *(request as *mut Request); request.waker.wake(); } - let param = ucp_request_param_t { - op_attr_mask: ucp_op_attr_t::UCP_OP_ATTR_FIELD_CALLBACK as u32 - | ucp_op_attr_t::UCP_OP_ATTR_FIELD_DATATYPE as u32, - cb: unsafe { - let mut cb: ucp_request_param_t__bindgen_ty_1 = std::mem::zeroed(); - cb.send = Some(callback); - cb - }, - datatype: ucp_dt_type::UCP_DATATYPE_IOV as _, - ..unsafe { std::mem::zeroed() } - }; + // Use RequestParam builder for iov + let param = RequestParam::new().send_cb(Some(callback)).iov(); let status = unsafe { ucp_tag_send_nbx( self.get_handle()?, iov.as_ptr() as _, iov.len() as _, tag, - ¶m, + param.as_ref(), ) }; let total_len = iov.iter().map(|v| v.len()).sum(); From fa298d3ec1b8038802368bcd32ed4a7d11c3c618 Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Fri, 27 Jun 2025 16:33:05 +0800 Subject: [PATCH 05/33] stream use nbx and Param --- src/ucp/endpoint/param.rs | 11 ++++- src/ucp/endpoint/stream.rs | 92 +++++++++++++++++++++++++++++++++----- 2 files changed, 92 insertions(+), 11 deletions(-) diff --git a/src/ucp/endpoint/param.rs b/src/ucp/endpoint/param.rs index 81e17d8..36889f7 100644 --- a/src/ucp/endpoint/param.rs +++ b/src/ucp/endpoint/param.rs @@ -32,6 +32,16 @@ impl RequestParam { self } + pub fn cb_stream_recv(mut self, callback: ucp_stream_recv_nbx_callback_t) -> Self { + self.inner.op_attr_mask |= ucp_op_attr_t::UCP_OP_ATTR_FIELD_CALLBACK as u32; + unsafe { + let mut cb: ucp_request_param_t__bindgen_ty_1 = std::mem::zeroed(); + cb.recv_stream = callback; + self.inner.cb = cb; + } + self + } + #[cfg(feature = "am")] pub fn cb_recv_am(mut self, callback: ucp_am_recv_data_nbx_callback_t) -> Self { self.inner.op_attr_mask |= ucp_op_attr_t::UCP_OP_ATTR_FIELD_CALLBACK as u32; @@ -79,4 +89,3 @@ impl RequestParam { &self.inner as *const _ } } - diff --git a/src/ucp/endpoint/stream.rs b/src/ucp/endpoint/stream.rs index 411a80b..d8c3eed 100644 --- a/src/ucp/endpoint/stream.rs +++ b/src/ucp/endpoint/stream.rs @@ -1,10 +1,11 @@ use super::*; +use super::param::RequestParam; impl Endpoint { /// Sends data through stream. pub async fn stream_send(&self, buf: &[u8]) -> Result { trace!("stream_send: endpoint={:?} len={}", self.handle, buf.len()); - unsafe extern "C" fn callback(request: *mut c_void, status: ucs_status_t) { + unsafe extern "C" fn callback(request: *mut c_void, status: ucs_status_t, _user_data: *mut c_void) { trace!( "stream_send: complete. req={:?}, status={:?}", request, @@ -13,14 +14,13 @@ impl Endpoint { let request = &mut *(request as *mut Request); request.waker.wake(); } + let param = RequestParam::new().send_cb(Some(callback)); let status = unsafe { - ucp_stream_send_nb( + ucp_stream_send_nbx( self.get_handle()?, buf.as_ptr() as _, buf.len() as _, - ucp_dt_make_contig(1), - Some(callback), - 0, + param.as_ref(), ) }; if status.is_null() { @@ -40,7 +40,7 @@ impl Endpoint { /// Receives data from stream. pub async fn stream_recv(&self, buf: &mut [MaybeUninit]) -> Result { trace!("stream_recv: endpoint={:?} len={}", self.handle, buf.len()); - unsafe extern "C" fn callback(request: *mut c_void, status: ucs_status_t, length: usize) { + unsafe extern "C" fn callback(request: *mut c_void, status: ucs_status_t, length: usize, _user_data: *mut c_void) { trace!( "stream_recv: complete. req={:?}, status={:?}, len={}", request, @@ -51,15 +51,14 @@ impl Endpoint { request.waker.wake(); } let mut length = MaybeUninit::::uninit(); + let param = RequestParam::new().cb_stream_recv(Some(callback)); let status = unsafe { - ucp_stream_recv_nb( + ucp_stream_recv_nbx( self.get_handle()?, buf.as_mut_ptr() as _, buf.len() as _, - ucp_dt_make_contig(1), - Some(callback), length.as_mut_ptr(), - 0, + param.as_ref(), ) }; if status.is_null() { @@ -87,3 +86,76 @@ unsafe fn poll_stream(ptr: ucs_status_ptr_t) -> Poll { Poll::Ready(len.assume_init()) } } + +#[cfg(test)] +mod tests { + use super::*; + #[test_log::test] + fn stream() { + for i in 0..20_usize { + spawn_thread!(_stream(4 << i)).join().unwrap(); + } + } + + async fn _stream(msg_size: usize) { + let context1 = Context::new().unwrap(); + let worker1 = context1.create_worker().unwrap(); + let context2 = Context::new().unwrap(); + let worker2 = context2.create_worker().unwrap(); + tokio::task::spawn_local(worker1.clone().polling()); + tokio::task::spawn_local(worker2.clone().polling()); + + // connect with each other + let mut listener = worker1 + .create_listener("0.0.0.0:0".parse().unwrap()) + .unwrap(); + let listen_port = listener.socket_addr().unwrap().port(); + println!("listen at port {}", listen_port); + let mut addr: std::net::SocketAddr = "127.0.0.1:0".parse().unwrap(); + addr.set_port(listen_port); + + let (endpoint1, endpoint2) = tokio::join!( + async { + let conn1 = listener.next().await; + worker1.accept(conn1).await.unwrap() + }, + async { worker2.connect_socket(addr).await.unwrap() }, + ); + + // send stream message + tokio::join!( + async { + // send + let buf = vec![42u8; msg_size]; + endpoint2.stream_send(&buf).await.unwrap(); + println!("stream sent"); + }, + async { + // recv + let mut buf = vec![std::mem::MaybeUninit::::uninit(); msg_size]; + let mut start = 0; + while start < msg_size { + let len = endpoint1.stream_recv(&mut buf[start..]).await.unwrap(); + if len == 0 { + break; // no more data + } + start += len; + } + let buf: Vec = unsafe { + buf.into_iter().map(|b| b.assume_init()).collect() + }; + assert_eq!(buf, vec![42u8; msg_size]); + println!("stream received"); + } + ); + + assert_eq!(endpoint1.get_rc(), (1, 1)); + assert_eq!(endpoint2.get_rc(), (1, 1)); + assert_eq!(endpoint1.close(false).await, Ok(())); + assert_eq!(endpoint2.close(false).await, Err(Error::ConnectionReset)); + assert_eq!(endpoint1.get_rc(), (1, 0)); + assert_eq!(endpoint2.get_rc(), (1, 1)); + assert_eq!(endpoint2.close(true).await, Ok(())); + assert_eq!(endpoint2.get_rc(), (1, 0)); + } +} From 8940e9bff44d5227db8bcb7ea7f0475ff5f04e96 Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Fri, 27 Jun 2025 16:43:49 +0800 Subject: [PATCH 06/33] rma use nbx --- src/ucp/endpoint/rma.rs | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/ucp/endpoint/rma.rs b/src/ucp/endpoint/rma.rs index e9cd41e..8af8bd9 100644 --- a/src/ucp/endpoint/rma.rs +++ b/src/ucp/endpoint/rma.rs @@ -1,3 +1,5 @@ +use super::param::RequestParam; + use super::*; /// A memory region allocated through UCP library, @@ -112,19 +114,20 @@ impl Endpoint { /// Stores a contiguous block of data into remote memory. pub async fn put(&self, buf: &[u8], remote_addr: u64, rkey: &RKey) -> Result<(), Error> { trace!("put: endpoint={:?} len={}", self.handle, buf.len()); - unsafe extern "C" fn callback(request: *mut c_void, status: ucs_status_t) { + unsafe extern "C" fn callback(request: *mut c_void, status: ucs_status_t, _user_data: *mut c_void) { trace!("put: complete. req={:?}, status={:?}", request, status); let request = &mut *(request as *mut Request); request.waker.wake(); } + let param = RequestParam::new().send_cb(Some(callback)); let status = unsafe { - ucp_put_nb( + ucp_put_nbx( self.get_handle()?, buf.as_ptr() as _, buf.len() as _, remote_addr, rkey.handle, - Some(callback), + param.as_ref(), ) }; if status.is_null() { @@ -144,19 +147,20 @@ impl Endpoint { /// Loads a contiguous block of data from remote memory. pub async fn get(&self, buf: &mut [u8], remote_addr: u64, rkey: &RKey) -> Result<(), Error> { trace!("get: endpoint={:?} len={}", self.handle, buf.len()); - unsafe extern "C" fn callback(request: *mut c_void, status: ucs_status_t) { + unsafe extern "C" fn callback(request: *mut c_void, status: ucs_status_t, _user_data: *mut c_void) { trace!("get: complete. req={:?}, status={:?}", request, status); let request = &mut *(request as *mut Request); request.waker.wake(); } + let param = RequestParam::new().send_cb(Some(callback)); let status = unsafe { - ucp_get_nb( + ucp_get_nbx( self.get_handle()?, buf.as_mut_ptr() as _, buf.len() as _, remote_addr, rkey.handle, - Some(callback), + param.as_ref(), ) }; if status.is_null() { From cdf6c85b9a65d443ff1871905eccfa833503c24f Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Fri, 27 Jun 2025 20:07:33 +0800 Subject: [PATCH 07/33] add author Li Kaiwei --- Cargo.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 0b06a69..6f33668 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "async-ucx" -version = "0.1.1" -authors = ["Runji Wang ", "Yiyuan Liu "] +version = "0.2.0" +authors = ["Runji Wang ", "Yiyuan Liu ", "Kaiwei Li "] edition = "2021" description = "Asynchronous Rust bindings to UCX." homepage = "https://github.com/madsys-dev/async-ucx" From 2808d4910ab760a870ce95a276839ef8b780efb1 Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Fri, 27 Jun 2025 20:14:54 +0800 Subject: [PATCH 08/33] update ucx to 1.18.1 --- ucx1-sys/ucx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ucx1-sys/ucx b/ucx1-sys/ucx index 938ffcd..d9aa565 160000 --- a/ucx1-sys/ucx +++ b/ucx1-sys/ucx @@ -1 +1 @@ -Subproject commit 938ffcd10122742d0f46a4f609e7395d1648c969 +Subproject commit d9aa5650d4cbcbb00d61af980614dbe9dd27a1f2 From 1a64dec42dd2a150a308a6e496fc244a4b56ad37 Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Fri, 27 Jun 2025 20:26:02 +0800 Subject: [PATCH 09/33] little fix --- examples/bench-multi-thread.rs | 2 +- ucx1-sys/build.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/bench-multi-thread.rs b/examples/bench-multi-thread.rs index e6b1ef2..9651c5d 100644 --- a/examples/bench-multi-thread.rs +++ b/examples/bench-multi-thread.rs @@ -123,7 +123,7 @@ impl WorkerThread { .build() .unwrap(); let local = tokio::task::LocalSet::new(); - #[cfg(not(event))] + #[cfg(not(feature = "event"))] local.spawn_local(worker.clone().polling()); #[cfg(feature = "event")] local.spawn_local(worker.clone().event_poll()); diff --git a/ucx1-sys/build.rs b/ucx1-sys/build.rs index 0f85602..2239b8e 100644 --- a/ucx1-sys/build.rs +++ b/ucx1-sys/build.rs @@ -15,7 +15,7 @@ fn main() { // Tell cargo to invalidate the built crate whenever the wrapper changes println!("cargo:rerun-if-changed=wrapper.h"); - // build_from_source(); + build_from_source(); // The bindgen::Builder is the main entry point // to bindgen, and lets you build up options for From 6a3216aa9dac17fb0f32d3628bd22a8d10c48fa3 Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Fri, 27 Jun 2025 20:31:33 +0800 Subject: [PATCH 10/33] rename send_cb -> cb_send --- src/ucp/endpoint/am.rs | 2 +- src/ucp/endpoint/param.rs | 2 +- src/ucp/endpoint/rma.rs | 4 ++-- src/ucp/endpoint/stream.rs | 2 +- src/ucp/endpoint/tag.rs | 4 ++-- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/ucp/endpoint/am.rs b/src/ucp/endpoint/am.rs index 9602d85..8e5d192 100644 --- a/src/ucp/endpoint/am.rs +++ b/src/ucp/endpoint/am.rs @@ -538,7 +538,7 @@ async fn am_send( } // Use RequestParam builder for send - let param = RequestParam::new().send_cb(Some(callback)); + let param = RequestParam::new().cb_send(Some(callback)); let param = match proto { Some(AmProto::Eager) => param.set_flag_eager(), Some(AmProto::Rndv) => param.set_flag_rndv(), diff --git a/src/ucp/endpoint/param.rs b/src/ucp/endpoint/param.rs index 36889f7..5bac074 100644 --- a/src/ucp/endpoint/param.rs +++ b/src/ucp/endpoint/param.rs @@ -12,7 +12,7 @@ impl RequestParam { } } - pub fn send_cb(mut self, callback: ucp_send_nbx_callback_t) -> Self { + pub fn cb_send(mut self, callback: ucp_send_nbx_callback_t) -> Self { self.inner.op_attr_mask |= ucp_op_attr_t::UCP_OP_ATTR_FIELD_CALLBACK as u32; unsafe { let mut cb: ucp_request_param_t__bindgen_ty_1 = std::mem::zeroed(); diff --git a/src/ucp/endpoint/rma.rs b/src/ucp/endpoint/rma.rs index 8af8bd9..f83aa21 100644 --- a/src/ucp/endpoint/rma.rs +++ b/src/ucp/endpoint/rma.rs @@ -119,7 +119,7 @@ impl Endpoint { let request = &mut *(request as *mut Request); request.waker.wake(); } - let param = RequestParam::new().send_cb(Some(callback)); + let param = RequestParam::new().cb_send(Some(callback)); let status = unsafe { ucp_put_nbx( self.get_handle()?, @@ -152,7 +152,7 @@ impl Endpoint { let request = &mut *(request as *mut Request); request.waker.wake(); } - let param = RequestParam::new().send_cb(Some(callback)); + let param = RequestParam::new().cb_send(Some(callback)); let status = unsafe { ucp_get_nbx( self.get_handle()?, diff --git a/src/ucp/endpoint/stream.rs b/src/ucp/endpoint/stream.rs index d8c3eed..ddcb540 100644 --- a/src/ucp/endpoint/stream.rs +++ b/src/ucp/endpoint/stream.rs @@ -14,7 +14,7 @@ impl Endpoint { let request = &mut *(request as *mut Request); request.waker.wake(); } - let param = RequestParam::new().send_cb(Some(callback)); + let param = RequestParam::new().cb_send(Some(callback)); let status = unsafe { ucp_stream_send_nbx( self.get_handle()?, diff --git a/src/ucp/endpoint/tag.rs b/src/ucp/endpoint/tag.rs index e8feeac..677ac96 100644 --- a/src/ucp/endpoint/tag.rs +++ b/src/ucp/endpoint/tag.rs @@ -128,7 +128,7 @@ impl Endpoint { request.waker.wake(); } // Use RequestParam builder - let param = RequestParam::new().send_cb(Some(callback)); + let param = RequestParam::new().cb_send(Some(callback)); let status = unsafe { ucp_tag_send_nbx( self.get_handle()?, @@ -173,7 +173,7 @@ impl Endpoint { request.waker.wake(); } // Use RequestParam builder for iov - let param = RequestParam::new().send_cb(Some(callback)).iov(); + let param = RequestParam::new().cb_send(Some(callback)).iov(); let status = unsafe { ucp_tag_send_nbx( self.get_handle()?, From 67691ca35d0f2e0be75d99547b353c8d9c406a44 Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Fri, 27 Jun 2025 23:53:10 +0800 Subject: [PATCH 11/33] fmt --- src/ucp/endpoint/am.rs | 2 +- src/ucp/endpoint/mod.rs | 2 +- src/ucp/endpoint/rma.rs | 12 ++++++++++-- src/ucp/endpoint/stream.rs | 19 +++++++++++++------ src/ucp/worker.rs | 6 +++++- 5 files changed, 30 insertions(+), 11 deletions(-) diff --git a/src/ucp/endpoint/am.rs b/src/ucp/endpoint/am.rs index 8e5d192..9d7f4ce 100644 --- a/src/ucp/endpoint/am.rs +++ b/src/ucp/endpoint/am.rs @@ -106,7 +106,7 @@ pub struct AmMsg<'a> { } impl<'a> AmMsg<'a> { - fn from_raw(worker: &'a Worker, msg: RawMsg) -> Self{ + fn from_raw(worker: &'a Worker, msg: RawMsg) -> Self { AmMsg { worker, msg } } diff --git a/src/ucp/endpoint/mod.rs b/src/ucp/endpoint/mod.rs index 757a20b..c3b18ae 100644 --- a/src/ucp/endpoint/mod.rs +++ b/src/ucp/endpoint/mod.rs @@ -9,10 +9,10 @@ use std::task::Poll; #[cfg(feature = "am")] mod am; +mod param; mod rma; mod stream; mod tag; -mod param; #[cfg(feature = "am")] pub use self::am::*; diff --git a/src/ucp/endpoint/rma.rs b/src/ucp/endpoint/rma.rs index f83aa21..c08e5b5 100644 --- a/src/ucp/endpoint/rma.rs +++ b/src/ucp/endpoint/rma.rs @@ -114,7 +114,11 @@ impl Endpoint { /// Stores a contiguous block of data into remote memory. pub async fn put(&self, buf: &[u8], remote_addr: u64, rkey: &RKey) -> Result<(), Error> { trace!("put: endpoint={:?} len={}", self.handle, buf.len()); - unsafe extern "C" fn callback(request: *mut c_void, status: ucs_status_t, _user_data: *mut c_void) { + unsafe extern "C" fn callback( + request: *mut c_void, + status: ucs_status_t, + _user_data: *mut c_void, + ) { trace!("put: complete. req={:?}, status={:?}", request, status); let request = &mut *(request as *mut Request); request.waker.wake(); @@ -147,7 +151,11 @@ impl Endpoint { /// Loads a contiguous block of data from remote memory. pub async fn get(&self, buf: &mut [u8], remote_addr: u64, rkey: &RKey) -> Result<(), Error> { trace!("get: endpoint={:?} len={}", self.handle, buf.len()); - unsafe extern "C" fn callback(request: *mut c_void, status: ucs_status_t, _user_data: *mut c_void) { + unsafe extern "C" fn callback( + request: *mut c_void, + status: ucs_status_t, + _user_data: *mut c_void, + ) { trace!("get: complete. req={:?}, status={:?}", request, status); let request = &mut *(request as *mut Request); request.waker.wake(); diff --git a/src/ucp/endpoint/stream.rs b/src/ucp/endpoint/stream.rs index ddcb540..a2f4bfb 100644 --- a/src/ucp/endpoint/stream.rs +++ b/src/ucp/endpoint/stream.rs @@ -1,11 +1,15 @@ -use super::*; use super::param::RequestParam; +use super::*; impl Endpoint { /// Sends data through stream. pub async fn stream_send(&self, buf: &[u8]) -> Result { trace!("stream_send: endpoint={:?} len={}", self.handle, buf.len()); - unsafe extern "C" fn callback(request: *mut c_void, status: ucs_status_t, _user_data: *mut c_void) { + unsafe extern "C" fn callback( + request: *mut c_void, + status: ucs_status_t, + _user_data: *mut c_void, + ) { trace!( "stream_send: complete. req={:?}, status={:?}", request, @@ -40,7 +44,12 @@ impl Endpoint { /// Receives data from stream. pub async fn stream_recv(&self, buf: &mut [MaybeUninit]) -> Result { trace!("stream_recv: endpoint={:?} len={}", self.handle, buf.len()); - unsafe extern "C" fn callback(request: *mut c_void, status: ucs_status_t, length: usize, _user_data: *mut c_void) { + unsafe extern "C" fn callback( + request: *mut c_void, + status: ucs_status_t, + length: usize, + _user_data: *mut c_void, + ) { trace!( "stream_recv: complete. req={:?}, status={:?}, len={}", request, @@ -141,9 +150,7 @@ mod tests { } start += len; } - let buf: Vec = unsafe { - buf.into_iter().map(|b| b.assume_init()).collect() - }; + let buf: Vec = unsafe { buf.into_iter().map(|b| b.assume_init()).collect() }; assert_eq!(buf, vec![42u8; msg_size]); println!("stream received"); } diff --git a/src/ucp/worker.rs b/src/ucp/worker.rs index d034a03..d1ad850 100644 --- a/src/ucp/worker.rs +++ b/src/ucp/worker.rs @@ -198,7 +198,11 @@ impl<'a> AsRef<[u8]> for WorkerAddress<'a> { impl<'a> Drop for WorkerAddress<'a> { fn drop(&mut self) { - trace!("destroy worker address= {:?} {:?}", self.worker.handle, self.handle); + trace!( + "destroy worker address= {:?} {:?}", + self.worker.handle, + self.handle + ); unsafe { ucp_worker_release_address(self.worker.handle, self.handle) } } } From 5b3936907036953be5e2aee22c5412d8999e500d Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Fri, 27 Jun 2025 23:53:46 +0800 Subject: [PATCH 12/33] poll fn are safe --- src/ucp/endpoint/am.rs | 13 ++----------- src/ucp/endpoint/mod.rs | 18 +++++++++--------- src/ucp/endpoint/stream.rs | 6 +++--- src/ucp/endpoint/tag.rs | 6 +++--- 4 files changed, 17 insertions(+), 26 deletions(-) diff --git a/src/ucp/endpoint/am.rs b/src/ucp/endpoint/am.rs index 9d7f4ce..378b86f 100644 --- a/src/ucp/endpoint/am.rs +++ b/src/ucp/endpoint/am.rs @@ -273,9 +273,9 @@ impl<'a> AmMsg<'a> { } else if UCS_PTR_IS_PTR(status) { RequestHandle { ptr: status, - poll_fn: poll_recv, + poll_fn: poll_normal, } - .await; + .await?; Ok(data_len) } else { Err(Error::from_ptr(status).unwrap_err()) @@ -580,15 +580,6 @@ async fn am_send( } } -unsafe fn poll_recv(ptr: ucs_status_ptr_t) -> Poll<()> { - let status = ucp_request_check_status(ptr as _); - if status == ucs_status_t::UCS_INPROGRESS { - Poll::Pending - } else { - Poll::Ready(()) - } -} - #[cfg(test)] #[cfg(feature = "am")] mod tests { diff --git a/src/ucp/endpoint/mod.rs b/src/ucp/endpoint/mod.rs index c3b18ae..3e9c546 100644 --- a/src/ucp/endpoint/mod.rs +++ b/src/ucp/endpoint/mod.rs @@ -253,7 +253,7 @@ impl Endpoint { Ok(()) } else if UCS_PTR_IS_PTR(status) { let result = loop { - if let Poll::Ready(result) = unsafe { poll_normal(status) } { + if let Poll::Ready(result) = poll_normal(status) { unsafe { ucp_request_free(status as _) }; break result; } else { @@ -302,32 +302,32 @@ impl Drop for Endpoint { } /// A handle to the request returned from async IO functions. -struct RequestHandle { +struct RequestHandle Poll> { ptr: ucs_status_ptr_t, - poll_fn: unsafe fn(ucs_status_ptr_t) -> Poll, + poll_fn: F, } -impl Future for RequestHandle { +impl Poll> Future for RequestHandle { type Output = T; fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context) -> Poll { - if let ret @ Poll::Ready(_) = unsafe { (self.poll_fn)(self.ptr) } { + if let ret @ Poll::Ready(_) = (self.poll_fn)(self.ptr) { return ret; } let request = unsafe { &mut *(self.ptr as *mut Request) }; request.waker.register(cx.waker()); - unsafe { (self.poll_fn)(self.ptr) } + (self.poll_fn)(self.ptr) } } -impl Drop for RequestHandle { +impl Poll> Drop for RequestHandle { fn drop(&mut self) { trace!("request free: {:?}", self.ptr); unsafe { ucp_request_free(self.ptr as _) }; } } -unsafe fn poll_normal(ptr: ucs_status_ptr_t) -> Poll> { - let status = ucp_request_check_status(ptr as _); +fn poll_normal(ptr: ucs_status_ptr_t) -> Poll> { + let status = unsafe { ucp_request_check_status(ptr as _) }; if status == ucs_status_t::UCS_INPROGRESS { Poll::Pending } else { diff --git a/src/ucp/endpoint/stream.rs b/src/ucp/endpoint/stream.rs index a2f4bfb..eb4303c 100644 --- a/src/ucp/endpoint/stream.rs +++ b/src/ucp/endpoint/stream.rs @@ -86,13 +86,13 @@ impl Endpoint { } } -unsafe fn poll_stream(ptr: ucs_status_ptr_t) -> Poll { +fn poll_stream(ptr: ucs_status_ptr_t) -> Poll { let mut len = MaybeUninit::::uninit(); - let status = ucp_stream_recv_request_test(ptr as _, len.as_mut_ptr() as _); + let status = unsafe { ucp_stream_recv_request_test(ptr as _, len.as_mut_ptr() as _) }; if status == ucs_status_t::UCS_INPROGRESS { Poll::Pending } else { - Poll::Ready(len.assume_init()) + Poll::Ready(unsafe { len.assume_init() }) } } diff --git a/src/ucp/endpoint/tag.rs b/src/ucp/endpoint/tag.rs index 677ac96..311c30e 100644 --- a/src/ucp/endpoint/tag.rs +++ b/src/ucp/endpoint/tag.rs @@ -199,13 +199,13 @@ impl Endpoint { } } -unsafe fn poll_tag(ptr: ucs_status_ptr_t) -> Poll> { +fn poll_tag(ptr: ucs_status_ptr_t) -> Poll> { let mut info = MaybeUninit::::uninit(); - let status = ucp_tag_recv_request_test(ptr as _, info.as_mut_ptr() as _); + let status = unsafe { ucp_tag_recv_request_test(ptr as _, info.as_mut_ptr() as _) }; match status { ucs_status_t::UCS_INPROGRESS => Poll::Pending, ucs_status_t::UCS_OK => { - let info = info.assume_init(); + let info = unsafe { info.assume_init() }; Poll::Ready(Ok((info.sender_tag, info.length as usize))) } status => Poll::Ready(Err(Error::from_error(status))), From fcad8aa8b76becccbcd25bb97e14ae54f9ca3be5 Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Sat, 28 Jun 2025 11:22:11 +0800 Subject: [PATCH 13/33] add stream Extension for AsyncRead and AsyncWrite --- Cargo.toml | 3 +- src/ucp/endpoint/mod.rs | 38 ++++++- src/ucp/endpoint/stream.rs | 210 ++++++++++++++++++++++++++++++++----- 3 files changed, 217 insertions(+), 34 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 6f33668..7dcd9a0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,9 +27,10 @@ tokio = { version = "1.0", features = ["net"], optional = true } crossbeam = { version = "0.8", optional = true } derivative = "2.2.0" thiserror = "1.0" +pin-project = "1.1.10" [dev-dependencies] -tokio = { version = "1.0", features = ["rt", "time", "macros", "sync"] } +tokio = { version = "1.0", features = ["rt", "time", "macros", "sync", "io-util"] } env_logger = "0.9" tracing = { version = "0.1", default-features = false } tracing-subscriber = { version = "0.2.17", default-features = false, features = ["env-filter", "fmt"] } diff --git a/src/ucp/endpoint/mod.rs b/src/ucp/endpoint/mod.rs index 3e9c546..369c9aa 100644 --- a/src/ucp/endpoint/mod.rs +++ b/src/ucp/endpoint/mod.rs @@ -89,12 +89,16 @@ impl Endpoint { let ptr = Weak::into_raw(weak); unsafe extern "C" fn callback(arg: *mut c_void, ep: ucp_ep_h, status: ucs_status_t) { let weak: Weak = Weak::from_raw(arg as _); + println!("error callback"); if let Some(inner) = weak.upgrade() { inner.set_status(status); // don't drop weak reference + println!("no drop"); + // panic!("{:?}",Error::from_status(status)); std::mem::forget(weak); } else { // no strong rc, force close endpoint here + println!("failed"); let status = ucp_ep_close_nb(ep, ucp_ep_close_mode::UCP_EP_CLOSE_MODE_FORCE as _); let _ = Error::from_ptr(status) .map_err(|err| error!("Force close endpoint failed, {}", err)); @@ -302,15 +306,15 @@ impl Drop for Endpoint { } /// A handle to the request returned from async IO functions. -struct RequestHandle Poll> { +struct RequestHandle { ptr: ucs_status_ptr_t, - poll_fn: F, + poll_fn: fn(ucs_status_ptr_t) -> Poll, } -impl Poll> Future for RequestHandle { +impl Future for RequestHandle { type Output = T; fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context) -> Poll { - if let ret @ Poll::Ready(_) = (self.poll_fn)(self.ptr) { + if let ret @ Poll::Ready(_) = { (self.poll_fn)(self.ptr) } { return ret; } let request = unsafe { &mut *(self.ptr as *mut Request) }; @@ -319,13 +323,37 @@ impl Poll> Future for RequestHandle { } } -impl Poll> Drop for RequestHandle { +impl Drop for RequestHandle { fn drop(&mut self) { trace!("request free: {:?}", self.ptr); unsafe { ucp_request_free(self.ptr as _) }; } } +enum Status { + Completed(Result), + Scheduled(RequestHandle>), +} + +impl Status { + pub fn new( + status: *mut c_void, + immediate: MaybeUninit, + poll_fn: fn(ucs_status_ptr_t) -> Poll>, + ) -> Self { + if status.is_null() { + Self::Completed(Ok(unsafe { immediate.assume_init() })) + } else if UCS_PTR_IS_PTR(status) { + Self::Scheduled(RequestHandle { + ptr: status, + poll_fn, + }) + } else { + Self::Completed(Err(Error::from_error(UCS_PTR_RAW_STATUS(status)))) + } + } +} + fn poll_normal(ptr: ucs_status_ptr_t) -> Poll> { let status = unsafe { ucp_request_check_status(ptr as _) }; if status == ucs_status_t::UCS_INPROGRESS { diff --git a/src/ucp/endpoint/stream.rs b/src/ucp/endpoint/stream.rs index eb4303c..28ae80a 100644 --- a/src/ucp/endpoint/stream.rs +++ b/src/ucp/endpoint/stream.rs @@ -2,8 +2,7 @@ use super::param::RequestParam; use super::*; impl Endpoint { - /// Sends data through stream. - pub async fn stream_send(&self, buf: &[u8]) -> Result { + fn stream_send_impl(&self, buf: &[u8]) -> Result, Error> { trace!("stream_send: endpoint={:?} len={}", self.handle, buf.len()); unsafe extern "C" fn callback( request: *mut c_void, @@ -27,22 +26,27 @@ impl Endpoint { param.as_ref(), ) }; - if status.is_null() { - trace!("stream_send: complete"); - } else if UCS_PTR_IS_PTR(status) { - RequestHandle { - ptr: status, - poll_fn: poll_normal, + Ok(Status::new(status, MaybeUninit::uninit(), poll_normal)) + } + + /// Sends data through stream. + pub async fn stream_send(&self, buf: &[u8]) -> Result { + match self.stream_send_impl(buf)? { + Status::Completed(r) => { + match &r { + Ok(()) => trace!("stream_send: complete"), + Err(e) => error!("stream_send error : {:?}", e), + } + r.map(|_| buf.len()) + } + Status::Scheduled(request_handle) => { + request_handle.await?; + Ok(buf.len()) } - .await?; - } else { - return Err(Error::from_ptr(status).unwrap_err()); } - Ok(buf.len()) } - /// Receives data from stream. - pub async fn stream_recv(&self, buf: &mut [MaybeUninit]) -> Result { + fn stream_recv_impl(&self, buf: &mut [MaybeUninit]) -> Result, Error> { trace!("stream_recv: endpoint={:?} len={}", self.handle, buf.len()); unsafe extern "C" fn callback( request: *mut c_void, @@ -70,35 +74,169 @@ impl Endpoint { param.as_ref(), ) }; - if status.is_null() { - let length = unsafe { length.assume_init() } as usize; - trace!("stream_recv: complete. len={}", length); - Ok(length) - } else if UCS_PTR_IS_PTR(status) { - Ok(RequestHandle { - ptr: status, - poll_fn: poll_stream, + Ok(Status::new(status, length, poll_stream)) + } + + /// Receives data from stream. + pub async fn stream_recv(&self, buf: &mut [MaybeUninit]) -> Result { + match self.stream_recv_impl(buf)? { + Status::Completed(r) => { + match &r { + Ok(x) => trace!("stream_recv: complete. len={}", x), + Err(e) => error!("stream_recv: error : {:?}", e), + } + r } - .await) - } else { - Err(Error::from_ptr(status).unwrap_err()) + Status::Scheduled(request_handle) => request_handle.await, } } } -fn poll_stream(ptr: ucs_status_ptr_t) -> Poll { +fn poll_stream(ptr: ucs_status_ptr_t) -> Poll> { let mut len = MaybeUninit::::uninit(); let status = unsafe { ucp_stream_recv_request_test(ptr as _, len.as_mut_ptr() as _) }; if status == ucs_status_t::UCS_INPROGRESS { Poll::Pending } else { - Poll::Ready(unsafe { len.assume_init() }) + Poll::Ready(Error::from_status(status).map(|_| unsafe { len.assume_init() })) + } +} + +use pin_project::pin_project; +#[pin_project] +pub struct WriteStream<'a> { + endpoint: &'a Endpoint, + #[pin] + request: Option>>, +} + +impl Endpoint { + /// make write stream + pub fn write_stream(&self) -> WriteStream { + WriteStream { + endpoint: self, + request: None, + } + } +} +use futures::FutureExt; +use std::task::ready; +use tokio::io::AsyncWrite; +use tokio::io::ReadBuf; + +fn to_io_error(e: Error) -> std::io::Error { + std::io::Error::new(std::io::ErrorKind::NotFound, e) +} + +impl<'a> AsyncWrite for WriteStream<'a> { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> Poll> { + if let Some(mut req) = self.as_mut().project().request.as_pin_mut() { + let r = match ready!(req.poll_unpin(cx)) { + Ok(_) => Ok(buf.len()), + Err(e) => Err(to_io_error(e)), + }; + self.request = None; + Poll::Ready(r) + } else { + match self.endpoint.stream_send_impl(buf) { + Ok(Status::Completed(r)) => Poll::Ready(r.map(|_| buf.len()).map_err(to_io_error)), + Ok(Status::Scheduled(request_handle)) => { + self.request = Some(request_handle); + Poll::Pending + } + Err(e) => Poll::Ready(Err(to_io_error(e))), + } + } + } + + fn poll_flush( + self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> Poll> { + todo!() + } + + fn poll_shutdown( + self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> Poll> { + todo!() + } +} + +#[pin_project] +pub struct ReadStream<'a> { + endpoint: &'a Endpoint, + #[pin] + request: Option>>, +} + +impl Endpoint { + /// make read stream + pub fn read_stream(&self) -> ReadStream { + ReadStream { + endpoint: self, + request: None, + } + } +} + +use tokio::io::AsyncRead; +impl<'a> AsyncRead for ReadStream<'a> { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + out_buf: &mut ReadBuf<'_>, + ) -> Poll> { + if let Some(mut req) = self.as_mut().project().request.as_pin_mut() { + let r = match ready!(req.poll_unpin(cx)) { + Ok(n) => { + // Safety: The buffer was filled by the recv operation. + unsafe { + out_buf.assume_init(n); + out_buf.advance(n); + } + Ok(()) + } + Err(e) => Err(to_io_error(e)), + }; + self.request = None; + Poll::Ready(r) + } else { + let buf = unsafe { out_buf.unfilled_mut() }; + match self.endpoint.stream_recv_impl(buf) { + Ok(Status::Completed(n_result)) => { + match n_result { + Ok(n) => { + // Safety: The buffer was filled by the recv operation. + unsafe { + out_buf.assume_init(n); + out_buf.advance(n); + } + Poll::Ready(Ok(())) + } + Err(e) => Poll::Ready(Err(to_io_error(e))), + } + } + Ok(Status::Scheduled(request_handle)) => { + self.request = Some(request_handle); + cx.waker().wake_by_ref(); + Poll::Pending + } + Err(e) => Poll::Ready(Err(to_io_error(e))), + } + } } } #[cfg(test)] mod tests { use super::*; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; #[test_log::test] fn stream() { for i in 0..20_usize { @@ -131,7 +269,6 @@ mod tests { async { worker2.connect_socket(addr).await.unwrap() }, ); - // send stream message tokio::join!( async { // send @@ -156,6 +293,23 @@ mod tests { } ); + tokio::join!( + async { + // send + let buf = vec![42u8; msg_size]; + endpoint2.write_stream().write_all(&buf).await.unwrap(); + println!("write_stream"); + }, + async { + // recv + let mut buf = vec![0u8; msg_size]; + endpoint1.read_stream().read_exact(&mut buf).await.unwrap(); + assert_eq!(buf, vec![42u8; msg_size]); + println!("read_stream"); + } + ); + + println!("status {:?}", endpoint2.get_status()); assert_eq!(endpoint1.get_rc(), (1, 1)); assert_eq!(endpoint2.get_rc(), (1, 1)); assert_eq!(endpoint1.close(false).await, Ok(())); From 8f48e8d417ceab5ec0f28b5a1c2e484dfae1f5b1 Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Sat, 28 Jun 2025 11:53:02 +0800 Subject: [PATCH 14/33] add more test --- src/ucp/endpoint/mod.rs | 10 +++--- src/ucp/endpoint/stream.rs | 67 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 70 insertions(+), 7 deletions(-) diff --git a/src/ucp/endpoint/mod.rs b/src/ucp/endpoint/mod.rs index 369c9aa..c8596a9 100644 --- a/src/ucp/endpoint/mod.rs +++ b/src/ucp/endpoint/mod.rs @@ -336,20 +336,20 @@ enum Status { } impl Status { - pub fn new( + pub fn from( status: *mut c_void, immediate: MaybeUninit, poll_fn: fn(ucs_status_ptr_t) -> Poll>, ) -> Self { - if status.is_null() { + if UCS_PTR_RAW_STATUS(status) == ucs_status_t::UCS_OK { Self::Completed(Ok(unsafe { immediate.assume_init() })) - } else if UCS_PTR_IS_PTR(status) { + } else if UCS_PTR_IS_ERR(status) { + Self::Completed(Err(Error::from_error(UCS_PTR_RAW_STATUS(status)))) + } else { Self::Scheduled(RequestHandle { ptr: status, poll_fn, }) - } else { - Self::Completed(Err(Error::from_error(UCS_PTR_RAW_STATUS(status)))) } } } diff --git a/src/ucp/endpoint/stream.rs b/src/ucp/endpoint/stream.rs index 28ae80a..5c18e0a 100644 --- a/src/ucp/endpoint/stream.rs +++ b/src/ucp/endpoint/stream.rs @@ -26,7 +26,7 @@ impl Endpoint { param.as_ref(), ) }; - Ok(Status::new(status, MaybeUninit::uninit(), poll_normal)) + Ok(Status::from(status, MaybeUninit::uninit(), poll_normal)) } /// Sends data through stream. @@ -74,7 +74,7 @@ impl Endpoint { param.as_ref(), ) }; - Ok(Status::new(status, length, poll_stream)) + Ok(Status::from(status, length, poll_stream)) } /// Receives data from stream. @@ -319,4 +319,67 @@ mod tests { assert_eq!(endpoint2.close(true).await, Ok(())); assert_eq!(endpoint2.get_rc(), (1, 0)); } + + #[test_log::test] + fn stream_send_recv_various_contents_and_sizes() { + spawn_thread!(_stream_send_recv_various_contents_and_sizes()) + .join() + .unwrap(); + } + + async fn _stream_send_recv_various_contents_and_sizes() { + let context1 = Context::new().unwrap(); + let worker1 = context1.create_worker().unwrap(); + let context2 = Context::new().unwrap(); + let worker2 = context2.create_worker().unwrap(); + tokio::task::spawn_local(worker1.clone().polling()); + tokio::task::spawn_local(worker2.clone().polling()); + + // connect with each other + let mut listener = worker1 + .create_listener("0.0.0.0:0".parse().unwrap()) + .unwrap(); + let listen_port = listener.socket_addr().unwrap().port(); + let mut addr: std::net::SocketAddr = "127.0.0.1:0".parse().unwrap(); + addr.set_port(listen_port); + + let (endpoint1, endpoint2) = tokio::join!( + async { + let conn1 = listener.next().await; + worker1.accept(conn1).await.unwrap() + }, + async { worker2.connect_socket(addr).await.unwrap() }, + ); + + // Test cases: (data, repeat count) + let test_cases = vec![ + (vec![], 1), + (vec![0u8], 10), + (vec![1, 2, 3, 4, 5], 5), + ((0..128).collect::>(), 3), + ((0..1024).map(|x| (x % 256) as u8).collect::>(), 2), + ((0..4096).map(|x| (x % 256) as u8).collect::>(), 1), + ]; + for (data, repeat) in test_cases { + for _ in 0..repeat { + // send + let send_buf = data.clone(); + let recv_len = send_buf.len(); + let mut recv_buf = vec![0u8; recv_len]; + tokio::join!( + async { + endpoint2.write_stream().write_all(&send_buf).await.unwrap(); + }, + async { + endpoint1 + .read_stream() + .read_exact(&mut recv_buf) + .await + .unwrap(); + assert_eq!(recv_buf, send_buf, "data mismatch for len={}", recv_len); + } + ); + } + } + } } From ad6b30097426c703ec54fbe7a4cc9235d07819ae Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Sun, 29 Jun 2025 10:22:51 +0800 Subject: [PATCH 15/33] to io error --- src/lib.rs | 41 ++++++++++++++++++++++++++++++++++++++ src/ucp/endpoint/stream.rs | 16 ++++++--------- 2 files changed, 47 insertions(+), 10 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 02cc343..87baf84 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -163,3 +163,44 @@ impl Error { } } } + +impl Into for Error { + fn into(self) -> std::io::Error { + use std::io::ErrorKind::*; + let kind = match self { + Error::Inprogress => WouldBlock, + Error::NoMessage => WouldBlock, + Error::NoReource => WouldBlock, + Error::IoError => Other, + Error::NoMemory => OutOfMemory, + Error::InvalidParam => InvalidInput, + Error::Unreachable => NotConnected, + Error::InvalidAddr => InvalidInput, + Error::NotImplemented => Unsupported, + Error::MessageTruncated => InvalidData, + Error::NoProgress => WouldBlock, + Error::BufferTooSmall => UnexpectedEof, + Error::NoElem => NotFound, + Error::SomeConnectsFailed => ConnectionAborted, + Error::NoDevice => NotFound, + Error::Busy => ResourceBusy, + Error::Canceled => Interrupted, + Error::ShmemSegment => Other, + Error::AlreadyExists => AlreadyExists, + Error::OutOfRange => InvalidInput, + Error::Timeout => TimedOut, + Error::ExceedsLimit => Other, + Error::Unsupported => Unsupported, + Error::Rejected => ConnectionRefused, + Error::NotConnected => NotConnected, + Error::ConnectionReset => ConnectionReset, + Error::FirstLinkFailure => Other, + Error::LastLinkFailure => Other, + Error::FirstEndpointFailure => Other, + Error::LastEndpointFailure => Other, + Error::EndpointTimeout => TimedOut, + Error::Unknown => Other, + }; + std::io::Error::new(kind, self) + } +} diff --git a/src/ucp/endpoint/stream.rs b/src/ucp/endpoint/stream.rs index 5c18e0a..f0948b1 100644 --- a/src/ucp/endpoint/stream.rs +++ b/src/ucp/endpoint/stream.rs @@ -124,10 +124,6 @@ use std::task::ready; use tokio::io::AsyncWrite; use tokio::io::ReadBuf; -fn to_io_error(e: Error) -> std::io::Error { - std::io::Error::new(std::io::ErrorKind::NotFound, e) -} - impl<'a> AsyncWrite for WriteStream<'a> { fn poll_write( mut self: Pin<&mut Self>, @@ -137,18 +133,18 @@ impl<'a> AsyncWrite for WriteStream<'a> { if let Some(mut req) = self.as_mut().project().request.as_pin_mut() { let r = match ready!(req.poll_unpin(cx)) { Ok(_) => Ok(buf.len()), - Err(e) => Err(to_io_error(e)), + Err(e) => Err(e.into()), }; self.request = None; Poll::Ready(r) } else { match self.endpoint.stream_send_impl(buf) { - Ok(Status::Completed(r)) => Poll::Ready(r.map(|_| buf.len()).map_err(to_io_error)), + Ok(Status::Completed(r)) => Poll::Ready(r.map(|_| buf.len()).map_err(|e| e.into())), Ok(Status::Scheduled(request_handle)) => { self.request = Some(request_handle); Poll::Pending } - Err(e) => Poll::Ready(Err(to_io_error(e))), + Err(e) => Poll::Ready(Err(e.into())), } } } @@ -202,7 +198,7 @@ impl<'a> AsyncRead for ReadStream<'a> { } Ok(()) } - Err(e) => Err(to_io_error(e)), + Err(e) => Err(e.into()), }; self.request = None; Poll::Ready(r) @@ -219,7 +215,7 @@ impl<'a> AsyncRead for ReadStream<'a> { } Poll::Ready(Ok(())) } - Err(e) => Poll::Ready(Err(to_io_error(e))), + Err(e) => Poll::Ready(Err(e.into())), } } Ok(Status::Scheduled(request_handle)) => { @@ -227,7 +223,7 @@ impl<'a> AsyncRead for ReadStream<'a> { cx.waker().wake_by_ref(); Poll::Pending } - Err(e) => Poll::Ready(Err(to_io_error(e))), + Err(e) => Poll::Ready(Err(e.into())), } } } From 1dbbc3e509c4b206e1f17b9dd6912dbe39e61ef2 Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Sun, 29 Jun 2025 10:53:10 +0800 Subject: [PATCH 16/33] add utill.rs with feature --- Cargo.toml | 1 + src/ucp/endpoint/mod.rs | 4 + src/ucp/endpoint/stream.rs | 214 +------------------------------------ src/ucp/endpoint/util.rs | 195 +++++++++++++++++++++++++++++++++ 4 files changed, 205 insertions(+), 209 deletions(-) create mode 100644 src/ucp/endpoint/util.rs diff --git a/Cargo.toml b/Cargo.toml index 7dcd9a0..429dc7c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,7 @@ categories = ["asynchronous", "api-bindings", "network-programming"] [features] event = ["tokio"] am = ["tokio/sync", "crossbeam"] +util = ["tokio"] [dependencies] ucx1-sys = { version = "0.1", path = "ucx1-sys" } diff --git a/src/ucp/endpoint/mod.rs b/src/ucp/endpoint/mod.rs index c8596a9..f86e5c9 100644 --- a/src/ucp/endpoint/mod.rs +++ b/src/ucp/endpoint/mod.rs @@ -13,10 +13,14 @@ mod param; mod rma; mod stream; mod tag; +#[cfg(feature = "util")] +mod util; #[cfg(feature = "am")] pub use self::am::*; pub use self::rma::*; +#[cfg(feature = "util")] +pub use self::util::*; // State associate with ucp_ep_h // todo: Add a `get_user_data` to UCX diff --git a/src/ucp/endpoint/stream.rs b/src/ucp/endpoint/stream.rs index f0948b1..d9c2f5b 100644 --- a/src/ucp/endpoint/stream.rs +++ b/src/ucp/endpoint/stream.rs @@ -2,7 +2,7 @@ use super::param::RequestParam; use super::*; impl Endpoint { - fn stream_send_impl(&self, buf: &[u8]) -> Result, Error> { + pub(super) fn stream_send_impl(&self, buf: &[u8]) -> Result, Error> { trace!("stream_send: endpoint={:?} len={}", self.handle, buf.len()); unsafe extern "C" fn callback( request: *mut c_void, @@ -46,7 +46,10 @@ impl Endpoint { } } - fn stream_recv_impl(&self, buf: &mut [MaybeUninit]) -> Result, Error> { + pub(super) fn stream_recv_impl( + &self, + buf: &mut [MaybeUninit], + ) -> Result, Error> { trace!("stream_recv: endpoint={:?} len={}", self.handle, buf.len()); unsafe extern "C" fn callback( request: *mut c_void, @@ -102,137 +105,9 @@ fn poll_stream(ptr: ucs_status_ptr_t) -> Poll> { } } -use pin_project::pin_project; -#[pin_project] -pub struct WriteStream<'a> { - endpoint: &'a Endpoint, - #[pin] - request: Option>>, -} - -impl Endpoint { - /// make write stream - pub fn write_stream(&self) -> WriteStream { - WriteStream { - endpoint: self, - request: None, - } - } -} -use futures::FutureExt; -use std::task::ready; -use tokio::io::AsyncWrite; -use tokio::io::ReadBuf; - -impl<'a> AsyncWrite for WriteStream<'a> { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &[u8], - ) -> Poll> { - if let Some(mut req) = self.as_mut().project().request.as_pin_mut() { - let r = match ready!(req.poll_unpin(cx)) { - Ok(_) => Ok(buf.len()), - Err(e) => Err(e.into()), - }; - self.request = None; - Poll::Ready(r) - } else { - match self.endpoint.stream_send_impl(buf) { - Ok(Status::Completed(r)) => Poll::Ready(r.map(|_| buf.len()).map_err(|e| e.into())), - Ok(Status::Scheduled(request_handle)) => { - self.request = Some(request_handle); - Poll::Pending - } - Err(e) => Poll::Ready(Err(e.into())), - } - } - } - - fn poll_flush( - self: Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, - ) -> Poll> { - todo!() - } - - fn poll_shutdown( - self: Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, - ) -> Poll> { - todo!() - } -} - -#[pin_project] -pub struct ReadStream<'a> { - endpoint: &'a Endpoint, - #[pin] - request: Option>>, -} - -impl Endpoint { - /// make read stream - pub fn read_stream(&self) -> ReadStream { - ReadStream { - endpoint: self, - request: None, - } - } -} - -use tokio::io::AsyncRead; -impl<'a> AsyncRead for ReadStream<'a> { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - out_buf: &mut ReadBuf<'_>, - ) -> Poll> { - if let Some(mut req) = self.as_mut().project().request.as_pin_mut() { - let r = match ready!(req.poll_unpin(cx)) { - Ok(n) => { - // Safety: The buffer was filled by the recv operation. - unsafe { - out_buf.assume_init(n); - out_buf.advance(n); - } - Ok(()) - } - Err(e) => Err(e.into()), - }; - self.request = None; - Poll::Ready(r) - } else { - let buf = unsafe { out_buf.unfilled_mut() }; - match self.endpoint.stream_recv_impl(buf) { - Ok(Status::Completed(n_result)) => { - match n_result { - Ok(n) => { - // Safety: The buffer was filled by the recv operation. - unsafe { - out_buf.assume_init(n); - out_buf.advance(n); - } - Poll::Ready(Ok(())) - } - Err(e) => Poll::Ready(Err(e.into())), - } - } - Ok(Status::Scheduled(request_handle)) => { - self.request = Some(request_handle); - cx.waker().wake_by_ref(); - Poll::Pending - } - Err(e) => Poll::Ready(Err(e.into())), - } - } - } -} - #[cfg(test)] mod tests { use super::*; - use tokio::io::{AsyncReadExt, AsyncWriteExt}; #[test_log::test] fn stream() { for i in 0..20_usize { @@ -289,22 +164,6 @@ mod tests { } ); - tokio::join!( - async { - // send - let buf = vec![42u8; msg_size]; - endpoint2.write_stream().write_all(&buf).await.unwrap(); - println!("write_stream"); - }, - async { - // recv - let mut buf = vec![0u8; msg_size]; - endpoint1.read_stream().read_exact(&mut buf).await.unwrap(); - assert_eq!(buf, vec![42u8; msg_size]); - println!("read_stream"); - } - ); - println!("status {:?}", endpoint2.get_status()); assert_eq!(endpoint1.get_rc(), (1, 1)); assert_eq!(endpoint2.get_rc(), (1, 1)); @@ -315,67 +174,4 @@ mod tests { assert_eq!(endpoint2.close(true).await, Ok(())); assert_eq!(endpoint2.get_rc(), (1, 0)); } - - #[test_log::test] - fn stream_send_recv_various_contents_and_sizes() { - spawn_thread!(_stream_send_recv_various_contents_and_sizes()) - .join() - .unwrap(); - } - - async fn _stream_send_recv_various_contents_and_sizes() { - let context1 = Context::new().unwrap(); - let worker1 = context1.create_worker().unwrap(); - let context2 = Context::new().unwrap(); - let worker2 = context2.create_worker().unwrap(); - tokio::task::spawn_local(worker1.clone().polling()); - tokio::task::spawn_local(worker2.clone().polling()); - - // connect with each other - let mut listener = worker1 - .create_listener("0.0.0.0:0".parse().unwrap()) - .unwrap(); - let listen_port = listener.socket_addr().unwrap().port(); - let mut addr: std::net::SocketAddr = "127.0.0.1:0".parse().unwrap(); - addr.set_port(listen_port); - - let (endpoint1, endpoint2) = tokio::join!( - async { - let conn1 = listener.next().await; - worker1.accept(conn1).await.unwrap() - }, - async { worker2.connect_socket(addr).await.unwrap() }, - ); - - // Test cases: (data, repeat count) - let test_cases = vec![ - (vec![], 1), - (vec![0u8], 10), - (vec![1, 2, 3, 4, 5], 5), - ((0..128).collect::>(), 3), - ((0..1024).map(|x| (x % 256) as u8).collect::>(), 2), - ((0..4096).map(|x| (x % 256) as u8).collect::>(), 1), - ]; - for (data, repeat) in test_cases { - for _ in 0..repeat { - // send - let send_buf = data.clone(); - let recv_len = send_buf.len(); - let mut recv_buf = vec![0u8; recv_len]; - tokio::join!( - async { - endpoint2.write_stream().write_all(&send_buf).await.unwrap(); - }, - async { - endpoint1 - .read_stream() - .read_exact(&mut recv_buf) - .await - .unwrap(); - assert_eq!(recv_buf, send_buf, "data mismatch for len={}", recv_len); - } - ); - } - } - } } diff --git a/src/ucp/endpoint/util.rs b/src/ucp/endpoint/util.rs new file mode 100644 index 0000000..66e6895 --- /dev/null +++ b/src/ucp/endpoint/util.rs @@ -0,0 +1,195 @@ +use super::*; +use futures::FutureExt; +use pin_project::pin_project; +use std::task::ready; +use tokio::io::AsyncRead; +use tokio::io::AsyncWrite; +use tokio::io::ReadBuf; + +impl Endpoint { + /// make write stream + pub fn write_stream(&self) -> WriteStream { + WriteStream { + endpoint: self, + request: None, + } + } + /// make read stream + pub fn read_stream(&self) -> ReadStream { + ReadStream { + endpoint: self, + request: None, + } + } +} + +#[pin_project] +/// A stream for writing data asynchronously to an `Endpoint` stream. +pub struct WriteStream<'a> { + endpoint: &'a Endpoint, + #[pin] + request: Option>>, +} + +impl<'a> AsyncWrite for WriteStream<'a> { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> Poll> { + if let Some(mut req) = self.as_mut().project().request.as_pin_mut() { + let r = match ready!(req.poll_unpin(cx)) { + Ok(_) => Ok(buf.len()), + Err(e) => Err(e.into()), + }; + self.request = None; + Poll::Ready(r) + } else { + match self.endpoint.stream_send_impl(buf) { + Ok(Status::Completed(r)) => Poll::Ready(r.map(|_| buf.len()).map_err(|e| e.into())), + Ok(Status::Scheduled(request_handle)) => { + self.request = Some(request_handle); + Poll::Pending + } + Err(e) => Poll::Ready(Err(e.into())), + } + } + } + + fn poll_flush( + self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> Poll> { + todo!() + } + + fn poll_shutdown( + self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> Poll> { + todo!() + } +} + +/// A stream for reading data asynchronously from an `Endpoint` stream. +#[pin_project] +pub struct ReadStream<'a> { + endpoint: &'a Endpoint, + #[pin] + request: Option>>, +} + +impl<'a> AsyncRead for ReadStream<'a> { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + out_buf: &mut ReadBuf<'_>, + ) -> Poll> { + if let Some(mut req) = self.as_mut().project().request.as_pin_mut() { + let r = match ready!(req.poll_unpin(cx)) { + Ok(n) => { + // Safety: The buffer was filled by the recv operation. + unsafe { + out_buf.assume_init(n); + out_buf.advance(n); + } + Ok(()) + } + Err(e) => Err(e.into()), + }; + self.request = None; + Poll::Ready(r) + } else { + let buf = unsafe { out_buf.unfilled_mut() }; + match self.endpoint.stream_recv_impl(buf) { + Ok(Status::Completed(n_result)) => { + match n_result { + Ok(n) => { + // Safety: The buffer was filled by the recv operation. + unsafe { + out_buf.assume_init(n); + out_buf.advance(n); + } + Poll::Ready(Ok(())) + } + Err(e) => Poll::Ready(Err(e.into())), + } + } + Ok(Status::Scheduled(request_handle)) => { + self.request = Some(request_handle); + cx.waker().wake_by_ref(); + Poll::Pending + } + Err(e) => Poll::Ready(Err(e.into())), + } + } + } +} + +#[cfg(test)] +mod test { + + use super::*; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + #[test_log::test] + fn stream_send_recv() { + spawn_thread!(_stream_send_recv()).join().unwrap(); + } + + async fn _stream_send_recv() { + let context1 = Context::new().unwrap(); + let worker1 = context1.create_worker().unwrap(); + let context2 = Context::new().unwrap(); + let worker2 = context2.create_worker().unwrap(); + tokio::task::spawn_local(worker1.clone().polling()); + tokio::task::spawn_local(worker2.clone().polling()); + + // connect with each other + let mut listener = worker1 + .create_listener("0.0.0.0:0".parse().unwrap()) + .unwrap(); + let listen_port = listener.socket_addr().unwrap().port(); + let mut addr: std::net::SocketAddr = "127.0.0.1:0".parse().unwrap(); + addr.set_port(listen_port); + + let (endpoint1, endpoint2) = tokio::join!( + async { + let conn1 = listener.next().await; + worker1.accept(conn1).await.unwrap() + }, + async { worker2.connect_socket(addr).await.unwrap() }, + ); + + // Test cases: (data, repeat count) + let test_cases = vec![ + (vec![], 1), + (vec![0u8], 10), + (vec![1, 2, 3, 4, 5], 5), + ((0..128).collect::>(), 3), + ((0..1024).map(|x| (x % 256) as u8).collect::>(), 2), + ((0..4096).map(|x| (x % 256) as u8).collect::>(), 1), + ]; + for (data, repeat) in test_cases { + for _ in 0..repeat { + // send + let send_buf = data.clone(); + let recv_len = send_buf.len(); + let mut recv_buf = vec![0u8; recv_len]; + tokio::join!( + async { + endpoint2.write_stream().write_all(&send_buf).await.unwrap(); + }, + async { + endpoint1 + .read_stream() + .read_exact(&mut recv_buf) + .await + .unwrap(); + assert_eq!(recv_buf, send_buf, "data mismatch for len={}", recv_len); + } + ); + } + } + } +} From bed08b8d4c048dd5924250cbd4c55abcc10830a1 Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Sun, 29 Jun 2025 20:07:22 +0800 Subject: [PATCH 17/33] Update endpoint tag and util modules --- src/ucp/endpoint/tag.rs | 103 +++++++++++--------- src/ucp/endpoint/util.rs | 200 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 256 insertions(+), 47 deletions(-) diff --git a/src/ucp/endpoint/tag.rs b/src/ucp/endpoint/tag.rs index 311c30e..1f7b7c3 100644 --- a/src/ucp/endpoint/tag.rs +++ b/src/ucp/endpoint/tag.rs @@ -17,12 +17,22 @@ impl Worker { tag_mask: u64, buf: &mut [MaybeUninit], ) -> Result<(u64, usize), Error> { + match self.tag_recv_impl(tag, tag_mask, buf)? { + Status::Completed(r) => r, + Status::Scheduled(request_handle) => request_handle.await, + } + } + + /// Like `tag_recv`, except that it reads into a slice of buffers. + pub async fn tag_recv_vectored( + &self, + tag: u64, + iov: &mut [IoSliceMut<'_>], + ) -> Result { trace!( - "tag_recv: worker={:?}, tag={}, mask={:#x} len={}", + "tag_recv_vectored: worker={:?} iov.len={}", self.handle, - tag, - tag_mask, - buf.len() + iov.len() ); unsafe extern "C" fn callback( request: *mut c_void, @@ -31,48 +41,50 @@ impl Worker { _user_data: *mut c_void, ) { let length = (*info).length; - let sender_tag = (*info).sender_tag; + let tag = (*info).sender_tag; trace!( - "tag_recv: complete. req={:?}, status={:?}, tag={}, len={}", + "tag_recv_vectored: complete. req={:?}, status={:?}, tag={}, len={}", request, status, - sender_tag, + tag, length ); let request = &mut *(request as *mut Request); request.waker.wake(); } - // Use RequestParam builder - let param = RequestParam::new().cb_tag_recv(Some(callback)); + // Use RequestParam builder for iov + let param = RequestParam::new().cb_tag_recv(Some(callback)).iov(); let status = unsafe { ucp_tag_recv_nbx( self.handle, - buf.as_mut_ptr() as _, - buf.len() as _, + iov.as_ptr() as _, + iov.len() as _, tag, - tag_mask, + u64::max_value(), param.as_ref(), ) }; - Error::from_ptr(status)?; RequestHandle { ptr: status, poll_fn: poll_tag, } .await + .map(|info| info.1) } - /// Like `tag_recv`, except that it reads into a slice of buffers. - pub async fn tag_recv_vectored( + pub(super) fn tag_recv_impl( &self, tag: u64, - iov: &mut [IoSliceMut<'_>], - ) -> Result { + tag_mask: u64, + buf: &mut [MaybeUninit], + ) -> Result, Error> { trace!( - "tag_recv_vectored: worker={:?} iov.len={}", + "tag_recv: worker={:?}, tag={}, mask={:#x} len={}", self.handle, - iov.len() + tag, + tag_mask, + buf.len() ); unsafe extern "C" fn callback( request: *mut c_void, @@ -81,42 +93,34 @@ impl Worker { _user_data: *mut c_void, ) { let length = (*info).length; - let tag = (*info).sender_tag; + let sender_tag = (*info).sender_tag; trace!( - "tag_recv_vectored: complete. req={:?}, status={:?}, tag={}, len={}", + "tag_recv: complete. req={:?}, status={:?}, tag={}, len={}", request, status, - tag, + sender_tag, length ); let request = &mut *(request as *mut Request); request.waker.wake(); } - // Use RequestParam builder for iov - let param = RequestParam::new().cb_tag_recv(Some(callback)).iov(); + let param = RequestParam::new().cb_tag_recv(Some(callback)); let status = unsafe { ucp_tag_recv_nbx( self.handle, - iov.as_ptr() as _, - iov.len() as _, + buf.as_mut_ptr() as _, + buf.len() as _, tag, - u64::max_value(), + tag_mask, param.as_ref(), ) }; - Error::from_ptr(status)?; - RequestHandle { - ptr: status, - poll_fn: poll_tag, - } - .await - .map(|info| info.1) + Ok(Status::from(status, MaybeUninit::uninit(), poll_tag)) } } impl Endpoint { - /// Sends a messages with `tag`. - pub async fn tag_send(&self, tag: u64, buf: &[u8]) -> Result { + pub(super) fn tag_send_impl(&self, tag: u64, buf: &[u8]) -> Result, Error> { trace!("tag_send: endpoint={:?} len={}", self.handle, buf.len()); unsafe extern "C" fn callback( request: *mut c_void, @@ -127,7 +131,6 @@ impl Endpoint { let request = &mut *(request as *mut Request); request.waker.wake(); } - // Use RequestParam builder let param = RequestParam::new().cb_send(Some(callback)); let status = unsafe { ucp_tag_send_nbx( @@ -138,18 +141,24 @@ impl Endpoint { param.as_ref(), ) }; - if status.is_null() { - trace!("tag_send: complete"); - } else if UCS_PTR_IS_PTR(status) { - RequestHandle { - ptr: status, - poll_fn: poll_normal, + Ok(Status::from(status, MaybeUninit::uninit(), poll_normal)) + } + + /// Sends a messages with `tag`. + pub async fn tag_send(&self, tag: u64, buf: &[u8]) -> Result { + match self.tag_send_impl(tag, buf)? { + Status::Completed(r) => { + match &r { + Ok(()) => trace!("tag_send: complete"), + Err(e) => error!("tag_send error : {:?}", e), + } + r.map(|_| buf.len()) + } + Status::Scheduled(request_handle) => { + request_handle.await?; + Ok(buf.len()) } - .await?; - } else { - return Err(Error::from_ptr(status).unwrap_err()); } - Ok(buf.len()) } /// Like `tag_send`, except that it reads into a slice of buffers. diff --git a/src/ucp/endpoint/util.rs b/src/ucp/endpoint/util.rs index 66e6895..a633725 100644 --- a/src/ucp/endpoint/util.rs +++ b/src/ucp/endpoint/util.rs @@ -21,6 +21,35 @@ impl Endpoint { request: None, } } + /// make tag write stream + pub fn tag_write_stream(&self, tag: u64) -> TagWriteStream { + TagWriteStream { + endpoint: self, + tag, + request: None, + } + } +} + +impl Worker { + /// make tag read stream + pub fn tag_read_stream(&self, tag: u64) -> TagReadStream { + TagReadStream { + worker: self, + tag, + tag_mask: u64::max_value(), + request: None, + } + } + /// make tag read stream with mask + pub fn tag_read_stream_mask(&self, tag: u64, tag_mask: u64) -> TagReadStream { + TagReadStream { + worker: self, + tag, + tag_mask, + request: None, + } + } } #[pin_project] @@ -126,6 +155,112 @@ impl<'a> AsyncRead for ReadStream<'a> { } } +#[pin_project] +/// A stream for writing data asynchronously to an `Endpoint` using tag. +pub struct TagWriteStream<'a> { + endpoint: &'a Endpoint, + tag: u64, + #[pin] + request: Option>>, +} + +impl<'a> AsyncWrite for TagWriteStream<'a> { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> Poll> { + if let Some(mut req) = self.as_mut().project().request.as_pin_mut() { + let r = match ready!(req.poll_unpin(cx)) { + Ok(_) => Ok(buf.len()), + Err(e) => Err(e.into()), + }; + self.request = None; + Poll::Ready(r) + } else { + match self.endpoint.tag_send_impl(self.tag, buf) { + Ok(Status::Completed(r)) => Poll::Ready(r.map(|_| buf.len()).map_err(|e| e.into())), + Ok(Status::Scheduled(request_handle)) => { + self.request = Some(request_handle); + Poll::Pending + } + Err(e) => Poll::Ready(Err(e.into())), + } + } + } + + fn poll_flush( + self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> Poll> { + todo!() + } + + fn poll_shutdown( + self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> Poll> { + todo!() + } +} + +/// A stream for reading data asynchronously from a `Worker` using tag. +#[pin_project] +pub struct TagReadStream<'a> { + worker: &'a Worker, + tag: u64, + tag_mask: u64, + #[pin] + request: Option>>, +} + +impl<'a> AsyncRead for TagReadStream<'a> { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + out_buf: &mut ReadBuf<'_>, + ) -> Poll> { + if let Some(mut req) = self.as_mut().project().request.as_pin_mut() { + let r = match ready!(req.poll_unpin(cx)) { + Ok((_, n)) => { + // Safety: The buffer was filled by the recv operation. + unsafe { + out_buf.assume_init(n); + out_buf.advance(n); + } + Ok(()) + } + Err(e) => Err(e.into()), + }; + self.request = None; + Poll::Ready(r) + } else { + let buf = unsafe { out_buf.unfilled_mut() }; + match self.worker.tag_recv_impl(self.tag, self.tag_mask, buf) { + Ok(Status::Completed(n_result)) => { + match n_result { + Ok((_, n)) => { + // Safety: The buffer was filled by the recv operation. + unsafe { + out_buf.assume_init(n); + out_buf.advance(n); + } + Poll::Ready(Ok(())) + } + Err(e) => Poll::Ready(Err(e.into())), + } + } + Ok(Status::Scheduled(request_handle)) => { + self.request = Some(request_handle); + cx.waker().wake_by_ref(); + Poll::Pending + } + Err(e) => Poll::Ready(Err(e.into())), + } + } + } +} + #[cfg(test)] mod test { @@ -137,6 +272,11 @@ mod test { spawn_thread!(_stream_send_recv()).join().unwrap(); } + #[test_log::test] + fn tag_send_recv() { + spawn_thread!(_tag_send_recv()).join().unwrap(); + } + async fn _stream_send_recv() { let context1 = Context::new().unwrap(); let worker1 = context1.create_worker().unwrap(); @@ -192,4 +332,64 @@ mod test { } } } + + async fn _tag_send_recv() { + let context1 = Context::new().unwrap(); + let worker1 = context1.create_worker().unwrap(); + let context2 = Context::new().unwrap(); + let worker2 = context2.create_worker().unwrap(); + tokio::task::spawn_local(worker1.clone().polling()); + tokio::task::spawn_local(worker2.clone().polling()); + + // connect with each other + let mut listener = worker1 + .create_listener("0.0.0.0:0".parse().unwrap()) + .unwrap(); + let listen_port = listener.socket_addr().unwrap().port(); + let mut addr: std::net::SocketAddr = "127.0.0.1:0".parse().unwrap(); + addr.set_port(listen_port); + + let (endpoint1, endpoint2) = tokio::join!( + async { + let conn1 = listener.next().await; + worker1.accept(conn1).await.unwrap() + }, + async { worker2.connect_socket(addr).await.unwrap() }, + ); + + // Test cases: (data, tag, repeat count) + let test_cases = vec![ + (vec![], 1u64, 1), + (vec![0u8], 2u64, 10), + (vec![1, 2, 3, 4, 5], 3u64, 5), + ((0..128).collect::>(), 4u64, 3), + ((0..1024).map(|x| (x % 256) as u8).collect::>(), 5u64, 2), + ((0..4096).map(|x| (x % 256) as u8).collect::>(), 6u64, 1), + ]; + for (data, tag, repeat) in test_cases { + for _ in 0..repeat { + // send + let send_buf = data.clone(); + let recv_len = send_buf.len(); + let mut recv_buf = vec![0u8; recv_len]; + tokio::join!( + async { + endpoint2.tag_write_stream(tag).write_all(&send_buf).await.unwrap(); + }, + async { + worker1 + .tag_read_stream(tag) + .read_exact(&mut recv_buf) + .await + .unwrap(); + assert_eq!(recv_buf, send_buf, "data mismatch for tag={}, len={}", tag, recv_len); + } + ); + } + } + + // Clean up + let _ = endpoint1.close(false).await; + let _ = endpoint2.close(false).await; + } } From 5abf8d40da7c12507dc442cbca96e6512f4eec4a Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Sun, 29 Jun 2025 20:47:20 +0800 Subject: [PATCH 18/33] update readme and changlog --- CHANGELOG.md | 17 ++++++++++ README.md | 88 ++++++++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 102 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6f49d8a..62f2c51 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,23 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.2.0] - 2025-06-29 + +### Added + +- Added AsyncRead/AsyncWrite support for Tag and Stream (requires `utils` feature flag) +- Added `connect_addr_vec` function to Worker + +### Changed + +- Updated to UCX 1.18 with latest API compatibility +- Updated multiple dependency versions +- Migrated to Rust 2021 edition + +### Fixed + +- Fixed various bugs and issues + ## [0.1.1] - 2022-09-01 ### Changed diff --git a/README.md b/README.md index 2f47499..04e7414 100644 --- a/README.md +++ b/README.md @@ -4,12 +4,94 @@ [![Docs](https://docs.rs/async-ucx/badge.svg)](https://docs.rs/async-ucx) [![CI](https://github.com/madsys-dev/async-ucx/workflows/CI/badge.svg?branch=main)](https://github.com/madsys-dev/async-ucx/actions) -Async Rust UCX bindings. +Async Rust UCX bindings providing high-performance networking capabilities for distributed systems and HPC applications. + +## Features + +- **Asynchronous UCP Operations**: Full async/await support for UCX operations +- **Multiple Communication Models**: Support for RMA, Stream, Tag, and Active Message APIs +- **High Performance**: Optimized for low-latency, high-throughput communication +- **Tokio Integration**: Seamless integration with Tokio async runtime +- **Comprehensive Examples**: Ready-to-use examples for various UCX patterns ## Optional features -- `event`: Enable UCP wakeup mechanism. -- `am`: Enable UCP Active Message API. +- `event`: Enable UCP wakeup mechanism for event-driven applications +- `am`: Enable UCP Active Message API for flexible message handling +- `util`: Enable additional utility functions for UCX integration + +## Quick Start + +Add to your `Cargo.toml`: + +```toml +[dependencies] +async-ucx = "0.2" +tokio = { version = "1.0", features = ["rt", "net"] } +``` + +Basic usage example: + +```rust +use async_ucx::ucp::*; +use std::mem::MaybeUninit; +use std::net::SocketAddr; + +#[tokio::main(flavor = "current_thread")] +async fn main() -> Result<(), Box> { + // Create UCP contexts and workers + let context1 = Context::new()?; + let worker1 = context1.create_worker()?; + let context2 = Context::new()?; + let worker2 = context2.create_worker()?; + + // Start polling for both workers + tokio::task::spawn_local(worker1.clone().polling()); + tokio::task::spawn_local(worker2.clone().polling()); + + // Create listener on worker1 + let mut listener = worker1 + .create_listener("0.0.0.0:0".parse().unwrap())?; + let listen_port = listener.socket_addr()?.port(); + + // Connect worker2 to worker1 + let mut addr: SocketAddr = "127.0.0.1:0".parse().unwrap(); + addr.set_port(listen_port); + + let (endpoint1, endpoint2) = tokio::join!( + async { + let conn1 = listener.next().await; + worker1.accept(conn1).await.unwrap() + }, + async { worker2.connect_socket(addr).await.unwrap() }, + ); + + // Send and receive tag message + tokio::join!( + async { + let msg = b"Hello UCX!"; + endpoint2.tag_send(1, msg).await.unwrap(); + println!("Message sent"); + }, + async { + let mut buf = vec![MaybeUninit::::uninit(); 10]; + worker1.tag_recv(1, &mut buf).await.unwrap(); + println!("Message received"); + } + ); + + Ok(()) +} +``` + +## Examples + +Check the `examples/` directory for comprehensive examples: +- `rma.rs`: Remote Memory Access operations +- `stream.rs`: Stream-based communication +- `tag.rs`: Tag-based message matching +- `bench.rs`: Performance benchmarking +- `bench-multi-thread.rs`: Multi-threaded benchmarking ## License From 2150e47a11b7c1761251b19ffc4b1bb89581a98f Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Sun, 29 Jun 2025 20:51:44 +0800 Subject: [PATCH 19/33] fix remote println! --- src/ucp/endpoint/mod.rs | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/ucp/endpoint/mod.rs b/src/ucp/endpoint/mod.rs index f86e5c9..ffd9f22 100644 --- a/src/ucp/endpoint/mod.rs +++ b/src/ucp/endpoint/mod.rs @@ -93,16 +93,12 @@ impl Endpoint { let ptr = Weak::into_raw(weak); unsafe extern "C" fn callback(arg: *mut c_void, ep: ucp_ep_h, status: ucs_status_t) { let weak: Weak = Weak::from_raw(arg as _); - println!("error callback"); if let Some(inner) = weak.upgrade() { inner.set_status(status); // don't drop weak reference - println!("no drop"); - // panic!("{:?}",Error::from_status(status)); std::mem::forget(weak); } else { // no strong rc, force close endpoint here - println!("failed"); let status = ucp_ep_close_nb(ep, ucp_ep_close_mode::UCP_EP_CLOSE_MODE_FORCE as _); let _ = Error::from_ptr(status) .map_err(|err| error!("Force close endpoint failed, {}", err)); From 35ce5c43031a57e03605d1fdbb5a5ebaf896d978 Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Sun, 29 Jun 2025 21:01:22 +0800 Subject: [PATCH 20/33] cargo fmt --- src/ucp/endpoint/util.rs | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/src/ucp/endpoint/util.rs b/src/ucp/endpoint/util.rs index a633725..c404aa3 100644 --- a/src/ucp/endpoint/util.rs +++ b/src/ucp/endpoint/util.rs @@ -363,8 +363,16 @@ mod test { (vec![0u8], 2u64, 10), (vec![1, 2, 3, 4, 5], 3u64, 5), ((0..128).collect::>(), 4u64, 3), - ((0..1024).map(|x| (x % 256) as u8).collect::>(), 5u64, 2), - ((0..4096).map(|x| (x % 256) as u8).collect::>(), 6u64, 1), + ( + (0..1024).map(|x| (x % 256) as u8).collect::>(), + 5u64, + 2, + ), + ( + (0..4096).map(|x| (x % 256) as u8).collect::>(), + 6u64, + 1, + ), ]; for (data, tag, repeat) in test_cases { for _ in 0..repeat { @@ -374,7 +382,11 @@ mod test { let mut recv_buf = vec![0u8; recv_len]; tokio::join!( async { - endpoint2.tag_write_stream(tag).write_all(&send_buf).await.unwrap(); + endpoint2 + .tag_write_stream(tag) + .write_all(&send_buf) + .await + .unwrap(); }, async { worker1 @@ -382,7 +394,11 @@ mod test { .read_exact(&mut recv_buf) .await .unwrap(); - assert_eq!(recv_buf, send_buf, "data mismatch for tag={}, len={}", tag, recv_len); + assert_eq!( + recv_buf, send_buf, + "data mismatch for tag={}, len={}", + tag, recv_len + ); } ); } From 09b79fdbee4673842f98f12aa898042c7657597e Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Sun, 29 Jun 2025 21:45:56 +0800 Subject: [PATCH 21/33] feat: improve tag receive implementation with ucp_tag_recv_info --- src/ucp/endpoint/param.rs | 6 +++ src/ucp/endpoint/tag.rs | 82 +++++++++++++++++++++++++++++++++++---- src/ucp/endpoint/util.rs | 15 +++---- 3 files changed, 88 insertions(+), 15 deletions(-) diff --git a/src/ucp/endpoint/param.rs b/src/ucp/endpoint/param.rs index 5bac074..92ea65c 100644 --- a/src/ucp/endpoint/param.rs +++ b/src/ucp/endpoint/param.rs @@ -42,6 +42,12 @@ impl RequestParam { self } + pub fn recv_tag_info(mut self, info: *mut ucp_tag_recv_info) -> Self { + self.inner.op_attr_mask |= ucp_op_attr_t::UCP_OP_ATTR_FIELD_RECV_INFO as u32; + self.inner.recv_info.tag_info = info; + self + } + #[cfg(feature = "am")] pub fn cb_recv_am(mut self, callback: ucp_am_recv_data_nbx_callback_t) -> Self { self.inner.op_attr_mask |= ucp_op_attr_t::UCP_OP_ATTR_FIELD_CALLBACK as u32; diff --git a/src/ucp/endpoint/tag.rs b/src/ucp/endpoint/tag.rs index 1f7b7c3..ae55fcb 100644 --- a/src/ucp/endpoint/tag.rs +++ b/src/ucp/endpoint/tag.rs @@ -18,8 +18,11 @@ impl Worker { buf: &mut [MaybeUninit], ) -> Result<(u64, usize), Error> { match self.tag_recv_impl(tag, tag_mask, buf)? { - Status::Completed(r) => r, - Status::Scheduled(request_handle) => request_handle.await, + Status::Completed(r) => r.map(|info| (info.sender_tag, info.length as usize)), + Status::Scheduled(request_handle) => { + let info = request_handle.await?; + Ok((info.sender_tag, info.length as usize)) + } } } @@ -70,7 +73,7 @@ impl Worker { poll_fn: poll_tag, } .await - .map(|info| info.1) + .map(|info| info.length as usize) } pub(super) fn tag_recv_impl( @@ -78,7 +81,7 @@ impl Worker { tag: u64, tag_mask: u64, buf: &mut [MaybeUninit], - ) -> Result, Error> { + ) -> Result, Error> { trace!( "tag_recv: worker={:?}, tag={}, mask={:#x} len={}", self.handle, @@ -104,7 +107,10 @@ impl Worker { let request = &mut *(request as *mut Request); request.waker.wake(); } - let param = RequestParam::new().cb_tag_recv(Some(callback)); + let mut info = MaybeUninit::::uninit(); + let param = RequestParam::new() + .cb_tag_recv(Some(callback)) + .recv_tag_info(info.as_mut_ptr() as _); let status = unsafe { ucp_tag_recv_nbx( self.handle, @@ -115,7 +121,7 @@ impl Worker { param.as_ref(), ) }; - Ok(Status::from(status, MaybeUninit::uninit(), poll_tag)) + Ok(Status::from(status, info, poll_tag)) } } @@ -208,14 +214,14 @@ impl Endpoint { } } -fn poll_tag(ptr: ucs_status_ptr_t) -> Poll> { +fn poll_tag(ptr: ucs_status_ptr_t) -> Poll> { let mut info = MaybeUninit::::uninit(); let status = unsafe { ucp_tag_recv_request_test(ptr as _, info.as_mut_ptr() as _) }; match status { ucs_status_t::UCS_INPROGRESS => Poll::Pending, ucs_status_t::UCS_OK => { let info = unsafe { info.assume_init() }; - Poll::Ready(Ok((info.sender_tag, info.length as usize))) + Poll::Ready(Ok(info)) } status => Poll::Ready(Err(Error::from_error(status))), } @@ -281,4 +287,64 @@ mod tests { assert_eq!(endpoint2.close(true).await, Ok(())); assert_eq!(endpoint2.get_rc(), (1, 0)); } + + #[test_log::test] + fn multi_tag() { + for i in 0..20_usize { + spawn_thread!(_multi_tag(4 << i)).join().unwrap(); + } + } + + async fn _multi_tag(msg_size: usize) { + let context1 = Context::new().unwrap(); + let worker1 = context1.create_worker().unwrap(); + let context2 = Context::new().unwrap(); + let worker2 = context2.create_worker().unwrap(); + tokio::task::spawn_local(worker1.clone().polling()); + tokio::task::spawn_local(worker2.clone().polling()); + + // connect with each other + let mut listener = worker1 + .create_listener("0.0.0.0:0".parse().unwrap()) + .unwrap(); + let listen_port = listener.socket_addr().unwrap().port(); + println!("listen at port {}", listen_port); + let mut addr: SocketAddr = "127.0.0.1:0".parse().unwrap(); + addr.set_port(listen_port); + + let (endpoint1, endpoint2) = tokio::join!( + async { + let conn1 = listener.next().await; + worker1.accept(conn1).await.unwrap() + }, + async { worker2.connect_socket(addr).await.unwrap() }, + ); + + // send tag message + tokio::join!( + async { + // send + let mut buf = vec![0; msg_size]; + endpoint2.tag_send(3, &mut buf).await.unwrap(); + println!("tag sended"); + }, + async { + // recv + let mut buf = vec![MaybeUninit::::uninit(); msg_size]; + let (tag, size) = worker1.tag_recv_mask(0, 0, &mut buf).await.unwrap(); + assert_eq!(size, msg_size); + assert_eq!(tag, 3); + println!("tag recved"); + } + ); + + assert_eq!(endpoint1.get_rc(), (1, 1)); + assert_eq!(endpoint2.get_rc(), (1, 1)); + assert_eq!(endpoint1.close(false).await, Ok(())); + assert_eq!(endpoint2.close(false).await, Err(Error::ConnectionReset)); + assert_eq!(endpoint1.get_rc(), (1, 0)); + assert_eq!(endpoint2.get_rc(), (1, 1)); + assert_eq!(endpoint2.close(true).await, Ok(())); + assert_eq!(endpoint2.get_rc(), (1, 0)); + } } diff --git a/src/ucp/endpoint/util.rs b/src/ucp/endpoint/util.rs index c404aa3..7fcf254 100644 --- a/src/ucp/endpoint/util.rs +++ b/src/ucp/endpoint/util.rs @@ -42,6 +42,7 @@ impl Worker { } } /// make tag read stream with mask + /// not suggested to use this function, because actual received tag should be checked by user pub fn tag_read_stream_mask(&self, tag: u64, tag_mask: u64) -> TagReadStream { TagReadStream { worker: self, @@ -211,7 +212,7 @@ pub struct TagReadStream<'a> { tag: u64, tag_mask: u64, #[pin] - request: Option>>, + request: Option>>, } impl<'a> AsyncRead for TagReadStream<'a> { @@ -222,11 +223,11 @@ impl<'a> AsyncRead for TagReadStream<'a> { ) -> Poll> { if let Some(mut req) = self.as_mut().project().request.as_pin_mut() { let r = match ready!(req.poll_unpin(cx)) { - Ok((_, n)) => { + Ok(info) => { // Safety: The buffer was filled by the recv operation. unsafe { - out_buf.assume_init(n); - out_buf.advance(n); + out_buf.assume_init(info.length as usize); + out_buf.advance(info.length as usize); } Ok(()) } @@ -239,11 +240,11 @@ impl<'a> AsyncRead for TagReadStream<'a> { match self.worker.tag_recv_impl(self.tag, self.tag_mask, buf) { Ok(Status::Completed(n_result)) => { match n_result { - Ok((_, n)) => { + Ok(info) => { // Safety: The buffer was filled by the recv operation. unsafe { - out_buf.assume_init(n); - out_buf.advance(n); + out_buf.assume_init(info.length as usize); + out_buf.advance(info.length as usize); } Poll::Ready(Ok(())) } From 59a03112c869c6fe810cc75e80e3bf798e282222 Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Thu, 3 Jul 2025 22:01:15 +0800 Subject: [PATCH 22/33] impl poll_flush and poll_shutdown --- src/ucp/endpoint/util.rs | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/src/ucp/endpoint/util.rs b/src/ucp/endpoint/util.rs index 7fcf254..57480ff 100644 --- a/src/ucp/endpoint/util.rs +++ b/src/ucp/endpoint/util.rs @@ -87,17 +87,23 @@ impl<'a> AsyncWrite for WriteStream<'a> { } fn poll_flush( - self: Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, ) -> Poll> { - todo!() + if let Some(mut req) = self.as_mut().project().request.as_pin_mut() { + let r = ready!(req.poll_unpin(cx)); + self.request = None; + Poll::Ready(r.map_err(|e| e.into())) + } else { + Poll::Ready(Ok(())) + } } fn poll_shutdown( self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>, ) -> Poll> { - todo!() + Poll::Ready(Ok(())) } } @@ -191,17 +197,23 @@ impl<'a> AsyncWrite for TagWriteStream<'a> { } fn poll_flush( - self: Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, ) -> Poll> { - todo!() + if let Some(mut req) = self.as_mut().project().request.as_pin_mut() { + let r = ready!(req.poll_unpin(cx)); + self.request = None; + Poll::Ready(r.map_err(|e| e.into())) + } else { + Poll::Ready(Ok(())) + } } fn poll_shutdown( self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>, ) -> Poll> { - todo!() + Poll::Ready(Ok(())) } } From a1fc9b9c1901e6ff34cfba0f98616e745a64ef97 Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Tue, 15 Jul 2025 15:37:59 +0800 Subject: [PATCH 23/33] endpoint handler --- src/ucp/endpoint/mod.rs | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/ucp/endpoint/mod.rs b/src/ucp/endpoint/mod.rs index ffd9f22..c563956 100644 --- a/src/ucp/endpoint/mod.rs +++ b/src/ucp/endpoint/mod.rs @@ -83,6 +83,12 @@ pub struct Endpoint { inner: Rc, } +/// Type alias of Endpoint handler +#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord)] +pub struct EndpointHandler(ucp_ep_h); +unsafe impl Sync for EndpointHandler {} +unsafe impl Send for EndpointHandler {} + impl Endpoint { fn create(worker: &Rc, mut params: ucp_ep_params) -> Result { let inner = Rc::new(EndpointInner::new(worker.clone())); @@ -205,6 +211,11 @@ impl Endpoint { Ok(self.handle) } + /// Get the endpoint handler + pub fn handler(&self) -> Result { + Ok(EndpointHandler(self.get_handle()?)) + } + /// Print endpoint information to stderr. pub fn print_to_stderr(&self) { if !self.inner.is_closed() { From 965c7e819de42eaa7559a370fab738407a8b2236 Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Tue, 15 Jul 2025 15:49:40 +0800 Subject: [PATCH 24/33] add reply_ep --- src/ucp/endpoint/am.rs | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/ucp/endpoint/am.rs b/src/ucp/endpoint/am.rs index 378b86f..6fdc2b9 100644 --- a/src/ucp/endpoint/am.rs +++ b/src/ucp/endpoint/am.rs @@ -295,6 +295,15 @@ impl<'a> AmMsg<'a> { && !self.msg.reply_ep.is_null() } + /// return endpoint handler + pub fn reply_ep(&self) -> Option { + if self.need_reply() { + Some(EndpointHandler(self.msg.reply_ep)) + } else { + None + } + } + /// Send reply /// # Safety /// User needs to ensure that the endpoint isn't closed. From 3263c6ecd38071cbb663cf2bb625f3b04cf38623 Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Wed, 16 Jul 2025 17:39:28 +0800 Subject: [PATCH 25/33] am id is u16 --- src/ucp/endpoint/am.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/ucp/endpoint/am.rs b/src/ucp/endpoint/am.rs index 6fdc2b9..f8c949b 100644 --- a/src/ucp/endpoint/am.rs +++ b/src/ucp/endpoint/am.rs @@ -309,7 +309,7 @@ impl<'a> AmMsg<'a> { /// User needs to ensure that the endpoint isn't closed. pub async unsafe fn reply( &self, - id: u32, + id: u16, header: &[u8], data: &[u8], need_reply: bool, @@ -327,7 +327,7 @@ impl<'a> AmMsg<'a> { /// User needs to ensure that the endpoint isn't closed. pub async unsafe fn reply_vectorized( &self, - id: u32, + id: u16, header: &[u8], data: &[IoSlice<'_>], need_reply: bool, @@ -497,7 +497,7 @@ impl Endpoint { /// Send active message. pub async fn am_send( &self, - id: u32, + id: u16, header: &[u8], data: &[u8], need_reply: bool, @@ -511,7 +511,7 @@ impl Endpoint { /// Send active message. pub async fn am_send_vectorized( &self, - id: u32, + id: u16, header: &[u8], data: &[IoSlice<'_>], need_reply: bool, @@ -534,7 +534,7 @@ pub enum AmProto { async fn am_send( endpoint: ucp_ep_h, - id: u32, + id: u16, header: &[u8], data: &[IoSlice<'_>], need_reply: bool, @@ -567,7 +567,7 @@ async fn am_send( let status = unsafe { ucp_am_send_nbx( endpoint, - id, + id as u32, header.as_ptr() as _, header.len() as _, buffer as _, From 0107ef0f8c4a2a10748387f4928f89cbb575b512 Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Mon, 21 Jul 2025 12:03:43 +0800 Subject: [PATCH 26/33] fix TagWriteStream --- src/ucp/endpoint/util.rs | 50 ++++++++++++++++++++++------------------ 1 file changed, 27 insertions(+), 23 deletions(-) diff --git a/src/ucp/endpoint/util.rs b/src/ucp/endpoint/util.rs index 57480ff..ac372d7 100644 --- a/src/ucp/endpoint/util.rs +++ b/src/ucp/endpoint/util.rs @@ -177,36 +177,40 @@ impl<'a> AsyncWrite for TagWriteStream<'a> { cx: &mut std::task::Context<'_>, buf: &[u8], ) -> Poll> { - if let Some(mut req) = self.as_mut().project().request.as_pin_mut() { - let r = match ready!(req.poll_unpin(cx)) { - Ok(_) => Ok(buf.len()), - Err(e) => Err(e.into()), - }; - self.request = None; - Poll::Ready(r) - } else { - match self.endpoint.tag_send_impl(self.tag, buf) { - Ok(Status::Completed(r)) => Poll::Ready(r.map(|_| buf.len()).map_err(|e| e.into())), - Ok(Status::Scheduled(request_handle)) => { - self.request = Some(request_handle); - Poll::Pending + loop { + if let Some(mut req) = self.as_mut().project().request.as_pin_mut() { + let r = match ready!(req.poll_unpin(cx)) { + Ok(_) => Ok(buf.len()), + Err(e) => Err(e.into()), + }; + self.request = None; + return Poll::Ready(r); + } else { + match self.endpoint.tag_send_impl(self.tag, buf) { + Ok(Status::Completed(r)) => { + return Poll::Ready(r.map(|_| buf.len()).map_err(|e| e.into())); + } + Ok(Status::Scheduled(request_handle)) => { + self.request = Some(request_handle); + } + Err(e) => return Poll::Ready(Err(e.into())), } - Err(e) => Poll::Ready(Err(e.into())), } } } fn poll_flush( - mut self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, + self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, ) -> Poll> { - if let Some(mut req) = self.as_mut().project().request.as_pin_mut() { - let r = ready!(req.poll_unpin(cx)); - self.request = None; - Poll::Ready(r.map_err(|e| e.into())) - } else { - Poll::Ready(Ok(())) - } + assert!(self.request.is_none()); + // if let Some(mut req) = self.as_mut().project().request.as_pin_mut() { + // let r = ready!(req.poll_unpin(cx)); + // self.request = None; + // Poll::Ready(r.map_err(|e| e.into())) + // } else { + Poll::Ready(Ok(())) + // } } fn poll_shutdown( From 5ac7cd1860020824a3886c50cad8bf83d3231725 Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Mon, 21 Jul 2025 12:05:56 +0800 Subject: [PATCH 27/33] fix WriteStream --- src/ucp/endpoint/util.rs | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/src/ucp/endpoint/util.rs b/src/ucp/endpoint/util.rs index ac372d7..0b0db5e 100644 --- a/src/ucp/endpoint/util.rs +++ b/src/ucp/endpoint/util.rs @@ -67,21 +67,24 @@ impl<'a> AsyncWrite for WriteStream<'a> { cx: &mut std::task::Context<'_>, buf: &[u8], ) -> Poll> { - if let Some(mut req) = self.as_mut().project().request.as_pin_mut() { - let r = match ready!(req.poll_unpin(cx)) { - Ok(_) => Ok(buf.len()), - Err(e) => Err(e.into()), - }; - self.request = None; - Poll::Ready(r) - } else { - match self.endpoint.stream_send_impl(buf) { - Ok(Status::Completed(r)) => Poll::Ready(r.map(|_| buf.len()).map_err(|e| e.into())), - Ok(Status::Scheduled(request_handle)) => { - self.request = Some(request_handle); - Poll::Pending + loop { + if let Some(mut req) = self.as_mut().project().request.as_pin_mut() { + let r = match ready!(req.poll_unpin(cx)) { + Ok(_) => Ok(buf.len()), + Err(e) => Err(e.into()), + }; + self.request = None; + return Poll::Ready(r); + } else { + match self.endpoint.stream_send_impl(buf) { + Ok(Status::Completed(r)) => { + return Poll::Ready(r.map(|_| buf.len()).map_err(|e| e.into())) + } + Ok(Status::Scheduled(request_handle)) => { + self.request = Some(request_handle); + } + Err(e) => return Poll::Ready(Err(e.into())), } - Err(e) => Poll::Ready(Err(e.into())), } } } From c8ec53712f1842dd71e5a5b16006f27438d85abf Mon Sep 17 00:00:00 2001 From: Kaiwei Li Date: Wed, 24 Sep 2025 13:11:38 +0800 Subject: [PATCH 28/33] Create rust.yml --- .github/workflows/rust.yml | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 .github/workflows/rust.yml diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml new file mode 100644 index 0000000..9fd45e0 --- /dev/null +++ b/.github/workflows/rust.yml @@ -0,0 +1,22 @@ +name: Rust + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +env: + CARGO_TERM_COLOR: always + +jobs: + build: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - name: Build + run: cargo build --verbose + - name: Run tests + run: cargo test --verbose From 62f4afadfa0a25cad1a1d9aab96105645644903b Mon Sep 17 00:00:00 2001 From: Kaiwei Li Date: Wed, 24 Sep 2025 13:30:22 +0800 Subject: [PATCH 29/33] cargo clippy --fix --- src/lib.rs | 8 ++++---- src/ucp/endpoint/am.rs | 20 ++++++++++---------- src/ucp/endpoint/tag.rs | 4 ++-- src/ucp/endpoint/util.rs | 8 ++++---- src/ucp/mod.rs | 2 +- src/ucp/worker.rs | 2 +- 6 files changed, 22 insertions(+), 22 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 87baf84..23ac9fe 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -164,10 +164,10 @@ impl Error { } } -impl Into for Error { - fn into(self) -> std::io::Error { +impl From for std::io::Error { + fn from(val: Error) -> Self { use std::io::ErrorKind::*; - let kind = match self { + let kind = match val { Error::Inprogress => WouldBlock, Error::NoMessage => WouldBlock, Error::NoReource => WouldBlock, @@ -201,6 +201,6 @@ impl Into for Error { Error::EndpointTimeout => TimedOut, Error::Unknown => Other, }; - std::io::Error::new(kind, self) + std::io::Error::new(kind, val) } } diff --git a/src/ucp/endpoint/am.rs b/src/ucp/endpoint/am.rs index f8c949b..046098e 100644 --- a/src/ucp/endpoint/am.rs +++ b/src/ucp/endpoint/am.rs @@ -9,7 +9,7 @@ use std::{ sync::atomic::AtomicBool, }; -//// Active message protocol. +/// Active message protocol. /// Active message protocol is a mechanism for sending and receiving messages /// between processes in a distributed system. /// It allows a process to send a message to another process, which can then @@ -222,7 +222,7 @@ impl<'a> AmMsg<'a> { } Some(AmData::Data(data)) => { // data message, no need to receive - let size = copy_data_to_iov(&data, iov)?; + let size = copy_data_to_iov(data, iov)?; self.drop_msg(AmData::Data(data)); Ok(size) } @@ -439,8 +439,8 @@ impl Worker { param: *const ucp_am_recv_param_t, ) -> ucs_status_t { let handler = &*(arg as *const AmStreamInner); - let header = slice::from_raw_parts(header as *const u8, header_len as usize); - let data = slice::from_raw_parts(data as *const u8, data_len as usize); + let header = slice::from_raw_parts(header as *const u8, header_len); + let data = slice::from_raw_parts(data as *const u8, data_len); let param = &*param; handler.callback(header, data, param.reply_ep, param.recv_attr); @@ -460,7 +460,7 @@ impl Worker { } self.am_streams.write().unwrap().insert(id, stream.clone()); - return Ok(AmStream::new(self, stream)); + Ok(AmStream::new(self, stream)) } /// Register active message handler for `id`. @@ -596,7 +596,7 @@ mod tests { #[test_log::test] fn am() { - let protos = vec![None, Some(AmProto::Eager), Some(AmProto::Rndv)]; + let protos = [None, Some(AmProto::Eager), Some(AmProto::Rndv)]; for block_size_shift in 0..20_usize { for p in protos.iter() { let rt = tokio::runtime::Builder::new_current_thread() @@ -650,13 +650,13 @@ mod tests { let msg = stream1.wait_msg().await; let mut msg = msg.expect("no msg"); assert_eq!(msg.header(), &header); - assert_eq!(msg.contains_data(), true); + assert!(msg.contains_data()); assert_eq!(msg.data_len(), data.len()); let mut recv_data = vec![0_u8; msg.data_len()]; let recv_len = msg.recv_data_single(&mut recv_data).await.unwrap(); assert_eq!(data.len(), recv_len); assert_eq!(data, recv_data); - assert_eq!(msg.contains_data(), false); + assert!(!msg.contains_data()); msg } ); @@ -674,11 +674,11 @@ mod tests { let reply = stream2.wait_msg().await; let mut reply = reply.expect("no reply"); assert_eq!(reply.header(), &header); - assert_eq!(reply.contains_data(), true); + assert!(reply.contains_data()); assert_eq!(reply.data_len(), data.len()); let recv_data = reply.recv_data().await.unwrap(); assert_eq!(data, recv_data); - assert_eq!(reply.contains_data(), false); + assert!(!reply.contains_data()); } ); diff --git a/src/ucp/endpoint/tag.rs b/src/ucp/endpoint/tag.rs index ae55fcb..5d5e75c 100644 --- a/src/ucp/endpoint/tag.rs +++ b/src/ucp/endpoint/tag.rs @@ -18,7 +18,7 @@ impl Worker { buf: &mut [MaybeUninit], ) -> Result<(u64, usize), Error> { match self.tag_recv_impl(tag, tag_mask, buf)? { - Status::Completed(r) => r.map(|info| (info.sender_tag, info.length as usize)), + Status::Completed(r) => r.map(|info| (info.sender_tag, info.length)), Status::Scheduled(request_handle) => { let info = request_handle.await?; Ok((info.sender_tag, info.length as usize)) @@ -73,7 +73,7 @@ impl Worker { poll_fn: poll_tag, } .await - .map(|info| info.length as usize) + .map(|info| info.length) } pub(super) fn tag_recv_impl( diff --git a/src/ucp/endpoint/util.rs b/src/ucp/endpoint/util.rs index 0b0db5e..b26ee34 100644 --- a/src/ucp/endpoint/util.rs +++ b/src/ucp/endpoint/util.rs @@ -245,8 +245,8 @@ impl<'a> AsyncRead for TagReadStream<'a> { Ok(info) => { // Safety: The buffer was filled by the recv operation. unsafe { - out_buf.assume_init(info.length as usize); - out_buf.advance(info.length as usize); + out_buf.assume_init(info.length); + out_buf.advance(info.length); } Ok(()) } @@ -262,8 +262,8 @@ impl<'a> AsyncRead for TagReadStream<'a> { Ok(info) => { // Safety: The buffer was filled by the recv operation. unsafe { - out_buf.assume_init(info.length as usize); - out_buf.advance(info.length as usize); + out_buf.assume_init(info.length); + out_buf.advance(info.length); } Poll::Ready(Ok(())) } diff --git a/src/ucp/mod.rs b/src/ucp/mod.rs index 440e6be..8f003d8 100644 --- a/src/ucp/mod.rs +++ b/src/ucp/mod.rs @@ -91,7 +91,7 @@ impl Context { | ucp_params_field::UCP_PARAM_FIELD_MT_WORKERS_SHARED) .0 as u64, features: features.0 as u64, - request_size: std::mem::size_of::() as usize, + request_size: std::mem::size_of::(), request_init: Some(Request::init), request_cleanup: Some(Request::cleanup), mt_workers_shared: 1, diff --git a/src/ucp/worker.rs b/src/ucp/worker.rs index d1ad850..1f142e7 100644 --- a/src/ucp/worker.rs +++ b/src/ucp/worker.rs @@ -107,7 +107,7 @@ impl Worker { Ok(WorkerAddress { handle: unsafe { handle.assume_init() }, - length: unsafe { length.assume_init() } as usize, + length: unsafe { length.assume_init() }, worker: self, }) } From 50f1487cb2a54f0b7173903b0220f350d1e66cca Mon Sep 17 00:00:00 2001 From: Kaiwei Li Date: Wed, 24 Sep 2025 13:33:31 +0800 Subject: [PATCH 30/33] cargo fix --- src/ucp/endpoint/util.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/ucp/endpoint/util.rs b/src/ucp/endpoint/util.rs index b26ee34..b007fb2 100644 --- a/src/ucp/endpoint/util.rs +++ b/src/ucp/endpoint/util.rs @@ -8,21 +8,21 @@ use tokio::io::ReadBuf; impl Endpoint { /// make write stream - pub fn write_stream(&self) -> WriteStream { + pub fn write_stream(&self) -> WriteStream<'_> { WriteStream { endpoint: self, request: None, } } /// make read stream - pub fn read_stream(&self) -> ReadStream { + pub fn read_stream(&self) -> ReadStream<'_> { ReadStream { endpoint: self, request: None, } } /// make tag write stream - pub fn tag_write_stream(&self, tag: u64) -> TagWriteStream { + pub fn tag_write_stream(&self, tag: u64) -> TagWriteStream<'_> { TagWriteStream { endpoint: self, tag, @@ -33,7 +33,7 @@ impl Endpoint { impl Worker { /// make tag read stream - pub fn tag_read_stream(&self, tag: u64) -> TagReadStream { + pub fn tag_read_stream(&self, tag: u64) -> TagReadStream<'_> { TagReadStream { worker: self, tag, @@ -43,7 +43,7 @@ impl Worker { } /// make tag read stream with mask /// not suggested to use this function, because actual received tag should be checked by user - pub fn tag_read_stream_mask(&self, tag: u64, tag_mask: u64) -> TagReadStream { + pub fn tag_read_stream_mask(&self, tag: u64, tag_mask: u64) -> TagReadStream<'_> { TagReadStream { worker: self, tag, From b80b161a6235568e6f09807ab5eab99b8d4b466d Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Wed, 24 Sep 2025 16:36:38 +0800 Subject: [PATCH 31/33] feat(ucx1-sys): bump version to 0.2.0 --- Cargo.toml | 2 +- ucx1-sys/Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 429dc7c..e7b10c6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,7 +18,7 @@ am = ["tokio/sync", "crossbeam"] util = ["tokio"] [dependencies] -ucx1-sys = { version = "0.1", path = "ucx1-sys" } +ucx1-sys = { version = "0.2", path = "ucx1-sys" } socket2 = "0.4" futures = "0.3" futures-lite = "1.11" diff --git a/ucx1-sys/Cargo.toml b/ucx1-sys/Cargo.toml index d71ff71..6b0b713 100644 --- a/ucx1-sys/Cargo.toml +++ b/ucx1-sys/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ucx1-sys" -version = "0.1.0" +version = "0.2.0" authors = ["Runji Wang "] edition = "2021" description = "Rust FFI bindings to UCX." From 146a73b8045c4348e15b5cf005917915a6df8707 Mon Sep 17 00:00:00 2001 From: Kaiwei Li Date: Thu, 25 Sep 2025 17:23:50 +0800 Subject: [PATCH 32/33] fix testing on rust 1.90.0 --- src/ucp/endpoint/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ucp/endpoint/mod.rs b/src/ucp/endpoint/mod.rs index c563956..dd2958d 100644 --- a/src/ucp/endpoint/mod.rs +++ b/src/ucp/endpoint/mod.rs @@ -352,7 +352,7 @@ impl Status { immediate: MaybeUninit, poll_fn: fn(ucs_status_ptr_t) -> Poll>, ) -> Self { - if UCS_PTR_RAW_STATUS(status) == ucs_status_t::UCS_OK { + if status.is_null() { Self::Completed(Ok(unsafe { immediate.assume_init() })) } else if UCS_PTR_IS_ERR(status) { Self::Completed(Err(Error::from_error(UCS_PTR_RAW_STATUS(status)))) From 60726a1e219a559fff07f0191479b30b2cb0d1eb Mon Sep 17 00:00:00 2001 From: Kaiwei Li Date: Tue, 30 Sep 2025 15:45:16 +0800 Subject: [PATCH 33/33] build(deps): add alloc feature to crossbeam dependency --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index e7b10c6..afa3184 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,7 +25,7 @@ futures-lite = "1.11" lazy_static = "1.4" log = "0.4" tokio = { version = "1.0", features = ["net"], optional = true } -crossbeam = { version = "0.8", optional = true } +crossbeam = { version = "0.8", features = ["alloc"], optional = true } derivative = "2.2.0" thiserror = "1.0" pin-project = "1.1.10"