-
Notifications
You must be signed in to change notification settings - Fork 184
Parameter Passing for Predict via Remote Connector #4121
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,6 +9,7 @@ | |
| import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.escapeRemoteInferenceInputData; | ||
| import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processInput; | ||
|
|
||
| import java.io.IOException; | ||
| import java.util.Arrays; | ||
| import java.util.Collection; | ||
| import java.util.HashMap; | ||
|
|
@@ -28,11 +29,14 @@ | |
| import org.opensearch.common.collect.Tuple; | ||
| import org.opensearch.common.unit.TimeValue; | ||
| import org.opensearch.common.util.TokenBucket; | ||
| import org.opensearch.common.xcontent.XContentFactory; | ||
| import org.opensearch.commons.ConfigConstants; | ||
| import org.opensearch.commons.authuser.User; | ||
| import org.opensearch.core.action.ActionListener; | ||
| import org.opensearch.core.rest.RestStatus; | ||
| import org.opensearch.core.xcontent.NamedXContentRegistry; | ||
| import org.opensearch.core.xcontent.ToXContent; | ||
| import org.opensearch.core.xcontent.XContentBuilder; | ||
| import org.opensearch.ml.common.FunctionName; | ||
| import org.opensearch.ml.common.connector.Connector; | ||
| import org.opensearch.ml.common.connector.ConnectorAction; | ||
|
|
@@ -42,6 +46,7 @@ | |
| import org.opensearch.ml.common.dataset.TextDocsInputDataSet; | ||
| import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; | ||
| import org.opensearch.ml.common.input.MLInput; | ||
| import org.opensearch.ml.common.input.parameter.MLAlgoParams; | ||
| import org.opensearch.ml.common.model.MLGuard; | ||
| import org.opensearch.ml.common.output.model.ModelTensorOutput; | ||
| import org.opensearch.ml.common.output.model.ModelTensors; | ||
|
|
@@ -50,6 +55,8 @@ | |
| import org.opensearch.threadpool.ThreadPool; | ||
| import org.opensearch.transport.client.Client; | ||
|
|
||
| import com.fasterxml.jackson.databind.ObjectMapper; | ||
|
|
||
| import lombok.Builder; | ||
|
|
||
| public interface RemoteConnectorExecutor { | ||
|
|
@@ -83,6 +90,7 @@ default void executeAction(String action, MLInput mlInput, ActionListener<MLTask | |
| MLInput | ||
| .builder() | ||
| .algorithm(FunctionName.TEXT_EMBEDDING) | ||
| .parameters(mlInput.getParameters()) | ||
| .inputDataset(TextDocsInputDataSet.builder().docs(textDocs).build()) | ||
| .build(), | ||
| new ExecutionContext(sequence++), | ||
|
|
@@ -187,6 +195,17 @@ default void preparePayloadAndInvoke( | |
| inputParameters.putAll(((RemoteInferenceInputDataSet) inputDataset).getParameters()); | ||
| } | ||
| parameters.putAll(inputParameters); | ||
|
|
||
| MLAlgoParams algoParams = mlInput.getParameters(); | ||
| if (algoParams != null) { | ||
| try { | ||
| Map<String, String> parametersMap = getParams(mlInput); | ||
| parameters.putAll(parametersMap); | ||
| } catch (IOException e) { | ||
| actionListener.onFailure(e); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add return here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
| } | ||
| } | ||
|
|
||
| RemoteInferenceInputDataSet inputData = processInput(action, mlInput, connector, parameters, getScriptService()); | ||
| if (inputData.getParameters() != null) { | ||
| parameters.putAll(inputData.getParameters()); | ||
|
|
@@ -227,6 +246,23 @@ && getUserRateLimiterMap().get(user.getName()) != null | |
| } | ||
| } | ||
|
|
||
| default Map<String, String> getParams(MLInput mlInput) throws IOException { | ||
|
||
| Map<String, String> result = new HashMap<>(); | ||
| XContentBuilder builder = XContentFactory.jsonBuilder(); | ||
| mlInput.getParameters().toXContent(builder, ToXContent.EMPTY_PARAMS); | ||
| builder.flush(); | ||
| String json = builder.toString(); | ||
|
|
||
| ObjectMapper mapper = new ObjectMapper(); | ||
|
||
| Map<String, Object> tempMap = mapper.readValue(json, Map.class); | ||
|
|
||
| HashMap<String, String> paramMap = new HashMap<>(); | ||
| for (Map.Entry<String, Object> entry : tempMap.entrySet()) { | ||
| paramMap.put(entry.getKey(), entry.getValue() != null ? entry.getValue().toString() : null); | ||
|
||
| } | ||
| return paramMap; | ||
| } | ||
|
|
||
| default BackoffPolicy getRetryBackoffPolicy(ConnectorClientConfig connectorClientConfig) { | ||
| switch (connectorClientConfig.getRetryBackoffPolicy()) { | ||
| case EXPONENTIAL_EQUAL_JITTER: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add description for this function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we don't need remove the unused field. We should keep consistent behavior with previous code. I.e. if user set a parameter in template and don't provide values, we just keep it. Otherwise it may change the behavior at user side
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand. I’ve updated the code to follow this logic – now it keeps the unused fields consistent with the previous behavior.