feat(gateway): add DIFFUSION model capability detection#736
feat(gateway): add DIFFUSION model capability detection#736Kangyan-Zhou wants to merge 2 commits intolightseekorg:mainfrom
Conversation
Add a new DIFFUSION bitflag (1 << 12) to ModelType for detecting diffusion models (Stable Diffusion, Flux, SDXL, etc.) served by SGLang's multimodal_gen server. Detection works via: - External: model ID pattern matching (stable-diffusion, flux, sd-*, etc.) - Local: model_type="diffusion" label from /model_info response Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
📝 WalkthroughWalkthroughThis pull request introduces a new diffusion model capability to the system. It adds DIFFUSION bitflags to ModelType, implements detection logic for diffusion models based on model identifiers, and updates model card creation to properly handle and enable the diffusion capability when appropriate. Changes
Sequence DiagramsequenceDiagram
participant ModelDiscovery as Model Discovery
participant TypeInference as Type Inference
participant ModelCard as Model Card Builder
ModelDiscovery->>ModelDiscovery: Parse model identifier
ModelDiscovery->>TypeInference: Detect diffusion pattern<br/>(stable-diffusion, sd-, flux, etc.)
TypeInference->>TypeInference: Infer model_type<br/>as DIFFUSION_MODEL
TypeInference->>ModelCard: Pass detected type
ModelCard->>ModelCard: Check if diffusion<br/>from model_type label
alt User-Provided Model
ModelCard->>ModelCard: Enable DIFFUSION flag<br/>if not supported
else Non-User-Provided Model
ModelCard->>ModelCard: Set DIFFUSION_MODEL<br/>with precedence over embedding
end
ModelCard->>ModelCard: Return configured<br/>model_card
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
📝 Coding Plan
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment Tip You can customize the high-level summary generated by CodeRabbit.Configure the |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces comprehensive support for diffusion models within the gateway. It establishes a new Highlights
Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces support for diffusion models by adding a DIFFUSION capability to ModelType. The changes include updating model discovery logic for both external and local workers to correctly identify these models based on their IDs or labels. The implementation is well-structured and accompanied by a comprehensive set of unit tests. I have one suggestion to simplify the model ID matching logic for better readability and maintainability.
| if id_lower.contains("stable-diffusion") | ||
| || id_lower.contains("stable_diffusion") | ||
| || id_lower.starts_with("sd-") | ||
| || id_lower.starts_with("sd3") | ||
| || id_lower.starts_with("sdxl") | ||
| || id_lower.starts_with("flux") | ||
| || id_lower.contains("diffusion") | ||
| { | ||
| return ModelType::DIFFUSION_MODEL; | ||
| } |
There was a problem hiding this comment.
The checks for stable-diffusion and stable_diffusion are redundant because the more general diffusion check at the end of the condition will catch them anyway. Removing these specific checks simplifies the code without changing the logic. Reordering to check starts_with before contains can also be slightly more performant.
if id_lower.starts_with("sd-")
|| id_lower.starts_with("sd3")
|| id_lower.starts_with("sdxl")
|| id_lower.starts_with("flux")
|| id_lower.contains("diffusion")
{
return ModelType::DIFFUSION_MODEL;
}|
Hi @Kangyan-Zhou, the DCO sign-off check has failed. All commits must include a To fix existing commits: # Sign off the last N commits (replace N with the number of unsigned commits)
git rebase HEAD~N --signoff
git push --force-with-leaseTo sign off future commits automatically:
|
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
crates/protocols/src/model_type.rs (1)
313-331: 🧹 Nitpick | 🔵 TrivialReduce schema drift risk by deriving enum values from
CAPABILITY_NAMES.The list at Line 317–Line 331 duplicates capability names already maintained in
CAPABILITY_NAMES; this can silently diverge in future flag additions.♻️ Proposed refactor
fn json_schema(_gen: &mut SchemaGenerator) -> Schema { use schemars::schema::*; let items = SchemaObject { instance_type: Some(InstanceType::String.into()), - enum_values: Some(vec![ - "chat".into(), - "completions".into(), - "responses".into(), - "embeddings".into(), - "rerank".into(), - "generate".into(), - "vision".into(), - "tools".into(), - "reasoning".into(), - "image_gen".into(), - "audio".into(), - "moderation".into(), - "diffusion".into(), - ]), + enum_values: Some( + CAPABILITY_NAMES + .iter() + .map(|(_, name)| (*name).into()) + .collect(), + ), ..Default::default() };🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@crates/protocols/src/model_type.rs` around lines 313 - 331, The hard-coded enum_values array in the json_schema implementation for json_schema should be replaced with values derived from the existing CAPABILITY_NAMES constant to avoid duplication and drift: locate the json_schema function and replace the literal vec![...strings...] with an expression that iterates over CAPABILITY_NAMES (e.g., CAPABILITY_NAMES.iter().map(|s| s.into()).collect()) so the SchemaObject.enum_values is built from CAPABILITY_NAMES; ensure the produced collection matches schemars::schema::SingleOrVec<serde_json::Value> (or the appropriate type expected by enum_values) and import or convert types as needed so compilation continues to succeed.model_gateway/src/core/steps/worker/external/discover_models.rs (1)
109-127:⚠️ Potential issue | 🟠 MajorDiffusion classification is shadowed by the earlier generic image heuristic.
At Line 112, any ID containing
"image"returnsIMAGE_MODELbefore the new diffusion checks run. That means IDs likeflux-image-*orstable-diffusion-image-*will not be classified as diffusion.🐛 Proposed fix (check diffusion before generic image matching)
- // Image generation models - if id_lower.starts_with("dall-e") - || id_lower.starts_with("sora") - || (id_lower.contains("image") && !id_lower.contains("vision")) - { - return ModelType::IMAGE_MODEL; - } - // Diffusion models (Stable Diffusion, Flux, SDXL, etc.) if id_lower.contains("stable-diffusion") || id_lower.contains("stable_diffusion") || id_lower.starts_with("sd-") || id_lower.starts_with("sd3") || id_lower.starts_with("sdxl") || id_lower.starts_with("flux") || id_lower.contains("diffusion") { return ModelType::DIFFUSION_MODEL; } + + // Image generation models + if id_lower.starts_with("dall-e") + || id_lower.starts_with("sora") + || (id_lower.contains("image") && !id_lower.contains("vision")) + { + return ModelType::IMAGE_MODEL; + }🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@model_gateway/src/core/steps/worker/external/discover_models.rs` around lines 109 - 127, The image-vs-diffusion heuristic currently checks the generic image condition on id_lower before diffusion checks, causing IDs like "stable-diffusion-image-*" to be classified as IMAGE_MODEL; reorder or refine the checks so diffusion detection runs first: move the block that returns ModelType::DIFFUSION_MODEL (matching stable-diffusion, sd-, sdxl, flux, diffusion) ahead of the generic image heuristic that returns ModelType::IMAGE_MODEL, or alternatively narrow the image condition (id_lower.contains("image") && !id_lower.contains("vision")) to explicitly exclude diffusion keywords; update code paths that reference id_lower and the ModelType::DIFFUSION_MODEL / ModelType::IMAGE_MODEL returns accordingly.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@crates/protocols/src/model_type.rs`:
- Around line 313-331: The hard-coded enum_values array in the json_schema
implementation for json_schema should be replaced with values derived from the
existing CAPABILITY_NAMES constant to avoid duplication and drift: locate the
json_schema function and replace the literal vec![...strings...] with an
expression that iterates over CAPABILITY_NAMES (e.g.,
CAPABILITY_NAMES.iter().map(|s| s.into()).collect()) so the
SchemaObject.enum_values is built from CAPABILITY_NAMES; ensure the produced
collection matches schemars::schema::SingleOrVec<serde_json::Value> (or the
appropriate type expected by enum_values) and import or convert types as needed
so compilation continues to succeed.
In `@model_gateway/src/core/steps/worker/external/discover_models.rs`:
- Around line 109-127: The image-vs-diffusion heuristic currently checks the
generic image condition on id_lower before diffusion checks, causing IDs like
"stable-diffusion-image-*" to be classified as IMAGE_MODEL; reorder or refine
the checks so diffusion detection runs first: move the block that returns
ModelType::DIFFUSION_MODEL (matching stable-diffusion, sd-, sdxl, flux,
diffusion) ahead of the generic image heuristic that returns
ModelType::IMAGE_MODEL, or alternatively narrow the image condition
(id_lower.contains("image") && !id_lower.contains("vision")) to explicitly
exclude diffusion keywords; update code paths that reference id_lower and the
ModelType::DIFFUSION_MODEL / ModelType::IMAGE_MODEL returns accordingly.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: ASSERTIVE
Plan: Pro
Run ID: 43b5b55d-8f04-4741-b363-f0fa06175195
📒 Files selected for processing (3)
crates/protocols/src/model_type.rsmodel_gateway/src/core/steps/worker/external/discover_models.rsmodel_gateway/src/core/steps/worker/local/create_worker.rs
|
|
||
| if !user_provided { | ||
| let is_diffusion = labels | ||
| .get("model_type") |
There was a problem hiding this comment.
I think this is only supported for SGLang atm.
Summary
DIFFUSIONbitflag (1 << 12) toModelTypefor detecting diffusion models (Stable Diffusion, Flux, SDXL, SD3, etc.)stable-diffusion,flux*,sd-*,sdxl*,*diffusion*)model_type="diffusion"label from SGLang multimodal_gen server's/model_inforesponseDIFFUSION_MODELcomposite,supports_diffusion()/is_diffusion_model()helpers, serde + JSON schema supportTest plan
model_type.rs(flag basics, composite, serde roundtrip, display)discover_models.rs(diffusion ID patterns, non-diffusion negative cases)create_worker.rs(label detection, case insensitivity, precedence over embedding, negative case)cargo clippycleanpre-commit run --all-filespasses🤖 Generated with Claude Code
Summary by CodeRabbit
New Features
Tests