diff --git a/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java index 2f780db110..aa8c935138 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java @@ -17,6 +17,7 @@ import org.opensearch.ml.common.connector.functions.preprocess.CohereRerankPreProcessFunction; import org.opensearch.ml.common.connector.functions.preprocess.ImageEmbeddingPreProcessFunction; import org.opensearch.ml.common.connector.functions.preprocess.MultiModalConnectorPreProcessFunction; +import org.opensearch.ml.common.connector.functions.preprocess.NovaMultiModalEmbeddingPreProcessFunction; import org.opensearch.ml.common.connector.functions.preprocess.OpenAIEmbeddingPreProcessFunction; import org.opensearch.ml.common.connector.functions.preprocess.VideoEmbeddingPreProcessFunction; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; @@ -30,6 +31,7 @@ public class MLPreProcessFunction { public static final String TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT = "connector.pre_process.openai.embedding"; public static final String TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT = "connector.pre_process.bedrock.embedding"; public static final String TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT = "connector.pre_process.bedrock.multimodal_embedding"; + public static final String BEDROCK_NOVA_MULTI_MODAL_EMBEDDING_INPUT = "connector.pre_process.bedrock.nova.embedding"; public static final String TEXT_TO_BEDROCK_NOVA_EMBEDDING_INPUT = "connector.pre_process.bedrock.nova.text_embedding"; public static final String IMAGE_TO_BEDROCK_NOVA_EMBEDDING_INPUT = "connector.pre_process.bedrock.nova.image_embedding"; public static final String VIDEO_TO_BEDROCK_NOVA_EMBEDDING_INPUT = "connector.pre_process.bedrock.nova.video_embedding"; @@ -49,6 +51,8 @@ public class MLPreProcessFunction { CohereRerankPreProcessFunction cohereRerankPreProcessFunction = new CohereRerankPreProcessFunction(); BedrockRerankPreProcessFunction bedrockRerankPreProcessFunction = new BedrockRerankPreProcessFunction(); MultiModalConnectorPreProcessFunction multiModalEmbeddingPreProcessFunction = new MultiModalConnectorPreProcessFunction(); + NovaMultiModalEmbeddingPreProcessFunction novaMultiModalEmbeddingPreProcessFunction = + new NovaMultiModalEmbeddingPreProcessFunction(); ImageEmbeddingPreProcessFunction imageEmbeddingPreProcessFunction = new ImageEmbeddingPreProcessFunction(); VideoEmbeddingPreProcessFunction videoEmbeddingPreProcessFunction = new VideoEmbeddingPreProcessFunction(); AudioEmbeddingPreProcessFunction audioEmbeddingPreProcessFunction = new AudioEmbeddingPreProcessFunction(); @@ -57,6 +61,7 @@ public class MLPreProcessFunction { PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, cohereEmbeddingPreProcessFunction); PRE_PROCESS_FUNCTIONS.put(IMAGE_TO_COHERE_MULTI_MODAL_EMBEDDING_INPUT, cohereMultiModalEmbeddingPreProcessFunction); PRE_PROCESS_FUNCTIONS.put(TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT, multiModalEmbeddingPreProcessFunction); + PRE_PROCESS_FUNCTIONS.put(BEDROCK_NOVA_MULTI_MODAL_EMBEDDING_INPUT, novaMultiModalEmbeddingPreProcessFunction); PRE_PROCESS_FUNCTIONS.put(TEXT_TO_BEDROCK_NOVA_EMBEDDING_INPUT, bedrockEmbeddingPreProcessFunction); PRE_PROCESS_FUNCTIONS.put(IMAGE_TO_BEDROCK_NOVA_EMBEDDING_INPUT, imageEmbeddingPreProcessFunction); PRE_PROCESS_FUNCTIONS.put(VIDEO_TO_BEDROCK_NOVA_EMBEDDING_INPUT, videoEmbeddingPreProcessFunction); diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/NovaMultiModalEmbeddingPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/NovaMultiModalEmbeddingPreProcessFunction.java new file mode 100644 index 0000000000..2f7c3c34ec --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/NovaMultiModalEmbeddingPreProcessFunction.java @@ -0,0 +1,71 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.connector.functions.preprocess; + +import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.opensearch.ml.common.dataset.TextDocsInputDataSet; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class NovaMultiModalEmbeddingPreProcessFunction extends ConnectorPreProcessFunction { + + public NovaMultiModalEmbeddingPreProcessFunction() { + this.returnDirectlyForRemoteInferenceInput = true; + } + + @Override + public void validate(MLInput mlInput) { + validateTextDocsInput(mlInput); + List docs = ((TextDocsInputDataSet) mlInput.getInputDataset()).getDocs(); + if (docs.size() == 0 || (docs.size() == 1 && docs.get(0) == null)) { + throw new IllegalArgumentException("No input provided"); + } + } + + @Override + public RemoteInferenceInputDataSet process(MLInput mlInput) { + TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset(); + String input = inputData.getDocs().get(0); + + Map parametersMap = new HashMap<>(); + String parameterName = detectModalityParameter(input); + parametersMap.put(parameterName, input); + + return RemoteInferenceInputDataSet + .builder() + .parameters(convertScriptStringToJsonString(Map.of("parameters", parametersMap))) + .build(); + } + + private String detectModalityParameter(String input) { + try { + if (input.contains("\"text\"")) { + return "inputText"; + } + if (input.contains("\"image\"")) { + return "inputImage"; + } + if (input.contains("\"video\"")) { + return "inputVideo"; + } + if (input.contains("\"audio\"")) { + return "inputAudio"; + } + return "inputText"; + } catch (Exception e) { + log.warn("Failed to detect modality from input, defaulting to text: {}", e.getMessage()); + return "inputText"; + } + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/NovaMultiModalEmbeddingPreProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/NovaMultiModalEmbeddingPreProcessFunctionTest.java new file mode 100644 index 0000000000..47b9aafd38 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/NovaMultiModalEmbeddingPreProcessFunctionTest.java @@ -0,0 +1,117 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.connector.functions.preprocess; + +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.util.Arrays; +import java.util.Collections; +import java.util.Map; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.TextDocsInputDataSet; +import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; + +public class NovaMultiModalEmbeddingPreProcessFunctionTest { + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + NovaMultiModalEmbeddingPreProcessFunction function; + + TextSimilarityInputDataSet textSimilarityInputDataSet; + TextDocsInputDataSet textDocsInputDataSet; + RemoteInferenceInputDataSet remoteInferenceInputDataSet; + + MLInput textEmbeddingInput; + MLInput textSimilarityInput; + MLInput remoteInferenceInput; + + @Before + public void setUp() { + function = new NovaMultiModalEmbeddingPreProcessFunction(); + textSimilarityInputDataSet = TextSimilarityInputDataSet.builder().queryText("test").textDocs(Arrays.asList("hello")).build(); + textDocsInputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("hello")).build(); + remoteInferenceInputDataSet = RemoteInferenceInputDataSet.builder().parameters(Map.of("key1", "value1", "key2", "value2")).build(); + + textEmbeddingInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build(); + textSimilarityInput = MLInput.builder().algorithm(FunctionName.TEXT_SIMILARITY).inputDataset(textSimilarityInputDataSet).build(); + remoteInferenceInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(remoteInferenceInputDataSet).build(); + } + + @Test + public void process_NullInput() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Preprocess function input can't be null"); + function.apply(null); + } + + @Test + public void process_WrongInput() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("This pre_process_function can only support TextDocsInputDataSet"); + function.apply(textSimilarityInput); + } + + @Test + public void process_TextInput() { + MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build(); + RemoteInferenceInputDataSet dataSet = function.apply(mlInput); + assertEquals(1, dataSet.getParameters().size()); + assertEquals("hello", dataSet.getParameters().get("inputText")); + } + + @Test + public void process_JsonImageInput() { + TextDocsInputDataSet jsonInputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("{\"image\": \"base64data\"}")).build(); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(jsonInputDataSet).build(); + RemoteInferenceInputDataSet dataSet = function.apply(mlInput); + assertEquals(1, dataSet.getParameters().size()); + assertEquals("{\"image\": \"base64data\"}", dataSet.getParameters().get("inputImage")); + } + + @Test + public void process_VideoInput() { + TextDocsInputDataSet jsonInputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("{\"video\": \"videodata\"}")).build(); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(jsonInputDataSet).build(); + RemoteInferenceInputDataSet dataSet = function.apply(mlInput); + assertEquals(1, dataSet.getParameters().size()); + assertEquals("{\"video\": \"videodata\"}", dataSet.getParameters().get("inputVideo")); + } + + @Test + public void process_AudioInput() { + TextDocsInputDataSet jsonInputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("{\"audio\": \"audiodata\"}")).build(); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(jsonInputDataSet).build(); + RemoteInferenceInputDataSet dataSet = function.apply(mlInput); + assertEquals(1, dataSet.getParameters().size()); + assertEquals("{\"audio\": \"audiodata\"}", dataSet.getParameters().get("inputAudio")); + } + + @Test + public void process_EmptyDocs() { + TextDocsInputDataSet mockDataSet = mock(TextDocsInputDataSet.class); + when(mockDataSet.getDocs()).thenReturn(Collections.emptyList()); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(mockDataSet).build(); + + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("No input provided"); + function.apply(mlInput); + } + + @Test + public void process_RemoteInferenceInput() { + RemoteInferenceInputDataSet dataSet = function.apply(remoteInferenceInput); + assertEquals(remoteInferenceInputDataSet, dataSet); + } +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java index cc0736c9e2..d63b203473 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java @@ -53,6 +53,9 @@ import org.opensearch.ml.engine.processor.ProcessorChain; import org.opensearch.script.ScriptService; +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; +import com.google.gson.JsonParser; import com.jayway.jsonpath.JsonPath; import lombok.extern.log4j.Log4j2; @@ -72,6 +75,7 @@ public class ConnectorUtils { private static final AwsV4HttpSigner signer; public static final String SKIP_VALIDATE_MISSING_PARAMETERS = "skip_validating_missing_parameters"; + public static final String BEDROCK_NOVA_MODEL = "amazon.nova-2-multimodal-embeddings-v1:0"; static { signer = AwsV4HttpSigner.create(); @@ -340,6 +344,13 @@ public static SdkHttpFullRequest buildSdkRequest( SdkHttpMethod method ) { String charset = parameters.getOrDefault("charset", "UTF-8"); + + // Clean empty JSON sections for Bedrock Nova embedding requests + String model = connector.getParameters().get("model"); + if (payload != null && model != null && model.equals(BEDROCK_NOVA_MODEL)) { + payload = cleanBedrockNovaRequest(payload); + } + RequestBody requestBody; if (payload != null) { requestBody = RequestBody.fromString(payload, Charset.forName(charset)); @@ -480,4 +491,38 @@ public static ConnectorAction createConnectorAction(Connector connector, Connect .headers(batchPredictAction.get().getHeaders()) .build(); } + + private static String cleanBedrockNovaRequest(String json) { + try { + JsonObject root = JsonParser.parseString(json).getAsJsonObject(); + JsonObject params = root.getAsJsonObject("singleEmbeddingParams"); + if (params == null) + return json; + + removeIfNull(params, "text"); + removeIfNull(params, "image"); + removeIfNull(params, "video"); + removeIfNull(params, "audio"); + + return gson.toJson(root); + } catch (Exception e) { + log.warn("Failed to clean empty JSON sections: {}", e.getMessage()); + return json; + } + } + + private static void removeIfNull(JsonObject parent, String fieldName) { + JsonObject field = parent.getAsJsonObject(fieldName); + if (field == null) + return; + + // Check text field's value or other fields' source.bytes + JsonElement element = "text".equals(fieldName) + ? field.get("value") + : (field.getAsJsonObject("source") != null ? field.getAsJsonObject("source").get("bytes") : null); + + if (element != null && element.isJsonNull()) { + parent.remove(fieldName); + } + } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java index a5821944d0..9b1859a0e9 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java @@ -10,12 +10,14 @@ import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.when; import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.BATCH_PREDICT_STATUS; import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.CANCEL_BATCH_PREDICT; import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT; import static org.opensearch.ml.common.utils.StringUtils.gson; +import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.BEDROCK_NOVA_MODEL; import java.io.IOException; import java.util.ArrayList; @@ -47,6 +49,7 @@ import com.google.common.collect.ImmutableMap; import okhttp3.Request; +import software.amazon.awssdk.http.SdkHttpFullRequest; public class ConnectorUtilsTest { @@ -1057,6 +1060,51 @@ public void buildSdkRequest_CancelBatchPredictWithEmptyPayload() { } } + @Test + public void buildSdkRequest_NovaModelCleansJson() throws IOException { + Connector connector = mock(Connector.class); + when(connector.getParameters()).thenReturn(Map.of("model", BEDROCK_NOVA_MODEL)); + when(connector.getActionEndpoint("predict", Map.of())) + .thenReturn("https://bedrock-runtime.us-east-1.amazonaws.com/model/test/invoke"); + when(connector.getDecryptedHeaders()).thenReturn(Map.of("Content-Type", "application/json")); + + String payloadWithNulls = + "{\"singleEmbeddingParams\":{\"text\":{\"value\":\"hello\"},\"video\":{\"source\":{\"bytes\":null}},\"audio\":{\"source\":{\"bytes\":null}}}}"; + + SdkHttpFullRequest request = ConnectorUtils + .buildSdkRequest("predict", connector, Map.of(), payloadWithNulls, software.amazon.awssdk.http.SdkHttpMethod.POST); + + // Verify request was created successfully + assertNotNull(request); + assertTrue(request.contentStreamProvider().isPresent()); + + // Verify the payload was cleaned, null values removed + String actualPayload = new String(request.contentStreamProvider().get().newStream().readAllBytes()); + String expectedPayload = "{\"singleEmbeddingParams\":{\"text\":{\"value\":\"hello\"}}}"; + assertEquals(expectedPayload, actualPayload); + } + + @Test + public void testBuildSdkRequest_NonNovaModelSkipsCleaning() throws IOException { + Connector connector = mock(Connector.class); + when(connector.getParameters()).thenReturn(Map.of("model", "gpt-3.5-turbo")); + when(connector.getActionEndpoint("predict", Map.of())).thenReturn("https://api.openai.com/v1/chat/completions"); + when(connector.getDecryptedHeaders()).thenReturn(Map.of("Content-Type", "application/json")); + + String payloadWithNulls = "{\"video\":{\"source\":{\"bytes\":null}}}"; + + SdkHttpFullRequest request = ConnectorUtils + .buildSdkRequest("predict", connector, Map.of(), payloadWithNulls, software.amazon.awssdk.http.SdkHttpMethod.POST); + + // Verify request was created successfully + assertNotNull(request); + assertTrue(request.contentStreamProvider().isPresent()); + + // Verify the payload was not cleaned, null values preserved + String actualPayload = new String(request.contentStreamProvider().get().newStream().readAllBytes()); + assertEquals(payloadWithNulls, actualPayload); + } + @Test public void createConnectorAction_WithEmptyParameters() { Connector connector = HttpConnector