Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 75 additions & 14 deletions crates/protocols/src/model_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ bitflags! {
const AUDIO = 1 << 10;
/// Content moderation models
const MODERATION = 1 << 11;
/// Diffusion models (Stable Diffusion, Flux, etc.)
const DIFFUSION = 1 << 12;

/// Standard LLM: chat + completions + responses + tools
const LLM = Self::CHAT.bits() | Self::COMPLETIONS.bits()
Expand Down Expand Up @@ -62,6 +64,9 @@ bitflags! {

/// Content moderation model only
const MODERATION_MODEL = Self::MODERATION.bits();

/// Diffusion model only (Stable Diffusion, Flux, etc.)
const DIFFUSION_MODEL = Self::DIFFUSION.bits();
Comment on lines +68 to +69
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

DIFFUSION_MODEL is missing the actual serving capability.

Every other *_MODEL alias includes the endpoint bit it serves; this one is just the marker flag. Cards classified as DIFFUSION_MODEL therefore fail every model-serving branch in supports_endpoint(...), so the new discovery paths can identify diffusion models but not route to them. Please pair DIFFUSION with the transport capability those models use, or add explicit endpoint handling and a regression test.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@crates/protocols/src/model_type.rs` around lines 68 - 69, The DIFFUSION_MODEL
constant currently returns only Self::DIFFUSION.bits() and so lacks the
serving/endpoint bit used by supports_endpoint(...); update the DIFFUSION_MODEL
alias to OR the DIFFUSION marker with the correct transport/endpoint capability
used by diffusion models (e.g., the same endpoint bit other *_MODEL aliases use
such as IMAGE_GENERATION or the relevant transport flag) so
supports_endpoint(...) recognizes and routes diffusion cards, and add/adjust a
regression test covering supports_endpoint(...) for DIFFUSION_MODEL to ensure
routing works.

}
}

Expand All @@ -79,6 +84,7 @@ const CAPABILITY_NAMES: &[(ModelType, &str)] = &[
(ModelType::IMAGE_GEN, "image_gen"),
(ModelType::AUDIO, "audio"),
(ModelType::MODERATION, "moderation"),
(ModelType::DIFFUSION, "diffusion"),
];

impl ModelType {
Expand Down Expand Up @@ -154,6 +160,12 @@ impl ModelType {
self.contains(Self::MODERATION)
}

/// Check if this model type supports diffusion (image generation via diffusion)
#[inline]
pub fn supports_diffusion(self) -> bool {
self.contains(Self::DIFFUSION)
}

/// Check if this model type supports a given endpoint
pub fn supports_endpoint(self, endpoint: Endpoint) -> bool {
match endpoint {
Expand Down Expand Up @@ -213,6 +225,12 @@ impl ModelType {
pub fn is_moderation_model(self) -> bool {
self.supports_moderation() && !self.supports_chat()
}

/// Check if this is a diffusion model
#[inline]
pub fn is_diffusion_model(self) -> bool {
self.supports_diffusion() && !self.supports_chat()
}
}

impl std::fmt::Display for ModelType {
Expand Down Expand Up @@ -296,20 +314,12 @@ impl JsonSchema for ModelType {
use schemars::schema::*;
let items = SchemaObject {
instance_type: Some(InstanceType::String.into()),
enum_values: Some(vec![
"chat".into(),
"completions".into(),
"responses".into(),
"embeddings".into(),
"rerank".into(),
"generate".into(),
"vision".into(),
"tools".into(),
"reasoning".into(),
"image_gen".into(),
"audio".into(),
"moderation".into(),
]),
enum_values: Some(
CAPABILITY_NAMES
.iter()
.map(|(_, name)| (*name).into())
.collect(),
),
..Default::default()
};
SchemaObject {
Expand Down Expand Up @@ -406,3 +416,54 @@ impl std::fmt::Display for Endpoint {
}
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_diffusion_flag_basics() {
let dt = ModelType::DIFFUSION;
assert!(dt.supports_diffusion());
assert!(!dt.supports_chat());
assert!(!dt.supports_vision());
}

#[test]
fn test_diffusion_model_composite() {
let dm = ModelType::DIFFUSION_MODEL;
assert!(dm.supports_diffusion());
assert!(dm.is_diffusion_model());
assert!(!dm.is_llm());
assert!(!dm.is_image_model());
}

#[test]
fn test_diffusion_with_chat_is_not_diffusion_model() {
let dt = ModelType::CHAT | ModelType::DIFFUSION;
assert!(dt.supports_diffusion());
assert!(!dt.is_diffusion_model()); // has chat, so not a pure diffusion model
assert!(dt.is_llm());
}

#[test]
fn test_diffusion_capability_name() {
let dm = ModelType::DIFFUSION_MODEL;
let names = dm.as_capability_names();
assert_eq!(names, vec!["diffusion"]);
}

#[test]
fn test_diffusion_serialization_roundtrip() {
let dm = ModelType::DIFFUSION_MODEL;
let json = serde_json::to_string(&dm).unwrap();
assert_eq!(json, r#"["diffusion"]"#);
let deserialized: ModelType = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized, dm);
}

#[test]
fn test_diffusion_display() {
assert_eq!(ModelType::DIFFUSION_MODEL.to_string(), "diffusion");
}
}
64 changes: 64 additions & 0 deletions model_gateway/src/core/steps/worker/external/discover_models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,18 @@ pub fn infer_model_type_from_id(id: &str) -> ModelType {
return ModelType::RERANK_MODEL;
}

// Diffusion models (Stable Diffusion, Flux, SDXL, etc.)
// Must be checked before the generic "image" heuristic so that IDs like
// "flux-image-*" or "stable-diffusion-image-*" are not misclassified.
if id_lower.starts_with("sd-")
|| id_lower.starts_with("sd3")
|| id_lower.starts_with("sdxl")
|| id_lower.starts_with("flux")
|| id_lower.contains("diffusion")
Comment on lines +112 to +116
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Match diffusion heuristics against the basename, not the full ID

These checks only recognize flux, sdxl, and sd3 when they start the entire model ID. The rest of the repo already treats namespaced IDs like org/model as normal, so external /v1/models entries such as black-forest-labs/FLUX.1-dev or stabilityai/sdxl-turbo fall through every diffusion branch here and get classified as ModelType::LLM. That misses two of the main families called out in the feature summary unless the ID happens to contain the literal word diffusion.

Useful? React with 👍 / 👎.

{
return ModelType::DIFFUSION_MODEL;
}

// Image generation models
if id_lower.starts_with("dall-e")
|| id_lower.starts_with("sora")
Expand Down Expand Up @@ -318,3 +330,55 @@ impl StepExecutor<WorkerWorkflowData> for DiscoverModelsStep {
true
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_infer_diffusion_models() {
assert_eq!(
infer_model_type_from_id("stable-diffusion-xl-base-1.0"),
ModelType::DIFFUSION_MODEL
);
assert_eq!(
infer_model_type_from_id("stable_diffusion_3"),
ModelType::DIFFUSION_MODEL
);
assert_eq!(
infer_model_type_from_id("sd-v1-5"),
ModelType::DIFFUSION_MODEL
);
assert_eq!(
infer_model_type_from_id("sd3-medium"),
ModelType::DIFFUSION_MODEL
);
assert_eq!(
infer_model_type_from_id("sdxl-turbo"),
ModelType::DIFFUSION_MODEL
);
assert_eq!(
infer_model_type_from_id("flux-1-dev"),
ModelType::DIFFUSION_MODEL
);
assert_eq!(
infer_model_type_from_id("FLUX-schnell"),
ModelType::DIFFUSION_MODEL
);
assert_eq!(
infer_model_type_from_id("my-custom-diffusion-model"),
ModelType::DIFFUSION_MODEL
);
}

#[test]
fn test_infer_non_diffusion_models() {
assert_eq!(infer_model_type_from_id("gpt-4o"), ModelType::VISION_LLM);
assert_eq!(infer_model_type_from_id("dall-e-3"), ModelType::IMAGE_MODEL);
assert_eq!(
infer_model_type_from_id("text-embedding-3-small"),
ModelType::EMBED_MODEL
);
assert_eq!(infer_model_type_from_id("llama-3-70b"), ModelType::LLM);
}
}
79 changes: 75 additions & 4 deletions model_gateway/src/core/steps/worker/local/create_worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,17 +231,30 @@ fn build_model_card(
.map(|s| s == "true")
.unwrap_or(false);

if !user_provided {
// The "model_type" label is currently only reported by SGLang workers
// (via the /model_info endpoint for multimodal_gen servers).
let is_diffusion = labels
.get("model_type")
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is only supported for SGLang atm.

.is_some_and(|s| s.to_lowercase() == "diffusion");

if user_provided {
if has_vision && !card.model_type.supports_vision() {
card.model_type |= ModelType::VISION;
}
if is_diffusion && !card.model_type.supports_diffusion() {
card.model_type |= ModelType::DIFFUSION;
Comment on lines +244 to +245
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Replace default LLM type for user-provided diffusion models

When user_provided is true, the card usually comes from ModelCard::new, which defaults model_type to LLM in crates/protocols/src/model_card.rs. OR-ing DIFFUSION onto that default means a config like models: [{id: "black-forest-labs/FLUX.1-dev"}] becomes chat|completions|responses|tools|diffusion instead of a pure diffusion model, so the new detection still reports it as an LLM and is_diffusion_model() never becomes true for manually configured local diffusion workers.

Useful? React with 👍 / 👎.

}
} else {
let is_embedding = labels.get("is_embedding").is_some_and(|s| s == "true");
let is_non_generation = labels.get("is_generation").is_some_and(|s| s == "false");

if is_embedding || is_non_generation {
if is_diffusion {
card.model_type = ModelType::DIFFUSION_MODEL;
} else if is_embedding || is_non_generation {
Comment on lines +251 to +253
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Preserve VISION when a diffusion worker is also multimodal.

This assignment overwrites the card with DIFFUSION_MODEL, so a worker that reports both model_type=diffusion and supports_vision=true loses the input-image capability that has_vision just detected. Reapply VISION after setting the diffusion type, and add a regression test for the combined labels.

Suggested fix
-        if is_diffusion {
-            card.model_type = ModelType::DIFFUSION_MODEL;
+        if is_diffusion {
+            card.model_type = ModelType::DIFFUSION_MODEL;
+            if has_vision {
+                card.model_type |= ModelType::VISION;
+            }
         } else if is_embedding || is_non_generation {
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if is_diffusion {
card.model_type = ModelType::DIFFUSION_MODEL;
} else if is_embedding || is_non_generation {
if is_diffusion {
card.model_type = ModelType::DIFFUSION_MODEL;
if has_vision {
card.model_type |= ModelType::VISION;
}
} else if is_embedding || is_non_generation {
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@model_gateway/src/core/steps/worker/local/create_worker.rs` around lines 251
- 253, The assignment sets card.model_type = ModelType::DIFFUSION_MODEL and
overwrites any prior VISION capability detected (has_vision/supports_vision);
update the logic in the block where is_diffusion is handled (around the card
variable) to preserve or reapply the VISION flag after setting DIFFUSION_MODEL
(e.g., ensure card.model_type keeps/combines VISION with DIFFUSION rather than
replacing it), and add a regression test that constructs a worker with both
model_type=diffusion and supports_vision=true to assert the resulting card
retains the vision capability.

card.model_type = infer_non_generation_type(labels);
} else if has_vision && !card.model_type.supports_vision() {
card.model_type |= ModelType::VISION;
}
} else if has_vision && !card.model_type.supports_vision() {
card.model_type |= ModelType::VISION;
}

card
Expand Down Expand Up @@ -275,3 +288,61 @@ fn normalize_url(url: &str, connection_mode: ConnectionMode) -> String {
}
}
}

#[cfg(test)]
mod tests {
use openai_protocol::worker::WorkerSpec;

use super::*;

fn default_config() -> WorkerSpec {
serde_json::from_str(r#"{"url": "http://localhost:30000"}"#).unwrap()
}

#[test]
fn test_build_model_card_diffusion_from_model_type_label() {
let config = default_config();
let mut labels = HashMap::new();
labels.insert("model_type".to_string(), "diffusion".to_string());
labels.insert("is_generation".to_string(), "true".to_string());

let card = build_model_card("stable-diffusion-xl", &config, &labels);
assert!(card.model_type.supports_diffusion());
assert!(card.model_type.is_diffusion_model());
}

#[test]
fn test_build_model_card_diffusion_case_insensitive() {
let config = default_config();
let mut labels = HashMap::new();
labels.insert("model_type".to_string(), "Diffusion".to_string());

let card = build_model_card("flux-dev", &config, &labels);
assert!(card.model_type.supports_diffusion());
}

#[test]
fn test_build_model_card_non_diffusion_llm() {
let config = default_config();
let mut labels = HashMap::new();
labels.insert("model_type".to_string(), "llama".to_string());
labels.insert("is_generation".to_string(), "true".to_string());

let card = build_model_card("llama-3-70b", &config, &labels);
assert!(!card.model_type.supports_diffusion());
assert!(card.model_type.is_llm());
}

#[test]
fn test_build_model_card_diffusion_takes_precedence_over_embedding() {
let config = default_config();
let mut labels = HashMap::new();
labels.insert("model_type".to_string(), "diffusion".to_string());
labels.insert("is_generation".to_string(), "false".to_string());

let card = build_model_card("sd-model", &config, &labels);
// Diffusion check comes before is_generation check
assert!(card.model_type.supports_diffusion());
assert!(!card.model_type.supports_embeddings());
}
}
Loading