diff --git a/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java index efda9c4743..fe10a831f5 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java @@ -15,6 +15,7 @@ import org.opensearch.ml.common.connector.functions.postprocess.BedrockRerankPostProcessFunction; import org.opensearch.ml.common.connector.functions.postprocess.CohereRerankPostProcessFunction; import org.opensearch.ml.common.connector.functions.postprocess.EmbeddingPostProcessFunction; +import org.opensearch.ml.common.connector.functions.postprocess.RemoteMlCommonsPassthroughPostProcessFunction; import org.opensearch.ml.common.output.model.MLResultDataType; import org.opensearch.ml.common.output.model.ModelTensor; @@ -35,6 +36,8 @@ public class MLPostProcessFunction { public static final String BEDROCK_RERANK = "connector.post_process.bedrock.rerank"; public static final String DEFAULT_EMBEDDING = "connector.post_process.default.embedding"; public static final String DEFAULT_RERANK = "connector.post_process.default.rerank"; + // ML commons passthrough unwraps a remote ml-commons response and reconstructs model tensors directly based on remote inference + public static final String ML_COMMONS_PASSTHROUGH = "connector.post_process.mlcommons.passthrough"; private static final Map JSON_PATH_EXPRESSION = new HashMap<>(); @@ -46,6 +49,8 @@ public class MLPostProcessFunction { BedrockBatchJobArnPostProcessFunction batchJobArnPostProcessFunction = new BedrockBatchJobArnPostProcessFunction(); CohereRerankPostProcessFunction cohereRerankPostProcessFunction = new CohereRerankPostProcessFunction(); BedrockRerankPostProcessFunction bedrockRerankPostProcessFunction = new BedrockRerankPostProcessFunction(); + RemoteMlCommonsPassthroughPostProcessFunction remoteMlCommonsPassthroughPostProcessFunction = + new RemoteMlCommonsPassthroughPostProcessFunction(); JSON_PATH_EXPRESSION.put(OPENAI_EMBEDDING, "$.data[*].embedding"); JSON_PATH_EXPRESSION.put(COHERE_EMBEDDING, "$.embeddings"); JSON_PATH_EXPRESSION.put(COHERE_V2_EMBEDDING_FLOAT32, "$.embeddings.float"); @@ -61,6 +66,7 @@ public class MLPostProcessFunction { JSON_PATH_EXPRESSION.put(COHERE_RERANK, "$.results"); JSON_PATH_EXPRESSION.put(BEDROCK_RERANK, "$.results"); JSON_PATH_EXPRESSION.put(DEFAULT_RERANK, "$[*]"); + JSON_PATH_EXPRESSION.put(ML_COMMONS_PASSTHROUGH, "$"); // Get the entire response POST_PROCESS_FUNCTIONS.put(OPENAI_EMBEDDING, embeddingPostProcessFunction); POST_PROCESS_FUNCTIONS.put(COHERE_EMBEDDING, embeddingPostProcessFunction); POST_PROCESS_FUNCTIONS.put(COHERE_V2_EMBEDDING_FLOAT32, embeddingPostProcessFunction); @@ -76,6 +82,7 @@ public class MLPostProcessFunction { POST_PROCESS_FUNCTIONS.put(COHERE_RERANK, cohereRerankPostProcessFunction); POST_PROCESS_FUNCTIONS.put(BEDROCK_RERANK, bedrockRerankPostProcessFunction); POST_PROCESS_FUNCTIONS.put(DEFAULT_RERANK, cohereRerankPostProcessFunction); + POST_PROCESS_FUNCTIONS.put(ML_COMMONS_PASSTHROUGH, remoteMlCommonsPassthroughPostProcessFunction); } public static String getResponseFilter(String postProcessFunction) { diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/RemoteMlCommonsPassthroughPostProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/RemoteMlCommonsPassthroughPostProcessFunction.java new file mode 100644 index 0000000000..b991ee82d8 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/RemoteMlCommonsPassthroughPostProcessFunction.java @@ -0,0 +1,192 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.connector.functions.postprocess; + +import static org.opensearch.ml.common.output.model.ModelTensors.OUTPUT_FIELD; + +import java.lang.reflect.Array; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +import org.opensearch.ml.common.output.model.MLResultDataType; +import org.opensearch.ml.common.output.model.ModelTensor; + +/** + * A post-processing function for calling a remote ml commons instance that preserves the original neural sparse response structure + * to avoid double-wrapping when receiving responses from another ML-Commons instance. + */ +public class RemoteMlCommonsPassthroughPostProcessFunction extends ConnectorPostProcessFunction> { + @Override + public void validate(Object input) { + if (!(input instanceof Map) && !(input instanceof List)) { + throw new IllegalArgumentException("Post process function input must be a Map or List"); + } + } + + /** + * Example unwrapped response: + * { + * "inference_results": [ + * { + * "output": [ + * { + * "name": "output", + * "dataAsMap": { + * "inference_results": [ + * { + * "output": [ + * { + * "name": "output", + * "dataAsMap": { + * "response": [ + * { + * "increasingly": 0.028670792, + * "achievements": 0.4906937, + * ... + * } + * ] + * } + * } + * ], + * "status_code": 200.0 + * } + * ] + * } + * } + * ], + * "status_code": 200 + * } + * ] + * } + * + * Example unwrapped response: + * + * { + * "inference_results": [ + * { + * "output": [ + * { + * "name": "output", + * "dataAsMap": { + * "response": [ + * { + * "increasingly": 0.028670792, + * "achievements": 0.4906937, + * ... + * } + * ] + * } + * }, + * ], + * "status_code": 200 + * } + * ] + * } + * + * @param mlCommonsResponse raw remote ml commons response + * @param dataType the datatype of the result, not used since datatype is set based on the response body + * @return a list of model tensors representing the inner model tensors + */ + @Override + public List process(Map mlCommonsResponse, MLResultDataType dataType) { + // Check if this is an ML-Commons response with inference_results + if (mlCommonsResponse.containsKey("inference_results") && mlCommonsResponse.get("inference_results") instanceof List) { + List> inferenceResults = (List>) mlCommonsResponse.get("inference_results"); + + List modelTensors = new ArrayList<>(); + for (Map result : inferenceResults) { + // Extract the output field which contains the ModelTensor data + if (result.containsKey("output") && result.get("output") instanceof List) { + List> outputs = (List>) result.get("output"); + for (Map output : outputs) { + // This inner map should represent a model tensor, so we try to parse and instantiate a new one. + ModelTensor modelTensor = createModelTensorFromMap(output); + if (modelTensor != null) { + modelTensors.add(modelTensor); + } + } + } + } + + return modelTensors; + } + + // Fallback for non-ML-Commons responses + ModelTensor tensor = ModelTensor.builder().name("response").dataAsMap(mlCommonsResponse).build(); + + return List.of(tensor); + } + + /** + * Creates a ModelTensor from a Map representation based on the API format + * of the /_predict API + */ + private ModelTensor createModelTensorFromMap(Map map) { + if (map == null || map.isEmpty()) { + return null; + } + + // Get name. If name is null or not a String, default to OUTPUT_FIELD + Object uncastedName = map.get(ModelTensor.NAME_FIELD); + String name = uncastedName instanceof String castedName ? castedName : OUTPUT_FIELD; + String result = (String) map.get(ModelTensor.RESULT_FIELD); + + // Handle data as map + Map dataAsMap = (Map) map.get(ModelTensor.DATA_AS_MAP_FIELD); + + // Handle data type. For certain models like neural sparse and non-dense remote models, this field + // is not populated and left as null instead, which is still valid + MLResultDataType dataType = null; + if (map.containsKey(ModelTensor.DATA_TYPE_FIELD)) { + Object dataTypeObj = map.get(ModelTensor.DATA_TYPE_FIELD); + if (dataTypeObj instanceof String) { + try { + dataType = MLResultDataType.valueOf((String) dataTypeObj); + } catch (IllegalArgumentException e) { + // Invalid data type, leave as null in case inner data is still useful to be parsed in the future + } + } + } + + // Handle shape. For certain models like neural sparse and non-dense, null is valid since inference result + // is stored in dataAsMap, not data/shape field + long[] shape = null; + if (map.containsKey(ModelTensor.SHAPE_FIELD)) { + Number[] numbers = processNumericalArray(map, ModelTensor.SHAPE_FIELD, Number.class); + if (numbers != null) { + shape = Arrays.stream(numbers).mapToLong(Number::longValue).toArray(); + } + } + + // Handle shape. For certain models like neural sparse and non-dense, null is valid since inference result + // is stored in dataAsMap, not data/shape field + Number[] data = null; + if (map.containsKey(ModelTensor.DATA_FIELD)) { + data = processNumericalArray(map, ModelTensor.DATA_FIELD, Number.class); + } + + // For now, we skip handling byte buffer since it's not needed for neural sparse and dense model use cases. + + return ModelTensor.builder().name(name).dataType(dataType).shape(shape).data(data).result(result).dataAsMap(dataAsMap).build(); + } + + private static T[] processNumericalArray(Map map, String key, Class type) { + Object obj = map.get(key); + if (obj instanceof List list) { + T[] array = (T[]) Array.newInstance(type, list.size()); + for (int i = 0; i < list.size(); i++) { + Object item = list.get(i); + if (type.isInstance(item)) { + array[i] = type.cast(item); + } + } + return array; + } + return null; + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/RemoteMlCommonsPassthroughPostProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/RemoteMlCommonsPassthroughPostProcessFunctionTest.java new file mode 100644 index 0000000000..b2cc031f70 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/RemoteMlCommonsPassthroughPostProcessFunctionTest.java @@ -0,0 +1,193 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.connector.functions.postprocess; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.opensearch.ml.common.output.model.ModelTensors.OUTPUT_FIELD; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.opensearch.ml.common.output.model.MLResultDataType; +import org.opensearch.ml.common.output.model.ModelTensor; + +public class RemoteMlCommonsPassthroughPostProcessFunctionTest { + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + RemoteMlCommonsPassthroughPostProcessFunction function; + + @Before + public void setUp() { + function = new RemoteMlCommonsPassthroughPostProcessFunction(); + } + + @Test + public void process_WrongInput_NotMapOrList() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Post process function input must be a Map or List"); + function.apply("abc", null); + } + + /** + * Tests processing of ML-Commons response containing sparse vector data with rank features. + * Validates that sparse vectors with dataAsMap containing token-score pairs are correctly parsed. + */ + @Test + public void process_MLCommonsResponse_RankFeatures() { + Map rankFeatures = Map + .of("increasingly", 0.028670792, "achievements", 0.4906937, "nation", 0.15371077, "hello", 0.35982144, "today", 3.0966291); + Map innerDataAsMap = Map.of("response", Arrays.asList(rankFeatures)); + Map output = Map.of("name", "output", "dataAsMap", innerDataAsMap); + Map inferenceResult = Map.of("output", Arrays.asList(output)); + Map input = Map.of("inference_results", Arrays.asList(inferenceResult)); + + List result = function.apply(input, null); + + assertEquals(1, result.size()); + ModelTensor tensor = result.get(0); + assertEquals("output", tensor.getName()); + assertEquals(innerDataAsMap, tensor.getDataAsMap()); + + // Verify the nested sparse data structure + Map dataAsMap = (Map) tensor.getDataAsMap(); + List> response = (List>) dataAsMap.get("response"); + assertEquals(1, response.size()); + assertEquals(0.35982144, (Double) response.get(0).get("hello"), 0.0001); + assertEquals(3.0966291, (Double) response.get(0).get("today"), 0.0001); + } + + /** + * Tests processing of ML-Commons response containing dense vector data with numerical arrays. + * Validates that dense vectors with data_type, shape, and data fields are correctly parsed. + */ + @Test + public void process_MLCommonsResponse_DenseVector() { + Map output = Map + .of( + "name", + "sentence_embedding", + "data_type", + "FLOAT32", + "shape", + Arrays.asList(3L), + "data", + Arrays.asList(0.5400895, -0.19082281, 0.4996347) + ); + Map inferenceResult = Map.of("output", Arrays.asList(output)); + Map input = Map.of("inference_results", Arrays.asList(inferenceResult)); + + List result = function.apply(input, null); + + assertEquals(1, result.size()); + ModelTensor tensor = result.get(0); + assertEquals("sentence_embedding", tensor.getName()); + assertEquals(MLResultDataType.FLOAT32, tensor.getDataType()); + assertEquals(1, tensor.getShape().length); + assertEquals(3L, tensor.getShape()[0]); + assertEquals(3, tensor.getData().length); + assertEquals(0.5400895, tensor.getData()[0].doubleValue(), 0.0001); + } + + /** + * Tests processing of ML-Commons response with multiple output tensors in a single inference result. + * Ensures all outputs are processed and returned as separate ModelTensor objects. + */ + @Test + public void process_MLCommonsResponse_MultipleOutputs() { + Map output1 = Map.of("name", "output1", "result", "result1"); + Map output2 = Map.of("name", "output2", "result", "result2"); + Map inferenceResult = Map.of("output", Arrays.asList(output1, output2)); + Map input = Map.of("inference_results", Arrays.asList(inferenceResult)); + + List result = function.apply(input, null); + + assertEquals(2, result.size()); + assertEquals("output1", result.get(0).getName()); + assertEquals("result1", result.get(0).getResult()); + assertEquals("output2", result.get(1).getName()); + assertEquals("result2", result.get(1).getResult()); + } + + /** + * Tests edge case where ML-Commons response has empty inference_results array. + * Should return empty list without errors. + */ + @Test + public void process_MLCommonsResponse_EmptyInferenceResults() { + Map input = Map.of("inference_results", Arrays.asList()); + + List result = function.apply(input, null); + + assertEquals(0, result.size()); + } + + /** + * Tests edge cases where inference result lacks the expected format. + * Should skip processing and return empty list. + */ + @Test + public void process_MLCommonsResponse_InvalidOutputs() { + Map inferenceResult = Map.of("other_field", "value"); + Map input = Map.of("inference_results", Arrays.asList(inferenceResult)); + + List result = function.apply(input, null); + + assertEquals(0, result.size()); + + // correct format, but with empty output + inferenceResult = Map.of("output", List.of(Map.of())); + input = Map.of("inference_results", List.of(inferenceResult)); + + result = function.apply(input, null); + + assertEquals(0, result.size()); + + // Fallback for non-ml-commons responses + input = Map.of("invalid_format", "invalid value"); + result = function.apply(input, null); + + assertEquals(1, result.size()); + assertEquals(input, result.getFirst().getDataAsMap()); + assertEquals("response", result.getFirst().getName()); + } + + /** + * Tests processing of ML-Commons response containing dense vector data with numerical arrays. + * Validates that when the types are incorrect, values are parsed as nulls. + */ + @Test + public void process_MLCommonsResponse_InvalidDenseVectorFormat() { + Map output = Map + .of( + "name", + List.of("Not a string"), + "data_type", + "NON-EXISTENT TYPE", + "shape", + "not a list of long", + "data", + "not a list of numbers" + ); + Map inferenceResult = Map.of("output", Arrays.asList(output)); + Map input = Map.of("inference_results", Arrays.asList(inferenceResult)); + + List result = function.apply(input, null); + + assertEquals(1, result.size()); + ModelTensor tensor = result.getFirst(); + assertEquals(OUTPUT_FIELD, tensor.getName()); + assertNull(tensor.getShape()); + assertNull(tensor.getData()); + assertNull(tensor.getDataType()); + } +}