Skip to content

Commit d3c64e7

Browse files
committed
Remove the removeMissParameterFields
Signed-off-by: Shiqi Xia <[email protected]>
1 parent e7d6d73 commit d3c64e7

File tree

5 files changed

+126
-95
lines changed

5 files changed

+126
-95
lines changed

common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,6 @@ public <T> T createPayload(String action, Map<String, String> parameters) {
346346
String payload = connectorAction.get().getRequestBody();
347347
payload = fillNullParameters(parameters, payload);
348348
parseParameters(parameters);
349-
payload = removeMissingParameterFields(payload, parameters);
350349
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
351350
payload = substitutor.replace(payload);
352351

@@ -358,30 +357,6 @@ public <T> T createPayload(String action, Map<String, String> parameters) {
358357
return (T) parameters.get("http_body");
359358
}
360359

361-
/**
362-
* Removes fields from the given JSON payload string that correspond to parameters
363-
* not present in the provided parameter map.
364-
*/
365-
public String removeMissingParameterFields(String payload, Map<String, String> params) {
366-
if (params == null) {
367-
return payload;
368-
}
369-
Pattern pattern = Pattern.compile("\\s*\"[^\"]+\"\\s*:\\s*(\"?\\$?\\{parameters\\.([^}]+)\\}\"?)\\s*,?");
370-
Matcher matcher = pattern.matcher(payload);
371-
StringBuffer sb = new StringBuffer();
372-
373-
while (matcher.find()) {
374-
String paramName = matcher.group(2); // yyy
375-
if (!params.containsKey(paramName) && !"input".equals(paramName)) {
376-
matcher.appendReplacement(sb, "");
377-
} else {
378-
matcher.appendReplacement(sb, Matcher.quoteReplacement(matcher.group(0)));
379-
}
380-
}
381-
matcher.appendTail(sb);
382-
return sb.toString().replaceAll(",\\s*}", "}").replaceAll(",\\s*]", "]");
383-
}
384-
385360
protected String fillNullParameters(Map<String, String> parameters, String payload) {
386361
List<String> bodyParams = findStringParametersWithNullDefaultValue(payload);
387362
String newPayload = payload;

common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ public class StringUtils {
7878
}
7979
public static final String TO_STRING_FUNCTION_NAME = ".toString()";
8080

81-
private static final ObjectMapper MAPPER = new ObjectMapper();
81+
public static final ObjectMapper MAPPER = new ObjectMapper();
8282

8383
public static boolean isValidJsonString(String json) {
8484
if (json == null || json.isBlank()) {

common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java

Lines changed: 17 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,23 @@ public void createPayload_ExtraParams() {
211211
Assert.assertEquals(expected, predictPayload);
212212
}
213213

214+
@Test
215+
public void createPayload_MissingParamsInvalidJson() {
216+
exceptionRule.expect(IllegalArgumentException.class);
217+
exceptionRule
218+
.expectMessage(
219+
"Invalid payload: {\"input\": \"test value\", \"parameters\": {\"sparseEmbeddingFormat\": \"WORD\", \"content_type\": ${parameters.content_type} }}"
220+
);
221+
String requestBody =
222+
"{\"input\": \"${parameters.input}\", \"parameters\": {\"sparseEmbeddingFormat\": \"${parameters.sparseEmbeddingFormat}\", \"content_type\": ${parameters.content_type} }}";
223+
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
224+
Map<String, String> parameters = new HashMap<>();
225+
parameters.put("input", "test value");
226+
parameters.put("sparseEmbeddingFormat", "WORD");
227+
String predictPayload = connector.createPayload(PREDICT.name(), parameters);
228+
connector.validatePayload(predictPayload);
229+
}
230+
214231
@Test
215232
public void parseResponse_modelTensorJson() throws IOException {
216233
HttpConnector connector = createHttpConnector();
@@ -425,58 +442,4 @@ public void parse_WithTenantId() throws IOException {
425442
Assert.assertEquals("test_tenant", connector.getTenantId());
426443
}
427444

428-
@Test
429-
public void removeMissingParameterFields() {
430-
HttpConnector connector = createHttpConnector();
431-
Map<String, String> params = new HashMap<>();
432-
params.put("input", "test value");
433-
params.put("sparseEmbeddingFormat", "WORD");
434-
params.put("content_type", "query");
435-
436-
String payload =
437-
"{\"input\": ${parameters.input}, \"parameters\": {\"sparseEmbeddingFormat\": \"${parameters.sparseEmbeddingFormat}\", \"content_type\": \"${parameters.content_type}\" }}";
438-
String expected =
439-
"{\"input\": ${parameters.input}, \"parameters\": {\"sparseEmbeddingFormat\": \"${parameters.sparseEmbeddingFormat}\", \"content_type\": \"${parameters.content_type}\" }}";
440-
String result = connector.removeMissingParameterFields(payload, params);
441-
Assert.assertEquals(expected, result);
442-
}
443-
444-
@Test
445-
public void removeMissingParameterFields_MissingParameters() {
446-
HttpConnector connector = createHttpConnector();
447-
Map<String, String> params = new HashMap<>();
448-
params.put("input", "test value");
449-
450-
String payload =
451-
"{\"input\": ${parameters.input}, \"parameters\": {\"sparseEmbeddingFormat\": \"${parameters.sparseEmbeddingFormat}\", \"content_type\": \"${parameters.content_type}\"}}";
452-
String expected = "{\"input\": ${parameters.input}, \"parameters\": {}}";
453-
String result = connector.removeMissingParameterFields(payload, params);
454-
Assert.assertEquals(expected, result);
455-
}
456-
457-
@Test
458-
public void removeMissingParameterFields_MissingAll() {
459-
HttpConnector connector = createHttpConnector();
460-
Map<String, String> params = new HashMap<>();
461-
462-
String payload =
463-
"{\"input\": ${parameters.input}, \"parameters\": {\"sparseEmbeddingFormat\": \"${parameters.sparseEmbeddingFormat}\", \"content_type\": \"${parameters.content_type}\"}}";
464-
String expected = "{\"input\": ${parameters.input}, \"parameters\": {}}";
465-
String result = connector.removeMissingParameterFields(payload, params);
466-
Assert.assertEquals(expected, result);
467-
}
468-
469-
@Test
470-
public void removeMissingParameterFields_Nest() {
471-
HttpConnector connector = createHttpConnector();
472-
Map<String, String> params = new HashMap<>();
473-
params.put("input", "test value");
474-
475-
String payload =
476-
"{\"input\": \"${parameters.input}\", \"parameters\": {\"nested\": {\"sparseEmbeddingFormat\": \"${parameters.sparseEmbeddingFormat}\"}}}";
477-
String expected = "{\"input\": \"${parameters.input}\", \"parameters\": {\"nested\": {}}}";
478-
String result = connector.removeMissingParameterFields(payload, params);
479-
Assert.assertEquals(expected, result);
480-
}
481-
482445
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
package org.opensearch.ml.engine.algorithms.remote;
77

8+
import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;
89
import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.SKIP_VALIDATE_MISSING_PARAMETERS;
910
import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.escapeRemoteInferenceInputData;
1011
import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processInput;
@@ -51,12 +52,11 @@
5152
import org.opensearch.ml.common.output.model.ModelTensorOutput;
5253
import org.opensearch.ml.common.output.model.ModelTensors;
5354
import org.opensearch.ml.common.transport.MLTaskResponse;
55+
import org.opensearch.ml.common.utils.StringUtils;
5456
import org.opensearch.script.ScriptService;
5557
import org.opensearch.threadpool.ThreadPool;
5658
import org.opensearch.transport.client.Client;
5759

58-
import com.fasterxml.jackson.databind.ObjectMapper;
59-
6060
import lombok.Builder;
6161

6262
public interface RemoteConnectorExecutor {
@@ -203,6 +203,7 @@ default void preparePayloadAndInvoke(
203203
parameters.putAll(parametersMap);
204204
} catch (IOException e) {
205205
actionListener.onFailure(e);
206+
return;
206207
}
207208
}
208209

@@ -246,21 +247,13 @@ && getUserRateLimiterMap().get(user.getName()) != null
246247
}
247248
}
248249

249-
default Map<String, String> getParams(MLInput mlInput) throws IOException {
250-
Map<String, String> result = new HashMap<>();
250+
static Map<String, String> getParams(MLInput mlInput) throws IOException {
251251
XContentBuilder builder = XContentFactory.jsonBuilder();
252252
mlInput.getParameters().toXContent(builder, ToXContent.EMPTY_PARAMS);
253253
builder.flush();
254254
String json = builder.toString();
255-
256-
ObjectMapper mapper = new ObjectMapper();
257-
Map<String, Object> tempMap = mapper.readValue(json, Map.class);
258-
259-
HashMap<String, String> paramMap = new HashMap<>();
260-
for (Map.Entry<String, Object> entry : tempMap.entrySet()) {
261-
paramMap.put(entry.getKey(), entry.getValue() != null ? entry.getValue().toString() : null);
262-
}
263-
return paramMap;
255+
Map<String, Object> tempMap = StringUtils.MAPPER.readValue(json, Map.class);
256+
return getParameterMap(tempMap);
264257
}
265258

266259
default BackoffPolicy getRetryBackoffPolicy(ConnectorClientConfig connectorClientConfig) {

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutorTest.java

Lines changed: 102 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77

88
import static org.mockito.ArgumentMatchers.any;
99
import static org.mockito.Mockito.argThat;
10+
import static org.mockito.Mockito.doThrow;
1011
import static org.mockito.Mockito.spy;
1112
import static org.mockito.Mockito.times;
13+
import static org.mockito.Mockito.verify;
1214
import static org.mockito.Mockito.when;
1315
import static org.opensearch.ml.common.connector.AbstractConnector.ACCESS_KEY_FIELD;
1416
import static org.opensearch.ml.common.connector.AbstractConnector.SECRET_KEY_FIELD;
@@ -32,6 +34,7 @@
3234
import org.opensearch.common.settings.Settings;
3335
import org.opensearch.common.util.concurrent.ThreadContext;
3436
import org.opensearch.core.action.ActionListener;
37+
import org.opensearch.core.xcontent.XContentBuilder;
3538
import org.opensearch.ingest.TestTemplateService;
3639
import org.opensearch.ml.common.FunctionName;
3740
import org.opensearch.ml.common.connector.AwsConnector;
@@ -41,6 +44,8 @@
4144
import org.opensearch.ml.common.connector.RetryBackoffPolicy;
4245
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
4346
import org.opensearch.ml.common.input.MLInput;
47+
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
48+
import org.opensearch.ml.common.input.parameter.clustering.KMeansParams;
4449
import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters;
4550
import org.opensearch.ml.common.input.parameter.textembedding.SparseEmbeddingFormat;
4651
import org.opensearch.ml.common.output.model.ModelTensors;
@@ -68,6 +73,9 @@ public class RemoteConnectorExecutorTest {
6873
@Mock
6974
ActionListener<Tuple<Integer, ModelTensors>> actionListener;
7075

76+
@Mock
77+
private MLAlgoParams mlInputParams;
78+
7179
@Before
7280
public void setUp() {
7381
MockitoAnnotations.openMocks(this);
@@ -174,6 +182,62 @@ public void executePreparePayloadAndInvoke_SkipValidateMissingParameterDefault()
174182
assert exception.getMessage().contains("Some parameter placeholder not filled in payload: role");
175183
}
176184

185+
@Test
186+
public void executePreparePayloadAndInvoke_PassingParameter() {
187+
Map<String, String> parameters = ImmutableMap.of(SERVICE_NAME_FIELD, "sagemaker", REGION_FIELD, "us-west-2");
188+
Connector connector = getConnector(parameters);
189+
AwsConnectorExecutor executor = getExecutor(connector);
190+
191+
RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet
192+
.builder()
193+
.parameters(Map.of("input", "You are a ${parameters.role}"))
194+
.actionType(PREDICT)
195+
.build();
196+
String actionType = inputDataSet.getActionType().toString();
197+
AsymmetricTextEmbeddingParameters inputParams = AsymmetricTextEmbeddingParameters
198+
.builder()
199+
.sparseEmbeddingFormat(SparseEmbeddingFormat.WORD)
200+
.embeddingContentType(null)
201+
.build();
202+
MLInput mlInput = MLInput
203+
.builder()
204+
.algorithm(FunctionName.TEXT_EMBEDDING)
205+
.parameters(inputParams)
206+
.inputDataset(inputDataSet)
207+
.build();
208+
209+
Exception exception = Assert
210+
.assertThrows(
211+
IllegalArgumentException.class,
212+
() -> executor.preparePayloadAndInvoke(actionType, mlInput, null, actionListener)
213+
);
214+
assert exception.getMessage().contains("Some parameter placeholder not filled in payload: role");
215+
}
216+
217+
@Test
218+
public void executePreparePayloadAndInvoke_GetParamsIOException() throws Exception {
219+
Map<String, String> parameters = ImmutableMap.of(SERVICE_NAME_FIELD, "sagemaker", REGION_FIELD, "us-west-2");
220+
Connector connector = getConnector(parameters);
221+
AwsConnectorExecutor executor = getExecutor(connector);
222+
223+
RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet
224+
.builder()
225+
.parameters(Map.of("input", "test input"))
226+
.actionType(PREDICT)
227+
.build();
228+
String actionType = inputDataSet.getActionType().toString();
229+
doThrow(new IOException("UT test IOException")).when(mlInputParams).toXContent(any(XContentBuilder.class), any());
230+
MLInput mlInput = MLInput
231+
.builder()
232+
.algorithm(FunctionName.TEXT_EMBEDDING)
233+
.parameters(mlInputParams)
234+
.inputDataset(inputDataSet)
235+
.build();
236+
237+
executor.preparePayloadAndInvoke(actionType, mlInput, null, actionListener);
238+
verify(actionListener).onFailure(argThat(e -> e instanceof IOException && e.getMessage().contains("UT test IOException")));
239+
}
240+
177241
@Test
178242
public void executeGetParams_MissingParameter() {
179243
Map<String, String> parameters = ImmutableMap.of(SERVICE_NAME_FIELD, "sagemaker", REGION_FIELD, "us-west-2");
@@ -199,7 +263,7 @@ public void executeGetParams_MissingParameter() {
199263
.build();
200264

201265
try {
202-
Map<String, String> paramsMap = executor.getParams(mlInput);
266+
Map<String, String> paramsMap = RemoteConnectorExecutor.getParams(mlInput);
203267
Map<String, String> expectedMap = new HashMap<>();
204268
expectedMap.put("sparse_embedding_format", "WORD");
205269
Assert.assertEquals(expectedMap, paramsMap);
@@ -233,7 +297,7 @@ public void executeGetParams_PassingParameter() {
233297
.build();
234298

235299
try {
236-
Map<String, String> paramsMap = executor.getParams(mlInput);
300+
Map<String, String> paramsMap = RemoteConnectorExecutor.getParams(mlInput);
237301
Map<String, String> expectedMap = new HashMap<>();
238302
expectedMap.put("sparse_embedding_format", "WORD");
239303
expectedMap.put("content_type", "PASSAGE");
@@ -242,4 +306,40 @@ public void executeGetParams_PassingParameter() {
242306
e.printStackTrace();
243307
}
244308
}
309+
310+
@Test
311+
public void executeGetParams_ConvertToString() {
312+
Map<String, String> parameters = ImmutableMap.of(SERVICE_NAME_FIELD, "sagemaker", REGION_FIELD, "us-west-2");
313+
Connector connector = getConnector(parameters);
314+
AwsConnectorExecutor executor = getExecutor(connector);
315+
316+
RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet
317+
.builder()
318+
.parameters(Map.of("input", "${parameters.input}"))
319+
.actionType(PREDICT)
320+
.build();
321+
KMeansParams inputParams = KMeansParams
322+
.builder()
323+
.centroids(5)
324+
.iterations(100)
325+
.distanceType(KMeansParams.DistanceType.EUCLIDEAN)
326+
.build();
327+
MLInput mlInput = MLInput
328+
.builder()
329+
.algorithm(FunctionName.TEXT_EMBEDDING)
330+
.parameters(inputParams)
331+
.inputDataset(inputDataSet)
332+
.build();
333+
334+
try {
335+
Map<String, String> paramsMap = RemoteConnectorExecutor.getParams(mlInput);
336+
Map<String, String> expectedMap = new HashMap<>();
337+
expectedMap.put("centroids", "5");
338+
expectedMap.put("iterations", "100");
339+
expectedMap.put("distance_type", "EUCLIDEAN");
340+
Assert.assertEquals(expectedMap, paramsMap);
341+
} catch (IOException e) {
342+
e.printStackTrace();
343+
}
344+
}
245345
}

0 commit comments

Comments
 (0)