Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.opensearch.ml.common.connector.functions.postprocess.BedrockRerankPostProcessFunction;
import org.opensearch.ml.common.connector.functions.postprocess.CohereRerankPostProcessFunction;
import org.opensearch.ml.common.connector.functions.postprocess.EmbeddingPostProcessFunction;
import org.opensearch.ml.common.connector.functions.postprocess.RemoteMlCommonsPassthroughPostProcessFunction;
import org.opensearch.ml.common.output.model.MLResultDataType;
import org.opensearch.ml.common.output.model.ModelTensor;

Expand All @@ -35,6 +36,8 @@ public class MLPostProcessFunction {
public static final String BEDROCK_RERANK = "connector.post_process.bedrock.rerank";
public static final String DEFAULT_EMBEDDING = "connector.post_process.default.embedding";
public static final String DEFAULT_RERANK = "connector.post_process.default.rerank";
// ML commons passthrough unwraps a remote ml-commons response and reconstructs model tensors directly based on remote inference
public static final String ML_COMMONS_PASSTHROUGH = "connector.post_process.mlcommons.passthrough";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as this is a special kind, can we please add a comment here?


private static final Map<String, String> JSON_PATH_EXPRESSION = new HashMap<>();

Expand All @@ -46,6 +49,8 @@ public class MLPostProcessFunction {
BedrockBatchJobArnPostProcessFunction batchJobArnPostProcessFunction = new BedrockBatchJobArnPostProcessFunction();
CohereRerankPostProcessFunction cohereRerankPostProcessFunction = new CohereRerankPostProcessFunction();
BedrockRerankPostProcessFunction bedrockRerankPostProcessFunction = new BedrockRerankPostProcessFunction();
RemoteMlCommonsPassthroughPostProcessFunction remoteMlCommonsPassthroughPostProcessFunction =
new RemoteMlCommonsPassthroughPostProcessFunction();
JSON_PATH_EXPRESSION.put(OPENAI_EMBEDDING, "$.data[*].embedding");
JSON_PATH_EXPRESSION.put(COHERE_EMBEDDING, "$.embeddings");
JSON_PATH_EXPRESSION.put(COHERE_V2_EMBEDDING_FLOAT32, "$.embeddings.float");
Expand All @@ -61,6 +66,7 @@ public class MLPostProcessFunction {
JSON_PATH_EXPRESSION.put(COHERE_RERANK, "$.results");
JSON_PATH_EXPRESSION.put(BEDROCK_RERANK, "$.results");
JSON_PATH_EXPRESSION.put(DEFAULT_RERANK, "$[*]");
JSON_PATH_EXPRESSION.put(ML_COMMONS_PASSTHROUGH, "$"); // Get the entire response
POST_PROCESS_FUNCTIONS.put(OPENAI_EMBEDDING, embeddingPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(COHERE_EMBEDDING, embeddingPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(COHERE_V2_EMBEDDING_FLOAT32, embeddingPostProcessFunction);
Expand All @@ -76,6 +82,7 @@ public class MLPostProcessFunction {
POST_PROCESS_FUNCTIONS.put(COHERE_RERANK, cohereRerankPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(BEDROCK_RERANK, bedrockRerankPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(DEFAULT_RERANK, cohereRerankPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(ML_COMMONS_PASSTHROUGH, remoteMlCommonsPassthroughPostProcessFunction);
}

public static String getResponseFilter(String postProcessFunction) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.connector.functions.postprocess;

import static org.opensearch.ml.common.output.model.ModelTensors.OUTPUT_FIELD;

import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;

import org.opensearch.ml.common.output.model.MLResultDataType;
import org.opensearch.ml.common.output.model.ModelTensor;

/**
* A post-processing function for calling a remote ml commons instance that preserves the original neural sparse response structure
* to avoid double-wrapping when receiving responses from another ML-Commons instance.
*/
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add example in the doc for better understanding?

public class RemoteMlCommonsPassthroughPostProcessFunction extends ConnectorPostProcessFunction<Map<String, Object>> {
@Override
public void validate(Object input) {
if (!(input instanceof Map) && !(input instanceof List)) {
throw new IllegalArgumentException("Post process function input must be a Map or List");
}
}

/**
* Example unwrapped response:
* {
* "inference_results": [
* {
* "output": [
* {
* "name": "output",
* "dataAsMap": {
* "inference_results": [
* {
* "output": [
* {
* "name": "output",
* "dataAsMap": {
* "response": [
* {
* "increasingly": 0.028670792,
* "achievements": 0.4906937,
* ...
* }
* ]
* }
* }
* ],
* "status_code": 200.0
* }
* ]
* }
* }
* ],
* "status_code": 200
* }
* ]
* }
*
* Example unwrapped response:
*
* {
* "inference_results": [
* {
* "output": [
* {
* "name": "output",
* "dataAsMap": {
* "response": [
* {
* "increasingly": 0.028670792,
* "achievements": 0.4906937,
* ...
* }
* ]
* }
* },
* ],
* "status_code": 200
* }
* ]
* }
*
* @param mlCommonsResponse raw remote ml commons response
* @param dataType the datatype of the result, not used since datatype is set based on the response body
* @return a list of model tensors representing the inner model tensors
*/
@Override
public List<ModelTensor> process(Map<String, Object> mlCommonsResponse, MLResultDataType dataType) {
// Check if this is an ML-Commons response with inference_results
if (mlCommonsResponse.containsKey("inference_results") && mlCommonsResponse.get("inference_results") instanceof List) {
List<Map<String, Object>> inferenceResults = (List<Map<String, Object>>) mlCommonsResponse.get("inference_results");

List<ModelTensor> modelTensors = new ArrayList<>();
for (Map<String, Object> result : inferenceResults) {
// Extract the output field which contains the ModelTensor data
if (result.containsKey("output") && result.get("output") instanceof List) {
List<Map<String, Object>> outputs = (List<Map<String, Object>>) result.get("output");
for (Map<String, Object> output : outputs) {
// This inner map should represent a model tensor, so we try to parse and instantiate a new one.
ModelTensor modelTensor = createModelTensorFromMap(output);
if (modelTensor != null) {
modelTensors.add(modelTensor);
}
}
}
}

return modelTensors;
}

// Fallback for non-ML-Commons responses
ModelTensor tensor = ModelTensor.builder().name("response").dataAsMap(mlCommonsResponse).build();

return List.of(tensor);
}

/**
* Creates a ModelTensor from a Map<String, Object> representation based on the API format
* of the /_predict API
*/
private ModelTensor createModelTensorFromMap(Map<String, Object> map) {
if (map == null || map.isEmpty()) {
return null;
}

// Get name. If name is null or not a String, default to OUTPUT_FIELD
Object uncastedName = map.get(ModelTensor.NAME_FIELD);
String name = uncastedName instanceof String castedName ? castedName : OUTPUT_FIELD;
String result = (String) map.get(ModelTensor.RESULT_FIELD);

// Handle data as map
Map<String, Object> dataAsMap = (Map<String, Object>) map.get(ModelTensor.DATA_AS_MAP_FIELD);

// Handle data type. For certain models like neural sparse and non-dense remote models, this field
// is not populated and left as null instead, which is still valid
MLResultDataType dataType = null;
if (map.containsKey(ModelTensor.DATA_TYPE_FIELD)) {
Object dataTypeObj = map.get(ModelTensor.DATA_TYPE_FIELD);
if (dataTypeObj instanceof String) {
try {
dataType = MLResultDataType.valueOf((String) dataTypeObj);
} catch (IllegalArgumentException e) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IllegalArgumentException is a 500 level error. Should we treat this as a 5xx error?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also what is the reason for leaving as null?

Copy link
Contributor Author

@q-andy q-andy Aug 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left as null instead of treating like 5xx is because even if the model data type is invalid, there may be usecases where we can still parse and use the model data or dataAsMap. E.g. some model types like neural sparse or NER don't include data type as part of the model response, so null is still valid.

Right now since we're focused on neural sparse and dense, my thought process is its better to leave it flexible to be able to possible handle different model response formats. For example, in the future we me add a new datatype for dense models and we might inference that from an older version of ml-commons: perhaps we can still use the data by casting the data at the processor level. Updated the comment to explain this.

// Invalid data type, leave as null in case inner data is still useful to be parsed in the future
}
}
}

// Handle shape. For certain models like neural sparse and non-dense, null is valid since inference result
// is stored in dataAsMap, not data/shape field
long[] shape = null;
if (map.containsKey(ModelTensor.SHAPE_FIELD)) {
Number[] numbers = processNumericalArray(map, ModelTensor.SHAPE_FIELD, Number.class);
if (numbers != null) {
shape = Arrays.stream(numbers).mapToLong(Number::longValue).toArray();
}
}

// Handle shape. For certain models like neural sparse and non-dense, null is valid since inference result
// is stored in dataAsMap, not data/shape field
Number[] data = null;
if (map.containsKey(ModelTensor.DATA_FIELD)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like both of the underlying logic could be put in a common method like:

private <T> T[] processNumericArray(Object obj, Class<T> type) {
    if (obj instanceof List<?> list) {
        T[] result = (T[]) Array.newInstance(type, list.size());
        // ... process the list
        return result;
    }
    return null;
}

data = processNumericalArray(map, ModelTensor.DATA_FIELD, Number.class);
}

// For now, we skip handling byte buffer since it's not needed for neural sparse and dense model use cases.

return ModelTensor.builder().name(name).dataType(dataType).shape(shape).data(data).result(result).dataAsMap(dataAsMap).build();
}

private static <T> T[] processNumericalArray(Map<String, Object> map, String key, Class<T> type) {
Object obj = map.get(key);
if (obj instanceof List<?> list) {
T[] array = (T[]) Array.newInstance(type, list.size());
for (int i = 0; i < list.size(); i++) {
Object item = list.get(i);
if (type.isInstance(item)) {
array[i] = type.cast(item);
}
}
return array;
}
return null;
}
}
Loading
Loading