Skip to content

Commit 6767559

Browse files
fix: Support for msg[content] as a list (#4485)
Signed-off-by: Krishnan Prashanth <kprashanth@nvidia.com> Signed-off-by: KrishnanPrash <140860868+KrishnanPrash@users.noreply.github.com> Co-authored-by: Ryan McCormick <rmccormick@nvidia.com>
1 parent 2f18b23 commit 6767559

File tree

5 files changed

+190
-30
lines changed

5 files changed

+190
-30
lines changed

examples/backends/vllm/launch/agg_multimodal.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ python -m dynamo.frontend --http-port=8000 &
4646
EXTRA_ARGS=""
4747
if [[ "$MODEL_NAME" == "Qwen/Qwen2.5-VL-7B-Instruct" ]]; then
4848
EXTRA_ARGS="--gpu-memory-utilization 0.85 --max-model-len 2048"
49+
elif [[ "$MODEL_NAME" == "llava-hf/llava-1.5-7b-hf" ]]; then
50+
EXTRA_ARGS="--gpu-memory-utilization 0.85 --max-model-len 2048"
4951
fi
5052

5153
# Start vLLM worker with vision model

lib/llm/src/preprocessor/prompt/template.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ struct HfTokenizerConfigJsonFormatter {
106106
config: ChatTemplate,
107107
mixins: Arc<ContextMixins>,
108108
supports_add_generation_prompt: bool,
109+
requires_content_arrays: bool,
109110
}
110111

111112
// /// OpenAI Standard Prompt Formatter

lib/llm/src/preprocessor/prompt/template/formatters.rs

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,38 @@ use std::sync::Arc;
66
use super::tokcfg::{ChatTemplate, raise_exception, strftime_now, tojson};
77
use super::{ContextMixins, HfTokenizerConfigJsonFormatter, JinjaEnvironment};
88
use either::Either;
9-
use minijinja::{Environment, Value};
9+
use minijinja::{Environment, Value, context};
10+
use serde_json::json;
1011
use tracing;
1112

13+
/// Detects if a template requires content as arrays (multimodal) vs strings (text-only).
14+
/// Returns true if the template only works with array format.
15+
fn detect_content_array_usage(env: &Environment) -> bool {
16+
// Test with array format
17+
let array_msg = context! {
18+
messages => json!([{"role": "user", "content": [{"type": "text", "text": "template_test"}]}]),
19+
add_generation_prompt => false,
20+
};
21+
22+
// Test with string format
23+
let string_msg = context! {
24+
messages => json!([{"role": "user", "content": "template_test"}]),
25+
add_generation_prompt => false,
26+
};
27+
28+
let out_array = env
29+
.get_template("default")
30+
.and_then(|t| t.render(&array_msg))
31+
.unwrap_or_default();
32+
let out_string = env
33+
.get_template("default")
34+
.and_then(|t| t.render(&string_msg))
35+
.unwrap_or_default();
36+
37+
// If array works but string doesn't, template requires arrays
38+
out_array.contains("template_test") && !out_string.contains("template_test")
39+
}
40+
1241
/// Remove known non-standard Jinja2 tags from chat templates
1342
///
1443
/// Some models use custom Jinja2 extensions that minijinja doesn't recognize. These tags
@@ -120,11 +149,15 @@ impl HfTokenizerConfigJsonFormatter {
120149
}
121150
}
122151

152+
// Detect at model load time whether this template requires content arrays
153+
let requires_content_arrays = detect_content_array_usage(&env);
154+
123155
Ok(HfTokenizerConfigJsonFormatter {
124156
env,
125157
config,
126158
mixins: Arc::new(mixins),
127159
supports_add_generation_prompt: supports_add_generation_prompt.unwrap_or(false),
160+
requires_content_arrays,
128161
})
129162
}
130163
}

lib/llm/src/preprocessor/prompt/template/oai.rs

Lines changed: 117 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,9 @@ fn may_be_fix_tool_schema(tools: serde_json::Value) -> Option<Value> {
7373
Some(Value::from_serialize(&updated_tools))
7474
}
7575

76-
fn may_be_fix_msg_content(messages: serde_json::Value) -> Value {
77-
// If messages[content] is provided as a list containing ONLY text parts,
78-
// concatenate them into a string to match chat template expectations.
79-
// Mixed content types are left for chat templates to handle.
76+
fn may_be_fix_msg_content(messages: serde_json::Value, preserve_arrays: bool) -> Value {
77+
// preserve_arrays=true: strings → arrays (multimodal)
78+
// preserve_arrays=false: text-only arrays → strings (standard)
8079

8180
let Some(arr) = messages.as_array() else {
8281
return Value::from_serialize(&messages);
@@ -86,7 +85,20 @@ fn may_be_fix_msg_content(messages: serde_json::Value) -> Value {
8685
.iter()
8786
.map(|msg| {
8887
match msg.get("content") {
89-
Some(serde_json::Value::Array(content_array)) => {
88+
// Case 1: String to Array (for multimodal templates)
89+
Some(serde_json::Value::String(text)) if preserve_arrays => {
90+
let mut modified_msg = msg.clone();
91+
if let Some(msg_object) = modified_msg.as_object_mut() {
92+
let content_array = serde_json::json!([{
93+
"type": "text",
94+
"text": text
95+
}]);
96+
msg_object.insert("content".to_string(), content_array);
97+
}
98+
modified_msg
99+
}
100+
// Case 2: Array to String (for standard templates)
101+
Some(serde_json::Value::Array(content_array)) if !preserve_arrays => {
90102
let is_text_only_array = !content_array.is_empty()
91103
&& content_array.iter().all(|part| {
92104
part.get("type")
@@ -114,7 +126,7 @@ fn may_be_fix_msg_content(messages: serde_json::Value) -> Value {
114126
msg.clone() // Mixed content or non-text only
115127
}
116128
}
117-
_ => msg.clone(), // String content or missing content - return unchanged
129+
_ => msg.clone(), // No conversion needed
118130
}
119131
})
120132
.collect();
@@ -159,19 +171,7 @@ impl OAIChatLikeRequest for NvCreateChatCompletionRequest {
159171

160172
fn messages(&self) -> Value {
161173
let messages_json = serde_json::to_value(&self.inner.messages).unwrap();
162-
163-
let needs_fixing = if let Some(arr) = messages_json.as_array() {
164-
arr.iter()
165-
.any(|msg| msg.get("content").and_then(|c| c.as_array()).is_some())
166-
} else {
167-
false
168-
};
169-
170-
if needs_fixing {
171-
may_be_fix_msg_content(messages_json)
172-
} else {
173-
Value::from_serialize(&messages_json)
174-
}
174+
Value::from_serialize(&messages_json)
175175
}
176176

177177
fn tools(&self) -> Option<Value> {
@@ -301,6 +301,13 @@ impl OAIPromptFormatter for HfTokenizerConfigJsonFormatter {
301301
let messages_canonical = req.messages();
302302
let mut messages_for_template: serde_json::Value =
303303
serde_json::to_value(&messages_canonical).unwrap();
304+
305+
messages_for_template = serde_json::to_value(may_be_fix_msg_content(
306+
messages_for_template,
307+
self.requires_content_arrays,
308+
))
309+
.unwrap();
310+
304311
normalize_tool_arguments_in_messages(&mut messages_for_template);
305312

306313
let ctx = context! {
@@ -457,7 +464,10 @@ mod tests {
457464
}"#;
458465

459466
let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
460-
let messages = serde_json::to_value(request.messages()).unwrap();
467+
let messages_raw = serde_json::to_value(request.messages()).unwrap();
468+
469+
// Test array → string normalization (preserve_arrays=false for standard templates)
470+
let messages = serde_json::to_value(may_be_fix_msg_content(messages_raw, false)).unwrap();
461471

462472
// Verify: text-only array is concatenated into a single string
463473
assert_eq!(
@@ -500,7 +510,10 @@ mod tests {
500510
}"#;
501511

502512
let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
503-
let messages = serde_json::to_value(request.messages()).unwrap();
513+
let messages_raw = serde_json::to_value(request.messages()).unwrap();
514+
515+
// Test array → string normalization (preserve_arrays=false for standard templates)
516+
let messages = serde_json::to_value(may_be_fix_msg_content(messages_raw, false)).unwrap();
504517

505518
// Verify: System message with string content remains unchanged
506519
assert_eq!(
@@ -541,7 +554,10 @@ mod tests {
541554
}"#;
542555

543556
let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
544-
let messages = serde_json::to_value(request.messages()).unwrap();
557+
let messages_raw = serde_json::to_value(request.messages()).unwrap();
558+
559+
// Empty arrays should be preserved regardless of preserve_arrays setting
560+
let messages = serde_json::to_value(may_be_fix_msg_content(messages_raw, false)).unwrap();
545561

546562
// Verify: Empty arrays are preserved as-is
547563
assert!(messages[0]["content"].is_array());
@@ -562,7 +578,10 @@ mod tests {
562578
}"#;
563579

564580
let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
565-
let messages = serde_json::to_value(request.messages()).unwrap();
581+
let messages_raw = serde_json::to_value(request.messages()).unwrap();
582+
583+
// Test with preserve_arrays=false (standard templates)
584+
let messages = serde_json::to_value(may_be_fix_msg_content(messages_raw, false)).unwrap();
566585

567586
// Verify: String content is not modified
568587
assert_eq!(
@@ -589,7 +608,10 @@ mod tests {
589608
}"#;
590609

591610
let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
592-
let messages = serde_json::to_value(request.messages()).unwrap();
611+
let messages_raw = serde_json::to_value(request.messages()).unwrap();
612+
613+
// Mixed content should be preserved regardless of preserve_arrays setting
614+
let messages = serde_json::to_value(may_be_fix_msg_content(messages_raw, false)).unwrap();
593615

594616
// Verify: Mixed content types are preserved as array for template handling
595617
assert!(messages[0]["content"].is_array());
@@ -617,7 +639,10 @@ mod tests {
617639
}"#;
618640

619641
let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
620-
let messages = serde_json::to_value(request.messages()).unwrap();
642+
let messages_raw = serde_json::to_value(request.messages()).unwrap();
643+
644+
// Non-text arrays should be preserved regardless of preserve_arrays setting
645+
let messages = serde_json::to_value(may_be_fix_msg_content(messages_raw, false)).unwrap();
621646

622647
// Verify: Non-text content arrays are preserved for template handling
623648
assert!(messages[0]["content"].is_array());
@@ -713,7 +738,8 @@ NORMAL MODE
713738
}"#;
714739

715740
let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
716-
let messages = serde_json::to_value(request.messages()).unwrap();
741+
let messages_raw = serde_json::to_value(request.messages()).unwrap();
742+
let messages = serde_json::to_value(may_be_fix_msg_content(messages_raw, false)).unwrap();
717743

718744
// Mixed types should preserve array structure
719745
assert!(messages[0]["content"].is_array());
@@ -735,7 +761,8 @@ NORMAL MODE
735761
}"#;
736762

737763
let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
738-
let messages = serde_json::to_value(request.messages()).unwrap();
764+
let messages_raw = serde_json::to_value(request.messages()).unwrap();
765+
let messages = serde_json::to_value(may_be_fix_msg_content(messages_raw, false)).unwrap();
739766

740767
// Unknown types mixed with text should preserve array
741768
assert!(messages[0]["content"].is_array());
@@ -873,11 +900,15 @@ NORMAL MODE
873900
}"#;
874901

875902
let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
876-
let mut messages = serde_json::to_value(request.messages()).unwrap();
903+
let messages_raw = serde_json::to_value(request.messages()).unwrap();
904+
905+
// Apply content normalization with preserve_arrays=false (standard templates)
906+
let mut messages =
907+
serde_json::to_value(may_be_fix_msg_content(messages_raw, false)).unwrap();
877908

878909
normalize_tool_arguments_in_messages(&mut messages);
879910

880-
// Multimodal content preserved as array
911+
// Multimodal content preserved as array (mixed types not flattened)
881912
assert!(messages[0]["content"].is_array());
882913
assert_eq!(messages[0]["content"].as_array().unwrap().len(), 3);
883914

@@ -889,6 +920,63 @@ NORMAL MODE
889920
);
890921
}
891922

923+
/// Tests string → array normalization for multimodal templates
924+
#[test]
925+
fn test_may_be_fix_msg_content_string_to_array() {
926+
let json_str = r#"{
927+
"model": "gpt-4o",
928+
"messages": [
929+
{
930+
"role": "user",
931+
"content": "Hello, how are you?"
932+
}
933+
]
934+
}"#;
935+
936+
let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
937+
let messages_raw = serde_json::to_value(request.messages()).unwrap();
938+
939+
// Test with preserve_arrays=true (multimodal templates)
940+
let messages = serde_json::to_value(may_be_fix_msg_content(messages_raw, true)).unwrap();
941+
942+
// Verify: String is converted to array format
943+
assert!(messages[0]["content"].is_array());
944+
let content_array = messages[0]["content"].as_array().unwrap();
945+
assert_eq!(content_array.len(), 1);
946+
assert_eq!(content_array[0]["type"], "text");
947+
assert_eq!(content_array[0]["text"], "Hello, how are you?");
948+
}
949+
950+
/// Tests that arrays are preserved when preserve_arrays=true
951+
#[test]
952+
fn test_may_be_fix_msg_content_array_preserved_with_multimodal() {
953+
let json_str = r#"{
954+
"model": "gpt-4o",
955+
"messages": [
956+
{
957+
"role": "user",
958+
"content": [
959+
{"type": "text", "text": "part 1"},
960+
{"type": "text", "text": "part 2"}
961+
]
962+
}
963+
]
964+
}"#;
965+
966+
let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
967+
let messages_raw = serde_json::to_value(request.messages()).unwrap();
968+
969+
// Test with preserve_arrays=true (multimodal templates)
970+
let messages = serde_json::to_value(may_be_fix_msg_content(messages_raw, true)).unwrap();
971+
972+
// Verify: Array is preserved as-is
973+
assert!(messages[0]["content"].is_array());
974+
let content_array = messages[0]["content"].as_array().unwrap();
975+
assert_eq!(content_array.len(), 2);
976+
assert_eq!(content_array[0]["text"], "part 1");
977+
assert_eq!(content_array[1]["text"], "part 2");
978+
}
979+
892980
fn user() -> Msg {
893981
Msg::User(Default::default())
894982
}

tests/serve/test_vllm.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,42 @@ class VLLMConfig(EngineConfig):
230230
),
231231
],
232232
),
233+
"multimodal_agg_llava": VLLMConfig(
234+
name="multimodal_agg_llava",
235+
directory=vllm_dir,
236+
script_name="agg_multimodal.sh",
237+
marks=[
238+
pytest.mark.gpu_2,
239+
# https://github.com/ai-dynamo/dynamo/issues/4501
240+
pytest.mark.xfail(strict=False),
241+
],
242+
model="llava-hf/llava-1.5-7b-hf",
243+
script_args=["--model", "llava-hf/llava-1.5-7b-hf"],
244+
delayed_start=0,
245+
timeout=360,
246+
request_payloads=[
247+
# HTTP URL test
248+
chat_payload(
249+
[
250+
{"type": "text", "text": "What is in this image?"},
251+
{
252+
"type": "image_url",
253+
"image_url": {
254+
"url": "http://images.cocodataset.org/test2017/000000155781.jpg"
255+
},
256+
},
257+
],
258+
repeat_count=1,
259+
expected_response=["bus"],
260+
temperature=0.0,
261+
),
262+
# String content test - verifies string → array conversion for multimodal templates
263+
chat_payload_default(
264+
repeat_count=1,
265+
expected_response=[], # Just validate no error
266+
),
267+
],
268+
),
233269
# TODO: Update this test case when we have video multimodal support in vllm official components
234270
"multimodal_video_agg": VLLMConfig(
235271
name="multimodal_video_agg",

0 commit comments

Comments
 (0)