diff --git a/atoma-service/src/handlers/chat_completions.rs b/atoma-service/src/handlers/chat_completions.rs index 48d58162..a97ed236 100644 --- a/atoma-service/src/handlers/chat_completions.rs +++ b/atoma-service/src/handlers/chat_completions.rs @@ -34,7 +34,7 @@ use crate::{ middleware::RequestMetadata, }; -use super::handle_confidential_compute_encryption_response; +use super::{handle_confidential_compute_encryption_response, handle_status_code_error}; /// The path for confidential chat completions requests pub const CONFIDENTIAL_CHAT_COMPLETIONS_PATH: &str = "/v1/confidential/chat/completions"; @@ -683,10 +683,11 @@ async fn handle_streaming_response( })?; if !response.status().is_success() { - return Err(AtomaServiceError::InternalError { - message: "Inference service returned error".to_string(), - endpoint, - }); + let error = response + .status() + .canonical_reason() + .unwrap_or("Unknown error"); + handle_status_code_error(response.status(), &endpoint, error)?; } let stream = response.bytes_stream(); @@ -1066,6 +1067,8 @@ pub mod utils { use atoma_utils::constants::PAYLOAD_HASH_SIZE; use prometheus::HistogramTimer; + use crate::handlers::handle_status_code_error; + use super::{ handle_confidential_compute_encryption_response, info, instrument, sign_response_and_update_stack_hash, update_stack_num_compute_units, AppState, @@ -1271,13 +1274,11 @@ pub mod utils { })?; if !response.status().is_success() { - return Err(AtomaServiceError::InternalError { - message: format!( - "Inference service returned non-success status code: {}", - response.status() - ), - endpoint: endpoint.to_string(), - }); + let error = response + .status() + .canonical_reason() + .unwrap_or("Unknown error"); + handle_status_code_error(response.status(), endpoint, error)?; } response.json::().await.map_err(|e| { diff --git a/atoma-service/src/handlers/embeddings.rs b/atoma-service/src/handlers/embeddings.rs index 4c9b28a2..7a016eb6 100644 --- a/atoma-service/src/handlers/embeddings.rs +++ b/atoma-service/src/handlers/embeddings.rs @@ -19,6 +19,8 @@ use serde_json::Value; use tracing::{info, instrument}; use utoipa::OpenApi; +use super::handle_status_code_error; + /// The path for confidential embeddings requests pub const CONFIDENTIAL_EMBEDDINGS_PATH: &str = "/v1/confidential/embeddings"; @@ -321,13 +323,11 @@ async fn handle_embeddings_response( })?; if !response.status().is_success() { - return Err(AtomaServiceError::InternalError { - message: format!( - "Inference service returned non-success status code: {}", - response.status() - ), - endpoint: endpoint.to_string(), - }); + let error = response + .status() + .canonical_reason() + .unwrap_or("Unknown error"); + handle_status_code_error(response.status(), endpoint, error)?; } let mut response_body = diff --git a/atoma-service/src/handlers/image_generations.rs b/atoma-service/src/handlers/image_generations.rs index d19bedc5..ee7085c0 100644 --- a/atoma-service/src/handlers/image_generations.rs +++ b/atoma-service/src/handlers/image_generations.rs @@ -18,7 +18,10 @@ use serde_json::Value; use tracing::{info, instrument}; use utoipa::OpenApi; -use super::{handle_confidential_compute_encryption_response, sign_response_and_update_stack_hash}; +use super::{ + handle_confidential_compute_encryption_response, handle_status_code_error, + sign_response_and_update_stack_hash, +}; /// The path for confidential image generations requests pub const CONFIDENTIAL_IMAGE_GENERATIONS_PATH: &str = "/v1/confidential/images/generations"; @@ -307,13 +310,11 @@ async fn handle_image_generations_response( })?; if !response.status().is_success() { - return Err(AtomaServiceError::InternalError { - message: format!( - "Inference service returned non-success status code: {}", - response.status() - ), - endpoint: endpoint.to_string(), - }); + let error = response + .status() + .canonical_reason() + .unwrap_or("Unknown error"); + handle_status_code_error(response.status(), endpoint, error)?; } let mut response_body = diff --git a/atoma-service/src/handlers/mod.rs b/atoma-service/src/handlers/mod.rs index 4ba00449..4f4e0676 100644 --- a/atoma-service/src/handlers/mod.rs +++ b/atoma-service/src/handlers/mod.rs @@ -9,6 +9,7 @@ use atoma_confidential::types::{ use atoma_utils::hashing::blake2b_hash; use base64::{engine::general_purpose::STANDARD, Engine}; use flume::Sender; +use hyper::StatusCode; use image_generations::CONFIDENTIAL_IMAGE_GENERATIONS_PATH; use serde_json::{json, Value}; use tracing::{info, instrument}; @@ -305,3 +306,42 @@ pub fn update_stack_num_compute_units( endpoint: endpoint.to_string(), }) } + +/// Handles the status code returned by the inference service. +/// +/// This function maps the status code to an appropriate error variant. +/// +/// # Arguments +/// +/// * `status_code` - The status code returned by the inference service +/// * `endpoint` - The API endpoint path where the request was received +/// * `error` - The error message returned by the inference service +/// +/// # Returns +/// +/// Returns an `AtomaServiceError` variant based on the status code. +#[instrument(level = "info", skip_all, fields(endpoint))] +pub fn handle_status_code_error( + status_code: StatusCode, + endpoint: &str, + error: &str, +) -> Result<(), AtomaServiceError> { + match status_code { + StatusCode::UNAUTHORIZED => Err(AtomaServiceError::AuthError { + auth_error: format!("Unauthorized response from inference service: {error}"), + endpoint: endpoint.to_string(), + }), + StatusCode::INTERNAL_SERVER_ERROR => Err(AtomaServiceError::InternalError { + message: format!("Inference service returned internal server error: {error}"), + endpoint: endpoint.to_string(), + }), + StatusCode::BAD_REQUEST => Err(AtomaServiceError::InvalidBody { + message: format!("Inference service returned bad request error: {error}"), + endpoint: endpoint.to_string(), + }), + _ => Err(AtomaServiceError::InternalError { + message: format!("Inference service returned non-success error: {error}"), + endpoint: endpoint.to_string(), + }), + } +}