|
9 | 9 | import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.escapeRemoteInferenceInputData;
|
10 | 10 | import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processInput;
|
11 | 11 |
|
| 12 | +import java.io.IOException; |
| 13 | +import java.lang.reflect.Field; |
12 | 14 | import java.util.Arrays;
|
13 | 15 | import java.util.Collection;
|
14 | 16 | import java.util.HashMap;
|
|
17 | 19 | import java.util.Map;
|
18 | 20 | import java.util.Optional;
|
19 | 21 | import java.util.concurrent.atomic.AtomicBoolean;
|
| 22 | +import java.util.regex.Matcher; |
| 23 | +import java.util.regex.Pattern; |
20 | 24 |
|
21 | 25 | import org.apache.logging.log4j.Logger;
|
22 | 26 | import org.opensearch.ExceptionsHelper;
|
|
28 | 32 | import org.opensearch.common.collect.Tuple;
|
29 | 33 | import org.opensearch.common.unit.TimeValue;
|
30 | 34 | import org.opensearch.common.util.TokenBucket;
|
| 35 | +import org.opensearch.common.xcontent.XContentFactory; |
31 | 36 | import org.opensearch.commons.ConfigConstants;
|
32 | 37 | import org.opensearch.commons.authuser.User;
|
33 | 38 | import org.opensearch.core.action.ActionListener;
|
34 | 39 | import org.opensearch.core.rest.RestStatus;
|
35 | 40 | import org.opensearch.core.xcontent.NamedXContentRegistry;
|
| 41 | +import org.opensearch.core.xcontent.ToXContent; |
| 42 | +import org.opensearch.core.xcontent.XContentBuilder; |
36 | 43 | import org.opensearch.ml.common.FunctionName;
|
37 | 44 | import org.opensearch.ml.common.connector.Connector;
|
38 | 45 | import org.opensearch.ml.common.connector.ConnectorAction;
|
|
42 | 49 | import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
|
43 | 50 | import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
|
44 | 51 | import org.opensearch.ml.common.input.MLInput;
|
| 52 | +import org.opensearch.ml.common.input.parameter.MLAlgoParams; |
| 53 | +import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters; |
45 | 54 | import org.opensearch.ml.common.model.MLGuard;
|
46 | 55 | import org.opensearch.ml.common.output.model.ModelTensorOutput;
|
47 | 56 | import org.opensearch.ml.common.output.model.ModelTensors;
|
@@ -83,6 +92,7 @@ default void executeAction(String action, MLInput mlInput, ActionListener<MLTask
|
83 | 92 | MLInput
|
84 | 93 | .builder()
|
85 | 94 | .algorithm(FunctionName.TEXT_EMBEDDING)
|
| 95 | + .parameters(mlInput.getParameters()) |
86 | 96 | .inputDataset(TextDocsInputDataSet.builder().docs(textDocs).build())
|
87 | 97 | .build(),
|
88 | 98 | new ExecutionContext(sequence++),
|
@@ -187,6 +197,26 @@ default void preparePayloadAndInvoke(
|
187 | 197 | inputParameters.putAll(((RemoteInferenceInputDataSet) inputDataset).getParameters());
|
188 | 198 | }
|
189 | 199 | parameters.putAll(inputParameters);
|
| 200 | + |
| 201 | + MLAlgoParams algoParams = mlInput.getParameters(); |
| 202 | + if (algoParams != null) { |
| 203 | + Map<String, String> parametersMap = new HashMap<>(); |
| 204 | + String algoParamsStr = algoParams.toString(); |
| 205 | + Pattern pattern = Pattern.compile("\\(([^)]+)\\)"); |
| 206 | + Matcher matcher = pattern.matcher(algoParamsStr); |
| 207 | + if (matcher.find()) { |
| 208 | + String bracketContent = matcher.group(1); |
| 209 | + pattern = Pattern.compile("(\\w+)=([^,\\s]+)"); |
| 210 | + matcher = pattern.matcher(bracketContent); |
| 211 | + while (matcher.find()) { |
| 212 | + String fieldName = matcher.group(1); |
| 213 | + String fieldValue = matcher.group(2); |
| 214 | + parametersMap.put(fieldName, fieldValue); |
| 215 | + } |
| 216 | + } |
| 217 | + parameters.putAll(parametersMap); |
| 218 | + } |
| 219 | + |
190 | 220 | RemoteInferenceInputDataSet inputData = processInput(action, mlInput, connector, parameters, getScriptService());
|
191 | 221 | if (inputData.getParameters() != null) {
|
192 | 222 | parameters.putAll(inputData.getParameters());
|
|
0 commit comments