-
Notifications
You must be signed in to change notification settings - Fork 53
feat(gateway): add DIFFUSION model capability detection #736
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -106,6 +106,18 @@ pub fn infer_model_type_from_id(id: &str) -> ModelType { | |
| return ModelType::RERANK_MODEL; | ||
| } | ||
|
|
||
| // Diffusion models (Stable Diffusion, Flux, SDXL, etc.) | ||
| // Must be checked before the generic "image" heuristic so that IDs like | ||
| // "flux-image-*" or "stable-diffusion-image-*" are not misclassified. | ||
| if id_lower.starts_with("sd-") | ||
| || id_lower.starts_with("sd3") | ||
| || id_lower.starts_with("sdxl") | ||
| || id_lower.starts_with("flux") | ||
| || id_lower.contains("diffusion") | ||
|
Comment on lines
+112
to
+116
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
These checks only recognize Useful? React with 👍 / 👎. |
||
| { | ||
| return ModelType::DIFFUSION_MODEL; | ||
| } | ||
|
|
||
| // Image generation models | ||
| if id_lower.starts_with("dall-e") | ||
| || id_lower.starts_with("sora") | ||
|
|
@@ -318,3 +330,55 @@ impl StepExecutor<WorkerWorkflowData> for DiscoverModelsStep { | |
| true | ||
| } | ||
| } | ||
|
|
||
| #[cfg(test)] | ||
| mod tests { | ||
| use super::*; | ||
|
|
||
| #[test] | ||
| fn test_infer_diffusion_models() { | ||
| assert_eq!( | ||
| infer_model_type_from_id("stable-diffusion-xl-base-1.0"), | ||
| ModelType::DIFFUSION_MODEL | ||
| ); | ||
| assert_eq!( | ||
| infer_model_type_from_id("stable_diffusion_3"), | ||
| ModelType::DIFFUSION_MODEL | ||
| ); | ||
| assert_eq!( | ||
| infer_model_type_from_id("sd-v1-5"), | ||
| ModelType::DIFFUSION_MODEL | ||
| ); | ||
| assert_eq!( | ||
| infer_model_type_from_id("sd3-medium"), | ||
| ModelType::DIFFUSION_MODEL | ||
| ); | ||
| assert_eq!( | ||
| infer_model_type_from_id("sdxl-turbo"), | ||
| ModelType::DIFFUSION_MODEL | ||
| ); | ||
| assert_eq!( | ||
| infer_model_type_from_id("flux-1-dev"), | ||
| ModelType::DIFFUSION_MODEL | ||
| ); | ||
| assert_eq!( | ||
| infer_model_type_from_id("FLUX-schnell"), | ||
| ModelType::DIFFUSION_MODEL | ||
| ); | ||
| assert_eq!( | ||
| infer_model_type_from_id("my-custom-diffusion-model"), | ||
| ModelType::DIFFUSION_MODEL | ||
| ); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_infer_non_diffusion_models() { | ||
| assert_eq!(infer_model_type_from_id("gpt-4o"), ModelType::VISION_LLM); | ||
| assert_eq!(infer_model_type_from_id("dall-e-3"), ModelType::IMAGE_MODEL); | ||
| assert_eq!( | ||
| infer_model_type_from_id("text-embedding-3-small"), | ||
| ModelType::EMBED_MODEL | ||
| ); | ||
| assert_eq!(infer_model_type_from_id("llama-3-70b"), ModelType::LLM); | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -231,17 +231,30 @@ fn build_model_card( | |||||||||||||||||||
| .map(|s| s == "true") | ||||||||||||||||||||
| .unwrap_or(false); | ||||||||||||||||||||
|
|
||||||||||||||||||||
| if !user_provided { | ||||||||||||||||||||
| // The "model_type" label is currently only reported by SGLang workers | ||||||||||||||||||||
| // (via the /model_info endpoint for multimodal_gen servers). | ||||||||||||||||||||
| let is_diffusion = labels | ||||||||||||||||||||
| .get("model_type") | ||||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is only supported for SGLang atm. |
||||||||||||||||||||
| .is_some_and(|s| s.to_lowercase() == "diffusion"); | ||||||||||||||||||||
|
|
||||||||||||||||||||
| if user_provided { | ||||||||||||||||||||
| if has_vision && !card.model_type.supports_vision() { | ||||||||||||||||||||
| card.model_type |= ModelType::VISION; | ||||||||||||||||||||
| } | ||||||||||||||||||||
| if is_diffusion && !card.model_type.supports_diffusion() { | ||||||||||||||||||||
| card.model_type |= ModelType::DIFFUSION; | ||||||||||||||||||||
|
Comment on lines
+244
to
+245
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
When Useful? React with 👍 / 👎. |
||||||||||||||||||||
| } | ||||||||||||||||||||
| } else { | ||||||||||||||||||||
| let is_embedding = labels.get("is_embedding").is_some_and(|s| s == "true"); | ||||||||||||||||||||
| let is_non_generation = labels.get("is_generation").is_some_and(|s| s == "false"); | ||||||||||||||||||||
|
|
||||||||||||||||||||
| if is_embedding || is_non_generation { | ||||||||||||||||||||
| if is_diffusion { | ||||||||||||||||||||
| card.model_type = ModelType::DIFFUSION_MODEL; | ||||||||||||||||||||
| } else if is_embedding || is_non_generation { | ||||||||||||||||||||
|
Comment on lines
+251
to
+253
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Preserve This assignment overwrites the card with Suggested fix- if is_diffusion {
- card.model_type = ModelType::DIFFUSION_MODEL;
+ if is_diffusion {
+ card.model_type = ModelType::DIFFUSION_MODEL;
+ if has_vision {
+ card.model_type |= ModelType::VISION;
+ }
} else if is_embedding || is_non_generation {📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||
| card.model_type = infer_non_generation_type(labels); | ||||||||||||||||||||
| } else if has_vision && !card.model_type.supports_vision() { | ||||||||||||||||||||
| card.model_type |= ModelType::VISION; | ||||||||||||||||||||
| } | ||||||||||||||||||||
| } else if has_vision && !card.model_type.supports_vision() { | ||||||||||||||||||||
| card.model_type |= ModelType::VISION; | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| card | ||||||||||||||||||||
|
|
@@ -275,3 +288,61 @@ fn normalize_url(url: &str, connection_mode: ConnectionMode) -> String { | |||||||||||||||||||
| } | ||||||||||||||||||||
| } | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| #[cfg(test)] | ||||||||||||||||||||
| mod tests { | ||||||||||||||||||||
| use openai_protocol::worker::WorkerSpec; | ||||||||||||||||||||
|
|
||||||||||||||||||||
| use super::*; | ||||||||||||||||||||
|
|
||||||||||||||||||||
| fn default_config() -> WorkerSpec { | ||||||||||||||||||||
| serde_json::from_str(r#"{"url": "http://localhost:30000"}"#).unwrap() | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| #[test] | ||||||||||||||||||||
| fn test_build_model_card_diffusion_from_model_type_label() { | ||||||||||||||||||||
| let config = default_config(); | ||||||||||||||||||||
| let mut labels = HashMap::new(); | ||||||||||||||||||||
| labels.insert("model_type".to_string(), "diffusion".to_string()); | ||||||||||||||||||||
| labels.insert("is_generation".to_string(), "true".to_string()); | ||||||||||||||||||||
|
|
||||||||||||||||||||
| let card = build_model_card("stable-diffusion-xl", &config, &labels); | ||||||||||||||||||||
| assert!(card.model_type.supports_diffusion()); | ||||||||||||||||||||
| assert!(card.model_type.is_diffusion_model()); | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| #[test] | ||||||||||||||||||||
| fn test_build_model_card_diffusion_case_insensitive() { | ||||||||||||||||||||
| let config = default_config(); | ||||||||||||||||||||
| let mut labels = HashMap::new(); | ||||||||||||||||||||
| labels.insert("model_type".to_string(), "Diffusion".to_string()); | ||||||||||||||||||||
|
|
||||||||||||||||||||
| let card = build_model_card("flux-dev", &config, &labels); | ||||||||||||||||||||
| assert!(card.model_type.supports_diffusion()); | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| #[test] | ||||||||||||||||||||
| fn test_build_model_card_non_diffusion_llm() { | ||||||||||||||||||||
| let config = default_config(); | ||||||||||||||||||||
| let mut labels = HashMap::new(); | ||||||||||||||||||||
| labels.insert("model_type".to_string(), "llama".to_string()); | ||||||||||||||||||||
| labels.insert("is_generation".to_string(), "true".to_string()); | ||||||||||||||||||||
|
|
||||||||||||||||||||
| let card = build_model_card("llama-3-70b", &config, &labels); | ||||||||||||||||||||
| assert!(!card.model_type.supports_diffusion()); | ||||||||||||||||||||
| assert!(card.model_type.is_llm()); | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| #[test] | ||||||||||||||||||||
| fn test_build_model_card_diffusion_takes_precedence_over_embedding() { | ||||||||||||||||||||
| let config = default_config(); | ||||||||||||||||||||
| let mut labels = HashMap::new(); | ||||||||||||||||||||
| labels.insert("model_type".to_string(), "diffusion".to_string()); | ||||||||||||||||||||
| labels.insert("is_generation".to_string(), "false".to_string()); | ||||||||||||||||||||
|
|
||||||||||||||||||||
| let card = build_model_card("sd-model", &config, &labels); | ||||||||||||||||||||
| // Diffusion check comes before is_generation check | ||||||||||||||||||||
| assert!(card.model_type.supports_diffusion()); | ||||||||||||||||||||
| assert!(!card.model_type.supports_embeddings()); | ||||||||||||||||||||
| } | ||||||||||||||||||||
| } | ||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DIFFUSION_MODELis missing the actual serving capability.Every other
*_MODELalias includes the endpoint bit it serves; this one is just the marker flag. Cards classified asDIFFUSION_MODELtherefore fail every model-serving branch insupports_endpoint(...), so the new discovery paths can identify diffusion models but not route to them. Please pairDIFFUSIONwith the transport capability those models use, or add explicit endpoint handling and a regression test.🤖 Prompt for AI Agents