From 8e44fc2ad28309f802c072a7e99bce21cb356a2d Mon Sep 17 00:00:00 2001 From: SenreySong <25841017+SenreySong@users.noreply.github.com> Date: Thu, 13 Nov 2025 16:11:42 +0800 Subject: [PATCH] Fix extraBody loss during ModelOptionsUtils.merge() The extraBody field was being lost when merging OpenAiChatOptions into ChatCompletionRequest using ModelOptionsUtils.merge(). This fix ensures that extraBody parameters (like top_k, repetition_penalty, etc.) are properly preserved during the merge operation. Added test ExtraBodyMergeTest to verify the fix. Signed-off-by: SenreySong <25841017+SenreySong@users.noreply.github.com> --- .../ai/openai/OpenAiChatModel.java | 5 +- .../ai/openai/api/OpenAiApi.java | 25 ++++++++ .../ai/openai/ExtraBodyMergeTest.java | 60 +++++++++++++++++++ 3 files changed, 88 insertions(+), 2 deletions(-) create mode 100644 models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ExtraBodyMergeTest.java diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java index 19a884363c7..f631d27d303 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java @@ -628,14 +628,15 @@ else if (message.getMessageType() == MessageType.TOOL) { ChatCompletionRequest request = new ChatCompletionRequest(chatCompletionMessages, stream); OpenAiChatOptions requestOptions = (OpenAiChatOptions) prompt.getOptions(); - request = ModelOptionsUtils.merge(requestOptions, request, ChatCompletionRequest.class); + request = ModelOptionsUtils.merge(requestOptions, request, ChatCompletionRequest.class, + ChatCompletionRequest.getMergableFieldNames()); // Add the tool definitions to the request's tools parameter. List toolDefinitions = this.toolCallingManager.resolveToolDefinitions(requestOptions); if (!CollectionUtils.isEmpty(toolDefinitions)) { request = ModelOptionsUtils.merge( OpenAiChatOptions.builder().tools(this.getFunctionTools(toolDefinitions)).build(), request, - ChatCompletionRequest.class); + ChatCompletionRequest.class, ChatCompletionRequest.getMergableFieldNames()); } // Remove `streamOptions` from the request if it is not a streaming request diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java index e4ea5e4010c..5b797a49f3d 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java @@ -16,6 +16,8 @@ package org.springframework.ai.openai.api; +import java.lang.reflect.Field; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Map; @@ -1261,6 +1263,29 @@ private void setExtraBodyProperty(String key, Object value) { } } + /** + * Returns the list of field names that can be merged with other ChatCompletionRequest instances. + * This includes all @JsonProperty annotated fields plus special fields like extra_body + * that use @JsonAnyGetter/@JsonAnySetter. + * @return list of mergeable field names + */ + public static List getMergableFieldNames() { + List fieldNames = new ArrayList<>(); + + // Get all @JsonProperty annotated fields + for (Field field : ChatCompletionRequest.class.getDeclaredFields()) { + JsonProperty jsonPropertyAnnotation = field.getAnnotation(JsonProperty.class); + if (jsonPropertyAnnotation != null) { + fieldNames.add(jsonPropertyAnnotation.value()); + } + } + + // Add extra_body field (uses @JsonAnyGetter/@JsonAnySetter, not @JsonProperty) + fieldNames.add("extra_body"); + + return fieldNames; + } + /** * Helper factory that creates a tool_choice of type 'none', 'auto' or selected function by name. */ diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ExtraBodyMergeTest.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ExtraBodyMergeTest.java new file mode 100644 index 00000000000..bd96a1e23ec --- /dev/null +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ExtraBodyMergeTest.java @@ -0,0 +1,60 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai; + +import java.util.List; +import java.util.Map; + +import org.junit.jupiter.api.Test; + +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.openai.api.OpenAiApi; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests to verify that extraBody is preserved when using ModelOptionsUtils.merge(). + * + * @author senrey_song + */ +class ExtraBodyMergeTest { + + @Test + void shouldPreserveExtraBodyAfterMerge() { + List messages = List + .of(new ChatCompletionMessage("test message", OpenAiApi.ChatCompletionMessage.Role.USER)); + ChatCompletionRequest request = new ChatCompletionRequest(messages, false); + + OpenAiChatOptions requestOptions = OpenAiChatOptions.builder() + .extraBody(Map.of("top_k", 50, "repetition_penalty", 1.1, "custom_param", "custom_value")) + .build(); + + request = ModelOptionsUtils.merge(requestOptions, request, ChatCompletionRequest.class, + ChatCompletionRequest.getMergableFieldNames()); + + assertThat(request.extraBody()).isNotNull(); + @SuppressWarnings("unchecked") + Map extraBodyMap = (Map) request.extraBody().get("extra_body"); + assertThat(extraBodyMap).hasSize(3) + .containsEntry("top_k", 50) + .containsEntry("repetition_penalty", 1.1) + .containsEntry("custom_param", "custom_value"); + } + +}