Skip to content

Commit a89b235

Browse files
committed
Fix casts and unit tests
Signed-off-by: Andy Qin <[email protected]>
1 parent 11b5bb0 commit a89b235

File tree

2 files changed

+12
-10
lines changed

2 files changed

+12
-10
lines changed

common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/RemoteMlCommonsPassthroughPostProcessFunction.java

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import org.opensearch.ml.common.output.model.ModelTensor;
1616

1717
/**
18-
* A post-processing function for calling a remote ml commons instancethat preserves the original neural sparse response structure
18+
* A post-processing function for calling a remote ml commons instance that preserves the original neural sparse response structure
1919
* to avoid double-wrapping when receiving responses from another ML-Commons instance.
2020
*/
2121
public class RemoteMlCommonsPassthroughPostProcessFunction extends ConnectorPostProcessFunction<Map<String, Object>> {
@@ -26,7 +26,6 @@ public void validate(Object input) {
2626
}
2727
}
2828

29-
@SuppressWarnings("unchecked")
3029
@Override
3130
public List<ModelTensor> process(Map<String, Object> input, MLResultDataType dataType) {
3231
// Check if this is an ML-Commons response with inference_results
@@ -41,7 +40,9 @@ public List<ModelTensor> process(Map<String, Object> input, MLResultDataType dat
4140
for (Map<String, Object> output : outputs) {
4241
// This inner map should represent a model tensor, so we try to parse and instantiate a new one.
4342
ModelTensor modelTensor = createModelTensorFromMap(output);
44-
modelTensors.add(modelTensor);
43+
if (modelTensor != null) {
44+
modelTensors.add(modelTensor);
45+
}
4546
}
4647
}
4748
}
@@ -59,8 +60,11 @@ public List<ModelTensor> process(Map<String, Object> input, MLResultDataType dat
5960
* Creates a ModelTensor from a Map<String, Object> representation based on the API format
6061
* of the /_predict API
6162
*/
62-
@SuppressWarnings("unchecked")
6363
private ModelTensor createModelTensorFromMap(Map<String, Object> map) {
64+
if (map == null || map.isEmpty()) {
65+
return null;
66+
}
67+
6468
String name = (String) map.getOrDefault(ModelTensor.NAME_FIELD, OUTPUT_FIELD);
6569
String result = (String) map.get(ModelTensor.RESULT_FIELD);
6670
Map<String, Object> dataAsMap = (Map<String, Object>) map.get(ModelTensor.DATA_AS_MAP_FIELD);
@@ -82,8 +86,7 @@ private ModelTensor createModelTensorFromMap(Map<String, Object> map) {
8286
long[] shape = null;
8387
if (map.containsKey(ModelTensor.SHAPE_FIELD)) {
8488
Object shapeObj = map.get(ModelTensor.SHAPE_FIELD);
85-
if (shapeObj instanceof List) {
86-
List<?> shapeList = (List<?>) shapeObj;
89+
if (shapeObj instanceof List<?> shapeList) {
8790
shape = new long[shapeList.size()];
8891
for (int i = 0; i < shapeList.size(); i++) {
8992
Object item = shapeList.get(i);
@@ -98,8 +101,7 @@ private ModelTensor createModelTensorFromMap(Map<String, Object> map) {
98101
Number[] data = null;
99102
if (map.containsKey(ModelTensor.DATA_FIELD)) {
100103
Object dataObj = map.get(ModelTensor.DATA_FIELD);
101-
if (dataObj instanceof List) {
102-
List<?> dataList = (List<?>) dataObj;
104+
if (dataObj instanceof List<?> dataList) {
103105
data = new Number[dataList.size()];
104106
for (int i = 0; i < dataList.size(); i++) {
105107
Object item = dataList.get(i);

common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/RemoteMlCommonsPassthroughPostProcessFunctionTest.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ public void process_MLCommonsResponse_RankFeatures() {
5656
assertEquals(innerDataAsMap, tensor.getDataAsMap());
5757

5858
// Verify the nested sparse data structure
59-
Map<String, Object> dataAsMap = (Map<String, Object>) tensor.getDataAsMap();
59+
Map<String, Object> dataAsMap = (Map<String, Object>) tensor.getDataAsMap();
6060
List<Map<String, Object>> response = (List<Map<String, Object>>) dataAsMap.get("response");
6161
assertEquals(1, response.size());
6262
assertEquals(0.35982144, (Double) response.get(0).get("hello"), 0.0001);
@@ -88,7 +88,7 @@ public void process_MLCommonsResponse_DenseVector() {
8888
assertEquals(1, result.size());
8989
ModelTensor tensor = result.get(0);
9090
assertEquals("sentence_embedding", tensor.getName());
91-
assertEquals(3, tensor.getShape().length);
91+
assertEquals(1, tensor.getShape().length);
9292
assertEquals(3L, tensor.getShape()[0]);
9393
assertEquals(3, tensor.getData().length);
9494
assertEquals(0.5400895, tensor.getData()[0].doubleValue(), 0.0001);

0 commit comments

Comments
 (0)