-
Notifications
You must be signed in to change notification settings - Fork 186
Add ml-commons passthrough post process function #4111
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,121 @@ | ||
| /* | ||
| * Copyright OpenSearch Contributors | ||
| * SPDX-License-Identifier: Apache-2.0 | ||
| */ | ||
|
|
||
| package org.opensearch.ml.common.connector.functions.postprocess; | ||
|
|
||
| import static org.opensearch.ml.common.output.model.ModelTensors.OUTPUT_FIELD; | ||
|
|
||
| import java.util.ArrayList; | ||
| import java.util.List; | ||
| import java.util.Map; | ||
|
|
||
| import org.opensearch.ml.common.output.model.MLResultDataType; | ||
| import org.opensearch.ml.common.output.model.ModelTensor; | ||
|
|
||
| /** | ||
| * A post-processing function for calling a remote ml commons instance that preserves the original neural sparse response structure | ||
| * to avoid double-wrapping when receiving responses from another ML-Commons instance. | ||
| */ | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add example in the doc for better understanding? |
||
| public class RemoteMlCommonsPassthroughPostProcessFunction extends ConnectorPostProcessFunction<Map<String, Object>> { | ||
| @Override | ||
| public void validate(Object input) { | ||
| if (!(input instanceof Map) && !(input instanceof List)) { | ||
| throw new IllegalArgumentException("Post process function input must be a Map or List"); | ||
| } | ||
| } | ||
|
|
||
| @Override | ||
| public List<ModelTensor> process(Map<String, Object> input, MLResultDataType dataType) { | ||
|
||
| // Check if this is an ML-Commons response with inference_results | ||
| if (input.containsKey("inference_results") && input.get("inference_results") instanceof List) { | ||
| List<Map<String, Object>> inferenceResults = (List<Map<String, Object>>) input.get("inference_results"); | ||
|
|
||
| List<ModelTensor> modelTensors = new ArrayList<>(); | ||
| for (Map<String, Object> result : inferenceResults) { | ||
| // Extract the output field which contains the ModelTensor data | ||
| if (result.containsKey("output") && result.get("output") instanceof List) { | ||
| List<Map<String, Object>> outputs = (List<Map<String, Object>>) result.get("output"); | ||
| for (Map<String, Object> output : outputs) { | ||
| // This inner map should represent a model tensor, so we try to parse and instantiate a new one. | ||
| ModelTensor modelTensor = createModelTensorFromMap(output); | ||
| if (modelTensor != null) { | ||
| modelTensors.add(modelTensor); | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| return modelTensors; | ||
| } | ||
|
|
||
| // Fallback for non-ML-Commons responses | ||
| ModelTensor tensor = ModelTensor.builder().name("response").dataAsMap(input).build(); | ||
|
|
||
| return List.of(tensor); | ||
| } | ||
|
|
||
| /** | ||
| * Creates a ModelTensor from a Map<String, Object> representation based on the API format | ||
| * of the /_predict API | ||
| */ | ||
| private ModelTensor createModelTensorFromMap(Map<String, Object> map) { | ||
q-andy marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if (map == null || map.isEmpty()) { | ||
| return null; | ||
| } | ||
|
|
||
| // Get name. If name is null or not a String, default to OUTPUT_FIELD | ||
| Object uncastedName = map.get(ModelTensor.NAME_FIELD); | ||
| String name = uncastedName instanceof String castedName ? castedName : OUTPUT_FIELD; | ||
| String result = (String) map.get(ModelTensor.RESULT_FIELD); | ||
| Map<String, Object> dataAsMap = (Map<String, Object>) map.get(ModelTensor.DATA_AS_MAP_FIELD); | ||
|
|
||
| // Handle data type | ||
| MLResultDataType dataType = null; | ||
| if (map.containsKey(ModelTensor.DATA_TYPE_FIELD)) { | ||
| Object dataTypeObj = map.get(ModelTensor.DATA_TYPE_FIELD); | ||
| if (dataTypeObj instanceof String) { | ||
q-andy marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| try { | ||
| dataType = MLResultDataType.valueOf((String) dataTypeObj); | ||
| } catch (IllegalArgumentException e) { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also what is the reason for leaving as null?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I left as null instead of treating like 5xx is because even if the model data type is invalid, there may be usecases where we can still parse and use the model data or dataAsMap. E.g. some model types like neural sparse or NER don't include data type as part of the model response, so null is still valid. Right now since we're focused on neural sparse and dense, my thought process is its better to leave it flexible to be able to possible handle different model response formats. For example, in the future we me add a new datatype for dense models and we might inference that from an older version of ml-commons: perhaps we can still use the data by casting the data at the processor level. Updated the comment to explain this. |
||
| // Invalid data type, leave as null | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // Handle shape | ||
| long[] shape = null; | ||
| if (map.containsKey(ModelTensor.SHAPE_FIELD)) { | ||
| Object shapeObj = map.get(ModelTensor.SHAPE_FIELD); | ||
| if (shapeObj instanceof List<?> shapeList) { | ||
|
||
| shape = new long[shapeList.size()]; | ||
| for (int i = 0; i < shapeList.size(); i++) { | ||
| Object item = shapeList.get(i); | ||
| if (item instanceof Number) { | ||
| shape[i] = ((Number) item).longValue(); | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // Handle data array | ||
| Number[] data = null; | ||
| if (map.containsKey(ModelTensor.DATA_FIELD)) { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems like both of the underlying logic could be put in a common method like: |
||
| Object dataObj = map.get(ModelTensor.DATA_FIELD); | ||
| if (dataObj instanceof List<?> dataList) { | ||
|
||
| data = new Number[dataList.size()]; | ||
| for (int i = 0; i < dataList.size(); i++) { | ||
| Object item = dataList.get(i); | ||
| if (item instanceof Number) { | ||
| data[i] = (Number) item; | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // For now, we skip handling byte buffer since it's not needed for neural sparse and dense model use cases. | ||
|
|
||
| return ModelTensor.builder().name(name).dataType(dataType).shape(shape).data(data).result(result).dataAsMap(dataAsMap).build(); | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,193 @@ | ||
| /* | ||
| * Copyright OpenSearch Contributors | ||
| * SPDX-License-Identifier: Apache-2.0 | ||
| */ | ||
|
|
||
| package org.opensearch.ml.common.connector.functions.postprocess; | ||
|
|
||
| import static org.junit.Assert.assertEquals; | ||
| import static org.junit.Assert.assertNull; | ||
| import static org.opensearch.ml.common.output.model.ModelTensors.OUTPUT_FIELD; | ||
|
|
||
| import java.util.Arrays; | ||
| import java.util.List; | ||
| import java.util.Map; | ||
|
|
||
| import org.junit.Before; | ||
| import org.junit.Rule; | ||
| import org.junit.Test; | ||
| import org.junit.rules.ExpectedException; | ||
| import org.opensearch.ml.common.output.model.MLResultDataType; | ||
| import org.opensearch.ml.common.output.model.ModelTensor; | ||
|
|
||
| public class RemoteMlCommonsPassthroughPostProcessFunctionTest { | ||
| @Rule | ||
| public ExpectedException exceptionRule = ExpectedException.none(); | ||
|
|
||
| RemoteMlCommonsPassthroughPostProcessFunction function; | ||
|
|
||
| @Before | ||
| public void setUp() { | ||
| function = new RemoteMlCommonsPassthroughPostProcessFunction(); | ||
| } | ||
|
|
||
| @Test | ||
| public void process_WrongInput_NotMapOrList() { | ||
| exceptionRule.expect(IllegalArgumentException.class); | ||
| exceptionRule.expectMessage("Post process function input must be a Map or List"); | ||
| function.apply("abc", null); | ||
| } | ||
|
|
||
| /** | ||
| * Tests processing of ML-Commons response containing sparse vector data with rank features. | ||
| * Validates that sparse vectors with dataAsMap containing token-score pairs are correctly parsed. | ||
| */ | ||
| @Test | ||
| public void process_MLCommonsResponse_RankFeatures() { | ||
| Map<String, Object> rankFeatures = Map | ||
| .of("increasingly", 0.028670792, "achievements", 0.4906937, "nation", 0.15371077, "hello", 0.35982144, "today", 3.0966291); | ||
| Map<String, Object> innerDataAsMap = Map.of("response", Arrays.asList(rankFeatures)); | ||
| Map<String, Object> output = Map.of("name", "output", "dataAsMap", innerDataAsMap); | ||
| Map<String, Object> inferenceResult = Map.of("output", Arrays.asList(output)); | ||
| Map<String, Object> input = Map.of("inference_results", Arrays.asList(inferenceResult)); | ||
|
|
||
| List<ModelTensor> result = function.apply(input, null); | ||
|
|
||
| assertEquals(1, result.size()); | ||
| ModelTensor tensor = result.get(0); | ||
| assertEquals("output", tensor.getName()); | ||
| assertEquals(innerDataAsMap, tensor.getDataAsMap()); | ||
|
|
||
| // Verify the nested sparse data structure | ||
| Map<String, Object> dataAsMap = (Map<String, Object>) tensor.getDataAsMap(); | ||
| List<Map<String, Object>> response = (List<Map<String, Object>>) dataAsMap.get("response"); | ||
| assertEquals(1, response.size()); | ||
| assertEquals(0.35982144, (Double) response.get(0).get("hello"), 0.0001); | ||
| assertEquals(3.0966291, (Double) response.get(0).get("today"), 0.0001); | ||
| } | ||
|
|
||
| /** | ||
| * Tests processing of ML-Commons response containing dense vector data with numerical arrays. | ||
| * Validates that dense vectors with data_type, shape, and data fields are correctly parsed. | ||
| */ | ||
| @Test | ||
| public void process_MLCommonsResponse_DenseVector() { | ||
| Map<String, Object> output = Map | ||
| .of( | ||
| "name", | ||
| "sentence_embedding", | ||
| "data_type", | ||
| "FLOAT32", | ||
| "shape", | ||
| Arrays.asList(3L), | ||
| "data", | ||
| Arrays.asList(0.5400895, -0.19082281, 0.4996347) | ||
| ); | ||
| Map<String, Object> inferenceResult = Map.of("output", Arrays.asList(output)); | ||
| Map<String, Object> input = Map.of("inference_results", Arrays.asList(inferenceResult)); | ||
|
|
||
| List<ModelTensor> result = function.apply(input, null); | ||
|
|
||
| assertEquals(1, result.size()); | ||
| ModelTensor tensor = result.get(0); | ||
| assertEquals("sentence_embedding", tensor.getName()); | ||
| assertEquals(MLResultDataType.FLOAT32, tensor.getDataType()); | ||
| assertEquals(1, tensor.getShape().length); | ||
| assertEquals(3L, tensor.getShape()[0]); | ||
| assertEquals(3, tensor.getData().length); | ||
| assertEquals(0.5400895, tensor.getData()[0].doubleValue(), 0.0001); | ||
| } | ||
|
|
||
| /** | ||
| * Tests processing of ML-Commons response with multiple output tensors in a single inference result. | ||
| * Ensures all outputs are processed and returned as separate ModelTensor objects. | ||
| */ | ||
| @Test | ||
| public void process_MLCommonsResponse_MultipleOutputs() { | ||
| Map<String, Object> output1 = Map.of("name", "output1", "result", "result1"); | ||
| Map<String, Object> output2 = Map.of("name", "output2", "result", "result2"); | ||
| Map<String, Object> inferenceResult = Map.of("output", Arrays.asList(output1, output2)); | ||
| Map<String, Object> input = Map.of("inference_results", Arrays.asList(inferenceResult)); | ||
|
|
||
| List<ModelTensor> result = function.apply(input, null); | ||
|
|
||
| assertEquals(2, result.size()); | ||
| assertEquals("output1", result.get(0).getName()); | ||
| assertEquals("result1", result.get(0).getResult()); | ||
| assertEquals("output2", result.get(1).getName()); | ||
| assertEquals("result2", result.get(1).getResult()); | ||
| } | ||
|
|
||
| /** | ||
| * Tests edge case where ML-Commons response has empty inference_results array. | ||
| * Should return empty list without errors. | ||
| */ | ||
| @Test | ||
| public void process_MLCommonsResponse_EmptyInferenceResults() { | ||
| Map<String, Object> input = Map.of("inference_results", Arrays.asList()); | ||
|
|
||
| List<ModelTensor> result = function.apply(input, null); | ||
|
|
||
| assertEquals(0, result.size()); | ||
| } | ||
|
|
||
| /** | ||
| * Tests edge cases where inference result lacks the expected format. | ||
| * Should skip processing and return empty list. | ||
| */ | ||
| @Test | ||
| public void process_MLCommonsResponse_InvalidOutputs() { | ||
| Map<String, Object> inferenceResult = Map.of("other_field", "value"); | ||
| Map<String, Object> input = Map.of("inference_results", Arrays.asList(inferenceResult)); | ||
|
|
||
| List<ModelTensor> result = function.apply(input, null); | ||
|
|
||
| assertEquals(0, result.size()); | ||
|
|
||
| // correct format, but with empty output | ||
| inferenceResult = Map.of("output", List.of(Map.of())); | ||
| input = Map.of("inference_results", List.of(inferenceResult)); | ||
|
|
||
| result = function.apply(input, null); | ||
|
|
||
| assertEquals(0, result.size()); | ||
|
|
||
| // Fallback for non-ml-commons responses | ||
| input = Map.of("invalid_format", "invalid value"); | ||
| result = function.apply(input, null); | ||
|
|
||
| assertEquals(1, result.size()); | ||
| assertEquals(input, result.getFirst().getDataAsMap()); | ||
| assertEquals("response", result.getFirst().getName()); | ||
| } | ||
|
|
||
| /** | ||
| * Tests processing of ML-Commons response containing dense vector data with numerical arrays. | ||
| * Validates that when the types are incorrect, values are parsed as nulls. | ||
| */ | ||
| @Test | ||
| public void process_MLCommonsResponse_InvalidDenseVectorFormat() { | ||
| Map<String, Object> output = Map | ||
| .of( | ||
| "name", | ||
| List.of("Not a string"), | ||
| "data_type", | ||
| "NON-EXISTENT TYPE", | ||
| "shape", | ||
| "not a list of long", | ||
| "data", | ||
| "not a list of numbers" | ||
| ); | ||
| Map<String, Object> inferenceResult = Map.of("output", Arrays.asList(output)); | ||
| Map<String, Object> input = Map.of("inference_results", Arrays.asList(inferenceResult)); | ||
|
|
||
| List<ModelTensor> result = function.apply(input, null); | ||
|
|
||
| assertEquals(1, result.size()); | ||
| ModelTensor tensor = result.getFirst(); | ||
| assertEquals(OUTPUT_FIELD, tensor.getName()); | ||
| assertNull(tensor.getShape()); | ||
| assertNull(tensor.getData()); | ||
| assertNull(tensor.getDataType()); | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as this is a special kind, can we please add a comment here?