Skip to content

Commit f9270db

Browse files
XiaSq-eng
authored andcommitted
version that run correctly
Signed-off-by: Shiqi Xia <[email protected]>
1 parent b69fab3 commit f9270db

File tree

4 files changed

+160
-26
lines changed

4 files changed

+160
-26
lines changed

common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ public <T> T createPayload(String action, Map<String, String> parameters) {
358358
return (T) parameters.get("http_body");
359359
}
360360

361-
private String removeMissingParameterFields(String payload, Map<String, String> params) {
361+
public String removeMissingParameterFields(String payload, Map<String, String> params) {
362362
// Match: "xxx": "${parameters.yyy}" or "xxx": {parameters.yyy}
363363
Pattern pattern = Pattern.compile(
364364
"\\s*\"[^\"]+\"\\s*:\\s*(\"?\\$?\\{parameters\\.([^}]+)\\}\"?)\\s*,?"

common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,4 +407,54 @@ public void parse_WithTenantId() throws IOException {
407407
Assert.assertEquals("test_tenant", connector.getTenantId());
408408
}
409409

410+
@Test
411+
public void removeMissingParameterFields() {
412+
HttpConnector connector = createHttpConnector();
413+
Map<String, String> params = new HashMap<>();
414+
params.put("input", "test value");
415+
params.put("sparseEmbeddingFormat", "WORD");
416+
params.put("content_type", "query");
417+
418+
String payload = "{\"input\": ${parameters.input}, \"parameters\": {\"sparseEmbeddingFormat\": \"${parameters.sparseEmbeddingFormat}\", \"content_type\": \"${parameters.content_type}\" }}";
419+
String expected = "{\"input\": ${parameters.input}, \"parameters\": {\"sparseEmbeddingFormat\": \"${parameters.sparseEmbeddingFormat}\", \"content_type\": \"${parameters.content_type}\" }}";
420+
String result = connector.removeMissingParameterFields(payload, params);
421+
Assert.assertEquals(expected, result);
422+
}
423+
424+
@Test
425+
public void removeMissingParameterFields_MissingParameters() {
426+
HttpConnector connector = createHttpConnector();
427+
Map<String, String> params = new HashMap<>();
428+
params.put("input", "test value");
429+
430+
String payload = "{\"input\": ${parameters.input}, \"parameters\": {\"sparseEmbeddingFormat\": \"${parameters.sparseEmbeddingFormat}\", \"content_type\": \"${parameters.content_type}\"}}";
431+
String expected = "{\"input\": ${parameters.input}, \"parameters\": {}}";
432+
String result = connector.removeMissingParameterFields(payload, params);
433+
Assert.assertEquals(expected, result);
434+
}
435+
436+
@Test
437+
public void removeMissingParameterFields_MissingAll() {
438+
HttpConnector connector = createHttpConnector();
439+
Map<String, String> params = new HashMap<>();
440+
441+
String payload = "{\"input\": ${parameters.input}, \"parameters\": {\"sparseEmbeddingFormat\": \"${parameters.sparseEmbeddingFormat}\", \"content_type\": \"${parameters.content_type}\"}}";
442+
String expected = "{ \"parameters\": {}}";
443+
String result = connector.removeMissingParameterFields(payload, params);
444+
Assert.assertEquals(expected, result);
445+
}
446+
447+
@Test
448+
public void removeMissingParameterFields_Nest() {
449+
HttpConnector connector = createHttpConnector();
450+
Map<String, String> params = new HashMap<>();
451+
params.put("input", "test value");
452+
453+
// Case 1: Payload with valid parameter placeholders
454+
String payload = "{\"input\": \"${parameters.input}\", \"parameters\": {\"nested\": {\"sparseEmbeddingFormat\": \"${parameters.sparseEmbeddingFormat}\"}}}";
455+
String expected = "{\"input\": \"${parameters.input}\", \"parameters\": {\"nested\": {}}}";
456+
String result = connector.removeMissingParameterFields(payload, params);
457+
Assert.assertEquals(expected, result);
458+
}
459+
410460
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java

Lines changed: 46 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,20 @@
55

66
package org.opensearch.ml.engine.algorithms.remote;
77

8+
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
9+
//import static org.opensearch.ml.common.dataset.SearchQueryInputDataset.xContentRegistry;
810
import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.SKIP_VALIDATE_MISSING_PARAMETERS;
911
import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.escapeRemoteInferenceInputData;
1012
import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processInput;
1113

1214
import java.io.IOException;
1315
import java.lang.reflect.Field;
14-
import java.util.Arrays;
15-
import java.util.Collection;
16-
import java.util.HashMap;
17-
import java.util.List;
18-
import java.util.Locale;
19-
import java.util.Map;
20-
import java.util.Optional;
16+
import java.util.*;
2117
import java.util.concurrent.atomic.AtomicBoolean;
2218
import java.util.regex.Matcher;
2319
import java.util.regex.Pattern;
2420

21+
import com.fasterxml.jackson.databind.ObjectMapper;
2522
import org.apache.logging.log4j.Logger;
2623
import org.opensearch.ExceptionsHelper;
2724
import org.opensearch.OpenSearchStatusException;
@@ -33,13 +30,12 @@
3330
import org.opensearch.common.unit.TimeValue;
3431
import org.opensearch.common.util.TokenBucket;
3532
import org.opensearch.common.xcontent.XContentFactory;
33+
import org.opensearch.common.xcontent.XContentType;
3634
import org.opensearch.commons.ConfigConstants;
3735
import org.opensearch.commons.authuser.User;
3836
import org.opensearch.core.action.ActionListener;
3937
import org.opensearch.core.rest.RestStatus;
40-
import org.opensearch.core.xcontent.NamedXContentRegistry;
41-
import org.opensearch.core.xcontent.ToXContent;
42-
import org.opensearch.core.xcontent.XContentBuilder;
38+
import org.opensearch.core.xcontent.*;
4339
import org.opensearch.ml.common.FunctionName;
4440
import org.opensearch.ml.common.connector.Connector;
4541
import org.opensearch.ml.common.connector.ConnectorAction;
@@ -50,7 +46,7 @@
5046
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
5147
import org.opensearch.ml.common.input.MLInput;
5248
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
53-
import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters;
49+
//import org.opensearch.ml.common.input.parameter.textembedding.SparseEmbeddingFormat;
5450
import org.opensearch.ml.common.model.MLGuard;
5551
import org.opensearch.ml.common.output.model.ModelTensorOutput;
5652
import org.opensearch.ml.common.output.model.ModelTensors;
@@ -200,22 +196,30 @@ default void preparePayloadAndInvoke(
200196

201197
MLAlgoParams algoParams = mlInput.getParameters();
202198
if (algoParams != null) {
203-
Map<String, String> parametersMap = new HashMap<>();
204-
String algoParamsStr = algoParams.toString();
205-
Pattern pattern = Pattern.compile("\\(([^)]+)\\)");
206-
Matcher matcher = pattern.matcher(algoParamsStr);
207-
if (matcher.find()) {
208-
String bracketContent = matcher.group(1);
209-
pattern = Pattern.compile("(\\w+)=([^,\\s]+)");
210-
matcher = pattern.matcher(bracketContent);
211-
while (matcher.find()) {
212-
String fieldName = matcher.group(1);
213-
String fieldValue = matcher.group(2);
214-
parametersMap.put(fieldName, fieldValue);
215-
}
199+
try {
200+
Map<String, String> parametersMap = getParams(mlInput);
201+
parameters.putAll(parametersMap);
202+
} catch (IOException e) {
203+
actionListener.onFailure(e);
216204
}
217-
parameters.putAll(parametersMap);
218205
}
206+
// if (algoParams != null) {
207+
// Map<String, String> parametersMap = new HashMap<>();
208+
// String algoParamsStr = algoParams.toString();
209+
// Pattern pattern = Pattern.compile("\\(([^)]+)\\)");
210+
// Matcher matcher = pattern.matcher(algoParamsStr);
211+
// if (matcher.find()) {
212+
// String bracketContent = matcher.group(1);
213+
// pattern = Pattern.compile("(\\w+)=([^,\\s]+)");
214+
// matcher = pattern.matcher(bracketContent);
215+
// while (matcher.find()) {
216+
// String fieldName = matcher.group(1);
217+
// String fieldValue = matcher.group(2);
218+
// parametersMap.put(fieldName, fieldValue);
219+
// }
220+
// }
221+
// parameters.putAll(parametersMap);
222+
// }
219223

220224
RemoteInferenceInputDataSet inputData = processInput(action, mlInput, connector, parameters, getScriptService());
221225
if (inputData.getParameters() != null) {
@@ -257,6 +261,23 @@ && getUserRateLimiterMap().get(user.getName()) != null
257261
}
258262
}
259263

264+
default Map<String, String> getParams(MLInput mlInput) throws IOException {
265+
Map<String, String> result = new HashMap<>();
266+
XContentBuilder builder = XContentFactory.jsonBuilder();
267+
mlInput.getParameters().toXContent(builder, ToXContent.EMPTY_PARAMS);
268+
builder.flush();
269+
String json = builder.toString();
270+
271+
ObjectMapper mapper = new ObjectMapper();
272+
Map<String, Object> tempMap = mapper.readValue(json, Map.class);
273+
274+
HashMap<String, String> paramMap = new HashMap<>();
275+
for (Map.Entry<String, Object> entry : tempMap.entrySet()) {
276+
paramMap.put(entry.getKey(), entry.getValue() != null ? entry.getValue().toString() : null);
277+
}
278+
return paramMap;
279+
}
280+
260281
default BackoffPolicy getRetryBackoffPolicy(ConnectorClientConfig connectorClientConfig) {
261282
switch (connectorClientConfig.getRetryBackoffPolicy()) {
262283
case EXPONENTIAL_EQUAL_JITTER:

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutorTest.java

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
import static org.opensearch.ml.common.connector.HttpConnector.SERVICE_NAME_FIELD;
1818
import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.SKIP_VALIDATE_MISSING_PARAMETERS;
1919

20+
import java.io.IOException;
2021
import java.util.Arrays;
22+
import java.util.HashMap;
2123
import java.util.Map;
2224

2325
import org.junit.Assert;
@@ -39,6 +41,8 @@
3941
import org.opensearch.ml.common.connector.RetryBackoffPolicy;
4042
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
4143
import org.opensearch.ml.common.input.MLInput;
44+
import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters;
45+
import org.opensearch.ml.common.input.parameter.textembedding.SparseEmbeddingFormat;
4246
import org.opensearch.ml.common.output.model.ModelTensors;
4347
import org.opensearch.ml.engine.encryptor.Encryptor;
4448
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
@@ -169,4 +173,63 @@ public void executePreparePayloadAndInvoke_SkipValidateMissingParameterDefault()
169173
);
170174
assert exception.getMessage().contains("Some parameter placeholder not filled in payload: role");
171175
}
176+
177+
@Test
178+
public void executeGetParams_MissingParameter() {
179+
Map<String, String> parameters = ImmutableMap.of(SERVICE_NAME_FIELD, "sagemaker", REGION_FIELD, "us-west-2");
180+
Connector connector = getConnector(parameters);
181+
AwsConnectorExecutor executor = getExecutor(connector);
182+
183+
RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet
184+
.builder()
185+
.parameters(Map.of("input", "${parameters.input}"))
186+
.actionType(PREDICT)
187+
.build();
188+
String actionType = inputDataSet.getActionType().toString();
189+
AsymmetricTextEmbeddingParameters inputParams = AsymmetricTextEmbeddingParameters
190+
.builder()
191+
.sparseEmbeddingFormat(SparseEmbeddingFormat.WORD)
192+
.embeddingContentType(null)
193+
.build();
194+
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).parameters(inputParams).inputDataset(inputDataSet).build();
195+
196+
try {
197+
Map<String, String> paramsMap = executor.getParams(mlInput);
198+
Map<String, String> expectedMap = new HashMap<>();
199+
expectedMap.put("sparse_embedding_format", "WORD");
200+
Assert.assertEquals(expectedMap, paramsMap);
201+
} catch (IOException e) {
202+
e.printStackTrace();
203+
}
204+
}
205+
206+
@Test
207+
public void executeGetParams_PassingParameter() {
208+
Map<String, String> parameters = ImmutableMap.of(SERVICE_NAME_FIELD, "sagemaker", REGION_FIELD, "us-west-2");
209+
Connector connector = getConnector(parameters);
210+
AwsConnectorExecutor executor = getExecutor(connector);
211+
212+
RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet
213+
.builder()
214+
.parameters(Map.of("input", "${parameters.input}"))
215+
.actionType(PREDICT)
216+
.build();
217+
String actionType = inputDataSet.getActionType().toString();
218+
AsymmetricTextEmbeddingParameters inputParams = AsymmetricTextEmbeddingParameters
219+
.builder()
220+
.sparseEmbeddingFormat(SparseEmbeddingFormat.WORD)
221+
.embeddingContentType(AsymmetricTextEmbeddingParameters.EmbeddingContentType.PASSAGE)
222+
.build();
223+
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).parameters(inputParams).inputDataset(inputDataSet).build();
224+
225+
try {
226+
Map<String, String> paramsMap = executor.getParams(mlInput);
227+
Map<String, String> expectedMap = new HashMap<>();
228+
expectedMap.put("sparse_embedding_format", "WORD");
229+
expectedMap.put("content_type", "PASSAGE");
230+
Assert.assertEquals(expectedMap, paramsMap);
231+
} catch (IOException e) {
232+
e.printStackTrace();
233+
}
234+
}
172235
}

0 commit comments

Comments
 (0)