Skip to content

Commit fd6a817

Browse files
Clean up, add tests
Signed-off-by: Nathalie Jonathan <nathhjo@amazon.com>
1 parent d8a9c75 commit fd6a817

File tree

6 files changed

+207
-11
lines changed

6 files changed

+207
-11
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/streaming/BedrockStreamingHandler.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,9 +212,8 @@ private void handleStreamEvent(
212212
} else if (isStreamComplete(event)) {
213213
// For PER agent, we should keep the connection open after the planner LLM finish
214214
if ("per".equals(agentType)) {
215-
sendPlannerResponse(false, listener, String.valueOf(accumulatedContent));
216215
currentState.set(StreamState.WAITING_FOR_TOOL_RESULT);
217-
log.info("PER agent planner phase completed - waiting for execution phase");
216+
sendPlannerResponse(false, listener, String.valueOf(accumulatedContent));
218217
} else {
219218
currentState.set(StreamState.COMPLETED);
220219
sendCompletionResponse(isStreamClosed, listener);

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -710,7 +710,7 @@ public void testToolParameters() {
710710
// Verify the size of parameters passed in the tool run method.
711711
ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class);
712712
verify(firstTool).run((Map<String, String>) argumentCaptor.capture(), any());
713-
assertEquals(15, ((Map) argumentCaptor.getValue()).size());
713+
assertEquals(16, ((Map) argumentCaptor.getValue()).size());
714714

715715
Mockito.verify(agentActionListener).onResponse(objectCaptor.capture());
716716
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) objectCaptor.getValue();
@@ -738,7 +738,7 @@ public void testToolUseOriginalInput() {
738738
// Verify the size of parameters passed in the tool run method.
739739
ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class);
740740
verify(firstTool).run((Map<String, String>) argumentCaptor.capture(), any());
741-
assertEquals(16, ((Map) argumentCaptor.getValue()).size());
741+
assertEquals(17, ((Map) argumentCaptor.getValue()).size());
742742
assertEquals("raw input", ((Map<?, ?>) argumentCaptor.getValue()).get("input"));
743743

744744
Mockito.verify(agentActionListener).onResponse(objectCaptor.capture());
@@ -804,7 +804,7 @@ public void testToolConfig() {
804804
// Verify the size of parameters passed in the tool run method.
805805
ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class);
806806
verify(firstTool).run((Map<String, String>) argumentCaptor.capture(), any());
807-
assertEquals(16, ((Map) argumentCaptor.getValue()).size());
807+
assertEquals(17, ((Map) argumentCaptor.getValue()).size());
808808
// The value of input should be "config_value".
809809
assertEquals("config_value", ((Map<?, ?>) argumentCaptor.getValue()).get("input"));
810810

@@ -834,7 +834,7 @@ public void testToolConfigWithInputPlaceholder() {
834834
// Verify the size of parameters passed in the tool run method.
835835
ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class);
836836
verify(firstTool).run((Map<String, String>) argumentCaptor.capture(), any());
837-
assertEquals(16, ((Map) argumentCaptor.getValue()).size());
837+
assertEquals(17, ((Map) argumentCaptor.getValue()).size());
838838
// The value of input should be replaced with the value associated with the key "key2" of the first tool.
839839
assertEquals("value2", ((Map<?, ?>) argumentCaptor.getValue()).get("input"));
840840

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import static org.junit.Assert.assertNull;
1212
import static org.junit.Assert.assertThrows;
1313
import static org.junit.Assert.assertTrue;
14+
import static org.junit.Assert.fail;
1415
import static org.mockito.ArgumentMatchers.any;
1516
import static org.mockito.ArgumentMatchers.anyInt;
1617
import static org.mockito.ArgumentMatchers.anyString;
@@ -115,6 +116,8 @@ public class MLPlanExecuteAndReflectAgentRunnerTest extends MLStaticMockBase {
115116
private MLTaskResponse mlTaskResponse;
116117
@Mock
117118
private MLExecuteTaskResponse mlExecuteTaskResponse;
119+
@Mock
120+
private StreamingWrapper streamingWrapper;
118121

119122
@Captor
120123
private ArgumentCaptor<Object> objectCaptor;
@@ -174,6 +177,15 @@ public void setup() {
174177
encryptor
175178
);
176179

180+
// Set streaming wrapper
181+
try {
182+
java.lang.reflect.Field streamingWrapperField = MLPlanExecuteAndReflectAgentRunner.class.getDeclaredField("streamingWrapper");
183+
streamingWrapperField.setAccessible(true);
184+
streamingWrapperField.set(mlPlanExecuteAndReflectAgentRunner, streamingWrapper);
185+
} catch (Exception e) {
186+
fail("Exception thrown: " + e.getMessage());
187+
}
188+
177189
// Setup tools
178190
when(firstToolFactory.create(any())).thenReturn(firstTool);
179191
when(secondToolFactory.create(any())).thenReturn(secondTool);

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/StreamingWrapperTest.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ public void testExecuteRequestNonStreaming() {
159159

160160
@Test
161161
public void testSendCompletionChunkStreaming() throws Exception {
162-
streamingWrapper.sendCompletionChunk("session1", "parent1");
162+
streamingWrapper.sendCompletionChunk("session1", "parent1", "executeMemory", "executeParent");
163163

164164
ArgumentCaptor<MLTaskResponse> responseCaptor = ArgumentCaptor.forClass(MLTaskResponse.class);
165165
verify(channel).sendResponseBatch(responseCaptor.capture());
@@ -170,7 +170,7 @@ public void testSendCompletionChunkStreaming() throws Exception {
170170

171171
@Test
172172
public void testSendCompletionChunkNonStreaming() throws Exception {
173-
nonStreamingWrapper.sendCompletionChunk("session1", "parent1");
173+
nonStreamingWrapper.sendCompletionChunk("session1", "parent1", "executeMemory", "executeParent");
174174

175175
verify(channel, never()).sendResponseBatch(any());
176176
}
@@ -180,7 +180,7 @@ public void testSendCompletionChunkWithException() throws Exception {
180180
doThrow(new RuntimeException("Channel error")).when(channel).sendResponseBatch(any());
181181

182182
// Should not throw exception, just log warning
183-
streamingWrapper.sendCompletionChunk("session1", "parent1");
183+
streamingWrapper.sendCompletionChunk("session1", "parent1", "executeMemory", "executeParent");
184184

185185
verify(channel).sendResponseBatch(any());
186186
}
@@ -236,7 +236,7 @@ public void testSendToolResponseWithException() throws Exception {
236236

237237
@Test
238238
public void testCreateStreamChunkStructure() throws Exception {
239-
streamingWrapper.sendCompletionChunk("test-session", "test-parent");
239+
streamingWrapper.sendCompletionChunk("test-session", "test-parent", "executeMemory", "executeParent");
240240

241241
ArgumentCaptor<MLTaskResponse> responseCaptor = ArgumentCaptor.forClass(MLTaskResponse.class);
242242
verify(channel).sendResponseBatch(responseCaptor.capture());
@@ -245,7 +245,7 @@ public void testCreateStreamChunkStructure() throws Exception {
245245
ModelTensorOutput output = (ModelTensorOutput) response.getOutput();
246246
List<ModelTensor> tensors = output.getMlModelOutputs().get(0).getMlModelTensors();
247247

248-
assertEquals(3, tensors.size());
248+
assertEquals(5, tensors.size());
249249

250250
// Find specific tensors by name
251251
ModelTensor memoryTensor = tensors.stream().filter(t -> "memory_id".equals(t.getName())).findFirst().orElse(null);
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.engine.algorithms.remote.streaming;
7+
8+
import static org.junit.Assert.assertNotNull;
9+
import static org.mockito.ArgumentMatchers.any;
10+
import static org.mockito.Mockito.doAnswer;
11+
import static org.mockito.Mockito.verify;
12+
import static org.mockito.Mockito.when;
13+
14+
import java.util.HashMap;
15+
import java.util.Map;
16+
17+
import org.junit.Before;
18+
import org.junit.Test;
19+
import org.mockito.Mock;
20+
import org.mockito.MockitoAnnotations;
21+
import org.opensearch.ml.common.connector.AwsConnector;
22+
import org.opensearch.ml.common.exception.MLException;
23+
import org.opensearch.ml.common.transport.MLTaskResponse;
24+
25+
import software.amazon.awssdk.http.async.SdkAsyncHttpClient;
26+
27+
public class BedrockStreamingHandlerTest {
28+
29+
@Mock
30+
private SdkAsyncHttpClient httpClient;
31+
@Mock
32+
private AwsConnector connector;
33+
@Mock
34+
private StreamPredictActionListener<MLTaskResponse, ?> actionListener;
35+
36+
private BedrockStreamingHandler bedrockStreamingHandler;
37+
38+
@Before
39+
public void setup() {
40+
MockitoAnnotations.openMocks(this);
41+
42+
when(connector.getRegion()).thenReturn("us-east-1");
43+
when(connector.getAccessKey()).thenReturn("test-access-key");
44+
when(connector.getSecretKey()).thenReturn("test-secret-key");
45+
46+
bedrockStreamingHandler = new BedrockStreamingHandler(httpClient, connector);
47+
}
48+
49+
@Test
50+
public void testConstructor() {
51+
assertNotNull(bedrockStreamingHandler);
52+
}
53+
54+
@Test
55+
public void testHandleError() {
56+
Exception testException = new RuntimeException("Test error");
57+
58+
doAnswer(invocation -> {
59+
MLException exception = invocation.getArgument(0);
60+
assertNotNull(exception);
61+
return null;
62+
}).when(actionListener).onFailure(any(MLException.class));
63+
64+
bedrockStreamingHandler.handleError(testException, actionListener);
65+
verify(actionListener).onFailure(any(MLException.class));
66+
}
67+
68+
@Test
69+
public void testStartStreamInvalidPayload() {
70+
Map<String, String> parameters = new HashMap<>();
71+
parameters.put("model", "test-model");
72+
parameters.put("agent_type", "test");
73+
74+
String invalidPayload = "invalid json";
75+
76+
bedrockStreamingHandler.startStream("test_action", parameters, invalidPayload, actionListener);
77+
78+
verify(actionListener).onFailure(any(MLException.class));
79+
}
80+
}
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.engine.algorithms.remote.streaming;
7+
8+
import static org.junit.Assert.assertNotNull;
9+
import static org.mockito.ArgumentMatchers.any;
10+
import static org.mockito.Mockito.doAnswer;
11+
import static org.mockito.Mockito.verify;
12+
import static org.mockito.Mockito.when;
13+
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_INTERFACE_OPENAI_V1_CHAT_COMPLETIONS;
14+
15+
import java.util.HashMap;
16+
import java.util.Map;
17+
18+
import org.junit.Before;
19+
import org.junit.Test;
20+
import org.mockito.Mock;
21+
import org.mockito.MockitoAnnotations;
22+
import org.opensearch.ml.common.connector.Connector;
23+
import org.opensearch.ml.common.connector.ConnectorClientConfig;
24+
import org.opensearch.ml.common.exception.MLException;
25+
import org.opensearch.ml.common.transport.MLTaskResponse;
26+
27+
public class HttpStreamingHandlerTest {
28+
29+
@Mock
30+
private Connector connector;
31+
@Mock
32+
private ConnectorClientConfig connectorClientConfig;
33+
@Mock
34+
private StreamPredictActionListener<MLTaskResponse, ?> actionListener;
35+
36+
private HttpStreamingHandler httpStreamingHandler;
37+
38+
@Before
39+
public void setup() {
40+
MockitoAnnotations.openMocks(this);
41+
42+
when(connectorClientConfig.getConnectionTimeout()).thenReturn(30);
43+
when(connectorClientConfig.getReadTimeout()).thenReturn(30);
44+
45+
httpStreamingHandler = new HttpStreamingHandler(LLM_INTERFACE_OPENAI_V1_CHAT_COMPLETIONS, connector, connectorClientConfig);
46+
}
47+
48+
@Test
49+
public void testConstructor() {
50+
assertNotNull(httpStreamingHandler);
51+
}
52+
53+
@Test
54+
public void testHandleError() {
55+
Exception testException = new RuntimeException("Test error");
56+
57+
doAnswer(invocation -> {
58+
MLException exception = invocation.getArgument(0);
59+
assertNotNull(exception);
60+
return null;
61+
}).when(actionListener).onFailure(any(MLException.class));
62+
63+
httpStreamingHandler.handleError(testException, actionListener);
64+
verify(actionListener).onFailure(any(MLException.class));
65+
}
66+
67+
@Test
68+
public void testStartStreamWithException() {
69+
Map<String, String> parameters = new HashMap<>();
70+
parameters.put("agent_type", "test");
71+
72+
when(connector.getActions()).thenReturn(null);
73+
httpStreamingHandler.startStream("test_action", parameters, "test_payload", actionListener);
74+
75+
verify(actionListener).onFailure(any(MLException.class));
76+
}
77+
78+
@Test
79+
public void testHTTPEventSourceListenerConstructor() {
80+
HttpStreamingHandler.HTTPEventSourceListener listener = httpStreamingHandler.new HTTPEventSourceListener(
81+
actionListener, LLM_INTERFACE_OPENAI_V1_CHAT_COMPLETIONS, "test_agent"
82+
);
83+
assertNotNull(listener);
84+
}
85+
86+
@Test
87+
public void testHTTPEventSourceListenerOnFailureWithThrowable() {
88+
HttpStreamingHandler.HTTPEventSourceListener listener = httpStreamingHandler.new HTTPEventSourceListener(
89+
actionListener, LLM_INTERFACE_OPENAI_V1_CHAT_COMPLETIONS, "test_agent"
90+
);
91+
92+
RuntimeException testException = new RuntimeException("Test error");
93+
listener.onFailure(null, testException, null);
94+
95+
verify(actionListener).onFailure(any(MLException.class));
96+
}
97+
98+
@Test(expected = IllegalArgumentException.class)
99+
public void testUnsupportedLLMInterface() {
100+
HttpStreamingHandler.HTTPEventSourceListener listener = httpStreamingHandler.new HTTPEventSourceListener(
101+
actionListener, "unsupported_interface", "test_agent"
102+
);
103+
listener.onEvent(null, null, null, "test data");
104+
}
105+
}

0 commit comments

Comments
 (0)