diff --git a/Cargo.toml b/Cargo.toml index 6996cad..85d5929 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,11 +1,27 @@ [package] -name = "jsonrpc-ws" +name = "jsonrpc" version = "0.1.0" authors = ["tiannian "] edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[lib] +name = "jsonrpc" +path = "src/lib.rs" + +[workspace] +members = [ + ".", + "jsonrpc-websocket", +] + [dependencies] serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" +jsonrpc-lite = "0.5.0" +fxhash = "0.2.1" +futures-util = "0.3.5" + +[dev-dependencies] +tokio = { version = "0.2", features = ["full"] } diff --git a/jsonrpc-websocket/Cargo.toml b/jsonrpc-websocket/Cargo.toml new file mode 100644 index 0000000..0df1174 --- /dev/null +++ b/jsonrpc-websocket/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "jsonrpc-websocket" +version = "0.1.0" +authors = ["xujian "] +edition = "2018" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[lib] +name = "jsonrpc_websocket" +path = "src/lib.rs" + + +[dependencies] +tokio = { version = "0.2", features = ["full"] } +chrono = "0.4.11" +tokio-tungstenite = "0.10.1" +futures-util = { version = "0.3", default-features = false, features = ["async-await", "sink", "std"] } +url = "2.0.0" +jsonrpc-lite = "0.5.0" +jsonrpc = { path = "../../jsonrpc-ws" } +log = "0.4.8" + + +[dev-dependencies] +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +env_logger = "0.7" \ No newline at end of file diff --git a/jsonrpc-websocket/examples/example_client.rs b/jsonrpc-websocket/examples/example_client.rs new file mode 100644 index 0000000..ce87788 --- /dev/null +++ b/jsonrpc-websocket/examples/example_client.rs @@ -0,0 +1,168 @@ +use futures_util::stream::{SplitSink, SplitStream}; +use futures_util::{SinkExt, StreamExt}; +use std::env; +use std::{ + sync::{Arc, Mutex}, + time::Duration, +}; +use tokio::io::{AsyncBufReadExt, BufReader}; +use tokio::net::TcpStream; +use tokio::time; +use tokio_tungstenite::connect_async; +use tokio_tungstenite::tungstenite::Message; +use tokio_tungstenite::WebSocketStream; +use url::Url; + +static LOCAL_SERVER: &'static str = "ws://127.0.0.1:9000"; + +const RECONN_INTERVAL: u64 = 3000; + +struct WebSockWriteHalf(pub Option, Message>>); +struct WebSockReadHalf(pub Option>>); + +async fn set_conn_none( + lock_ws_receiver: Arc>, + lock_ws_sender: Arc>, +) -> bool { + let mut ws_receiver = lock_ws_receiver.lock().unwrap(); + let mut ws_sender = lock_ws_sender.lock().unwrap(); + + ws_receiver.0 = None; + ws_sender.0 = None; + return true; +} + +async fn client_check_conn( + case_url: Url, + lock_ws_receiver: Arc>, + lock_ws_sender: Arc>, +) -> bool { + let ws_receiver = lock_ws_receiver.lock().unwrap(); + + if let None = ws_receiver.0 { + drop(ws_receiver); + + if let Ok((ws_stream, _)) = connect_async(case_url).await { + let (sender, receiver) = ws_stream.split(); + let mut ws_receiver = lock_ws_receiver.lock().unwrap(); + let mut ws_sender = lock_ws_sender.lock().unwrap(); + + ws_sender.0 = Some(sender); + ws_receiver.0 = Some(receiver); + log::info!("connect success"); + return true; + } else { + log::info!("connect fail, reconning ..."); + return false; + } + } + return true; +} + +async fn receiver_loop( + case_url: Url, + lock_ws_receiver: Arc>, + lock_ws_sender: Arc>, +) { + loop { + let mut ws_receiver = lock_ws_receiver.lock().unwrap(); + + let result: Result = match &mut ws_receiver.0 { + Some(ws_receiver) => match ws_receiver.next().await { + Some(Ok(msg)) => { + if msg.is_text() { + Ok(msg.into_text().unwrap()) + } else { + log::warn!("Peer receive data format error, not text"); + Err(false) + } + } + Some(Err(_)) => { + log::warn!("server close connect"); + Err(true) + } + None => Err(true), + }, + None => Err(true), + }; + drop(ws_receiver); + + match result { + Ok(msg) => { + println!("resp: {}", msg); + } + Err(is_reconn) => { + if is_reconn { + set_conn_none(lock_ws_receiver.clone(), lock_ws_sender.clone()).await; + if client_check_conn( + case_url.clone(), + lock_ws_receiver.clone(), + lock_ws_sender.clone(), + ) + .await + { + log::info!("re_conn: {}", case_url); + continue; + } else { + time::delay_for(Duration::from_millis(RECONN_INTERVAL)).await; + } + } + } + } + } +} + +async fn ws_send(str_cmd: String, lock_ws_sender: Arc>) { + let mut ws_sender = match lock_ws_sender.try_lock() { + Ok(ws_sender) => ws_sender, + Err(_) => { + time::delay_for(Duration::from_millis(100)).await; + + log::warn!("ws_stream close, skip send"); + return; + } + }; + + if let Some(ws_sender) = &mut ws_sender.0 { + if let Err(err) = ws_sender.send(Message::Text(str_cmd)).await { + log::warn!("ws_stream send failed with err: {}", err); + } + } else { + log::warn!("ws_stream close, skip send"); + } +} + +#[tokio::main] +async fn main() { + use env_logger::Env; + env_logger::from_env(Env::default().default_filter_or("warn")).init(); + + let connect_transport = env::args() + .nth(1) + .unwrap_or_else(|| LOCAL_SERVER.to_string()); + + let case_url = Url::parse(&connect_transport).expect("Bad testcase URL"); + + let lock_ws_receiver = Arc::new(Mutex::new(WebSockReadHalf(None))); + let lock_ws_sender = Arc::new(Mutex::new(WebSockWriteHalf(None))); + + let local = tokio::task::LocalSet::new(); + local + .run_until(async move { + tokio::task::spawn_local(receiver_loop( + case_url, + lock_ws_receiver.clone(), + lock_ws_sender.clone(), + )); + + let mut reader = BufReader::new(tokio::io::stdin()); + loop { + let mut str_cmd = String::new(); + reader.read_line(&mut str_cmd).await.unwrap(); + str_cmd.pop(); + + ws_send(str_cmd, lock_ws_sender.clone()).await; + } + }) + .await; +} diff --git a/jsonrpc-websocket/examples/example_server.rs b/jsonrpc-websocket/examples/example_server.rs new file mode 100644 index 0000000..eb8d188 --- /dev/null +++ b/jsonrpc-websocket/examples/example_server.rs @@ -0,0 +1,105 @@ +extern crate jsonrpc_websocket; + +use jsonrpc::route::Route; +use jsonrpc::Data; +use jsonrpc_lite::Error as JsonRpcError; +use jsonrpc_websocket::WsServer; +use serde::{Deserialize, Serialize}; +use std::env; +use std::sync::{Arc, RwLock}; + +#[derive(Serialize)] +pub enum ExampleError { + // websock 错误 + ParamIsNone, +} + +impl Into for ExampleError { + fn into(self) -> JsonRpcError { + let (code, message) = match self { + ExampleError::ParamIsNone => (1000i64, "Param is none"), + }; + + JsonRpcError { + code, + message: message.to_string(), + data: None, + } + } +} + +#[derive(Serialize, Deserialize)] +pub struct CurrencyDetail { + value: u64, + id: String, + dcds: String, + locked: bool, + owner: String, +} + +pub struct CurrencyStore {} +impl CurrencyStore { + pub fn get_detail_by_ids( + &self, + req: GetDetailParam, + ) -> Result, ExampleError> { + if req.ids.len() == 0 { + return Err(ExampleError::ParamIsNone); + } + Ok(Vec::::new()) + } +} + +#[derive(Serialize, Deserialize)] +pub struct GetDetailParam { + ids: Vec, +} + +pub async fn get_detail_by_ids( + wallet: Data, + req: GetDetailParam, +) -> Result, ExampleError> { + let store = wallet.get_ref().store.try_read().unwrap(); + store.get_detail_by_ids(req) +} + +pub struct TSystem { + pub store: RwLock, +} + +impl TSystem { + pub fn new() -> Self { + Self { + store: RwLock::new(CurrencyStore {}), + } + } +} + +pub async fn start_ws_server(bind_transport: String) { + let route: Arc = Arc::new( + Route::new() + .data(TSystem::new()) + .to("currency.ids.detail".to_string(), get_detail_by_ids), + ); + + let ws_server = match WsServer::bind(bind_transport).await { + Ok(ws_server) => ws_server, + Err(err) => panic!("{}", err), + }; + + ws_server.listen_loop(route).await; +} + +static LOCAL_SERVER: &'static str = "127.0.0.1:9000"; + +#[tokio::main] +async fn main() { + use env_logger::Env; + env_logger::from_env(Env::default().default_filter_or("warn")).init(); + + let bind_transport = env::args() + .nth(1) + .unwrap_or_else(|| LOCAL_SERVER.to_string()); + + start_ws_server(bind_transport).await; +} diff --git a/jsonrpc-websocket/src/lib.rs b/jsonrpc-websocket/src/lib.rs new file mode 100644 index 0000000..bfa8fa5 --- /dev/null +++ b/jsonrpc-websocket/src/lib.rs @@ -0,0 +1,2 @@ +mod server; +pub use server::WsServer; diff --git a/jsonrpc-websocket/src/server.rs b/jsonrpc-websocket/src/server.rs new file mode 100644 index 0000000..3ab62e4 --- /dev/null +++ b/jsonrpc-websocket/src/server.rs @@ -0,0 +1,117 @@ +use futures_util::stream::{SplitSink, SplitStream}; +use futures_util::{SinkExt, StreamExt}; +use jsonrpc::route::{route_jsonrpc, Route}; +use std::sync::Arc; +use tokio::net::{TcpListener, TcpStream}; +use tokio::sync::mpsc; +use tokio_tungstenite::accept_async; +use tokio_tungstenite::tungstenite::Message; +use tokio_tungstenite::WebSocketStream; + +type WebSockWriteHalf = SplitSink, Message>; +type WebSockReadHalf = SplitStream>; + +const REQ_QUEUE_LEN: usize = 10; + +pub struct WsServer { + listener: TcpListener, +} + +impl WsServer { + pub async fn bind(bind_transport: String) -> Result { + let listener = TcpListener::bind(&bind_transport) + .await + .map_err(|err| err.to_string())?; + + log::info!("Listening on: {}", &bind_transport); + + let instance = Self { listener }; + + Ok(instance) + } + + pub async fn listen_loop(mut self, route: Arc) { + while let Ok((stream, _)) = self.listener.accept().await { + let route_ = route.clone(); + tokio::spawn(async move { + if let Err(err) = Self::client_loop(stream, route_).await { + log::warn!("{}", err); + } + }); + } + } + + async fn client_loop(stream: TcpStream, route: Arc) -> Result<(), String> { + let peer = stream + .peer_addr() + .map_err(|err| format!("get client peer_addr error, with info: {}", err))?; + + let ws_stream = accept_async(stream) + .await + .map_err(|err| format!("ws_stream accept error, with info: {}", err))?; + + log::info!("client {} connect", peer); + let (write_half, read_half) = ws_stream.split(); + + let (req_pipe_in, req_pipe_out) = mpsc::channel(REQ_QUEUE_LEN); + let (resp_pipe_in, resp_pipe_out) = mpsc::channel(REQ_QUEUE_LEN); + + tokio::select! { + _ = Self::dispatch_loop(route, req_pipe_out, resp_pipe_in) => { + log::info!("client {} close because dispatch_loop", peer); + }, + _ = Self::read_half_loop(read_half, req_pipe_in) => { + log::info!("client {} close because read_half", peer); + }, + _ = Self::write_half_loop(write_half, resp_pipe_out) => { + log::info!("client {} close because write_half", peer); + }, + }; + + Ok(()) + } + + async fn dispatch_loop( + route: Arc, + mut req_pipe: mpsc::Receiver, + mut resp_pipe: mpsc::Sender, + ) { + while let Some(req_str) = req_pipe.recv().await { + let route_ = route.clone(); + let resp_str = route_jsonrpc(route_, &req_str).await; + if let Err(_) = resp_pipe.send(resp_str).await { + // 处理完客户端已断开,忽略 + return; + } + } + } + + async fn read_half_loop(mut read_half: WebSockReadHalf, mut req_pipe_in: mpsc::Sender) { + while let Some(ans) = read_half.next().await { + match ans { + Err(_) => { + return; + } + Ok(Message::Text(msg_str)) => { + if let Err(_) = req_pipe_in.send(msg_str).await { + return; + } + } + Ok(Message::Ping(_)) => log::debug!("recv message ping/pong"), + Ok(Message::Pong(_)) => log::debug!("recv message ping/pong"), + Ok(_) => log::debug!("data format not String, ignore this item"), + } + } + } + + async fn write_half_loop( + mut write_half: WebSockWriteHalf, + mut resp_pipe_out: mpsc::Receiver, + ) { + while let Some(msg_str) = resp_pipe_out.recv().await { + if let Err(_) = write_half.send(Message::Text(msg_str)).await { + return; + } + } + } +} diff --git a/src/data.rs b/src/data.rs new file mode 100644 index 0000000..9d4f0bf --- /dev/null +++ b/src/data.rs @@ -0,0 +1,80 @@ +use std::any::{Any, TypeId}; +use std::sync::Arc; + +use fxhash::FxHashMap; + +pub struct Data(Arc); + +impl Data { + pub fn new(d: T) -> Self { + Self(Arc::new(d)) + } + + pub fn get_ref(&self) -> &T { + self.0.as_ref() + } +} + +impl Clone for Data { + fn clone(&self) -> Data { + Data(self.0.clone()) + } +} + +pub(crate) trait DataFactory { + fn get(&self) -> Option<&D>; +} + +pub struct DataExtensions(FxHashMap>); + +impl DataExtensions { + pub fn insert(&mut self, t: T) { + self.0.insert(TypeId::of::(), Box::new(t)); + } +} + +unsafe impl Sync for DataExtensions {} + +unsafe impl Send for DataExtensions {} + +impl Default for DataExtensions { + fn default() -> Self { + Self(FxHashMap::>::default()) + } +} + +impl DataFactory for DataExtensions { + fn get(&self) -> Option<&D> { + self.0 + .get(&TypeId::of::()) + .and_then(|boxed| boxed.downcast_ref()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Mutex; + + #[derive(Debug)] + pub struct Test { + pub a: Mutex, + pub b: Mutex, + } + + #[test] + fn test_data() { + let data = Data::new(Test { + a: Mutex::new(99u32), + b: Mutex::new("abcdefg".to_string()), + }); + + let mut extension = DataExtensions::default(); + extension.insert(data); + + let obj = extension.get::>().unwrap().clone(); + + assert_eq!(99, *obj.get_ref().a.lock().unwrap()); + assert_eq!("abcdefg".to_string(), *obj.get_ref().b.lock().unwrap()); + } +} diff --git a/src/factory.rs b/src/factory.rs new file mode 100644 index 0000000..4431a0e --- /dev/null +++ b/src/factory.rs @@ -0,0 +1,24 @@ +use crate::data::Data; +use serde::{Deserialize, Serialize}; +use std::future::Future; + +pub(crate) trait Factory: Clone + 'static +where + O: Serialize, + R: Future + Send, +{ + fn call(&self, params: T) -> R; +} + +impl Factory<(Data, P), R, O> for F +where + F: Fn(Data, P) -> R + Clone + 'static, + O: Serialize, + R: Future + Send, + P: for<'de> Deserialize<'de>, + T: 'static, +{ + fn call(&self, params: (Data, P)) -> R { + (self)(params.0, params.1) + } +} diff --git a/src/lib.rs b/src/lib.rs index 089e222..f1eff90 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,172 +1,19 @@ -use serde::{Deserialize, Serialize}; -use std::collections::HashMap; -use std::future::Future; -use std::pin::Pin; +#![feature(type_alias_impl_trait)] +#![feature(fn_traits)] -#[derive(Deserialize, Debug)] -pub(crate) struct Request { - pub jsonrpc: String, - pub method: String, - pub params: serde_json::Value, - pub id: i64, -} - -#[derive(Serialize, Debug)] -pub(crate) struct Response { - jsonrpc: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub result: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub error: Option, - id: i64, -} - -impl Response { - pub fn new(id: i64, result: Result) -> Self { - match result { - Ok(r) => Response { - jsonrpc: "2.0".to_string(), - error: None, - result: Some(r), - id, - }, - Err(e) => Response { - jsonrpc: "2.0".to_string(), - result: None, - error: Some(e), - id, - }, - } - } -} - -pub(crate) trait DataFactory {} - -pub struct Data(T); - -impl Data { - pub fn new(t: T) -> Self { - Data(t) - } -} - -impl DataFactory for Data {} - -pub(crate) trait Factory: Clone + 'static -where - O: Serialize, - R: Future, -{ - fn call(&self, params: T) -> R; -} - -impl Factory<(), R, O> for F -where - F: Fn() -> R + Clone + 'static, - O: Serialize, - R: Future, -{ - fn call(&self, _: ()) -> R { - (self)() - } -} - -impl Factory<(Data,), R, O> for F -where - F: Fn(Data) -> R + Clone + 'static, - O: Serialize, - R: Future, -{ - fn call(&self, params: (Data,)) -> R { - (self)(params.0) - } -} - -impl Factory<(P,), R, O> for F -where - F: Fn(P) -> R + Clone + 'static, - O: Serialize, - R: Future, - P: for<'de> Deserialize<'de>, -{ - fn call(&self, params: (P,)) -> R { - (self)(params.0) - } -} - -impl Factory<(Data, P), R, O> for F -where - F: Fn(Data, P) -> R + Clone + 'static, - O: Serialize, - R: Future, - P: for<'de> Deserialize<'de>, -{ - fn call(&self, params: (Data, P)) -> R { - (self)(params.0, params.1) - } -} - -impl Factory<(P, Data), R, O> for F -where - F: Fn(P, Data) -> R + Clone + 'static, - O: Serialize, - R: Future, - P: for<'de> Deserialize<'de>, -{ - fn call(&self, params: (P, Data)) -> R { - (self)(params.0, params.1) - } -} +mod data; +pub use data::Data; -pub struct Server { - map: HashMap< - String, - Box Pin + Send>>>, - >, - state: Option>, -} +mod factory; -impl Server { - pub fn new() -> Self { - Server { - map: HashMap::new(), - state: None, - } - } +pub mod route; - pub fn to(mut self, key: String, handle: H) -> Self - where - P: for<'de> Deserialize<'de> + Send + 'static, - R: Serialize + 'static, - E: Serialize + 'static, - F: Future> + Send + 'static, - H: Fn(P) -> F + 'static + Send, - { - let inner_handle = - move |req: Request| -> Pin + Send>> { - async fn inner(req: Request, handle: H) -> serde_json::Value - where - P: for<'de> Deserialize<'de> + Send + 'static, - R: Serialize + 'static, - E: Serialize + 'static, - F: Future> + Send + 'static, - H: Fn(P) -> F + 'static + Send, - { - let params: P = serde_json::from_value(req.params).unwrap(); - let _r = (handle)(params); - let result = Response::new(req.id, _r.await); - serde_json::to_value(result).unwrap() - } - Box::pin(inner(req, handle)) - }; - self.map.insert(key, Box::new(inner_handle)); - self - } +use jsonrpc_lite::Error as JsonRpcError; - pub fn data(mut self, d: D) -> Self { - if self.state.is_none() { - self.state = Some(Box::new(Data::new(d))) - } - self +fn server_route_error() -> JsonRpcError { + JsonRpcError { + code: -32500, + message: "Server Internal Route error".to_string(), + data: None, } } diff --git a/src/route.rs b/src/route.rs new file mode 100644 index 0000000..2a7ec41 --- /dev/null +++ b/src/route.rs @@ -0,0 +1,181 @@ +use crate::data::DataFactory; +use crate::data::{Data, DataExtensions}; +use crate::server_route_error; +use futures_util::future::join_all; +use jsonrpc_lite::Error as JsonRpcError; +use jsonrpc_lite::JsonRpc; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::collections::HashMap; +use std::future::Future; +use std::pin::Pin; +use std::sync::{Arc, Mutex}; + +#[derive(Deserialize, Debug)] +pub struct Request { + pub jsonrpc: String, + pub method: String, + pub params: Value, + pub id: i64, +} + +pub struct Route { + map: HashMap< + String, + Box, Request) -> Pin + Send>>>, + >, + extensions: Arc, +} + +unsafe impl Sync for Route {} + +unsafe impl Send for Route {} + +impl Route { + pub fn new() -> Self { + Route { + map: HashMap::new(), + extensions: Arc::new(DataExtensions::default()), + } + } + + pub fn to<'a, P, F, R, E, H, T>(mut self, key: String, handle: H) -> Self + where + P: for<'de> Deserialize<'de> + Send + 'static, + R: Serialize + 'static, + E: Serialize + Into + 'static, + F: Future> + Send + 'static, + H: Fn(Data, P) -> F + 'static + Clone + Send + Sync, + T: 'static + Sync + Send, + { + let inner_handle = move |extensions: Arc, + req: Request| + -> Pin + Send>> { + async fn inner( + extensions: Arc, + req: Request, + handle: H, + ) -> Value + where + P: for<'de> Deserialize<'de> + Send + 'static, + R: Serialize + 'static, + E: Serialize + Into + 'static, + F: Future> + Send + 'static, + H: Fn(Data, P) -> F + 'static + Clone + Send + Sync, + T: 'static + Sync + Send, + { + let params: P = match serde_json::from_value(req.params) { + Ok(params) => params, + Err(_) => { + return serde_json::to_value(JsonRpc::error( + req.id, + JsonRpcError::invalid_params(), + )) + .unwrap() + } + }; + + let data_t = extensions.get::>().unwrap().clone(); + match (handle).call((data_t, params)).await { + Ok(result) => serde_json::to_value(JsonRpc::success( + req.id, + &serde_json::to_value(result).unwrap(), + )) + .unwrap(), + Err(err) => serde_json::to_value(JsonRpc::error(req.id, err.into())).unwrap(), + } + } + Box::pin(inner(extensions, req, handle.clone())) + }; + self.map.insert(key, Box::new(inner_handle)); + self + } + + pub fn data(mut self, d: D) -> Self { + Arc::get_mut(&mut self.extensions) + .unwrap() + .insert(Data::new(d)); + self + } + + /// 传入一个Value格式的json-rpc单独请求 + /// 立刻返回响应执行Future或者错误结果 + pub async fn route_once( + &self, + req_str: Value, + ) -> Result + Send>>, Value> { + let req: Request = serde_json::from_value(req_str).unwrap(); + let handle = match self.map.get(&req.method) { + Some(handle) => handle, + None => { + return Err(serde_json::to_value(JsonRpc::error( + req.id, + JsonRpcError::method_not_found(), + )) + .unwrap()) + } + }; + + Ok(handle(self.extensions.clone(), req)) + } +} + +/// 传入jsonrpc请求 +/// 返回结果 +pub async fn route_jsonrpc(server: Arc, req_str: &str) -> String { + let req: Value = match serde_json::from_str(req_str) { + Ok(req) => req, + Err(_) => { + return serde_json::to_value(JsonRpc::error((), JsonRpcError::parse_error())) + .unwrap() + .to_string() + } + }; + let resp = match req { + Value::Object(_) => match server.route_once(req).await { + Ok(fut) => fut.await, + Err(err) => err, + }, + Value::Array(array) => { + let share_outputs = Arc::new(Mutex::new(Vec::::new())); + let mut tasks = Vec::new(); + + for each in array { + let inner_server = Arc::downgrade(&server); + let share_outputs = share_outputs.clone(); + + tasks.push(async move { + // task开始执行是尝试获取server对象 + let output = match inner_server.upgrade() { + Some(server) => match server.route_once(each).await { + Ok(fut) => fut.await, + Err(err) => err, + }, + None => serde_json::to_value(server_route_error()).unwrap(), + }; + + let mut outputs = share_outputs.lock().unwrap(); + outputs.push(output); + }); + } + join_all(tasks).await; + + // TODO 内部panic可能要处理 + // outputs Arc持有者只剩下一个,此处取出不会失败,也不考虑失败处理 + let output = if let Ok(outputs) = Arc::try_unwrap(share_outputs) { + // 锁持有者同理 + outputs.into_inner().unwrap() + } else { + panic!("Arc> into_inner failed"); + }; + Value::Array(output) + } + _ => { + return serde_json::to_value(JsonRpc::error((), JsonRpcError::parse_error())) + .unwrap() + .to_string() + } + }; + + resp.to_string() +} diff --git a/tests/test.rs b/tests/test.rs new file mode 100644 index 0000000..f34e4bd --- /dev/null +++ b/tests/test.rs @@ -0,0 +1,194 @@ +use jsonrpc::route::route_jsonrpc; +use jsonrpc::route::Route; +use jsonrpc::Data; +use jsonrpc_lite::Error as JsonRpcError; +use serde::{Deserialize, Serialize}; +use serde_json::json; +use serde_json::Value; +use std::sync::{Arc, Mutex}; +use tokio::time::{self, Duration}; + +#[derive(Debug)] +pub struct ShareStateTest { + pub a: Mutex, + pub b: Mutex, +} + +#[derive(Deserialize)] +pub struct ReqTest { + pub a: u64, + pub b: String, + pub c: Vec, +} + +#[derive(Deserialize, Serialize)] +pub struct RespTest { + pub a: u64, + pub b: String, + pub c: Vec, +} + +#[derive(Debug, Serialize)] +pub enum TestError { + // websock 错误 + WebSockServerBindError, + WebSockServerAcceptConnError, + WebSockServerGetPeerError, +} + +impl Into for TestError { + fn into(self) -> JsonRpcError { + JsonRpcError { + code: 1000i64, + message: "test".to_string(), + data: None, + } + } +} + +async fn route_b(local_test: Data, req: ReqTest) -> Result { + time_sleep(1000).await; + + let mut a = local_test.get_ref().a.lock().unwrap(); + *a += 1; + + let mut new_resp_c = Vec::::new(); + new_resp_c.extend_from_slice(&req.c[..]); + new_resp_c.push(format!(" add {}", *a)); + + Ok(RespTest { + a: *a + req.a, + b: req.b, + c: new_resp_c, + }) +} + +async fn time_sleep(timeout_ms: u64) { + time::delay_for(Duration::from_millis(timeout_ms.into())).await; +} + +#[test] +fn test_server_simple() { + let route = Route::new() + .data(ShareStateTest { + a: Mutex::new(100u64), + b: Mutex::new("abcdefg".to_string()), + }) + .to("route_b".to_string(), route_b); + + let mut runtime = tokio::runtime::Runtime::new().unwrap(); + + runtime.block_on(async move { + let resp = route + .route_once(json!({ + "jsonrpc": "2.0", + "method": "route_b", + "params": {"err_param": 1}, + "id": 99, + })) + .await; + + assert_eq!( + json!({ + "error":{ + "code":-32602, + "message":"Invalid params" + }, + "id":99, + "jsonrpc":"2.0" + }), + resp.unwrap().await + ); + }); +} + +#[tokio::test] +async fn test_server_route_and_array() { + let route = Arc::new( + Route::new() + .data(ShareStateTest { + a: Mutex::new(100u64), + b: Mutex::new("abcdefg".to_string()), + }) + .to("route_b".to_string(), route_b), + ); + + let tasks = async move { + let resp: Value = serde_json::from_str( + &route_jsonrpc( + route.clone(), + &json!({ + "jsonrpc": "2.0", + "method": "route_b", + "params": {"err_param": 1}, + "id": 99, + }) + .to_string(), + ) + .await, + ) + .unwrap(); + + assert_eq!( + json!({"error":{"code":-32602,"message":"Invalid params"},"id":99,"jsonrpc":"2.0"}) + .to_string(), + resp.to_string() + ); + + let resp: Value = serde_json::from_str( + &route_jsonrpc( + route.clone(), + &json!([{ + "jsonrpc": "2.0", + "method": "route_b", + "params": {"err_param": 1}, + "id": 91, + },{ + "jsonrpc": "2.0", + "method": "route_b", + "params": {"a": 8888u64, "b":"_8888_", "c":["c","_string_","_8888_"]}, + "id": 92, + },{ + "jsonrpc": "2.0", + "method": "route_b", + "params": {"a": 8888u64, "b":"_8888_", "c":["c","_string_","_8888_"]}, + "id": 93, + }]) + .to_string(), + ) + .await, + ) + .unwrap(); + + let resp_vec = match resp { + Value::Array(array) => array, + _ => panic!("unexpect error"), + }; + + let ans_91: Vec<&Value> = resp_vec + .iter() + .filter(|&resp| resp["id"].as_u64().unwrap() == 91) + .collect(); + + assert_eq!(1, ans_91.len()); + assert_eq!( + "Invalid params", + ans_91[0]["error"]["message"].as_str().unwrap() + ); + + let ans_9293_a_sum = resp_vec + .iter() + .filter(|&resp| { + resp["id"].as_u64().unwrap() == 92 || resp["id"].as_u64().unwrap() == 93 + }) + .fold(0, |sum, resp| { + sum + serde_json::from_value::(resp["result"].clone()) + .unwrap() + .a + }); + + assert_eq!(8888u64 * 2 + 101 + 102, ans_9293_a_sum); + }; + + tokio::spawn(tasks).await.unwrap(); +}