Skip to content

Commit 11b5bb0

Browse files
committed
Add ml-commons passthrough post process function
Signed-off-by: Andy Qin <[email protected]>
1 parent 85a3135 commit 11b5bb0

File tree

3 files changed

+266
-0
lines changed

3 files changed

+266
-0
lines changed

common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import org.opensearch.ml.common.connector.functions.postprocess.BedrockRerankPostProcessFunction;
1616
import org.opensearch.ml.common.connector.functions.postprocess.CohereRerankPostProcessFunction;
1717
import org.opensearch.ml.common.connector.functions.postprocess.EmbeddingPostProcessFunction;
18+
import org.opensearch.ml.common.connector.functions.postprocess.RemoteMlCommonsPassthroughPostProcessFunction;
1819
import org.opensearch.ml.common.output.model.MLResultDataType;
1920
import org.opensearch.ml.common.output.model.ModelTensor;
2021

@@ -35,6 +36,7 @@ public class MLPostProcessFunction {
3536
public static final String BEDROCK_RERANK = "connector.post_process.bedrock.rerank";
3637
public static final String DEFAULT_EMBEDDING = "connector.post_process.default.embedding";
3738
public static final String DEFAULT_RERANK = "connector.post_process.default.rerank";
39+
public static final String ML_COMMONS_PASSTHROUGH = "connector.post_process.mlcommons.passthrough";
3840

3941
private static final Map<String, String> JSON_PATH_EXPRESSION = new HashMap<>();
4042

@@ -46,6 +48,8 @@ public class MLPostProcessFunction {
4648
BedrockBatchJobArnPostProcessFunction batchJobArnPostProcessFunction = new BedrockBatchJobArnPostProcessFunction();
4749
CohereRerankPostProcessFunction cohereRerankPostProcessFunction = new CohereRerankPostProcessFunction();
4850
BedrockRerankPostProcessFunction bedrockRerankPostProcessFunction = new BedrockRerankPostProcessFunction();
51+
RemoteMlCommonsPassthroughPostProcessFunction remoteMlCommonsPassthroughPostProcessFunction =
52+
new RemoteMlCommonsPassthroughPostProcessFunction();
4953
JSON_PATH_EXPRESSION.put(OPENAI_EMBEDDING, "$.data[*].embedding");
5054
JSON_PATH_EXPRESSION.put(COHERE_EMBEDDING, "$.embeddings");
5155
JSON_PATH_EXPRESSION.put(COHERE_V2_EMBEDDING_FLOAT32, "$.embeddings.float");
@@ -61,6 +65,7 @@ public class MLPostProcessFunction {
6165
JSON_PATH_EXPRESSION.put(COHERE_RERANK, "$.results");
6266
JSON_PATH_EXPRESSION.put(BEDROCK_RERANK, "$.results");
6367
JSON_PATH_EXPRESSION.put(DEFAULT_RERANK, "$[*]");
68+
JSON_PATH_EXPRESSION.put(ML_COMMONS_PASSTHROUGH, "$"); // Get the entire response
6469
POST_PROCESS_FUNCTIONS.put(OPENAI_EMBEDDING, embeddingPostProcessFunction);
6570
POST_PROCESS_FUNCTIONS.put(COHERE_EMBEDDING, embeddingPostProcessFunction);
6671
POST_PROCESS_FUNCTIONS.put(COHERE_V2_EMBEDDING_FLOAT32, embeddingPostProcessFunction);
@@ -76,6 +81,7 @@ public class MLPostProcessFunction {
7681
POST_PROCESS_FUNCTIONS.put(COHERE_RERANK, cohereRerankPostProcessFunction);
7782
POST_PROCESS_FUNCTIONS.put(BEDROCK_RERANK, bedrockRerankPostProcessFunction);
7883
POST_PROCESS_FUNCTIONS.put(DEFAULT_RERANK, cohereRerankPostProcessFunction);
84+
POST_PROCESS_FUNCTIONS.put(ML_COMMONS_PASSTHROUGH, remoteMlCommonsPassthroughPostProcessFunction);
7985
}
8086

8187
public static String getResponseFilter(String postProcessFunction) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.connector.functions.postprocess;
7+
8+
import static org.opensearch.ml.common.output.model.ModelTensors.OUTPUT_FIELD;
9+
10+
import java.util.ArrayList;
11+
import java.util.List;
12+
import java.util.Map;
13+
14+
import org.opensearch.ml.common.output.model.MLResultDataType;
15+
import org.opensearch.ml.common.output.model.ModelTensor;
16+
17+
/**
18+
* A post-processing function for calling a remote ml commons instancethat preserves the original neural sparse response structure
19+
* to avoid double-wrapping when receiving responses from another ML-Commons instance.
20+
*/
21+
public class RemoteMlCommonsPassthroughPostProcessFunction extends ConnectorPostProcessFunction<Map<String, Object>> {
22+
@Override
23+
public void validate(Object input) {
24+
if (!(input instanceof Map) && !(input instanceof List)) {
25+
throw new IllegalArgumentException("Post process function input must be a Map or List");
26+
}
27+
}
28+
29+
@SuppressWarnings("unchecked")
30+
@Override
31+
public List<ModelTensor> process(Map<String, Object> input, MLResultDataType dataType) {
32+
// Check if this is an ML-Commons response with inference_results
33+
if (input.containsKey("inference_results") && input.get("inference_results") instanceof List) {
34+
List<Map<String, Object>> inferenceResults = (List<Map<String, Object>>) input.get("inference_results");
35+
36+
List<ModelTensor> modelTensors = new ArrayList<>();
37+
for (Map<String, Object> result : inferenceResults) {
38+
// Extract the output field which contains the ModelTensor data
39+
if (result.containsKey("output") && result.get("output") instanceof List) {
40+
List<Map<String, Object>> outputs = (List<Map<String, Object>>) result.get("output");
41+
for (Map<String, Object> output : outputs) {
42+
// This inner map should represent a model tensor, so we try to parse and instantiate a new one.
43+
ModelTensor modelTensor = createModelTensorFromMap(output);
44+
modelTensors.add(modelTensor);
45+
}
46+
}
47+
}
48+
49+
return modelTensors;
50+
}
51+
52+
// Fallback for non-ML-Commons responses
53+
ModelTensor tensor = ModelTensor.builder().name("response").dataAsMap(input).build();
54+
55+
return List.of(tensor);
56+
}
57+
58+
/**
59+
* Creates a ModelTensor from a Map<String, Object> representation based on the API format
60+
* of the /_predict API
61+
*/
62+
@SuppressWarnings("unchecked")
63+
private ModelTensor createModelTensorFromMap(Map<String, Object> map) {
64+
String name = (String) map.getOrDefault(ModelTensor.NAME_FIELD, OUTPUT_FIELD);
65+
String result = (String) map.get(ModelTensor.RESULT_FIELD);
66+
Map<String, Object> dataAsMap = (Map<String, Object>) map.get(ModelTensor.DATA_AS_MAP_FIELD);
67+
68+
// Handle data type
69+
MLResultDataType dataType = null;
70+
if (map.containsKey(ModelTensor.DATA_TYPE_FIELD)) {
71+
Object dataTypeObj = map.get(ModelTensor.DATA_TYPE_FIELD);
72+
if (dataTypeObj instanceof String) {
73+
try {
74+
dataType = MLResultDataType.valueOf((String) dataTypeObj);
75+
} catch (IllegalArgumentException e) {
76+
// Invalid data type, leave as null
77+
}
78+
}
79+
}
80+
81+
// Handle shape
82+
long[] shape = null;
83+
if (map.containsKey(ModelTensor.SHAPE_FIELD)) {
84+
Object shapeObj = map.get(ModelTensor.SHAPE_FIELD);
85+
if (shapeObj instanceof List) {
86+
List<?> shapeList = (List<?>) shapeObj;
87+
shape = new long[shapeList.size()];
88+
for (int i = 0; i < shapeList.size(); i++) {
89+
Object item = shapeList.get(i);
90+
if (item instanceof Number) {
91+
shape[i] = ((Number) item).longValue();
92+
}
93+
}
94+
}
95+
}
96+
97+
// Handle data array
98+
Number[] data = null;
99+
if (map.containsKey(ModelTensor.DATA_FIELD)) {
100+
Object dataObj = map.get(ModelTensor.DATA_FIELD);
101+
if (dataObj instanceof List) {
102+
List<?> dataList = (List<?>) dataObj;
103+
data = new Number[dataList.size()];
104+
for (int i = 0; i < dataList.size(); i++) {
105+
Object item = dataList.get(i);
106+
if (item instanceof Number) {
107+
data[i] = (Number) item;
108+
}
109+
}
110+
}
111+
}
112+
113+
// For now, we skip handling byte buffer since it's not needed for neural sparse and dense model use cases.
114+
115+
return ModelTensor.builder().name(name).dataType(dataType).shape(shape).data(data).result(result).dataAsMap(dataAsMap).build();
116+
}
117+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.connector.functions.postprocess;
7+
8+
import static org.junit.Assert.assertEquals;
9+
10+
import java.util.Arrays;
11+
import java.util.List;
12+
import java.util.Map;
13+
14+
import org.junit.Before;
15+
import org.junit.Rule;
16+
import org.junit.Test;
17+
import org.junit.rules.ExpectedException;
18+
import org.opensearch.ml.common.output.model.ModelTensor;
19+
20+
public class RemoteMlCommonsPassthroughPostProcessFunctionTest {
21+
@Rule
22+
public ExpectedException exceptionRule = ExpectedException.none();
23+
24+
RemoteMlCommonsPassthroughPostProcessFunction function;
25+
26+
@Before
27+
public void setUp() {
28+
function = new RemoteMlCommonsPassthroughPostProcessFunction();
29+
}
30+
31+
@Test
32+
public void process_WrongInput_NotMapOrList() {
33+
exceptionRule.expect(IllegalArgumentException.class);
34+
exceptionRule.expectMessage("Post process function input must be a Map or List");
35+
function.apply("abc", null);
36+
}
37+
38+
/**
39+
* Tests processing of ML-Commons response containing sparse vector data with rank features.
40+
* Validates that sparse vectors with dataAsMap containing token-score pairs are correctly parsed.
41+
*/
42+
@Test
43+
public void process_MLCommonsResponse_RankFeatures() {
44+
Map<String, Object> rankFeatures = Map
45+
.of("increasingly", 0.028670792, "achievements", 0.4906937, "nation", 0.15371077, "hello", 0.35982144, "today", 3.0966291);
46+
Map<String, Object> innerDataAsMap = Map.of("response", Arrays.asList(rankFeatures));
47+
Map<String, Object> output = Map.of("name", "output", "dataAsMap", innerDataAsMap);
48+
Map<String, Object> inferenceResult = Map.of("output", Arrays.asList(output));
49+
Map<String, Object> input = Map.of("inference_results", Arrays.asList(inferenceResult));
50+
51+
List<ModelTensor> result = function.apply(input, null);
52+
53+
assertEquals(1, result.size());
54+
ModelTensor tensor = result.get(0);
55+
assertEquals("output", tensor.getName());
56+
assertEquals(innerDataAsMap, tensor.getDataAsMap());
57+
58+
// Verify the nested sparse data structure
59+
Map<String, Object> dataAsMap = (Map<String, Object>) tensor.getDataAsMap();
60+
List<Map<String, Object>> response = (List<Map<String, Object>>) dataAsMap.get("response");
61+
assertEquals(1, response.size());
62+
assertEquals(0.35982144, (Double) response.get(0).get("hello"), 0.0001);
63+
assertEquals(3.0966291, (Double) response.get(0).get("today"), 0.0001);
64+
}
65+
66+
/**
67+
* Tests processing of ML-Commons response containing dense vector data with numerical arrays.
68+
* Validates that dense vectors with data_type, shape, and data fields are correctly parsed.
69+
*/
70+
@Test
71+
public void process_MLCommonsResponse_DenseVector() {
72+
Map<String, Object> output = Map
73+
.of(
74+
"name",
75+
"sentence_embedding",
76+
"data_type",
77+
"FLOAT32",
78+
"shape",
79+
Arrays.asList(3L),
80+
"data",
81+
Arrays.asList(0.5400895, -0.19082281, 0.4996347)
82+
);
83+
Map<String, Object> inferenceResult = Map.of("output", Arrays.asList(output));
84+
Map<String, Object> input = Map.of("inference_results", Arrays.asList(inferenceResult));
85+
86+
List<ModelTensor> result = function.apply(input, null);
87+
88+
assertEquals(1, result.size());
89+
ModelTensor tensor = result.get(0);
90+
assertEquals("sentence_embedding", tensor.getName());
91+
assertEquals(3, tensor.getShape().length);
92+
assertEquals(3L, tensor.getShape()[0]);
93+
assertEquals(3, tensor.getData().length);
94+
assertEquals(0.5400895, tensor.getData()[0].doubleValue(), 0.0001);
95+
}
96+
97+
/**
98+
* Tests processing of ML-Commons response with multiple output tensors in a single inference result.
99+
* Ensures all outputs are processed and returned as separate ModelTensor objects.
100+
*/
101+
@Test
102+
public void process_MLCommonsResponse_MultipleOutputs() {
103+
Map<String, Object> output1 = Map.of("name", "output1", "result", "result1");
104+
Map<String, Object> output2 = Map.of("name", "output2", "result", "result2");
105+
Map<String, Object> inferenceResult = Map.of("output", Arrays.asList(output1, output2));
106+
Map<String, Object> input = Map.of("inference_results", Arrays.asList(inferenceResult));
107+
108+
List<ModelTensor> result = function.apply(input, null);
109+
110+
assertEquals(2, result.size());
111+
assertEquals("output1", result.get(0).getName());
112+
assertEquals("result1", result.get(0).getResult());
113+
assertEquals("output2", result.get(1).getName());
114+
assertEquals("result2", result.get(1).getResult());
115+
}
116+
117+
/**
118+
* Tests edge case where ML-Commons response has empty inference_results array.
119+
* Should return empty list without errors.
120+
*/
121+
@Test
122+
public void process_MLCommonsResponse_EmptyInferenceResults() {
123+
Map<String, Object> input = Map.of("inference_results", Arrays.asList());
124+
125+
List<ModelTensor> result = function.apply(input, null);
126+
127+
assertEquals(0, result.size());
128+
}
129+
130+
/**
131+
* Tests edge case where inference result lacks the expected "output" field.
132+
* Should skip processing and return empty list.
133+
*/
134+
@Test
135+
public void process_MLCommonsResponse_NoOutputField() {
136+
Map<String, Object> inferenceResult = Map.of("other_field", "value");
137+
Map<String, Object> input = Map.of("inference_results", Arrays.asList(inferenceResult));
138+
139+
List<ModelTensor> result = function.apply(input, null);
140+
141+
assertEquals(0, result.size());
142+
}
143+
}

0 commit comments

Comments
 (0)