Skip to content

Commit 4cba96b

Browse files
XiaSq-eng
authored andcommitted
passed all UT
Signed-off-by: Shiqi Xia <[email protected]>
1 parent f9270db commit 4cba96b

File tree

3 files changed

+22
-29
lines changed

3 files changed

+22
-29
lines changed

common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,9 @@ public <T> T createPayload(String action, Map<String, String> parameters) {
359359
}
360360

361361
public String removeMissingParameterFields(String payload, Map<String, String> params) {
362-
// Match: "xxx": "${parameters.yyy}" or "xxx": {parameters.yyy}
362+
if (params == null) {
363+
return payload;
364+
}
363365
Pattern pattern = Pattern.compile(
364366
"\\s*\"[^\"]+\"\\s*:\\s*(\"?\\$?\\{parameters\\.([^}]+)\\}\"?)\\s*,?"
365367
);
@@ -368,7 +370,7 @@ public String removeMissingParameterFields(String payload, Map<String, String> p
368370

369371
while (matcher.find()) {
370372
String paramName = matcher.group(2); // yyy
371-
if (!params.containsKey(paramName)) {
373+
if (!params.containsKey(paramName) && !"input".equals(paramName)) {
372374
matcher.appendReplacement(sb, "");
373375
} else {
374376
matcher.appendReplacement(sb, Matcher.quoteReplacement(matcher.group(0)));

common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,22 @@ public void createPayload() {
193193
Assert.assertEquals("{\"input\": \"test input value\"}", predictPayload);
194194
}
195195

196+
@Test
197+
public void createPayload_ExtraParams() {
198+
199+
String requestBody = "{\"input\": \"${parameters.input}\", \"parameters\": {\"sparseEmbeddingFormat\": \"${parameters.sparseEmbeddingFormat}\", \"content_type\": \"${parameters.content_type}\" }}";
200+
String expected = "{\"input\": \"test value\", \"parameters\": {\"sparseEmbeddingFormat\": \"WORD\", \"content_type\": \"query\" }}";
201+
202+
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
203+
Map<String, String> parameters = new HashMap<>();
204+
parameters.put("input", "test value");
205+
parameters.put("sparseEmbeddingFormat", "WORD");
206+
parameters.put("content_type", "query");
207+
String predictPayload = connector.createPayload(PREDICT.name(), parameters);
208+
connector.validatePayload(predictPayload);
209+
Assert.assertEquals(expected, predictPayload);
210+
}
211+
196212
@Test
197213
public void parseResponse_modelTensorJson() throws IOException {
198214
HttpConnector connector = createHttpConnector();
@@ -439,7 +455,7 @@ public void removeMissingParameterFields_MissingAll() {
439455
Map<String, String> params = new HashMap<>();
440456

441457
String payload = "{\"input\": ${parameters.input}, \"parameters\": {\"sparseEmbeddingFormat\": \"${parameters.sparseEmbeddingFormat}\", \"content_type\": \"${parameters.content_type}\"}}";
442-
String expected = "{ \"parameters\": {}}";
458+
String expected = "{\"input\": ${parameters.input}, \"parameters\": {}}";
443459
String result = connector.removeMissingParameterFields(payload, params);
444460
Assert.assertEquals(expected, result);
445461
}
@@ -450,7 +466,6 @@ public void removeMissingParameterFields_Nest() {
450466
Map<String, String> params = new HashMap<>();
451467
params.put("input", "test value");
452468

453-
// Case 1: Payload with valid parameter placeholders
454469
String payload = "{\"input\": \"${parameters.input}\", \"parameters\": {\"nested\": {\"sparseEmbeddingFormat\": \"${parameters.sparseEmbeddingFormat}\"}}}";
455470
String expected = "{\"input\": \"${parameters.input}\", \"parameters\": {\"nested\": {}}}";
456471
String result = connector.removeMissingParameterFields(payload, params);

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,13 @@
55

66
package org.opensearch.ml.engine.algorithms.remote;
77

8-
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
9-
//import static org.opensearch.ml.common.dataset.SearchQueryInputDataset.xContentRegistry;
108
import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.SKIP_VALIDATE_MISSING_PARAMETERS;
119
import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.escapeRemoteInferenceInputData;
1210
import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processInput;
1311

1412
import java.io.IOException;
15-
import java.lang.reflect.Field;
1613
import java.util.*;
1714
import java.util.concurrent.atomic.AtomicBoolean;
18-
import java.util.regex.Matcher;
19-
import java.util.regex.Pattern;
2015

2116
import com.fasterxml.jackson.databind.ObjectMapper;
2217
import org.apache.logging.log4j.Logger;
@@ -30,7 +25,6 @@
3025
import org.opensearch.common.unit.TimeValue;
3126
import org.opensearch.common.util.TokenBucket;
3227
import org.opensearch.common.xcontent.XContentFactory;
33-
import org.opensearch.common.xcontent.XContentType;
3428
import org.opensearch.commons.ConfigConstants;
3529
import org.opensearch.commons.authuser.User;
3630
import org.opensearch.core.action.ActionListener;
@@ -46,7 +40,6 @@
4640
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
4741
import org.opensearch.ml.common.input.MLInput;
4842
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
49-
//import org.opensearch.ml.common.input.parameter.textembedding.SparseEmbeddingFormat;
5043
import org.opensearch.ml.common.model.MLGuard;
5144
import org.opensearch.ml.common.output.model.ModelTensorOutput;
5245
import org.opensearch.ml.common.output.model.ModelTensors;
@@ -88,7 +81,7 @@ default void executeAction(String action, MLInput mlInput, ActionListener<MLTask
8881
MLInput
8982
.builder()
9083
.algorithm(FunctionName.TEXT_EMBEDDING)
91-
.parameters(mlInput.getParameters())
84+
// .parameters(mlInput.getParameters())
9285
.inputDataset(TextDocsInputDataSet.builder().docs(textDocs).build())
9386
.build(),
9487
new ExecutionContext(sequence++),
@@ -203,23 +196,6 @@ default void preparePayloadAndInvoke(
203196
actionListener.onFailure(e);
204197
}
205198
}
206-
// if (algoParams != null) {
207-
// Map<String, String> parametersMap = new HashMap<>();
208-
// String algoParamsStr = algoParams.toString();
209-
// Pattern pattern = Pattern.compile("\\(([^)]+)\\)");
210-
// Matcher matcher = pattern.matcher(algoParamsStr);
211-
// if (matcher.find()) {
212-
// String bracketContent = matcher.group(1);
213-
// pattern = Pattern.compile("(\\w+)=([^,\\s]+)");
214-
// matcher = pattern.matcher(bracketContent);
215-
// while (matcher.find()) {
216-
// String fieldName = matcher.group(1);
217-
// String fieldValue = matcher.group(2);
218-
// parametersMap.put(fieldName, fieldValue);
219-
// }
220-
// }
221-
// parameters.putAll(parametersMap);
222-
// }
223199

224200
RemoteInferenceInputDataSet inputData = processInput(action, mlInput, connector, parameters, getScriptService());
225201
if (inputData.getParameters() != null) {

0 commit comments

Comments
 (0)