@@ -73,10 +73,9 @@ fn may_be_fix_tool_schema(tools: serde_json::Value) -> Option<Value> {
7373 Some ( Value :: from_serialize ( & updated_tools) )
7474}
7575
76- fn may_be_fix_msg_content ( messages : serde_json:: Value ) -> Value {
77- // If messages[content] is provided as a list containing ONLY text parts,
78- // concatenate them into a string to match chat template expectations.
79- // Mixed content types are left for chat templates to handle.
76+ fn may_be_fix_msg_content ( messages : serde_json:: Value , preserve_arrays : bool ) -> Value {
77+ // preserve_arrays=true: strings → arrays (multimodal)
78+ // preserve_arrays=false: text-only arrays → strings (standard)
8079
8180 let Some ( arr) = messages. as_array ( ) else {
8281 return Value :: from_serialize ( & messages) ;
@@ -86,7 +85,20 @@ fn may_be_fix_msg_content(messages: serde_json::Value) -> Value {
8685 . iter ( )
8786 . map ( |msg| {
8887 match msg. get ( "content" ) {
89- Some ( serde_json:: Value :: Array ( content_array) ) => {
88+ // Case 1: String to Array (for multimodal templates)
89+ Some ( serde_json:: Value :: String ( text) ) if preserve_arrays => {
90+ let mut modified_msg = msg. clone ( ) ;
91+ if let Some ( msg_object) = modified_msg. as_object_mut ( ) {
92+ let content_array = serde_json:: json!( [ {
93+ "type" : "text" ,
94+ "text" : text
95+ } ] ) ;
96+ msg_object. insert ( "content" . to_string ( ) , content_array) ;
97+ }
98+ modified_msg
99+ }
100+ // Case 2: Array to String (for standard templates)
101+ Some ( serde_json:: Value :: Array ( content_array) ) if !preserve_arrays => {
90102 let is_text_only_array = !content_array. is_empty ( )
91103 && content_array. iter ( ) . all ( |part| {
92104 part. get ( "type" )
@@ -114,7 +126,7 @@ fn may_be_fix_msg_content(messages: serde_json::Value) -> Value {
114126 msg. clone ( ) // Mixed content or non-text only
115127 }
116128 }
117- _ => msg. clone ( ) , // String content or missing content - return unchanged
129+ _ => msg. clone ( ) , // No conversion needed
118130 }
119131 } )
120132 . collect ( ) ;
@@ -159,19 +171,7 @@ impl OAIChatLikeRequest for NvCreateChatCompletionRequest {
159171
160172 fn messages ( & self ) -> Value {
161173 let messages_json = serde_json:: to_value ( & self . inner . messages ) . unwrap ( ) ;
162-
163- let needs_fixing = if let Some ( arr) = messages_json. as_array ( ) {
164- arr. iter ( )
165- . any ( |msg| msg. get ( "content" ) . and_then ( |c| c. as_array ( ) ) . is_some ( ) )
166- } else {
167- false
168- } ;
169-
170- if needs_fixing {
171- may_be_fix_msg_content ( messages_json)
172- } else {
173- Value :: from_serialize ( & messages_json)
174- }
174+ Value :: from_serialize ( & messages_json)
175175 }
176176
177177 fn tools ( & self ) -> Option < Value > {
@@ -301,6 +301,13 @@ impl OAIPromptFormatter for HfTokenizerConfigJsonFormatter {
301301 let messages_canonical = req. messages ( ) ;
302302 let mut messages_for_template: serde_json:: Value =
303303 serde_json:: to_value ( & messages_canonical) . unwrap ( ) ;
304+
305+ messages_for_template = serde_json:: to_value ( may_be_fix_msg_content (
306+ messages_for_template,
307+ self . requires_content_arrays ,
308+ ) )
309+ . unwrap ( ) ;
310+
304311 normalize_tool_arguments_in_messages ( & mut messages_for_template) ;
305312
306313 let ctx = context ! {
@@ -457,7 +464,10 @@ mod tests {
457464 }"# ;
458465
459466 let request: NvCreateChatCompletionRequest = serde_json:: from_str ( json_str) . unwrap ( ) ;
460- let messages = serde_json:: to_value ( request. messages ( ) ) . unwrap ( ) ;
467+ let messages_raw = serde_json:: to_value ( request. messages ( ) ) . unwrap ( ) ;
468+
469+ // Test array → string normalization (preserve_arrays=false for standard templates)
470+ let messages = serde_json:: to_value ( may_be_fix_msg_content ( messages_raw, false ) ) . unwrap ( ) ;
461471
462472 // Verify: text-only array is concatenated into a single string
463473 assert_eq ! (
@@ -500,7 +510,10 @@ mod tests {
500510 }"# ;
501511
502512 let request: NvCreateChatCompletionRequest = serde_json:: from_str ( json_str) . unwrap ( ) ;
503- let messages = serde_json:: to_value ( request. messages ( ) ) . unwrap ( ) ;
513+ let messages_raw = serde_json:: to_value ( request. messages ( ) ) . unwrap ( ) ;
514+
515+ // Test array → string normalization (preserve_arrays=false for standard templates)
516+ let messages = serde_json:: to_value ( may_be_fix_msg_content ( messages_raw, false ) ) . unwrap ( ) ;
504517
505518 // Verify: System message with string content remains unchanged
506519 assert_eq ! (
@@ -541,7 +554,10 @@ mod tests {
541554 }"# ;
542555
543556 let request: NvCreateChatCompletionRequest = serde_json:: from_str ( json_str) . unwrap ( ) ;
544- let messages = serde_json:: to_value ( request. messages ( ) ) . unwrap ( ) ;
557+ let messages_raw = serde_json:: to_value ( request. messages ( ) ) . unwrap ( ) ;
558+
559+ // Empty arrays should be preserved regardless of preserve_arrays setting
560+ let messages = serde_json:: to_value ( may_be_fix_msg_content ( messages_raw, false ) ) . unwrap ( ) ;
545561
546562 // Verify: Empty arrays are preserved as-is
547563 assert ! ( messages[ 0 ] [ "content" ] . is_array( ) ) ;
@@ -562,7 +578,10 @@ mod tests {
562578 }"# ;
563579
564580 let request: NvCreateChatCompletionRequest = serde_json:: from_str ( json_str) . unwrap ( ) ;
565- let messages = serde_json:: to_value ( request. messages ( ) ) . unwrap ( ) ;
581+ let messages_raw = serde_json:: to_value ( request. messages ( ) ) . unwrap ( ) ;
582+
583+ // Test with preserve_arrays=false (standard templates)
584+ let messages = serde_json:: to_value ( may_be_fix_msg_content ( messages_raw, false ) ) . unwrap ( ) ;
566585
567586 // Verify: String content is not modified
568587 assert_eq ! (
@@ -589,7 +608,10 @@ mod tests {
589608 }"# ;
590609
591610 let request: NvCreateChatCompletionRequest = serde_json:: from_str ( json_str) . unwrap ( ) ;
592- let messages = serde_json:: to_value ( request. messages ( ) ) . unwrap ( ) ;
611+ let messages_raw = serde_json:: to_value ( request. messages ( ) ) . unwrap ( ) ;
612+
613+ // Mixed content should be preserved regardless of preserve_arrays setting
614+ let messages = serde_json:: to_value ( may_be_fix_msg_content ( messages_raw, false ) ) . unwrap ( ) ;
593615
594616 // Verify: Mixed content types are preserved as array for template handling
595617 assert ! ( messages[ 0 ] [ "content" ] . is_array( ) ) ;
@@ -617,7 +639,10 @@ mod tests {
617639 }"# ;
618640
619641 let request: NvCreateChatCompletionRequest = serde_json:: from_str ( json_str) . unwrap ( ) ;
620- let messages = serde_json:: to_value ( request. messages ( ) ) . unwrap ( ) ;
642+ let messages_raw = serde_json:: to_value ( request. messages ( ) ) . unwrap ( ) ;
643+
644+ // Non-text arrays should be preserved regardless of preserve_arrays setting
645+ let messages = serde_json:: to_value ( may_be_fix_msg_content ( messages_raw, false ) ) . unwrap ( ) ;
621646
622647 // Verify: Non-text content arrays are preserved for template handling
623648 assert ! ( messages[ 0 ] [ "content" ] . is_array( ) ) ;
@@ -713,7 +738,8 @@ NORMAL MODE
713738 }"# ;
714739
715740 let request: NvCreateChatCompletionRequest = serde_json:: from_str ( json_str) . unwrap ( ) ;
716- let messages = serde_json:: to_value ( request. messages ( ) ) . unwrap ( ) ;
741+ let messages_raw = serde_json:: to_value ( request. messages ( ) ) . unwrap ( ) ;
742+ let messages = serde_json:: to_value ( may_be_fix_msg_content ( messages_raw, false ) ) . unwrap ( ) ;
717743
718744 // Mixed types should preserve array structure
719745 assert ! ( messages[ 0 ] [ "content" ] . is_array( ) ) ;
@@ -735,7 +761,8 @@ NORMAL MODE
735761 }"# ;
736762
737763 let request: NvCreateChatCompletionRequest = serde_json:: from_str ( json_str) . unwrap ( ) ;
738- let messages = serde_json:: to_value ( request. messages ( ) ) . unwrap ( ) ;
764+ let messages_raw = serde_json:: to_value ( request. messages ( ) ) . unwrap ( ) ;
765+ let messages = serde_json:: to_value ( may_be_fix_msg_content ( messages_raw, false ) ) . unwrap ( ) ;
739766
740767 // Unknown types mixed with text should preserve array
741768 assert ! ( messages[ 0 ] [ "content" ] . is_array( ) ) ;
@@ -873,11 +900,15 @@ NORMAL MODE
873900 }"# ;
874901
875902 let request: NvCreateChatCompletionRequest = serde_json:: from_str ( json_str) . unwrap ( ) ;
876- let mut messages = serde_json:: to_value ( request. messages ( ) ) . unwrap ( ) ;
903+ let messages_raw = serde_json:: to_value ( request. messages ( ) ) . unwrap ( ) ;
904+
905+ // Apply content normalization with preserve_arrays=false (standard templates)
906+ let mut messages =
907+ serde_json:: to_value ( may_be_fix_msg_content ( messages_raw, false ) ) . unwrap ( ) ;
877908
878909 normalize_tool_arguments_in_messages ( & mut messages) ;
879910
880- // Multimodal content preserved as array
911+ // Multimodal content preserved as array (mixed types not flattened)
881912 assert ! ( messages[ 0 ] [ "content" ] . is_array( ) ) ;
882913 assert_eq ! ( messages[ 0 ] [ "content" ] . as_array( ) . unwrap( ) . len( ) , 3 ) ;
883914
@@ -889,6 +920,63 @@ NORMAL MODE
889920 ) ;
890921 }
891922
923+ /// Tests string → array normalization for multimodal templates
924+ #[ test]
925+ fn test_may_be_fix_msg_content_string_to_array ( ) {
926+ let json_str = r#"{
927+ "model": "gpt-4o",
928+ "messages": [
929+ {
930+ "role": "user",
931+ "content": "Hello, how are you?"
932+ }
933+ ]
934+ }"# ;
935+
936+ let request: NvCreateChatCompletionRequest = serde_json:: from_str ( json_str) . unwrap ( ) ;
937+ let messages_raw = serde_json:: to_value ( request. messages ( ) ) . unwrap ( ) ;
938+
939+ // Test with preserve_arrays=true (multimodal templates)
940+ let messages = serde_json:: to_value ( may_be_fix_msg_content ( messages_raw, true ) ) . unwrap ( ) ;
941+
942+ // Verify: String is converted to array format
943+ assert ! ( messages[ 0 ] [ "content" ] . is_array( ) ) ;
944+ let content_array = messages[ 0 ] [ "content" ] . as_array ( ) . unwrap ( ) ;
945+ assert_eq ! ( content_array. len( ) , 1 ) ;
946+ assert_eq ! ( content_array[ 0 ] [ "type" ] , "text" ) ;
947+ assert_eq ! ( content_array[ 0 ] [ "text" ] , "Hello, how are you?" ) ;
948+ }
949+
950+ /// Tests that arrays are preserved when preserve_arrays=true
951+ #[ test]
952+ fn test_may_be_fix_msg_content_array_preserved_with_multimodal ( ) {
953+ let json_str = r#"{
954+ "model": "gpt-4o",
955+ "messages": [
956+ {
957+ "role": "user",
958+ "content": [
959+ {"type": "text", "text": "part 1"},
960+ {"type": "text", "text": "part 2"}
961+ ]
962+ }
963+ ]
964+ }"# ;
965+
966+ let request: NvCreateChatCompletionRequest = serde_json:: from_str ( json_str) . unwrap ( ) ;
967+ let messages_raw = serde_json:: to_value ( request. messages ( ) ) . unwrap ( ) ;
968+
969+ // Test with preserve_arrays=true (multimodal templates)
970+ let messages = serde_json:: to_value ( may_be_fix_msg_content ( messages_raw, true ) ) . unwrap ( ) ;
971+
972+ // Verify: Array is preserved as-is
973+ assert ! ( messages[ 0 ] [ "content" ] . is_array( ) ) ;
974+ let content_array = messages[ 0 ] [ "content" ] . as_array ( ) . unwrap ( ) ;
975+ assert_eq ! ( content_array. len( ) , 2 ) ;
976+ assert_eq ! ( content_array[ 0 ] [ "text" ] , "part 1" ) ;
977+ assert_eq ! ( content_array[ 1 ] [ "text" ] , "part 2" ) ;
978+ }
979+
892980 fn user ( ) -> Msg {
893981 Msg :: User ( Default :: default ( ) )
894982 }
0 commit comments