From e7de6d7d84775149a8ab8d9f2f7ad01e7b139dfd Mon Sep 17 00:00:00 2001 From: Jorge Antonio Date: Wed, 29 Jan 2025 17:33:21 +0000 Subject: [PATCH 1/2] update return of appropriate status code from inference service --- .../src/handlers/chat_completions.rs | 25 ++++++------ atoma-service/src/handlers/embeddings.rs | 14 +++---- .../src/handlers/image_generations.rs | 17 ++++---- atoma-service/src/handlers/mod.rs | 40 +++++++++++++++++++ 4 files changed, 69 insertions(+), 27 deletions(-) diff --git a/atoma-service/src/handlers/chat_completions.rs b/atoma-service/src/handlers/chat_completions.rs index 48d58162..9efa897e 100644 --- a/atoma-service/src/handlers/chat_completions.rs +++ b/atoma-service/src/handlers/chat_completions.rs @@ -34,7 +34,7 @@ use crate::{ middleware::RequestMetadata, }; -use super::handle_confidential_compute_encryption_response; +use super::{handle_confidential_compute_encryption_response, handle_status_code}; /// The path for confidential chat completions requests pub const CONFIDENTIAL_CHAT_COMPLETIONS_PATH: &str = "/v1/confidential/chat/completions"; @@ -683,10 +683,11 @@ async fn handle_streaming_response( })?; if !response.status().is_success() { - return Err(AtomaServiceError::InternalError { - message: "Inference service returned error".to_string(), - endpoint, - }); + let error = response + .status() + .canonical_reason() + .unwrap_or("Unknown error"); + handle_status_code(response.status(), &endpoint, error)?; } let stream = response.bytes_stream(); @@ -1066,6 +1067,8 @@ pub mod utils { use atoma_utils::constants::PAYLOAD_HASH_SIZE; use prometheus::HistogramTimer; + use crate::handlers::handle_status_code; + use super::{ handle_confidential_compute_encryption_response, info, instrument, sign_response_and_update_stack_hash, update_stack_num_compute_units, AppState, @@ -1271,13 +1274,11 @@ pub mod utils { })?; if !response.status().is_success() { - return Err(AtomaServiceError::InternalError { - message: format!( - "Inference service returned non-success status code: {}", - response.status() - ), - endpoint: endpoint.to_string(), - }); + let error = response + .status() + .canonical_reason() + .unwrap_or("Unknown error"); + handle_status_code(response.status(), endpoint, error)?; } response.json::().await.map_err(|e| { diff --git a/atoma-service/src/handlers/embeddings.rs b/atoma-service/src/handlers/embeddings.rs index 4c9b28a2..b13e2407 100644 --- a/atoma-service/src/handlers/embeddings.rs +++ b/atoma-service/src/handlers/embeddings.rs @@ -19,6 +19,8 @@ use serde_json::Value; use tracing::{info, instrument}; use utoipa::OpenApi; +use super::handle_status_code; + /// The path for confidential embeddings requests pub const CONFIDENTIAL_EMBEDDINGS_PATH: &str = "/v1/confidential/embeddings"; @@ -321,13 +323,11 @@ async fn handle_embeddings_response( })?; if !response.status().is_success() { - return Err(AtomaServiceError::InternalError { - message: format!( - "Inference service returned non-success status code: {}", - response.status() - ), - endpoint: endpoint.to_string(), - }); + let error = response + .status() + .canonical_reason() + .unwrap_or("Unknown error"); + handle_status_code(response.status(), endpoint, error)?; } let mut response_body = diff --git a/atoma-service/src/handlers/image_generations.rs b/atoma-service/src/handlers/image_generations.rs index d19bedc5..22bb16ab 100644 --- a/atoma-service/src/handlers/image_generations.rs +++ b/atoma-service/src/handlers/image_generations.rs @@ -18,7 +18,10 @@ use serde_json::Value; use tracing::{info, instrument}; use utoipa::OpenApi; -use super::{handle_confidential_compute_encryption_response, sign_response_and_update_stack_hash}; +use super::{ + handle_confidential_compute_encryption_response, handle_status_code, + sign_response_and_update_stack_hash, +}; /// The path for confidential image generations requests pub const CONFIDENTIAL_IMAGE_GENERATIONS_PATH: &str = "/v1/confidential/images/generations"; @@ -307,13 +310,11 @@ async fn handle_image_generations_response( })?; if !response.status().is_success() { - return Err(AtomaServiceError::InternalError { - message: format!( - "Inference service returned non-success status code: {}", - response.status() - ), - endpoint: endpoint.to_string(), - }); + let error = response + .status() + .canonical_reason() + .unwrap_or("Unknown error"); + handle_status_code(response.status(), endpoint, error)?; } let mut response_body = diff --git a/atoma-service/src/handlers/mod.rs b/atoma-service/src/handlers/mod.rs index 4ba00449..27d6334c 100644 --- a/atoma-service/src/handlers/mod.rs +++ b/atoma-service/src/handlers/mod.rs @@ -9,6 +9,7 @@ use atoma_confidential::types::{ use atoma_utils::hashing::blake2b_hash; use base64::{engine::general_purpose::STANDARD, Engine}; use flume::Sender; +use hyper::StatusCode; use image_generations::CONFIDENTIAL_IMAGE_GENERATIONS_PATH; use serde_json::{json, Value}; use tracing::{info, instrument}; @@ -305,3 +306,42 @@ pub fn update_stack_num_compute_units( endpoint: endpoint.to_string(), }) } + +/// Handles the status code returned by the inference service. +/// +/// This function maps the status code to an appropriate error variant. +/// +/// # Arguments +/// +/// * `status_code` - The status code returned by the inference service +/// * `endpoint` - The API endpoint path where the request was received +/// * `error` - The error message returned by the inference service +/// +/// # Returns +/// +/// Returns an `AtomaServiceError` variant based on the status code. +#[instrument(level = "info", skip_all, fields(endpoint))] +pub fn handle_status_code( + status_code: StatusCode, + endpoint: &str, + error: &str, +) -> Result<(), AtomaServiceError> { + match status_code { + StatusCode::UNAUTHORIZED => Err(AtomaServiceError::AuthError { + auth_error: format!("Unauthorized response from inference service: {error}"), + endpoint: endpoint.to_string(), + }), + StatusCode::INTERNAL_SERVER_ERROR => Err(AtomaServiceError::InternalError { + message: format!("Inference service returned internal server error: {error}"), + endpoint: endpoint.to_string(), + }), + StatusCode::BAD_REQUEST => Err(AtomaServiceError::InvalidBody { + message: format!("Inference service returned bad request error: {error}"), + endpoint: endpoint.to_string(), + }), + _ => Err(AtomaServiceError::InternalError { + message: format!("Inference service returned non-success error: {error}"), + endpoint: endpoint.to_string(), + }), + } +} From a152304cdbd380bc902c4b2103fb09e385f3c1af Mon Sep 17 00:00:00 2001 From: Jorge Antonio Date: Wed, 29 Jan 2025 17:47:45 +0000 Subject: [PATCH 2/2] rename method --- atoma-service/src/handlers/chat_completions.rs | 8 ++++---- atoma-service/src/handlers/embeddings.rs | 4 ++-- atoma-service/src/handlers/image_generations.rs | 4 ++-- atoma-service/src/handlers/mod.rs | 2 +- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/atoma-service/src/handlers/chat_completions.rs b/atoma-service/src/handlers/chat_completions.rs index 9efa897e..a97ed236 100644 --- a/atoma-service/src/handlers/chat_completions.rs +++ b/atoma-service/src/handlers/chat_completions.rs @@ -34,7 +34,7 @@ use crate::{ middleware::RequestMetadata, }; -use super::{handle_confidential_compute_encryption_response, handle_status_code}; +use super::{handle_confidential_compute_encryption_response, handle_status_code_error}; /// The path for confidential chat completions requests pub const CONFIDENTIAL_CHAT_COMPLETIONS_PATH: &str = "/v1/confidential/chat/completions"; @@ -687,7 +687,7 @@ async fn handle_streaming_response( .status() .canonical_reason() .unwrap_or("Unknown error"); - handle_status_code(response.status(), &endpoint, error)?; + handle_status_code_error(response.status(), &endpoint, error)?; } let stream = response.bytes_stream(); @@ -1067,7 +1067,7 @@ pub mod utils { use atoma_utils::constants::PAYLOAD_HASH_SIZE; use prometheus::HistogramTimer; - use crate::handlers::handle_status_code; + use crate::handlers::handle_status_code_error; use super::{ handle_confidential_compute_encryption_response, info, instrument, @@ -1278,7 +1278,7 @@ pub mod utils { .status() .canonical_reason() .unwrap_or("Unknown error"); - handle_status_code(response.status(), endpoint, error)?; + handle_status_code_error(response.status(), endpoint, error)?; } response.json::().await.map_err(|e| { diff --git a/atoma-service/src/handlers/embeddings.rs b/atoma-service/src/handlers/embeddings.rs index b13e2407..7a016eb6 100644 --- a/atoma-service/src/handlers/embeddings.rs +++ b/atoma-service/src/handlers/embeddings.rs @@ -19,7 +19,7 @@ use serde_json::Value; use tracing::{info, instrument}; use utoipa::OpenApi; -use super::handle_status_code; +use super::handle_status_code_error; /// The path for confidential embeddings requests pub const CONFIDENTIAL_EMBEDDINGS_PATH: &str = "/v1/confidential/embeddings"; @@ -327,7 +327,7 @@ async fn handle_embeddings_response( .status() .canonical_reason() .unwrap_or("Unknown error"); - handle_status_code(response.status(), endpoint, error)?; + handle_status_code_error(response.status(), endpoint, error)?; } let mut response_body = diff --git a/atoma-service/src/handlers/image_generations.rs b/atoma-service/src/handlers/image_generations.rs index 22bb16ab..ee7085c0 100644 --- a/atoma-service/src/handlers/image_generations.rs +++ b/atoma-service/src/handlers/image_generations.rs @@ -19,7 +19,7 @@ use tracing::{info, instrument}; use utoipa::OpenApi; use super::{ - handle_confidential_compute_encryption_response, handle_status_code, + handle_confidential_compute_encryption_response, handle_status_code_error, sign_response_and_update_stack_hash, }; @@ -314,7 +314,7 @@ async fn handle_image_generations_response( .status() .canonical_reason() .unwrap_or("Unknown error"); - handle_status_code(response.status(), endpoint, error)?; + handle_status_code_error(response.status(), endpoint, error)?; } let mut response_body = diff --git a/atoma-service/src/handlers/mod.rs b/atoma-service/src/handlers/mod.rs index 27d6334c..4f4e0676 100644 --- a/atoma-service/src/handlers/mod.rs +++ b/atoma-service/src/handlers/mod.rs @@ -321,7 +321,7 @@ pub fn update_stack_num_compute_units( /// /// Returns an `AtomaServiceError` variant based on the status code. #[instrument(level = "info", skip_all, fields(endpoint))] -pub fn handle_status_code( +pub fn handle_status_code_error( status_code: StatusCode, endpoint: &str, error: &str,