Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions model_gateway/src/routers/grpc/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -412,10 +412,6 @@ impl RequestContext {
}

/// Get Arc clone of completion request (panics if not completion)
#[expect(
dead_code,
reason = "Arc accessor is introduced before later stacked PRs use it from completion stages"
)]
#[expect(
clippy::panic,
reason = "typed accessor: caller guarantees variant via RequestType construction"
Expand Down Expand Up @@ -651,10 +647,6 @@ pub(crate) enum ExecutionResult {

/// Final processed response
#[derive(Debug)]
#[expect(
dead_code,
reason = "Completion final response is introduced in the plumbing PR before later stacked PRs produce it"
)]
pub(crate) enum FinalResponse {
Chat(ChatCompletionResponse),
/// Generate response is a Vec of GenerateResponse (n=1 returns single item, n>1 returns multiple)
Expand Down
4 changes: 0 additions & 4 deletions model_gateway/src/routers/grpc/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -551,10 +551,6 @@ impl RequestPipeline {
}

/// Execute the complete pipeline for a completion request
#[expect(
dead_code,
reason = "Completion pipeline entrypoint is introduced before later stacked PRs wire the router to call it"
)]
pub async fn execute_completion(
&self,
request: Arc<CompletionRequest>,
Expand Down
15 changes: 15 additions & 0 deletions model_gateway/src/routers/grpc/regular/stages/completion/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
//! Completion endpoint pipeline stages
//!
//! These stages handle completion-specific preprocessing, request building, and
//! response processing. `CompletionRequest` flows natively as
//! `RequestType::Completion` through every pipeline stage — preparation, worker
//! selection, client acquisition, request building, execution, and response
//! processing — following the same architecture as chat and generate.

mod preparation;
mod request_building;
mod response_processing;

pub(crate) use preparation::CompletionPreparationStage;
pub(crate) use request_building::CompletionRequestBuildingStage;
pub(crate) use response_processing::CompletionResponseProcessingStage;
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
//! Completion preparation stage: Resolve prompt, tokenize, create stop decoder
//!
//! Structurally mirrors `GeneratePreparationStage` but reads from
//! `CompletionRequest` fields directly (prompt, stop, skip_special_tokens).
//! No chat template, no tool calls, no multimodal.

use async_trait::async_trait;
use axum::response::Response;
use openai_protocol::common::StringOrArray;
use tracing::error;

use crate::routers::{
error,
grpc::{
common::stages::PipelineStage,
context::{PreparationOutput, RequestContext},
utils,
},
};

pub(crate) struct CompletionPreparationStage;

#[async_trait]
impl PipelineStage for CompletionPreparationStage {
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> {
let request = ctx.completion_request_arc();

let tokenizer =
utils::resolve_tokenizer(ctx, "CompletionPreparationStage::execute").map_err(|e| *e)?;

let prompt_text = match &request.prompt {
StringOrArray::String(s) => s.clone(),
StringOrArray::Array(_) => {
return Err(error::bad_request(
"batch_prompts_not_supported",
"Batched prompt arrays are not supported for gRPC /v1/completions yet",
));
}
};

let encoding = tokenizer.encode(&prompt_text, false).map_err(|e| {
error!(
function = "CompletionPreparationStage::execute",
error = %e,
"Tokenization failed"
);
error::bad_request("tokenization_failed", format!("Tokenization failed: {e}"))
})?;

let stop_decoder = utils::create_stop_decoder(
&tokenizer,
request.stop.as_ref(),
request.stop_token_ids.as_ref(),
request.skip_special_tokens,
request.no_stop_trim,
);

ctx.state.preparation = Some(PreparationOutput {
original_text: Some(prompt_text),
token_ids: encoding.token_ids().to_vec(),
processed_messages: None,
tool_constraints: None,
filtered_request: None,
harmony_mode: false,
selection_text: None,
harmony_messages: None,
harmony_stop_ids: None,
});

ctx.state.response.stop_decoder = Some(stop_decoder);

Ok(None)
}

fn name(&self) -> &'static str {
"CompletionPreparation"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
//! Completion request building stage: Build proto GenerateRequest from CompletionRequest
//!
//! Follows the same pattern as `ChatRequestBuildingStage` and
//! `GenerateRequestBuildingStage`: reads preparation output, gets the gRPC
//! client, and calls `builder_client.build_completion_request()` to convert
//! `CompletionRequest` directly to the backend proto format.

use async_trait::async_trait;
use axum::response::Response;
use tracing::error;
use uuid::Uuid;

use crate::routers::{
error,
grpc::{
common::stages::{helpers, PipelineStage},
context::{ClientSelection, RequestContext},
proto_wrapper::ProtoRequest,
},
};

pub(crate) struct CompletionRequestBuildingStage {
inject_pd_metadata: bool,
}

impl CompletionRequestBuildingStage {
pub fn new(inject_pd_metadata: bool) -> Self {
Self { inject_pd_metadata }
}
}

#[async_trait]
impl PipelineStage for CompletionRequestBuildingStage {
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> {
let prep = ctx.state.preparation.as_ref().ok_or_else(|| {
error!(
function = "CompletionRequestBuildingStage::execute",
"Preparation not completed"
);
error::internal_error("preparation_not_completed", "Preparation not completed")
})?;

let clients = ctx.state.clients.as_ref().ok_or_else(|| {
error!(
function = "CompletionRequestBuildingStage::execute",
"Client acquisition not completed"
);
error::internal_error(
"client_acquisition_not_completed",
"Client acquisition not completed",
)
})?;

let completion_request = ctx.completion_request_arc();

let builder_client = match clients {
ClientSelection::Single { client } => client,
ClientSelection::Dual { prefill, .. } => prefill,
};

let request_id = format!("cmpl-{}", Uuid::now_v7());

let mut proto_request = builder_client
.build_completion_request(
request_id,
&completion_request,
prep.original_text.clone().unwrap_or_default(),
prep.token_ids.clone(),
)
.map_err(|e| {
error!(
function = "CompletionRequestBuildingStage::execute",
error = %e,
"Failed to build completion request"
);
error::bad_request(
"invalid_request_parameters",
format!("Invalid request parameters: {e}"),
)
})?;
Comment on lines +63 to +80
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The CompletionPreparationStage always sets original_text, so prep.original_text should not be None at this point. Using unwrap_or_default() can hide a potential logic error if it were ever None, which could lead to silent failures where an empty string is used as the prompt. It would be more robust to explicitly handle the None case as an internal error. This makes the contract between stages explicit and ensures the system fails fast if an invariant is broken.

        let original_text = prep.original_text.as_ref().ok_or_else(|| {
            error!(
                function = "CompletionRequestBuildingStage::execute",
                "original_text not found in preparation output for completion request"
            );
            error::internal_error(
                "missing_preparation_output",
                "original_text not found in preparation output for completion request",
            )
        })?;

        let mut proto_request = builder_client
            .build_completion_request(
                request_id,
                &completion_request,
                original_text.clone(),
                prep.token_ids.clone(),
            )
            .map_err(|e| {
                error!(
                    function = "CompletionRequestBuildingStage::execute",
                    error = %e,
                    "Failed to build completion request"
                );
                error::bad_request(
                    "invalid_request_parameters",
                    format!("Invalid request parameters: {e}"),
                )
            })?;


if self.inject_pd_metadata {
if let Some(workers) = ctx.state.workers.as_ref() {
helpers::maybe_inject_pd_metadata(&mut proto_request, workers);
}
}

ctx.state.proto_request = Some(ProtoRequest::Generate(proto_request));
Ok(None)
}

fn name(&self) -> &'static str {
"CompletionRequestBuilding"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
//! Completion response processing stage
//!
//! Non-streaming: collects generate results via the shared processor and wraps
//! them as `CompletionResponse`. Streaming: delegates to the completion-aware
//! streaming processor, which emits OpenAI `CompletionStreamResponse` chunks
//! directly from typed proto responses.

use std::sync::Arc;

use async_trait::async_trait;
use axum::response::Response;
use tracing::error;

use crate::{
core::AttachedBody,
routers::{
error,
grpc::{
common::stages::PipelineStage,
context::{FinalResponse, RequestContext},
regular::{processor, streaming},
},
},
};

pub(crate) struct CompletionResponseProcessingStage {
processor: processor::ResponseProcessor,
streaming_processor: Arc<streaming::StreamingProcessor>,
}

impl CompletionResponseProcessingStage {
pub fn new(
processor: processor::ResponseProcessor,
streaming_processor: Arc<streaming::StreamingProcessor>,
) -> Self {
Self {
processor,
streaming_processor,
}
}
}

#[async_trait]
impl PipelineStage for CompletionResponseProcessingStage {
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> {
self.process_completion_response(ctx).await
}

fn name(&self) -> &'static str {
"CompletionResponseProcessing"
}
}

impl CompletionResponseProcessingStage {
async fn process_completion_response(
&self,
ctx: &mut RequestContext,
) -> Result<Option<Response>, Response> {
let is_streaming = ctx.is_streaming();
let completion_req = ctx.completion_request_arc();

let execution_result = ctx.state.response.execution_result.take().ok_or_else(|| {
error!(
function = "CompletionResponseProcessingStage::execute",
"No execution result"
);
error::internal_error("no_execution_result", "No execution result")
})?;

let dispatch = ctx
.state
.dispatch
.as_ref()
.ok_or_else(|| {
error!(
function = "CompletionResponseProcessingStage::execute",
"Dispatch metadata not set"
);
error::internal_error("dispatch_metadata_not_set", "Dispatch metadata not set")
})?
.clone();

let tokenizer = ctx.tokenizer_arc().ok_or_else(|| {
error!(
function = "CompletionResponseProcessingStage::execute",
"Tokenizer not cached in context"
);
error::internal_error(
"tokenizer_not_cached",
"Tokenizer not cached in context - preparation stage may have been skipped",
)
})?;

let prompt_text = ctx
.state
.preparation
.as_ref()
.and_then(|p| p.original_text.clone())
.unwrap_or_default();

if is_streaming {
let response = self
.streaming_processor
.clone()
.process_streaming_completion(
execution_result,
completion_req.clone(),
dispatch,
tokenizer,
prompt_text,
);

let response = match ctx.state.load_guards.take() {
Some(guards) => AttachedBody::wrap_response(response, guards),
None => response,
};
return Ok(Some(response));
}

// Non-streaming
let stop_decoder = ctx.state.response.stop_decoder.as_mut().ok_or_else(|| {
error!(
function = "CompletionResponseProcessingStage::execute",
"Stop decoder not initialized"
);
error::internal_error(
"stop_decoder_not_initialized",
"Stop decoder not initialized",
)
})?;

let completion_response = self
.processor
.process_non_streaming_completion_response(
execution_result,
completion_req,
dispatch,
tokenizer,
stop_decoder,
&prompt_text,
)
.await?;

ctx.state.response.final_response = Some(FinalResponse::Completion(completion_response));
Ok(None)
}
}
1 change: 1 addition & 0 deletions model_gateway/src/routers/grpc/regular/stages/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

pub(crate) mod chat;
pub(crate) mod classify;
pub(crate) mod completion;
pub(crate) mod embedding;
pub(crate) mod generate;
pub(crate) mod messages;
Expand Down
Loading