diff --git a/.codecov.yml b/.codecov.yml new file mode 100644 index 00000000..b7c2f10b --- /dev/null +++ b/.codecov.yml @@ -0,0 +1,48 @@ +codecov: + require_ci_to_pass: true + +coverage: + precision: 2 + round: down + range: "70...100" + + status: + project: + default: + target: auto + threshold: 1% + base: auto + branches: + - main + if_ci_failed: error + informational: false + only_pulls: false + + patch: + default: + target: auto + threshold: 1% + base: auto + if_ci_failed: error + only_pulls: true + + changes: false + +comment: + layout: "diff, flags, files" + behavior: default + require_changes: false + require_base: false + require_head: true + hide_project_coverage: false + +parsers: + gcov: + branch_detection: + conditional: true + loop: true + macro: false + method: false + +github_checks: + annotations: true diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a3255356..59992aba 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -29,7 +29,7 @@ jobs: - name: checkout uses: actions/checkout@v4 - name: toolchain - uses: dtolnay/rust-toolchain@master + uses: dtolnay/rust-toolchain@stable with: toolchain: ${{ env.toolchain }} components: rustfmt @@ -43,7 +43,7 @@ jobs: - name: checkout uses: actions/checkout@v4 - name: toolchain - uses: dtolnay/rust-toolchain@master + uses: dtolnay/rust-toolchain@stable with: toolchain: ${{ env.toolchain }} components: clippy @@ -51,7 +51,15 @@ jobs: run: | cargo clippy \ --all-targets \ - -- -D warnings + -- -D warnings \ + -W clippy::pedantic \ + -W clippy::nursery \ + -W clippy::style \ + -W clippy::complexity \ + -W clippy::perf \ + -W clippy::suspicious \ + -W clippy::correctness + test: name: test runs-on: ubuntu-22.04 @@ -73,7 +81,7 @@ jobs: - name: checkout uses: actions/checkout@v4 - name: toolchain - uses: dtolnay/rust-toolchain@master + uses: dtolnay/rust-toolchain@stable with: toolchain: ${{ env.toolchain }} - name: Unit tests diff --git a/atoma-bin/atoma_daemon.rs b/atoma-bin/atoma_daemon.rs index 48b30e1d..fe16cbb2 100644 --- a/atoma-bin/atoma_daemon.rs +++ b/atoma-bin/atoma_daemon.rs @@ -6,7 +6,7 @@ use atoma_daemon::{ server::{run_server, DaemonState}, }; use atoma_state::{config::AtomaStateManagerConfig, AtomaState}; -use atoma_sui::client::AtomaSuiClient; +use atoma_sui::client::Client; use atoma_utils::spawn_with_shutdown; use clap::Parser; use sui_sdk::types::base_types::ObjectID; @@ -39,13 +39,14 @@ struct DaemonArgs { } #[tokio::main] +#[allow(clippy::redundant_pub_crate)] async fn main() -> Result<()> { - setup_logging()?; + setup_logging(); let args = DaemonArgs::parse(); let daemon_config = AtomaDaemonConfig::from_file_path(args.config_path.clone()); let state_manager_config = AtomaStateManagerConfig::from_file_path(args.config_path.clone()); let client = Arc::new(RwLock::new( - AtomaSuiClient::new_from_config(args.config_path).await?, + Client::new_from_config(args.config_path).await?, )); info!( @@ -86,7 +87,7 @@ async fn main() -> Result<()> { let ctrl_c = tokio::task::spawn(async move { tokio::select! { - _ = tokio::signal::ctrl_c() => { + result = tokio::signal::ctrl_c() => { info!( target = "atoma_daemon", event = "atoma-daemon-stop", @@ -95,10 +96,10 @@ async fn main() -> Result<()> { shutdown_sender .send(true) .context("Failed to send shutdown signal")?; - Ok::<(), anyhow::Error>(()) + result.map_err(anyhow::Error::from) } _ = shutdown_receiver.changed() => { - Ok::<(), anyhow::Error>(()) + Ok(()) } } }); @@ -108,7 +109,7 @@ async fn main() -> Result<()> { daemon_result } -fn setup_logging() -> Result<()> { +fn setup_logging() { let log_dir = Path::new(LOGS); let file_appender = RollingFileAppender::new(Rotation::DAILY, log_dir, LOG_FILE); let (non_blocking_appender, _guard) = non_blocking(file_appender); @@ -133,5 +134,4 @@ fn setup_logging() -> Result<()> { .with(console_layer) .with(file_layer) .init(); - Ok(()) } diff --git a/atoma-bin/atoma_node.rs b/atoma-bin/atoma_node.rs index f40c9913..a52362eb 100644 --- a/atoma-bin/atoma_node.rs +++ b/atoma-bin/atoma_node.rs @@ -5,7 +5,7 @@ use std::{ }; use anyhow::{Context, Result}; -use atoma_confidential::AtomaConfidentialComputeService; +use atoma_confidential::AtomaConfidentialCompute; use atoma_daemon::{AtomaDaemonConfig, DaemonState}; use atoma_service::{ config::AtomaServiceConfig, @@ -13,7 +13,7 @@ use atoma_service::{ server::AppState, }; use atoma_state::{config::AtomaStateManagerConfig, AtomaState, AtomaStateManager}; -use atoma_sui::{client::AtomaSuiClient, AtomaSuiConfig, SuiEventSubscriber}; +use atoma_sui::{client::Client, config::Config, subscriber::Subscriber}; use atoma_utils::spawn_with_shutdown; use clap::Parser; use dotenv::dotenv; @@ -66,9 +66,9 @@ struct Args { /// This struct holds the configuration settings for various components /// of the Atoma node, including the Sui, service, and state manager configurations. #[derive(Debug)] -struct Config { +struct NodeConfig { /// Configuration for the Sui component. - sui: AtomaSuiConfig, + sui: Config, /// Configuration for the service component. service: AtomaServiceConfig, @@ -83,9 +83,9 @@ struct Config { proxy: ProxyConfig, } -impl Config { - async fn load(path: &str) -> Result { - let sui = AtomaSuiConfig::from_file_path(path); +impl NodeConfig { + fn load(path: &str) -> Result { + let sui = Config::from_file_path(path); let service = AtomaServiceConfig::from_file_path(path); let state = AtomaStateManagerConfig::from_file_path(path); let daemon = AtomaDaemonConfig::from_file_path(path); @@ -127,7 +127,7 @@ impl Config { /// async fn example() -> Result<()> { /// let models = vec!["facebook/opt-125m".to_string()]; /// let revisions = vec!["main".to_string()]; -/// +/// /// let tokenizers = initialize_tokenizers(&models, &revisions).await?; /// Ok(()) /// } @@ -174,13 +174,15 @@ async fn initialize_tokenizers( } #[tokio::main] +#[allow(clippy::too_many_lines)] +#[allow(clippy::redundant_pub_crate)] async fn main() -> Result<()> { let _log_guards = setup_logging(LOGS).context("Failed to setup logging")?; dotenv().ok(); let args = Args::parse(); - let config = Config::load(&args.config_path).await?; + let config = NodeConfig::load(&args.config_path)?; info!("Starting Atoma node service"); @@ -203,13 +205,13 @@ async fn main() -> Result<()> { config.sui.max_concurrent_requests(), )?; let address = wallet_ctx.active_address()?; - let address_index = args.address_index.unwrap_or( + let address_index = args.address_index.unwrap_or_else(|| { wallet_ctx .get_addresses() .iter() .position(|a| a == &address) - .unwrap(), - ); + .unwrap() + }); info!( target = "atoma-node-service", @@ -232,14 +234,14 @@ async fn main() -> Result<()> { shutdown_sender.clone(), ); - let (subscriber_confidential_compute_sender, _subscriber_confidential_compute_receiver) = + let (subscriber_confidential_compute_sender, subscriber_confidential_compute_receiver) = tokio::sync::mpsc::unbounded_channel(); - let (app_state_decryption_sender, _app_state_decryption_receiver) = + let (app_state_decryption_sender, app_state_decryption_receiver) = tokio::sync::mpsc::unbounded_channel(); - let (app_state_encryption_sender, _app_state_encryption_receiver) = + let (app_state_encryption_sender, app_state_encryption_receiver) = tokio::sync::mpsc::unbounded_channel(); - for (_, node_small_id) in config.daemon.node_badges.iter() { + for (_, node_small_id) in &config.daemon.node_badges { if let Err(e) = register_on_proxy(&config.proxy, *node_small_id, &keystore, address_index).await { @@ -259,19 +261,19 @@ async fn main() -> Result<()> { ); let client = Arc::new(RwLock::new( - AtomaSuiClient::new_from_config(args.config_path).await?, + Client::new_from_config(args.config_path).await?, )); - let (compute_shared_secret_sender, _compute_shared_secret_receiver) = + let (compute_shared_secret_sender, compute_shared_secret_receiver) = tokio::sync::mpsc::unbounded_channel(); let confidential_compute_service_handle = spawn_with_shutdown( - AtomaConfidentialComputeService::start_confidential_compute_service( + AtomaConfidentialCompute::start_confidential_compute_service( client.clone(), - _subscriber_confidential_compute_receiver, - _app_state_decryption_receiver, - _app_state_encryption_receiver, - _compute_shared_secret_receiver, + subscriber_confidential_compute_receiver, + app_state_decryption_receiver, + app_state_encryption_receiver, + compute_shared_secret_receiver, shutdown_receiver.clone(), ), shutdown_sender.clone(), @@ -286,7 +288,7 @@ async fn main() -> Result<()> { "Spawning subscriber service" ); - let subscriber = SuiEventSubscriber::new( + let subscriber = Subscriber::new( config.sui, event_subscriber_sender, stack_retrieve_receiver, @@ -320,8 +322,8 @@ async fn main() -> Result<()> { shutdown_sender.clone(), ); - let hf_token = std::env::var(HF_TOKEN) - .context(format!("Variable {} not set in the .env file", HF_TOKEN))?; + let hf_token = + std::env::var(HF_TOKEN).context(format!("Variable {HF_TOKEN} not set in the .env file"))?; let tokenizers = initialize_tokenizers(&config.service.models, &config.service.revisions, hf_token).await?; diff --git a/atoma-confidential/src/key_management.rs b/atoma-confidential/src/key_management.rs index 7a1066e7..5cdb52a1 100644 --- a/atoma-confidential/src/key_management.rs +++ b/atoma-confidential/src/key_management.rs @@ -1,6 +1,4 @@ -use atoma_utils::encryption::{ - decrypt_ciphertext, encrypt_plaintext, EncryptionError, NONCE_BYTE_SIZE, -}; +use atoma_utils::encryption::{decrypt_ciphertext, encrypt_plaintext, Error, NONCE_BYTE_SIZE}; use thiserror::Error; use x25519_dalek::{PublicKey, SharedSecret, StaticSecret}; @@ -21,6 +19,15 @@ pub struct X25519KeyPairManager { impl X25519KeyPairManager { /// Constructor + /// + /// # Returns + /// A new `X25519KeyPairManager` instance if successful + /// + /// # Errors + /// Returns an error if: + /// - Failed to initialize the key store + /// - Failed to load existing keys + /// - Failed to generate initial keys #[allow(clippy::new_without_default)] pub fn new() -> Result { let mut rng = rand::thread_rng(); @@ -36,6 +43,7 @@ impl X25519KeyPairManager { /// - Identity verification /// /// The public key will change when `rotate_keys()` is called. + #[must_use] pub fn get_public_key(&self) -> PublicKey { PublicKey::from(&self.secret_key) } @@ -72,6 +80,7 @@ impl X25519KeyPairManager { /// /// # Returns /// - `SharedSecret` - The shared secret + #[must_use] pub fn compute_shared_secret(&self, public_key: &PublicKey) -> SharedSecret { self.secret_key.diffie_hellman(public_key) } @@ -91,7 +100,12 @@ impl X25519KeyPairManager { /// /// # Returns /// * `Ok(Vec)` - The decrypted plaintext as a byte vector - /// * `Err(KeyManagementError)` - If decryption fails + /// + /// # Errors + /// Returns a `KeyManagementError` if: + /// - Decryption fails due to invalid key material + /// - The ciphertext is malformed + /// - The authentication tag is invalid /// /// # Example /// ```rust,ignore @@ -132,7 +146,12 @@ impl X25519KeyPairManager { /// * `Ok((Vec, [u8; NONCE_BYTE_SIZE]))` - A tuple containing: /// - The encrypted ciphertext as a byte vector /// - A randomly generated nonce used in the encryption - /// * `Err(KeyManagementError)` - If encryption fails + /// + /// # Errors + /// Returns a `KeyManagementError` if: + /// - Encryption fails due to invalid key material + /// - Random number generation fails + /// - The plaintext is too large /// /// # Example /// ```rust,ignore @@ -156,10 +175,14 @@ impl X25519KeyPairManager { } } +/// Errors that can occur during key management operations #[derive(Debug, Error)] pub enum KeyManagementError { + /// Error during encryption/decryption operations #[error("Encryption error: `{0}`")] - EncryptionError(#[from] EncryptionError), + EncryptionError(#[from] Error), + + /// Error during file I/O operations #[error("IO error: `{0}`")] IoError(#[from] std::io::Error), } diff --git a/atoma-confidential/src/lib.rs b/atoma-confidential/src/lib.rs index c55d632b..9eccf266 100644 --- a/atoma-confidential/src/lib.rs +++ b/atoma-confidential/src/lib.rs @@ -1,10 +1,13 @@ +#![allow(clippy::doc_markdown)] +#![allow(clippy::module_name_repetitions)] + pub mod key_management; pub mod service; #[cfg(feature = "tdx")] pub mod tdx; pub mod types; -pub use service::AtomaConfidentialComputeService; +pub use service::AtomaConfidentialCompute; /// Trait for converting types into a byte representation /// diff --git a/atoma-confidential/src/service.rs b/atoma-confidential/src/service.rs index 51519d07..630b1c1f 100644 --- a/atoma-confidential/src/service.rs +++ b/atoma-confidential/src/service.rs @@ -11,7 +11,7 @@ use crate::{ tdx::{get_compute_data_attestation, TdxError}, ToBytes, }; -use atoma_sui::client::AtomaSuiClient; +use atoma_sui::client::Client; use atoma_sui::{client::AtomaSuiClientError, events::AtomaEvent}; use atoma_utils::constants::NONCE_SIZE; use std::sync::Arc; @@ -46,12 +46,12 @@ type ServiceSharedSecretRequest = ( /// - Managing TDX key rotations and attestations /// - Submitting attestations to the Sui blockchain /// - Graceful shutdown handling -pub struct AtomaConfidentialComputeService { +pub struct AtomaConfidentialCompute { /// Client for interacting with the Sui blockchain to submit attestations and transactions /// NOTE: We disable clippy's `dead_code` lint warning here, as the `sui_client` is used /// in the `submit_node_key_rotation_tdx_attestation` method, when the tdx feature is enabled. #[allow(dead_code)] - sui_client: Arc>, + sui_client: Arc>, /// Current key rotation counter key_rotation_counter: Option, /// Manages TDX key operations including key rotation and attestation generation @@ -68,10 +68,25 @@ pub struct AtomaConfidentialComputeService { shutdown_signal: tokio::sync::watch::Receiver, } -impl AtomaConfidentialComputeService { +impl AtomaConfidentialCompute { /// Constructor + /// + /// # Arguments + /// * `sui_client` - Configuration settings for the client + /// * `event_receiver` - Channel receiver for Atoma events + /// * `service_decryption_receiver` - Channel receiver for decryption requests + /// * `service_encryption_receiver` - Channel receiver for encryption requests + /// * `service_shared_secret_receiver` - Channel receiver for shared secret requests + /// * `shutdown_signal` - Channel receiver for shutdown signals + /// + /// # Returns + /// A new client instance + /// + /// # Errors + /// Returns `AtomaConfidentialComputeError` if: + /// - Key manager initialization fails pub fn new( - sui_client: Arc>, + sui_client: Arc>, event_receiver: UnboundedReceiver, service_decryption_receiver: UnboundedReceiver, service_encryption_receiver: UnboundedReceiver, @@ -116,7 +131,7 @@ impl AtomaConfidentialComputeService { /// * `AtomaConfidentialComputeError::SuiClientError` if attestation submission fails #[instrument(level = "info", skip_all)] pub async fn start_confidential_compute_service( - sui_client: Arc>, + sui_client: Arc>, event_receiver: UnboundedReceiver, service_decryption_receiver: UnboundedReceiver, service_encryption_receiver: UnboundedReceiver, @@ -147,6 +162,7 @@ impl AtomaConfidentialComputeService { /// /// # Returns /// - `x25519_dalek::PublicKey`: The current public key from the key manager + #[must_use] pub fn get_public_key(&self) -> x25519_dalek::PublicKey { self.key_manager.get_public_key() } @@ -158,6 +174,7 @@ impl AtomaConfidentialComputeService { /// /// # Returns /// - `x25519_dalek::StaticSecret`: The shared secret between the node and the proxy + #[must_use] pub fn compute_shared_secret( &self, client_x25519_public_key: &PublicKey, @@ -337,18 +354,7 @@ impl AtomaConfidentialComputeService { client_dh_public_key, node_dh_public_key, } = decryption_request; - let result = if PublicKey::from(node_dh_public_key) != self.key_manager.get_public_key() { - tracing::error!( - target = "atoma-confidential-compute-service", - event = "confidential_compute_service_decryption_error", - "Node X25519 public key does not match the expected key: {:?} != {:?}", - node_dh_public_key, - self.key_manager.get_public_key().as_bytes() - ); - Err(anyhow::anyhow!( - "Node X25519 public key does not match the expected key" - )) - } else { + let result = if PublicKey::from(node_dh_public_key) == self.key_manager.get_public_key() { self.key_manager .decrypt_ciphertext(client_dh_public_key, &ciphertext, &salt, &nonce) .map_err(|e| { @@ -360,6 +366,15 @@ impl AtomaConfidentialComputeService { ); anyhow::anyhow!(e) }) + } else { + tracing::error!( + target = "atoma-confidential-compute-service", + event = "confidential_compute_service_decryption_error", + "Failed to decrypt request: node public key mismatch" + ); + Err(anyhow::anyhow!( + "Node X25519 public key does not match the expected key" + )) }; let message = result .map(|plaintext| ConfidentialComputeDecryptionResponse { plaintext }) @@ -517,8 +532,7 @@ impl AtomaConfidentialComputeService { // for a previous key rotation counter and not for the current one). if self .key_rotation_counter - .map(|counter| counter < event.key_rotation_counter) - .unwrap_or(true) + .map_or(true, |counter| counter < event.key_rotation_counter) { self.submit_node_key_rotation_tdx_attestation().await?; } diff --git a/atoma-daemon/src/components/mod.rs b/atoma-daemon/src/components/mod.rs index a342ca7b..7951b476 100644 --- a/atoma-daemon/src/components/mod.rs +++ b/atoma-daemon/src/components/mod.rs @@ -1 +1 @@ -pub(crate) mod openapi; +pub mod openapi; diff --git a/atoma-daemon/src/components/openapi.rs b/atoma-daemon/src/components/openapi.rs index d5db872e..2f90dd6d 100644 --- a/atoma-daemon/src/components/openapi.rs +++ b/atoma-daemon/src/components/openapi.rs @@ -46,7 +46,7 @@ pub fn openapi_routes() -> Router { let spec_path = docs_dir.join("openapi.yml"); fs::write(&spec_path, spec).expect("Failed to write OpenAPI spec to file"); - println!("OpenAPI spec written to: {:?}", spec_path); + println!("OpenAPI spec written to: {spec_path:?}"); } Router::new() diff --git a/atoma-daemon/src/handlers/attestation_disputes.rs b/atoma-daemon/src/handlers/attestation_disputes.rs index a71e22ed..3c929330 100644 --- a/atoma-daemon/src/handlers/attestation_disputes.rs +++ b/atoma-daemon/src/handlers/attestation_disputes.rs @@ -20,7 +20,7 @@ pub const ATTESTATION_DISPUTES_PATH: &str = "/attestation_disputes"; ), components(schemas(StackAttestationDispute)) )] -pub(crate) struct AttestationDisputesOpenApi; +pub struct AttestationDisputesOpenApi; //TODO: this endpoint can be merged into one (I think) through filters diff --git a/atoma-daemon/src/handlers/claimed_stacks.rs b/atoma-daemon/src/handlers/claimed_stacks.rs index 52a1a55b..dc1bbe52 100644 --- a/atoma-daemon/src/handlers/claimed_stacks.rs +++ b/atoma-daemon/src/handlers/claimed_stacks.rs @@ -12,7 +12,7 @@ pub const CLAIMED_STACKS_PATH: &str = "/claimed-stacks"; paths(claimed_stacks_nodes_list), components(schemas(StackSettlementTicket)) )] -pub(crate) struct ClaimedStacksOpenApi; +pub struct ClaimedStacksOpenApi; pub fn claimed_stacks_router() -> Router { Router::new().route( diff --git a/atoma-daemon/src/handlers/mod.rs b/atoma-daemon/src/handlers/mod.rs index 4867faa0..9872386d 100644 --- a/atoma-daemon/src/handlers/mod.rs +++ b/atoma-daemon/src/handlers/mod.rs @@ -1,6 +1,6 @@ -pub(crate) mod attestation_disputes; -pub(crate) mod claimed_stacks; -pub(crate) mod nodes; -pub(crate) mod stacks; -pub(crate) mod subscriptions; -pub(crate) mod tasks; +pub mod attestation_disputes; +pub mod claimed_stacks; +pub mod nodes; +pub mod stacks; +pub mod subscriptions; +pub mod tasks; diff --git a/atoma-daemon/src/handlers/nodes.rs b/atoma-daemon/src/handlers/nodes.rs index 994273b0..c1cbc3b9 100644 --- a/atoma-daemon/src/handlers/nodes.rs +++ b/atoma-daemon/src/handlers/nodes.rs @@ -54,7 +54,7 @@ pub const NODES_PATH: &str = "/nodes"; NodeClaimFundsResponse )) )] -pub(crate) struct NodesOpenApi; +pub struct NodesOpenApi; /// Router for handling node-related endpoints /// @@ -118,8 +118,11 @@ pub async fn nodes_register( gas_budget, gas_price, } = value; - let mut tx_client = daemon_state.client.write().await; - let tx_digest = tx_client + + let tx_digest = daemon_state + .client + .write() + .await .submit_node_registration_tx(gas, gas_budget, gas_price) .await .map_err(|_| { diff --git a/atoma-daemon/src/handlers/stacks.rs b/atoma-daemon/src/handlers/stacks.rs index 49eafe9a..f45a5c46 100644 --- a/atoma-daemon/src/handlers/stacks.rs +++ b/atoma-daemon/src/handlers/stacks.rs @@ -23,7 +23,7 @@ pub struct StackQuery { paths(stacks_nodes_list), components(schemas(Stack, StackSettlementTicket, StackQuery)) )] -pub(crate) struct StacksOpenApi; +pub struct StacksOpenApi; pub fn stacks_router() -> Router { Router::new().route( diff --git a/atoma-daemon/src/handlers/subscriptions.rs b/atoma-daemon/src/handlers/subscriptions.rs index ae72baad..97a27ba0 100644 --- a/atoma-daemon/src/handlers/subscriptions.rs +++ b/atoma-daemon/src/handlers/subscriptions.rs @@ -14,7 +14,7 @@ pub const SUBSCRIPTIONS_PATH: &str = "/subscriptions"; #[derive(OpenApi)] #[openapi(paths(subscriptions_nodes_list), components(schemas(NodeSubscription)))] -pub(crate) struct SubscriptionsOpenApi; +pub struct SubscriptionsOpenApi; /// Router for handling subscription-related endpoints /// diff --git a/atoma-daemon/src/handlers/tasks.rs b/atoma-daemon/src/handlers/tasks.rs index 6896aa42..e3ac2e2d 100644 --- a/atoma-daemon/src/handlers/tasks.rs +++ b/atoma-daemon/src/handlers/tasks.rs @@ -9,7 +9,7 @@ pub const TASKS_PATH: &str = "/tasks"; #[derive(OpenApi)] #[openapi(paths(tasks_list), components(schemas(Task)))] -pub(crate) struct TasksOpenApi; +pub struct TasksOpenApi; /// Router for handling task-related endpoints /// diff --git a/atoma-daemon/src/lib.rs b/atoma-daemon/src/lib.rs index 9a057178..d25b23bd 100644 --- a/atoma-daemon/src/lib.rs +++ b/atoma-daemon/src/lib.rs @@ -1,3 +1,11 @@ +#![allow(clippy::cast_possible_truncation)] +#![allow(clippy::cast_possible_wrap)] +#![allow(clippy::cast_precision_loss)] +#![allow(clippy::missing_docs_in_private_items)] +#![allow(clippy::doc_markdown)] +#![allow(clippy::module_name_repetitions)] +#![allow(clippy::cast_sign_loss)] + pub(crate) mod components; pub mod config; pub(crate) mod handlers; diff --git a/atoma-daemon/src/server.rs b/atoma-daemon/src/server.rs index 3b6bf754..fa64741e 100644 --- a/atoma-daemon/src/server.rs +++ b/atoma-daemon/src/server.rs @@ -1,5 +1,5 @@ use atoma_state::state_manager::AtomaState; -use atoma_sui::client::AtomaSuiClient; +use atoma_sui::client::Client; use axum::{http::StatusCode, routing::get, Router}; use std::sync::Arc; use sui_sdk::types::base_types::ObjectID; @@ -37,7 +37,7 @@ use crate::{ /// ```rust,ignore /// // Create a new daemon state instance /// let daemon_state = DaemonState { -/// client: Arc::new(RwLock::new(AtomaSuiClient::new())), +/// client: Arc::new(RwLock::new(Client::new())), /// state_manager: AtomaStateManager::new(), /// node_badges: vec![(ObjectID::new([0; 32]), 1)], /// }; @@ -50,7 +50,7 @@ pub struct DaemonState { /// Thread-safe reference to the Sui blockchain client that handles all blockchain interactions. /// Wrapped in `Arc` to allow multiple handlers to safely access and modify the client /// state concurrently. - pub client: Arc>, + pub client: Arc>, /// Manages the persistent state of nodes, tasks, and other system components. /// Handles database operations and state synchronization. @@ -62,41 +62,25 @@ pub struct DaemonState { pub node_badges: Vec<(ObjectID, u64)>, } -/// Starts and runs the Atoma daemon service, handling HTTP requests and graceful shutdown. -/// This function initializes and runs the main daemon service that handles node operations, +/// Runs the daemon server, handling incoming connections and graceful shutdown. /// /// # Arguments -/// -/// * `daemon_state` - The shared state container for the daemon service, containing the Sui client, -/// state manager, and node badge information -/// * `tcp_listener` - A pre-configured TCP listener that the HTTP server will bind to +/// * `daemon_state` - The shared state of the daemon +/// * `tcp_listener` - The TCP listener for accepting connections +/// * `shutdown_receiver` - Channel receiver for shutdown signals /// /// # Returns +/// Result indicating success or failure of server operation /// -/// * `anyhow::Result<()>` - Ok(()) on successful shutdown, or an error if -/// server initialization or shutdown fails -/// -/// # Shutdown Behavior +/// # Errors +/// Returns an error if: +/// - Failed to accept new connections +/// - Failed to spawn connection handlers /// -/// The server implements graceful shutdown by: -/// 1. Listening for a Ctrl+C signal -/// 2. Logging shutdown initiation -/// 3. Waiting for existing connections to complete -/// -/// # Example -/// -/// ```rust,ignore -/// use tokio::net::TcpListener; -/// use tokio::sync::watch; -/// use atoma_daemon::{DaemonState, run_server}; -/// -/// async fn start_server() -> Result<(), Box> { -/// let daemon_state = DaemonState::new(/* ... */); -/// let listener = TcpListener::bind("127.0.0.1:3000").await?; -/// -/// run_server(daemon_state, listener).await -/// } -/// ``` +/// # Panics +/// This function will panic if: +/// - The shutdown signal receiver is closed unexpectedly +/// - Failed to receive shutdown signal pub async fn run_server( daemon_state: DaemonState, tcp_listener: TcpListener, @@ -108,7 +92,7 @@ pub async fn run_server( shutdown_receiver .changed() .await - .expect("Error receiving shutdown signal") + .expect("Error receiving shutdown signal"); }); server.await?; Ok(()) diff --git a/atoma-service/src/components/mod.rs b/atoma-service/src/components/mod.rs index a342ca7b..7951b476 100644 --- a/atoma-service/src/components/mod.rs +++ b/atoma-service/src/components/mod.rs @@ -1 +1 @@ -pub(crate) mod openapi; +pub mod openapi; diff --git a/atoma-service/src/error.rs b/atoma-service/src/error.rs index 8aaf98fc..fc7f32d6 100644 --- a/atoma-service/src/error.rs +++ b/atoma-service/src/error.rs @@ -107,7 +107,7 @@ impl AtomaServiceError { /// - `"MODEL_ERROR"` for ML model errors /// - `"AUTH_ERROR"` for authentication failures /// - `"INTERNAL_ERROR"` for unexpected server errors - fn error_code(&self) -> &'static str { + const fn error_code(&self) -> &'static str { match self { Self::MissingHeader { .. } => "MISSING_HEADER", Self::InvalidHeader { .. } => "INVALID_HEADER", @@ -154,7 +154,8 @@ impl AtomaServiceError { /// # Returns /// /// An [`axum::http::StatusCode`] representing the appropriate HTTP response code for this error - pub fn status_code(&self) -> StatusCode { + #[must_use] + pub const fn status_code(&self) -> StatusCode { match self { Self::MissingHeader { .. } | Self::InvalidHeader { .. } @@ -173,14 +174,15 @@ impl AtomaServiceError { /// # Returns /// /// A `String` containing the API endpoint path where the error was encountered. - fn endpoint(&self) -> String { + #[must_use] + pub fn get_endpoint(&self, _endpoint: &str) -> String { match self { - Self::MissingHeader { endpoint, .. } => endpoint.clone(), - Self::InvalidHeader { endpoint, .. } => endpoint.clone(), - Self::InvalidBody { endpoint, .. } => endpoint.clone(), - Self::ModelError { endpoint, .. } => endpoint.clone(), - Self::AuthError { endpoint, .. } => endpoint.clone(), - Self::InternalError { endpoint, .. } => endpoint.clone(), + Self::MissingHeader { endpoint, .. } + | Self::InvalidHeader { endpoint, .. } + | Self::InvalidBody { endpoint, .. } + | Self::ModelError { endpoint, .. } + | Self::AuthError { endpoint, .. } + | Self::InternalError { endpoint, .. } => endpoint.clone(), } } @@ -216,7 +218,7 @@ impl IntoResponse for AtomaServiceError { tracing::error!( target = "atoma-service", event = "error_occurred", - endpoint = self.endpoint(), + endpoint = self.get_endpoint(""), error = %self.message(), ); let error_response = ErrorResponse { diff --git a/atoma-service/src/handlers/chat_completions.rs b/atoma-service/src/handlers/chat_completions.rs index 69e23a9f..48d58162 100644 --- a/atoma-service/src/handlers/chat_completions.rs +++ b/atoma-service/src/handlers/chat_completions.rs @@ -24,7 +24,15 @@ use serde::{Deserialize, Deserializer, Serialize}; use std::{collections::HashMap, time::Duration}; use utoipa::ToSchema; -use crate::{error::AtomaServiceError, handlers::prometheus::*, middleware::RequestMetadata}; +use crate::{ + error::AtomaServiceError, + handlers::prometheus::{ + CHAT_COMPLETIONS_INPUT_TOKENS_METRICS, CHAT_COMPLETIONS_LATENCY_METRICS, + CHAT_COMPLETIONS_NUM_REQUESTS, CHAT_COMPLETIONS_OUTPUT_TOKENS_METRICS, + CHAT_COMPLETIONS_TIME_TO_FIRST_TOKEN, TOTAL_COMPLETED_REQUESTS, TOTAL_FAILED_REQUESTS, + }, + middleware::RequestMetadata, +}; use super::handle_confidential_compute_encryption_response; @@ -94,7 +102,7 @@ const UNKNOWN_MODEL: &str = "unknown"; ChatCompletionsResponse )) )] -pub(crate) struct ChatCompletionsOpenApi; +pub struct ChatCompletionsOpenApi; /// Create chat completion /// @@ -158,7 +166,7 @@ pub async fn chat_completions_handler( let is_stream = payload .get(STREAM_KEY) - .and_then(|s| s.as_bool()) + .and_then(serde_json::Value::as_bool) .unwrap_or_default(); let endpoint = request_metadata.endpoint_path.clone(); @@ -225,13 +233,12 @@ pub async fn chat_completions_handler( /// /// The confidential variant ensures end-to-end encryption of the chat completion responses, /// making it suitable for sensitive or private conversations. - #[derive(OpenApi)] #[openapi( paths(chat_completions_handler), components(schemas(ChatCompletionsRequest, ConfidentialComputeResponse)) )] -pub(crate) struct ConfidentialChatCompletionsOpenApi; +pub struct ConfidentialChatCompletionsOpenApi; /// Handles confidential chat completion requests by providing end-to-end encrypted responses. /// @@ -320,7 +327,7 @@ pub async fn confidential_chat_completions_handler( // Check if streaming is requested let is_stream = payload .get(STREAM_KEY) - .and_then(|s| s.as_bool()) + .and_then(serde_json::Value::as_bool) .unwrap_or_default(); let model = payload @@ -431,18 +438,7 @@ async fn handle_response( estimated_total_compute_units: i64, client_encryption_metadata: Option, ) -> Result, AtomaServiceError> { - if !is_stream { - handle_non_streaming_response( - state, - payload, - stack_small_id, - estimated_total_compute_units, - payload_hash, - client_encryption_metadata, - endpoint, - ) - .await - } else { + if is_stream { let streaming_encryption_metadata = utils::get_streaming_encryption_metadata( state, client_encryption_metadata, @@ -462,6 +458,17 @@ async fn handle_response( endpoint, ) .await + } else { + handle_non_streaming_response( + state, + payload, + stack_small_id, + estimated_total_compute_units, + payload_hash, + client_encryption_metadata, + endpoint, + ) + .await } } @@ -645,9 +652,14 @@ async fn handle_streaming_response( let chat_completions_service_url = state .chat_completions_service_urls .get(&model.to_lowercase()) - .ok_or(AtomaServiceError::InternalError { - message: format!("Chat completions service URL not found, likely that model is not supported by the current node: {}", model), - endpoint: endpoint.clone(), + .ok_or_else(|| { + AtomaServiceError::InternalError { + message: format!( + "Chat completions service URL not found, likely that model is not supported by the current node: {}", + model + ), + endpoint: endpoint.clone(), + } })?; let client = Client::new(); let response = client @@ -823,11 +835,11 @@ pub enum MessageContent { impl std::fmt::Display for MessageContent { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - MessageContent::Text(text) => write!(f, "{}", text), - MessageContent::Array(parts) => { + Self::Text(text) => write!(f, "{}", text), + Self::Array(parts) => { let mut content = String::new(); for part in parts { - content.push_str(&format!("{}\n", part)) + content.push_str(&format!("{}\n", part)); } write!(f, "{}", content) } @@ -844,7 +856,7 @@ impl<'de> Deserialize<'de> for MessageContent { let value: Value = Value::deserialize(deserializer)?; if let Some(s) = value.as_str() { - return Ok(MessageContent::Text(s.to_string())); + return Ok(Self::Text(s.to_string())); } if let Some(arr) = value.as_array() { @@ -852,7 +864,7 @@ impl<'de> Deserialize<'de> for MessageContent { .iter() .map(|v| serde_json::from_value(v.clone()).map_err(serde::de::Error::custom)) .collect(); - return Ok(MessageContent::Array(parts?)); + return Ok(Self::Array(parts?)); } Err(serde::de::Error::custom( @@ -884,10 +896,10 @@ pub enum MessageContentPart { impl std::fmt::Display for MessageContentPart { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - MessageContentPart::Text { r#type, text } => { + Self::Text { r#type, text } => { write!(f, "{}: {}", r#type, text) } - MessageContentPart::Image { r#type, image_url } => { + Self::Image { r#type, image_url } => { write!(f, "{}: [Image URL: {}]", r#type, image_url) } } @@ -1023,10 +1035,9 @@ impl TryFrom> for FinishReason { fn try_from(value: Option<&str>) -> Result { match value { - Some("stopped") => Ok(FinishReason::Stopped), - Some("length_capped") => Ok(FinishReason::LengthCapped), - Some("content_filter") => Ok(FinishReason::ContentFilter), - None => Ok(FinishReason::Stopped), + None | Some("stopped") => Ok(Self::Stopped), + Some("length_capped") => Ok(Self::LengthCapped), + Some("content_filter") => Ok(Self::ContentFilter), _ => Err(format!("Invalid finish reason: {}", value.unwrap())), } } @@ -1051,11 +1062,18 @@ pub struct Usage { pub completion_tokens_details: Option, } -pub(crate) mod utils { +pub mod utils { use atoma_utils::constants::PAYLOAD_HASH_SIZE; use prometheus::HistogramTimer; - use super::*; + use super::{ + handle_confidential_compute_encryption_response, info, instrument, + sign_response_and_update_stack_hash, update_stack_num_compute_units, AppState, + AtomaServiceError, Body, Client, ConfidentialComputeSharedSecretRequest, + ConfidentialComputeSharedSecretResponse, EncryptionMetadata, IntoResponse, Json, Response, + StreamingEncryptionMetadata, Value, CHAT_COMPLETIONS_INPUT_TOKENS_METRICS, + CHAT_COMPLETIONS_OUTPUT_TOKENS_METRICS, CHAT_COMPLETIONS_PATH, MODEL_KEY, UNKNOWN_MODEL, + }; /// Retrieves encryption metadata for streaming chat completions when confidential compute is enabled. /// @@ -1100,7 +1118,7 @@ pub(crate) mod utils { endpoint_path = endpoint ) )] - pub(crate) async fn get_streaming_encryption_metadata( + pub async fn get_streaming_encryption_metadata( state: &AppState, client_encryption_metadata: Option, payload_hash: [u8; PAYLOAD_HASH_SIZE], @@ -1208,7 +1226,7 @@ pub(crate) mod utils { skip_all, fields(stack_small_id, payload_hash, endpoint) )] - pub(crate) async fn send_request_to_inference_service( + pub async fn send_request_to_inference_service( state: &AppState, payload: &Value, stack_small_id: i64, @@ -1223,12 +1241,14 @@ pub(crate) mod utils { let chat_completions_service_url = state .chat_completions_service_urls .get(&model.to_lowercase()) - .ok_or(AtomaServiceError::InternalError { - message: format!( - "Chat completions service URL not found, likely that model is not supported by the current node: {}", - model - ), - endpoint: endpoint.to_string(), + .ok_or_else(|| { + AtomaServiceError::InternalError { + message: format!( + "Chat completions service URL not found, likely that model is not supported by the current node: {}", + model + ), + endpoint: endpoint.to_string(), + } })?; let response = client .post(format!( @@ -1322,7 +1342,7 @@ pub(crate) mod utils { /// let total = extract_total_num_tokens(&response_body, "gpt-4"); /// assert_eq!(total, 30); /// ``` - pub(crate) fn extract_total_num_tokens(response_body: &Value, model: &str) -> i64 { + pub fn extract_total_num_tokens(response_body: &Value, model: &str) -> i64 { let mut total_compute_units = 0; if let Some(usage) = response_body.get("usage") { if let Some(prompt_tokens) = usage.get("prompt_tokens") { @@ -1406,7 +1426,7 @@ pub(crate) mod utils { fields(stack_small_id, estimated_total_compute_units, payload_hash, endpoint) )] #[allow(clippy::too_many_arguments)] - pub(crate) async fn serve_non_streaming_response( + pub async fn serve_non_streaming_response( state: &AppState, mut response_body: Value, stack_small_id: i64, diff --git a/atoma-service/src/handlers/embeddings.rs b/atoma-service/src/handlers/embeddings.rs index e0d779e8..4c9b28a2 100644 --- a/atoma-service/src/handlers/embeddings.rs +++ b/atoma-service/src/handlers/embeddings.rs @@ -35,7 +35,7 @@ pub const MODEL_KEY: &str = "model"; /// the API documentation. #[derive(OpenApi)] #[openapi(paths(embeddings_handler))] -pub(crate) struct EmbeddingsOpenApi; +pub struct EmbeddingsOpenApi; /// Create embeddings /// @@ -139,7 +139,7 @@ pub async fn embeddings_handler( /// the API documentation. #[derive(OpenApi)] #[openapi(paths(confidential_embeddings_handler))] -pub(crate) struct ConfidentialEmbeddingsOpenApi; +pub struct ConfidentialEmbeddingsOpenApi; /// Handler for confidential embeddings requests /// diff --git a/atoma-service/src/handlers/image_generations.rs b/atoma-service/src/handlers/image_generations.rs index 2bb7774b..d19bedc5 100644 --- a/atoma-service/src/handlers/image_generations.rs +++ b/atoma-service/src/handlers/image_generations.rs @@ -36,7 +36,7 @@ pub const MODEL_KEY: &str = "model"; /// the API documentation. #[derive(OpenApi)] #[openapi(paths(image_generations_handler))] -pub(crate) struct ImageGenerationsOpenApi; +pub struct ImageGenerationsOpenApi; /// Create image generation /// @@ -143,7 +143,7 @@ pub async fn image_generations_handler( /// computing requirements. #[derive(OpenApi)] #[openapi(paths(confidential_image_generations_handler))] -pub(crate) struct ConfidentialImageGenerationsOpenApi; +pub struct ConfidentialImageGenerationsOpenApi; /// Handles confidential image generation requests /// diff --git a/atoma-service/src/handlers/mod.rs b/atoma-service/src/handlers/mod.rs index 9c4e005b..4ba00449 100644 --- a/atoma-service/src/handlers/mod.rs +++ b/atoma-service/src/handlers/mod.rs @@ -1,7 +1,7 @@ -pub(crate) mod chat_completions; -pub(crate) mod embeddings; -pub(crate) mod image_generations; -pub(crate) mod prometheus; +pub mod chat_completions; +pub mod embeddings; +pub mod image_generations; +pub mod prometheus; use atoma_confidential::types::{ ConfidentialComputeEncryptionRequest, ConfidentialComputeEncryptionResponse, @@ -130,7 +130,7 @@ async fn sign_response_and_update_stack_hash( skip(state, response_body, client_encryption_metadata), fields(event = "confidential-compute-encryption-response") )] -pub(crate) async fn handle_confidential_compute_encryption_response( +pub async fn handle_confidential_compute_encryption_response( state: &AppState, mut response_body: Value, client_encryption_metadata: Option, @@ -165,18 +165,17 @@ pub(crate) async fn handle_confidential_compute_encryption_response( } let (sender, receiver) = tokio::sync::oneshot::channel(); - let usage = if endpoint != CONFIDENTIAL_IMAGE_GENERATIONS_PATH { - Some( - response_body - .get(USAGE_KEY) - .ok_or(AtomaServiceError::InvalidBody { + let usage = + if endpoint == CONFIDENTIAL_IMAGE_GENERATIONS_PATH { + None + } else { + Some(response_body.get(USAGE_KEY).ok_or_else(|| { + AtomaServiceError::InvalidBody { message: "Usage not found in response body".to_string(), endpoint: endpoint.clone(), - })?, - ) - } else { - None - }; + } + })?) + }; state .encryption_sender .send(( @@ -188,13 +187,13 @@ pub(crate) async fn handle_confidential_compute_encryption_response( sender, )) .map_err(|e| AtomaServiceError::InternalError { - message: format!("Error sending encryption request: {}", e), + message: format!("Error sending encryption request: {e}"), endpoint: endpoint.clone(), })?; let result = receiver .await .map_err(|e| AtomaServiceError::InternalError { - message: format!("Error receiving encryption response: {}", e), + message: format!("Error receiving encryption response: {e}"), endpoint: endpoint.clone(), })?; match result { @@ -288,7 +287,7 @@ pub(crate) async fn handle_confidential_compute_encryption_response( endpoint ) )] -pub(crate) fn update_stack_num_compute_units( +pub fn update_stack_num_compute_units( state_manager_sender: &Sender, stack_small_id: i64, estimated_total_compute_units: i64, @@ -302,7 +301,7 @@ pub(crate) fn update_stack_num_compute_units( estimated_total_compute_units, }) .map_err(|e| AtomaServiceError::InternalError { - message: format!("Error sending update stack num compute units event: {}", e,), + message: format!("Error sending update stack num compute units event: {e}"), endpoint: endpoint.to_string(), }) } diff --git a/atoma-service/src/lib.rs b/atoma-service/src/lib.rs index 4b6ff5d1..aafff919 100644 --- a/atoma-service/src/lib.rs +++ b/atoma-service/src/lib.rs @@ -2,6 +2,16 @@ //! and supports multiple signature schemes (including ed25519, secp256k1, and secp256r1, //! matching SUI's supported cryptography primitives). +#![allow(clippy::cast_possible_truncation)] +#![allow(clippy::cast_possible_wrap)] +#![allow(clippy::cast_precision_loss)] +#![allow(clippy::missing_docs_in_private_items)] +#![allow(clippy::doc_markdown)] +#![allow(clippy::module_name_repetitions)] +#![allow(clippy::cast_sign_loss)] +#![allow(clippy::items_after_statements)] +#![allow(clippy::uninlined_format_args)] + pub(crate) mod components; pub mod config; pub mod error; diff --git a/atoma-service/src/middleware.rs b/atoma-service/src/middleware.rs index f83b9e90..aecc6a67 100644 --- a/atoma-service/src/middleware.rs +++ b/atoma-service/src/middleware.rs @@ -98,7 +98,7 @@ pub struct RequestMetadata { /// The type of request /// /// This enum is used to determine the type of request based on the path of the request. -#[derive(Clone, Debug, Default, PartialEq)] +#[derive(Clone, Debug, Default, Eq, PartialEq, Copy)] pub enum RequestType { #[default] ChatCompletions, @@ -109,7 +109,8 @@ pub enum RequestType { impl RequestMetadata { /// Create a new `RequestMetadata` with the given stack info - pub fn with_stack_info( + #[must_use] + pub const fn with_stack_info( mut self, stack_small_id: i64, estimated_total_compute_units: i64, @@ -120,7 +121,8 @@ impl RequestMetadata { } /// Create a new `RequestMetadata` with the given payload hash - pub fn with_payload_hash(mut self, payload_hash: [u8; PAYLOAD_HASH_SIZE]) -> Self { + #[must_use] + pub const fn with_payload_hash(mut self, payload_hash: [u8; PAYLOAD_HASH_SIZE]) -> Self { self.payload_hash = payload_hash; self } @@ -140,7 +142,8 @@ impl RequestMetadata { /// let metadata = RequestMetadata::default() /// .with_request_type(RequestType::ChatCompletions); /// ``` - pub fn with_request_type(mut self, request_type: RequestType) -> Self { + #[must_use] + pub const fn with_request_type(mut self, request_type: RequestType) -> Self { self.request_type = request_type; self } @@ -160,7 +163,8 @@ impl RequestMetadata { /// let metadata = RequestMetadata::default() /// .with_client_encryption_metadata(client_dh_public_key, salt); /// ``` - pub fn with_client_encryption_metadata( + #[must_use] + pub const fn with_client_encryption_metadata( mut self, client_x25519_public_key: [u8; DH_PUBLIC_KEY_SIZE], salt: [u8; SALT_SIZE], @@ -185,6 +189,7 @@ impl RequestMetadata { /// /// let metadata = RequestMetadata::default().with_endpoint_path(CHAT_COMPLETIONS_PATH.to_string()); /// ``` + #[must_use] pub fn with_endpoint_path(mut self, endpoint_path: String) -> Self { self.endpoint_path = endpoint_path; self @@ -429,13 +434,8 @@ pub async fn verify_stack_permissions( }); } - let total_num_compute_units = utils::calculate_compute_units( - &body_json, - request_type.clone(), - &state, - model, - endpoint.clone(), - )?; + let total_num_compute_units = + utils::calculate_compute_units(&body_json, request_type, &state, model, endpoint.clone())?; let (result_sender, result_receiver) = oneshot::channel(); state @@ -647,7 +647,13 @@ pub async fn confidential_compute_middleware( pub(crate) mod utils { use hyper::HeaderMap; - use super::*; + use super::{ + blake2b_hash, instrument, oneshot, verify_signature, AppState, AtomaServiceError, + ConfidentialComputeDecryptionRequest, ConfidentialComputeRequest, DecryptionMetadata, + Engine, RequestType, TransactionDigest, Value, DEFAULT_MAX_TOKENS_CHAT_COMPLETIONS, + DH_PUBLIC_KEY_SIZE, IMAGE_N, IMAGE_SIZE, INPUT, MAX_TOKENS, MESSAGES, NONCE_SIZE, + PAYLOAD_HASH_SIZE, SALT_SIZE, STANDARD, + }; /// Requests and verifies stack information from the blockchain for a given transaction. /// @@ -690,7 +696,7 @@ pub(crate) mod utils { /// 42, // stack_small_id /// "/v1/completions".to_string() /// ).await; - /// + /// /// match result { /// Ok(()) => println!("Stack verified successfully"), /// Err(e) => eprintln!("Stack verification failed: {}", e), @@ -706,7 +712,7 @@ pub(crate) mod utils { /// - Stack exists but has insufficient compute units /// - Stack small ID doesn't match the expected value #[instrument(level = "trace", skip_all)] - pub(crate) async fn request_blockchain_for_stack( + pub async fn request_blockchain_for_stack( state: &AppState, tx_digest: TransactionDigest, estimated_compute_units: i64, @@ -768,7 +774,7 @@ pub(crate) mod utils { /// - `calculate_chat_completion_compute_units` /// - `calculate_embedding_compute_units` /// - `calculate_image_generation_compute_units` - pub(crate) fn calculate_compute_units( + pub fn calculate_compute_units( body_json: &Value, request_type: RequestType, state: &AppState, @@ -829,7 +835,7 @@ pub(crate) mod utils { /// } /// ``` #[instrument(level = "trace", skip_all)] - pub(crate) fn calculate_chat_completion_compute_units( + pub fn calculate_chat_completion_compute_units( body_json: &Value, state: &AppState, model: &str, @@ -887,7 +893,7 @@ pub(crate) mod utils { total_num_compute_units += body_json .get(MAX_TOKENS) - .and_then(|value| value.as_i64()) + .and_then(serde_json::Value::as_i64) .unwrap_or(DEFAULT_MAX_TOKENS_CHAT_COMPLETIONS); Ok(total_num_compute_units) @@ -967,14 +973,12 @@ pub(crate) mod utils { Value::Array(texts) => texts .iter() .map(|v| { - v.as_str() - .map(|s| { - state.tokenizers[tokenizer_index] - .encode(s, true) - .map(|tokens| tokens.get_ids().len() as i64) - .unwrap_or(0) - }) - .unwrap_or(0) + v.as_str().map_or(0, |s| { + state.tokenizers[tokenizer_index] + .encode(s, true) + .map(|tokens| tokens.get_ids().len() as i64) + .unwrap_or(0) + }) }) .sum(), _ => { @@ -1046,7 +1050,7 @@ pub(crate) mod utils { // n is the number of images to generate let n = body_json .get(IMAGE_N) - .and_then(|v| v.as_u64()) + .and_then(serde_json::Value::as_u64) .ok_or_else(|| AtomaServiceError::InvalidBody { message: "Invalid or missing image count (n)".to_string(), endpoint: endpoint.clone(), @@ -1093,7 +1097,7 @@ pub(crate) mod utils { /// } /// ``` #[instrument(level = "trace", skip_all)] - pub(crate) fn verify_plaintext_body_hash( + pub fn verify_plaintext_body_hash( plaintext_body_hash: &[u8; PAYLOAD_HASH_SIZE], headers: &HeaderMap, endpoint: &str, @@ -1154,7 +1158,7 @@ pub(crate) mod utils { /// } /// ``` #[instrument(level = "trace", skip_all)] - pub(crate) async fn decrypt_confidential_compute_request( + pub async fn decrypt_confidential_compute_request( state: &AppState, confidential_compute_request: &ConfidentialComputeRequest, endpoint: &str, @@ -1301,7 +1305,7 @@ pub(crate) mod utils { /// } /// ``` #[instrument(level = "trace", skip_all)] - pub(crate) fn check_plaintext_body_hash( + pub fn check_plaintext_body_hash( plaintext_body_hash_bytes: [u8; PAYLOAD_HASH_SIZE], plaintext: &[u8], endpoint: &str, diff --git a/atoma-service/src/proxy/mod.rs b/atoma-service/src/proxy/mod.rs index 04cacd27..8f365004 100644 --- a/atoma-service/src/proxy/mod.rs +++ b/atoma-service/src/proxy/mod.rs @@ -30,6 +30,14 @@ const SIGNATURE: &str = "signature"; /// * `node_small_id` - Small ID of the node /// * `keystore` - Keystore for signing the registration request /// * `address_index` - Index of the address to use for signing +/// +/// # Errors +/// +/// This function will return an error if: +/// - The request to the proxy server fails +/// - The server returns a non-success status code +/// - The signature generation fails +/// - The HTTP request fails to be sent pub async fn register_on_proxy( config: &ProxyConfig, node_small_id: u64, diff --git a/atoma-service/src/server.rs b/atoma-service/src/server.rs index b75e9454..690d793f 100644 --- a/atoma-service/src/server.rs +++ b/atoma-service/src/server.rs @@ -261,8 +261,7 @@ pub fn create_router(app_state: AppState) -> Router { /// * `app_state` - The shared application state containing database connections, tokenizers, /// and other configuration. /// * `tcp_listener` - A configured TCP listener that specifies the address and port for the server. -/// * `shutdown_sender` - A channel sender used to communicate the shutdown status to other parts -/// of the application. +/// * `shutdown_receiver` - A channel receiver used to listen for shutdown signals. /// /// # Returns /// @@ -274,6 +273,12 @@ pub fn create_router(app_state: AppState) -> Router { /// - The server fails to start or encounters an error while running /// - The shutdown signal fails to be sent through the channel /// +/// # Panics +/// +/// This function will panic if: +/// - The shutdown receiver channel is closed unexpectedly +/// - The shutdown signal cannot be received due to channel errors +/// /// # Example /// /// ```rust,ignore @@ -281,7 +286,7 @@ pub fn create_router(app_state: AppState) -> Router { /// let listener = TcpListener::bind("127.0.0.1:3000").await?; /// let (shutdown_tx, shutdown_rx) = watch::channel(false); /// -/// run_server(app_state, listener, shutdown_tx).await?; +/// run_server(app_state, listener, shutdown_rx).await?; /// ``` pub async fn run_server( app_state: AppState, @@ -294,7 +299,7 @@ pub async fn run_server( shutdown_receiver .changed() .await - .expect("Error receiving shutdown signal") + .expect("Error receiving shutdown signal"); }); server.await?; @@ -380,7 +385,7 @@ async fn metrics_handler() -> Result { } pub(crate) mod utils { - use super::*; + use super::{FileBasedKeystore, Value}; use atoma_utils::hashing::blake2b_hash; use sui_keys::keystore::AccountKeystore; @@ -408,7 +413,7 @@ pub(crate) mod utils { /// Returns an error if: /// * The keystore fails to sign the hash /// * The SHA-256 hash cannot be converted to a 32-byte array - pub(crate) fn sign_response_body( + pub fn sign_response_body( response_body: &Value, keystore: &FileBasedKeystore, address_index: usize, diff --git a/atoma-service/src/streamer.rs b/atoma-service/src/streamer.rs index 9a52738f..6e99e279 100644 --- a/atoma-service/src/streamer.rs +++ b/atoma-service/src/streamer.rs @@ -320,7 +320,7 @@ impl Streamer { "Error signing response: {}", e ); - Error::new(format!("Error signing response: {}", e)) + Error::new(format!("Error signing response: {e}")) }, )?; @@ -388,7 +388,7 @@ impl Streamer { "Error encrypting chunk: {}", e ); - Error::new(format!("Error encrypting chunk: {}", e)) + Error::new(format!("Error encrypting chunk: {e}")) })?; if let Some(usage) = usage { @@ -472,8 +472,7 @@ impl Streamer { e ); return Poll::Ready(Some(Err(Error::new(format!( - "Invalid UTF-8 sequence: {}", - e + "Invalid UTF-8 sequence: {e}", ))))); } }; @@ -517,10 +516,9 @@ impl Streamer { "Error parsing chunk {chunk_str}: {}", e ); - return Poll::Ready(Some(Err(Error::new(format!( - "Error parsing chunk: {}", - e - ))))); + return Poll::Ready(Some(Err(Error::new( + format!("Error parsing chunk: {e}",), + )))); } self.chunk_buffer.push_str(chunk_str); @@ -549,8 +547,7 @@ impl Streamer { ); self.chunk_buffer.clear(); return Poll::Ready(Some(Err(Error::new(format!( - "Error parsing chunk: {}", - e + "Error parsing chunk: {e}", ))))); } } @@ -568,17 +565,14 @@ impl Streamer { let (signature, response_hash) = self.sign_chunk(&chunk)?; - let choices = match chunk.get(CHOICES).and_then(|choices| choices.as_array()) { - Some(choices) => choices, - None => { - error!( - target = "atoma-service", - level = "error", - endpoint = self.endpoint, - "Error getting choices from chunk" - ); - return Poll::Ready(Some(Err(Error::new("Error getting choices from chunk")))); - } + let Some(choices) = chunk.get(CHOICES).and_then(|choices| choices.as_array()) else { + error!( + target = "atoma-service", + level = "error", + endpoint = self.endpoint, + "Error getting choices from chunk" + ); + return Poll::Ready(Some(Err(Error::new("Error getting choices from chunk")))); }; if choices.is_empty() { @@ -598,7 +592,7 @@ impl Streamer { chunk.clone() }; self.handle_final_chunk(usage, response_hash)?; - update_chunk(&mut chunk, signature, response_hash); + update_chunk(&mut chunk, &signature, response_hash); Poll::Ready(Some(Ok(Event::default().json_data(&chunk)?))) } else { error!( @@ -618,7 +612,7 @@ impl Streamer { } else { chunk }; - update_chunk(&mut chunk, signature, response_hash); + update_chunk(&mut chunk, &signature, response_hash); Poll::Ready(Some(Ok(Event::default().json_data(&chunk)?))) } } @@ -633,85 +627,90 @@ impl Stream for Streamer { } match self.stream.as_mut().poll_next(cx) { - Poll::Ready(Some(Ok(chunk))) => { - match self.handle_streaming_chunk(chunk) { - Poll::Ready(Some(Ok(event))) => { - // Observe the previous timer if it exists - if let Some(timer) = self.inter_stream_token_latency_timer.take() { - timer.observe_duration(); - } - // Start the timer after we've processed this chunk - self.inter_stream_token_latency_timer = Some( - CHAT_COMPLETIONS_INTER_TOKEN_GENERATION_TIME - .with_label_values(&[&self.model]) - .start_timer(), - ); + Poll::Ready(Some(Ok(chunk))) => self.handle_poll_chunk(chunk), + Poll::Ready(Some(Err(e))) => self.handle_poll_error(&e), + Poll::Ready(None) => self.handle_poll_complete(), + Poll::Pending => Poll::Pending, + } + } +} - Poll::Ready(Some(Ok(event))) - } - Poll::Ready(Some(Err(e))) => { - self.status = StreamStatus::Failed(e.to_string()); - // NOTE: We need to update the stack number of tokens as the service failed to generate - // a proper response. For this reason, we set the total number of tokens to 0. - // This will ensure that the stack number of tokens is not updated, and the stack - // will not be penalized for the failed request. - if let Err(e) = update_stack_num_compute_units( - &self.state_manager_sender, - self.stack_small_id, - self.estimated_total_compute_units, - 0, - &self.endpoint, - ) { - error!( - target = "atoma-service-streamer", - level = "error", - "Error updating stack num tokens: {}", - e - ); - } - Poll::Ready(Some(Err(e))) - } - Poll::Ready(None) => Poll::Ready(None), - Poll::Pending => Poll::Pending, - } - } - Poll::Ready(Some(Err(e))) => { - self.status = StreamStatus::Failed(e.to_string()); - // NOTE: We need to update the stack number of tokens as the service failed to generate - // a proper response. For this reason, we set the total number of tokens to 0. - // This will ensure that the stack number of tokens is not updated, and the stack - // will not be penalized for the failed request. - if let Err(e) = update_stack_num_compute_units( - &self.state_manager_sender, - self.stack_small_id, - self.estimated_total_compute_units, - 0, - &self.endpoint, - ) { - error!( - target = "atoma-service-streamer", - level = "error", - "Error updating stack num tokens: {}", - e - ); - } - Poll::Ready(None) - } - Poll::Ready(None) => { - if !self.chunk_buffer.is_empty() { - error!( - target = "atoma-service-streamer", - level = "error", - "Stream ended, but the chunk buffer is not empty, this should not happen: {}", - self.chunk_buffer - ); - } - self.status = StreamStatus::Completed; - Poll::Ready(None) - } +impl Streamer { + /// Handles a successful chunk from the stream + fn handle_poll_chunk(&mut self, chunk: Bytes) -> Poll>> { + match self.handle_streaming_chunk(chunk) { + Poll::Ready(Some(Ok(event))) => self.handle_successful_event(event), + Poll::Ready(Some(Err(e))) => self.handle_streaming_error(e), + Poll::Ready(None) => Poll::Ready(None), Poll::Pending => Poll::Pending, } } + + /// Handles a successful event, updating timers + fn handle_successful_event(&mut self, event: Event) -> Poll>> { + // Observe the previous timer if it exists + if let Some(timer) = self.inter_stream_token_latency_timer.take() { + timer.observe_duration(); + } + // Start the timer after we've processed this chunk + self.inter_stream_token_latency_timer = Some( + CHAT_COMPLETIONS_INTER_TOKEN_GENERATION_TIME + .with_label_values(&[&self.model]) + .start_timer(), + ); + + Poll::Ready(Some(Ok(event))) + } + + /// Handles errors during streaming + fn handle_streaming_error(&mut self, e: Error) -> Poll>> { + self.status = StreamStatus::Failed(e.to_string()); + self.update_stack_tokens_on_error(); + Poll::Ready(Some(Err(e))) + } + + /// Handles stream poll errors + fn handle_poll_error(&mut self, e: &reqwest::Error) -> Poll>> { + self.status = StreamStatus::Failed(e.to_string()); + self.update_stack_tokens_on_error(); + Poll::Ready(None) + } + + /// Handles stream completion + fn handle_poll_complete(&mut self) -> Poll>> { + if !self.chunk_buffer.is_empty() { + error!( + target = "atoma-service-streamer", + level = "error", + "Stream ended, but the chunk buffer is not empty, this should not happen: {}", + self.chunk_buffer + ); + } + self.status = StreamStatus::Completed; + Poll::Ready(None) + } + + /// Updates stack tokens when an error occurs + fn update_stack_tokens_on_error(&self) { + // NOTE: We need to update the stack number of tokens as the service failed to generate + // a proper response. For this reason, we set the total number of tokens to 0. + // This will ensure that the stack number of tokens is not updated, and the stack + // will not be penalized for the failed request. + if let Err(e) = update_stack_num_compute_units( + &self.state_manager_sender, + self.stack_small_id, + self.estimated_total_compute_units, + 0, + &self.endpoint, + ) { + error!( + target = "atoma-service-streamer", + level = "error", + "Error updating stack num tokens: {}", + e + ); + } + } } /// Updates the final chunk with the signature and response hash @@ -723,7 +722,7 @@ impl Stream for Streamer { /// * `chunk` - The chunk to update (mut ref, as we update the chunk in place) /// * `signature` - The signature to update the chunk with /// * `response_hash` - The response hash to update the chunk with -fn update_chunk(chunk: &mut Value, signature: String, response_hash: [u8; PAYLOAD_HASH_SIZE]) { +fn update_chunk(chunk: &mut Value, signature: &str, response_hash: [u8; PAYLOAD_HASH_SIZE]) { chunk[SIGNATURE_KEY] = json!(signature); chunk[RESPONSE_HASH_KEY] = json!(STANDARD.encode(response_hash)); } diff --git a/atoma-service/src/tests.rs b/atoma-service/src/tests.rs index 7d039a2a..3884a28a 100644 --- a/atoma-service/src/tests.rs +++ b/atoma-service/src/tests.rs @@ -1,10 +1,10 @@ mod middleware { - use atoma_confidential::AtomaConfidentialComputeService; + use atoma_confidential::AtomaConfidentialCompute; use atoma_state::{ types::{AtomaAtomaStateManagerEvent, Stack, Task}, AtomaStateManager, }; - use atoma_sui::{client::AtomaSuiClient, events::AtomaEvent, AtomaSuiConfig}; + use atoma_sui::{client::Client, config::Builder, events::AtomaEvent}; use atoma_utils::{ constants::{self, SALT_SIZE}, encryption::encrypt_plaintext, @@ -100,7 +100,7 @@ mod middleware { .await .expect("Failed to connect to database"); sqlx::query( - "TRUNCATE TABLE + "TRUNCATE TABLE tasks, node_subscriptions, stacks, @@ -122,7 +122,7 @@ mod middleware { Sender, tokio::sync::watch::Receiver, ) { - let (_event_subscriber_sender, event_subscriber_receiver) = flume::unbounded(); + let (event_subscriber_sender, event_subscriber_receiver) = flume::unbounded(); let (state_manager_sender, state_manager_receiver) = flume::unbounded(); let state_manager = AtomaStateManager::new_from_url( POSTGRES_TEST_DB_URL, @@ -175,11 +175,12 @@ mod middleware { state_manager_handle, state_manager_sender, shutdown_sender, - _event_subscriber_sender, + event_subscriber_sender, shutdown_signal, ) } + #[allow(clippy::too_many_lines)] async fn setup_app_state() -> ( AppState, PublicKey, @@ -205,7 +206,7 @@ mod middleware { state_manager_handle, state_manager_sender, shutdown_sender, - _event_subscriber_sender, + event_subscriber_sender, shutdown_receiver, ) = setup_database(public_key.clone()).await; let (stack_retrieve_sender, _) = tokio::sync::mpsc::unbounded_channel(); @@ -242,26 +243,21 @@ mod middleware { .expect("Failed to create .sui/keystore directory"); std::fs::write(keystore_path.clone(), sui_keystore_contents) .expect("Failed to write to keystore"); - let client_config = AtomaSuiConfig::new( - "http://localhost:9000".to_string(), - ObjectID::from_str("0x1").unwrap(), - ObjectID::from_str("0x2").unwrap(), - ObjectID::from_str("0x3").unwrap(), - None, - None, - None, - None, - None, - client_yaml_path.to_string_lossy().to_string(), - "./keystore".to_string(), - "./".to_string(), - ); + let client_config = Builder::new() + .http_rpc_node_addr("http://localhost:9000".to_string()) + .atoma_db(ObjectID::from_str("0x1").unwrap()) + .atoma_package_id(ObjectID::from_str("0x2").unwrap()) + .usdc_package_id(ObjectID::from_str("0x3").unwrap()) + .sui_config_path(client_yaml_path.to_string_lossy().to_string()) + .sui_keystore_path("./keystore".to_string()) + .cursor_path("./".to_string()) + .build(); let (compute_shared_secret_sender, compute_shared_secret_receiver) = tokio::sync::mpsc::unbounded_channel(); let _join_handle = tokio::spawn(async move { - let confidential_compute_service = AtomaConfidentialComputeService::new( + let confidential_compute_service = AtomaConfidentialCompute::new( Arc::new(RwLock::new( - AtomaSuiClient::new(client_config) + Client::new(client_config) .await .expect("Failed to create Sui client"), )), @@ -286,15 +282,20 @@ mod middleware { .expect("Failed to remove keystore"); ( AppState { - models: Arc::new(models.into_iter().map(|s| s.to_string()).collect()), + models: Arc::new( + models + .into_iter() + .map(std::string::ToString::to_string) + .collect(), + ), tokenizers: Arc::new(vec![Arc::new(tokenizer.clone()), Arc::new(tokenizer)]), state_manager_sender, decryption_sender, encryption_sender, compute_shared_secret_sender, chat_completions_service_urls: HashMap::new(), - embeddings_service_url: "".to_string(), - image_generations_service_url: "".to_string(), + embeddings_service_url: String::new(), + image_generations_service_url: String::new(), keystore: Arc::new(keystore), address_index: 0, stack_retrieve_sender, @@ -303,7 +304,7 @@ mod middleware { signature, shutdown_sender, state_manager_handle, - _event_subscriber_sender, + event_subscriber_sender, dh_public_key, ) } @@ -1470,7 +1471,7 @@ mod middleware { let client_dh_public_key = x25519_dalek::PublicKey::from(&client_dh_private_key); // Create incorrect hash (hash of different data) - let incorrect_plaintext = "different data".as_bytes(); + let incorrect_plaintext = b"different data"; let incorrect_hash: [u8; 32] = blake2b_hash(incorrect_plaintext).into(); let shared_secret = client_dh_private_key.diffie_hellman(&server_dh_public_key); diff --git a/atoma-state/src/config.rs b/atoma-state/src/config.rs index a01a91ea..01323e3a 100644 --- a/atoma-state/src/config.rs +++ b/atoma-state/src/config.rs @@ -11,7 +11,8 @@ pub struct AtomaStateManagerConfig { impl AtomaStateManagerConfig { /// Constructor - pub fn new(database_url: String) -> Self { + #[must_use] + pub const fn new(database_url: String) -> Self { Self { database_url } } diff --git a/atoma-state/src/handlers.rs b/atoma-state/src/handlers.rs index 9a72f2b2..65a459b1 100644 --- a/atoma-state/src/handlers.rs +++ b/atoma-state/src/handlers.rs @@ -8,8 +8,9 @@ use atoma_sui::events::{ use tracing::{info, instrument}; use crate::{ - state_manager::Result, types::AtomaAtomaStateManagerEvent, AtomaStateManager, - AtomaStateManagerError, + state_manager::Result, + types::{AtomaAtomaStateManagerEvent, StackSettlementTicket}, + AtomaStateManager, AtomaStateManagerError, }; #[instrument(level = "info", skip_all)] @@ -435,7 +436,7 @@ pub(crate) async fn handle_stack_try_settle_event( event = "handle-stack-try-settle-event", "Processing stack try settle event" ); - let stack_settlement_ticket = event.into(); + let stack_settlement_ticket = StackSettlementTicket::try_from(event)?; state_manager .state .insert_new_stack_settlement_ticket(stack_settlement_ticket) @@ -695,7 +696,7 @@ pub(crate) async fn handle_state_manager_event( estimated_total_compute_units, total_compute_units, ) - .await? + .await?; } AtomaAtomaStateManagerEvent::UpdateStackTotalHash { stack_small_id, @@ -704,7 +705,7 @@ pub(crate) async fn handle_state_manager_event( state_manager .state .update_stack_total_hash(stack_small_id, total_hash) - .await? + .await?; } } Ok(()) diff --git a/atoma-state/src/lib.rs b/atoma-state/src/lib.rs index 6a9ab790..63bddc1d 100644 --- a/atoma-state/src/lib.rs +++ b/atoma-state/src/lib.rs @@ -1,3 +1,9 @@ +#![allow(clippy::cast_possible_wrap)] +#![allow(clippy::cast_precision_loss)] +#![allow(clippy::missing_docs_in_private_items)] +#![allow(clippy::doc_markdown)] +#![allow(clippy::module_name_repetitions)] + pub mod config; pub mod handlers; pub mod state_manager; diff --git a/atoma-state/src/state_manager.rs b/atoma-state/src/state_manager.rs index 56a23477..336f0b55 100644 --- a/atoma-state/src/state_manager.rs +++ b/atoma-state/src/state_manager.rs @@ -29,7 +29,8 @@ pub struct AtomaStateManager { impl AtomaStateManager { /// Constructor - pub fn new( + #[must_use] + pub const fn new( db: PgPool, event_subscriber_receiver: FlumeReceiver, state_manager_receiver: FlumeReceiver, @@ -45,6 +46,20 @@ impl AtomaStateManager { /// /// This method establishes a connection to the Postgres database using the provided URL, /// creates all necessary tables in the database, and returns a new `AtomaStateManager` instance. + /// + /// # Arguments + /// * `database_url` - The URL of the PostgreSQL database to connect to + /// * `event_subscriber_receiver` - Channel receiver for Atoma events + /// * `state_manager_receiver` - Channel receiver for state manager events + /// + /// # Returns + /// A new state manager instance + /// + /// # Errors + /// Returns `AtomaStateManagerError` if: + /// - Failed to connect to database + /// - Failed to create connection pool + /// - Failed to run database migrations pub async fn new_from_url( database_url: &str, event_subscriber_receiver: FlumeReceiver, @@ -181,11 +196,29 @@ pub struct AtomaState { impl AtomaState { /// Constructor - pub fn new(db: PgPool) -> Self { + /// + /// # Arguments + /// * `db` - The Postgres connection pool + /// + /// # Returns + /// A new `AtomaState` instance + #[must_use] + pub const fn new(db: PgPool) -> Self { Self { db } } /// Creates a new `AtomaState` instance from a database URL. + /// + /// # Arguments + /// * `database_url` - The URL of the PostgreSQL database to connect to + /// + /// # Returns + /// A new state instance wrapped in `Result` + /// + /// # Errors + /// Returns `AtomaStateManagerError` if: + /// - Failed to connect to database + /// - Failed to run database migrations pub async fn new_from_url(database_url: &str) -> Result { let db = PgPool::connect(database_url).await?; sqlx::migrate!("./src/migrations").run(&db).await?; @@ -561,8 +594,8 @@ impl AtomaState { max_num_compute_units: i64, ) -> Result<()> { sqlx::query( - "INSERT INTO node_subscriptions - (node_small_id, task_small_id, price_per_one_million_compute_units, max_num_compute_units, valid) + "INSERT INTO node_subscriptions + (node_small_id, task_small_id, price_per_one_million_compute_units, max_num_compute_units, valid) VALUES ($1, $2, $3, $4, TRUE) ON CONFLICT (task_small_id, node_small_id) DO NOTHING", ) @@ -773,11 +806,11 @@ impl AtomaState { ) -> Result<()> { sqlx::query( "INSERT INTO node_public_key_rotations (epoch, key_rotation_counter, node_small_id, public_key_bytes, tdx_quote_bytes) VALUES ($1, $2, $3, $4, $5) - ON CONFLICT (node_small_id) - DO UPDATE SET + ON CONFLICT (node_small_id) + DO UPDATE SET epoch = $1, key_rotation_counter = $2, - public_key_bytes = $4, + public_key_bytes = $4, tdx_quote_bytes = $5", ) .bind(epoch as i64) @@ -815,7 +848,7 @@ impl AtomaState { /// ```rust,ignore /// use atoma_node::atoma_state::AtomaStateManager; /// - /// async fn get_stack(state_manager: &AtomaStateManager, stack_small_id: i64) -> Result { + /// async fn get_stack(state_manager: &AtomaStateManager, stack_small_id: i64) -> Result { /// state_manager.get_stack(stack_id).await /// } /// ``` @@ -1013,7 +1046,7 @@ impl AtomaState { /// async fn get_filled_stacks(state_manager: &AtomaStateManager) -> Result, AtomaStateManagerError> { /// let node_ids = &[1, 2, 3]; // Check stacks for these nodes /// let threshold = 0.8; // Look for stacks that are 80% or more filled - /// + /// /// state_manager.get_almost_filled_stacks(node_ids, threshold).await /// } /// ``` @@ -1028,15 +1061,15 @@ impl AtomaState { fraction: f64, ) -> Result> { Ok(sqlx::query_as::<_, Stack>( - r#" + r" SELECT * FROM stacks WHERE selected_node_id = ANY($1) - AND CASE + AND CASE WHEN num_compute_units = 0 THEN true ELSE (already_computed_units::float / num_compute_units::float) > $2 END - "#, + ", ) .bind(node_small_ids) .bind(fraction) @@ -1116,7 +1149,7 @@ impl AtomaState { ) -> Result> { // Single query that updates and returns the modified row let maybe_stack = sqlx::query_as::<_, Stack>( - r#" + r" UPDATE stacks SET already_computed_units = already_computed_units + $1 WHERE stack_small_id = $2 @@ -1124,7 +1157,7 @@ impl AtomaState { AND num_compute_units - already_computed_units >= $1 AND in_settle_period = false RETURNING * - "#, + ", ) .bind(num_compute_units) .bind(stack_small_id) @@ -1165,7 +1198,7 @@ impl AtomaState { /// async fn insert_stack(state_manager: &AtomaStateManager, stack: Stack) -> Result<(), AtomaStateManagerError> { /// // If a stack with the same stack_small_id exists, this will succeed without modifying it /// state_manager.insert_new_stack(stack).await - /// } + /// } /// ``` #[tracing::instrument( level = "trace", @@ -1179,8 +1212,8 @@ impl AtomaState { )] pub async fn insert_new_stack(&self, stack: Stack) -> Result<()> { sqlx::query( - "INSERT INTO stacks - (owner_address, stack_small_id, stack_id, task_small_id, selected_node_id, num_compute_units, price_per_one_million_compute_units, already_computed_units, in_settle_period, total_hash, num_total_messages) + "INSERT INTO stacks + (owner_address, stack_small_id, stack_id, task_small_id, selected_node_id, num_compute_units, price_per_one_million_compute_units, already_computed_units, in_settle_period, total_hash, num_total_messages) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) ON CONFLICT (stack_small_id) DO UPDATE SET already_computed_units = stacks.already_computed_units + $8 @@ -1290,8 +1323,8 @@ impl AtomaState { total_compute_units: i64, ) -> Result<()> { let result = sqlx::query( - "UPDATE stacks - SET already_computed_units = already_computed_units - ($1 - $2) + "UPDATE stacks + SET already_computed_units = already_computed_units - ($1 - $2) WHERE stack_small_id = $3", ) .bind(estimated_total_compute_units) @@ -1447,19 +1480,19 @@ impl AtomaState { ) -> Result<()> { let mut tx = self.db.begin().await?; sqlx::query( - "INSERT INTO stack_settlement_tickets + "INSERT INTO stack_settlement_tickets ( - stack_small_id, - selected_node_id, - num_claimed_compute_units, - requested_attestation_nodes, - committed_stack_proofs, - stack_merkle_leaves, - dispute_settled_at_epoch, - already_attested_nodes, - is_in_dispute, - user_refund_amount, - is_claimed) + stack_small_id, + selected_node_id, + num_claimed_compute_units, + requested_attestation_nodes, + committed_stack_proofs, + stack_merkle_leaves, + dispute_settled_at_epoch, + already_attested_nodes, + is_in_dispute, + user_refund_amount, + is_claimed) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) ON CONFLICT (stack_small_id, selected_node_id) DO NOTHING", ) @@ -1529,7 +1562,7 @@ impl AtomaState { new_hash: [u8; 32], ) -> Result<()> { let rows_affected = sqlx::query( - "UPDATE stacks + "UPDATE stacks SET total_hash = total_hash || $1, num_total_messages = num_total_messages + 1 WHERE stack_small_id = $2", @@ -1708,7 +1741,7 @@ impl AtomaState { // First query remains the same - get existing data let row = sqlx::query( "SELECT committed_stack_proofs, stack_merkle_leaves, requested_attestation_nodes, already_attested_nodes - FROM stack_settlement_tickets + FROM stack_settlement_tickets WHERE stack_small_id = $1", ) .bind(stack_small_id) @@ -1750,7 +1783,7 @@ impl AtomaState { // Simplified update query sqlx::query( - "UPDATE stack_settlement_tickets + "UPDATE stack_settlement_tickets SET committed_stack_proofs = $1, stack_merkle_leaves = $2, already_attested_nodes = $3 @@ -1861,7 +1894,7 @@ impl AtomaState { user_refund_amount: i64, ) -> Result<()> { sqlx::query( - "UPDATE stack_settlement_tickets + "UPDATE stack_settlement_tickets SET user_refund_amount = $1, is_claimed = true WHERE stack_small_id = $2", @@ -1974,7 +2007,7 @@ impl AtomaState { attestation_node_id: i64, ) -> Result> { let disputes = sqlx::query( - "SELECT * FROM stack_attestation_disputes + "SELECT * FROM stack_attestation_disputes WHERE stack_small_id = $1 AND attestation_node_id = $2", ) .bind(stack_small_id) @@ -2144,8 +2177,8 @@ impl AtomaState { stack_attestation_dispute: StackAttestationDispute, ) -> Result<()> { sqlx::query( - "INSERT INTO stack_attestation_disputes - (stack_small_id, attestation_commitment, attestation_node_id, original_node_id, original_commitment) + "INSERT INTO stack_attestation_disputes + (stack_small_id, attestation_commitment, attestation_node_id, original_node_id, original_commitment) VALUES ($1, $2, $3, $4, $5) ON CONFLICT (stack_small_id, attestation_node_id) DO NOTHING", ) @@ -2199,7 +2232,7 @@ mod tests { async fn truncate_tables(db: &sqlx::PgPool) { // List all your tables here sqlx::query( - "TRUNCATE TABLE + "TRUNCATE TABLE tasks, node_subscriptions, stacks, @@ -4023,11 +4056,11 @@ mod tests { stack_small_id: 1, selected_node_id: 1, num_claimed_compute_units: 100, - requested_attestation_nodes: "".to_string(), + requested_attestation_nodes: String::new(), committed_stack_proofs: vec![], stack_merkle_leaves: vec![], dispute_settled_at_epoch: None, - already_attested_nodes: "".to_string(), + already_attested_nodes: String::new(), is_in_dispute: false, user_refund_amount: 0, is_claimed: true, @@ -4040,11 +4073,11 @@ mod tests { stack_small_id: 2, selected_node_id: 1, num_claimed_compute_units: 100, - requested_attestation_nodes: "".to_string(), + requested_attestation_nodes: String::new(), committed_stack_proofs: vec![], stack_merkle_leaves: vec![], dispute_settled_at_epoch: None, - already_attested_nodes: "".to_string(), + already_attested_nodes: String::new(), is_in_dispute: false, user_refund_amount: 0, is_claimed: false, @@ -4057,11 +4090,11 @@ mod tests { stack_small_id: 3, selected_node_id: 2, num_claimed_compute_units: 200, - requested_attestation_nodes: "".to_string(), + requested_attestation_nodes: String::new(), committed_stack_proofs: vec![], stack_merkle_leaves: vec![], dispute_settled_at_epoch: None, - already_attested_nodes: "".to_string(), + already_attested_nodes: String::new(), is_in_dispute: false, user_refund_amount: 0, is_claimed: true, @@ -4541,7 +4574,7 @@ mod tests { ]; // Insert rotations for multiple nodes - for (node_id, pub_key, tee_bytes) in nodes.iter() { + for (node_id, pub_key, tee_bytes) in &nodes { state_manager .insert_node_public_key_rotation( 100u64, @@ -4554,7 +4587,7 @@ mod tests { } // Verify all insertions - for (node_id, pub_key, tee_bytes) in nodes.iter() { + for (node_id, pub_key, tee_bytes) in &nodes { let row = sqlx::query("SELECT * FROM node_public_key_rotations WHERE node_small_id = $1") .bind(*node_id as i64) diff --git a/atoma-state/src/types.rs b/atoma-state/src/types.rs index 32370bc9..3472b9e2 100644 --- a/atoma-state/src/types.rs +++ b/atoma-state/src/types.rs @@ -1,15 +1,16 @@ -use crate::state_manager::Result; +use crate::AtomaStateManagerError; use atoma_sui::events::{ StackAttestationDisputeEvent, StackCreateAndUpdateEvent, StackCreatedEvent, StackTrySettleEvent, TaskRegisteredEvent, }; use serde::{Deserialize, Serialize}; use sqlx::FromRow; +use std::result::Result; use tokio::sync::oneshot; use utoipa::ToSchema; /// Represents a task in the system -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, FromRow, ToSchema)] +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, FromRow, ToSchema)] pub struct Task { /// Unique small integer identifier for the task pub task_small_id: i64, @@ -33,22 +34,22 @@ pub struct Task { impl From for Task { fn from(event: TaskRegisteredEvent) -> Self { - Task { + Self { task_id: event.task_id, task_small_id: event.task_small_id.inner as i64, - role: event.role.inner as i64, + role: i64::from(event.role.inner), model_name: event.model_name, is_deprecated: false, valid_until_epoch: None, deprecated_at_epoch: None, - security_level: event.security_level.inner as i64, - minimum_reputation_score: event.minimum_reputation_score.map(|score| score as i64), + security_level: i64::from(event.security_level.inner), + minimum_reputation_score: event.minimum_reputation_score.map(i64::from), } } } /// Represents a stack of compute units for a specific task -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, FromRow, ToSchema)] +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, FromRow, ToSchema)] pub struct Stack { /// Address of the owner of the stack pub owner_address: String, @@ -77,7 +78,7 @@ pub struct Stack { impl From for Stack { fn from(event: StackCreatedEvent) -> Self { - Stack { + Self { owner_address: event.owner, stack_id: event.stack_id, stack_small_id: event.stack_small_id.inner as i64, @@ -95,7 +96,7 @@ impl From for Stack { impl From for Stack { fn from(event: StackCreateAndUpdateEvent) -> Self { - Stack { + Self { owner_address: event.owner, stack_small_id: event.stack_small_id.inner as i64, stack_id: event.stack_id, @@ -112,7 +113,7 @@ impl From for Stack { } /// Represents a settlement ticket for a compute stack -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, FromRow, ToSchema)] +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, FromRow, ToSchema)] pub struct StackSettlementTicket { /// Unique small integer identifier for the stack pub stack_small_id: i64, @@ -138,8 +139,10 @@ pub struct StackSettlementTicket { pub is_claimed: bool, } -impl From for StackSettlementTicket { - fn from(event: StackTrySettleEvent) -> Self { +impl TryFrom for StackSettlementTicket { + type Error = AtomaStateManagerError; + + fn try_from(event: StackTrySettleEvent) -> Result { let num_attestation_nodes = event.requested_attestation_nodes.len(); let expanded_size = 32 * num_attestation_nodes; @@ -149,7 +152,7 @@ impl From for StackSettlementTicket { let mut expanded_leaves = event.stack_merkle_leaf; expanded_leaves.resize(expanded_size, 0); - StackSettlementTicket { + Ok(Self { stack_small_id: event.stack_small_id.inner as i64, selected_node_id: event.selected_node_id.inner as i64, num_claimed_compute_units: event.num_claimed_compute_units as i64, @@ -160,20 +163,21 @@ impl From for StackSettlementTicket { .map(|id| id.inner) .collect::>(), ) - .unwrap(), + .map_err(AtomaStateManagerError::JsonParseError)?, committed_stack_proofs: expanded_proofs, stack_merkle_leaves: expanded_leaves, dispute_settled_at_epoch: None, - already_attested_nodes: serde_json::to_string(&Vec::::new()).unwrap(), + already_attested_nodes: serde_json::to_string(&Vec::::new()) + .map_err(AtomaStateManagerError::JsonParseError)?, is_in_dispute: false, user_refund_amount: 0, is_claimed: false, - } + }) } } /// Represents a dispute in the stack attestation process -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, FromRow, ToSchema)] +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, FromRow, ToSchema)] pub struct StackAttestationDispute { /// Unique small integer identifier for the stack involved in the dispute pub stack_small_id: i64, @@ -189,7 +193,7 @@ pub struct StackAttestationDispute { impl From for StackAttestationDispute { fn from(event: StackAttestationDisputeEvent) -> Self { - StackAttestationDispute { + Self { stack_small_id: event.stack_small_id.inner as i64, attestation_commitment: event.attestation_commitment, attestation_node_id: event.attestation_node_id.inner as i64, @@ -200,7 +204,7 @@ impl From for StackAttestationDispute { } /// Represents a node subscription to a task -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, FromRow, ToSchema)] +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, FromRow, ToSchema)] pub struct NodeSubscription { /// Unique small integer identifier for the node subscription pub node_small_id: i64, @@ -240,6 +244,6 @@ pub enum AtomaAtomaStateManagerEvent { /// Total number of compute units total_num_compute_units: i64, /// Oneshot channel to send the result back to the sender channel - result_sender: oneshot::Sender>>, + result_sender: oneshot::Sender, AtomaStateManagerError>>, }, } diff --git a/atoma-sui/src/client.rs b/atoma-sui/src/client.rs index c760aba7..c6a866e8 100644 --- a/atoma-sui/src/client.rs +++ b/atoma-sui/src/client.rs @@ -1,12 +1,9 @@ use std::path::Path; -use sui_sdk::{ - json::SuiJsonValue, rpc_types::SuiData, types::base_types::ObjectID, - wallet_context::WalletContext, -}; +use sui_sdk::{json::SuiJsonValue, types::base_types::ObjectID, wallet_context::WalletContext}; use thiserror::Error; use tracing::{error, info, instrument}; -use crate::{config::AtomaSuiConfig, events::NodePublicKeyCommittmentEvent}; +use crate::{config::Config as SuiConfig, events::NodePublicKeyCommittmentEvent}; type Result = std::result::Result; @@ -46,31 +43,35 @@ const UPDATE_NODE_TASK_SUBSCRIPTION_METHOD: &str = "update_node_subscription"; /// The Atoma's contract method name for submitting a node key rotation attestation const ROTATE_NODE_PUBLIC_KEY: &str = "rotate_node_public_key"; -/// A client for interacting with the Atoma network using the Sui blockchain. -/// -/// The `AtomaSuiClient` struct provides methods to perform various operations -/// in the Atoma network, such as registering nodes, subscribing to models and tasks, -/// and managing transactions. It maintains a wallet context and optionally stores -/// a node badge representing the client's node registration status. -pub struct AtomaSuiClient { - /// Configuration settings for the Atoma client, including paths and timeouts. - config: AtomaSuiConfig, - - /// The wallet context used for managing blockchain interactions. +/// Client for interacting with Atoma's Sui blockchain functionality +pub struct Client { + /// Configuration settings for the Atoma client + config: SuiConfig, + /// The Sui client for blockchain interactions wallet_ctx: WalletContext, - - /// An optional tuple containing the ObjectID and small ID of the node badge, - /// which represents the node's registration in the Atoma network. + /// An optional tuple containing the `ObjectID` and small ID of the node badge, + /// used for authentication node_badge: Option<(ObjectID, u64)>, - - /// The ObjectID of the USDC wallet address - /// for the current operator - usdc_wallet_id: Option, + /// The `ObjectID` of the USDC wallet address + usdc_wallet: Option, } -impl AtomaSuiClient { - /// Constructor - pub async fn new(config: AtomaSuiConfig) -> Result { +impl Client { + /// Creates a new Sui client instance + /// + /// # Arguments + /// * `config` - Configuration settings for the client + /// + /// # Returns + /// A new client instance + /// + /// # Errors + /// Returns an error if: + /// - Failed to initialize wallet context + /// - Failed to get client from wallet context + /// - Failed to get active address + /// - Failed to retrieve node badge + pub async fn new(config: SuiConfig) -> Result { let sui_config_path = config.sui_config_path(); let sui_config_path = Path::new(&sui_config_path); let mut wallet_ctx = WalletContext::new( @@ -88,30 +89,18 @@ impl AtomaSuiClient { config, wallet_ctx, node_badge, - usdc_wallet_id: None, + usdc_wallet: None, }) } - /// Creates a new `AtomaSuiClient` instance from a configuration file. - /// - /// This method reads the configuration from the specified file path and initializes - /// a new `AtomaSuiClient` with the loaded configuration. - /// - /// # Arguments - /// - /// * `config_path` - A path-like type that represents the location of the configuration file. - /// - /// # Returns - /// - /// * `Result` - A Result containing the new `AtomaSuiClient` instance if successful, - /// or an error if the configuration couldn't be read. + /// Creates a new client instance /// /// # Errors - /// - /// This function will return an error if: - /// * The configuration file cannot be read or parsed. - pub async fn new_from_config>(config_path: P) -> Result { - let config = AtomaSuiConfig::from_file_path(config_path); + /// - If Sui client initialization fails + /// - If keystore operations fail + /// - If network connection fails + pub async fn new_from_config + Send>(config_path: P) -> Result { + let config = SuiConfig::from_file_path(config_path); Self::new(config).await } @@ -446,7 +435,7 @@ impl AtomaSuiClient { /// None, // default gas budget /// None // default gas price /// ).await?; - /// + /// /// // Or with custom gas settings and specific node badge ID /// let gas_object = ObjectID::new([1; 32]); /// let node_badge = ObjectID::new([2; 32]); @@ -458,7 +447,7 @@ impl AtomaSuiClient { /// Some(10_000_000), // 0.01 SUI gas budget /// Some(1000) // specific gas price /// ).await?; - /// + /// /// println!("Task subscription updated: {}", tx_digest); /// Ok(()) /// # } @@ -1068,7 +1057,7 @@ impl AtomaSuiClient { /// async fn example(client: &mut AtomaSuiClient) -> Result<()> { /// let tdx_quote = vec![1, 2, 3, 4]; // Your TDX quote bytes /// let public_key = [0u8; 32]; // Your new public key - /// + /// /// // Submit with default gas settings /// let tx_digest = client.submit_key_rotation_remote_attestation( /// tdx_quote, @@ -1077,7 +1066,7 @@ impl AtomaSuiClient { /// None, // default gas budget /// None, // default gas price /// ).await?; - /// + /// /// println!("Key rotation submitted: {}", tx_digest); /// Ok(()) /// } @@ -1165,7 +1154,7 @@ impl AtomaSuiClient { address = %self.wallet_ctx.active_address().unwrap() ))] pub async fn get_or_load_usdc_wallet_object_id(&mut self) -> Result { - if let Some(usdc_wallet_id) = self.usdc_wallet_id { + if let Some(usdc_wallet_id) = self.usdc_wallet { Ok(usdc_wallet_id) } else { let active_address = self.wallet_ctx.active_address()?; @@ -1177,7 +1166,7 @@ impl AtomaSuiClient { .await { Ok(usdc_wallet) => { - self.usdc_wallet_id = Some(usdc_wallet); + self.usdc_wallet = Some(usdc_wallet); Ok(usdc_wallet) } Err(e) => Err(e), @@ -1209,37 +1198,39 @@ pub enum AtomaSuiClientError { } pub(crate) mod utils { - use super::*; + use super::{AtomaSuiClientError, ObjectID, Result, MODULE_ID}; use sui_sdk::{ - rpc_types::{Page, SuiObjectDataFilter, SuiObjectDataOptions, SuiObjectResponseQuery}, + rpc_types::{ + Page, SuiData, SuiObjectDataFilter, SuiObjectDataOptions, SuiObjectResponseQuery, + }, types::base_types::{ObjectType, SuiAddress}, SuiClient, }; - use tracing::error; + use tracing::{error, instrument}; /// The name of the Atoma's contract node badge type const DB_NODE_TYPE_NAME: &str = "NodeBadge"; - /// Retrieves the node badge (ObjectID and small_id) associated with a given address. + /// Retrieves the node badge (`ObjectID` and `small_id`) associated with a given address. /// - /// This function queries the Sui blockchain to find a NodeBadge object owned by the specified - /// address that was created by the specified package. The NodeBadge represents a node's + /// This function queries the Sui blockchain to find a `NodeBadge` object owned by the specified + /// address that was created by the specified package. The `NodeBadge` represents a node's /// registration in the Atoma network. /// /// # Arguments /// - /// * `client` - A reference to the SuiClient used to interact with the blockchain - /// * `package` - The ObjectID of the Atoma package that created the NodeBadge - /// * `active_address` - The SuiAddress to query for owned NodeBadge objects + /// * `client` - A reference to the `SuiClient` used to interact with the blockchain + /// * `package` - The `ObjectID` of the Atoma package that created the `NodeBadge` + /// * `active_address` - The `SuiAddress` to query for owned `NodeBadge` objects /// /// # Returns /// /// Returns `Option<(ObjectID, u64)>` where: - /// - `Some((object_id, small_id))` if a NodeBadge is found, where: - /// - `object_id` is the unique identifier of the NodeBadge object + /// - `Some((object_id, small_id))` if a `NodeBadge` is found, where: + /// - `object_id` is the unique identifier of the `NodeBadge` object /// - `small_id` is the node's numeric identifier in the Atoma network /// - `None` if: - /// - No NodeBadge is found + /// - No `NodeBadge` is found /// - The query fails /// - The object data cannot be parsed /// @@ -1252,7 +1243,7 @@ pub(crate) mod utils { /// async fn example(client: &SuiClient) { /// let package_id = ObjectID::new([1; 32]); /// let address = SuiAddress::random_for_testing_only(); - /// + /// /// match get_node_badge(client, package_id, address).await { /// Some((badge_id, small_id)) => { /// println!("Found NodeBadge: ID={}, small_id={}", badge_id, small_id); @@ -1267,9 +1258,9 @@ pub(crate) mod utils { /// # Implementation Notes /// /// - The function queries up to 100 objects at a time - /// - The function filters objects by package and looks for the specific NodeBadge type - /// - Object content is parsed to extract the small_id from the Move object's fields - pub(crate) async fn get_node_badge( + /// - The function filters objects by package and looks for the specific `NodeBadge` type + /// - Object content is parsed to extract the `small_id` from the Move object's fields + pub async fn get_node_badge( client: &SuiClient, package: ObjectID, active_address: SuiAddress, @@ -1354,7 +1345,7 @@ pub(crate) mod utils { endpoint = "find_usdc_token_wallet", address = %active_address ))] - pub(crate) async fn find_usdc_token_wallet( + pub async fn find_usdc_token_wallet( client: &SuiClient, usdc_package: ObjectID, active_address: SuiAddress, diff --git a/atoma-sui/src/config.rs b/atoma-sui/src/config.rs index d174b4ba..c87afa5d 100644 --- a/atoma-sui/src/config.rs +++ b/atoma-sui/src/config.rs @@ -1,15 +1,15 @@ use std::{path::Path, time::Duration}; -use config::Config; +use config::Config as RustConfig; use serde::{Deserialize, Serialize}; use sui_sdk::types::base_types::ObjectID; -/// Configuration for the Sui Event Subscriber +/// Configuration for Sui blockchain interactions. /// /// This struct holds the necessary configuration parameters for connecting to and /// interacting with a Sui network, including URLs, package ID, timeout, and small IDs. #[derive(Clone, Debug, Deserialize, Serialize)] -pub struct AtomaSuiConfig { +pub struct Config { /// The HTTP URL for a Sui RPC node, to which the subscriber will connect /// This is used for making HTTP requests to the Sui RPC node http_rpc_node_addr: String, @@ -58,100 +58,80 @@ pub struct AtomaSuiConfig { cursor_path: String, } -impl AtomaSuiConfig { - /// Constructor - #[allow(clippy::too_many_arguments)] - pub fn new( - http_rpc_node_addr: String, - atoma_db: ObjectID, - atoma_package_id: ObjectID, - usdc_package_id: ObjectID, - request_timeout: Option, - limit: Option, - node_small_ids: Option>, - task_small_ids: Option>, - max_concurrent_requests: Option, - sui_config_path: String, - sui_keystore_path: String, - cursor_path: String, - ) -> Self { - Self { - http_rpc_node_addr, - atoma_db, - atoma_package_id, - usdc_package_id, - request_timeout, - limit, - node_small_ids, - task_small_ids, - max_concurrent_requests, - sui_config_path, - sui_keystore_path, - cursor_path, - } - } - - /// Getter for `http_url` +impl Config { + /// Gets the HTTP RPC node address + #[must_use] pub fn http_rpc_node_addr(&self) -> String { self.http_rpc_node_addr.clone() } /// Getter for `limit` - pub fn limit(&self) -> Option { + #[must_use] + pub const fn limit(&self) -> Option { self.limit } /// Getter for `package_id` - pub fn atoma_package_id(&self) -> ObjectID { + #[must_use] + pub const fn atoma_package_id(&self) -> ObjectID { self.atoma_package_id } /// Getter for `usdc_package_id` - pub fn usdc_package_id(&self) -> ObjectID { + #[must_use] + pub const fn usdc_package_id(&self) -> ObjectID { self.usdc_package_id } /// Getter for `atoma_db` - pub fn atoma_db(&self) -> ObjectID { + #[must_use] + pub const fn atoma_db(&self) -> ObjectID { self.atoma_db } /// Getter for `request_timeout` - pub fn request_timeout(&self) -> Option { + #[must_use] + pub const fn request_timeout(&self) -> Option { self.request_timeout } /// Getter for `small_id` + #[must_use] pub fn node_small_ids(&self) -> Option> { self.node_small_ids.clone() } /// Getter for `task_small_ids` + #[must_use] pub fn task_small_ids(&self) -> Option> { self.task_small_ids.clone() } /// Getter for `max_concurrent_requests` - pub fn max_concurrent_requests(&self) -> Option { + #[must_use] + pub const fn max_concurrent_requests(&self) -> Option { self.max_concurrent_requests } /// Getter for `keystore_path` + #[must_use] pub fn sui_config_path(&self) -> String { self.sui_config_path.clone() } /// Getter for `sui_keystore_path` + #[must_use] pub fn sui_keystore_path(&self) -> String { self.sui_keystore_path.clone() } /// Getter for `cursor_path` + #[must_use] pub fn cursor_path(&self) -> String { self.cursor_path.clone() } - /// Constructs a new `AtomaSuiConfig` instance from a configuration file path. + /// Constructs a new `Config` instance from a configuration file path. /// /// # Arguments /// @@ -159,25 +139,26 @@ impl AtomaSuiConfig { /// /// # Returns /// - /// Returns a new `AtomaSuiConfig` instance populated with values from the configuration file. + /// Returns a new `Config` instance populated with values from the configuration file. /// /// # Panics /// /// This method will panic if: /// - The configuration file cannot be read or parsed. /// - The "atoma-sui" section is missing from the configuration file. - /// - The configuration values cannot be deserialized into a `AtomaSuiConfig` instance. + /// - The configuration values cannot be deserialized into a `Config` instance. /// /// # Examples /// /// ```rust,ignore - /// use atoma_sui::config::AtomaSuiConfig; + /// use atoma_sui::config::Config; /// use std::path::Path; /// - /// let config = AtomaSuiConfig::from_file_path("config.toml"); + /// let config = Config::from_file_path("config.toml"); /// ``` + #[must_use] pub fn from_file_path>(config_file_path: P) -> Self { - let builder = Config::builder() + let builder = RustConfig::builder() .add_source(config::File::with_name( config_file_path.as_ref().to_str().unwrap(), )) @@ -196,32 +177,209 @@ impl AtomaSuiConfig { } } +/// Builder for creating Config instances +/// Builder pattern implementation for creating `Config` instances. +/// +/// This struct provides a flexible way to construct `Config` objects by allowing optional +/// setting of individual configuration parameters. Each field is wrapped in an `Option` +/// to track whether it has been explicitly set. +/// +/// # Fields +/// +/// * `http_rpc_node_addr` - Optional HTTP URL for the Sui RPC node +/// * `atoma_db` - Optional Atoma's DB object ID on the Sui network +/// * `atoma_package_id` - Optional Atoma's package ID on the Sui network +/// * `usdc_package_id` - Optional USDC token package ID on the Sui network +/// * `request_timeout` - Optional timeout duration for requests +/// * `max_concurrent_requests` - Optional maximum number of concurrent requests +/// * `limit` - Optional limit on number of dynamic fields per iteration +/// * `node_small_ids` - Optional list of node small IDs under control +/// * `task_small_ids` - Optional list of task small IDs under control +/// * `sui_config_path` - Optional path to Sui config file +/// * `sui_keystore_path` - Optional path to Sui keystore +/// * `cursor_path` - Optional path to cursor file +/// +/// # Example +/// +/// ```rust,ignore +/// let config = Builder::new() +/// .http_rpc_node_addr("http://localhost:9000".to_string()) +/// .atoma_db(object_id) +/// .build(); +/// ``` +pub struct Builder { + http_rpc_node_addr: Option, + atoma_db: Option, + atoma_package_id: Option, + usdc_package_id: Option, + request_timeout: Option, + max_concurrent_requests: Option, + limit: Option, + node_small_ids: Option>, + task_small_ids: Option>, + sui_config_path: Option, + sui_keystore_path: Option, + cursor_path: Option, +} + +impl Builder { + #[must_use] + pub const fn new() -> Self { + Self { + http_rpc_node_addr: None, + atoma_db: None, + atoma_package_id: None, + usdc_package_id: None, + request_timeout: None, + max_concurrent_requests: None, + limit: None, + node_small_ids: None, + task_small_ids: None, + sui_config_path: None, + sui_keystore_path: None, + cursor_path: None, + } + } + + #[must_use] + pub fn http_rpc_node_addr(mut self, addr: String) -> Self { + self.http_rpc_node_addr = Some(addr); + self + } + + #[must_use] + pub const fn atoma_db(mut self, db: ObjectID) -> Self { + self.atoma_db = Some(db); + self + } + + #[must_use] + pub const fn atoma_package_id(mut self, package_id: ObjectID) -> Self { + self.atoma_package_id = Some(package_id); + self + } + + #[must_use] + pub const fn usdc_package_id(mut self, package_id: ObjectID) -> Self { + self.usdc_package_id = Some(package_id); + self + } + + #[must_use] + pub const fn request_timeout(mut self, timeout: Option) -> Self { + self.request_timeout = timeout; + self + } + + #[must_use] + pub const fn max_concurrent_requests(mut self, requests: Option) -> Self { + self.max_concurrent_requests = requests; + self + } + + #[must_use] + pub const fn limit(mut self, limit: Option) -> Self { + self.limit = limit; + self + } + + #[must_use] + pub fn node_small_ids(mut self, ids: Option>) -> Self { + self.node_small_ids = ids; + self + } + + #[must_use] + pub fn task_small_ids(mut self, ids: Option>) -> Self { + self.task_small_ids = ids; + self + } + + #[must_use] + pub fn sui_config_path(mut self, path: String) -> Self { + self.sui_config_path = Some(path); + self + } + + #[must_use] + pub fn sui_keystore_path(mut self, path: String) -> Self { + self.sui_keystore_path = Some(path); + self + } + + #[must_use] + pub fn cursor_path(mut self, path: String) -> Self { + self.cursor_path = Some(path); + self + } + + /// Builds the final Config from the builder + /// + /// # Returns + /// A new `Config` instance with the configured values + /// + /// # Panics + /// This function will panic if: + /// - `atoma_db` is not set + /// - `atoma_package_id` is not set + /// - `usdc_package_id` is not set + #[must_use] + pub fn build(self) -> Config { + Config { + http_rpc_node_addr: self.http_rpc_node_addr.unwrap_or_default(), + atoma_db: self.atoma_db.expect("atoma_db is required"), + atoma_package_id: self.atoma_package_id.expect("atoma_package_id is required"), + usdc_package_id: self.usdc_package_id.expect("usdc_package_id is required"), + request_timeout: self.request_timeout, + max_concurrent_requests: self.max_concurrent_requests, + limit: self.limit, + node_small_ids: self.node_small_ids, + task_small_ids: self.task_small_ids, + sui_config_path: self.sui_config_path.unwrap_or_default(), + sui_keystore_path: self.sui_keystore_path.unwrap_or_default(), + cursor_path: self.cursor_path.unwrap_or_default(), + } + } +} + +impl Default for Builder { + fn default() -> Self { + Self::new() + } +} + #[cfg(test)] pub mod tests { use super::*; #[test] fn test_config() { - let config = AtomaSuiConfig::new( - "".to_string(), - "0x8d97f1cd6ac663735be08d1d2b6d02a159e711586461306ce60a2b7a6a565a9e" - .parse() - .unwrap(), - "0x8d97f1cd6ac663735be08d1d2b6d02a159e711586461306ce60a2b7a6a565a9e" - .parse() - .unwrap(), - "0x8d97f1cd6ac663735be08d1d2b6d02a159e711586461306ce60a2b7a6a565a9e" - .parse() - .unwrap(), - Some(Duration::from_secs(5 * 60)), - Some(10), - Some(vec![0, 1, 2]), - Some(vec![3, 4, 5]), - Some(10), - "".to_string(), - "".to_string(), - "".to_string(), - ); + let config = Builder::new() + .http_rpc_node_addr(String::new()) + .atoma_db( + "0x8d97f1cd6ac663735be08d1d2b6d02a159e711586461306ce60a2b7a6a565a9e" + .parse() + .unwrap(), + ) + .atoma_package_id( + "0x8d97f1cd6ac663735be08d1d2b6d02a159e711586461306ce60a2b7a6a565a9e" + .parse() + .unwrap(), + ) + .usdc_package_id( + "0x8d97f1cd6ac663735be08d1d2b6d02a159e711586461306ce60a2b7a6a565a9e" + .parse() + .unwrap(), + ) + .request_timeout(Some(Duration::from_secs(5 * 60))) + .limit(Some(10)) + .node_small_ids(Some(vec![0, 1, 2])) + .task_small_ids(Some(vec![3, 4, 5])) + .max_concurrent_requests(Some(10)) + .sui_config_path(String::new()) + .sui_keystore_path(String::new()) + .cursor_path(String::new()) + .build(); let toml_str = toml::to_string(&config).unwrap(); let should_be_toml_str = "http_rpc_node_addr = \"\"\natoma_db = \"0x8d97f1cd6ac663735be08d1d2b6d02a159e711586461306ce60a2b7a6a565a9e\"\natoma_package_id = \"0x8d97f1cd6ac663735be08d1d2b6d02a159e711586461306ce60a2b7a6a565a9e\"\nusdc_package_id = \"0x8d97f1cd6ac663735be08d1d2b6d02a159e711586461306ce60a2b7a6a565a9e\"\nmax_concurrent_requests = 10\nlimit = 10\nnode_small_ids = [0, 1, 2]\ntask_small_ids = [3, 4, 5]\nsui_config_path = \"\"\nsui_keystore_path = \"\"\ncursor_path = \"\"\n\n[request_timeout]\nsecs = 300\nnanos = 0\n"; diff --git a/atoma-sui/src/events.rs b/atoma-sui/src/events.rs index 091e3d63..5cbe5597 100644 --- a/atoma-sui/src/events.rs +++ b/atoma-sui/src/events.rs @@ -200,11 +200,11 @@ where /// Represents an event that is emitted when the Atoma contract is first published. /// -/// This event contains information about the newly published AtomaDb object id and +/// This event contains information about the newly published `AtomaDb` object id and /// the associated manager badge id. #[derive(Clone, Debug, Deserialize, Serialize)] pub struct PublishedEvent { - /// The object id of the AtomaDb. + /// The object id of the `AtomaDb`. pub db: String, /// The identifier of the manager badge associated with the Atoma contract. @@ -239,7 +239,7 @@ pub struct NodeSubscribedToModelEvent { /// The name of the model that the node is subscribing to. /// This field represents the name of the AI models avaiable in the network - /// (which is compatible with HuggingFace's model naming convention). + /// (which is compatible with `HuggingFace` 's model naming convention). pub model_name: String, /// The echelon ID representing the performance tier or capability level @@ -684,7 +684,7 @@ pub struct Text2ImagePromptEvent { pub nodes: Vec, /// The output destination where the output will be stored. - /// The output is serialized with MessagePack. + /// The output is serialized with `MessagePack`. pub output_destination: Vec, } @@ -721,7 +721,7 @@ pub struct NewlySampledNodesEvent { /// Represents an event emitted when a ticket is settled. #[derive(Clone, Debug, Deserialize, Serialize)] pub struct SettledEvent { - /// The ID of the settlement object. + /// The ID of the settlement object. pub ticket_id: String, /// The oracle node ID that settled the ticket. @@ -731,7 +731,7 @@ pub struct SettledEvent { /// Represents an event emitted when a retry settlement is requested. #[derive(Clone, Debug, Deserialize, Serialize)] pub struct RetrySettlementEvent { - /// The ID of the settlement object. + /// The ID of the settlement object. pub ticket_id: String, /// The number of nodes in the echelon that should be used to retry the settlement. @@ -815,7 +815,7 @@ pub struct Text2TextPromptEvent { pub nodes: Vec, /// The output destination where the output will be stored. - /// The output is serialized with MessagePack. + /// The output is serialized with `MessagePack`. pub output_destination: Vec, } @@ -1237,7 +1237,7 @@ mod tests { assert_eq!(timeout.timed_out_count, 2); assert_eq!(timeout.timeout_ms, 5000); assert_eq!(timeout.started_in_epoch, 4000); - assert_eq!(timeout.started_at_epoch_timestamp_ms, 162000); + assert_eq!(timeout.started_at_epoch_timestamp_ms, 162_000); } #[test] @@ -1327,11 +1327,11 @@ mod tests { "prompt": [65, 66, 67], "random_seed": "42", "repeat_last_n": "64", - "repeat_penalty": 1065353216, // 1.0 in IEEE 754 single-precision float + "repeat_penalty": 1_065_353_216, // 1.0 in IEEE 754 single-precision float "should_stream_output": false, - "temperature": 1065353216, // 1.0 in IEEE 754 single-precision float + "temperature": 1_065_353_216, // 1.0 in IEEE 754 single-precision float "top_k": "50", - "top_p": 1065353216 // 1.0 in IEEE 754 single-precision float + "top_p": 1_065_353_216 // 1.0 in IEEE 754 single-precision float }, "chunks_count": "2", "nodes": [{"inner": "3"}, {"inner": "4"}], @@ -1346,11 +1346,11 @@ mod tests { assert_eq!(event.params.prompt, vec![65, 66, 67]); assert_eq!(event.params.random_seed, 42); assert_eq!(event.params.repeat_last_n, 64); - assert_eq!(event.params.repeat_penalty, 1065353216); + assert_eq!(event.params.repeat_penalty, 1_065_353_216); assert!(!event.params.should_stream_output); - assert_eq!(event.params.temperature, 1065353216); + assert_eq!(event.params.temperature, 1_065_353_216); assert_eq!(event.params.top_k, 50); - assert_eq!(event.params.top_p, 1065353216); + assert_eq!(event.params.top_p, 1_065_353_216); assert_eq!(event.chunks_count, 2); assert_eq!(event.nodes.len(), 2); assert_eq!(event.nodes[0].inner, 3); diff --git a/atoma-sui/src/lib.rs b/atoma-sui/src/lib.rs index 05400e4a..3f2facd5 100644 --- a/atoma-sui/src/lib.rs +++ b/atoma-sui/src/lib.rs @@ -2,6 +2,3 @@ pub mod client; pub mod config; pub mod events; pub mod subscriber; - -pub use config::AtomaSuiConfig; -pub use subscriber::SuiEventSubscriber; diff --git a/atoma-sui/src/subscriber.rs b/atoma-sui/src/subscriber.rs index fea1a45a..0e05c266 100644 --- a/atoma-sui/src/subscriber.rs +++ b/atoma-sui/src/subscriber.rs @@ -1,5 +1,5 @@ use crate::{ - config::AtomaSuiConfig, + config::Config as SuiConfig, events::{ AtomaEvent, AtomaEventIdentifier, StackCreateAndUpdateEvent, StackCreatedEvent, SuiEventParseError, @@ -31,7 +31,6 @@ pub(crate) type Result = std::result::Result; /// Represents the number of compute units available, stored as a 64-bit unsigned integer. type ComputeUnits = i64; - /// Represents the small identifier for a stack, stored as a 64-bit unsigned integer. type StackSmallId = i64; @@ -50,9 +49,9 @@ pub(crate) type StackRetrieveReceiver = mpsc::UnboundedReceiver<( /// /// This struct provides functionality to subscribe to and process events /// from the Sui blockchain based on specified filters. -pub struct SuiEventSubscriber { +pub struct Subscriber { /// The configuration values for the subscriber. - config: AtomaSuiConfig, + config: SuiConfig, /// The event filter used to specify which events to subscribe to. filter: EventFilter, @@ -73,10 +72,15 @@ pub struct SuiEventSubscriber { shutdown_signal: Receiver, } -impl SuiEventSubscriber { +impl Subscriber { /// Constructor + /// + /// # Panics + /// - If identifier creation fails for DB module name + /// - If event filtering setup fails + #[must_use] pub fn new( - config: AtomaSuiConfig, + config: SuiConfig, state_manager_sender: Sender, stack_retrieve_receiver: StackRetrieveReceiver, confidential_compute_service_sender: UnboundedSender, @@ -96,10 +100,10 @@ impl SuiEventSubscriber { } } - /// Creates a new `SuiEventSubscriber` instance from a configuration file. + /// Creates a new `Subscriber` instance from a configuration file. /// /// This method reads the configuration from the specified file path and initializes - /// a new `SuiEventSubscriber` with the loaded configuration. + /// a new `Subscriber` with the loaded configuration. /// /// # Arguments /// @@ -107,7 +111,7 @@ impl SuiEventSubscriber { /// /// # Returns /// - /// * `Result` - A Result containing the new `SuiEventSubscriber` instance if successful, + /// * `Result` - A Result containing the new `Subscriber` instance if successful, /// or an error if the configuration couldn't be read. /// /// # Errors @@ -121,7 +125,7 @@ impl SuiEventSubscriber { confidential_compute_service_sender: UnboundedSender, shutdown_signal: Receiver, ) -> Self { - let config = AtomaSuiConfig::from_file_path(config_path); + let config = SuiConfig::from_file_path(config_path); Self::new( config, state_manager_sender, @@ -155,7 +159,7 @@ impl SuiEventSubscriber { #[instrument(level = "info", skip_all, fields( http_rpc_node_addr = %config.http_rpc_node_addr() ))] - pub async fn build_client(config: &AtomaSuiConfig) -> Result { + pub async fn build_client(config: &SuiConfig) -> Result { let mut client_builder = SuiClientBuilder::default(); if let Some(request_timeout) = config.request_timeout() { client_builder = client_builder.request_timeout(request_timeout); @@ -212,151 +216,151 @@ impl SuiEventSubscriber { let mut cursor = read_cursor_from_toml_file(&self.config.cursor_path())?; loop { tokio::select! { - Some((tx_digest, estimated_compute_units, selected_stack_small_id, result_sender)) = self.stack_retrieve_receiver.recv() => { - let tx_events = client - .read_api() - .get_transaction_with_options( - tx_digest, - SuiTransactionBlockResponseOptions { - show_events: true, ..Default::default() + Some((tx_digest, estimated_compute_units, selected_stack_small_id, result_sender)) = self.stack_retrieve_receiver.recv() => { + let tx_events = client + .read_api() + .get_transaction_with_options( + tx_digest, + SuiTransactionBlockResponseOptions { + show_events: true, ..Default::default() + } + ) + .await? + .events; + let mut compute_units = None; + let mut stack_small_id = None; + if let Some(tx_events) = tx_events { + for event in &tx_events.data { + let event_identifier = AtomaEventIdentifier::from_str(event.type_.name.as_str())?; + if event_identifier == AtomaEventIdentifier::StackCreatedEvent { + // NOTE: In this case, the transaction contains a stack creation event, + // which means that whoever made a request to the service has already paid + // to buy new compute units. + // We need to count the compute units used by the transaction. + let event: StackCreatedEvent = serde_json::from_value(event.parsed_json.clone())?; + + // Move the cast to a separate statement with the attribute + #[allow(clippy::cast_sign_loss, clippy::cast_possible_wrap)] + let selected_stack_small_id_u64 = selected_stack_small_id as u64; + if event.stack_small_id.inner != selected_stack_small_id_u64 { + continue; } - ) - .await? - .events; - let mut compute_units = None; - let mut stack_small_id = None; - if let Some(tx_events) = tx_events { - for event in tx_events.data.iter() { - let event_identifier = AtomaEventIdentifier::from_str(event.type_.name.as_str())?; - if event_identifier == AtomaEventIdentifier::StackCreatedEvent { - // NOTE: In this case, the transaction contains a stack creation event, - // which means that whoever made a request to the service has already paid - // to buy new compute units. - // We need to count the compute units used by the transaction. - let event: StackCreatedEvent = serde_json::from_value(event.parsed_json.clone())?; - if event.stack_small_id.inner as i64 != selected_stack_small_id { - // NOTE: This is a safety check to ensure that the stack small id - // is the same as the one defined in the original transaction - continue; - } - if estimated_compute_units > event.num_compute_units as i64 { - // NOTE: If the estimated compute units are greater than the event compute units, - // this means that whoever made a request to the service has requested more compute units - // than those that it paid for. In this case, we should not process the event, and break - // out of the loop. This will send `None` values to the Atoma service, which will - // trigger an error back to the client. - // SAFETY: It is fine if we do not process the [`StackCreatedEvent`] right away, as it will - // be catched later by the Sui's event subscriber. - error!( - target = "atoma-sui-subscriber", - event = "subscriber-stack-create-event-error", - "Stack create event with id {} has more compute units than the transaction used, this is not possible", - event.stack_small_id.inner - ); - break; - } - let event: StackCreateAndUpdateEvent = (event, estimated_compute_units).into(); - // NOTE: We also send the event to the state manager, so it can be processed - // right away. - compute_units = Some(event.num_compute_units as i64); - stack_small_id = Some(event.stack_small_id.inner as i64); - self.state_manager_sender - .send(AtomaEvent::StackCreateAndUpdateEvent(event)) - .map_err(Box::new)?; - // We found the stack creation event, so we can break out of the loop + + // Move the cast to a separate statement with the attribute + #[allow(clippy::cast_possible_wrap)] + let event_compute_units = event.num_compute_units as i64; + if estimated_compute_units > event_compute_units { break; } + + let event: StackCreateAndUpdateEvent = (event, estimated_compute_units).into(); + + // Move the casts to separate statements with attributes + #[allow(clippy::cast_possible_wrap)] + let compute_units_val = event.num_compute_units as i64; + compute_units = Some(compute_units_val); + + #[allow(clippy::cast_possible_wrap)] + let stack_small_id_val = event.stack_small_id.inner as i64; + stack_small_id = Some(stack_small_id_val); + + self.state_manager_sender + .send(AtomaEvent::StackCreateAndUpdateEvent(event)) + .map_err(Box::new)?; + // We found the stack creation event, so we can break out of the loop + break; } } - // Send the compute units to the Atoma service, so it can be used to validate the - // request. - result_sender - .send((stack_small_id, compute_units)) - .map_err(|_| SuiEventSubscriberError::SendComputeUnitsError)?; } - page = client.event_api().query_events(self.filter.clone(), cursor, limit, false) => { - let EventPage { - data, - next_cursor, - has_next_page, - } = match page { - Ok(page) => page, - Err(e) => { - error!( - target = "atoma-sui-subscriber", - event = "subscriber-read-events-error", - "Failed to read paged events, with error: {e}" - ); - continue; - } - }; - cursor = next_cursor; - - for sui_event in data { - let event_name = sui_event.type_.name; - trace!( + // Send the compute units to the Atoma service, so it can be used to validate the + // request. + result_sender + .send((stack_small_id, compute_units)) + .map_err(|_| SuiEventSubscriberError::SendComputeUnitsError)?; + } + page = client.event_api().query_events(self.filter.clone(), cursor, limit, false) => { + let EventPage { + data, + next_cursor, + has_next_page, + } = match page { + Ok(page) => page, + Err(e) => { + error!( target = "atoma-sui-subscriber", - event = "subscriber-received-new-event", - event_name = %event_name, - "Received new event: {event_name:#?}" + event = "subscriber-read-events-error", + "Failed to read paged events, with error: {e}" ); - match AtomaEventIdentifier::from_str(event_name.as_str()) { - Ok(atoma_event_id) => { - let sender = sui_event.sender; - let atoma_event = match parse_event(&atoma_event_id, sui_event.parsed_json, sender, sui_event.timestamp_ms).await { - Ok(atoma_event) => atoma_event, - Err(e) => { - error!( - target = "atoma-sui-subscriber", - event = "subscriber-event-parse-error", - event_name = %event_name, - "Failed to parse event: {e}", - ); - continue; - } - }; - if filter_event( - &atoma_event, - self.config.node_small_ids().as_ref(), - self.config.task_small_ids().as_ref(), - ) { - self.handle_atoma_event(atoma_event_id, atoma_event).await?; - } else { + continue; + } + }; + cursor = next_cursor; + + for sui_event in data { + let event_name = sui_event.type_.name; + trace!( + target = "atoma-sui-subscriber", + event = "subscriber-received-new-event", + event_name = %event_name, + "Received new event: {event_name:#?}" + ); + match AtomaEventIdentifier::from_str(event_name.as_str()) { + Ok(atoma_event_id) => { + let sender = sui_event.sender; + let atoma_event = match parse_event(&atoma_event_id, sui_event.parsed_json, sender, sui_event.timestamp_ms).await { + Ok(atoma_event) => atoma_event, + Err(e) => { + error!( + target = "atoma-sui-subscriber", + event = "subscriber-event-parse-error", + event_name = %event_name, + "Failed to parse event: {e}", + ); continue; } - } - Err(e) => { - error!( - target = "atoma-sui-subscriber", - event = "subscriber-event-parse-error", - "Failed to parse event: {e}", - ); - // NOTE: `AtomaEvent` didn't match any known event, so we skip it. + }; + if filter_event( + &atoma_event, + self.config.node_small_ids().as_ref(), + self.config.task_small_ids().as_ref(), + ) { + self.handle_atoma_event(atoma_event_id, atoma_event).await?; + } else { + continue; } } + Err(e) => { + error!( + target = "atoma-sui-subscriber", + event = "subscriber-event-parse-error", + "Failed to parse event: {e}", + ); + // NOTE: `AtomaEvent` didn't match any known event, so we skip it. + } } + } - if !has_next_page { - // Update the cursor file with the current cursor - write_cursor_to_toml_file(cursor, &self.config.cursor_path())?; - // No new events to read, so let's wait for a while - trace!( - target = "atoma-sui-subscriber", - event = "subscriber-no-new-events", - wait_duration = DURATION_TO_WAIT_FOR_NEW_EVENTS_IN_MILLIS, - "No new events to read, the node is now synced with the Atoma protocol, waiting until the next synchronization..." - ); - tokio::time::sleep(Duration::from_millis( - DURATION_TO_WAIT_FOR_NEW_EVENTS_IN_MILLIS, - )) - .await; - } + if !has_next_page { + // Update the cursor file with the current cursor + write_cursor_to_toml_file(cursor, &self.config.cursor_path())?; + // No new events to read, so let's wait for a while + trace!( + target = "atoma-sui-subscriber", + event = "subscriber-no-new-events", + wait_duration = DURATION_TO_WAIT_FOR_NEW_EVENTS_IN_MILLIS, + "No new events to read, the node is now synced with the Atoma protocol, waiting until the next synchronization..." + ); + tokio::time::sleep(Duration::from_millis( + DURATION_TO_WAIT_FOR_NEW_EVENTS_IN_MILLIS, + )) + .await; } - shutdown_signal_changed = self.shutdown_signal.changed() => { - match shutdown_signal_changed { - Ok(()) => { - if *self.shutdown_signal.borrow() { - info!( + } + shutdown_signal_changed = self.shutdown_signal.changed() => { + match shutdown_signal_changed { + Ok(()) => { + if *self.shutdown_signal.borrow() { + info!( target = "atoma-sui-subscriber", event = "subscriber-stopped", "Shutdown signal received, gracefully stopping subscriber..." @@ -655,155 +659,78 @@ async fn parse_event( } } -/// Filters events based on a list of small IDs. +/// Filters an Atoma event based on a list of node small IDs. /// -/// This function checks if the given `AtomaEvent` is associated with any of the small IDs -/// provided in the `node_small_ids` and `task_small_ids` options. It returns `true` if the event -/// is relevant to the specified small IDs, and `false` otherwise. +/// This function checks if the given event is related to any of the nodes specified by their small IDs. +/// For node-specific events (like registration, subscriptions, etc.), it returns true only if the +/// event's node small ID is in the provided list. For all other event types, it returns true. /// /// # Arguments /// -/// * `event` - A reference to the `AtomaEvent` enum indicating the type of event to filter. -/// * `node_small_ids` - An optional reference to a vector of node IDs that are relevant for the current context. -/// * `task_small_ids` - An optional reference to a vector of task IDs that are relevant for the current context. +/// * `event` - Reference to the Atoma event to filter +/// * `node_small_ids` - Slice containing the node small IDs to filter by /// /// # Returns /// -/// Returns a `bool` indicating whether the event is associated with any of the small IDs: -/// * `true` if the event is relevant to the small IDs, -/// * `false` if it is not. -/// -/// # Event Types +/// Returns `true` if: +/// - The event is not node-specific +/// - The event's node small ID is contained in `node_small_ids` /// -/// The function specifically checks for the following event types: -/// * `NodeSubscribedToTaskEvent` -/// * `NodeUnsubscribedFromTaskEvent` -/// * `NodeSubscriptionUpdatedEvent` -/// * `StackCreatedEvent` -/// * `StackTrySettleEvent` -/// * `NewStackSettlementAttestationEvent` -/// * `StackSettlementTicketEvent` -/// * `StackSettlementTicketClaimedEvent` -/// * `TaskDeprecationEvent` -/// * `TaskRemovedEvent` -/// -/// For all other event types, the function returns `true`, indicating that they are not -/// filtered out by small IDs. +/// Returns `false` if the event is node-specific but its node small ID is not in `node_small_ids` +fn filter_event_by_node(event: &AtomaEvent, node_small_ids: &[u64]) -> bool { + match event { + AtomaEvent::NodeRegisteredEvent((event, _)) => { + node_small_ids.contains(&event.node_small_id.inner) + } + AtomaEvent::NodeSubscribedToModelEvent(event) => { + node_small_ids.contains(&event.node_small_id.inner) + } + AtomaEvent::NodeSubscribedToTaskEvent(event) => { + node_small_ids.contains(&event.node_small_id.inner) + } + AtomaEvent::NodeUnsubscribedFromTaskEvent(event) => { + node_small_ids.contains(&event.node_small_id.inner) + } + AtomaEvent::NodeSubscriptionUpdatedEvent(event) => { + node_small_ids.contains(&event.node_small_id.inner) + } + _ => true, + } +} + +fn filter_event_by_task(event: &AtomaEvent, task_small_ids: &[u64]) -> bool { + match event { + AtomaEvent::TaskDeprecationEvent(event) => { + task_small_ids.contains(&event.task_small_id.inner) + } + AtomaEvent::TaskRemovedEvent(event) => task_small_ids.contains(&event.task_small_id.inner), + AtomaEvent::StackCreatedEvent((event, _)) => { + task_small_ids.contains(&event.task_small_id.inner) + } + AtomaEvent::NodeSubscribedToTaskEvent(event) => { + task_small_ids.contains(&event.task_small_id.inner) + } + AtomaEvent::NodeUnsubscribedFromTaskEvent(event) => { + task_small_ids.contains(&event.task_small_id.inner) + } + AtomaEvent::NodeSubscriptionUpdatedEvent(event) => { + task_small_ids.contains(&event.task_small_id.inner) + } + _ => true, + } +} + fn filter_event( event: &AtomaEvent, node_small_ids: Option<&Vec>, task_small_ids: Option<&Vec>, ) -> bool { match (node_small_ids, task_small_ids) { - (Some(node_small_ids), Some(task_small_ids)) => match event { - AtomaEvent::NodeSubscribedToTaskEvent(event) => { - node_small_ids.contains(&event.node_small_id.inner) - && task_small_ids.contains(&event.task_small_id.inner) - } - AtomaEvent::NodeUnsubscribedFromTaskEvent(event) => { - node_small_ids.contains(&event.node_small_id.inner) - && task_small_ids.contains(&event.task_small_id.inner) - } - AtomaEvent::NodeSubscriptionUpdatedEvent(event) => { - node_small_ids.contains(&event.node_small_id.inner) - && task_small_ids.contains(&event.task_small_id.inner) - } - AtomaEvent::StackCreatedEvent((event, _)) => { - node_small_ids.contains(&event.selected_node_id.inner) - && task_small_ids.contains(&event.task_small_id.inner) - } - AtomaEvent::StackTrySettleEvent((event, _)) => { - node_small_ids.contains(&event.selected_node_id.inner) - || event - .requested_attestation_nodes - .iter() - .any(|id| node_small_ids.contains(&id.inner)) - } - AtomaEvent::NewStackSettlementAttestationEvent(event) => { - node_small_ids.contains(&event.attestation_node_id.inner) - } - AtomaEvent::StackSettlementTicketEvent(event) => { - node_small_ids.contains(&event.selected_node_id.inner) - || event - .requested_attestation_nodes - .iter() - .any(|id| node_small_ids.contains(&id.inner)) - } - AtomaEvent::StackSettlementTicketClaimedEvent(event) => { - node_small_ids.contains(&event.selected_node_id.inner) - || event - .attestation_nodes - .iter() - .any(|id| node_small_ids.contains(&id.inner)) - } - AtomaEvent::TaskDeprecationEvent(event) => { - task_small_ids.contains(&event.task_small_id.inner) - } - AtomaEvent::TaskRemovedEvent(event) => { - task_small_ids.contains(&event.task_small_id.inner) - } - _ => true, - }, - (Some(node_small_ids), None) => match event { - AtomaEvent::NodeSubscribedToTaskEvent(event) => { - node_small_ids.contains(&event.node_small_id.inner) - } - AtomaEvent::NodeUnsubscribedFromTaskEvent(event) => { - node_small_ids.contains(&event.node_small_id.inner) - } - AtomaEvent::NodeSubscriptionUpdatedEvent(event) => { - node_small_ids.contains(&event.node_small_id.inner) - } - AtomaEvent::StackCreatedEvent((event, _)) => { - node_small_ids.contains(&event.selected_node_id.inner) - } - AtomaEvent::StackTrySettleEvent((event, _)) => { - node_small_ids.contains(&event.selected_node_id.inner) - || event - .requested_attestation_nodes - .iter() - .any(|id| node_small_ids.contains(&id.inner)) - } - AtomaEvent::NewStackSettlementAttestationEvent(event) => { - node_small_ids.contains(&event.attestation_node_id.inner) - } - AtomaEvent::StackSettlementTicketEvent(event) => { - node_small_ids.contains(&event.selected_node_id.inner) - || event - .requested_attestation_nodes - .iter() - .any(|id| node_small_ids.contains(&id.inner)) - } - AtomaEvent::StackSettlementTicketClaimedEvent(event) => { - node_small_ids.contains(&event.selected_node_id.inner) - || event - .attestation_nodes - .iter() - .any(|id| node_small_ids.contains(&id.inner)) - } - _ => true, - }, - (None, Some(task_small_ids)) => match event { - AtomaEvent::TaskDeprecationEvent(event) => { - task_small_ids.contains(&event.task_small_id.inner) - } - AtomaEvent::TaskRemovedEvent(event) => { - task_small_ids.contains(&event.task_small_id.inner) - } - AtomaEvent::StackCreatedEvent((event, _)) => { - task_small_ids.contains(&event.task_small_id.inner) - } - AtomaEvent::NodeSubscribedToTaskEvent(event) => { - task_small_ids.contains(&event.task_small_id.inner) - } - AtomaEvent::NodeUnsubscribedFromTaskEvent(event) => { - task_small_ids.contains(&event.task_small_id.inner) - } - AtomaEvent::NodeSubscriptionUpdatedEvent(event) => { - task_small_ids.contains(&event.task_small_id.inner) - } - _ => true, - }, + (Some(node_ids), Some(task_ids)) => { + filter_event_by_node(event, node_ids) && filter_event_by_task(event, task_ids) + } + (Some(node_ids), None) => filter_event_by_node(event, node_ids), + (None, Some(task_ids)) => filter_event_by_task(event, task_ids), (None, None) => true, } } @@ -826,6 +753,8 @@ pub enum SuiEventSubscriberError { SerializeCursorError(#[from] toml::ser::Error), #[error("Failed to deserialize cursor: {0}")] DeserializeCursorError(#[from] toml::de::Error), + #[error("Failed to convert stack small id: {0}")] + ConversionError(#[from] std::num::TryFromIntError), } #[cfg(test)] diff --git a/atoma-utils/src/encryption.rs b/atoma-utils/src/encryption.rs index c5151e5d..92b8f4d8 100644 --- a/atoma-utils/src/encryption.rs +++ b/atoma-utils/src/encryption.rs @@ -1,58 +1,27 @@ use aes_gcm::{aead::Aead, Aes256Gcm, Error as AesError, KeyInit}; use hkdf::Hkdf; use sha2::Sha256; -use thiserror::Error; use x25519_dalek::SharedSecret; pub const NONCE_BYTE_SIZE: usize = 12; -type Result = std::result::Result; +type Result = std::result::Result; -/// Decrypts ciphertext using AES-256-GCM with a derived key from a shared secret. -/// -/// This function performs the following steps: -/// 1. Derives a symmetric key from the shared secret using HKDF-SHA256 -/// 2. Initializes an AES-256-GCM cipher with the derived key -/// 3. Decrypts the ciphertext using the provided nonce +/// Decrypts a ciphertext using the provided shared secret and nonce. /// /// # Arguments -/// -/// * `shared_secret` - The shared secret derived from X25519 key exchange +/// * `shared_secret` - The shared secret key for decryption /// * `ciphertext` - The encrypted data to decrypt -/// * `salt` - Salt value used in the key derivation process -/// * `nonce` - Unique nonce (number used once) for AES-GCM +/// * `salt` - Salt used in key derivation +/// * `nonce` - Nonce used in encryption /// /// # Returns +/// The decrypted plaintext as a byte vector /// -/// Returns the decrypted plaintext as a vector of bytes, or a `DecryptionError` if the operation fails. -/// -/// # Example -/// -/// ```rust,ignore -/// use atoma_tdx::decryption::decrypt_ciphertext; -/// # use your_crate::SharedSecret; -/// -/// # fn main() -> Result<(), Box> { -/// # let shared_secret = SharedSecret::new(); -/// let ciphertext = vec![/* encrypted data */]; -/// let salt = b"unique_salt_value"; -/// let nonce = b"unique_nonce_12"; // Must be 12 bytes for AES-GCM -/// -/// let plaintext = decrypt_ciphertext( -/// shared_secret, -/// &ciphertext, -/// salt, -/// nonce -/// )?; -/// # Ok(()) -/// # } -/// ``` -/// -/// # Security Considerations -/// -/// - The nonce must be unique for each encryption operation -/// - The salt should be randomly generated for each key derivation -/// - The shared secret should be derived using secure key exchange +/// # Errors +/// Returns an error if: +/// - Key derivation fails +/// - Decryption fails due to invalid data or parameters pub fn decrypt_ciphertext( shared_secret: &SharedSecret, ciphertext: &[u8], @@ -62,57 +31,29 @@ pub fn decrypt_ciphertext( let hkdf = Hkdf::::new(Some(salt), shared_secret.as_bytes()); let mut symmetric_key = [0u8; 32]; hkdf.expand(b"", &mut symmetric_key) - .map_err(EncryptionError::KeyExpansionFailed)?; + .map_err(Error::KeyExpansionFailed)?; let cipher = Aes256Gcm::new(&symmetric_key.into()); cipher .decrypt(nonce.into(), ciphertext) - .map_err(EncryptionError::DecryptionFailed) + .map_err(Error::DecryptionFailed) } -/// Encrypts plaintext using AES-256-GCM with a derived key from a shared secret. -/// -/// This function performs the following steps: -/// 1. Derives a symmetric key from the shared secret using HKDF-SHA256 -/// 2. Initializes an AES-256-GCM cipher with the derived key -/// 3. Generates a random nonce -/// 4. Encrypts the plaintext using the generated nonce +/// Encrypts plaintext using the provided shared secret. /// /// # Arguments -/// /// * `plaintext` - The data to encrypt -/// * `shared_secret` - The shared secret derived from X25519 key exchange -/// * `salt` - Salt value used in the key derivation process +/// * `shared_secret` - The shared secret key for encryption +/// * `salt` - Salt for key derivation +/// * `nonce` - Optional nonce (generated if None) /// /// # Returns +/// Tuple of (encrypted data, nonce used) /// -/// Returns a tuple containing the encrypted ciphertext and the generated nonce, or a `DecryptionError` if the operation fails. -/// -/// # Example -/// -/// ```rust,ignore -/// use atoma_tdx::decryption::encrypt_plaintext; -/// # use your_crate::SharedSecret; -/// -/// # fn main() -> Result<(), Box> { -/// # let shared_secret = SharedSecret::new(); -/// let plaintext = b"secret message"; -/// let salt = b"unique_salt_value"; -/// -/// let (ciphertext, nonce) = encrypt_plaintext( -/// plaintext, -/// shared_secret, -/// salt -/// )?; -/// # Ok(()) -/// # } -/// ``` -/// -/// # Security Considerations -/// -/// - The salt should be randomly generated for each key derivation -/// - The shared secret should be derived using secure key exchange -/// - The generated nonce is guaranteed to be unique for each encryption operation +/// # Errors +/// Returns an error if: +/// - Key derivation fails +/// - Encryption operation fails pub fn encrypt_plaintext( plaintext: &[u8], shared_secret: &SharedSecret, @@ -122,19 +63,20 @@ pub fn encrypt_plaintext( let hkdf = Hkdf::::new(Some(salt), shared_secret.as_bytes()); let mut symmetric_key = [0u8; 32]; hkdf.expand(b"", &mut symmetric_key) - .map_err(EncryptionError::KeyExpansionFailed)?; + .map_err(Error::KeyExpansionFailed)?; let cipher = Aes256Gcm::new(&symmetric_key.into()); - let nonce = nonce.unwrap_or(rand::random::<[u8; NONCE_BYTE_SIZE]>()); + let nonce = nonce.unwrap_or_else(rand::random::<[u8; NONCE_BYTE_SIZE]>); let ciphertext = cipher .encrypt(&nonce.into(), plaintext) - .map_err(EncryptionError::EncryptionFailed)?; + .map_err(Error::EncryptionFailed)?; Ok((ciphertext, nonce)) } -#[derive(Debug, Error)] -pub enum EncryptionError { +/// Errors that can occur during encryption/decryption operations +#[derive(Debug, thiserror::Error)] +pub enum Error { #[error("Failed to decrypt ciphertext, with error: `{0}`")] DecryptionFailed(AesError), #[error("Failed to encrypt plaintext, with error: `{0}`")] diff --git a/atoma-utils/src/hashing.rs b/atoma-utils/src/hashing.rs index 24b19eb4..e7b89263 100644 --- a/atoma-utils/src/hashing.rs +++ b/atoma-utils/src/hashing.rs @@ -3,21 +3,13 @@ use blake2::{ Blake2b, Digest, }; -/// Computes a Blake2b hash of the input data +/// Computes the `BLAKE2b` hash of the provided data /// /// # Arguments /// * `slice` - A byte slice containing the data to be hashed -/// /// # Returns -/// A 32-byte [`GenericArray`] containing the computed hash -/// -/// # Example -/// ```rust,ignore -/// use atoma_utils::hashing::blake2b_hash; -/// -/// let data = b"Hello, world!"; -/// let hash = blake2b_hash(data); -/// ``` +/// The 32-byte `BLAKE2b` hash +#[must_use] pub fn blake2b_hash(slice: &[u8]) -> GenericArray { let mut hasher = Blake2b::new(); hasher.update(slice); diff --git a/atoma-utils/src/lib.rs b/atoma-utils/src/lib.rs index 9419c285..589c2580 100644 --- a/atoma-utils/src/lib.rs +++ b/atoma-utils/src/lib.rs @@ -207,23 +207,19 @@ pub fn verify_signature( )] pub fn parse_json_byte_array(value: &serde_json::Value, field: &str) -> Result, String> { let array = value.get(field).and_then(|v| v.as_array()).ok_or_else(|| { - error!("Error getting field array {} from JSON", field); - format!("Error getting field array {} from JSON", field) + error!("Error getting field array {field} from JSON"); + format!("Error getting field array {field} from JSON") })?; array .iter() .map(|b| { - b.as_u64().map(|u| u as u8).ok_or_else(|| { - error!( - "Error parsing field array {} values as bytes from JSON", - field - ); - format!( - "Error parsing field array {} values as bytes from JSON", - field - ) - }) + b.as_u64() + .and_then(|u| u8::try_from(u).ok()) + .ok_or_else(|| { + error!("Error parsing field array {field} values as bytes from JSON"); + format!("Error parsing field array {field} values as bytes from JSON") + }) }) .collect() }