Skip to content

Commit dfe2bf5

Browse files
Merge pull request #14 from rohas-dev/feat/middleware
feat(middleware): implement middleware support for APIs and WebSockets, including generation and execution logic
2 parents 6c8a197 + a6498b9 commit dfe2bf5

18 files changed

Lines changed: 550 additions & 14 deletions

File tree

crates/rohas-codegen/src/generator.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ impl Generator {
4848
"handlers/events",
4949
"handlers/cron",
5050
"handlers/websockets",
51+
"middlewares",
5152
];
5253

5354
for dir in &dirs {
@@ -77,6 +78,7 @@ impl Generator {
7778
typescript::generate_events(schema, output_dir)?;
7879
typescript::generate_crons(schema, output_dir)?;
7980
typescript::generate_websockets(schema, output_dir)?;
81+
typescript::generate_middlewares(schema, output_dir)?;
8082
typescript::generate_index(schema, output_dir)?;
8183

8284
info!("Generating TypeScript configuration files");
@@ -98,6 +100,7 @@ impl Generator {
98100
python::generate_events(schema, output_dir)?;
99101
python::generate_crons(schema, output_dir)?;
100102
python::generate_websockets(schema, output_dir)?;
103+
python::generate_middlewares(schema, output_dir)?;
101104
python::generate_init(schema, output_dir)?;
102105

103106
info!("Generating Python configuration files");

crates/rohas-codegen/src/python.rs

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,88 @@ pub fn generate_websockets(schema: &Schema, output_dir: &Path) -> Result<()> {
344344
Ok(())
345345
}
346346

347+
pub fn generate_middlewares(schema: &Schema, output_dir: &Path) -> Result<()> {
348+
use std::collections::HashSet;
349+
350+
let mut middleware_names = HashSet::new();
351+
352+
for api in &schema.apis {
353+
for middleware in &api.middlewares {
354+
middleware_names.insert(middleware.clone());
355+
}
356+
}
357+
358+
for ws in &schema.websockets {
359+
for middleware in &ws.middlewares {
360+
middleware_names.insert(middleware.clone());
361+
}
362+
}
363+
364+
if middleware_names.is_empty() {
365+
return Ok(());
366+
}
367+
368+
let middlewares_dir = output_dir.join("middlewares");
369+
for middleware_name in middleware_names {
370+
let file_name = format!("{}.py", templates::to_snake_case(&middleware_name));
371+
let middleware_path = middlewares_dir.join(&file_name);
372+
373+
if !middleware_path.exists() {
374+
let content = generate_middleware_stub(&middleware_name);
375+
fs::write(middleware_path, content)?;
376+
}
377+
}
378+
379+
Ok(())
380+
}
381+
382+
fn generate_middleware_stub(middleware_name: &str) -> String {
383+
let mut content = String::new();
384+
385+
content.push_str("from typing import Dict, Any, Optional\n");
386+
content.push_str("from generated.state import State\n\n");
387+
388+
content.push_str(&format!(
389+
"async def {}_middleware(context: Dict[str, Any], state: State) -> Optional[Dict[str, Any]]:\n",
390+
templates::to_snake_case(middleware_name)
391+
));
392+
content.push_str(" \"\"\"\n");
393+
content.push_str(&format!(" Middleware function for {}.\n\n", middleware_name));
394+
content.push_str(" Args:\n");
395+
content.push_str(" context: Request context containing:\n");
396+
content.push_str(" - payload: Request payload (for APIs)\n");
397+
content.push_str(" - query_params: Query parameters (for APIs)\n");
398+
content.push_str(" - connection: WebSocket connection info (for WebSockets)\n");
399+
content.push_str(" - websocket_name: WebSocket name (for WebSockets)\n");
400+
content.push_str(" - api_name: API name (for APIs)\n");
401+
content.push_str(" - trace_id: Trace ID\n");
402+
content.push_str(" state: State object for logging and triggering events\n\n");
403+
content.push_str(" Returns:\n");
404+
content.push_str(" Optional[Dict[str, Any]]: Modified context with 'payload' and/or 'query_params' keys,\n");
405+
content.push_str(" or None to pass through unchanged. Return a dict with 'error' key to reject the request.\n\n");
406+
content.push_str(" To reject the request, raise an exception \n");
407+
content.push_str(" \"\"\"\n");
408+
content.push_str(" # TODO: Implement middleware logic\n");
409+
content.push_str(" # Example: Validate authentication\n");
410+
content.push_str(" # Example: Rate limiting\n");
411+
content.push_str(" # Example: Logging\n");
412+
content.push_str(" # Example: Modify payload/query_params\n");
413+
content.push_str(" # \n");
414+
content.push_str(" # To modify the request:\n");
415+
content.push_str(" # return {\n");
416+
content.push_str(" # 'payload': modified_payload,\n");
417+
content.push_str(" # 'query_params': modified_query_params\n");
418+
content.push_str(" # }\n");
419+
content.push_str(" # \n");
420+
content.push_str(" # To reject the request:\n");
421+
content.push_str(" # raise Exception('Access denied')\n");
422+
content.push_str(" \n");
423+
content.push_str(" # Pass through unchanged\n");
424+
content.push_str(" return None\n");
425+
426+
content
427+
}
428+
347429
fn generate_websocket_content(ws: &WebSocket) -> String {
348430
let mut content = String::new();
349431

crates/rohas-codegen/src/typescript.rs

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,99 @@ pub fn generate_websockets(schema: &Schema, output_dir: &Path) -> Result<()> {
504504
Ok(())
505505
}
506506

507+
pub fn generate_middlewares(schema: &Schema, output_dir: &Path) -> Result<()> {
508+
use std::collections::HashSet;
509+
510+
let mut middleware_names = HashSet::new();
511+
512+
for api in &schema.apis {
513+
for middleware in &api.middlewares {
514+
middleware_names.insert(middleware.clone());
515+
}
516+
}
517+
518+
for ws in &schema.websockets {
519+
for middleware in &ws.middlewares {
520+
middleware_names.insert(middleware.clone());
521+
}
522+
}
523+
524+
if middleware_names.is_empty() {
525+
return Ok(());
526+
}
527+
528+
let middlewares_dir = output_dir.join("middlewares");
529+
for middleware_name in middleware_names {
530+
let file_name = format!("{}.ts", middleware_name);
531+
let middleware_path = middlewares_dir.join(&file_name);
532+
533+
if !middleware_path.exists() {
534+
let content = generate_middleware_stub(&middleware_name);
535+
fs::write(middleware_path, content)?;
536+
}
537+
}
538+
539+
Ok(())
540+
}
541+
542+
fn generate_middleware_stub(middleware_name: &str) -> String {
543+
let mut content = String::new();
544+
545+
content.push_str("import { State } from '@generated/state';\n\n");
546+
547+
content.push_str("export interface MiddlewareContext {\n");
548+
content.push_str(" payload?: any;\n");
549+
content.push_str(" query_params?: Record<string, string>;\n");
550+
content.push_str(" connection?: any;\n");
551+
content.push_str(" websocket_name?: string;\n");
552+
content.push_str(" api_name?: string;\n");
553+
content.push_str(" trace_id?: string;\n");
554+
content.push_str("}\n\n");
555+
556+
content.push_str(&format!(
557+
"export async function {}Middleware(\n",
558+
middleware_name
559+
));
560+
content.push_str(" context: MiddlewareContext,\n");
561+
content.push_str(" state: State\n");
562+
content.push_str("): Promise<MiddlewareContext | null> {\n");
563+
content.push_str(" /**\n");
564+
content.push_str(&format!(" * Middleware function for {}.\n", middleware_name));
565+
content.push_str(" * \n");
566+
content.push_str(" * @param context - Request context containing:\n");
567+
content.push_str(" * - payload: Request payload (for APIs)\n");
568+
content.push_str(" * - query_params: Query parameters (for APIs)\n");
569+
content.push_str(" * - connection: WebSocket connection info (for WebSockets)\n");
570+
content.push_str(" * - websocket_name: WebSocket name (for WebSockets)\n");
571+
content.push_str(" * - api_name: API name (for APIs)\n");
572+
content.push_str(" * - trace_id: Trace ID\n");
573+
content.push_str(" * @param state - State object for logging and triggering events\n");
574+
content.push_str(" * @returns Modified context with 'payload' and/or 'query_params' keys,\n");
575+
content.push_str(" * or null to pass through unchanged. Throw an error to reject the request.\n");
576+
content.push_str(" * \n");
577+
content.push_str(" * To reject the request, throw an error:\n");
578+
content.push_str(" * throw new Error('Access denied');\n");
579+
content.push_str(" * \n");
580+
content.push_str(" * To modify the request:\n");
581+
content.push_str(" * return {\n");
582+
content.push_str(" * ...context,\n");
583+
content.push_str(" * payload: modifiedPayload,\n");
584+
content.push_str(" * query_params: modifiedQueryParams\n");
585+
content.push_str(" * };\n");
586+
content.push_str(" */\n");
587+
content.push_str(" // TODO: Implement middleware logic\n");
588+
content.push_str(" // Example: Validate authentication\n");
589+
content.push_str(" // Example: Rate limiting\n");
590+
content.push_str(" // Example: Logging\n");
591+
content.push_str(" // Example: Modify payload/query_params\n");
592+
content.push_str(" \n");
593+
content.push_str(" // Pass through unchanged\n");
594+
content.push_str(" return null;\n");
595+
content.push_str("}\n");
596+
597+
content
598+
}
599+
507600
fn generate_websocket_content(ws: &WebSocket) -> String {
508601
let mut content = String::new();
509602

crates/rohas-engine/src/api.rs

Lines changed: 115 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use chrono::Utc;
1212
use rohas_codegen::templates;
1313
use rohas_parser::{HttpMethod, Schema};
1414
use rohas_runtime::Executor;
15-
use serde_json::Value;
15+
use serde_json::{json, Value};
1616
use std::{collections::HashMap, sync::Arc};
1717
use tracing::{debug, info_span};
1818

@@ -307,11 +307,31 @@ async fn api_handler(
307307
}
308308
}
309309

310+
let middleware_result = execute_middlewares(
311+
state.clone(),
312+
&api.middlewares,
313+
payload.clone(),
314+
query_params.clone(),
315+
&trace_id,
316+
&api_name,
317+
)
318+
.await;
319+
320+
if let Err(e) = middleware_result {
321+
state
322+
.trace_store
323+
.complete_trace(&trace_id, crate::trace::TraceStatus::Failed, Some(e.clone()))
324+
.await;
325+
return Err(ApiError::BadRequest(e));
326+
}
327+
328+
let (final_payload, final_query_params) = middleware_result.unwrap();
329+
310330
let result = execute_handler(
311331
state.clone(),
312332
handler_name.clone(),
313-
payload,
314-
query_params,
333+
final_payload,
334+
final_query_params,
315335
api_triggers,
316336
api_name,
317337
trace_id.clone(),
@@ -383,6 +403,98 @@ fn parse_query_string(query: &str) -> HashMap<String, String> {
383403
.collect()
384404
}
385405

406+
async fn execute_middlewares(
407+
state: ApiState,
408+
middlewares: &[String],
409+
mut payload: Value,
410+
mut query_params: HashMap<String, String>,
411+
trace_id: &str,
412+
api_name: &str,
413+
) -> Result<(Value, HashMap<String, String>), String> {
414+
if middlewares.is_empty() {
415+
return Ok((payload, query_params));
416+
}
417+
418+
debug!("Executing {} middlewares for API: {}", middlewares.len(), api_name);
419+
420+
for middleware_name in middlewares {
421+
let middleware_handler_name = match state.config.language {
422+
config::Language::TypeScript => middleware_name.clone(),
423+
config::Language::Python => templates::to_snake_case(middleware_name.as_str()),
424+
};
425+
426+
debug!("Executing middleware: {}", middleware_handler_name);
427+
428+
let middleware_context = json!({
429+
"payload": payload,
430+
"query_params": query_params,
431+
"api_name": api_name,
432+
"trace_id": trace_id,
433+
});
434+
435+
let mut context = rohas_runtime::HandlerContext::new(&middleware_handler_name, middleware_context);
436+
context.metadata.insert("middleware".to_string(), "true".to_string());
437+
context.metadata.insert("api_name".to_string(), api_name.to_string());
438+
439+
let start = std::time::Instant::now();
440+
let result = state.executor.execute_with_context(context).await;
441+
let duration_ms = start.elapsed().as_millis() as u64;
442+
443+
if let Ok(ref exec_result) = result {
444+
state
445+
.trace_store
446+
.add_step(
447+
trace_id,
448+
format!("middleware:{}", middleware_handler_name),
449+
duration_ms.max(exec_result.execution_time_ms),
450+
exec_result.success,
451+
exec_result.error.clone(),
452+
)
453+
.await;
454+
}
455+
456+
match result {
457+
Ok(exec_result) => {
458+
if !exec_result.success {
459+
let error_msg = exec_result.error.unwrap_or_else(|| {
460+
format!("Middleware '{}' rejected the request", middleware_name)
461+
});
462+
return Err(error_msg);
463+
}
464+
465+
if let Some(data) = exec_result.data {
466+
if let Value::Object(middleware_data) = data {
467+
if let Some(new_payload) = middleware_data.get("payload") {
468+
payload = new_payload.clone();
469+
}
470+
471+
if let Some(new_query_params) = middleware_data.get("query_params") {
472+
if let Value::Object(params_obj) = new_query_params {
473+
query_params = params_obj
474+
.iter()
475+
.filter_map(|(k, v)| {
476+
if let Value::String(s) = v {
477+
Some((k.clone(), s.clone()))
478+
} else {
479+
None
480+
}
481+
})
482+
.collect();
483+
}
484+
}
485+
}
486+
}
487+
}
488+
Err(e) => {
489+
let error_msg = format!("Middleware '{}' execution failed: {}", middleware_name, e);
490+
return Err(error_msg);
491+
}
492+
}
493+
}
494+
495+
Ok((payload, query_params))
496+
}
497+
386498
async fn execute_handler(
387499
state: ApiState,
388500
handler_name: String,

crates/rohas-engine/src/workbench.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -859,6 +859,7 @@ pub struct WebSocketEndpoint {
859859
pub on_disconnect: Vec<String>,
860860
pub triggers: Vec<String>,
861861
pub broadcast: bool,
862+
pub middlewares: Vec<String>,
862863
}
863864

864865
#[derive(Serialize, Deserialize)]
@@ -913,6 +914,7 @@ async fn get_endpoints(State(state): State<ApiState>) -> Result<Response, Workbe
913914
on_disconnect: ws.on_disconnect.clone(),
914915
triggers: ws.triggers.clone(),
915916
broadcast: ws.broadcast,
917+
middlewares: ws.middlewares.clone(),
916918
})
917919
.collect();
918920

0 commit comments

Comments
 (0)