diff --git a/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java b/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java index e81ccc54a3..a4e65ce381 100644 --- a/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java +++ b/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java @@ -1225,4 +1225,46 @@ public void testDeserializeNullFloat_ToNull() { assertTrue(m.get("fPrim").isJsonPrimitive()); assertEquals(1.0f, m.get("fPrim").getAsFloat(), 1e-9f); } + + @Test + public void testProcessTextDoc_ExceptionMessageEscaping() { + // Test the problematic exception message from the error + String problematicMessage = "Invalid payload: { \"system\": [{\"text\": \"You are a precise...\"}], \"messages\": [...] }\n" + + "See https://github.com/google/gson/blob/main/Troubleshooting.md#unexpected-json-structure"; + + String escapedMessage = StringUtils.processTextDoc(problematicMessage); + + // Verify that problematic characters are escaped + assertFalse( + "Escaped message should not contain unescaped newlines", + escapedMessage.contains("\n") && !escapedMessage.contains("\\n") + ); + assertFalse( + "Escaped message should not contain unescaped quotes", + escapedMessage.contains("\"") && !escapedMessage.contains("\\\"") + ); + } + + @Test + public void testProcessTextDoc_GsonParsingErrorMessageEscaping() { + // Test the specific Gson error message pattern + String gsonError = "Expected BEGIN_ARRAY but was STRING at line 1 column 1 path $\n" + + "See https://github.com/google/gson/blob/main/Troubleshooting.md#unexpected-json-structure"; + + String escapedMessage = StringUtils.processTextDoc(gsonError); + + // The escaped message should be safe for JSON inclusion + assertTrue("Escaped message should be safe for JSON", !escapedMessage.contains("\n") || escapedMessage.contains("\\n")); + } + + @Test + public void testProcessTextDoc_NormalMessagePassthrough() { + // Test that normal messages without special characters pass through unchanged + String normalMessage = "Tool execution failed with normal error"; + + String escapedMessage = StringUtils.processTextDoc(normalMessage); + + // Normal messages should be handled properly + assertTrue("Normal messages should be handled properly", escapedMessage.length() > 0); + } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 103f3f89b3..09fcf67dda 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -625,7 +625,13 @@ private static void runTool( .add( substitute( tmpParameters.get(INTERACTION_TEMPLATE_TOOL_RESPONSE), - Map.of(TOOL_CALL_ID, toolCallId, "tool_response", "Tool " + action + " failed: " + e.getMessage()), + Map + .of( + TOOL_CALL_ID, + toolCallId, + "tool_response", + "Tool " + action + " failed: " + processTextDoc(e.getMessage()) + ), INTERACTIONS_PREFIX ) ); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index f6c3e3618e..acdd5d1795 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -691,6 +691,38 @@ public void testToolThrowException() { assertNotNull(modelTensorOutput); } + @Test + public void testToolExceptionMessageEscaping() { + // Mock tool validation to return true + when(firstTool.validate(any())).thenReturn(true); + + // Create an MLAgent with tools + MLAgent mlAgent = createMLAgentWithTools(); + + // Create parameters for the agent + Map params = createAgentParamsWithAction(FIRST_TOOL, "someInput"); + + // Mock tool to throw exception with problematic characters (quotes, newlines) + String problematicMessage = "Invalid payload: { \"system\": [{\"text\": \"You are a precise...\"}] }\n" + + "See https://github.com/google/gson/blob/main/Troubleshooting.md#unexpected-json-structure"; + + Mockito + .doThrow(new IllegalArgumentException(problematicMessage)) + .when(firstTool) + .run(Mockito.anyMap(), toolListenerCaptor.capture()); + + // Run the MLChatAgentRunner + mlChatAgentRunner.run(mlAgent, params, agentActionListener, null); + + // Verify that the tool's run method was called + verify(firstTool).run(any(), any()); + + // Verify that the agent completes without throwing JSON parsing exceptions + Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) objectCaptor.getValue(); + assertNotNull("Agent should complete successfully even with problematic exception messages", modelTensorOutput); + } + @Test public void testToolParameters() { // Mock tool validation to return false.