Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.opensearch.ml.common.connector.functions.preprocess.CohereRerankPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.ImageEmbeddingPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.MultiModalConnectorPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.NovaMultiModalEmbeddingPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.OpenAIEmbeddingPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.VideoEmbeddingPreProcessFunction;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
Expand All @@ -30,6 +31,7 @@ public class MLPreProcessFunction {
public static final String TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT = "connector.pre_process.openai.embedding";
public static final String TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT = "connector.pre_process.bedrock.embedding";
public static final String TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT = "connector.pre_process.bedrock.multimodal_embedding";
public static final String BEDROCK_NOVA_MULTI_MODAL_EMBEDDING_INPUT = "connector.pre_process.bedrock.nova.embedding";
public static final String TEXT_TO_BEDROCK_NOVA_EMBEDDING_INPUT = "connector.pre_process.bedrock.nova.text_embedding";
public static final String IMAGE_TO_BEDROCK_NOVA_EMBEDDING_INPUT = "connector.pre_process.bedrock.nova.image_embedding";
public static final String VIDEO_TO_BEDROCK_NOVA_EMBEDDING_INPUT = "connector.pre_process.bedrock.nova.video_embedding";
Expand All @@ -49,6 +51,8 @@ public class MLPreProcessFunction {
CohereRerankPreProcessFunction cohereRerankPreProcessFunction = new CohereRerankPreProcessFunction();
BedrockRerankPreProcessFunction bedrockRerankPreProcessFunction = new BedrockRerankPreProcessFunction();
MultiModalConnectorPreProcessFunction multiModalEmbeddingPreProcessFunction = new MultiModalConnectorPreProcessFunction();
NovaMultiModalEmbeddingPreProcessFunction novaMultiModalEmbeddingPreProcessFunction =
new NovaMultiModalEmbeddingPreProcessFunction();
ImageEmbeddingPreProcessFunction imageEmbeddingPreProcessFunction = new ImageEmbeddingPreProcessFunction();
VideoEmbeddingPreProcessFunction videoEmbeddingPreProcessFunction = new VideoEmbeddingPreProcessFunction();
AudioEmbeddingPreProcessFunction audioEmbeddingPreProcessFunction = new AudioEmbeddingPreProcessFunction();
Expand All @@ -57,6 +61,7 @@ public class MLPreProcessFunction {
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, cohereEmbeddingPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(IMAGE_TO_COHERE_MULTI_MODAL_EMBEDDING_INPUT, cohereMultiModalEmbeddingPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT, multiModalEmbeddingPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(BEDROCK_NOVA_MULTI_MODAL_EMBEDDING_INPUT, novaMultiModalEmbeddingPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(TEXT_TO_BEDROCK_NOVA_EMBEDDING_INPUT, bedrockEmbeddingPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(IMAGE_TO_BEDROCK_NOVA_EMBEDDING_INPUT, imageEmbeddingPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(VIDEO_TO_BEDROCK_NOVA_EMBEDDING_INPUT, videoEmbeddingPreProcessFunction);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.connector.functions.preprocess;

import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;

import lombok.extern.log4j.Log4j2;

@Log4j2
public class NovaMultiModalEmbeddingPreProcessFunction extends ConnectorPreProcessFunction {

public NovaMultiModalEmbeddingPreProcessFunction() {
this.returnDirectlyForRemoteInferenceInput = true;
}

@Override
public void validate(MLInput mlInput) {
validateTextDocsInput(mlInput);
List<String> docs = ((TextDocsInputDataSet) mlInput.getInputDataset()).getDocs();
if (docs.size() == 0 || (docs.size() == 1 && docs.get(0) == null)) {
throw new IllegalArgumentException("No input provided");
}
}

@Override
public RemoteInferenceInputDataSet process(MLInput mlInput) {
TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset();
String input = inputData.getDocs().get(0);

Map<String, String> parametersMap = new HashMap<>();
String parameterName = detectModalityParameter(input);
parametersMap.put(parameterName, input);

return RemoteInferenceInputDataSet
.builder()
.parameters(convertScriptStringToJsonString(Map.of("parameters", parametersMap)))
.build();
}

private String detectModalityParameter(String input) {
try {
if (input.contains("\"text\"")) {
return "inputText";
}
if (input.contains("\"image\"")) {
return "inputImage";
}
if (input.contains("\"video\"")) {
return "inputVideo";
}
if (input.contains("\"audio\"")) {
return "inputAudio";
}
return "inputText";
} catch (Exception e) {
log.warn("Failed to detect modality from input, defaulting to text: {}", e.getMessage());
return "inputText";
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.connector.functions.preprocess;

import static org.junit.Assert.assertEquals;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

import java.util.Arrays;
import java.util.Collections;
import java.util.Map;

import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;

public class NovaMultiModalEmbeddingPreProcessFunctionTest {
@Rule
public ExpectedException exceptionRule = ExpectedException.none();

NovaMultiModalEmbeddingPreProcessFunction function;

TextSimilarityInputDataSet textSimilarityInputDataSet;
TextDocsInputDataSet textDocsInputDataSet;
RemoteInferenceInputDataSet remoteInferenceInputDataSet;

MLInput textEmbeddingInput;
MLInput textSimilarityInput;
MLInput remoteInferenceInput;

@Before
public void setUp() {
function = new NovaMultiModalEmbeddingPreProcessFunction();
textSimilarityInputDataSet = TextSimilarityInputDataSet.builder().queryText("test").textDocs(Arrays.asList("hello")).build();
textDocsInputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("hello")).build();
remoteInferenceInputDataSet = RemoteInferenceInputDataSet.builder().parameters(Map.of("key1", "value1", "key2", "value2")).build();

textEmbeddingInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build();
textSimilarityInput = MLInput.builder().algorithm(FunctionName.TEXT_SIMILARITY).inputDataset(textSimilarityInputDataSet).build();
remoteInferenceInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(remoteInferenceInputDataSet).build();
}

@Test
public void process_NullInput() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Preprocess function input can't be null");
function.apply(null);
}

@Test
public void process_WrongInput() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("This pre_process_function can only support TextDocsInputDataSet");
function.apply(textSimilarityInput);
}

@Test
public void process_TextInput() {
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build();
RemoteInferenceInputDataSet dataSet = function.apply(mlInput);
assertEquals(1, dataSet.getParameters().size());
assertEquals("hello", dataSet.getParameters().get("inputText"));
}

@Test
public void process_JsonImageInput() {
TextDocsInputDataSet jsonInputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("{\"image\": \"base64data\"}")).build();
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(jsonInputDataSet).build();
RemoteInferenceInputDataSet dataSet = function.apply(mlInput);
assertEquals(1, dataSet.getParameters().size());
assertEquals("{\"image\": \"base64data\"}", dataSet.getParameters().get("inputImage"));
}

@Test
public void process_VideoInput() {
TextDocsInputDataSet jsonInputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("{\"video\": \"videodata\"}")).build();
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(jsonInputDataSet).build();
RemoteInferenceInputDataSet dataSet = function.apply(mlInput);
assertEquals(1, dataSet.getParameters().size());
assertEquals("{\"video\": \"videodata\"}", dataSet.getParameters().get("inputVideo"));
}

@Test
public void process_AudioInput() {
TextDocsInputDataSet jsonInputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("{\"audio\": \"audiodata\"}")).build();
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(jsonInputDataSet).build();
RemoteInferenceInputDataSet dataSet = function.apply(mlInput);
assertEquals(1, dataSet.getParameters().size());
assertEquals("{\"audio\": \"audiodata\"}", dataSet.getParameters().get("inputAudio"));
}

@Test
public void process_EmptyDocs() {
TextDocsInputDataSet mockDataSet = mock(TextDocsInputDataSet.class);
when(mockDataSet.getDocs()).thenReturn(Collections.emptyList());
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(mockDataSet).build();

exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("No input provided");
function.apply(mlInput);
}

@Test
public void process_RemoteInferenceInput() {
RemoteInferenceInputDataSet dataSet = function.apply(remoteInferenceInput);
assertEquals(remoteInferenceInputDataSet, dataSet);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@
import org.opensearch.ml.engine.processor.ProcessorChain;
import org.opensearch.script.ScriptService;

import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import com.jayway.jsonpath.JsonPath;

import lombok.extern.log4j.Log4j2;
Expand All @@ -72,6 +75,7 @@ public class ConnectorUtils {

private static final AwsV4HttpSigner signer;
public static final String SKIP_VALIDATE_MISSING_PARAMETERS = "skip_validating_missing_parameters";
public static final String BEDROCK_NOVA_MODEL = "amazon.nova-2-multimodal-embeddings-v1:0";

static {
signer = AwsV4HttpSigner.create();
Expand Down Expand Up @@ -340,6 +344,13 @@ public static SdkHttpFullRequest buildSdkRequest(
SdkHttpMethod method
) {
String charset = parameters.getOrDefault("charset", "UTF-8");

// Clean empty JSON sections for Bedrock Nova embedding requests
String model = connector.getParameters().get("model");
if (payload != null && model != null && model.equals(BEDROCK_NOVA_MODEL)) {
payload = cleanBedrockNovaRequest(payload);
}

RequestBody requestBody;
if (payload != null) {
requestBody = RequestBody.fromString(payload, Charset.forName(charset));
Expand Down Expand Up @@ -480,4 +491,38 @@ public static ConnectorAction createConnectorAction(Connector connector, Connect
.headers(batchPredictAction.get().getHeaders())
.build();
}

private static String cleanBedrockNovaRequest(String json) {
try {
JsonObject root = JsonParser.parseString(json).getAsJsonObject();
JsonObject params = root.getAsJsonObject("singleEmbeddingParams");
if (params == null)
return json;

removeIfNull(params, "text");
removeIfNull(params, "image");
removeIfNull(params, "video");
removeIfNull(params, "audio");

return gson.toJson(root);
} catch (Exception e) {
log.warn("Failed to clean empty JSON sections: {}", e.getMessage());
return json;
}
}

private static void removeIfNull(JsonObject parent, String fieldName) {
JsonObject field = parent.getAsJsonObject(fieldName);
if (field == null)
return;

// Check text field's value or other fields' source.bytes
JsonElement element = "text".equals(fieldName)
? field.get("value")
: (field.getAsJsonObject("source") != null ? field.getAsJsonObject("source").get("bytes") : null);

if (element != null && element.isJsonNull()) {
parent.remove(fieldName);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.BATCH_PREDICT_STATUS;
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.CANCEL_BATCH_PREDICT;
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT;
import static org.opensearch.ml.common.utils.StringUtils.gson;
import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.BEDROCK_NOVA_MODEL;

import java.io.IOException;
import java.util.ArrayList;
Expand Down Expand Up @@ -47,6 +49,7 @@
import com.google.common.collect.ImmutableMap;

import okhttp3.Request;
import software.amazon.awssdk.http.SdkHttpFullRequest;

public class ConnectorUtilsTest {

Expand Down Expand Up @@ -1057,6 +1060,51 @@ public void buildSdkRequest_CancelBatchPredictWithEmptyPayload() {
}
}

@Test
public void buildSdkRequest_NovaModelCleansJson() throws IOException {
Connector connector = mock(Connector.class);
when(connector.getParameters()).thenReturn(Map.of("model", BEDROCK_NOVA_MODEL));
when(connector.getActionEndpoint("predict", Map.of()))
.thenReturn("https://bedrock-runtime.us-east-1.amazonaws.com/model/test/invoke");
when(connector.getDecryptedHeaders()).thenReturn(Map.of("Content-Type", "application/json"));

String payloadWithNulls =
"{\"singleEmbeddingParams\":{\"text\":{\"value\":\"hello\"},\"video\":{\"source\":{\"bytes\":null}},\"audio\":{\"source\":{\"bytes\":null}}}}";

SdkHttpFullRequest request = ConnectorUtils
.buildSdkRequest("predict", connector, Map.of(), payloadWithNulls, software.amazon.awssdk.http.SdkHttpMethod.POST);

// Verify request was created successfully
assertNotNull(request);
assertTrue(request.contentStreamProvider().isPresent());

// Verify the payload was cleaned, null values removed
String actualPayload = new String(request.contentStreamProvider().get().newStream().readAllBytes());
String expectedPayload = "{\"singleEmbeddingParams\":{\"text\":{\"value\":\"hello\"}}}";
assertEquals(expectedPayload, actualPayload);
}

@Test
public void testBuildSdkRequest_NonNovaModelSkipsCleaning() throws IOException {
Connector connector = mock(Connector.class);
when(connector.getParameters()).thenReturn(Map.of("model", "gpt-3.5-turbo"));
when(connector.getActionEndpoint("predict", Map.of())).thenReturn("https://api.openai.com/v1/chat/completions");
when(connector.getDecryptedHeaders()).thenReturn(Map.of("Content-Type", "application/json"));

String payloadWithNulls = "{\"video\":{\"source\":{\"bytes\":null}}}";

SdkHttpFullRequest request = ConnectorUtils
.buildSdkRequest("predict", connector, Map.of(), payloadWithNulls, software.amazon.awssdk.http.SdkHttpMethod.POST);

// Verify request was created successfully
assertNotNull(request);
assertTrue(request.contentStreamProvider().isPresent());

// Verify the payload was not cleaned, null values preserved
String actualPayload = new String(request.contentStreamProvider().get().newStream().readAllBytes());
assertEquals(payloadWithNulls, actualPayload);
}

@Test
public void createConnectorAction_WithEmptyParameters() {
Connector connector = HttpConnector
Expand Down
Loading