Skip to content

Commit b69fab3

Browse files
XiaSq-eng
authored andcommitted
poc
Signed-off-by: Shiqi Xia <[email protected]>
1 parent 71d47e9 commit b69fab3

File tree

2 files changed

+52
-0
lines changed

2 files changed

+52
-0
lines changed

common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,7 @@ public <T> T createPayload(String action, Map<String, String> parameters) {
346346
String payload = connectorAction.get().getRequestBody();
347347
payload = fillNullParameters(parameters, payload);
348348
parseParameters(parameters);
349+
payload = removeMissingParameterFields(payload, parameters);
349350
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
350351
payload = substitutor.replace(payload);
351352

@@ -357,6 +358,27 @@ public <T> T createPayload(String action, Map<String, String> parameters) {
357358
return (T) parameters.get("http_body");
358359
}
359360

361+
private String removeMissingParameterFields(String payload, Map<String, String> params) {
362+
// Match: "xxx": "${parameters.yyy}" or "xxx": {parameters.yyy}
363+
Pattern pattern = Pattern.compile(
364+
"\\s*\"[^\"]+\"\\s*:\\s*(\"?\\$?\\{parameters\\.([^}]+)\\}\"?)\\s*,?"
365+
);
366+
Matcher matcher = pattern.matcher(payload);
367+
StringBuffer sb = new StringBuffer();
368+
369+
while (matcher.find()) {
370+
String paramName = matcher.group(2); // yyy
371+
if (!params.containsKey(paramName)) {
372+
matcher.appendReplacement(sb, "");
373+
} else {
374+
matcher.appendReplacement(sb, Matcher.quoteReplacement(matcher.group(0)));
375+
}
376+
}
377+
matcher.appendTail(sb);
378+
return sb.toString().replaceAll(",\\s*}", "}").replaceAll(",\\s*]", "]");
379+
}
380+
381+
360382
protected String fillNullParameters(Map<String, String> parameters, String payload) {
361383
List<String> bodyParams = findStringParametersWithNullDefaultValue(payload);
362384
String newPayload = payload;

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.escapeRemoteInferenceInputData;
1010
import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processInput;
1111

12+
import java.io.IOException;
13+
import java.lang.reflect.Field;
1214
import java.util.Arrays;
1315
import java.util.Collection;
1416
import java.util.HashMap;
@@ -17,6 +19,8 @@
1719
import java.util.Map;
1820
import java.util.Optional;
1921
import java.util.concurrent.atomic.AtomicBoolean;
22+
import java.util.regex.Matcher;
23+
import java.util.regex.Pattern;
2024

2125
import org.apache.logging.log4j.Logger;
2226
import org.opensearch.ExceptionsHelper;
@@ -28,11 +32,14 @@
2832
import org.opensearch.common.collect.Tuple;
2933
import org.opensearch.common.unit.TimeValue;
3034
import org.opensearch.common.util.TokenBucket;
35+
import org.opensearch.common.xcontent.XContentFactory;
3136
import org.opensearch.commons.ConfigConstants;
3237
import org.opensearch.commons.authuser.User;
3338
import org.opensearch.core.action.ActionListener;
3439
import org.opensearch.core.rest.RestStatus;
3540
import org.opensearch.core.xcontent.NamedXContentRegistry;
41+
import org.opensearch.core.xcontent.ToXContent;
42+
import org.opensearch.core.xcontent.XContentBuilder;
3643
import org.opensearch.ml.common.FunctionName;
3744
import org.opensearch.ml.common.connector.Connector;
3845
import org.opensearch.ml.common.connector.ConnectorAction;
@@ -42,6 +49,8 @@
4249
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
4350
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
4451
import org.opensearch.ml.common.input.MLInput;
52+
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
53+
import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters;
4554
import org.opensearch.ml.common.model.MLGuard;
4655
import org.opensearch.ml.common.output.model.ModelTensorOutput;
4756
import org.opensearch.ml.common.output.model.ModelTensors;
@@ -83,6 +92,7 @@ default void executeAction(String action, MLInput mlInput, ActionListener<MLTask
8392
MLInput
8493
.builder()
8594
.algorithm(FunctionName.TEXT_EMBEDDING)
95+
.parameters(mlInput.getParameters())
8696
.inputDataset(TextDocsInputDataSet.builder().docs(textDocs).build())
8797
.build(),
8898
new ExecutionContext(sequence++),
@@ -187,6 +197,26 @@ default void preparePayloadAndInvoke(
187197
inputParameters.putAll(((RemoteInferenceInputDataSet) inputDataset).getParameters());
188198
}
189199
parameters.putAll(inputParameters);
200+
201+
MLAlgoParams algoParams = mlInput.getParameters();
202+
if (algoParams != null) {
203+
Map<String, String> parametersMap = new HashMap<>();
204+
String algoParamsStr = algoParams.toString();
205+
Pattern pattern = Pattern.compile("\\(([^)]+)\\)");
206+
Matcher matcher = pattern.matcher(algoParamsStr);
207+
if (matcher.find()) {
208+
String bracketContent = matcher.group(1);
209+
pattern = Pattern.compile("(\\w+)=([^,\\s]+)");
210+
matcher = pattern.matcher(bracketContent);
211+
while (matcher.find()) {
212+
String fieldName = matcher.group(1);
213+
String fieldValue = matcher.group(2);
214+
parametersMap.put(fieldName, fieldValue);
215+
}
216+
}
217+
parameters.putAll(parametersMap);
218+
}
219+
190220
RemoteInferenceInputDataSet inputData = processInput(action, mlInput, connector, parameters, getScriptService());
191221
if (inputData.getParameters() != null) {
192222
parameters.putAll(inputData.getParameters());

0 commit comments

Comments
 (0)