Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ public class ConnectorAction implements ToXContentObject, Writeable {
public static final List<String> SUPPORTED_REMOTE_SERVERS_FOR_DEFAULT_ACTION_TYPES = List.of(SAGEMAKER, OPENAI, BEDROCK, COHERE);

private static final String INBUILT_FUNC_PREFIX = "connector.";
private static final String EMBED = "embed";
private static final String EMBEDDING = "embedding";
private static final String PRE_PROCESS_FUNC = "PreProcessFunction";
private static final String POST_PROCESS_FUNC = "PostProcessFunction";
private static final Logger logger = LogManager.getLogger(ConnectorAction.class);
Expand Down Expand Up @@ -216,8 +218,8 @@ public void validatePrePostProcessFunctions(Map<String, String> parameters) {
String endPoint = substitutor.replace(url);
String remoteServer = getRemoteServerFromURL(endPoint);
if (!remoteServer.isEmpty()) {
validateProcessFunctions(remoteServer, preProcessFunction, PRE_PROCESS_FUNC);
validateProcessFunctions(remoteServer, postProcessFunction, POST_PROCESS_FUNC);
validateProcessFunctions(endPoint, remoteServer, preProcessFunction, PRE_PROCESS_FUNC);
validateProcessFunctions(endPoint, remoteServer, postProcessFunction, POST_PROCESS_FUNC);
}
}

Expand All @@ -232,7 +234,7 @@ public static String getRemoteServerFromURL(String url) {
return SUPPORTED_REMOTE_SERVERS_FOR_DEFAULT_ACTION_TYPES.stream().filter(url::contains).findFirst().orElse("");
}

private void validateProcessFunctions(String remoteServer, String processFunction, String funcNameForWarnText) {
private void validateProcessFunctions(String endPointUrl, String remoteServer, String processFunction, String funcNameForWarnText) {
if (isInBuiltProcessFunction(processFunction)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on the issue, I believe we should block the connector creation if the input process function doesn't exist. But based on your PRs I didn't find such validation, am I missing something?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @zane-neo for review.
Already two PRs done for fixing this bug and this is third part of it.
#3260
#3579.

So validation of required parameters for connector creation is done as part of #3260 . Out of that, process functions are optional field according to https://docs.opensearch.org/docs/latest/ml-commons-plugin/remote-models/blueprints/#configuration-parameters and we didn't add any validation for it. All other parameters are also discussed in that PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand it's an optional field but the point is: if it shows up it's either a painless script or a build-in process function, not any other value. So I believe the validation should be try to compile the value to painless script, if exception occurred, then check if it's a valid value in Pre/PostProcessFunction. Without this validation, user still can pass any value to the process_function field and not error message letting them know it's not a correct value, during runtime they found the function not working at all not knowing the root cause.

Copy link
Contributor Author

@akolarkunnu akolarkunnu Jul 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zane-neo I agree with that, that's the perfect validation here. But there is a technical limitation to implement that. As per my code walkthrough ScriptService instance is not available during the connector creation code flow. It will be available during the actual processing of pre or post process functions(inside the class DefaultPreProcessFunction). Are there anyway I can get ScriptService instance in ConnectorActions class ?

switch (remoteServer) {
case OPENAI:
Expand All @@ -255,6 +257,7 @@ private void validateProcessFunctions(String remoteServer, String processFunctio
logWarningForInvalidProcessFunc(SAGEMAKER, funcNameForWarnText);
}
}
validateEmbeddingProcessFunctions(endPointUrl, remoteServer, processFunction, funcNameForWarnText);
}
}

Expand All @@ -275,6 +278,17 @@ private void logWarningForInvalidProcessFunc(String remoteServer, String funcNam
);
}

private void validateEmbeddingProcessFunctions(
String endPointUrl,
String remoteServer,
String processFunction,
String funcNameForWarnText
) {
if (endPointUrl.contains(EMBED) && !(processFunction.contains(EMBEDDING))) {
logWarningForInvalidProcessFunc(remoteServer + " " + EMBEDDING, funcNameForWarnText);
}
}

public enum ActionType {
PREDICT,
EXECUTE,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,11 @@ public class ConnectorActionTest {
private static final String TEST_REQUEST_BODY = "{\"input\": \"${parameters.input}\"}";
private static final String URL = "https://test.com";
private static final String OPENAI_URL = "https://api.openai.com/v1/chat/completions";
private static final String COHERE_URL = "https://api.cohere.ai/v1/embed";
private static final String BEDROCK_URL = "https://bedrock-runtime.us-east-1.amazonaws.com/model/amazon.titan-embed-text-v1/invoke";
private static final String COHERE_URL = "https://api.cohere.ai/v1/rerank";
private static final String COHERE_EMBED_URL = "https://api.cohere.ai/v1/embed";
private static final String BEDROCK_URL = "https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-v2/invoke";
private static final String BEDROCK_EMBED_URL =
"https://bedrock-runtime.us-east-1.amazonaws.com/model/amazon.titan-embed-text-v1/invoke";
private static final String SAGEMAKER_URL =
"https://runtime.sagemaker.us-west-2.amazonaws.com/endpoints/lmi-model-2023-06-24-01-35-32-275/invocations";
private static final Logger logger = LogManager.getLogger(ConnectorActionTest.class);
Expand Down Expand Up @@ -232,7 +235,7 @@ public void testValidatePrePostProcessFunctionsWithCohereConnectorCorrectInBuilt
ConnectorAction action = new ConnectorAction(
TEST_ACTION_TYPE,
TEST_METHOD_HTTP,
COHERE_URL,
COHERE_EMBED_URL,
null,
TEST_REQUEST_BODY,
TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT,
Expand All @@ -244,7 +247,7 @@ public void testValidatePrePostProcessFunctionsWithCohereConnectorCorrectInBuilt
action = new ConnectorAction(
TEST_ACTION_TYPE,
TEST_METHOD_HTTP,
COHERE_URL,
COHERE_EMBED_URL,
null,
TEST_REQUEST_BODY,
IMAGE_TO_COHERE_MULTI_MODAL_EMBEDDING_INPUT,
Expand All @@ -271,7 +274,7 @@ public void testValidatePrePostProcessFunctionsWithCohereConnectorWrongInBuiltPr
ConnectorAction action = new ConnectorAction(
TEST_ACTION_TYPE,
TEST_METHOD_HTTP,
COHERE_URL,
COHERE_EMBED_URL,
null,
TEST_REQUEST_BODY,
TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT,
Expand All @@ -296,7 +299,7 @@ public void testValidatePrePostProcessFunctionsWithCohereConnectorWrongInBuiltPo
ConnectorAction action = new ConnectorAction(
TEST_ACTION_TYPE,
TEST_METHOD_HTTP,
COHERE_URL,
COHERE_EMBED_URL,
null,
TEST_REQUEST_BODY,
TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT,
Expand All @@ -316,12 +319,52 @@ public void testValidatePrePostProcessFunctionsWithCohereConnectorWrongInBuiltPo
assertTrue(isWarningLogged);
}

@Test
public void testValidatePrePostProcessFunctionsWithCohereConnectorWrongEmbedInBuiltPrePostProcessFunction() {
ConnectorAction action = new ConnectorAction(
TEST_ACTION_TYPE,
TEST_METHOD_HTTP,
COHERE_EMBED_URL,
null,
TEST_REQUEST_BODY,
TEXT_SIMILARITY_TO_COHERE_RERANK_INPUT,
COHERE_RERANK
);
action.validatePrePostProcessFunctions(Map.of());
boolean isWarningLogged = testAppender
.getLogEvents()
.stream()
.anyMatch(
event -> event.getLevel() == Level.WARN
&& event
.getMessage()
.getFormattedMessage()
.contains(
"LLM service is cohere embedding, so PreProcessFunction should be corresponding to cohere embedding for better results."
)
);
assertTrue(isWarningLogged);
isWarningLogged = testAppender
.getLogEvents()
.stream()
.anyMatch(
event -> event.getLevel() == Level.WARN
&& event
.getMessage()
.getFormattedMessage()
.contains(
"LLM service is cohere embedding, so PostProcessFunction should be corresponding to cohere embedding for better results."
)
);
assertTrue(isWarningLogged);
}

@Test
public void testValidatePrePostProcessFunctionsWithBedrockConnectorCorrectInBuiltPrePostProcessFunctionSuccess() {
ConnectorAction action = new ConnectorAction(
TEST_ACTION_TYPE,
TEST_METHOD_HTTP,
BEDROCK_URL,
BEDROCK_EMBED_URL,
null,
TEST_REQUEST_BODY,
TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT,
Expand Down Expand Up @@ -360,7 +403,7 @@ public void testValidatePrePostProcessFunctionsWithBedrockConnectorWrongInBuiltP
ConnectorAction action = new ConnectorAction(
TEST_ACTION_TYPE,
TEST_METHOD_HTTP,
BEDROCK_URL,
BEDROCK_EMBED_URL,
null,
TEST_REQUEST_BODY,
TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT,
Expand All @@ -385,7 +428,7 @@ public void testValidatePrePostProcessFunctionsWithBedrockConnectorWrongInBuiltP
ConnectorAction action = new ConnectorAction(
TEST_ACTION_TYPE,
TEST_METHOD_HTTP,
BEDROCK_URL,
BEDROCK_EMBED_URL,
null,
TEST_REQUEST_BODY,
TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT,
Expand Down
Loading