Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,15 @@ pub mod a2a_types;
pub mod client;
pub mod config;
pub mod server;
pub mod task_handler;

pub use client::{A2AClient, HealthStatus};
pub use config::{AgentConfig, ClientConfig, Config};
pub use server::{A2AServer, A2AServerBuilder, Agent, AgentBuilder, AgentCardOverrides};
pub use task_handler::{
BackgroundTaskHandlerConfig, BackgroundTaskQueue, DefaultBackgroundTaskHandler,
ManagedTask, QueuedTask, TaskHandler, TaskQueueStats
};

#[cfg(test)]
mod tests {
Expand Down
193 changes: 189 additions & 4 deletions src/server.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
use crate::a2a_types::AgentCard;
use crate::a2a_types::{AgentCard, SendMessageRequest, SendMessageSuccessResponse, SendMessageRequestId, SendMessageSuccessResponseId, SendMessageSuccessResponseResult, TaskState};
use crate::client::HealthStatus;
use crate::config::{AgentConfig, Config};
use crate::task_handler::{
BackgroundTaskHandlerConfig, BackgroundTaskQueue, DefaultBackgroundTaskHandler, TaskHandler
};
use anyhow::{Result, anyhow};
use axum::{
Router,
extract::State,
extract::{Path, State},
http::StatusCode,
response::Json,
routing::{get, post},
Expand All @@ -19,7 +22,7 @@ use std::sync::Arc;
use tokio::net::TcpListener;
use tower::ServiceBuilder;
use tower_http::{cors::CorsLayer, trace::TraceLayer};
use tracing::{debug, error, info};
use tracing::{debug, error, info, warn};

/// Agent card field overrides
#[derive(Debug, Clone, Default)]
Expand Down Expand Up @@ -144,8 +147,10 @@ pub struct A2AServer {
agent_card: Option<AgentCard>,
agent: Option<Agent>,
gateway_url: String,
background_task_queue: Option<Arc<BackgroundTaskQueue>>,
}

#[derive(Clone)]
pub struct Agent {
#[allow(dead_code)]
config: AgentConfig,
Expand All @@ -158,7 +163,7 @@ pub struct Agent {
max_conversation_history: u32,
#[allow(dead_code)]
toolbox: Option<Vec<Tool>>,
tool_handlers: HashMap<String, Box<dyn ToolHandler>>,
tool_handlers: HashMap<String, Arc<dyn ToolHandler>>,
}

impl std::fmt::Debug for Agent {
Expand Down Expand Up @@ -186,6 +191,8 @@ pub struct A2AServerBuilder {
agent_card_overrides: Option<AgentCardOverrides>,
agent: Option<Agent>,
gateway_url: Option<String>,
background_task_handler: Option<Arc<dyn TaskHandler>>,
background_task_handler_config: Option<BackgroundTaskHandlerConfig>,
}

pub struct AgentBuilder {
Expand All @@ -202,6 +209,16 @@ struct AppState {
server: A2AServer,
}

impl AppState {
fn has_background_task_handler(&self) -> bool {
self.server.background_task_queue.is_some()
}

fn get_background_task_queue(&self) -> Option<&BackgroundTaskQueue> {
self.server.background_task_queue.as_deref()
}
}

impl A2AServerBuilder {
pub fn new() -> Self {
Self {
Expand All @@ -211,6 +228,8 @@ impl A2AServerBuilder {
agent_card_overrides: None,
agent: None,
gateway_url: None,
background_task_handler: None,
background_task_handler_config: None,
}
}

Expand Down Expand Up @@ -244,6 +263,21 @@ impl A2AServerBuilder {
self
}

pub fn with_background_task_handler(mut self, handler: Arc<dyn TaskHandler>) -> Self {
self.background_task_handler = Some(handler);
self
}

pub fn with_default_background_task_handler(self) -> Self {
// This will be set up in build() method when we have access to config and agent
self
}

pub fn with_background_task_handler_config(mut self, config: BackgroundTaskHandlerConfig) -> Self {
self.background_task_handler_config = Some(config);
self
}

pub async fn build(self) -> Result<A2AServer> {
let config = self.config.unwrap_or_default();

Expand Down Expand Up @@ -301,11 +335,39 @@ impl A2AServerBuilder {
.gateway_url
.unwrap_or_else(|| "http://localhost:8080/v1".to_string());

// Set up background task handler if configured
let background_task_queue = if let Some(handler) = self.background_task_handler {
let handler_config = self.background_task_handler_config
.unwrap_or_else(|| BackgroundTaskHandlerConfig::from(&config));
Some(Arc::new(BackgroundTaskQueue::new(handler_config, handler)))
} else if self.agent.is_some() {
// Check if user wants default background handler
// For now, we'll make this explicit - user must call with_default_background_task_handler()
None
} else {
None
};

// If with_default_background_task_handler() was called, create the default handler
let background_task_queue = if background_task_queue.is_none() && self.agent.is_some() {
let handler_config = self.background_task_handler_config
.unwrap_or_else(|| BackgroundTaskHandlerConfig::from(&config));
let default_handler = Arc::new(DefaultBackgroundTaskHandler::new(
handler_config.clone(),
self.agent.clone().map(|a| Arc::new(a)),
gateway_url.clone(),
));
Some(Arc::new(BackgroundTaskQueue::new(handler_config, default_handler)))
} else {
background_task_queue
};

Ok(A2AServer {
config,
agent_card,
agent: self.agent,
gateway_url,
background_task_queue,
})
}
}
Expand Down Expand Up @@ -412,6 +474,22 @@ impl Agent {
pub fn toolbox(&self) -> Option<&Vec<Tool>> {
self.toolbox.as_ref()
}

pub fn get_system_prompt(&self) -> Option<String> {
self.system_prompt.clone()
}

pub fn get_provider(&self) -> Provider {
self.provider
}

pub fn get_model(&self) -> &str {
&self.model
}

pub fn get_tool_handlers(&self) -> &HashMap<String, Arc<dyn ToolHandler>> {
&self.tool_handlers
}
}

impl A2AServer {
Expand All @@ -422,6 +500,9 @@ impl A2AServer {
.route("/health", get(health_handler))
.route("/.well-known/agent.json", get(agent_card_handler))
.route("/a2a", post(a2a_handler))
.route("/tasks", post(submit_task_handler))
.route("/tasks/:task_id", get(get_task_handler))
.route("/tasks/stats", get(get_task_stats_handler))
.layer(
ServiceBuilder::new()
.layer(TraceLayer::new_for_http())
Expand Down Expand Up @@ -750,6 +831,110 @@ async fn a2a_handler(
}
}

async fn submit_task_handler(
State(state): State<Arc<AppState>>,
Json(payload): Json<SendMessageRequest>,
) -> Result<Json<SendMessageSuccessResponse>, StatusCode> {
debug!("Task submission request received");

if !state.has_background_task_handler() {
warn!("Background task handler not configured, rejecting task submission");
return Err(StatusCode::NOT_IMPLEMENTED);
}

let queue = state.get_background_task_queue().unwrap();

// Extract message from A2A request format
let a2a_message = &payload.params.message;
let messages = vec![inference_gateway_sdk::Message {
role: match a2a_message.role {
crate::a2a_types::MessageRole::User => inference_gateway_sdk::MessageRole::User,
crate::a2a_types::MessageRole::Agent => inference_gateway_sdk::MessageRole::Assistant,
},
content: a2a_message.parts.iter()
.find_map(|part| match part {
crate::a2a_types::Part::TextPart(text_part) => Some(text_part.text.clone()),
_ => None,
}),
..Default::default()
}];

let context_id = a2a_message.context_id.clone();
let metadata: std::collections::HashMap<String, serde_json::Value> =
payload.params.metadata.into_iter().collect();

match queue.submit_task(messages, context_id, metadata).await {
Ok(task_id) => {
let response = SendMessageSuccessResponse {
id: match payload.id {
SendMessageRequestId::String(s) => SendMessageSuccessResponseId::String(s),
SendMessageRequestId::Integer(i) => SendMessageSuccessResponseId::Integer(i),
},
jsonrpc: "2.0".to_string(),
result: SendMessageSuccessResponseResult::Task(crate::a2a_types::Task {
id: task_id.clone(),
kind: "task".to_string(),
context_id: context_id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
status: crate::a2a_types::TaskStatus {
state: TaskState::Submitted,
message: None,
timestamp: Some(chrono::Utc::now().to_rfc3339()),
},
history: vec![],
artifacts: vec![],
metadata: serde_json::Map::new(),
}),
};
debug!("Task {} submitted successfully", task_id);
Ok(Json(response))
}
Err(e) => {
error!("Failed to submit task: {}", e);
Err(StatusCode::INTERNAL_SERVER_ERROR)
}
}
}

async fn get_task_handler(
State(state): State<Arc<AppState>>,
Path(task_id): Path<String>,
) -> Result<Json<crate::a2a_types::Task>, StatusCode> {
debug!("Task status request for task: {}", task_id);

if !state.has_background_task_handler() {
return Err(StatusCode::NOT_IMPLEMENTED);
}

let queue = state.get_background_task_queue().unwrap();

match queue.get_task(&task_id).await {
Some(managed_task) => {
debug!("Returning status for task: {}", task_id);
Ok(Json(managed_task.task))
}
None => {
debug!("Task not found: {}", task_id);
Err(StatusCode::NOT_FOUND)
}
}
}

async fn get_task_stats_handler(
State(state): State<Arc<AppState>>,
) -> Result<Json<crate::task_handler::TaskQueueStats>, StatusCode> {
debug!("Task statistics request");

if !state.has_background_task_handler() {
return Err(StatusCode::NOT_IMPLEMENTED);
}

let queue = state.get_background_task_queue().unwrap();
let stats = queue.get_stats().await;

debug!("Returning task statistics: {:?}", stats);
Ok(Json(stats))
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
Loading
Loading