diff --git a/.github/workflows/typos.yml b/.github/workflows/typos.yml index 2ff4ef0ac6..49d1335194 100644 --- a/.github/workflows/typos.yml +++ b/.github/workflows/typos.yml @@ -38,5 +38,5 @@ jobs: - name: Run lychee uses: lycheeverse/lychee-action@v2 with: - args: --base . --config ./lychee.toml './**/*.md' + args: --config ./lychee.toml './**/*.md' fail: true diff --git a/Cargo.toml b/Cargo.toml index 65d0c135ef..645cf1ab52 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -703,6 +703,11 @@ name = "flat_router" path = "examples/06-routing/flat_router.rs" doc-scrape-examples = true +[[example]] +name = "query_params" +path = "examples/07-fullstack/query_params.rs" +doc-scrape-examples = true + [[example]] name = "server_functions" path = "examples/07-fullstack/server_functions.rs" diff --git a/examples/07-fullstack/fullstack_hello_world.rs b/examples/07-fullstack/fullstack_hello_world.rs index db0ecfdb00..5c12d6e66d 100644 --- a/examples/07-fullstack/fullstack_hello_world.rs +++ b/examples/07-fullstack/fullstack_hello_world.rs @@ -18,7 +18,7 @@ fn main() { }); } -#[get("/api/{name}/?age")] +#[get("/api/:name/?age")] async fn get_message(name: String, age: i32) -> Result { Ok(format!("Hello {}, you are {} years old!", name, age)) } diff --git a/examples/07-fullstack/query_params.rs b/examples/07-fullstack/query_params.rs new file mode 100644 index 0000000000..177af625ac --- /dev/null +++ b/examples/07-fullstack/query_params.rs @@ -0,0 +1,55 @@ +//! An example showcasing query parameters in Dioxus Fullstack server functions. +//! +//! The query parameter syntax mostly follows axum, but with a few extra conveniences. +//! - can rename parameters in the function signature with `?age=age_in_years` where `age_in_years` is Rust variable name +//! - can absorb all query params with `?{object}` directly into a struct implementing `Deserialize` + +use dioxus::prelude::*; + +fn main() { + dioxus::launch(|| { + let mut message = use_action(get_message); + let mut message_rebind = use_action(get_message_rebind); + let mut message_all = use_action(get_message_all); + + rsx! { + h1 { "Server says: "} + div { + button { onclick: move |_| message.call(22), "Single" } + pre { "{message:?}"} + } + div { + button { onclick: move |_| message_rebind.call(25), "Rebind" } + pre { "{message_rebind:?}"} + } + div { + button { onclick: move |_| message_all.call(Params { age: 30, name: "world".into() }), "Bind all" } + pre { "{message_all:?}"} + } + } + }); +} + +#[get("/api/message/?age")] +async fn get_message(age: i32) -> Result { + Ok(format!("You are {} years old!", age)) +} + +#[get("/api/rebind/?age=age_in_years")] +async fn get_message_rebind(age_in_years: i32) -> Result { + Ok(format!("You are {} years old!", age_in_years)) +} + +#[derive(serde::Deserialize, serde::Serialize, Debug)] +struct Params { + age: i32, + name: String, +} + +#[get("/api/all/?{query}")] +async fn get_message_all(query: Params) -> Result { + Ok(format!( + "Hello {}, you are {} years old!", + query.name, query.age + )) +} diff --git a/examples/07-fullstack/server_functions.rs b/examples/07-fullstack/server_functions.rs index 9811376dac..52f47843e4 100644 --- a/examples/07-fullstack/server_functions.rs +++ b/examples/07-fullstack/server_functions.rs @@ -92,8 +92,8 @@ //! take a `State` extractor cannot be automatically added to the router since the dioxus router //! type does not know how to construct the `T` type. //! -//! These server functions will be registered once the `ServerFnState` layer is added to the app with -//! `router = router.layer(ServerFnState::new(your_state))`. +//! These server functions will be registered once the `ServerState` layer is added to the app with +//! `router = router.layer(ServerState::new(your_state))`. //! //! ## Middleware //! diff --git a/examples/07-fullstack/server_state.rs b/examples/07-fullstack/server_state.rs index 96ba6e42c2..b0975ab23b 100644 --- a/examples/07-fullstack/server_state.rs +++ b/examples/07-fullstack/server_state.rs @@ -1,6 +1,13 @@ //! This example shows how to use global state to maintain state between server functions. -use dioxus::prelude::*; +use std::rc::Rc; + +use axum_core::extract::{FromRef, FromRequest}; +use dioxus::{ + fullstack::{FullstackContext, extract::State}, + prelude::*, +}; +use reqwest::header::HeaderMap; #[cfg(feature = "server")] use { @@ -77,7 +84,68 @@ type BroadcastExtension = axum::Extension #[post("/api/broadcast", ext: BroadcastExtension)] async fn broadcast_message() -> Result<()> { + let rt = Rc::new("asdasd".to_string()); ext.send("New broadcast message".to_string())?; + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + println!("rt: {}", rt); + + Ok(()) +} + +/* +Option 4: + +You can use Axum's `State` extractor to provide custom application state to your server functions. + +All ServerFunctions pull in `FullstackContext`, so you need to implement `FromRef` for your +custom state type. To add your state to your app, you can use `.register_server_functions()` on a router +for a given state type, which will automatically add your state into the `FullstackContext` used by your server functions. + +There are two details to note here: + +- You need to implement `FromRef` for your custom state type. +- Custom extractors need to implement `FromRequest` where `S` is the state type that implements `FromRef`. +*/ +#[derive(Clone)] +struct MyAppState { + abc: i32, +} + +impl FromRef for MyAppState { + fn from_ref(state: &FullstackContext) -> Self { + state.extension::().unwrap() + } +} + +struct CustomExtractor { + abc: i32, + headermap: HeaderMap, +} + +impl FromRequest for CustomExtractor +where + MyAppState: FromRef, + S: Send + Sync, +{ + type Rejection = (); + + async fn from_request( + _req: axum::extract::Request, + state: &S, + ) -> std::result::Result { + let state = MyAppState::from_ref(state); + Ok(CustomExtractor { + abc: state.abc, + headermap: HeaderMap::new(), + }) + } +} + +#[post("/api/stateful", state: State, ex: CustomExtractor)] +async fn app_state() -> Result<()> { + println!("abc: {}", state.abc); + println!("state abc: {:?}", ex.abc); + println!("headermap: {:?}", ex.headermap); Ok(()) } @@ -95,6 +163,10 @@ fn main() { let router = dioxus::server::router(app) .layer(Extension(tokio::sync::broadcast::channel::(16).0)); + // To use our custom app state with `State`, we need to register it + // as an extension since our `FromRef` implementation relies on it. + let router = router.layer(Extension(MyAppState { abc: 42 })); + Ok(router) }); } diff --git a/packages/dioxus/src/lib.rs b/packages/dioxus/src/lib.rs index 7f09576236..9199e75ae7 100644 --- a/packages/dioxus/src/lib.rs +++ b/packages/dioxus/src/lib.rs @@ -211,9 +211,7 @@ pub mod prelude { #[cfg(feature = "server")] #[cfg_attr(docsrs, doc(cfg(feature = "server")))] #[doc(inline)] - pub use dioxus_server::{ - self, serve, DioxusRouterExt, DioxusRouterFnExt, ServeConfig, ServerFunction, - }; + pub use dioxus_server::{self, serve, DioxusRouterExt, ServeConfig, ServerFunction}; #[cfg(feature = "router")] #[cfg_attr(docsrs, doc(cfg(feature = "router")))] diff --git a/packages/fullstack-core/src/error.rs b/packages/fullstack-core/src/error.rs index ae4499f657..cb95da6ce7 100644 --- a/packages/fullstack-core/src/error.rs +++ b/packages/fullstack-core/src/error.rs @@ -162,6 +162,20 @@ impl From for ServerFnError { } } +impl From for HttpError { + fn from(value: ServerFnError) -> Self { + let status = StatusCode::from_u16(match &value { + ServerFnError::ServerError { code, .. } => *code, + _ => 500, + }) + .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); + HttpError { + status, + message: Some(value.to_string()), + } + } +} + impl From for ServerFnError { fn from(value: HttpError) -> Self { ServerFnError::ServerError { diff --git a/packages/fullstack-core/src/lib.rs b/packages/fullstack-core/src/lib.rs index e408a7bb3e..6ace2e449c 100644 --- a/packages/fullstack-core/src/lib.rs +++ b/packages/fullstack-core/src/lib.rs @@ -25,6 +25,3 @@ pub use error::*; pub mod httperror; pub use httperror::*; - -#[derive(Clone, Default)] -pub struct DioxusServerState {} diff --git a/packages/fullstack-core/src/streaming.rs b/packages/fullstack-core/src/streaming.rs index 481db186e1..bb7052d673 100644 --- a/packages/fullstack-core/src/streaming.rs +++ b/packages/fullstack-core/src/streaming.rs @@ -9,10 +9,6 @@ use std::collections::HashSet; use std::fmt::Debug; use std::sync::Arc; -tokio::task_local! { - static FULLSTACK_CONTEXT: FullstackContext; -} - /// The context provided by dioxus fullstack for server-side rendering. /// /// This context will only be set on the server during the initial streaming response @@ -21,10 +17,18 @@ tokio::task_local! { pub struct FullstackContext { // We expose the lock for request headers directly so it needs to be in a separate lock request_headers: Arc>, + // The rest of the fields are only held internally, so we can group them together lock: Arc>, } +// `FullstackContext` is always set when either +// 1. rendering the app via SSR +// 2. handling a server function request +tokio::task_local! { + static FULLSTACK_CONTEXT: FullstackContext; +} + pub struct FullstackContextInner { current_status: StreamingStatus, current_status_subscribers: HashSet, @@ -78,6 +82,7 @@ impl FullstackContext { pub fn commit_initial_chunk(&mut self) { let mut lock = self.lock.write(); lock.current_status = StreamingStatus::InitialChunkCommitted; + // The key type is mutable, but the hash is stable through mutations because we hash by pointer #[allow(clippy::mutable_key_type)] let subscribers = std::mem::take(&mut lock.current_status_subscribers); @@ -110,17 +115,23 @@ impl FullstackContext { FULLSTACK_CONTEXT.scope(self, fut).await } + /// Extract an extension from the current request. + pub fn extension(&self) -> Option { + let lock = self.request_headers.read(); + lock.extensions.get::().cloned() + } + /// Extract an axum extractor from the current request. /// /// The body of the request is always empty when using this method, as the body can only be consumed once in the server /// function extractors. - pub async fn extract, M>() -> Result { + pub async fn extract, M>() -> Result { let this = Self::current() .ok_or_else(|| ServerFnError::new("No FullstackContext found".to_string()))?; let parts = this.request_headers.read().clone(); let request = axum_core::extract::Request::from_parts(parts, Default::default()); - match T::from_request(request, &()).await { + match T::from_request(request, &this).await { Ok(res) => Ok(res), Err(err) => { let resp = err.into_response(); @@ -255,7 +266,7 @@ pub fn commit_initial_chunk() { /// Extract an axum extractor from the current request. #[deprecated(note = "Use FullstackContext::extract instead", since = "0.7.0")] -pub fn extract, M>( +pub fn extract, M>( ) -> impl std::future::Future> { FullstackContext::extract::() } diff --git a/packages/fullstack-macro/src/lib.rs b/packages/fullstack-macro/src/lib.rs index 736f4dced0..325ef77485 100644 --- a/packages/fullstack-macro/src/lib.rs +++ b/packages/fullstack-macro/src/lib.rs @@ -13,8 +13,7 @@ use syn::{ parse::ParseStream, punctuated::Punctuated, token::{Comma, Slash}, - Error, ExprTuple, FnArg, GenericArgument, Meta, PathArguments, PathSegment, Signature, Token, - Type, TypePath, + Error, ExprTuple, FnArg, Meta, PathArguments, PathSegment, Token, Type, TypePath, }; use syn::{parse::Parse, parse_quote, Ident, ItemFn, LitStr, Path}; use syn::{spanned::Spanned, LitBool, LitInt, Pat, PatType}; @@ -104,7 +103,6 @@ pub fn server(attr: proc_macro::TokenStream, mut item: TokenStream) -> TokenStre method: None, path_params: vec![], query_params: vec![], - state: None, route_lit: args.fn_path, oapi_options: None, server_args: Default::default(), @@ -204,17 +202,16 @@ fn route_impl_with_route( // Parse the route and function let mut function = syn::parse::(item)?; - let middleware_attrs = function + // Collect the middleware initializers + let middleware_layers = function .attrs .iter() .filter(|attr| attr.path().is_ident("middleware")) - .cloned() - .collect::>(); - - let middleware_inits = middleware_attrs - .into_iter() - .map(|f| match f.meta { - Meta::List(meta_list) => Ok(meta_list.tokens), + .map(|f| match &f.meta { + Meta::List(meta_list) => Ok({ + let tokens = &meta_list.tokens; + quote! { .layer(#tokens) } + }), _ => Err(Error::new( f.span(), "Expected middleware attribute to be a list, e.g. #[middleware(MyLayer::new())]", @@ -227,54 +224,48 @@ fn route_impl_with_route( .attrs .retain(|attr| !attr.path().is_ident("middleware")); - let server_args = route.server_args.clone(); - let mut function_on_server = function.clone(); - function_on_server.sig.inputs.extend(server_args.clone()); - - // Now we can compile the route - let original_inputs = function + // Attach `#[allow(unused_mut)]` to all original inputs to avoid warnings + let outer_inputs = function .sig .inputs .iter() - .map(|arg| match arg { + .enumerate() + .map(|(i, arg)| match arg { FnArg::Receiver(_receiver) => panic!("Self type is not supported"), - FnArg::Typed(pat_type) => { - quote! { - #[allow(unused_mut)] - #pat_type + FnArg::Typed(pat_type) => match pat_type.pat.as_ref() { + Pat::Ident(_) => { + quote! { #[allow(unused_mut)] #pat_type } } - } + _ => { + let ident = format_ident!("___Arg{}", i); + let ty = &pat_type.ty; + quote! { #[allow(unused_mut)] #ident: #ty } + } + }, }) .collect::>(); + // .collect::>(); let route = CompiledRoute::from_route(route, &function, false, method_from_macro)?; - let path_extractor = route.path_extractor(); - let query_extractor = route.query_extractor(); let query_params_struct = route.query_params_struct(false); - let _state_type = &route.state; let method_ident = &route.method; - let http_method = route.method.to_axum_method_name(); - let _remaining_numbered_pats = route.remaining_pattypes_numbered(&function.sig.inputs); let body_json_args = route.remaining_pattypes_named(&function.sig.inputs); let body_json_names = body_json_args .iter() - .enumerate() .map(|(i, pat_type)| match &*pat_type.pat { Pat::Ident(ref pat_ident) => pat_ident.ident.clone(), - _ => format_ident!("___arg{}", i), + _ => format_ident!("___Arg{}", i), }) .collect::>(); let body_json_types = body_json_args .iter() - .map(|pat_type| &pat_type.ty) + .map(|pat_type| &pat_type.1.ty) .collect::>(); - let extracted_idents = route.extracted_idents(); let route_docs = route.to_doc_comments(); // Get the variables we need for code generation - let fn_name = &function.sig.ident; + let fn_on_server_name = &function.sig.ident; let vis = &function.vis; - let asyncness = &function.sig.asyncness; let (impl_generics, ty_generics, where_clause) = &function.sig.generics.split_for_impl(); let ty_generics = ty_generics.as_turbofish(); let fn_docs = function @@ -289,7 +280,11 @@ fn route_impl_with_route( syn::ReturnType::Type(_, ty) => (*ty).clone(), }; - let query_param_names = route.query_params.iter().map(|(ident, _)| ident); + let query_param_names = route + .query_params + .iter() + .filter(|c| !c.catch_all) + .map(|param| ¶m.binding); let path_param_args = route.path_params.iter().map(|(_slash, param)| match param { PathParam::Capture(_lit, _brace_1, ident, _ty, _brace_2) => { @@ -306,21 +301,24 @@ fn route_impl_with_route( _ => output_type.clone(), }; - let server_names = server_args + let mut function_on_server = function.clone(); + function_on_server + .sig + .inputs + .extend(route.server_args.clone()); + + let server_names = route + .server_args .iter() - .map(|pat_type| match pat_type { - FnArg::Receiver(_) => quote! { () }, - FnArg::Typed(pat_type) => match pat_type.pat.as_ref() { - Pat::Ident(pat_ident) => { - let name = &pat_ident.ident; - quote! { #name } - } - _ => panic!("Expected Pat::Ident"), - }, + .enumerate() + .map(|(i, pat_type)| match pat_type { + FnArg::Typed(_pat_type) => format_ident!("___sarg___{}", i), + FnArg::Receiver(_) => panic!("Self type is not supported"), }) .collect::>(); - let server_types = server_args + let server_types = route + .server_args .iter() .map(|pat_type| match pat_type { FnArg::Receiver(_) => parse_quote! { () }, @@ -349,42 +347,16 @@ fn route_impl_with_route( }; // This unpacks the body struct into the individual variables that get scoped - let unpack = { + let unpack_closure = { let unpack_args = body_json_names.iter().map(|name| quote! { data.#name }); quote! { |data| { ( #(#unpack_args,)* ) } } }; - // there's no active request on the server, so we just create a dummy one - let server_defaults = if server_args.is_empty() { - quote! {} - } else { - quote! { - let (#(#server_names,)*) = dioxus_fullstack::FullstackContext::extract::<(#(#server_types,)*), _>().await?; - } - }; - let as_axum_path = route.to_axum_path_string(); - let query_endpoint = if let Some(route_lit) = route.route_lit.as_ref() { - let prefix = route - .prefix - .as_ref() - .cloned() - .unwrap_or_else(|| LitStr::new("", Span::call_site())) - .value(); - let url_without_queries = route_lit.value().split('?').next().unwrap().to_string(); - let full_url = format!( - "{}{}{}", - prefix, - if url_without_queries.starts_with("/") { - "" - } else { - "/" - }, - url_without_queries - ); + let query_endpoint = if let Some(full_url) = route.url_without_queries_for_format() { quote! { format!(#full_url, #( #path_param_args)*) } } else { quote! { __ENDPOINT_PATH.to_string() } @@ -403,7 +375,7 @@ fn route_impl_with_route( quote! { concat!( "/", - stringify!(#fn_name) + stringify!(#fn_on_server_name) ) } }; @@ -433,14 +405,32 @@ fn route_impl_with_route( } }; - let middleware_extra = middleware_inits - .iter() - .map(|init| { - quote! { - .layer(#init) - } - }) - .collect::>(); + let extracted_idents = route.extracted_idents(); + + let query_tokens = if route.query_is_catchall() { + let query = route + .query_params + .iter() + .find(|param| param.catch_all) + .unwrap(); + let input = &function.sig.inputs[query.arg_idx]; + let name = match input { + FnArg::Typed(pat_type) => match pat_type.pat.as_ref() { + Pat::Ident(ref pat_ident) => pat_ident.ident.clone(), + _ => format_ident!("___Arg{}", query.arg_idx), + }, + FnArg::Receiver(_receiver) => panic!(), + }; + quote! { + #name + } + } else { + quote! { + __QueryParams__ { #(#query_param_names,)* } + } + }; + + let extracted_as_server_headers = route.extracted_as_server_headers(query_tokens.clone()); Ok(quote! { #(#fn_docs)* @@ -467,13 +457,11 @@ fn route_impl_with_route( ========================================================================================== " )] - #vis async fn #fn_name #impl_generics( - #original_inputs - ) -> #out_ty #where_clause { + #vis async fn #fn_on_server_name #impl_generics( #outer_inputs ) -> #out_ty #where_clause { use dioxus_fullstack::serde as serde; use dioxus_fullstack::{ // concrete types - ServerFnEncoder, ServerFnDecoder, DioxusServerState, + ServerFnEncoder, ServerFnDecoder, FullstackContext, // "magic" traits for encoding/decoding on the client ExtractRequest, EncodeRequest, RequestDecodeResult, RequestDecodeErr, @@ -482,14 +470,27 @@ fn route_impl_with_route( MakeAxumResponse, MakeAxumError, }; - _ = dioxus_fullstack::assert_is_result::<#out_ty>(); - #query_params_struct #body_struct_impl const __ENDPOINT_PATH: &str = #endpoint_path; + { + _ = dioxus_fullstack::assert_is_result::<#out_ty>(); + + let verify_token = (&&&&&&&&&&&&&&ServerFnEncoder::<___Body_Serialize___<#(#body_json_types,)*>, (#(#body_json_types,)*)>::new()) + .verify_can_serialize(); + + dioxus_fullstack::assert_can_encode(verify_token); + + let decode_token = (&&&&&ServerFnDecoder::<#out_ty>::new()) + .verify_can_deserialize(); + + dioxus_fullstack::assert_can_decode(decode_token); + }; + + // On the client, we make the request to the server // We want to support extremely flexible error types and return types, making this more complex than it should #[allow(clippy::unused_unit)] @@ -498,16 +499,11 @@ fn route_impl_with_route( let client = dioxus_fullstack::ClientRequest::new( dioxus_fullstack::http::Method::#method_ident, #query_endpoint, - &__QueryParams__ { #(#query_param_names,)* }, + &#query_tokens, ); - let verify_token = (&&&&&&&&&&&&&&ServerFnEncoder::<___Body_Serialize___<#(#body_json_types,)*>, (#(#body_json_types,)*)>::new()) - .verify_can_serialize(); - - dioxus_fullstack::assert_can_encode(verify_token); - let response = (&&&&&&&&&&&&&&ServerFnEncoder::<___Body_Serialize___<#(#body_json_types,)*>, (#(#body_json_types,)*)>::new()) - .fetch_client(client, ___Body_Serialize___ { #(#body_json_names,)* }, #unpack) + .fetch_client(client, ___Body_Serialize___ { #(#body_json_names,)* }, #unpack_closure) .await; let decoded = (&&&&&ServerFnDecoder::<#out_ty>::new()) @@ -523,45 +519,46 @@ fn route_impl_with_route( // On the server, we expand the tokens and submit the function to inventory #[cfg(feature = "server")] { - use #__axum::response::IntoResponse; - use dioxus_server::ServerFunction; - #function_on_server #[allow(clippy::unused_unit)] - #asyncness fn __inner__function__ #impl_generics( - ___state: #__axum::extract::State, - #path_extractor - #query_extractor - request: #__axum::extract::Request, - ) -> Result<#__axum::response::Response, #__axum::response::Response> #where_clause { - let ((#(#server_names,)*), ( #(#body_json_names,)* )) = (&&&&&&&&&&&&&&ServerFnEncoder::<___Body_Serialize___<#(#body_json_types,)*>, (#(#body_json_types,)*)>::new()) - .extract_axum(___state.0, request, #unpack).await?; - - let encoded = (&&&&&&ServerFnDecoder::<#out_ty>::new()) - .make_axum_response( - #fn_name #ty_generics(#(#extracted_idents,)* #(#body_json_names,)* #(#server_names,)*).await - ); - - let response = (&&&&&ServerFnDecoder::<#out_ty>::new()) - .make_axum_error(encoded); - - return response; + fn __inner__function__ #impl_generics( + ___state: #__axum::extract::State, + ___request: #__axum::extract::Request, + ) -> std::pin::Pin>> #where_clause { + Box::pin(async move { + match (&&&&&&&&&&&&&&ServerFnEncoder::<___Body_Serialize___<#(#body_json_types,)*>, (#(#body_json_types,)*)>::new()).extract_axum(___state.0, ___request, #unpack_closure).await { + Ok(((#(#body_json_names,)* ), (#(#extracted_as_server_headers,)* #(#server_names,)*) )) => { + // Call the user function + let res = #fn_on_server_name #ty_generics(#(#extracted_idents,)* #(#body_json_names,)* #(#server_names,)*).await; + + // Encode the response Into a `Result` + let encoded = (&&&&&&ServerFnDecoder::<#out_ty>::new()).make_axum_response(res); + + // And then encode `Result` into `Response` + (&&&&&ServerFnDecoder::<#out_ty>::new()).make_axum_error(encoded) + }, + Err(res) => res, + } + }) } dioxus_server::inventory::submit! { - ServerFunction::new( + dioxus_server::ServerFunction::new( dioxus_server::http::Method::#method_ident, __ENDPOINT_PATH, || { - #__axum::routing::#http_method(__inner__function__ #ty_generics) #(#middleware_extra)* + dioxus_server::ServerFunction::make_handler(dioxus_server::http::Method::#method_ident, __inner__function__ #ty_generics) + #(#middleware_layers)* } ) } - #server_defaults + // Extract the server arguments from the context + let (#(#server_names,)*) = dioxus_fullstack::FullstackContext::extract::<(#(#server_types,)*), _>().await?; - return #fn_name #ty_generics( + // Call the function directly + return #fn_on_server_name #ty_generics( #(#extracted_idents,)* #(#body_json_names,)* #(#server_names,)* @@ -580,11 +577,19 @@ struct CompiledRoute { method: Method, #[allow(clippy::type_complexity)] path_params: Vec<(Slash, PathParam)>, - query_params: Vec<(Ident, Box)>, - state: Type, + query_params: Vec, route_lit: Option, prefix: Option, oapi_options: Option, + server_args: Punctuated, +} + +struct QueryParam { + arg_idx: usize, + name: String, + binding: Ident, + catch_all: bool, + ty: Box, } impl CompiledRoute { @@ -607,10 +612,6 @@ impl CompiledRoute { } PathParam::Static(lit) => path.push_str(&lit.value()), } - // if colon.is_some() { - // path.push(':'); - // } - // path.push_str(&ident.value()); } path @@ -645,12 +646,13 @@ impl CompiledRoute { let mut arg_map = sig .inputs .iter() - .filter_map(|item| match item { + .enumerate() + .filter_map(|(i, item)| match item { syn::FnArg::Receiver(_) => None, - syn::FnArg::Typed(pat_type) => Some(pat_type), + syn::FnArg::Typed(pat_type) => Some((i, pat_type)), }) - .filter_map(|pat_type| match &*pat_type.pat { - syn::Pat::Ident(ident) => Some((ident.ident.clone(), pat_type.ty.clone())), + .filter_map(|(i, pat_type)| match &*pat_type.pat { + syn::Pat::Ident(ident) => Some((ident.ident.clone(), (pat_type.ty.clone(), i))), _ => None, }) .collect::>(); @@ -665,7 +667,7 @@ impl CompiledRoute { ) })?; *ident = new_ident; - *ty = new_ty; + *ty = new_ty.0; } PathParam::WildCard(_lit, _, _star, ident, ty, _) => { let (new_ident, new_ty) = arg_map.remove_entry(ident).ok_or_else(|| { @@ -675,24 +677,38 @@ impl CompiledRoute { ) })?; *ident = new_ident; - *ty = new_ty; + *ty = new_ty.0; } PathParam::Static(_lit) => {} } } let mut query_params = Vec::new(); - for ident in route.query_params { - let (ident, ty) = arg_map.remove_entry(&ident).ok_or_else(|| { + for param in route.query_params { + let (ident, ty) = arg_map.remove_entry(¶m.binding).ok_or_else(|| { syn::Error::new( - ident.span(), + param.binding.span(), format!( "query parameter `{}` not found in function arguments", - ident + param.binding ), ) })?; - query_params.push((ident, ty)); + query_params.push(QueryParam { + binding: ident, + name: param.name, + catch_all: param.catch_all, + ty: ty.0, + arg_idx: ty.1, + }); + } + + // Disallow multiple query params if one is a catch-all + if query_params.iter().any(|param| param.catch_all) && query_params.len() > 1 { + return Err(syn::Error::new( + Span::call_site(), + "Cannot have multiple query parameters when one is a catch-all", + )); } if let Some(options) = route.oapi_options.as_mut() { @@ -721,34 +737,54 @@ impl CompiledRoute { route_lit: route.route_lit, path_params: route.path_params, query_params, - state: route.state.unwrap_or_else(|| guess_state_type(sig)), oapi_options: route.oapi_options, prefix: route.prefix, + server_args: route.server_args, }) } - pub fn path_extractor(&self) -> TokenStream2 { - let path_iter = self - .path_params - .iter() - .filter_map(|(_slash, path_param)| path_param.capture()); - let idents = path_iter.clone().map(|item| item.0); - let types = path_iter.clone().map(|item| item.1); - quote! { - dioxus_server::axum::extract::Path((#(#idents,)*)): dioxus_server::axum::extract::Path<(#(#types,)*)>, - } + pub fn query_is_catchall(&self) -> bool { + self.query_params.iter().any(|param| param.catch_all) } - pub fn query_extractor(&self) -> TokenStream2 { - let idents = self.query_params.iter().map(|item| &item.0); - quote! { - dioxus_server::axum::extract::Query(__QueryParams__ { #(#idents,)* }): dioxus_server::axum::extract::Query<__QueryParams__>, - } + pub fn extracted_as_server_headers(&self, query_tokens: TokenStream2) -> Vec { + let mut out = vec![]; + + // Add the path extractor + out.push({ + let path_iter = self + .path_params + .iter() + .filter_map(|(_slash, path_param)| path_param.capture()); + let idents = path_iter.clone().map(|item| item.0); + parse_quote! { + dioxus_server::axum::extract::Path((#(#idents,)*)) + } + }); + + out.push(parse_quote!( + dioxus_server::axum::extract::Query(#query_tokens) + )); + + out } pub fn query_params_struct(&self, with_aide: bool) -> TokenStream2 { - let idents = self.query_params.iter().map(|item| &item.0); - let types = self.query_params.iter().map(|item| &item.1); + let fields = self.query_params.iter().map(|item| { + let name = &item.name; + let binding = &item.binding; + let ty = &item.ty; + if item.catch_all { + quote! {} + } else if item.binding != item.name { + quote! { + #[serde(rename = #name)] + #binding: #ty, + } + } else { + quote! { #binding: #ty, } + } + }); let derive = match with_aide { true => quote! { #[derive(serde::Deserialize, serde::Serialize, ::schemars::JsonSchema)] @@ -762,7 +798,7 @@ impl CompiledRoute { quote! { #derive struct __QueryParams__ { - #(#idents: #types,)* + #(#fields)* } } } @@ -774,49 +810,13 @@ impl CompiledRoute { idents.push(ident.clone()); } } - for (ident, _ty) in &self.query_params { - idents.push(ident.clone()); + for param in &self.query_params { + idents.push(param.binding.clone()); } idents } - fn remaining_pattypes_named( - &self, - args: &Punctuated, - ) -> Punctuated { - args.iter() - .filter_map(|item| { - if let FnArg::Typed(pat_type) = item { - if let syn::Pat::Ident(pat_ident) = &*pat_type.pat { - if self.path_params.iter().any(|(_slash, path_param)| { - if let Some((path_ident, _ty)) = path_param.capture() { - path_ident == &pat_ident.ident - } else { - false - } - }) || self - .query_params - .iter() - .any(|(query_ident, _)| query_ident == &pat_ident.ident) - { - return None; - } - } - - Some(pat_type.clone()) - } else { - unimplemented!("Self type is not supported") - } - }) - .collect() - } - - /// The arguments not used in the route. - /// Map the identifier to `___arg___{i}: Type`. - pub fn remaining_pattypes_numbered( - &self, - args: &Punctuated, - ) -> Punctuated { + fn remaining_pattypes_named(&self, args: &Punctuated) -> Vec<(usize, PatType)> { args.iter() .enumerate() .filter_map(|(i, item)| { @@ -831,16 +831,13 @@ impl CompiledRoute { }) || self .query_params .iter() - .any(|(query_ident, _)| query_ident == &pat_ident.ident) + .any(|query| query.binding == pat_ident.ident) { return None; } } - let mut new_pat_type = pat_type.clone(); - let ident = format_ident!("___arg___{}", i); - new_pat_type.pat = Box::new(parse_quote!(#ident)); - Some(new_pat_type) + Some((i, pat_type.clone())) } else { unimplemented!("Self type is not supported") } @@ -848,224 +845,16 @@ impl CompiledRoute { .collect() } - #[allow(dead_code)] - fn aide() { - // let http_method = format_ident!("{}_with", http_method); - // let summary = route - // .get_oapi_summary() - // .map(|summary| quote! { .summary(#summary) }); - // let description = route - // .get_oapi_description() - // .map(|description| quote! { .description(#description) }); - // let hidden = route - // .get_oapi_hidden() - // .map(|hidden| quote! { .hidden(#hidden) }); - // let tags = route.get_oapi_tags(); - // let id = route - // .get_oapi_id(&function.sig) - // .map(|id| quote! { .id(#id) }); - // let transform = route.get_oapi_transform()?; - // let responses = route.get_oapi_responses(); - // let response_code = responses.iter().map(|response| &response.0); - // let response_type = responses.iter().map(|response| &response.1); - // let security = route.get_oapi_security(); - // let schemes = security.iter().map(|sec| &sec.0); - // let scopes = security.iter().map(|sec| &sec.1); - - // ( - // route.ide_documentation_for_aide_methods(), - // quote! { - // ::aide::axum::routing::#http_method( - // __inner__function__ #ty_generics, - // |__op__| { - // let __op__ = __op__ - // #summary - // #description - // #hidden - // #id - // #(.tag(#tags))* - // #(.security_requirement_scopes::, _>(#schemes, vec![#(#scopes),*]))* - // #(.response::<#response_code, #response_type>())* - // ; - // #transform - // __op__ - // } - // ) - // }, - // quote! { ::aide::axum::routing::ApiMethodRouter }, - // ) - } - - #[allow(dead_code)] - pub fn ide_documentation_for_aide_methods(&self) -> TokenStream2 { - let Some(options) = &self.oapi_options else { - return quote! {}; - }; - let summary = options.summary.as_ref().map(|(ident, _)| { - let method = Ident::new("summary", ident.span()); - quote!( let x = x.#method(""); ) - }); - let description = options.description.as_ref().map(|(ident, _)| { - let method = Ident::new("description", ident.span()); - quote!( let x = x.#method(""); ) - }); - let id = options.id.as_ref().map(|(ident, _)| { - let method = Ident::new("id", ident.span()); - quote!( let x = x.#method(""); ) - }); - let hidden = options.hidden.as_ref().map(|(ident, _)| { - let method = Ident::new("hidden", ident.span()); - quote!( let x = x.#method(false); ) - }); - let tags = options.tags.as_ref().map(|(ident, _)| { - let method = Ident::new("tag", ident.span()); - quote!( let x = x.#method(""); ) - }); - let security = options.security.as_ref().map(|(ident, _)| { - let method = Ident::new("security_requirement_scopes", ident.span()); - quote!( let x = x.#method("", [""]); ) - }); - let responses = options.responses.as_ref().map(|(ident, _)| { - let method = Ident::new("response", ident.span()); - quote!( let x = x.#method::<0, String>(); ) - }); - let transform = options.transform.as_ref().map(|(ident, _)| { - let method = Ident::new("with", ident.span()); - quote!( let x = x.#method(|x|x); ) - }); - - quote! { - #[allow(unused)] - #[allow(clippy::no_effect)] - fn ____ide_documentation_for_aide____(x: ::aide::transform::TransformOperation) { - #summary - #description - #id - #hidden - #tags - #security - #responses - #transform - } - } - } - - #[allow(dead_code)] - pub fn get_oapi_summary(&self) -> Option { - if let Some(oapi_options) = &self.oapi_options { - if let Some(summary) = &oapi_options.summary { - return Some(summary.1.clone()); - } - } - None - } - - #[allow(dead_code)] - pub fn get_oapi_description(&self) -> Option { - if let Some(oapi_options) = &self.oapi_options { - if let Some(description) = &oapi_options.description { - return Some(description.1.clone()); - } - } - None - } - - #[allow(dead_code)] - pub fn get_oapi_hidden(&self) -> Option { - if let Some(oapi_options) = &self.oapi_options { - if let Some(hidden) = &oapi_options.hidden { - return Some(hidden.1.clone()); - } - } - None - } - - #[allow(dead_code)] - pub fn get_oapi_tags(&self) -> Vec { - if let Some(oapi_options) = &self.oapi_options { - if let Some(tags) = &oapi_options.tags { - return tags.1 .0.clone(); - } - } - Vec::new() - } - - #[allow(dead_code)] - pub fn get_oapi_id(&self, sig: &Signature) -> Option { - if let Some(oapi_options) = &self.oapi_options { - if let Some(id) = &oapi_options.id { - return Some(id.1.clone()); - } - } - Some(LitStr::new(&sig.ident.to_string(), sig.ident.span())) - } - - #[allow(dead_code)] - pub fn get_oapi_transform(&self) -> syn::Result> { - if let Some(oapi_options) = &self.oapi_options { - if let Some(transform) = &oapi_options.transform { - if transform.1.inputs.len() != 1 { - return Err(syn::Error::new( - transform.1.span(), - "expected a single identifier", - )); - } - - let pat = transform.1.inputs.first().unwrap(); - let body = &transform.1.body; - - if let Pat::Ident(pat_ident) = pat { - let ident = &pat_ident.ident; - return Ok(Some(quote! { - let #ident = __op__; - let __op__ = #body; - })); - } else { - return Err(syn::Error::new( - pat.span(), - "expected a single identifier without type", - )); - } - } - } - Ok(None) - } - - #[allow(dead_code)] - pub fn get_oapi_responses(&self) -> Vec<(LitInt, Type)> { - if let Some(oapi_options) = &self.oapi_options { - if let Some((_ident, Responses(responses))) = &oapi_options.responses { - return responses.clone(); - } - } - Default::default() - } - - #[allow(dead_code)] - pub fn get_oapi_security(&self) -> Vec<(LitStr, Vec)> { - if let Some(oapi_options) = &self.oapi_options { - if let Some((_ident, Security(security))) = &oapi_options.security { - return security - .iter() - .map(|(scheme, StrArray(scopes))| (scheme.clone(), scopes.clone())) - .collect(); - } - } - Default::default() - } - pub(crate) fn to_doc_comments(&self) -> TokenStream2 { let mut doc = format!( "# Handler information - Method: `{}` -- Path: `{}` -- State: `{}`", +- Path: `{}`", self.method.to_axum_method_name(), self.route_lit .as_ref() .map(|lit| lit.value()) .unwrap_or_else(|| "".into()), - self.state.to_token_stream(), ); if let Some(options) = &self.oapi_options { @@ -1118,33 +907,56 @@ impl CompiledRoute { #[doc = #doc] ) } -} -fn guess_state_type(sig: &syn::Signature) -> Type { - for arg in &sig.inputs { - if let FnArg::Typed(pat_type) = arg { - // Returns `T` if the type of the last segment is exactly `State`. - if let Type::Path(ty) = &*pat_type.ty { - let last_segment = ty.path.segments.last().unwrap(); - if last_segment.ident == "State" { - if let PathArguments::AngleBracketed(args) = &last_segment.arguments { - if args.args.len() == 1 { - if let GenericArgument::Type(ty) = args.args.first().unwrap() { - return ty.clone(); - } + fn url_without_queries_for_format(&self) -> Option { + // If there's no explicit route, we can't generate a format string this way. + let _lit = self.route_lit.as_ref()?; + + let url_without_queries = + self.path_params + .iter() + .fold(String::new(), |mut acc, (_slash, param)| { + acc.push('/'); + match param { + PathParam::Capture(lit, _brace_1, _, _, _brace_2) => { + acc.push_str(&format!("{{{}}}", lit.value())); + } + PathParam::WildCard(lit, _brace_1, _, _, _, _brace_2) => { + // no `*` since we want to use the argument *as the wildcard* when making requests + // it's not super applicable to server functions, more for general route generation + acc.push_str(&format!("{{{}}}", lit.value())); + } + PathParam::Static(lit) => { + acc.push_str(&lit.value()); } } - } - } - } - } + acc + }); - parse_quote! { () } + let prefix = self + .prefix + .as_ref() + .cloned() + .unwrap_or_else(|| LitStr::new("", Span::call_site())) + .value(); + let full_url = format!( + "{}{}{}", + prefix, + if url_without_queries.starts_with("/") { + "" + } else { + "/" + }, + url_without_queries + ); + + Some(full_url) + } } struct RouteParser { path_params: Vec<(Slash, PathParam)>, - query_params: Vec, + query_params: Vec, } impl RouteParser { @@ -1198,7 +1010,53 @@ impl RouteParser { if split_route.len() == 2 { let query = split_route[1]; for query_param in query.split('&') { - query_params.push(Ident::new(query_param, span)); + if query_param.starts_with(":") { + let ident = Ident::new(query_param.strip_prefix(":").unwrap(), span); + + query_params.push(QueryParam { + name: ident.to_string(), + binding: ident, + catch_all: true, + ty: parse_quote!(()), + arg_idx: usize::MAX, + }); + } else if query_param.starts_with("{") && query_param.ends_with("}") { + let ident = Ident::new( + query_param + .strip_prefix("{") + .unwrap() + .strip_suffix("}") + .unwrap(), + span, + ); + + query_params.push(QueryParam { + name: ident.to_string(), + binding: ident, + catch_all: true, + ty: parse_quote!(()), + arg_idx: usize::MAX, + }); + } else { + // if there's an `=` in the query param, we only take the left side as the name, and the right side is the binding + let name; + let binding; + if let Some((n, b)) = query_param.split_once('=') { + name = n; + binding = Ident::new(b, span); + } else { + name = query_param; + binding = Ident::new(query_param, span); + } + + query_params.push(QueryParam { + name: name.to_string(), + binding, + catch_all: false, + ty: parse_quote!(()), + arg_idx: usize::MAX, + }); + } } } @@ -1459,8 +1317,7 @@ fn doc_iter(attrs: &[Attribute]) -> impl Iterator + '_ { struct Route { method: Option, path_params: Vec<(Slash, PathParam)>, - query_params: Vec, - state: Option, + query_params: Vec, route_lit: Option, prefix: Option, oapi_options: Option, @@ -1485,13 +1342,6 @@ impl Parse for Route { query_params, } = RouteParser::new(route_lit.clone())?; - // todo: maybe let the user include `State` here, eventually? - // let state = match input.parse::() { - // Ok(_) => Some(input.parse::()?), - // Err(_) => None, - // }; - - let state = None; let oapi_options = input .peek(Brace) .then(|| { @@ -1512,7 +1362,6 @@ impl Parse for Route { method, path_params, query_params, - state, route_lit: Some(route_lit), oapi_options, server_args, diff --git a/packages/fullstack-server/src/launch.rs b/packages/fullstack-server/src/launch.rs index 0c805edf96..01624bed07 100644 --- a/packages/fullstack-server/src/launch.rs +++ b/packages/fullstack-server/src/launch.rs @@ -1,6 +1,6 @@ //! A launch function that creates an axum router for the LaunchBuilder -use crate::{server::DioxusRouterExt, RenderHandleState, ServeConfig}; +use crate::{server::DioxusRouterExt, FullstackState, ServeConfig}; use anyhow::Context; use axum::{ body::Body, @@ -9,10 +9,9 @@ use axum::{ Router, }; use dioxus_cli_config::base_path; -use dioxus_core::Element; -#[cfg(not(target_arch = "wasm32"))] -use dioxus_core::{RenderError, VNode}; -use dioxus_devtools::DevserverMsg; +use dioxus_core::{ComponentFunction, Element}; + +use dioxus_devtools::{DevserverMsg, HotReloadMsg}; use futures_util::{stream::FusedStream, StreamExt}; use hyper::body::Incoming; use hyper_util::server::conn::auto::Builder as HyperBuilder; @@ -20,14 +19,19 @@ use hyper_util::{ rt::{TokioExecutor, TokioIo}, service::TowerToHyperService, }; -use std::sync::Arc; -use std::{any::Any, collections::HashMap, net::SocketAddr, prelude::rust_2024::Future}; -use tokio::net::TcpStream; -use tokio_util::task::LocalPoolHandle; +use std::{any::Any, net::SocketAddr, prelude::rust_2024::Future}; +use std::{pin::Pin, sync::Arc}; +use subsecond::HotFn; +use tokio_util::either::Either; use tower::{Service, ServiceExt as _}; -type ContextList = Vec Box + Send + Sync>>; +#[cfg(not(target_arch = "wasm32"))] +use { + dioxus_core::{RenderError, VNode}, + tokio::net::TcpListener, +}; +type ContextList = Vec Box + Send + Sync>>; type BaseComp = fn() -> Element; /// Launch a fullstack app with the given root component. @@ -52,10 +56,6 @@ async fn serve_server( contexts: Vec Box + Send + Sync>>, platform_config: Vec>, ) { - let (devtools_tx, mut devtools_rx) = futures_channel::mpsc::unbounded(); - - dioxus_devtools::connect(move |msg| _ = devtools_tx.unbounded_send(msg)); - let mut cfg = platform_config .into_iter() .find_map(|cfg| cfg.downcast::().ok().map(|b| *b)) @@ -67,56 +67,117 @@ async fn serve_server( cfg.context_providers.push(arced); } - // Get the address the server should run on. If the CLI is running, the CLI proxies fullstack into the main address - // and we use the generated address the CLI gives us - let address = dioxus_cli_config::fullstack_address_or_localhost(); + let cb = move || { + let cfg = cfg.clone(); + Box::pin(async move { + Ok(apply_base_path( + Router::new().serve_dioxus_application(cfg.clone(), original_root), + original_root, + cfg.clone(), + base_path().map(|s| s.to_string()), + )) + }) as _ + }; + + serve_router(cb, dioxus_cli_config::fullstack_address_or_localhost()).await; +} - // Create the router and register the server functions under the basepath. - let router = apply_base_path( - Router::new().serve_dioxus_application(cfg.clone(), original_root), - original_root, - cfg.clone(), +/// Create a router that serves the dioxus application at the appropriate base path. +/// +/// This method automatically setups up: +/// - Static asset serving +/// - Mapping of base paths +/// - Automatic registration of server functions +/// - Handler to render the dioxus application +/// - WebSocket handling for live reload and devtools +/// - Hot-reloading +/// - Async Runtime +/// - Logging +pub fn router(app: fn() -> Element) -> Router { + let cfg = ServeConfig::new(); + apply_base_path( + Router::new().serve_dioxus_application(cfg.clone(), app), + app, + cfg, base_path().map(|s| s.to_string()), - ); + ) +} - let task_pool = LocalPoolHandle::new( - std::thread::available_parallelism() - .map(usize::from) - .unwrap_or(1), +/// Serve a fullstack dioxus application with a custom axum router. +/// +/// This function sets up an async runtime, enables the default dioxus logger, runs the provided initializer, +/// and then starts an axum server with the returned router. +/// +/// The axum router will be bound to the address specified by the `IP` and `PORT` environment variables, +/// defaulting to `127.0.0.1:8080` if not set. +/// +/// This function uses axum to block on serving the application, and will not return. +pub fn serve(mut serve_it: impl FnMut() -> F) -> ! +where + F: Future> + 'static, +{ + let cb = move || Box::pin(serve_it()) as _; + + block_on( + async move { serve_router(cb, dioxus_cli_config::fullstack_address_or_localhost()).await }, ); - let mut make_service = router.into_make_service(); + unreachable!("Serving a fullstack app should never return") +} - let listener = tokio::net::TcpListener::bind(address).await.unwrap(); +/// Serve a fullstack dioxus application with a custom axum router. +/// +/// This function enables the dioxus logger and then serves the axum server with hot-reloading support. +/// +/// To enable hot-reloading of the router, the provided `serve_callback` should return a new `Router` +/// each time it is called. +pub async fn serve_router( + mut serve_callback: impl FnMut() -> Pin>>>, + addr: SocketAddr, +) { + dioxus_logger::initialize_default(); + + let listener = TcpListener::bind(addr) + .await + .with_context(|| format!("Failed to bind to address {addr}")) + .unwrap(); - enum Msg { - TcpStream(std::io::Result<(TcpStream, SocketAddr)>), - Devtools(DevserverMsg), + // If we're not in debug mode, just serve the app normally + if !cfg!(debug_assertions) { + axum::serve(listener, serve_callback().await.unwrap()) + .await + .unwrap(); + return; } - let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(0); - let mut hr_idx = 0; + // Wire up the devtools connection. The sender only sends messages in dev. + let (devtools_tx, mut devtools_rx) = futures_channel::mpsc::unbounded(); + dioxus_devtools::connect(move |msg| _ = devtools_tx.unbounded_send(msg)); + + let mut hot_serve_callback = HotFn::current(serve_callback); + let mut make_service = hot_serve_callback + .call(()) + .await + .map(|router| router.into_make_service()) + .unwrap(); + + let (shutdown_tx, _) = tokio::sync::broadcast::channel(1); + let our_build_id = Some(dioxus_cli_config::build_id()); // Manually loop on accepting connections so we can also respond to devtools messages loop { let res = tokio::select! { - res = listener.accept() => Msg::TcpStream(res), - msg = devtools_rx.next(), if !devtools_rx.is_terminated() => { - if let Some(msg) = msg { - Msg::Devtools(msg) - } else { - continue; - } - } + res = listener.accept() => Either::Left(res), + Some(msg) = devtools_rx.next(), if !devtools_rx.is_terminated() => Either::Right(msg), + else => continue }; match res { - Msg::TcpStream(Ok((tcp_stream, _remote_addr))) => { - let this_hr_index = hr_idx; + Either::Left(Ok((tcp_stream, _remote_addr))) => { let mut make_service = make_service.clone(); - let mut shutdown_rx = shutdown_rx.clone(); + let mut shutdown_rx = shutdown_tx.subscribe(); - task_pool.spawn_pinned(move || async move { + tokio::task::spawn(async move { let tcp_stream = TokioIo::new(tcp_stream); std::future::poll_fn(|cx| { @@ -126,19 +187,19 @@ async fn serve_server( ) }) .await - .unwrap(); - - let tower_service = make_service - .call(()) - .await - .unwrap() - .map_request(|req: Request| req.map(Body::new)); + .expect("Infallible"); // upgrades needed for websockets let builder = HyperBuilder::new(TokioExecutor::new()); let connection = builder.serve_connection_with_upgrades( tcp_stream, - TowerToHyperService::new(tower_service), + TowerToHyperService::new( + make_service + .call(()) + .await + .unwrap() + .map_request(|req: Request| req.map(Body::new)), + ), ); tokio::select! { @@ -151,91 +212,59 @@ async fn serve_server( // appear. } } - _res = shutdown_rx.wait_for(|i| *i == this_hr_index + 1) => {} + _res = shutdown_rx.recv() => {} } }); } - Msg::TcpStream(Err(_)) => {} - // We need to delete our old router and build a new one - // - // one challenge is that the server functions are sitting in the dlopened lib and no longer - // accessible by us (the original process) - // - // We need to somehow get them out... ? + + // Handle just hot-patches for now. + // We don't do RSX hot-reload since usually the client handles that once the page is loaded. // - // for now we just support editing existing server functions - Msg::Devtools(devserver_msg) => { - match devserver_msg { - DevserverMsg::HotReload(hot_reload_msg) => { - if hot_reload_msg.for_build_id == Some(dioxus_cli_config::build_id()) { - if let Some(table) = hot_reload_msg.jump_table { - use crate::ServerFunction; - - unsafe { dioxus_devtools::subsecond::apply_patch(table).unwrap() }; - - let mut new_router = Router::new().serve_static_assets(); - let new_cfg = ServeConfig::new(); - - let server_fn_iter = ServerFunction::collect(); - - // de-duplicate iteratively by preferring the most recent (first, since it's linked) - let mut server_fn_map: HashMap<_, _> = HashMap::new(); - for f in server_fn_iter.into_iter().rev() { - server_fn_map.insert(f.path(), f); - } - - for (_, fn_) in server_fn_map { - tracing::trace!( - "Registering server function: {:?} {:?}", - fn_.path(), - fn_.method() - ); - new_router = fn_.register_server_fn_on_router(new_router); - } - - let hot_root = subsecond::HotFn::current(original_root); - let new_root_addr = hot_root.ptr_address().0 as usize as *const (); - let new_root = unsafe { - std::mem::transmute::<*const (), fn() -> Element>(new_root_addr) - }; - - crate::document::reset_renderer(); - - let state = RenderHandleState::new(new_cfg.clone(), new_root); - - let fallback_handler = - axum::routing::get(RenderHandleState::render_handler) - .with_state(state); - - make_service = apply_base_path( - new_router.fallback(fallback_handler), - new_root, - new_cfg.clone(), - base_path().map(|s| s.to_string()), - ) - .into_make_service(); - - shutdown_tx.send_modify(|i| { - *i += 1; - hr_idx += 1; - }); - } - } - } - DevserverMsg::FullReloadStart => {} - DevserverMsg::FullReloadFailed => {} - DevserverMsg::FullReloadCommand => {} - DevserverMsg::Shutdown => {} - _ => {} - } + // todo(jon): I *believe* SSR is resilient to RSX changes, but we should verify that... + Either::Right(DevserverMsg::HotReload(HotReloadMsg { + jump_table: Some(table), + for_build_id, + .. + })) if for_build_id == our_build_id => { + // Apply the hot-reload patch to the dioxus devtools first + unsafe { dioxus_devtools::subsecond::apply_patch(table).unwrap() }; + + // Now recreate the router + // We panic here because we don't want their app to continue in a maybe-corrupted state + make_service = hot_serve_callback + .call(()) + .await + .expect("Failed to create new router after hot-patch!") + .into_make_service(); + + // Make sure to wipe out the renderer state so we don't have stale elements + crate::document::reset_renderer(); + + _ = shutdown_tx.send(()); } + + // Explicitly don't handle RSX hot-reloads on the server + // The client will handle that once the page is loaded. If we handled it here, + _ => {} } } } -fn apply_base_path( +fn block_on(app_future: impl Future) { + if let Ok(handle) = tokio::runtime::Handle::try_current() { + handle.block_on(app_future); + } else { + tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap() + .block_on(app_future); + } +} + +fn apply_base_path( mut router: Router, - root: fn() -> Element, + root: impl ComponentFunction<(), M> + Send + Sync, cfg: ServeConfig, base_path: Option, ) -> Router { @@ -247,75 +276,15 @@ fn apply_base_path( router = Router::new().nest(&format!("/{base_path}/"), router).route( &format!("/{base_path}"), axum::routing::method_routing::get( - |state: State, mut request: Request| async move { + |state: State, mut request: Request| async move { // The root of the base path always looks like the root from dioxus fullstack *request.uri_mut() = "/".parse().unwrap(); - RenderHandleState::render_handler(state, request).await + FullstackState::render_handler(state, request).await }, ) - .with_state(RenderHandleState::new(cfg, root)), + .with_state(FullstackState::new(cfg, root)), ) } router } - -/// Serve a fullstack dioxus application with a custom axum router. -/// -/// This function sets up an async runtime, enables the default dioxus logger, runs the provided initializer, -/// and then starts an axum server with the returned router. -/// -/// The axum router will be bound to the address specified by the `IP` and `PORT` environment variables, -/// defaulting to `127.0.0.1:8080` if not set. -pub fn serve(mut serve_it: impl FnMut() -> F) -> ! -where - F: Future>, -{ - dioxus_logger::initialize_default(); - - let app_future = async move { - let router = serve_it().await.expect("Failed to create axum router"); - let address = dioxus_cli_config::fullstack_address_or_localhost(); - let listener = tokio::net::TcpListener::bind(address) - .await - .with_context(|| format!("Failed to bind app to given address: {address}")) - .unwrap(); - tracing::trace!("Listening on {address}"); - axum::serve::serve(listener, router) - .await - .expect("Failed to serve axum app"); - }; - - if let Ok(handle) = tokio::runtime::Handle::try_current() { - handle.block_on(app_future); - } else { - tokio::runtime::Builder::new_multi_thread() - .enable_all() - .build() - .unwrap() - .block_on(app_future); - } - - unreachable!("Serving a fullstack app should never return") -} - -/// Create a router that serves the dioxus application at the appropriate base path. -/// -/// This method automatically setups up: -/// - Static asset serving -/// - Mapping of base paths -/// - Automatic registration of server functions -/// - Handler to render the dioxus application -/// - WebSocket handling for live reload and devtools -/// - Hot-reloading -/// - Async Runtime -/// - Logging -pub fn router(app: fn() -> Element) -> Router { - let cfg = ServeConfig::new(); - apply_base_path( - Router::new().serve_dioxus_application(cfg.clone(), app), - app, - cfg, - base_path().map(|s| s.to_string()), - ) -} diff --git a/packages/fullstack-server/src/server.rs b/packages/fullstack-server/src/server.rs index ad55b3d8b7..a8a1678ef8 100644 --- a/packages/fullstack-server/src/server.rs +++ b/packages/fullstack-server/src/server.rs @@ -6,66 +6,27 @@ use axum::{ body::Body, extract::State, http::{Request, StatusCode}, - response::IntoResponse, - response::Response, + response::{IntoResponse, Response}, routing::*, }; -use dioxus_core::{Element, VirtualDom}; +use dioxus_core::{ComponentFunction, VirtualDom}; use http::header::*; use std::path::{Path, PathBuf}; use std::sync::Arc; +use tokio_util::task::LocalPoolHandle; use tower::util::MapResponse; use tower::ServiceExt; use tower_http::services::fs::ServeFileSystemResponseBody; -/// SSR renderer handler for Axum with added context injection. -/// -/// # Example -/// ```rust,no_run -/// #![allow(non_snake_case)] -/// use std::sync::{Arc, Mutex}; -/// -/// use axum::routing::get; -/// use dioxus::prelude::*; -/// use dioxus_server::{RenderHandleState, render_handler, ServeConfig}; -/// -/// fn app() -> Element { -/// rsx! { -/// "hello!" -/// } -/// } -/// -/// #[tokio::main] -/// async fn main() { -/// let addr = dioxus::cli_config::fullstack_address_or_localhost(); -/// let router = axum::Router::new() -/// // Register server functions, etc. -/// // Note you can use `register_server_functions_with_context` -/// // to inject the context into server functions running outside -/// // of an SSR render context. -/// .fallback(get(render_handler)) -/// .with_state(RenderHandleState::new(ServeConfig::new(), app)); -/// -/// let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); -/// axum::serve(listener, router).await.unwrap(); -/// } -/// ``` -pub async fn render_handler( - State(state): State, - request: Request, -) -> impl IntoResponse { - RenderHandleState::render_handler(State(state), request).await -} - /// A extension trait with utilities for integrating Dioxus with your Axum router. -pub trait DioxusRouterExt: DioxusRouterFnExt { +pub trait DioxusRouterExt { /// Serves the static WASM for your Dioxus application (except the generated index.html). /// /// # Example /// ```rust, no_run /// # #![allow(non_snake_case)] /// # use dioxus::prelude::*; - /// use dioxus_server::{DioxusRouterExt, DioxusRouterFnExt}; + /// use dioxus_server::DioxusRouterExt; /// /// #[tokio::main] /// async fn main() -> anyhow::Result<()> { @@ -75,15 +36,13 @@ pub trait DioxusRouterExt: DioxusRouterFnExt { /// .serve_static_assets() /// // Server render the application /// // ... - /// .into_make_service(); + /// .with_state(dioxus_server::FullstackState::headless()); /// let listener = tokio::net::TcpListener::bind(addr).await?; /// axum::serve(listener, router).await?; /// Ok(()) /// } /// ``` - fn serve_static_assets(self) -> Self - where - Self: Sized; + fn serve_static_assets(self) -> Router; /// Serves the Dioxus application. This will serve a complete server side rendered application. /// This will serve static assets, server render the application, register server functions, and integrate with hot reloading. @@ -92,15 +51,14 @@ pub trait DioxusRouterExt: DioxusRouterFnExt { /// ```rust, no_run /// # #![allow(non_snake_case)] /// # use dioxus::prelude::*; - /// use dioxus_server::{DioxusRouterExt, DioxusRouterFnExt, ServeConfig}; + /// use dioxus_server::{DioxusRouterExt, ServeConfig}; /// /// #[tokio::main] /// async fn main() { /// let addr = dioxus::cli_config::fullstack_address_or_localhost(); /// let router = axum::Router::new() /// // Server side render the application, serve static assets, and register server functions - /// .serve_dioxus_application(ServeConfig::new(), app) - /// .into_make_service(); + /// .serve_dioxus_application(ServeConfig::new(), app); /// let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); /// axum::serve(listener, router).await.unwrap(); /// } @@ -109,57 +67,31 @@ pub trait DioxusRouterExt: DioxusRouterFnExt { /// rsx! { "Hello World" } /// } /// ``` - fn serve_dioxus_application(self, cfg: ServeConfig, app: fn() -> Element) -> Self - where - Self: Sized; -} - -#[cfg(not(target_arch = "wasm32"))] -impl DioxusRouterExt for Router -where - S: Send + Sync + Clone + 'static, -{ - fn serve_static_assets(self) -> Self { - let Some(public_path) = public_path() else { - return self; - }; - - // Serve all files in public folder except index.html - serve_dir_cached(self, &public_path, &public_path) - } - - fn serve_dioxus_application(self, cfg: ServeConfig, app: fn() -> Element) -> Self { - self.register_server_functions() - .serve_static_assets() - .fallback( - get(RenderHandleState::render_handler).with_state(RenderHandleState::new(cfg, app)), - ) - } -} + fn serve_dioxus_application( + self, + cfg: ServeConfig, + app: impl ComponentFunction<(), M> + Send + Sync, + ) -> Router<()>; -/// A extension trait with server function utilities for integrating Dioxus with your Axum router. -pub trait DioxusRouterFnExt { /// Registers server functions with the default handler. /// /// # Example /// ```rust, no_run /// # use dioxus::prelude::*; - /// # use dioxus_server::DioxusRouterFnExt; + /// # use dioxus_server::DioxusRouterExt; /// #[tokio::main] /// async fn main() { /// let addr = dioxus::cli_config::fullstack_address_or_localhost(); /// let router = axum::Router::new() /// // Register server functions routes with the default handler /// .register_server_functions() - /// .into_make_service(); + /// .with_state(dioxus_server::FullstackState::headless()); /// let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); /// axum::serve(listener, router).await.unwrap(); /// } /// ``` #[allow(dead_code)] - fn register_server_functions(self) -> Self - where - Self: Sized; + fn register_server_functions(self) -> Router; /// Serves a Dioxus application without static assets. /// Sets up server function routes and rendering endpoints only. @@ -170,7 +102,7 @@ pub trait DioxusRouterFnExt { /// # Example /// ```rust, no_run /// # use dioxus::prelude::*; - /// # use dioxus_server::{DioxusRouterFnExt, ServeConfig}; + /// # use dioxus_server::{DioxusRouterExt, ServeConfig}; /// #[tokio::main] /// async fn main() { /// let router = axum::Router::new() @@ -183,67 +115,177 @@ pub trait DioxusRouterFnExt { /// rsx! { "Hello World" } /// } /// ``` - fn serve_api_application(self, cfg: ServeConfig, app: fn() -> Element) -> Self + fn serve_api_application( + self, + cfg: ServeConfig, + app: impl ComponentFunction<(), M> + Send + Sync, + ) -> Router<()> where Self: Sized; } -impl DioxusRouterFnExt for Router { - fn register_server_functions(mut self) -> Self { - for func in ServerFunction::collect() { - tracing::info!( - "Registering server function: {} {}", - func.method(), - func.path() - ); +#[cfg(not(target_arch = "wasm32"))] +impl DioxusRouterExt for Router { + fn register_server_functions(mut self) -> Router { + use std::collections::HashSet; + + let mut seen = HashSet::new(); - self = func.register_server_fn_on_router(self); + for func in ServerFunction::collect() { + if seen.insert(format!("{} {}", func.method(), func.path())) { + tracing::info!( + "Registering server function: {} {}", + func.method(), + func.path() + ); + + self = self.route(func.path(), func.method_router()) + } } + self } - fn serve_api_application(self, cfg: ServeConfig, app: fn() -> Element) -> Self - where - Self: Sized, - { - self.register_server_functions().fallback( - get(RenderHandleState::render_handler).with_state(RenderHandleState::new(cfg, app)), - ) + fn serve_static_assets(self) -> Router { + let Some(public_path) = public_path() else { + return self; + }; + + // Serve all files in public folder except index.html + serve_dir_cached(self, &public_path, &public_path) + } + + fn serve_api_application( + self, + cfg: ServeConfig, + app: impl ComponentFunction<(), M> + Send + Sync, + ) -> Router<()> { + self.register_server_functions() + .fallback(get(FullstackState::render_handler)) + .with_state(FullstackState::new(cfg, app)) + } + + fn serve_dioxus_application( + self, + cfg: ServeConfig, + app: impl ComponentFunction<(), M> + Send + Sync, + ) -> Router<()> { + self.register_server_functions() + .serve_static_assets() + .fallback(get(FullstackState::render_handler)) + .with_state(FullstackState::new(cfg, app)) } } -/// State used by [`RenderHandleState::render_handler`] to render a dioxus component with axum +/// SSR renderer handler for Axum with added context injection. +/// +/// # Example +/// ```rust,no_run +/// #![allow(non_snake_case)] +/// use std::sync::{Arc, Mutex}; +/// +/// use axum::routing::get; +/// use dioxus::prelude::*; +/// use dioxus_server::{FullstackState, render_handler, ServeConfig}; +/// +/// fn app() -> Element { +/// rsx! { +/// "hello!" +/// } +/// } +/// +/// #[tokio::main] +/// async fn main() { +/// let addr = dioxus::cli_config::fullstack_address_or_localhost(); +/// let router = axum::Router::new() +/// // Register server functions, etc. +/// // Note you can use `register_server_functions_with_context` +/// // to inject the context into server functions running outside +/// // of an SSR render context. +/// .fallback(get(render_handler)) +/// .with_state(FullstackState::new(ServeConfig::new(), app)); +/// +/// let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); +/// axum::serve(listener, router).await.unwrap(); +/// } +/// ``` +pub async fn render_handler( + State(state): State, + request: Request, +) -> impl IntoResponse { + FullstackState::render_handler(State(state), request).await +} + +/// State used by [`FullstackState::render_handler`] to render a dioxus component with axum #[derive(Clone)] -pub struct RenderHandleState { +pub struct FullstackState { config: ServeConfig, build_virtual_dom: Arc VirtualDom + Send + Sync>, renderers: Arc, + pub(crate) rt: LocalPoolHandle, } -impl RenderHandleState { - /// Create a new [`RenderHandleState`] - pub fn new(config: ServeConfig, root: fn() -> Element) -> Self { +impl FullstackState { + /// Create a headless [`FullstackState`] without a root component. + /// + /// This won't render pages, but can still be used to register server functions and serve static assets. + pub fn headless() -> Self { + let rt = LocalPoolHandle::new( + std::thread::available_parallelism() + .map(usize::from) + .unwrap_or(1), + ); + + Self { + renderers: Arc::new(SsrRendererPool::new(4, None)), + build_virtual_dom: Arc::new(|| { + panic!("No root component provided for headless FullstackState") + }), + config: ServeConfig::new(), + rt, + } + } + + /// Create a new [`FullstackState`] + pub fn new( + config: ServeConfig, + root: impl ComponentFunction<(), M> + Send + Sync + 'static, + ) -> Self { + let rt = LocalPoolHandle::new( + std::thread::available_parallelism() + .map(usize::from) + .unwrap_or(1), + ); + Self { renderers: Arc::new(SsrRendererPool::new(4, config.incremental.clone())), - build_virtual_dom: Arc::new(move || VirtualDom::new(root)), + build_virtual_dom: Arc::new(move || VirtualDom::new_with_props(root.clone(), ())), config, + rt, } } - /// Create a new [`RenderHandleState`] with a custom [`VirtualDom`] factory. This method can be + /// Create a new [`FullstackState`] with a custom [`VirtualDom`] factory. This method can be /// used to pass context into the root component of your application. pub fn new_with_virtual_dom_factory( config: ServeConfig, build_virtual_dom: impl Fn() -> VirtualDom + Send + Sync + 'static, ) -> Self { + let rt = LocalPoolHandle::new( + std::thread::available_parallelism() + .map(usize::from) + .unwrap_or(1), + ); + Self { renderers: Arc::new(SsrRendererPool::new(4, config.incremental.clone())), config, build_virtual_dom: Arc::new(build_virtual_dom), + rt, } } - /// Set the [`ServeConfig`] for this [`RenderHandleState`] + /// Set the [`ServeConfig`] for this [`FullstackState`] pub fn with_config(mut self, config: ServeConfig) -> Self { self.config = config; self @@ -258,7 +300,7 @@ impl RenderHandleState { /// /// use axum::routing::get; /// use dioxus::prelude::*; - /// use dioxus_server::{RenderHandleState, render_handler, ServeConfig}; + /// use dioxus_server::{FullstackState, render_handler, ServeConfig}; /// /// fn app() -> Element { /// rsx! { @@ -275,7 +317,7 @@ impl RenderHandleState { /// // to inject the context into server functions running outside /// // of an SSR render context. /// .fallback(get(render_handler)) - /// .with_state(RenderHandleState::new(ServeConfig::new(), app)) + /// .with_state(FullstackState::new(ServeConfig::new(), app)) /// .into_make_service(); /// /// let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); @@ -288,7 +330,7 @@ impl RenderHandleState { let response = state .renderers .clone() - .render_to(parts, &state.config, { + .render_to(parts, &state.config, &state.rt, { let build_virtual_dom = state.build_virtual_dom.clone(); let context_providers = state.config.context_providers.clone(); move || { diff --git a/packages/fullstack-server/src/serverfn.rs b/packages/fullstack-server/src/serverfn.rs index fe9f8a5795..ae01bd87a0 100644 --- a/packages/fullstack-server/src/serverfn.rs +++ b/packages/fullstack-server/src/serverfn.rs @@ -1,35 +1,33 @@ -use axum::body::Body; -use axum::routing::MethodRouter; -use axum::Router; -use dashmap::DashMap; -use dioxus_fullstack_core::{DioxusServerState, FullstackContext}; -use http::Method; -use std::{marker::PhantomData, sync::LazyLock}; - -pub type AxumRequest = http::Request; -pub type AxumResponse = http::Response; +use crate::FullstackState; +use axum::{ + body::Body, + extract::{Request, State}, + response::Response, + routing::MethodRouter, +}; +use dioxus_fullstack_core::FullstackContext; +use http::{Method, StatusCode}; +use std::{pin::Pin, prelude::rust_2024::Future}; /// A function endpoint that can be called from the client. #[derive(Clone)] -pub struct ServerFunction { +pub struct ServerFunction { path: &'static str, method: Method, - handler: fn() -> MethodRouter, - _phantom: PhantomData, + handler: fn() -> MethodRouter, } impl ServerFunction { - /// Create a new server function object. + /// Create a new server function object from a MethodRouter pub const fn new( method: Method, path: &'static str, - handler: fn() -> MethodRouter, + handler: fn() -> MethodRouter, ) -> Self { Self { path, method, handler, - _phantom: PhantomData, } } @@ -43,72 +41,103 @@ impl ServerFunction { self.method.clone() } + /// Collect all globally registered server functions pub fn collect() -> Vec<&'static ServerFunction> { inventory::iter::().collect() } - pub fn handler(&self) -> fn() -> MethodRouter { - self.handler + /// Create a `MethodRouter` for this server function that can be mounted on an `axum::Router`. + /// + /// This runs the handler inside the required `FullstackContext` scope and populates + /// `FullstackContext` so that the handler can use those features. + /// + /// It also runs the server function inside a tokio `LocalPool` to allow !Send futures. + pub fn method_router(&self) -> MethodRouter { + (self.handler)() } - pub fn register_server_fn_on_router(&'static self, router: Router) -> Router - where - S: Send + Sync + Clone + 'static, - { - // // store Accepts and Referrer in case we need them for redirect (below) - // let referrer = req.headers().get(REFERER).cloned(); - // let accepts_html = req - // .headers() - // .get(ACCEPT) - // .and_then(|v| v.to_str().ok()) - // .map(|v| v.contains("text/html")) - // .unwrap_or(false); - - // // it it accepts text/html (i.e., is a plain form post) and doesn't already have a - // // Location set, then redirect to Referer - // if accepts_html { - // if let Some(referrer) = referrer { - // let has_location = res.headers().get(LOCATION).is_some(); - // if !has_location { - // *res.status_mut() = StatusCode::FOUND; - // res.headers_mut().insert(LOCATION, referrer); - // } - // } - // } - - async fn server_context_middleware( - request: axum::extract::Request, - next: axum::middleware::Next, - ) -> axum::response::Response { - let (parts, body) = request.into_parts(); - let server_context = FullstackContext::new(parts.clone()); - let request = axum::extract::Request::from_parts(parts, body); - - server_context - .scope(async move { - // Run the next middleware / handler inside the server context - let mut response = next.run(request).await; - - let server_context = FullstackContext::current().expect( - "Server context should be available inside the server context scope", - ); - - // Get the extra response headers set during the handler and add them to the response - let headers = server_context.take_response_headers(); - if let Some(headers) = headers { - response.headers_mut().extend(headers); - } - - response - }) - .await - } - - router.route( - self.path(), - ((self.handler)()) - .with_state(DioxusServerState {}) - .layer(axum::middleware::from_fn(server_context_middleware)), + /// Creates a new `MethodRouter` for the given method and !Send handler. + /// + /// This is used internally by the `ServerFunction` to create the method router that this + /// server function uses. + #[allow(clippy::type_complexity)] + pub fn make_handler( + method: Method, + handler: fn(State, Request) -> Pin>>, + ) -> MethodRouter { + axum::routing::method_routing::on( + method + .try_into() + .expect("MethodFilter only supports standard HTTP methods"), + move |state: State, request: Request| async move { + // Allow !Send futures by running in the render handlers pinned local pool + let result = state.rt.spawn_pinned(move || async move { + use dioxus_fullstack_core::FullstackContext; + use http::header::{ACCEPT, LOCATION, REFERER}; + use http::StatusCode; + + // todo: we're copying the parts here, but it'd be ideal if we didn't. + // We can probably just pass the URI in so the matching logic can work and then + // in the server function, do all extraction via FullstackContext. This ensures + // calls to `.remove()` work as expected. + let (parts, body) = request.into_parts(); + let server_context = FullstackContext::new(parts.clone()); + let request = axum::extract::Request::from_parts(parts, body); + + // store Accepts and Referrer in case we need them for redirect (below) + let referrer = request.headers().get(REFERER).cloned(); + let accepts_html = request + .headers() + .get(ACCEPT) + .and_then(|v| v.to_str().ok()) + .map(|v| v.contains("text/html")) + .unwrap_or(false); + + server_context + .clone() + .scope(async move { + // Run the next middleware / handler inside the server context + let mut response = handler(State(server_context), request).await; + + let server_context = FullstackContext::current().expect( + "Server context should be available inside the server context scope", + ); + + // Get the extra response headers set during the handler and add them to the response + let headers = server_context.take_response_headers(); + if let Some(headers) = headers { + response.headers_mut().extend(headers); + } + + // it it accepts text/html (i.e., is a plain form post) and doesn't already have a + // Location set, then redirect to Referer + if accepts_html { + if let Some(referrer) = referrer { + let has_location = response.headers().get(LOCATION).is_some(); + if !has_location { + *response.status_mut() = StatusCode::FOUND; + response.headers_mut().insert(LOCATION, referrer); + } + } + } + + response + }) + .await + }).await; + + match result { + Ok(response) => response, + Err(err) => Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(Body::new(if cfg!(debug_assertions) { + format!("Server function panicked: {}", err) + } else { + "Internal Server Error".to_string() + })) + .unwrap(), + } + }, ) } } @@ -120,18 +149,3 @@ impl inventory::Collect for ServerFunction { ®ISTRY } } - -/// The set of all registered server function paths. -pub fn server_fn_paths() -> impl Iterator { - REGISTERED_SERVER_FUNCTIONS - .iter() - .map(|item| (item.path(), item.method())) -} - -type LazyServerFnMap = LazyLock>; -static REGISTERED_SERVER_FUNCTIONS: LazyServerFnMap = std::sync::LazyLock::new(|| { - inventory::iter:: - .into_iter() - .map(|obj| ((obj.path().to_string(), obj.method()), obj.clone())) - .collect() -}); diff --git a/packages/fullstack-server/src/ssr.rs b/packages/fullstack-server/src/ssr.rs index 3f13a46de6..11c874e307 100644 --- a/packages/fullstack-server/src/ssr.rs +++ b/packages/fullstack-server/src/ssr.rs @@ -21,12 +21,11 @@ use http::{request::Parts, HeaderMap, StatusCode}; use std::{ collections::HashMap, fmt::Write, - future::Future, iter::Peekable, rc::Rc, sync::{Arc, RwLock}, }; -use tokio::task::JoinHandle; +use tokio_util::task::LocalPoolHandle; use crate::StreamingMode; @@ -99,6 +98,7 @@ impl SsrRendererPool { self: Arc, parts: Parts, cfg: &ServeConfig, + rt: &LocalPoolHandle, virtual_dom_factory: impl FnOnce() -> VirtualDom + Send + Sync + 'static, ) -> Result< ( @@ -420,7 +420,8 @@ impl SsrRendererPool { myself.renderers.write().unwrap().push(renderer); }; - let join_handle = Self::spawn_platform(create_render_future); + // Spawn the render future onto the local pool + let join_handle = rt.spawn_pinned(create_render_future); // Wait for the initial result which determines the status code let (status, headers) = initial_result_rx @@ -763,32 +764,4 @@ impl SsrRendererPool { Ok(()) } - - /// Spawn a task in the background. If wasm is enabled, this will use the single threaded tokio runtime - fn spawn_platform(f: impl FnOnce() -> Fut + Send + 'static) -> JoinHandle - where - Fut: Future + 'static, - Fut::Output: Send + 'static, - { - #[cfg(not(target_arch = "wasm32"))] - { - use tokio_util::task::LocalPoolHandle; - static TASK_POOL: std::sync::OnceLock = std::sync::OnceLock::new(); - - let pool = TASK_POOL.get_or_init(|| { - LocalPoolHandle::new( - std::thread::available_parallelism() - .map(usize::from) - .unwrap_or(1), - ) - }); - - pool.spawn_pinned(f) - } - - #[cfg(target_arch = "wasm32")] - { - tokio::task::spawn_local(f()) - } - } } diff --git a/packages/fullstack/Cargo.toml b/packages/fullstack/Cargo.toml index 0060fa726a..765c6e2489 100644 --- a/packages/fullstack/Cargo.toml +++ b/packages/fullstack/Cargo.toml @@ -66,7 +66,7 @@ tower-http = { workspace = true, features = ["fs", "limit"], optional = true } tower-layer = { version = "0.3.3", optional = true } # payloads -postcard = { features = ["alloc"], optional = true, workspace = true, default-features = true } +postcard = { features = ["alloc", "use-std"], optional = true, workspace = true, default-features = true } rmp-serde = { version = "1.3", optional = true } async-stream = "0.3.6" diff --git a/packages/fullstack/src/client.rs b/packages/fullstack/src/client.rs index 8f4a69aaef..c90a21184a 100644 --- a/packages/fullstack/src/client.rs +++ b/packages/fullstack/src/client.rs @@ -129,10 +129,14 @@ impl ClientRequest { } /// Creates a new reqwest request builder with the method, url, and headers set from this ClientRequest + /// + /// Using this method attaches `X-Request-Client: dioxus` header to the request. pub fn new_reqwest_request(&self) -> reqwest::RequestBuilder { let client = GLOBAL_REQUEST_CLIENT.get_or_init(Self::new_reqwest_client); - let mut req = client.request(self.method.clone(), self.url.clone()); + let mut req = client + .request(self.method.clone(), self.url.clone()) + .header("X-Request-Client", "dioxus"); for (key, value) in self.headers.iter() { req = req.header(key, value); @@ -141,6 +145,7 @@ impl ClientRequest { req } + /// Using this method attaches `X-Request-Client-Dioxus` header to the request. #[cfg(feature = "web")] pub fn new_gloo_request(&self) -> gloo_net::http::RequestBuilder { let mut builder = gloo_net::http::RequestBuilder::new( @@ -155,6 +160,7 @@ impl ClientRequest { ) .as_str(), ) + .header("X-Request-Client", "dioxus") .method(self.method.clone()); for (key, value) in self.headers.iter() { diff --git a/packages/fullstack/src/encoding.rs b/packages/fullstack/src/encoding.rs index 22408b48cf..6e7557800c 100644 --- a/packages/fullstack/src/encoding.rs +++ b/packages/fullstack/src/encoding.rs @@ -7,8 +7,13 @@ use serde::{de::DeserializeOwned, Serialize}; pub trait Encoding { fn content_type() -> &'static str; fn stream_content_type() -> &'static str; - fn to_bytes(data: impl Serialize) -> Option; - fn from_bytes(bytes: Bytes) -> Option; + fn to_bytes(data: impl Serialize) -> Option { + let mut buf = Vec::new(); + Self::encode(data, &mut buf)?; + Some(buf.into()) + } + fn encode(data: impl Serialize, buf: &mut Vec) -> Option; + fn decode(bytes: Bytes) -> Option; } pub struct JsonEncoding; @@ -19,11 +24,14 @@ impl Encoding for JsonEncoding { fn stream_content_type() -> &'static str { "application/stream+json" } - fn to_bytes(data: impl Serialize) -> Option { - serde_json::to_vec(&data).ok().map(Into::into) + + fn encode(data: impl Serialize, mut buf: &mut Vec) -> Option { + let len = buf.len(); + serde_json::to_writer(&mut buf, &data).ok()?; + Some(buf.len() - len) } - fn from_bytes(bytes: Bytes) -> Option { + fn decode(bytes: Bytes) -> Option { serde_json::from_slice(&bytes).ok() } } @@ -36,15 +44,16 @@ impl Encoding for CborEncoding { fn stream_content_type() -> &'static str { "application/stream+cbor" } - fn to_bytes(data: impl Serialize) -> Option { - let mut buf = Vec::new(); - ciborium::into_writer(&data, &mut buf).ok()?; - Some(buf.into()) - } - fn from_bytes(bytes: Bytes) -> Option { + fn decode(bytes: Bytes) -> Option { ciborium::de::from_reader(bytes.as_ref()).ok() } + + fn encode(data: impl Serialize, mut buf: &mut Vec) -> Option { + let len = buf.len(); + ciborium::into_writer(&data, &mut buf).ok()?; + Some(buf.len() - len) + } } #[cfg(feature = "postcard")] @@ -57,11 +66,14 @@ impl Encoding for PostcardEncoding { fn stream_content_type() -> &'static str { "application/stream+postcard" } - fn to_bytes(data: impl Serialize) -> Option { - postcard::to_allocvec(&data).ok().map(Into::into) + + fn encode(data: impl Serialize, mut buf: &mut Vec) -> Option { + let len = buf.len(); + postcard::to_io(&data, &mut buf).ok()?; + Some(buf.len() - len) } - fn from_bytes(bytes: Bytes) -> Option { + fn decode(bytes: Bytes) -> Option { postcard::from_bytes(bytes.as_ref()).ok() } } @@ -76,11 +88,13 @@ impl Encoding for MsgPackEncoding { fn stream_content_type() -> &'static str { "application/stream+msgpack" } - fn to_bytes(data: impl Serialize) -> Option { - rmp_serde::to_vec(&data).ok().map(Into::into) + fn encode(data: impl Serialize, buf: &mut Vec) -> Option { + let len = buf.len(); + rmp_serde::encode::write(buf, &data).ok()?; + Some(buf.len() - len) } - fn from_bytes(bytes: Bytes) -> Option { + fn decode(bytes: Bytes) -> Option { rmp_serde::from_slice(&bytes).ok() } } diff --git a/packages/fullstack/src/lazy.rs b/packages/fullstack/src/lazy.rs index 9ef68d770d..e152ba8c75 100644 --- a/packages/fullstack/src/lazy.rs +++ b/packages/fullstack/src/lazy.rs @@ -88,6 +88,20 @@ impl Lazy { } Ok(()) } + + /// Get a reference to the value of the `Lazy` instance. This will block the current thread if the + /// value is not yet initialized. + pub fn get(&self) -> &T { + if self.constructor.is_none() { + return self.value.get().expect("Lazy value is not initialized. Make sure to call `initialize` before dereferencing."); + }; + + if self.value.get().is_none() { + self.initialize().expect("Failed to initialize lazy value"); + } + + self.value.get().unwrap() + } } impl Default for Lazy { @@ -100,15 +114,13 @@ impl std::ops::Deref for Lazy { type Target = T; fn deref(&self) -> &Self::Target { - if self.constructor.is_none() { - return self.value.get().expect("Lazy value is not initialized. Make sure to call `initialize` before dereferencing."); - }; - - if self.value.get().is_none() { - self.initialize().expect("Failed to initialize lazy value"); - } + self.get() + } +} - self.value.get().unwrap() +impl std::fmt::Debug for Lazy { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Lazy").field("value", self.get()).finish() } } @@ -129,10 +141,16 @@ where { let ptr: F = unsafe { std::mem::zeroed() }; let fut = ptr(); - let rt = tokio::runtime::Handle::current(); - return std::thread::spawn(move || rt.block_on(fut).map_err(|e| e.into())) - .join() - .unwrap(); + return std::thread::spawn(move || { + tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap() + .block_on(fut) + .map_err(|e| e.into()) + }) + .join() + .unwrap(); } // todo: technically we can support constructors in wasm with the same tricks inventory uses with `__wasm_call_ctors` diff --git a/packages/fullstack/src/lib.rs b/packages/fullstack/src/lib.rs index f7a6b8ad5d..c497af85dc 100644 --- a/packages/fullstack/src/lib.rs +++ b/packages/fullstack/src/lib.rs @@ -107,6 +107,9 @@ pub mod payloads { pub mod header; pub use header::*; + pub mod query; + pub use query::*; + #[cfg(feature = "ws")] pub mod websocket; #[cfg(feature = "ws")] diff --git a/packages/fullstack/src/magic.rs b/packages/fullstack/src/magic.rs index 2b8851fe58..02c4b8cad0 100644 --- a/packages/fullstack/src/magic.rs +++ b/packages/fullstack/src/magic.rs @@ -42,7 +42,7 @@ use crate::{ use axum::response::IntoResponse; use axum_core::extract::{FromRequest, Request}; use bytes::Bytes; -use dioxus_fullstack_core::{DioxusServerState, RequestError}; +use dioxus_fullstack_core::RequestError; use http::StatusCode; use send_wrapper::SendWrapper; use serde::Serialize; @@ -172,7 +172,7 @@ pub mod req_to { impl EncodeRequest for &&&&&&&&&ServerFnEncoder where T: 'static, - O: FromRequest + IntoRequest, + O: IntoRequest, { type VerifyEncode = EncodeIsVerified; fn fetch_client( @@ -214,6 +214,8 @@ pub mod req_to { pub use decode_ok::*; mod decode_ok { + use crate::{CantDecode, DecodeIsVerified}; + use super::*; /// Convert the reqwest response into the desired type, in place. @@ -222,13 +224,16 @@ mod decode_ok { /// This is because FromResponse types are more specialized and can handle things like websockets and files. /// DeserializeOwned types are more general and can handle things like JSON responses. pub trait RequestDecodeResult { + type VerifyDecode; fn decode_client_response( &self, res: Result, ) -> impl Future, RequestError>> + Send; + fn verify_can_deserialize(&self) -> Self::VerifyDecode; } impl, E, R> RequestDecodeResult for &&&ServerFnDecoder> { + type VerifyDecode = DecodeIsVerified; fn decode_client_response( &self, res: Result, @@ -240,11 +245,15 @@ mod decode_ok { } }) } + fn verify_can_deserialize(&self) -> Self::VerifyDecode { + DecodeIsVerified + } } impl RequestDecodeResult for &&ServerFnDecoder> { + type VerifyDecode = DecodeIsVerified; fn decode_client_response( &self, res: Result, @@ -304,6 +313,24 @@ mod decode_ok { } }) } + fn verify_can_deserialize(&self) -> Self::VerifyDecode { + DecodeIsVerified + } + } + + impl RequestDecodeResult for &ServerFnDecoder> { + type VerifyDecode = CantDecode; + + fn decode_client_response( + &self, + _res: Result, + ) -> impl Future, RequestError>> + Send { + async move { unimplemented!() } + } + + fn verify_can_deserialize(&self) -> Self::VerifyDecode { + CantDecode + } } pub trait RequestDecodeErr { @@ -460,33 +487,34 @@ pub use req_from::*; pub mod req_from { use super::*; use axum::{extract::FromRequestParts, response::Response}; + use dioxus_fullstack_core::FullstackContext; pub trait ExtractRequest { fn extract_axum( &self, - state: DioxusServerState, + state: FullstackContext, request: Request, map: fn(In) -> Out, - ) -> impl Future> + 'static; + ) -> impl Future> + 'static; } /// When you're extracting entirely on the server, we need to reject client-consuning request bodies /// This sits above priority in the combined headers on server / body on client case. impl ExtractRequest for &&&&&&&&&&&ServerFnEncoder where - H: FromRequest + 'static, + H: FromRequest + 'static, { fn extract_axum( &self, - state: DioxusServerState, + state: FullstackContext, request: Request, _map: fn(In) -> (), - ) -> impl Future> + 'static { + ) -> impl Future> + 'static { async move { H::from_request(request, &state) .await .map_err(|e| e.into_response()) - .map(|out| (out, ())) + .map(|out| ((), out)) } } } @@ -496,14 +524,14 @@ pub mod req_from { where In: DeserializeOwned + 'static, Out: 'static, - H: FromRequestParts, + H: FromRequestParts, { fn extract_axum( &self, - _state: DioxusServerState, + _state: FullstackContext, request: Request, map: fn(In) -> Out, - ) -> impl Future> + 'static { + ) -> impl Future> + 'static { async move { let (mut parts, body) = request.into_parts(); let headers = H::from_request_parts(&mut parts, &_state) @@ -525,7 +553,7 @@ pub mod req_from { .map_err(|e| ServerFnError::from(e).into_response()) .unwrap(); - Ok((headers, out)) + Ok((out, headers)) } } } @@ -533,15 +561,15 @@ pub mod req_from { /// We skip the BodySerialize wrapper and just go for the output type directly. impl ExtractRequest for &&&&&&&&&ServerFnEncoder where - Out: FromRequest + 'static, - H: FromRequestParts, + Out: FromRequest + 'static, + H: FromRequestParts, { fn extract_axum( &self, - state: DioxusServerState, + state: FullstackContext, request: Request, _map: fn(In) -> Out, - ) -> impl Future> + 'static { + ) -> impl Future> + 'static { async move { let (mut parts, body) = request.into_parts(); let headers = H::from_request_parts(&mut parts, &state) @@ -554,7 +582,7 @@ pub mod req_from { .await .map_err(|e| e.into_response()); - res.map(|out| (headers, out)) + res.map(|out| (out, headers)) } } } @@ -616,7 +644,7 @@ mod resp { #[allow(clippy::result_large_err)] pub trait MakeAxumError { - fn make_axum_error(self, result: Result) -> Result; + fn make_axum_error(self, result: Result) -> Response; } /// Get the status code from the error type if possible. @@ -639,9 +667,9 @@ mod resp { where E: AsStatusCode + From + Serialize + DeserializeOwned + Display, { - fn make_axum_error(self, result: Result) -> Result { + fn make_axum_error(self, result: Result) -> Response { match result { - Ok(res) => Ok(res), + Ok(res) => res, Err(err) => { let status_code = err.as_status_code(); let err = ErrorPayload { @@ -656,19 +684,16 @@ mod resp { HeaderValue::from_static("application/json"), ); *resp.status_mut() = status_code; - Err(resp) + resp } } } } impl MakeAxumError for &&ServerFnDecoder> { - fn make_axum_error( - self, - result: Result, - ) -> Result { + fn make_axum_error(self, result: Result) -> Response { match result { - Ok(res) => Ok(res), + Ok(res) => res, // Optimize the case where we have sole ownership of the error Err(errr) if errr._strong_count() == 1 => { @@ -721,19 +746,16 @@ mod resp { HeaderValue::from_static("application/json"), ); *resp.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; - Err(resp) + resp } } } } impl MakeAxumError for &&ServerFnDecoder> { - fn make_axum_error( - self, - result: Result, - ) -> Result { + fn make_axum_error(self, result: Result) -> Response { match result { - Ok(res) => Ok(res), + Ok(res) => res, Err(errr) => { // The `WithHttpError` trait emits ServerFnErrors so we can downcast them here // to create richer responses. @@ -775,19 +797,16 @@ mod resp { HeaderValue::from_static("application/json"), ); *resp.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; - Err(resp) + resp } } } } impl MakeAxumError for &&ServerFnDecoder> { - fn make_axum_error( - self, - result: Result, - ) -> Result { + fn make_axum_error(self, result: Result) -> Response { match result { - Ok(resp) => Ok(resp), + Ok(resp) => resp, Err(status) => { let body = serde_json::to_string(&ErrorPayload::<()> { code: status.as_u16(), @@ -801,19 +820,16 @@ mod resp { HeaderValue::from_static("application/json"), ); *resp.status_mut() = status; - Err(resp) + resp } } } } impl MakeAxumError for &ServerFnDecoder> { - fn make_axum_error( - self, - result: Result, - ) -> Result { + fn make_axum_error(self, result: Result) -> Response { match result { - Ok(resp) => Ok(resp), + Ok(resp) => resp, Err(http_err) => { let body = serde_json::to_string(&ErrorPayload::<()> { code: http_err.status.as_u16(), @@ -829,7 +845,7 @@ mod resp { HeaderValue::from_static("application/json"), ); *resp.status_mut() = http_err.status; - Err(resp) + resp } } } diff --git a/packages/fullstack/src/payloads/axum_types.rs b/packages/fullstack/src/payloads/axum_types.rs index 8fce7c6b7f..5a7e1d5dac 100644 --- a/packages/fullstack/src/payloads/axum_types.rs +++ b/packages/fullstack/src/payloads/axum_types.rs @@ -18,7 +18,7 @@ impl> FromResponse for Html { impl IntoRequest for Json where - T: Serialize + 'static, + T: Serialize + 'static + DeserializeOwned, { fn into_request(self, request: ClientRequest) -> impl Future + 'static { async move { request.send_json(&self.0).await } diff --git a/packages/fullstack/src/payloads/form.rs b/packages/fullstack/src/payloads/form.rs index 805e2305a1..2e6b311b4b 100644 --- a/packages/fullstack/src/payloads/form.rs +++ b/packages/fullstack/src/payloads/form.rs @@ -4,7 +4,7 @@ pub use axum::extract::Form; impl IntoRequest for Form where - T: Serialize + 'static, + T: Serialize + 'static + DeserializeOwned, { fn into_request(self, req: ClientRequest) -> impl Future + 'static { async move { req.send_form(&self.0).await } diff --git a/packages/fullstack/src/payloads/query.rs b/packages/fullstack/src/payloads/query.rs new file mode 100644 index 0000000000..b487ea100a --- /dev/null +++ b/packages/fullstack/src/payloads/query.rs @@ -0,0 +1,34 @@ +use std::ops::Deref; + +use crate::ServerFnError; +use axum::extract::FromRequestParts; +use http::request::Parts; +use serde::de::DeserializeOwned; + +/// An extractor that deserializes query parameters into the given type `T`. +/// +/// This uses `serde_qs` under the hood to support complex query parameter structures. +#[derive(Debug, Clone, Copy, Default)] +pub struct Query(pub T); + +impl FromRequestParts for Query +where + T: DeserializeOwned, + S: Send + Sync, +{ + type Rejection = ServerFnError; + + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { + let inner: T = serde_qs::from_str(parts.uri.query().unwrap_or_default()) + .map_err(|e| ServerFnError::Deserialization(e.to_string()))?; + Ok(Self(inner)) + } +} + +impl Deref for Query { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} diff --git a/packages/fullstack/src/payloads/stream.rs b/packages/fullstack/src/payloads/stream.rs index 1e0af3c039..7ab59ffeb6 100644 --- a/packages/fullstack/src/payloads/stream.rs +++ b/packages/fullstack/src/payloads/stream.rs @@ -9,18 +9,83 @@ use axum_core::response::IntoResponse; use bytes::Bytes; use dioxus_fullstack_core::{HttpError, RequestError}; use futures::{Stream, StreamExt}; +#[cfg(feature = "server")] +use futures_channel::mpsc::UnboundedSender; use headers::{ContentType, Header}; use send_wrapper::SendWrapper; use serde::{de::DeserializeOwned, Serialize}; use std::{future::Future, marker::PhantomData, pin::Pin}; +/// A stream of text data. +/// +/// # Chunking +/// +/// Note that strings sent by the server might not arrive in the same chunking as they were sent. +/// +/// This is because the underlying transport layer (HTTP/2 or HTTP/3) may choose to split or combine +/// chunks for efficiency. +/// +/// If you need to preserve individual string boundaries, consider using `ChunkedTextStream` or another +/// encoding that preserves chunk boundaries. pub type TextStream = Streaming; + +/// A stream of binary data. +/// +/// # Chunking +/// +/// Note that bytes sent by the server might not arrive in the same chunking as they were sent. +/// This is because the underlying transport layer (HTTP/2 or HTTP/3) may choose to split or combine +/// chunks for efficiency. +/// +/// If you need to preserve individual byte boundaries, consider using `ChunkedByteStream` or another +/// encoding that preserves chunk boundaries. pub type ByteStream = Streaming; + +/// A stream of JSON-encoded data. +/// +/// # Chunking +/// +/// Normally, it's not possible to stream JSON over HTTP because browsers are free to re-chunk +/// data as they see fit. However, this implementation manually frames each JSON as if it were an unmasked +/// websocket message. +/// +/// If you need to send a stream of JSON data without framing, consider using TextStream instead and +/// manually handling JSON buffering. pub type JsonStream = Streaming; + +/// A stream of Cbor-encoded data. +/// +/// # Chunking +/// +/// Normally, it's not possible to stream JSON over HTTP because browsers are free to re-chunk +/// data as they see fit. However, this implementation manually frames each item as if it were an unmasked +/// websocket message. pub type CborStream = Streaming; +/// A stream of manually chunked binary data. +/// +/// This encoding preserves chunk boundaries by framing each chunk with its length, using Websocket +/// Framing. +pub type ChunkedByteStream = Streaming; + +/// A stream of manually chunked text data. +/// +/// This encoding preserves chunk boundaries by framing each chunk with its length, using Websocket +/// Framing. +pub type ChunkedTextStream = Streaming; + /// A streaming payload. /// +/// ## Frames and Chunking +/// +/// The streaming payload sends and receives data in discrete chunks or "frames". The size is converted +/// to hex and sent before each chunk, followed by a CRLF, the chunk data, and another CRLF. +/// +/// This mimics actual HTTP chunked transfer encoding, but allows us to define our own framing +/// protocol on top of it. +/// +/// Arbitrary bytes can be encoded between these frames, but the frames do come with some overhead. +/// /// ## Browser Support for Streaming Input /// /// Browser fetch requests do not currently support full request duplexing, which @@ -58,9 +123,9 @@ impl Streaming { pub fn new(value: impl Stream + Send + 'static) -> Self { // Box and pin the incoming stream and store as a trait object Self { - output_stream: Box::pin(futures::stream::empty()) as _, input_stream: Box::pin(value.map(|item| Ok(item))) as Pin> + Send>>, + output_stream: Box::pin(futures::stream::empty()) as _, encoding: PhantomData, } } @@ -69,9 +134,7 @@ impl Streaming { /// /// The callback is provided an `UnboundedSender` that can be used to send items to the stream. #[cfg(feature = "server")] - pub fn spawn( - callback: impl FnOnce(futures_channel::mpsc::UnboundedSender) -> F + Send + 'static, - ) -> Self + pub fn spawn(callback: impl FnOnce(UnboundedSender) -> F + Send + 'static) -> Self where F: Future + 'static, T: Send, @@ -92,6 +155,17 @@ impl Streaming { pub fn into_inner(self) -> impl Stream> + Send { self.input_stream } + + /// Creates a streaming payload from an existing stream of bytes. + /// + /// This uses the internal framing mechanism to decode the stream into items of type `T`. + fn from_bytes(stream: impl Stream> + Send + 'static) -> Self { + Self { + input_stream: Box::pin(stream), + output_stream: Box::pin(futures::stream::empty()) as _, + encoding: PhantomData, + } + } } impl From for TextStream @@ -100,11 +174,7 @@ where U: Into, { fn from(value: S) -> Self { - Self { - input_stream: Box::pin(value.map(|data| Ok(data.into()))), - output_stream: Box::pin(futures::stream::empty()) as _, - encoding: PhantomData, - } + Self::new(value.map(|data| data.into())) } } @@ -129,11 +199,7 @@ where E: Encoding, { fn from(value: S) -> Self { - Self { - input_stream: Box::pin(value.map(|data| Ok(data.into()))), - output_stream: Box::pin(futures::stream::empty()) as _, - encoding: PhantomData, - } + Self::from_bytes(value.map(|data| Ok(data.into()))) } } @@ -158,7 +224,7 @@ impl IntoResponse for Streaming { impl IntoResponse for Streaming { fn into_response(self) -> axum_core::response::Response { let res = self.input_stream.map(|r| match r { - Ok(res) => match E::to_bytes(&res) { + Ok(res) => match encode_stream_frame::(res) { Some(bytes) => Ok(bytes), None => Err(StreamingError::Failed), }, @@ -218,7 +284,7 @@ impl FromResponse SendWrapper::new(async move { let client_stream = Box::pin(SendWrapper::new(res.bytes_stream().map( |byte| match byte { - Ok(bytes) => match E::from_bytes(bytes) { + Ok(bytes) => match decode_stream_frame::(bytes) { Some(res) => Ok(res), None => Err(StreamingError::Decoding), }, @@ -330,7 +396,7 @@ impl FromReque Ok(Self { input_stream: Box::pin(futures::stream::empty()), output_stream: Box::pin(stream.map(|byte| match byte { - Ok(bytes) => match E::from_bytes(bytes) { + Ok(bytes) => match decode_stream_frame::(bytes) { Some(res) => Ok(res), None => Err(StreamingError::Decoding), }, @@ -380,11 +446,11 @@ impl IntoRequest async move { builder .header("Content-Type", E::stream_content_type())? - .send_body_stream( - self.input_stream.map(|r| { - r.and_then(|item| E::to_bytes(&item).ok_or(StreamingError::Failed)) - }), - ) + .send_body_stream(self.input_stream.map(|r| { + r.and_then(|item| { + encode_stream_frame::(item).ok_or(StreamingError::Failed) + }) + })) .await } } @@ -403,3 +469,113 @@ impl std::fmt::Debug for Streaming { .finish() } } + +/// This function encodes a single frame of a streaming payload using the specified encoding. +/// +/// The resulting `Bytes` object is encoded as a websocket frame, so you can send it over a streaming +/// HTTP response or even a websocket connection. +/// +/// Note that the packet is not masked, as it is assumed to be sent over a trusted connection. +pub fn encode_stream_frame(data: T) -> Option { + // We use full advantage of `BytesMut` here, writing a maximally full frame and then shrinking it + // down to size at the end. + // + // Also note we don't do any masking over this data since it's not going over an untrusted + // network like a websocket would. + // + // We allocate 10 extra bytes to account for framing overhead, which we'll shrink after + let mut bytes = vec![0u8; 10]; + + E::encode(data, &mut bytes)?; + + let len = (bytes.len() - 10) as u64; + let opcode = 0x82; // FIN + binary opcode + + // Write the header directly into the allocated space. + let offset = if len <= 125 { + bytes[8] = opcode; + bytes[9] = len as u8; + 8 + } else if len <= u16::MAX as u64 { + bytes[6] = opcode; + bytes[7] = 126; + let len_bytes = (len as u16).to_be_bytes(); + bytes[8] = len_bytes[0]; + bytes[9] = len_bytes[1]; + 6 + } else { + bytes[0] = opcode; + bytes[1] = 127; + bytes[2..10].copy_from_slice(&len.to_be_bytes()); + 0 + }; + + // Shrink down to the actual used size - is zero copy! + Some(Bytes::from(bytes).slice(offset..)) +} + +/// Decode a websocket-framed streaming payload produced by [`encode_stream_frame`]. +/// +/// This function returns `None` if the frame is invalid or cannot be decoded. +/// +/// It cannot handle masked frames, as those are not produced by our encoding function. +pub fn decode_stream_frame(frame: Bytes) -> Option +where + E: Encoding, + T: DeserializeOwned, +{ + let data = frame.as_ref(); + + if data.len() < 2 { + return None; + } + + let first = data[0]; + let second = data[1]; + + // Require FIN with binary opcode and no RSV bits + let fin = first & 0x80 != 0; + let opcode = first & 0x0F; + let rsv = first & 0x70; + if !fin || opcode != 0x02 || rsv != 0 { + return None; + } + + // Mask bit must be zero for our framing + if second & 0x80 != 0 { + return None; + } + + let mut offset = 2usize; + let mut payload_len = (second & 0x7F) as usize; + + if payload_len == 126 { + if data.len() < offset + 2 { + return None; + } + + payload_len = u16::from_be_bytes([data[offset], data[offset + 1]]) as usize; + offset += 2; + } else if payload_len == 127 { + if data.len() < offset + 8 { + return None; + } + + let mut len_bytes = [0u8; 8]; + len_bytes.copy_from_slice(&data[offset..offset + 8]); + let len_u64 = u64::from_be_bytes(len_bytes); + + if len_u64 > usize::MAX as u64 { + return None; + } + + payload_len = len_u64 as usize; + offset += 8; + } + + if data.len() < offset + payload_len { + return None; + } + + E::decode(frame.slice(offset..offset + payload_len)) +} diff --git a/packages/fullstack/src/payloads/websocket.rs b/packages/fullstack/src/payloads/websocket.rs index 3c65bb97fd..5fab15ebb3 100644 --- a/packages/fullstack/src/payloads/websocket.rs +++ b/packages/fullstack/src/payloads/websocket.rs @@ -307,11 +307,11 @@ impl Websocket { match msg { Message::Text(text) => { let e: O = - E::from_bytes(text.into()).ok_or_else(WebsocketError::deserialization)?; + E::decode(text.into()).ok_or_else(WebsocketError::deserialization)?; return Ok(e); } Message::Binary(bytes) => { - let e: O = E::from_bytes(bytes).ok_or_else(WebsocketError::deserialization)?; + let e: O = E::decode(bytes).ok_or_else(WebsocketError::deserialization)?; return Ok(e); } Message::Close { code, reason } => { @@ -707,13 +707,12 @@ impl TypedWebsocket match res { AxumMessage::Text(utf8_bytes) => { - let e: In = E::from_bytes(utf8_bytes.into()) + let e: In = E::decode(utf8_bytes.into()) .ok_or_else(WebsocketError::deserialization)?; return Ok(e); } AxumMessage::Binary(bytes) => { - let e: In = - E::from_bytes(bytes).ok_or_else(WebsocketError::deserialization)?; + let e: In = E::decode(bytes).ok_or_else(WebsocketError::deserialization)?; return Ok(e); } diff --git a/packages/fullstack/src/request.rs b/packages/fullstack/src/request.rs index 8e214e202d..34658c4c1e 100644 --- a/packages/fullstack/src/request.rs +++ b/packages/fullstack/src/request.rs @@ -34,7 +34,7 @@ pub trait IntoRequest: Sized { impl IntoRequest for (A,) where - A: IntoRequest + 'static, + A: IntoRequest + 'static + Send, { fn into_request( self, @@ -157,9 +157,20 @@ impl AssertIsResult for Result {} #[doc(hidden)] pub fn assert_is_result() {} -#[diagnostic::on_unimplemented( - message = "The arguments to the server function must either be a single `impl FromRequest + IntoRequest` argument, or multiple `DeserializeOwned` arguments." -)] +#[diagnostic::on_unimplemented(message = r#"❌ Invalid Arguments to ServerFn ❌ + +The arguments to the server function must be either: + +- a single `impl FromRequest + IntoRequest` argument +- or multiple `DeserializeOwned` arguments. + +Did you forget to implement `IntoRequest` or `Deserialize` for one of the arguments? + +`IntoRequest` is a trait that allows payloads to be sent to the server function. + +> See https://dioxuslabs.com/learn/0.7/essentials/fullstack/server_functions for more details. + +"#)] pub trait AssertCanEncode {} pub struct CantEncode; @@ -167,5 +178,27 @@ pub struct CantEncode; pub struct EncodeIsVerified; impl AssertCanEncode for EncodeIsVerified {} +#[diagnostic::on_unimplemented(message = r#"❌ Invalid return type from ServerFn ❌ + +The arguments to the server function must be either: + +- a single `impl FromResponse` return type +- a single `impl Serialize + DeserializedOwned` return type + +Did you forget to implement `FromResponse` or `DeserializeOwned` for one of the arguments? + +`FromResponse` is a trait that allows payloads to be decoded from the server function response. + +> See https://dioxuslabs.com/learn/0.7/essentials/fullstack/server_functions for more details. + +"#)] +pub trait AssertCanDecode {} +pub struct CantDecode; +pub struct DecodeIsVerified; +impl AssertCanDecode for DecodeIsVerified {} + #[doc(hidden)] pub fn assert_can_encode(_t: impl AssertCanEncode) {} + +#[doc(hidden)] +pub fn assert_can_decode(_t: impl AssertCanDecode) {}