Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.opensearch.ml.common.connector.functions.postprocess.BedrockRerankPostProcessFunction;
import org.opensearch.ml.common.connector.functions.postprocess.CohereRerankPostProcessFunction;
import org.opensearch.ml.common.connector.functions.postprocess.EmbeddingPostProcessFunction;
import org.opensearch.ml.common.connector.functions.postprocess.RemoteMlCommonsPassthroughPostProcessFunction;
import org.opensearch.ml.common.output.model.MLResultDataType;
import org.opensearch.ml.common.output.model.ModelTensor;

Expand All @@ -35,6 +36,7 @@ public class MLPostProcessFunction {
public static final String BEDROCK_RERANK = "connector.post_process.bedrock.rerank";
public static final String DEFAULT_EMBEDDING = "connector.post_process.default.embedding";
public static final String DEFAULT_RERANK = "connector.post_process.default.rerank";
public static final String ML_COMMONS_PASSTHROUGH = "connector.post_process.mlcommons.passthrough";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as this is a special kind, can we please add a comment here?


private static final Map<String, String> JSON_PATH_EXPRESSION = new HashMap<>();

Expand All @@ -46,6 +48,8 @@ public class MLPostProcessFunction {
BedrockBatchJobArnPostProcessFunction batchJobArnPostProcessFunction = new BedrockBatchJobArnPostProcessFunction();
CohereRerankPostProcessFunction cohereRerankPostProcessFunction = new CohereRerankPostProcessFunction();
BedrockRerankPostProcessFunction bedrockRerankPostProcessFunction = new BedrockRerankPostProcessFunction();
RemoteMlCommonsPassthroughPostProcessFunction remoteMlCommonsPassthroughPostProcessFunction =
new RemoteMlCommonsPassthroughPostProcessFunction();
JSON_PATH_EXPRESSION.put(OPENAI_EMBEDDING, "$.data[*].embedding");
JSON_PATH_EXPRESSION.put(COHERE_EMBEDDING, "$.embeddings");
JSON_PATH_EXPRESSION.put(COHERE_V2_EMBEDDING_FLOAT32, "$.embeddings.float");
Expand All @@ -61,6 +65,7 @@ public class MLPostProcessFunction {
JSON_PATH_EXPRESSION.put(COHERE_RERANK, "$.results");
JSON_PATH_EXPRESSION.put(BEDROCK_RERANK, "$.results");
JSON_PATH_EXPRESSION.put(DEFAULT_RERANK, "$[*]");
JSON_PATH_EXPRESSION.put(ML_COMMONS_PASSTHROUGH, "$"); // Get the entire response
POST_PROCESS_FUNCTIONS.put(OPENAI_EMBEDDING, embeddingPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(COHERE_EMBEDDING, embeddingPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(COHERE_V2_EMBEDDING_FLOAT32, embeddingPostProcessFunction);
Expand All @@ -76,6 +81,7 @@ public class MLPostProcessFunction {
POST_PROCESS_FUNCTIONS.put(COHERE_RERANK, cohereRerankPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(BEDROCK_RERANK, bedrockRerankPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(DEFAULT_RERANK, cohereRerankPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(ML_COMMONS_PASSTHROUGH, remoteMlCommonsPassthroughPostProcessFunction);
}

public static String getResponseFilter(String postProcessFunction) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.connector.functions.postprocess;

import static org.opensearch.ml.common.output.model.ModelTensors.OUTPUT_FIELD;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;

import org.opensearch.ml.common.output.model.MLResultDataType;
import org.opensearch.ml.common.output.model.ModelTensor;

/**
* A post-processing function for calling a remote ml commons instance that preserves the original neural sparse response structure
* to avoid double-wrapping when receiving responses from another ML-Commons instance.
*/
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add example in the doc for better understanding?

public class RemoteMlCommonsPassthroughPostProcessFunction extends ConnectorPostProcessFunction<Map<String, Object>> {
@Override
public void validate(Object input) {
if (!(input instanceof Map) && !(input instanceof List)) {
throw new IllegalArgumentException("Post process function input must be a Map or List");
}
}

@Override
public List<ModelTensor> process(Map<String, Object> input, MLResultDataType dataType) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of Map<String, Object> input, Since this parameter represents ML Commons inference results or response data, some better name suggestions would be:

mlCommonsResponse
inferenceResponse
responseData
modelResponse

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, didn't realize you could rename the parameter variables of an overridden function. Changed to mlCommonsResponse

// Check if this is an ML-Commons response with inference_results
if (input.containsKey("inference_results") && input.get("inference_results") instanceof List) {
List<Map<String, Object>> inferenceResults = (List<Map<String, Object>>) input.get("inference_results");

List<ModelTensor> modelTensors = new ArrayList<>();
for (Map<String, Object> result : inferenceResults) {
// Extract the output field which contains the ModelTensor data
if (result.containsKey("output") && result.get("output") instanceof List) {
List<Map<String, Object>> outputs = (List<Map<String, Object>>) result.get("output");
for (Map<String, Object> output : outputs) {
// This inner map should represent a model tensor, so we try to parse and instantiate a new one.
ModelTensor modelTensor = createModelTensorFromMap(output);
if (modelTensor != null) {
modelTensors.add(modelTensor);
}
}
}
}

return modelTensors;
}

// Fallback for non-ML-Commons responses
ModelTensor tensor = ModelTensor.builder().name("response").dataAsMap(input).build();

return List.of(tensor);
}

/**
* Creates a ModelTensor from a Map<String, Object> representation based on the API format
* of the /_predict API
*/
private ModelTensor createModelTensorFromMap(Map<String, Object> map) {
if (map == null || map.isEmpty()) {
return null;
}

// Get name. If name is null or not a String, default to OUTPUT_FIELD
Object uncastedName = map.get(ModelTensor.NAME_FIELD);
String name = uncastedName instanceof String castedName ? castedName : OUTPUT_FIELD;
String result = (String) map.get(ModelTensor.RESULT_FIELD);
Map<String, Object> dataAsMap = (Map<String, Object>) map.get(ModelTensor.DATA_AS_MAP_FIELD);

// Handle data type
MLResultDataType dataType = null;
if (map.containsKey(ModelTensor.DATA_TYPE_FIELD)) {
Object dataTypeObj = map.get(ModelTensor.DATA_TYPE_FIELD);
if (dataTypeObj instanceof String) {
try {
dataType = MLResultDataType.valueOf((String) dataTypeObj);
} catch (IllegalArgumentException e) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IllegalArgumentException is a 500 level error. Should we treat this as a 5xx error?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also what is the reason for leaving as null?

Copy link
Contributor Author

@q-andy q-andy Aug 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left as null instead of treating like 5xx is because even if the model data type is invalid, there may be usecases where we can still parse and use the model data or dataAsMap. E.g. some model types like neural sparse or NER don't include data type as part of the model response, so null is still valid.

Right now since we're focused on neural sparse and dense, my thought process is its better to leave it flexible to be able to possible handle different model response formats. For example, in the future we me add a new datatype for dense models and we might inference that from an older version of ml-commons: perhaps we can still use the data by casting the data at the processor level. Updated the comment to explain this.

// Invalid data type, leave as null
}
}
}

// Handle shape
long[] shape = null;
if (map.containsKey(ModelTensor.SHAPE_FIELD)) {
Object shapeObj = map.get(ModelTensor.SHAPE_FIELD);
if (shapeObj instanceof List<?> shapeList) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the else we are sending null for shape, is that expected?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, e.g. neural sparse and dense models populate different fields, sparse leaves data_type, data, and shape as null.

Dense

{
	"name": "sentence_embedding",
	"data_type": "FLOAT32",
	"shape": [
		768
	],
	"data": [...]
}

Sparse

{
	"name": "output",
	"dataAsMap": {
		"response": [
			{ ... }
		]
	}
}

shape = new long[shapeList.size()];
for (int i = 0; i < shapeList.size(); i++) {
Object item = shapeList.get(i);
if (item instanceof Number) {
shape[i] = ((Number) item).longValue();
}
}
}
}

// Handle data array
Number[] data = null;
if (map.containsKey(ModelTensor.DATA_FIELD)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like both of the underlying logic could be put in a common method like:

private <T> T[] processNumericArray(Object obj, Class<T> type) {
    if (obj instanceof List<?> list) {
        T[] result = (T[]) Array.newInstance(type, list.size());
        // ... process the list
        return result;
    }
    return null;
}

Object dataObj = map.get(ModelTensor.DATA_FIELD);
if (dataObj instanceof List<?> dataList) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens if we send data as null to the model Tensor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above comment, this is valid for neural sparse and other remote model format. Looks like data field is primarily used for vector/numerical info, and if a model output doesn't include that, then it uses dataAsMap instead and data field being null is valid.

data = new Number[dataList.size()];
for (int i = 0; i < dataList.size(); i++) {
Object item = dataList.get(i);
if (item instanceof Number) {
data[i] = (Number) item;
}
}
}
}

// For now, we skip handling byte buffer since it's not needed for neural sparse and dense model use cases.

return ModelTensor.builder().name(name).dataType(dataType).shape(shape).data(data).result(result).dataAsMap(dataAsMap).build();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.connector.functions.postprocess;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.opensearch.ml.common.output.model.ModelTensors.OUTPUT_FIELD;

import java.util.Arrays;
import java.util.List;
import java.util.Map;

import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.opensearch.ml.common.output.model.MLResultDataType;
import org.opensearch.ml.common.output.model.ModelTensor;

public class RemoteMlCommonsPassthroughPostProcessFunctionTest {
@Rule
public ExpectedException exceptionRule = ExpectedException.none();

RemoteMlCommonsPassthroughPostProcessFunction function;

@Before
public void setUp() {
function = new RemoteMlCommonsPassthroughPostProcessFunction();
}

@Test
public void process_WrongInput_NotMapOrList() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Post process function input must be a Map or List");
function.apply("abc", null);
}

/**
* Tests processing of ML-Commons response containing sparse vector data with rank features.
* Validates that sparse vectors with dataAsMap containing token-score pairs are correctly parsed.
*/
@Test
public void process_MLCommonsResponse_RankFeatures() {
Map<String, Object> rankFeatures = Map
.of("increasingly", 0.028670792, "achievements", 0.4906937, "nation", 0.15371077, "hello", 0.35982144, "today", 3.0966291);
Map<String, Object> innerDataAsMap = Map.of("response", Arrays.asList(rankFeatures));
Map<String, Object> output = Map.of("name", "output", "dataAsMap", innerDataAsMap);
Map<String, Object> inferenceResult = Map.of("output", Arrays.asList(output));
Map<String, Object> input = Map.of("inference_results", Arrays.asList(inferenceResult));

List<ModelTensor> result = function.apply(input, null);

assertEquals(1, result.size());
ModelTensor tensor = result.get(0);
assertEquals("output", tensor.getName());
assertEquals(innerDataAsMap, tensor.getDataAsMap());

// Verify the nested sparse data structure
Map<String, Object> dataAsMap = (Map<String, Object>) tensor.getDataAsMap();
List<Map<String, Object>> response = (List<Map<String, Object>>) dataAsMap.get("response");
assertEquals(1, response.size());
assertEquals(0.35982144, (Double) response.get(0).get("hello"), 0.0001);
assertEquals(3.0966291, (Double) response.get(0).get("today"), 0.0001);
}

/**
* Tests processing of ML-Commons response containing dense vector data with numerical arrays.
* Validates that dense vectors with data_type, shape, and data fields are correctly parsed.
*/
@Test
public void process_MLCommonsResponse_DenseVector() {
Map<String, Object> output = Map
.of(
"name",
"sentence_embedding",
"data_type",
"FLOAT32",
"shape",
Arrays.asList(3L),
"data",
Arrays.asList(0.5400895, -0.19082281, 0.4996347)
);
Map<String, Object> inferenceResult = Map.of("output", Arrays.asList(output));
Map<String, Object> input = Map.of("inference_results", Arrays.asList(inferenceResult));

List<ModelTensor> result = function.apply(input, null);

assertEquals(1, result.size());
ModelTensor tensor = result.get(0);
assertEquals("sentence_embedding", tensor.getName());
assertEquals(MLResultDataType.FLOAT32, tensor.getDataType());
assertEquals(1, tensor.getShape().length);
assertEquals(3L, tensor.getShape()[0]);
assertEquals(3, tensor.getData().length);
assertEquals(0.5400895, tensor.getData()[0].doubleValue(), 0.0001);
}

/**
* Tests processing of ML-Commons response with multiple output tensors in a single inference result.
* Ensures all outputs are processed and returned as separate ModelTensor objects.
*/
@Test
public void process_MLCommonsResponse_MultipleOutputs() {
Map<String, Object> output1 = Map.of("name", "output1", "result", "result1");
Map<String, Object> output2 = Map.of("name", "output2", "result", "result2");
Map<String, Object> inferenceResult = Map.of("output", Arrays.asList(output1, output2));
Map<String, Object> input = Map.of("inference_results", Arrays.asList(inferenceResult));

List<ModelTensor> result = function.apply(input, null);

assertEquals(2, result.size());
assertEquals("output1", result.get(0).getName());
assertEquals("result1", result.get(0).getResult());
assertEquals("output2", result.get(1).getName());
assertEquals("result2", result.get(1).getResult());
}

/**
* Tests edge case where ML-Commons response has empty inference_results array.
* Should return empty list without errors.
*/
@Test
public void process_MLCommonsResponse_EmptyInferenceResults() {
Map<String, Object> input = Map.of("inference_results", Arrays.asList());

List<ModelTensor> result = function.apply(input, null);

assertEquals(0, result.size());
}

/**
* Tests edge cases where inference result lacks the expected format.
* Should skip processing and return empty list.
*/
@Test
public void process_MLCommonsResponse_InvalidOutputs() {
Map<String, Object> inferenceResult = Map.of("other_field", "value");
Map<String, Object> input = Map.of("inference_results", Arrays.asList(inferenceResult));

List<ModelTensor> result = function.apply(input, null);

assertEquals(0, result.size());

// correct format, but with empty output
inferenceResult = Map.of("output", List.of(Map.of()));
input = Map.of("inference_results", List.of(inferenceResult));

result = function.apply(input, null);

assertEquals(0, result.size());

// Fallback for non-ml-commons responses
input = Map.of("invalid_format", "invalid value");
result = function.apply(input, null);

assertEquals(1, result.size());
assertEquals(input, result.getFirst().getDataAsMap());
assertEquals("response", result.getFirst().getName());
}

/**
* Tests processing of ML-Commons response containing dense vector data with numerical arrays.
* Validates that when the types are incorrect, values are parsed as nulls.
*/
@Test
public void process_MLCommonsResponse_InvalidDenseVectorFormat() {
Map<String, Object> output = Map
.of(
"name",
List.of("Not a string"),
"data_type",
"NON-EXISTENT TYPE",
"shape",
"not a list of long",
"data",
"not a list of numbers"
);
Map<String, Object> inferenceResult = Map.of("output", Arrays.asList(output));
Map<String, Object> input = Map.of("inference_results", Arrays.asList(inferenceResult));

List<ModelTensor> result = function.apply(input, null);

assertEquals(1, result.size());
ModelTensor tensor = result.getFirst();
assertEquals(OUTPUT_FIELD, tensor.getName());
assertNull(tensor.getShape());
assertNull(tensor.getData());
assertNull(tensor.getDataType());
}
}
Loading