Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 13 additions & 12 deletions atoma-service/src/handlers/chat_completions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ use crate::{
middleware::RequestMetadata,
};

use super::handle_confidential_compute_encryption_response;
use super::{handle_confidential_compute_encryption_response, handle_status_code_error};

/// The path for confidential chat completions requests
pub const CONFIDENTIAL_CHAT_COMPLETIONS_PATH: &str = "/v1/confidential/chat/completions";
Expand Down Expand Up @@ -683,10 +683,11 @@ async fn handle_streaming_response(
})?;

if !response.status().is_success() {
return Err(AtomaServiceError::InternalError {
message: "Inference service returned error".to_string(),
endpoint,
});
let error = response
.status()
.canonical_reason()
.unwrap_or("Unknown error");
handle_status_code_error(response.status(), &endpoint, error)?;
}

let stream = response.bytes_stream();
Expand Down Expand Up @@ -1066,6 +1067,8 @@ pub mod utils {
use atoma_utils::constants::PAYLOAD_HASH_SIZE;
use prometheus::HistogramTimer;

use crate::handlers::handle_status_code_error;

use super::{
handle_confidential_compute_encryption_response, info, instrument,
sign_response_and_update_stack_hash, update_stack_num_compute_units, AppState,
Expand Down Expand Up @@ -1271,13 +1274,11 @@ pub mod utils {
})?;

if !response.status().is_success() {
return Err(AtomaServiceError::InternalError {
message: format!(
"Inference service returned non-success status code: {}",
response.status()
),
endpoint: endpoint.to_string(),
});
let error = response
.status()
.canonical_reason()
.unwrap_or("Unknown error");
handle_status_code_error(response.status(), endpoint, error)?;
}

response.json::<Value>().await.map_err(|e| {
Expand Down
14 changes: 7 additions & 7 deletions atoma-service/src/handlers/embeddings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ use serde_json::Value;
use tracing::{info, instrument};
use utoipa::OpenApi;

use super::handle_status_code_error;

/// The path for confidential embeddings requests
pub const CONFIDENTIAL_EMBEDDINGS_PATH: &str = "/v1/confidential/embeddings";

Expand Down Expand Up @@ -321,13 +323,11 @@ async fn handle_embeddings_response(
})?;

if !response.status().is_success() {
return Err(AtomaServiceError::InternalError {
message: format!(
"Inference service returned non-success status code: {}",
response.status()
),
endpoint: endpoint.to_string(),
});
let error = response
.status()
.canonical_reason()
.unwrap_or("Unknown error");
handle_status_code_error(response.status(), endpoint, error)?;
}

let mut response_body =
Expand Down
17 changes: 9 additions & 8 deletions atoma-service/src/handlers/image_generations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ use serde_json::Value;
use tracing::{info, instrument};
use utoipa::OpenApi;

use super::{handle_confidential_compute_encryption_response, sign_response_and_update_stack_hash};
use super::{
handle_confidential_compute_encryption_response, handle_status_code_error,
sign_response_and_update_stack_hash,
};

/// The path for confidential image generations requests
pub const CONFIDENTIAL_IMAGE_GENERATIONS_PATH: &str = "/v1/confidential/images/generations";
Expand Down Expand Up @@ -307,13 +310,11 @@ async fn handle_image_generations_response(
})?;

if !response.status().is_success() {
return Err(AtomaServiceError::InternalError {
message: format!(
"Inference service returned non-success status code: {}",
response.status()
),
endpoint: endpoint.to_string(),
});
let error = response
.status()
.canonical_reason()
.unwrap_or("Unknown error");
handle_status_code_error(response.status(), endpoint, error)?;
}

let mut response_body =
Expand Down
40 changes: 40 additions & 0 deletions atoma-service/src/handlers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use atoma_confidential::types::{
use atoma_utils::hashing::blake2b_hash;
use base64::{engine::general_purpose::STANDARD, Engine};
use flume::Sender;
use hyper::StatusCode;
use image_generations::CONFIDENTIAL_IMAGE_GENERATIONS_PATH;
use serde_json::{json, Value};
use tracing::{info, instrument};
Expand Down Expand Up @@ -305,3 +306,42 @@ pub fn update_stack_num_compute_units(
endpoint: endpoint.to_string(),
})
}

/// Handles the status code returned by the inference service.
///
/// This function maps the status code to an appropriate error variant.
///
/// # Arguments
///
/// * `status_code` - The status code returned by the inference service
/// * `endpoint` - The API endpoint path where the request was received
/// * `error` - The error message returned by the inference service
///
/// # Returns
///
/// Returns an `AtomaServiceError` variant based on the status code.
#[instrument(level = "info", skip_all, fields(endpoint))]
pub fn handle_status_code_error(
status_code: StatusCode,
endpoint: &str,
error: &str,
) -> Result<(), AtomaServiceError> {
match status_code {
StatusCode::UNAUTHORIZED => Err(AtomaServiceError::AuthError {
auth_error: format!("Unauthorized response from inference service: {error}"),
endpoint: endpoint.to_string(),
}),
StatusCode::INTERNAL_SERVER_ERROR => Err(AtomaServiceError::InternalError {
message: format!("Inference service returned internal server error: {error}"),
endpoint: endpoint.to_string(),
}),
StatusCode::BAD_REQUEST => Err(AtomaServiceError::InvalidBody {
message: format!("Inference service returned bad request error: {error}"),
endpoint: endpoint.to_string(),
}),
_ => Err(AtomaServiceError::InternalError {
message: format!("Inference service returned non-success error: {error}"),
endpoint: endpoint.to_string(),
}),
}
}
Loading