|
5 | 5 |
|
6 | 6 | package org.opensearch.ml.engine.algorithms.remote;
|
7 | 7 |
|
| 8 | +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; |
| 9 | +//import static org.opensearch.ml.common.dataset.SearchQueryInputDataset.xContentRegistry; |
8 | 10 | import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.SKIP_VALIDATE_MISSING_PARAMETERS;
|
9 | 11 | import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.escapeRemoteInferenceInputData;
|
10 | 12 | import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processInput;
|
11 | 13 |
|
12 | 14 | import java.io.IOException;
|
13 | 15 | import java.lang.reflect.Field;
|
14 |
| -import java.util.Arrays; |
15 |
| -import java.util.Collection; |
16 |
| -import java.util.HashMap; |
17 |
| -import java.util.List; |
18 |
| -import java.util.Locale; |
19 |
| -import java.util.Map; |
20 |
| -import java.util.Optional; |
| 16 | +import java.util.*; |
21 | 17 | import java.util.concurrent.atomic.AtomicBoolean;
|
22 | 18 | import java.util.regex.Matcher;
|
23 | 19 | import java.util.regex.Pattern;
|
24 | 20 |
|
| 21 | +import com.fasterxml.jackson.databind.ObjectMapper; |
25 | 22 | import org.apache.logging.log4j.Logger;
|
26 | 23 | import org.opensearch.ExceptionsHelper;
|
27 | 24 | import org.opensearch.OpenSearchStatusException;
|
|
33 | 30 | import org.opensearch.common.unit.TimeValue;
|
34 | 31 | import org.opensearch.common.util.TokenBucket;
|
35 | 32 | import org.opensearch.common.xcontent.XContentFactory;
|
| 33 | +import org.opensearch.common.xcontent.XContentType; |
36 | 34 | import org.opensearch.commons.ConfigConstants;
|
37 | 35 | import org.opensearch.commons.authuser.User;
|
38 | 36 | import org.opensearch.core.action.ActionListener;
|
39 | 37 | import org.opensearch.core.rest.RestStatus;
|
40 |
| -import org.opensearch.core.xcontent.NamedXContentRegistry; |
41 |
| -import org.opensearch.core.xcontent.ToXContent; |
42 |
| -import org.opensearch.core.xcontent.XContentBuilder; |
| 38 | +import org.opensearch.core.xcontent.*; |
43 | 39 | import org.opensearch.ml.common.FunctionName;
|
44 | 40 | import org.opensearch.ml.common.connector.Connector;
|
45 | 41 | import org.opensearch.ml.common.connector.ConnectorAction;
|
|
50 | 46 | import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
|
51 | 47 | import org.opensearch.ml.common.input.MLInput;
|
52 | 48 | import org.opensearch.ml.common.input.parameter.MLAlgoParams;
|
53 |
| -import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters; |
| 49 | +//import org.opensearch.ml.common.input.parameter.textembedding.SparseEmbeddingFormat; |
54 | 50 | import org.opensearch.ml.common.model.MLGuard;
|
55 | 51 | import org.opensearch.ml.common.output.model.ModelTensorOutput;
|
56 | 52 | import org.opensearch.ml.common.output.model.ModelTensors;
|
@@ -200,22 +196,30 @@ default void preparePayloadAndInvoke(
|
200 | 196 |
|
201 | 197 | MLAlgoParams algoParams = mlInput.getParameters();
|
202 | 198 | if (algoParams != null) {
|
203 |
| - Map<String, String> parametersMap = new HashMap<>(); |
204 |
| - String algoParamsStr = algoParams.toString(); |
205 |
| - Pattern pattern = Pattern.compile("\\(([^)]+)\\)"); |
206 |
| - Matcher matcher = pattern.matcher(algoParamsStr); |
207 |
| - if (matcher.find()) { |
208 |
| - String bracketContent = matcher.group(1); |
209 |
| - pattern = Pattern.compile("(\\w+)=([^,\\s]+)"); |
210 |
| - matcher = pattern.matcher(bracketContent); |
211 |
| - while (matcher.find()) { |
212 |
| - String fieldName = matcher.group(1); |
213 |
| - String fieldValue = matcher.group(2); |
214 |
| - parametersMap.put(fieldName, fieldValue); |
215 |
| - } |
| 199 | + try { |
| 200 | + Map<String, String> parametersMap = getParams(mlInput); |
| 201 | + parameters.putAll(parametersMap); |
| 202 | + } catch (IOException e) { |
| 203 | + actionListener.onFailure(e); |
216 | 204 | }
|
217 |
| - parameters.putAll(parametersMap); |
218 | 205 | }
|
| 206 | +// if (algoParams != null) { |
| 207 | +// Map<String, String> parametersMap = new HashMap<>(); |
| 208 | +// String algoParamsStr = algoParams.toString(); |
| 209 | +// Pattern pattern = Pattern.compile("\\(([^)]+)\\)"); |
| 210 | +// Matcher matcher = pattern.matcher(algoParamsStr); |
| 211 | +// if (matcher.find()) { |
| 212 | +// String bracketContent = matcher.group(1); |
| 213 | +// pattern = Pattern.compile("(\\w+)=([^,\\s]+)"); |
| 214 | +// matcher = pattern.matcher(bracketContent); |
| 215 | +// while (matcher.find()) { |
| 216 | +// String fieldName = matcher.group(1); |
| 217 | +// String fieldValue = matcher.group(2); |
| 218 | +// parametersMap.put(fieldName, fieldValue); |
| 219 | +// } |
| 220 | +// } |
| 221 | +// parameters.putAll(parametersMap); |
| 222 | +// } |
219 | 223 |
|
220 | 224 | RemoteInferenceInputDataSet inputData = processInput(action, mlInput, connector, parameters, getScriptService());
|
221 | 225 | if (inputData.getParameters() != null) {
|
@@ -257,6 +261,23 @@ && getUserRateLimiterMap().get(user.getName()) != null
|
257 | 261 | }
|
258 | 262 | }
|
259 | 263 |
|
| 264 | + default Map<String, String> getParams(MLInput mlInput) throws IOException { |
| 265 | + Map<String, String> result = new HashMap<>(); |
| 266 | + XContentBuilder builder = XContentFactory.jsonBuilder(); |
| 267 | + mlInput.getParameters().toXContent(builder, ToXContent.EMPTY_PARAMS); |
| 268 | + builder.flush(); |
| 269 | + String json = builder.toString(); |
| 270 | + |
| 271 | + ObjectMapper mapper = new ObjectMapper(); |
| 272 | + Map<String, Object> tempMap = mapper.readValue(json, Map.class); |
| 273 | + |
| 274 | + HashMap<String, String> paramMap = new HashMap<>(); |
| 275 | + for (Map.Entry<String, Object> entry : tempMap.entrySet()) { |
| 276 | + paramMap.put(entry.getKey(), entry.getValue() != null ? entry.getValue().toString() : null); |
| 277 | + } |
| 278 | + return paramMap; |
| 279 | + } |
| 280 | + |
260 | 281 | default BackoffPolicy getRetryBackoffPolicy(ConnectorClientConfig connectorClientConfig) {
|
261 | 282 | switch (connectorClientConfig.getRetryBackoffPolicy()) {
|
262 | 283 | case EXPONENTIAL_EQUAL_JITTER:
|
|
0 commit comments