diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml new file mode 100644 index 0000000..9fd45e0 --- /dev/null +++ b/.github/workflows/rust.yml @@ -0,0 +1,22 @@ +name: Rust + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +env: + CARGO_TERM_COLOR: always + +jobs: + build: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - name: Build + run: cargo build --verbose + - name: Run tests + run: cargo test --verbose diff --git a/.gitignore b/.gitignore index 96ef6c0..057c110 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,4 @@ +.vscode/ + /target Cargo.lock diff --git a/CHANGELOG.md b/CHANGELOG.md index 6f49d8a..62f2c51 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,23 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.2.0] - 2025-06-29 + +### Added + +- Added AsyncRead/AsyncWrite support for Tag and Stream (requires `utils` feature flag) +- Added `connect_addr_vec` function to Worker + +### Changed + +- Updated to UCX 1.18 with latest API compatibility +- Updated multiple dependency versions +- Migrated to Rust 2021 edition + +### Fixed + +- Fixed various bugs and issues + ## [0.1.1] - 2022-09-01 ### Changed diff --git a/Cargo.toml b/Cargo.toml index 0b06a69..afa3184 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "async-ucx" -version = "0.1.1" -authors = ["Runji Wang ", "Yiyuan Liu "] +version = "0.2.0" +authors = ["Runji Wang ", "Yiyuan Liu ", "Kaiwei Li "] edition = "2021" description = "Asynchronous Rust bindings to UCX." homepage = "https://github.com/madsys-dev/async-ucx" @@ -15,21 +15,23 @@ categories = ["asynchronous", "api-bindings", "network-programming"] [features] event = ["tokio"] am = ["tokio/sync", "crossbeam"] +util = ["tokio"] [dependencies] -ucx1-sys = { version = "0.1", path = "ucx1-sys" } +ucx1-sys = { version = "0.2", path = "ucx1-sys" } socket2 = "0.4" futures = "0.3" futures-lite = "1.11" lazy_static = "1.4" log = "0.4" tokio = { version = "1.0", features = ["net"], optional = true } -crossbeam = { version = "0.8", optional = true } +crossbeam = { version = "0.8", features = ["alloc"], optional = true } derivative = "2.2.0" thiserror = "1.0" +pin-project = "1.1.10" [dev-dependencies] -tokio = { version = "1.0", features = ["rt", "time", "macros", "sync"] } +tokio = { version = "1.0", features = ["rt", "time", "macros", "sync", "io-util"] } env_logger = "0.9" tracing = { version = "0.1", default-features = false } tracing-subscriber = { version = "0.2.17", default-features = false, features = ["env-filter", "fmt"] } diff --git a/README.md b/README.md index 2f47499..04e7414 100644 --- a/README.md +++ b/README.md @@ -4,12 +4,94 @@ [![Docs](https://docs.rs/async-ucx/badge.svg)](https://docs.rs/async-ucx) [![CI](https://github.com/madsys-dev/async-ucx/workflows/CI/badge.svg?branch=main)](https://github.com/madsys-dev/async-ucx/actions) -Async Rust UCX bindings. +Async Rust UCX bindings providing high-performance networking capabilities for distributed systems and HPC applications. + +## Features + +- **Asynchronous UCP Operations**: Full async/await support for UCX operations +- **Multiple Communication Models**: Support for RMA, Stream, Tag, and Active Message APIs +- **High Performance**: Optimized for low-latency, high-throughput communication +- **Tokio Integration**: Seamless integration with Tokio async runtime +- **Comprehensive Examples**: Ready-to-use examples for various UCX patterns ## Optional features -- `event`: Enable UCP wakeup mechanism. -- `am`: Enable UCP Active Message API. +- `event`: Enable UCP wakeup mechanism for event-driven applications +- `am`: Enable UCP Active Message API for flexible message handling +- `util`: Enable additional utility functions for UCX integration + +## Quick Start + +Add to your `Cargo.toml`: + +```toml +[dependencies] +async-ucx = "0.2" +tokio = { version = "1.0", features = ["rt", "net"] } +``` + +Basic usage example: + +```rust +use async_ucx::ucp::*; +use std::mem::MaybeUninit; +use std::net::SocketAddr; + +#[tokio::main(flavor = "current_thread")] +async fn main() -> Result<(), Box> { + // Create UCP contexts and workers + let context1 = Context::new()?; + let worker1 = context1.create_worker()?; + let context2 = Context::new()?; + let worker2 = context2.create_worker()?; + + // Start polling for both workers + tokio::task::spawn_local(worker1.clone().polling()); + tokio::task::spawn_local(worker2.clone().polling()); + + // Create listener on worker1 + let mut listener = worker1 + .create_listener("0.0.0.0:0".parse().unwrap())?; + let listen_port = listener.socket_addr()?.port(); + + // Connect worker2 to worker1 + let mut addr: SocketAddr = "127.0.0.1:0".parse().unwrap(); + addr.set_port(listen_port); + + let (endpoint1, endpoint2) = tokio::join!( + async { + let conn1 = listener.next().await; + worker1.accept(conn1).await.unwrap() + }, + async { worker2.connect_socket(addr).await.unwrap() }, + ); + + // Send and receive tag message + tokio::join!( + async { + let msg = b"Hello UCX!"; + endpoint2.tag_send(1, msg).await.unwrap(); + println!("Message sent"); + }, + async { + let mut buf = vec![MaybeUninit::::uninit(); 10]; + worker1.tag_recv(1, &mut buf).await.unwrap(); + println!("Message received"); + } + ); + + Ok(()) +} +``` + +## Examples + +Check the `examples/` directory for comprehensive examples: +- `rma.rs`: Remote Memory Access operations +- `stream.rs`: Stream-based communication +- `tag.rs`: Tag-based message matching +- `bench.rs`: Performance benchmarking +- `bench-multi-thread.rs`: Multi-threaded benchmarking ## License diff --git a/examples/bench-multi-thread.rs b/examples/bench-multi-thread.rs index e6b1ef2..9651c5d 100644 --- a/examples/bench-multi-thread.rs +++ b/examples/bench-multi-thread.rs @@ -123,7 +123,7 @@ impl WorkerThread { .build() .unwrap(); let local = tokio::task::LocalSet::new(); - #[cfg(not(event))] + #[cfg(not(feature = "event"))] local.spawn_local(worker.clone().polling()); #[cfg(feature = "event")] local.spawn_local(worker.clone().event_poll()); diff --git a/src/lib.rs b/src/lib.rs index 02cc343..23ac9fe 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -163,3 +163,44 @@ impl Error { } } } + +impl From for std::io::Error { + fn from(val: Error) -> Self { + use std::io::ErrorKind::*; + let kind = match val { + Error::Inprogress => WouldBlock, + Error::NoMessage => WouldBlock, + Error::NoReource => WouldBlock, + Error::IoError => Other, + Error::NoMemory => OutOfMemory, + Error::InvalidParam => InvalidInput, + Error::Unreachable => NotConnected, + Error::InvalidAddr => InvalidInput, + Error::NotImplemented => Unsupported, + Error::MessageTruncated => InvalidData, + Error::NoProgress => WouldBlock, + Error::BufferTooSmall => UnexpectedEof, + Error::NoElem => NotFound, + Error::SomeConnectsFailed => ConnectionAborted, + Error::NoDevice => NotFound, + Error::Busy => ResourceBusy, + Error::Canceled => Interrupted, + Error::ShmemSegment => Other, + Error::AlreadyExists => AlreadyExists, + Error::OutOfRange => InvalidInput, + Error::Timeout => TimedOut, + Error::ExceedsLimit => Other, + Error::Unsupported => Unsupported, + Error::Rejected => ConnectionRefused, + Error::NotConnected => NotConnected, + Error::ConnectionReset => ConnectionReset, + Error::FirstLinkFailure => Other, + Error::LastLinkFailure => Other, + Error::FirstEndpointFailure => Other, + Error::LastEndpointFailure => Other, + Error::EndpointTimeout => TimedOut, + Error::Unknown => Other, + }; + std::io::Error::new(kind, val) + } +} diff --git a/src/ucp/endpoint/am.rs b/src/ucp/endpoint/am.rs index 2002020..046098e 100644 --- a/src/ucp/endpoint/am.rs +++ b/src/ucp/endpoint/am.rs @@ -1,6 +1,7 @@ use crossbeam::queue::SegQueue; use tokio::sync::Notify; +use super::param::RequestParam; use super::*; use std::{ io::{IoSlice, IoSliceMut}, @@ -8,7 +9,7 @@ use std::{ sync::atomic::AtomicBool, }; -//// Active message protocol. +/// Active message protocol. /// Active message protocol is a mechanism for sending and receiving messages /// between processes in a distributed system. /// It allows a process to send a message to another process, which can then @@ -221,7 +222,7 @@ impl<'a> AmMsg<'a> { } Some(AmData::Data(data)) => { // data message, no need to receive - let size = copy_data_to_iov(&data, iov)?; + let size = copy_data_to_iov(data, iov)?; self.drop_msg(AmData::Data(data)); Ok(size) } @@ -249,22 +250,12 @@ impl<'a> AmMsg<'a> { self.worker.handle, iov.len() ); - let mut param = MaybeUninit::::uninit(); - let (buffer, count) = unsafe { - let param = &mut *param.as_mut_ptr(); - param.op_attr_mask = ucp_op_attr_t::UCP_OP_ATTR_FIELD_CALLBACK as u32 - | ucp_op_attr_t::UCP_OP_ATTR_FIELD_DATATYPE as u32; - param.cb = ucp_request_param_t__bindgen_ty_1 { - recv_am: Some(callback), - }; - - if iov.len() == 1 { - param.datatype = ucp_dt_make_contig(1); - (iov[0].as_ptr(), iov[0].len()) - } else { - param.datatype = ucp_dt_type::UCP_DATATYPE_IOV as _; - (iov.as_ptr() as _, iov.len()) - } + + let param = RequestParam::new().cb_recv_am(Some(callback)); + let (buffer, count, param) = if iov.len() == 1 { + (iov[0].as_ptr(), iov[0].len(), param) + } else { + (iov.as_ptr() as _, iov.len(), param.iov()) }; let status = unsafe { @@ -273,7 +264,7 @@ impl<'a> AmMsg<'a> { data_desc as _, buffer as _, count as _, - param.as_ptr(), + param.as_ref(), ) }; if status.is_null() { @@ -282,9 +273,9 @@ impl<'a> AmMsg<'a> { } else if UCS_PTR_IS_PTR(status) { RequestHandle { ptr: status, - poll_fn: poll_recv, + poll_fn: poll_normal, } - .await; + .await?; Ok(data_len) } else { Err(Error::from_ptr(status).unwrap_err()) @@ -304,12 +295,21 @@ impl<'a> AmMsg<'a> { && !self.msg.reply_ep.is_null() } + /// return endpoint handler + pub fn reply_ep(&self) -> Option { + if self.need_reply() { + Some(EndpointHandler(self.msg.reply_ep)) + } else { + None + } + } + /// Send reply /// # Safety /// User needs to ensure that the endpoint isn't closed. pub async unsafe fn reply( &self, - id: u32, + id: u16, header: &[u8], data: &[u8], need_reply: bool, @@ -327,7 +327,7 @@ impl<'a> AmMsg<'a> { /// User needs to ensure that the endpoint isn't closed. pub async unsafe fn reply_vectorized( &self, - id: u32, + id: u16, header: &[u8], data: &[IoSlice<'_>], need_reply: bool, @@ -439,8 +439,8 @@ impl Worker { param: *const ucp_am_recv_param_t, ) -> ucs_status_t { let handler = &*(arg as *const AmStreamInner); - let header = slice::from_raw_parts(header as *const u8, header_len as usize); - let data = slice::from_raw_parts(data as *const u8, data_len as usize); + let header = slice::from_raw_parts(header as *const u8, header_len); + let data = slice::from_raw_parts(data as *const u8, data_len); let param = &*param; handler.callback(header, data, param.reply_ep, param.recv_attr); @@ -460,7 +460,7 @@ impl Worker { } self.am_streams.write().unwrap().insert(id, stream.clone()); - return Ok(AmStream::new(self, stream)); + Ok(AmStream::new(self, stream)) } /// Register active message handler for `id`. @@ -497,7 +497,7 @@ impl Endpoint { /// Send active message. pub async fn am_send( &self, - id: u32, + id: u16, header: &[u8], data: &[u8], need_reply: bool, @@ -511,7 +511,7 @@ impl Endpoint { /// Send active message. pub async fn am_send_vectorized( &self, - id: u32, + id: u16, header: &[u8], data: &[IoSlice<'_>], need_reply: bool, @@ -534,7 +534,7 @@ pub enum AmProto { async fn am_send( endpoint: ucp_ep_h, - id: u32, + id: u16, header: &[u8], data: &[IoSlice<'_>], need_reply: bool, @@ -546,45 +546,33 @@ async fn am_send( request.waker.wake(); } - let mut param = MaybeUninit::::uninit(); - let (buffer, count) = unsafe { - let param = &mut *param.as_mut_ptr(); - param.op_attr_mask = ucp_op_attr_t::UCP_OP_ATTR_FIELD_CALLBACK as u32 - | ucp_op_attr_t::UCP_OP_ATTR_FIELD_DATATYPE as u32 - | ucp_op_attr_t::UCP_OP_ATTR_FIELD_FLAGS as u32; - param.flags = 0; - param.cb = ucp_request_param_t__bindgen_ty_1 { - send: Some(callback), - }; - - match proto { - Some(AmProto::Eager) => param.flags |= ucp_send_am_flags::UCP_AM_SEND_FLAG_EAGER.0, - Some(AmProto::Rndv) => param.flags |= ucp_send_am_flags::UCP_AM_SEND_FLAG_RNDV.0, - _ => (), - } - - if need_reply { - param.flags |= ucp_send_am_flags::UCP_AM_SEND_FLAG_REPLY.0; - } - - if data.len() == 1 { - param.datatype = ucp_dt_make_contig(1); - (data[0].as_ptr(), data[0].len()) - } else { - param.datatype = ucp_dt_type::UCP_DATATYPE_IOV as _; - (data.as_ptr() as _, data.len()) - } + // Use RequestParam builder for send + let param = RequestParam::new().cb_send(Some(callback)); + let param = match proto { + Some(AmProto::Eager) => param.set_flag_eager(), + Some(AmProto::Rndv) => param.set_flag_rndv(), + None => param, + }; + let param = if need_reply { + param.set_flag_reply() + } else { + param + }; + let (buffer, count, param) = if data.len() == 1 { + (data[0].as_ptr(), data[0].len(), param) + } else { + (data.as_ptr() as _, data.len(), param.iov()) }; let status = unsafe { ucp_am_send_nbx( endpoint, - id, + id as u32, header.as_ptr() as _, header.len() as _, buffer as _, count as _, - param.as_mut_ptr(), + param.as_ref(), ) }; if status.is_null() { @@ -601,15 +589,6 @@ async fn am_send( } } -unsafe fn poll_recv(ptr: ucs_status_ptr_t) -> Poll<()> { - let status = ucp_request_check_status(ptr as _); - if status == ucs_status_t::UCS_INPROGRESS { - Poll::Pending - } else { - Poll::Ready(()) - } -} - #[cfg(test)] #[cfg(feature = "am")] mod tests { @@ -617,7 +596,7 @@ mod tests { #[test_log::test] fn am() { - let protos = vec![None, Some(AmProto::Eager), Some(AmProto::Rndv)]; + let protos = [None, Some(AmProto::Eager), Some(AmProto::Rndv)]; for block_size_shift in 0..20_usize { for p in protos.iter() { let rt = tokio::runtime::Builder::new_current_thread() @@ -671,13 +650,13 @@ mod tests { let msg = stream1.wait_msg().await; let mut msg = msg.expect("no msg"); assert_eq!(msg.header(), &header); - assert_eq!(msg.contains_data(), true); + assert!(msg.contains_data()); assert_eq!(msg.data_len(), data.len()); let mut recv_data = vec![0_u8; msg.data_len()]; let recv_len = msg.recv_data_single(&mut recv_data).await.unwrap(); assert_eq!(data.len(), recv_len); assert_eq!(data, recv_data); - assert_eq!(msg.contains_data(), false); + assert!(!msg.contains_data()); msg } ); @@ -695,11 +674,11 @@ mod tests { let reply = stream2.wait_msg().await; let mut reply = reply.expect("no reply"); assert_eq!(reply.header(), &header); - assert_eq!(reply.contains_data(), true); + assert!(reply.contains_data()); assert_eq!(reply.data_len(), data.len()); let recv_data = reply.recv_data().await.unwrap(); assert_eq!(data, recv_data); - assert_eq!(reply.contains_data(), false); + assert!(!reply.contains_data()); } ); diff --git a/src/ucp/endpoint/mod.rs b/src/ucp/endpoint/mod.rs index b477468..dd2958d 100644 --- a/src/ucp/endpoint/mod.rs +++ b/src/ucp/endpoint/mod.rs @@ -9,13 +9,18 @@ use std::task::Poll; #[cfg(feature = "am")] mod am; +mod param; mod rma; mod stream; mod tag; +#[cfg(feature = "util")] +mod util; #[cfg(feature = "am")] pub use self::am::*; pub use self::rma::*; +#[cfg(feature = "util")] +pub use self::util::*; // State associate with ucp_ep_h // todo: Add a `get_user_data` to UCX @@ -78,6 +83,12 @@ pub struct Endpoint { inner: Rc, } +/// Type alias of Endpoint handler +#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord)] +pub struct EndpointHandler(ucp_ep_h); +unsafe impl Sync for EndpointHandler {} +unsafe impl Send for EndpointHandler {} + impl Endpoint { fn create(worker: &Rc, mut params: ucp_ep_params) -> Result { let inner = Rc::new(EndpointInner::new(worker.clone())); @@ -200,6 +211,11 @@ impl Endpoint { Ok(self.handle) } + /// Get the endpoint handler + pub fn handler(&self) -> Result { + Ok(EndpointHandler(self.get_handle()?)) + } + /// Print endpoint information to stderr. pub fn print_to_stderr(&self) { if !self.inner.is_closed() { @@ -252,7 +268,7 @@ impl Endpoint { Ok(()) } else if UCS_PTR_IS_PTR(status) { let result = loop { - if let Poll::Ready(result) = unsafe { poll_normal(status) } { + if let Poll::Ready(result) = poll_normal(status) { unsafe { ucp_request_free(status as _) }; break result; } else { @@ -303,18 +319,18 @@ impl Drop for Endpoint { /// A handle to the request returned from async IO functions. struct RequestHandle { ptr: ucs_status_ptr_t, - poll_fn: unsafe fn(ucs_status_ptr_t) -> Poll, + poll_fn: fn(ucs_status_ptr_t) -> Poll, } impl Future for RequestHandle { type Output = T; fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context) -> Poll { - if let ret @ Poll::Ready(_) = unsafe { (self.poll_fn)(self.ptr) } { + if let ret @ Poll::Ready(_) = { (self.poll_fn)(self.ptr) } { return ret; } let request = unsafe { &mut *(self.ptr as *mut Request) }; request.waker.register(cx.waker()); - unsafe { (self.poll_fn)(self.ptr) } + (self.poll_fn)(self.ptr) } } @@ -325,8 +341,32 @@ impl Drop for RequestHandle { } } -unsafe fn poll_normal(ptr: ucs_status_ptr_t) -> Poll> { - let status = ucp_request_check_status(ptr as _); +enum Status { + Completed(Result), + Scheduled(RequestHandle>), +} + +impl Status { + pub fn from( + status: *mut c_void, + immediate: MaybeUninit, + poll_fn: fn(ucs_status_ptr_t) -> Poll>, + ) -> Self { + if status.is_null() { + Self::Completed(Ok(unsafe { immediate.assume_init() })) + } else if UCS_PTR_IS_ERR(status) { + Self::Completed(Err(Error::from_error(UCS_PTR_RAW_STATUS(status)))) + } else { + Self::Scheduled(RequestHandle { + ptr: status, + poll_fn, + }) + } + } +} + +fn poll_normal(ptr: ucs_status_ptr_t) -> Poll> { + let status = unsafe { ucp_request_check_status(ptr as _) }; if status == ucs_status_t::UCS_INPROGRESS { Poll::Pending } else { diff --git a/src/ucp/endpoint/param.rs b/src/ucp/endpoint/param.rs new file mode 100644 index 0000000..92ea65c --- /dev/null +++ b/src/ucp/endpoint/param.rs @@ -0,0 +1,97 @@ +use ucx1_sys::*; + +pub struct RequestParam { + inner: ucp_request_param_t, +} + +impl RequestParam { + pub fn new() -> Self { + // Zeroed for safety, as in C + Self { + inner: unsafe { std::mem::zeroed() }, + } + } + + pub fn cb_send(mut self, callback: ucp_send_nbx_callback_t) -> Self { + self.inner.op_attr_mask |= ucp_op_attr_t::UCP_OP_ATTR_FIELD_CALLBACK as u32; + unsafe { + let mut cb: ucp_request_param_t__bindgen_ty_1 = std::mem::zeroed(); + cb.send = callback; + self.inner.cb = cb; + } + self + } + + pub fn cb_tag_recv(mut self, callback: ucp_tag_recv_nbx_callback_t) -> Self { + self.inner.op_attr_mask |= ucp_op_attr_t::UCP_OP_ATTR_FIELD_CALLBACK as u32; + unsafe { + let mut cb: ucp_request_param_t__bindgen_ty_1 = std::mem::zeroed(); + cb.recv = callback; + self.inner.cb = cb; + } + self + } + + pub fn cb_stream_recv(mut self, callback: ucp_stream_recv_nbx_callback_t) -> Self { + self.inner.op_attr_mask |= ucp_op_attr_t::UCP_OP_ATTR_FIELD_CALLBACK as u32; + unsafe { + let mut cb: ucp_request_param_t__bindgen_ty_1 = std::mem::zeroed(); + cb.recv_stream = callback; + self.inner.cb = cb; + } + self + } + + pub fn recv_tag_info(mut self, info: *mut ucp_tag_recv_info) -> Self { + self.inner.op_attr_mask |= ucp_op_attr_t::UCP_OP_ATTR_FIELD_RECV_INFO as u32; + self.inner.recv_info.tag_info = info; + self + } + + #[cfg(feature = "am")] + pub fn cb_recv_am(mut self, callback: ucp_am_recv_data_nbx_callback_t) -> Self { + self.inner.op_attr_mask |= ucp_op_attr_t::UCP_OP_ATTR_FIELD_CALLBACK as u32; + unsafe { + let mut cb: ucp_request_param_t__bindgen_ty_1 = std::mem::zeroed(); + cb.recv_am = callback; + self.inner.cb = cb; + } + self + } + + pub fn iov(mut self) -> Self { + self.inner.op_attr_mask |= ucp_op_attr_t::UCP_OP_ATTR_FIELD_DATATYPE as u32; + self.inner.datatype = ucp_dt_type::UCP_DATATYPE_IOV as _; + self + } + + #[cfg(feature = "am")] + fn set_flag(&mut self) { + self.inner.op_attr_mask |= ucp_op_attr_t::UCP_OP_ATTR_FIELD_FLAGS as u32; + } + + #[cfg(feature = "am")] + pub fn set_flag_eager(mut self) -> Self { + self.set_flag(); + self.inner.flags |= ucp_send_am_flags::UCP_AM_SEND_FLAG_EAGER.0; + self + } + + #[cfg(feature = "am")] + pub fn set_flag_rndv(mut self) -> Self { + self.set_flag(); + self.inner.flags |= ucp_send_am_flags::UCP_AM_SEND_FLAG_RNDV.0; + self + } + + #[cfg(feature = "am")] + pub fn set_flag_reply(mut self) -> Self { + self.set_flag(); + self.inner.flags |= ucp_send_am_flags::UCP_AM_SEND_FLAG_REPLY.0; + self + } + + pub fn as_ref(&self) -> *const ucp_request_param_t { + &self.inner as *const _ + } +} diff --git a/src/ucp/endpoint/rma.rs b/src/ucp/endpoint/rma.rs index e9cd41e..c08e5b5 100644 --- a/src/ucp/endpoint/rma.rs +++ b/src/ucp/endpoint/rma.rs @@ -1,3 +1,5 @@ +use super::param::RequestParam; + use super::*; /// A memory region allocated through UCP library, @@ -112,19 +114,24 @@ impl Endpoint { /// Stores a contiguous block of data into remote memory. pub async fn put(&self, buf: &[u8], remote_addr: u64, rkey: &RKey) -> Result<(), Error> { trace!("put: endpoint={:?} len={}", self.handle, buf.len()); - unsafe extern "C" fn callback(request: *mut c_void, status: ucs_status_t) { + unsafe extern "C" fn callback( + request: *mut c_void, + status: ucs_status_t, + _user_data: *mut c_void, + ) { trace!("put: complete. req={:?}, status={:?}", request, status); let request = &mut *(request as *mut Request); request.waker.wake(); } + let param = RequestParam::new().cb_send(Some(callback)); let status = unsafe { - ucp_put_nb( + ucp_put_nbx( self.get_handle()?, buf.as_ptr() as _, buf.len() as _, remote_addr, rkey.handle, - Some(callback), + param.as_ref(), ) }; if status.is_null() { @@ -144,19 +151,24 @@ impl Endpoint { /// Loads a contiguous block of data from remote memory. pub async fn get(&self, buf: &mut [u8], remote_addr: u64, rkey: &RKey) -> Result<(), Error> { trace!("get: endpoint={:?} len={}", self.handle, buf.len()); - unsafe extern "C" fn callback(request: *mut c_void, status: ucs_status_t) { + unsafe extern "C" fn callback( + request: *mut c_void, + status: ucs_status_t, + _user_data: *mut c_void, + ) { trace!("get: complete. req={:?}, status={:?}", request, status); let request = &mut *(request as *mut Request); request.waker.wake(); } + let param = RequestParam::new().cb_send(Some(callback)); let status = unsafe { - ucp_get_nb( + ucp_get_nbx( self.get_handle()?, buf.as_mut_ptr() as _, buf.len() as _, remote_addr, rkey.handle, - Some(callback), + param.as_ref(), ) }; if status.is_null() { diff --git a/src/ucp/endpoint/stream.rs b/src/ucp/endpoint/stream.rs index 411a80b..d9c2f5b 100644 --- a/src/ucp/endpoint/stream.rs +++ b/src/ucp/endpoint/stream.rs @@ -1,10 +1,14 @@ +use super::param::RequestParam; use super::*; impl Endpoint { - /// Sends data through stream. - pub async fn stream_send(&self, buf: &[u8]) -> Result { + pub(super) fn stream_send_impl(&self, buf: &[u8]) -> Result, Error> { trace!("stream_send: endpoint={:?} len={}", self.handle, buf.len()); - unsafe extern "C" fn callback(request: *mut c_void, status: ucs_status_t) { + unsafe extern "C" fn callback( + request: *mut c_void, + status: ucs_status_t, + _user_data: *mut c_void, + ) { trace!( "stream_send: complete. req={:?}, status={:?}", request, @@ -13,34 +17,46 @@ impl Endpoint { let request = &mut *(request as *mut Request); request.waker.wake(); } + let param = RequestParam::new().cb_send(Some(callback)); let status = unsafe { - ucp_stream_send_nb( + ucp_stream_send_nbx( self.get_handle()?, buf.as_ptr() as _, buf.len() as _, - ucp_dt_make_contig(1), - Some(callback), - 0, + param.as_ref(), ) }; - if status.is_null() { - trace!("stream_send: complete"); - } else if UCS_PTR_IS_PTR(status) { - RequestHandle { - ptr: status, - poll_fn: poll_normal, + Ok(Status::from(status, MaybeUninit::uninit(), poll_normal)) + } + + /// Sends data through stream. + pub async fn stream_send(&self, buf: &[u8]) -> Result { + match self.stream_send_impl(buf)? { + Status::Completed(r) => { + match &r { + Ok(()) => trace!("stream_send: complete"), + Err(e) => error!("stream_send error : {:?}", e), + } + r.map(|_| buf.len()) + } + Status::Scheduled(request_handle) => { + request_handle.await?; + Ok(buf.len()) } - .await?; - } else { - return Err(Error::from_ptr(status).unwrap_err()); } - Ok(buf.len()) } - /// Receives data from stream. - pub async fn stream_recv(&self, buf: &mut [MaybeUninit]) -> Result { + pub(super) fn stream_recv_impl( + &self, + buf: &mut [MaybeUninit], + ) -> Result, Error> { trace!("stream_recv: endpoint={:?} len={}", self.handle, buf.len()); - unsafe extern "C" fn callback(request: *mut c_void, status: ucs_status_t, length: usize) { + unsafe extern "C" fn callback( + request: *mut c_void, + status: ucs_status_t, + length: usize, + _user_data: *mut c_void, + ) { trace!( "stream_recv: complete. req={:?}, status={:?}, len={}", request, @@ -51,39 +67,111 @@ impl Endpoint { request.waker.wake(); } let mut length = MaybeUninit::::uninit(); + let param = RequestParam::new().cb_stream_recv(Some(callback)); let status = unsafe { - ucp_stream_recv_nb( + ucp_stream_recv_nbx( self.get_handle()?, buf.as_mut_ptr() as _, buf.len() as _, - ucp_dt_make_contig(1), - Some(callback), length.as_mut_ptr(), - 0, + param.as_ref(), ) }; - if status.is_null() { - let length = unsafe { length.assume_init() } as usize; - trace!("stream_recv: complete. len={}", length); - Ok(length) - } else if UCS_PTR_IS_PTR(status) { - Ok(RequestHandle { - ptr: status, - poll_fn: poll_stream, + Ok(Status::from(status, length, poll_stream)) + } + + /// Receives data from stream. + pub async fn stream_recv(&self, buf: &mut [MaybeUninit]) -> Result { + match self.stream_recv_impl(buf)? { + Status::Completed(r) => { + match &r { + Ok(x) => trace!("stream_recv: complete. len={}", x), + Err(e) => error!("stream_recv: error : {:?}", e), + } + r } - .await) - } else { - Err(Error::from_ptr(status).unwrap_err()) + Status::Scheduled(request_handle) => request_handle.await, } } } -unsafe fn poll_stream(ptr: ucs_status_ptr_t) -> Poll { +fn poll_stream(ptr: ucs_status_ptr_t) -> Poll> { let mut len = MaybeUninit::::uninit(); - let status = ucp_stream_recv_request_test(ptr as _, len.as_mut_ptr() as _); + let status = unsafe { ucp_stream_recv_request_test(ptr as _, len.as_mut_ptr() as _) }; if status == ucs_status_t::UCS_INPROGRESS { Poll::Pending } else { - Poll::Ready(len.assume_init()) + Poll::Ready(Error::from_status(status).map(|_| unsafe { len.assume_init() })) + } +} + +#[cfg(test)] +mod tests { + use super::*; + #[test_log::test] + fn stream() { + for i in 0..20_usize { + spawn_thread!(_stream(4 << i)).join().unwrap(); + } + } + + async fn _stream(msg_size: usize) { + let context1 = Context::new().unwrap(); + let worker1 = context1.create_worker().unwrap(); + let context2 = Context::new().unwrap(); + let worker2 = context2.create_worker().unwrap(); + tokio::task::spawn_local(worker1.clone().polling()); + tokio::task::spawn_local(worker2.clone().polling()); + + // connect with each other + let mut listener = worker1 + .create_listener("0.0.0.0:0".parse().unwrap()) + .unwrap(); + let listen_port = listener.socket_addr().unwrap().port(); + println!("listen at port {}", listen_port); + let mut addr: std::net::SocketAddr = "127.0.0.1:0".parse().unwrap(); + addr.set_port(listen_port); + + let (endpoint1, endpoint2) = tokio::join!( + async { + let conn1 = listener.next().await; + worker1.accept(conn1).await.unwrap() + }, + async { worker2.connect_socket(addr).await.unwrap() }, + ); + + tokio::join!( + async { + // send + let buf = vec![42u8; msg_size]; + endpoint2.stream_send(&buf).await.unwrap(); + println!("stream sent"); + }, + async { + // recv + let mut buf = vec![std::mem::MaybeUninit::::uninit(); msg_size]; + let mut start = 0; + while start < msg_size { + let len = endpoint1.stream_recv(&mut buf[start..]).await.unwrap(); + if len == 0 { + break; // no more data + } + start += len; + } + let buf: Vec = unsafe { buf.into_iter().map(|b| b.assume_init()).collect() }; + assert_eq!(buf, vec![42u8; msg_size]); + println!("stream received"); + } + ); + + println!("status {:?}", endpoint2.get_status()); + assert_eq!(endpoint1.get_rc(), (1, 1)); + assert_eq!(endpoint2.get_rc(), (1, 1)); + assert_eq!(endpoint1.close(false).await, Ok(())); + assert_eq!(endpoint2.close(false).await, Err(Error::ConnectionReset)); + assert_eq!(endpoint1.get_rc(), (1, 0)); + assert_eq!(endpoint2.get_rc(), (1, 1)); + assert_eq!(endpoint2.close(true).await, Ok(())); + assert_eq!(endpoint2.get_rc(), (1, 0)); } } diff --git a/src/ucp/endpoint/tag.rs b/src/ucp/endpoint/tag.rs index 6f04b4d..5d5e75c 100644 --- a/src/ucp/endpoint/tag.rs +++ b/src/ucp/endpoint/tag.rs @@ -1,3 +1,4 @@ +use super::param::RequestParam; use super::*; use std::io::{IoSlice, IoSliceMut}; @@ -16,126 +17,154 @@ impl Worker { tag_mask: u64, buf: &mut [MaybeUninit], ) -> Result<(u64, usize), Error> { + match self.tag_recv_impl(tag, tag_mask, buf)? { + Status::Completed(r) => r.map(|info| (info.sender_tag, info.length)), + Status::Scheduled(request_handle) => { + let info = request_handle.await?; + Ok((info.sender_tag, info.length as usize)) + } + } + } + + /// Like `tag_recv`, except that it reads into a slice of buffers. + pub async fn tag_recv_vectored( + &self, + tag: u64, + iov: &mut [IoSliceMut<'_>], + ) -> Result { trace!( - "tag_recv: worker={:?}, tag={}, mask={:#x} len={}", + "tag_recv_vectored: worker={:?} iov.len={}", self.handle, - tag, - tag_mask, - buf.len() + iov.len() ); unsafe extern "C" fn callback( request: *mut c_void, status: ucs_status_t, - info: *mut ucp_tag_recv_info, + info: *const ucp_tag_recv_info, + _user_data: *mut c_void, ) { let length = (*info).length; + let tag = (*info).sender_tag; trace!( - "tag_recv: complete. req={:?}, status={:?}, len={}", + "tag_recv_vectored: complete. req={:?}, status={:?}, tag={}, len={}", request, status, + tag, length ); let request = &mut *(request as *mut Request); request.waker.wake(); } + // Use RequestParam builder for iov + let param = RequestParam::new().cb_tag_recv(Some(callback)).iov(); let status = unsafe { - ucp_tag_recv_nb( + ucp_tag_recv_nbx( self.handle, - buf.as_mut_ptr() as _, - buf.len() as _, - ucp_dt_make_contig(1), + iov.as_ptr() as _, + iov.len() as _, tag, - tag_mask, - Some(callback), + u64::max_value(), + param.as_ref(), ) }; - Error::from_ptr(status)?; RequestHandle { ptr: status, poll_fn: poll_tag, } .await + .map(|info| info.length) } - /// Like `tag_recv`, except that it reads into a slice of buffers. - pub async fn tag_recv_vectored( + pub(super) fn tag_recv_impl( &self, tag: u64, - iov: &mut [IoSliceMut<'_>], - ) -> Result { + tag_mask: u64, + buf: &mut [MaybeUninit], + ) -> Result, Error> { trace!( - "tag_recv_vectored: worker={:?} iov.len={}", + "tag_recv: worker={:?}, tag={}, mask={:#x} len={}", self.handle, - iov.len() + tag, + tag_mask, + buf.len() ); unsafe extern "C" fn callback( request: *mut c_void, status: ucs_status_t, - info: *mut ucp_tag_recv_info, + info: *const ucp_tag_recv_info, + _user_data: *mut c_void, ) { let length = (*info).length; + let sender_tag = (*info).sender_tag; trace!( - "tag_recv_vectored: complete. req={:?}, status={:?}, len={}", + "tag_recv: complete. req={:?}, status={:?}, tag={}, len={}", request, status, + sender_tag, length ); let request = &mut *(request as *mut Request); request.waker.wake(); } + let mut info = MaybeUninit::::uninit(); + let param = RequestParam::new() + .cb_tag_recv(Some(callback)) + .recv_tag_info(info.as_mut_ptr() as _); let status = unsafe { - ucp_tag_recv_nb( + ucp_tag_recv_nbx( self.handle, - iov.as_ptr() as _, - iov.len() as _, - ucp_dt_type::UCP_DATATYPE_IOV as _, + buf.as_mut_ptr() as _, + buf.len() as _, tag, - u64::max_value(), - Some(callback), + tag_mask, + param.as_ref(), ) }; - Error::from_ptr(status)?; - RequestHandle { - ptr: status, - poll_fn: poll_tag, - } - .await - .map(|info| info.1) + Ok(Status::from(status, info, poll_tag)) } } impl Endpoint { - /// Sends a messages with `tag`. - pub async fn tag_send(&self, tag: u64, buf: &[u8]) -> Result { + pub(super) fn tag_send_impl(&self, tag: u64, buf: &[u8]) -> Result, Error> { trace!("tag_send: endpoint={:?} len={}", self.handle, buf.len()); - unsafe extern "C" fn callback(request: *mut c_void, status: ucs_status_t) { + unsafe extern "C" fn callback( + request: *mut c_void, + status: ucs_status_t, + _user_data: *mut c_void, + ) { trace!("tag_send: complete. req={:?}, status={:?}", request, status); let request = &mut *(request as *mut Request); request.waker.wake(); } + let param = RequestParam::new().cb_send(Some(callback)); let status = unsafe { - ucp_tag_send_nb( + ucp_tag_send_nbx( self.get_handle()?, buf.as_ptr() as _, buf.len() as _, - ucp_dt_make_contig(1), tag, - Some(callback), + param.as_ref(), ) }; - if status.is_null() { - trace!("tag_send: complete"); - } else if UCS_PTR_IS_PTR(status) { - RequestHandle { - ptr: status, - poll_fn: poll_normal, + Ok(Status::from(status, MaybeUninit::uninit(), poll_normal)) + } + + /// Sends a messages with `tag`. + pub async fn tag_send(&self, tag: u64, buf: &[u8]) -> Result { + match self.tag_send_impl(tag, buf)? { + Status::Completed(r) => { + match &r { + Ok(()) => trace!("tag_send: complete"), + Err(e) => error!("tag_send error : {:?}", e), + } + r.map(|_| buf.len()) + } + Status::Scheduled(request_handle) => { + request_handle.await?; + Ok(buf.len()) } - .await?; - } else { - return Err(Error::from_ptr(status).unwrap_err()); } - Ok(buf.len()) } /// Like `tag_send`, except that it reads into a slice of buffers. @@ -145,7 +174,11 @@ impl Endpoint { self.handle, iov.len() ); - unsafe extern "C" fn callback(request: *mut c_void, status: ucs_status_t) { + unsafe extern "C" fn callback( + request: *mut c_void, + status: ucs_status_t, + _user_data: *mut c_void, + ) { trace!( "tag_send_vectored: complete. req={:?}, status={:?}", request, @@ -154,14 +187,15 @@ impl Endpoint { let request = &mut *(request as *mut Request); request.waker.wake(); } + // Use RequestParam builder for iov + let param = RequestParam::new().cb_send(Some(callback)).iov(); let status = unsafe { - ucp_tag_send_nb( + ucp_tag_send_nbx( self.get_handle()?, iov.as_ptr() as _, iov.len() as _, - ucp_dt_type::UCP_DATATYPE_IOV as _, tag, - Some(callback), + param.as_ref(), ) }; let total_len = iov.iter().map(|v| v.len()).sum(); @@ -180,14 +214,14 @@ impl Endpoint { } } -unsafe fn poll_tag(ptr: ucs_status_ptr_t) -> Poll> { +fn poll_tag(ptr: ucs_status_ptr_t) -> Poll> { let mut info = MaybeUninit::::uninit(); - let status = ucp_tag_recv_request_test(ptr as _, info.as_mut_ptr() as _); + let status = unsafe { ucp_tag_recv_request_test(ptr as _, info.as_mut_ptr() as _) }; match status { ucs_status_t::UCS_INPROGRESS => Poll::Pending, ucs_status_t::UCS_OK => { - let info = info.assume_init(); - Poll::Ready(Ok((info.sender_tag, info.length as usize))) + let info = unsafe { info.assume_init() }; + Poll::Ready(Ok(info)) } status => Poll::Ready(Err(Error::from_error(status))), } @@ -253,4 +287,64 @@ mod tests { assert_eq!(endpoint2.close(true).await, Ok(())); assert_eq!(endpoint2.get_rc(), (1, 0)); } + + #[test_log::test] + fn multi_tag() { + for i in 0..20_usize { + spawn_thread!(_multi_tag(4 << i)).join().unwrap(); + } + } + + async fn _multi_tag(msg_size: usize) { + let context1 = Context::new().unwrap(); + let worker1 = context1.create_worker().unwrap(); + let context2 = Context::new().unwrap(); + let worker2 = context2.create_worker().unwrap(); + tokio::task::spawn_local(worker1.clone().polling()); + tokio::task::spawn_local(worker2.clone().polling()); + + // connect with each other + let mut listener = worker1 + .create_listener("0.0.0.0:0".parse().unwrap()) + .unwrap(); + let listen_port = listener.socket_addr().unwrap().port(); + println!("listen at port {}", listen_port); + let mut addr: SocketAddr = "127.0.0.1:0".parse().unwrap(); + addr.set_port(listen_port); + + let (endpoint1, endpoint2) = tokio::join!( + async { + let conn1 = listener.next().await; + worker1.accept(conn1).await.unwrap() + }, + async { worker2.connect_socket(addr).await.unwrap() }, + ); + + // send tag message + tokio::join!( + async { + // send + let mut buf = vec![0; msg_size]; + endpoint2.tag_send(3, &mut buf).await.unwrap(); + println!("tag sended"); + }, + async { + // recv + let mut buf = vec![MaybeUninit::::uninit(); msg_size]; + let (tag, size) = worker1.tag_recv_mask(0, 0, &mut buf).await.unwrap(); + assert_eq!(size, msg_size); + assert_eq!(tag, 3); + println!("tag recved"); + } + ); + + assert_eq!(endpoint1.get_rc(), (1, 1)); + assert_eq!(endpoint2.get_rc(), (1, 1)); + assert_eq!(endpoint1.close(false).await, Ok(())); + assert_eq!(endpoint2.close(false).await, Err(Error::ConnectionReset)); + assert_eq!(endpoint1.get_rc(), (1, 0)); + assert_eq!(endpoint2.get_rc(), (1, 1)); + assert_eq!(endpoint2.close(true).await, Ok(())); + assert_eq!(endpoint2.get_rc(), (1, 0)); + } } diff --git a/src/ucp/endpoint/util.rs b/src/ucp/endpoint/util.rs new file mode 100644 index 0000000..b007fb2 --- /dev/null +++ b/src/ucp/endpoint/util.rs @@ -0,0 +1,431 @@ +use super::*; +use futures::FutureExt; +use pin_project::pin_project; +use std::task::ready; +use tokio::io::AsyncRead; +use tokio::io::AsyncWrite; +use tokio::io::ReadBuf; + +impl Endpoint { + /// make write stream + pub fn write_stream(&self) -> WriteStream<'_> { + WriteStream { + endpoint: self, + request: None, + } + } + /// make read stream + pub fn read_stream(&self) -> ReadStream<'_> { + ReadStream { + endpoint: self, + request: None, + } + } + /// make tag write stream + pub fn tag_write_stream(&self, tag: u64) -> TagWriteStream<'_> { + TagWriteStream { + endpoint: self, + tag, + request: None, + } + } +} + +impl Worker { + /// make tag read stream + pub fn tag_read_stream(&self, tag: u64) -> TagReadStream<'_> { + TagReadStream { + worker: self, + tag, + tag_mask: u64::max_value(), + request: None, + } + } + /// make tag read stream with mask + /// not suggested to use this function, because actual received tag should be checked by user + pub fn tag_read_stream_mask(&self, tag: u64, tag_mask: u64) -> TagReadStream<'_> { + TagReadStream { + worker: self, + tag, + tag_mask, + request: None, + } + } +} + +#[pin_project] +/// A stream for writing data asynchronously to an `Endpoint` stream. +pub struct WriteStream<'a> { + endpoint: &'a Endpoint, + #[pin] + request: Option>>, +} + +impl<'a> AsyncWrite for WriteStream<'a> { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> Poll> { + loop { + if let Some(mut req) = self.as_mut().project().request.as_pin_mut() { + let r = match ready!(req.poll_unpin(cx)) { + Ok(_) => Ok(buf.len()), + Err(e) => Err(e.into()), + }; + self.request = None; + return Poll::Ready(r); + } else { + match self.endpoint.stream_send_impl(buf) { + Ok(Status::Completed(r)) => { + return Poll::Ready(r.map(|_| buf.len()).map_err(|e| e.into())) + } + Ok(Status::Scheduled(request_handle)) => { + self.request = Some(request_handle); + } + Err(e) => return Poll::Ready(Err(e.into())), + } + } + } + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + if let Some(mut req) = self.as_mut().project().request.as_pin_mut() { + let r = ready!(req.poll_unpin(cx)); + self.request = None; + Poll::Ready(r.map_err(|e| e.into())) + } else { + Poll::Ready(Ok(())) + } + } + + fn poll_shutdown( + self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } +} + +/// A stream for reading data asynchronously from an `Endpoint` stream. +#[pin_project] +pub struct ReadStream<'a> { + endpoint: &'a Endpoint, + #[pin] + request: Option>>, +} + +impl<'a> AsyncRead for ReadStream<'a> { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + out_buf: &mut ReadBuf<'_>, + ) -> Poll> { + if let Some(mut req) = self.as_mut().project().request.as_pin_mut() { + let r = match ready!(req.poll_unpin(cx)) { + Ok(n) => { + // Safety: The buffer was filled by the recv operation. + unsafe { + out_buf.assume_init(n); + out_buf.advance(n); + } + Ok(()) + } + Err(e) => Err(e.into()), + }; + self.request = None; + Poll::Ready(r) + } else { + let buf = unsafe { out_buf.unfilled_mut() }; + match self.endpoint.stream_recv_impl(buf) { + Ok(Status::Completed(n_result)) => { + match n_result { + Ok(n) => { + // Safety: The buffer was filled by the recv operation. + unsafe { + out_buf.assume_init(n); + out_buf.advance(n); + } + Poll::Ready(Ok(())) + } + Err(e) => Poll::Ready(Err(e.into())), + } + } + Ok(Status::Scheduled(request_handle)) => { + self.request = Some(request_handle); + cx.waker().wake_by_ref(); + Poll::Pending + } + Err(e) => Poll::Ready(Err(e.into())), + } + } + } +} + +#[pin_project] +/// A stream for writing data asynchronously to an `Endpoint` using tag. +pub struct TagWriteStream<'a> { + endpoint: &'a Endpoint, + tag: u64, + #[pin] + request: Option>>, +} + +impl<'a> AsyncWrite for TagWriteStream<'a> { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> Poll> { + loop { + if let Some(mut req) = self.as_mut().project().request.as_pin_mut() { + let r = match ready!(req.poll_unpin(cx)) { + Ok(_) => Ok(buf.len()), + Err(e) => Err(e.into()), + }; + self.request = None; + return Poll::Ready(r); + } else { + match self.endpoint.tag_send_impl(self.tag, buf) { + Ok(Status::Completed(r)) => { + return Poll::Ready(r.map(|_| buf.len()).map_err(|e| e.into())); + } + Ok(Status::Scheduled(request_handle)) => { + self.request = Some(request_handle); + } + Err(e) => return Poll::Ready(Err(e.into())), + } + } + } + } + + fn poll_flush( + self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> Poll> { + assert!(self.request.is_none()); + // if let Some(mut req) = self.as_mut().project().request.as_pin_mut() { + // let r = ready!(req.poll_unpin(cx)); + // self.request = None; + // Poll::Ready(r.map_err(|e| e.into())) + // } else { + Poll::Ready(Ok(())) + // } + } + + fn poll_shutdown( + self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } +} + +/// A stream for reading data asynchronously from a `Worker` using tag. +#[pin_project] +pub struct TagReadStream<'a> { + worker: &'a Worker, + tag: u64, + tag_mask: u64, + #[pin] + request: Option>>, +} + +impl<'a> AsyncRead for TagReadStream<'a> { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + out_buf: &mut ReadBuf<'_>, + ) -> Poll> { + if let Some(mut req) = self.as_mut().project().request.as_pin_mut() { + let r = match ready!(req.poll_unpin(cx)) { + Ok(info) => { + // Safety: The buffer was filled by the recv operation. + unsafe { + out_buf.assume_init(info.length); + out_buf.advance(info.length); + } + Ok(()) + } + Err(e) => Err(e.into()), + }; + self.request = None; + Poll::Ready(r) + } else { + let buf = unsafe { out_buf.unfilled_mut() }; + match self.worker.tag_recv_impl(self.tag, self.tag_mask, buf) { + Ok(Status::Completed(n_result)) => { + match n_result { + Ok(info) => { + // Safety: The buffer was filled by the recv operation. + unsafe { + out_buf.assume_init(info.length); + out_buf.advance(info.length); + } + Poll::Ready(Ok(())) + } + Err(e) => Poll::Ready(Err(e.into())), + } + } + Ok(Status::Scheduled(request_handle)) => { + self.request = Some(request_handle); + cx.waker().wake_by_ref(); + Poll::Pending + } + Err(e) => Poll::Ready(Err(e.into())), + } + } + } +} + +#[cfg(test)] +mod test { + + use super::*; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + #[test_log::test] + fn stream_send_recv() { + spawn_thread!(_stream_send_recv()).join().unwrap(); + } + + #[test_log::test] + fn tag_send_recv() { + spawn_thread!(_tag_send_recv()).join().unwrap(); + } + + async fn _stream_send_recv() { + let context1 = Context::new().unwrap(); + let worker1 = context1.create_worker().unwrap(); + let context2 = Context::new().unwrap(); + let worker2 = context2.create_worker().unwrap(); + tokio::task::spawn_local(worker1.clone().polling()); + tokio::task::spawn_local(worker2.clone().polling()); + + // connect with each other + let mut listener = worker1 + .create_listener("0.0.0.0:0".parse().unwrap()) + .unwrap(); + let listen_port = listener.socket_addr().unwrap().port(); + let mut addr: std::net::SocketAddr = "127.0.0.1:0".parse().unwrap(); + addr.set_port(listen_port); + + let (endpoint1, endpoint2) = tokio::join!( + async { + let conn1 = listener.next().await; + worker1.accept(conn1).await.unwrap() + }, + async { worker2.connect_socket(addr).await.unwrap() }, + ); + + // Test cases: (data, repeat count) + let test_cases = vec![ + (vec![], 1), + (vec![0u8], 10), + (vec![1, 2, 3, 4, 5], 5), + ((0..128).collect::>(), 3), + ((0..1024).map(|x| (x % 256) as u8).collect::>(), 2), + ((0..4096).map(|x| (x % 256) as u8).collect::>(), 1), + ]; + for (data, repeat) in test_cases { + for _ in 0..repeat { + // send + let send_buf = data.clone(); + let recv_len = send_buf.len(); + let mut recv_buf = vec![0u8; recv_len]; + tokio::join!( + async { + endpoint2.write_stream().write_all(&send_buf).await.unwrap(); + }, + async { + endpoint1 + .read_stream() + .read_exact(&mut recv_buf) + .await + .unwrap(); + assert_eq!(recv_buf, send_buf, "data mismatch for len={}", recv_len); + } + ); + } + } + } + + async fn _tag_send_recv() { + let context1 = Context::new().unwrap(); + let worker1 = context1.create_worker().unwrap(); + let context2 = Context::new().unwrap(); + let worker2 = context2.create_worker().unwrap(); + tokio::task::spawn_local(worker1.clone().polling()); + tokio::task::spawn_local(worker2.clone().polling()); + + // connect with each other + let mut listener = worker1 + .create_listener("0.0.0.0:0".parse().unwrap()) + .unwrap(); + let listen_port = listener.socket_addr().unwrap().port(); + let mut addr: std::net::SocketAddr = "127.0.0.1:0".parse().unwrap(); + addr.set_port(listen_port); + + let (endpoint1, endpoint2) = tokio::join!( + async { + let conn1 = listener.next().await; + worker1.accept(conn1).await.unwrap() + }, + async { worker2.connect_socket(addr).await.unwrap() }, + ); + + // Test cases: (data, tag, repeat count) + let test_cases = vec![ + (vec![], 1u64, 1), + (vec![0u8], 2u64, 10), + (vec![1, 2, 3, 4, 5], 3u64, 5), + ((0..128).collect::>(), 4u64, 3), + ( + (0..1024).map(|x| (x % 256) as u8).collect::>(), + 5u64, + 2, + ), + ( + (0..4096).map(|x| (x % 256) as u8).collect::>(), + 6u64, + 1, + ), + ]; + for (data, tag, repeat) in test_cases { + for _ in 0..repeat { + // send + let send_buf = data.clone(); + let recv_len = send_buf.len(); + let mut recv_buf = vec![0u8; recv_len]; + tokio::join!( + async { + endpoint2 + .tag_write_stream(tag) + .write_all(&send_buf) + .await + .unwrap(); + }, + async { + worker1 + .tag_read_stream(tag) + .read_exact(&mut recv_buf) + .await + .unwrap(); + assert_eq!( + recv_buf, send_buf, + "data mismatch for tag={}, len={}", + tag, recv_len + ); + } + ); + } + } + + // Clean up + let _ = endpoint1.close(false).await; + let _ = endpoint2.close(false).await; + } +} diff --git a/src/ucp/mod.rs b/src/ucp/mod.rs index 440e6be..8f003d8 100644 --- a/src/ucp/mod.rs +++ b/src/ucp/mod.rs @@ -91,7 +91,7 @@ impl Context { | ucp_params_field::UCP_PARAM_FIELD_MT_WORKERS_SHARED) .0 as u64, features: features.0 as u64, - request_size: std::mem::size_of::() as usize, + request_size: std::mem::size_of::(), request_init: Some(Request::init), request_cleanup: Some(Request::cleanup), mt_workers_shared: 1, diff --git a/src/ucp/worker.rs b/src/ucp/worker.rs index 00abd16..1f142e7 100644 --- a/src/ucp/worker.rs +++ b/src/ucp/worker.rs @@ -107,7 +107,7 @@ impl Worker { Ok(WorkerAddress { handle: unsafe { handle.assume_init() }, - length: unsafe { length.assume_init() } as usize, + length: unsafe { length.assume_init() }, worker: self, }) } @@ -122,6 +122,11 @@ impl Worker { Endpoint::connect_addr(self, addr.handle) } + /// Connect to a remote worker by address. + pub fn connect_addr_vec(self: &Rc, addr: &[u8]) -> Result { + Endpoint::connect_addr(self, addr.as_ptr() as _) + } + /// Connect to a remote listener. pub async fn connect_socket(self: &Rc, addr: SocketAddr) -> Result { Endpoint::connect_socket(self, addr).await @@ -193,6 +198,11 @@ impl<'a> AsRef<[u8]> for WorkerAddress<'a> { impl<'a> Drop for WorkerAddress<'a> { fn drop(&mut self) { + trace!( + "destroy worker address= {:?} {:?}", + self.worker.handle, + self.handle + ); unsafe { ucp_worker_release_address(self.worker.handle, self.handle) } } } diff --git a/ucx1-sys/Cargo.toml b/ucx1-sys/Cargo.toml index d71ff71..6b0b713 100644 --- a/ucx1-sys/Cargo.toml +++ b/ucx1-sys/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ucx1-sys" -version = "0.1.0" +version = "0.2.0" authors = ["Runji Wang "] edition = "2021" description = "Rust FFI bindings to UCX." diff --git a/ucx1-sys/ucx b/ucx1-sys/ucx index 938ffcd..d9aa565 160000 --- a/ucx1-sys/ucx +++ b/ucx1-sys/ucx @@ -1 +1 @@ -Subproject commit 938ffcd10122742d0f46a4f609e7395d1648c969 +Subproject commit d9aa5650d4cbcbb00d61af980614dbe9dd27a1f2