Skip to content

Commit ea980e3

Browse files
XiaSq-eng
authored andcommitted
passed all UT
Signed-off-by: Shiqi Xia <[email protected]>
1 parent 4cba96b commit ea980e3

File tree

4 files changed

+60
-33
lines changed

4 files changed

+60
-33
lines changed

common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -358,13 +358,15 @@ public <T> T createPayload(String action, Map<String, String> parameters) {
358358
return (T) parameters.get("http_body");
359359
}
360360

361+
/**
362+
* Removes fields from the given JSON payload string that correspond to parameters
363+
* not present in the provided parameter map.
364+
*/
361365
public String removeMissingParameterFields(String payload, Map<String, String> params) {
362366
if (params == null) {
363367
return payload;
364368
}
365-
Pattern pattern = Pattern.compile(
366-
"\\s*\"[^\"]+\"\\s*:\\s*(\"?\\$?\\{parameters\\.([^}]+)\\}\"?)\\s*,?"
367-
);
369+
Pattern pattern = Pattern.compile("\\s*\"[^\"]+\"\\s*:\\s*(\"?\\$?\\{parameters\\.([^}]+)\\}\"?)\\s*,?");
368370
Matcher matcher = pattern.matcher(payload);
369371
StringBuffer sb = new StringBuffer();
370372

@@ -380,7 +382,6 @@ public String removeMissingParameterFields(String payload, Map<String, String> p
380382
return sb.toString().replaceAll(",\\s*}", "}").replaceAll(",\\s*]", "]");
381383
}
382384

383-
384385
protected String fillNullParameters(Map<String, String> parameters, String payload) {
385386
List<String> bodyParams = findStringParametersWithNullDefaultValue(payload);
386387
String newPayload = payload;

common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -196,8 +196,10 @@ public void createPayload() {
196196
@Test
197197
public void createPayload_ExtraParams() {
198198

199-
String requestBody = "{\"input\": \"${parameters.input}\", \"parameters\": {\"sparseEmbeddingFormat\": \"${parameters.sparseEmbeddingFormat}\", \"content_type\": \"${parameters.content_type}\" }}";
200-
String expected = "{\"input\": \"test value\", \"parameters\": {\"sparseEmbeddingFormat\": \"WORD\", \"content_type\": \"query\" }}";
199+
String requestBody =
200+
"{\"input\": \"${parameters.input}\", \"parameters\": {\"sparseEmbeddingFormat\": \"${parameters.sparseEmbeddingFormat}\", \"content_type\": \"${parameters.content_type}\" }}";
201+
String expected =
202+
"{\"input\": \"test value\", \"parameters\": {\"sparseEmbeddingFormat\": \"WORD\", \"content_type\": \"query\" }}";
201203

202204
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
203205
Map<String, String> parameters = new HashMap<>();
@@ -431,8 +433,10 @@ public void removeMissingParameterFields() {
431433
params.put("sparseEmbeddingFormat", "WORD");
432434
params.put("content_type", "query");
433435

434-
String payload = "{\"input\": ${parameters.input}, \"parameters\": {\"sparseEmbeddingFormat\": \"${parameters.sparseEmbeddingFormat}\", \"content_type\": \"${parameters.content_type}\" }}";
435-
String expected = "{\"input\": ${parameters.input}, \"parameters\": {\"sparseEmbeddingFormat\": \"${parameters.sparseEmbeddingFormat}\", \"content_type\": \"${parameters.content_type}\" }}";
436+
String payload =
437+
"{\"input\": ${parameters.input}, \"parameters\": {\"sparseEmbeddingFormat\": \"${parameters.sparseEmbeddingFormat}\", \"content_type\": \"${parameters.content_type}\" }}";
438+
String expected =
439+
"{\"input\": ${parameters.input}, \"parameters\": {\"sparseEmbeddingFormat\": \"${parameters.sparseEmbeddingFormat}\", \"content_type\": \"${parameters.content_type}\" }}";
436440
String result = connector.removeMissingParameterFields(payload, params);
437441
Assert.assertEquals(expected, result);
438442
}
@@ -443,7 +447,8 @@ public void removeMissingParameterFields_MissingParameters() {
443447
Map<String, String> params = new HashMap<>();
444448
params.put("input", "test value");
445449

446-
String payload = "{\"input\": ${parameters.input}, \"parameters\": {\"sparseEmbeddingFormat\": \"${parameters.sparseEmbeddingFormat}\", \"content_type\": \"${parameters.content_type}\"}}";
450+
String payload =
451+
"{\"input\": ${parameters.input}, \"parameters\": {\"sparseEmbeddingFormat\": \"${parameters.sparseEmbeddingFormat}\", \"content_type\": \"${parameters.content_type}\"}}";
447452
String expected = "{\"input\": ${parameters.input}, \"parameters\": {}}";
448453
String result = connector.removeMissingParameterFields(payload, params);
449454
Assert.assertEquals(expected, result);
@@ -454,7 +459,8 @@ public void removeMissingParameterFields_MissingAll() {
454459
HttpConnector connector = createHttpConnector();
455460
Map<String, String> params = new HashMap<>();
456461

457-
String payload = "{\"input\": ${parameters.input}, \"parameters\": {\"sparseEmbeddingFormat\": \"${parameters.sparseEmbeddingFormat}\", \"content_type\": \"${parameters.content_type}\"}}";
462+
String payload =
463+
"{\"input\": ${parameters.input}, \"parameters\": {\"sparseEmbeddingFormat\": \"${parameters.sparseEmbeddingFormat}\", \"content_type\": \"${parameters.content_type}\"}}";
458464
String expected = "{\"input\": ${parameters.input}, \"parameters\": {}}";
459465
String result = connector.removeMissingParameterFields(payload, params);
460466
Assert.assertEquals(expected, result);
@@ -466,7 +472,8 @@ public void removeMissingParameterFields_Nest() {
466472
Map<String, String> params = new HashMap<>();
467473
params.put("input", "test value");
468474

469-
String payload = "{\"input\": \"${parameters.input}\", \"parameters\": {\"nested\": {\"sparseEmbeddingFormat\": \"${parameters.sparseEmbeddingFormat}\"}}}";
475+
String payload =
476+
"{\"input\": \"${parameters.input}\", \"parameters\": {\"nested\": {\"sparseEmbeddingFormat\": \"${parameters.sparseEmbeddingFormat}\"}}}";
470477
String expected = "{\"input\": \"${parameters.input}\", \"parameters\": {\"nested\": {}}}";
471478
String result = connector.removeMissingParameterFields(payload, params);
472479
Assert.assertEquals(expected, result);

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,15 @@
1010
import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processInput;
1111

1212
import java.io.IOException;
13-
import java.util.*;
13+
import java.util.Map;
14+
import java.util.Optional;
15+
import java.util.HashMap;
16+
import java.util.Collection;
17+
import java.util.Arrays;
18+
import java.util.List;
19+
import java.util.Locale;
1420
import java.util.concurrent.atomic.AtomicBoolean;
1521

16-
import com.fasterxml.jackson.databind.ObjectMapper;
1722
import org.apache.logging.log4j.Logger;
1823
import org.opensearch.ExceptionsHelper;
1924
import org.opensearch.OpenSearchStatusException;
@@ -29,7 +34,9 @@
2934
import org.opensearch.commons.authuser.User;
3035
import org.opensearch.core.action.ActionListener;
3136
import org.opensearch.core.rest.RestStatus;
32-
import org.opensearch.core.xcontent.*;
37+
import org.opensearch.core.xcontent.NamedXContentRegistry;
38+
import org.opensearch.core.xcontent.XContentBuilder;
39+
import org.opensearch.core.xcontent.ToXContent;
3340
import org.opensearch.ml.common.FunctionName;
3441
import org.opensearch.ml.common.connector.Connector;
3542
import org.opensearch.ml.common.connector.ConnectorAction;
@@ -48,6 +55,8 @@
4855
import org.opensearch.threadpool.ThreadPool;
4956
import org.opensearch.transport.client.Client;
5057

58+
import com.fasterxml.jackson.databind.ObjectMapper;
59+
5160
import lombok.Builder;
5261

5362
public interface RemoteConnectorExecutor {
@@ -81,7 +90,7 @@ default void executeAction(String action, MLInput mlInput, ActionListener<MLTask
8190
MLInput
8291
.builder()
8392
.algorithm(FunctionName.TEXT_EMBEDDING)
84-
// .parameters(mlInput.getParameters())
93+
.parameters(mlInput.getParameters())
8594
.inputDataset(TextDocsInputDataSet.builder().docs(textDocs).build())
8695
.build(),
8796
new ExecutionContext(sequence++),

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutorTest.java

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -181,17 +181,22 @@ public void executeGetParams_MissingParameter() {
181181
AwsConnectorExecutor executor = getExecutor(connector);
182182

183183
RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet
184-
.builder()
185-
.parameters(Map.of("input", "${parameters.input}"))
186-
.actionType(PREDICT)
187-
.build();
184+
.builder()
185+
.parameters(Map.of("input", "${parameters.input}"))
186+
.actionType(PREDICT)
187+
.build();
188188
String actionType = inputDataSet.getActionType().toString();
189189
AsymmetricTextEmbeddingParameters inputParams = AsymmetricTextEmbeddingParameters
190-
.builder()
191-
.sparseEmbeddingFormat(SparseEmbeddingFormat.WORD)
192-
.embeddingContentType(null)
193-
.build();
194-
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).parameters(inputParams).inputDataset(inputDataSet).build();
190+
.builder()
191+
.sparseEmbeddingFormat(SparseEmbeddingFormat.WORD)
192+
.embeddingContentType(null)
193+
.build();
194+
MLInput mlInput = MLInput
195+
.builder()
196+
.algorithm(FunctionName.TEXT_EMBEDDING)
197+
.parameters(inputParams)
198+
.inputDataset(inputDataSet)
199+
.build();
195200

196201
try {
197202
Map<String, String> paramsMap = executor.getParams(mlInput);
@@ -210,17 +215,22 @@ public void executeGetParams_PassingParameter() {
210215
AwsConnectorExecutor executor = getExecutor(connector);
211216

212217
RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet
213-
.builder()
214-
.parameters(Map.of("input", "${parameters.input}"))
215-
.actionType(PREDICT)
216-
.build();
218+
.builder()
219+
.parameters(Map.of("input", "${parameters.input}"))
220+
.actionType(PREDICT)
221+
.build();
217222
String actionType = inputDataSet.getActionType().toString();
218223
AsymmetricTextEmbeddingParameters inputParams = AsymmetricTextEmbeddingParameters
219-
.builder()
220-
.sparseEmbeddingFormat(SparseEmbeddingFormat.WORD)
221-
.embeddingContentType(AsymmetricTextEmbeddingParameters.EmbeddingContentType.PASSAGE)
222-
.build();
223-
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).parameters(inputParams).inputDataset(inputDataSet).build();
224+
.builder()
225+
.sparseEmbeddingFormat(SparseEmbeddingFormat.WORD)
226+
.embeddingContentType(AsymmetricTextEmbeddingParameters.EmbeddingContentType.PASSAGE)
227+
.build();
228+
MLInput mlInput = MLInput
229+
.builder()
230+
.algorithm(FunctionName.TEXT_EMBEDDING)
231+
.parameters(inputParams)
232+
.inputDataset(inputDataSet)
233+
.build();
224234

225235
try {
226236
Map<String, String> paramsMap = executor.getParams(mlInput);

0 commit comments

Comments
 (0)