From 4cc3faeca3e76977b56ba5823af70567c45d46de Mon Sep 17 00:00:00 2001 From: Abdul Muneer Kolarkunnu Date: Tue, 8 Jul 2025 12:49:10 +0530 Subject: [PATCH 1/3] [Enhancement] Enhance validation for create connector API - Part 3 This change will address the third part of validation "embeddings pre and post processing function validation". Resolves #2993 Signed-off-by: Abdul Muneer Kolarkunnu --- .../ml/common/connector/ConnectorAction.java | 22 ++++++- .../common/connector/ConnectorActionTest.java | 61 ++++++++++++++++--- 2 files changed, 71 insertions(+), 12 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java index c82f489296..122e8002e0 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java @@ -47,6 +47,8 @@ public class ConnectorAction implements ToXContentObject, Writeable { public static final List SUPPORTED_REMOTE_SERVERS_FOR_DEFAULT_ACTION_TYPES = List.of(SAGEMAKER, OPENAI, BEDROCK, COHERE); private static final String INBUILT_FUNC_PREFIX = "connector."; + private static final String EMBED = "embed"; + private static final String EMBEDDING = "embedding"; private static final String PRE_PROCESS_FUNC = "PreProcessFunction"; private static final String POST_PROCESS_FUNC = "PostProcessFunction"; private static final Logger logger = LogManager.getLogger(ConnectorAction.class); @@ -216,8 +218,8 @@ public void validatePrePostProcessFunctions(Map parameters) { String endPoint = substitutor.replace(url); String remoteServer = getRemoteServerFromURL(endPoint); if (!remoteServer.isEmpty()) { - validateProcessFunctions(remoteServer, preProcessFunction, PRE_PROCESS_FUNC); - validateProcessFunctions(remoteServer, postProcessFunction, POST_PROCESS_FUNC); + validateProcessFunctions(endPoint, remoteServer, preProcessFunction, PRE_PROCESS_FUNC); + validateProcessFunctions(endPoint, remoteServer, postProcessFunction, POST_PROCESS_FUNC); } } @@ -232,7 +234,7 @@ public static String getRemoteServerFromURL(String url) { return SUPPORTED_REMOTE_SERVERS_FOR_DEFAULT_ACTION_TYPES.stream().filter(url::contains).findFirst().orElse(""); } - private void validateProcessFunctions(String remoteServer, String processFunction, String funcNameForWarnText) { + private void validateProcessFunctions(String endPointUrl, String remoteServer, String processFunction, String funcNameForWarnText) { if (isInBuiltProcessFunction(processFunction)) { switch (remoteServer) { case OPENAI: @@ -255,6 +257,7 @@ private void validateProcessFunctions(String remoteServer, String processFunctio logWarningForInvalidProcessFunc(SAGEMAKER, funcNameForWarnText); } } + validateEmbeddingProcessFunctions(endPointUrl, remoteServer, processFunction, funcNameForWarnText); } } @@ -275,6 +278,19 @@ private void logWarningForInvalidProcessFunc(String remoteServer, String funcNam ); } + private void validateEmbeddingProcessFunctions( + String endPointUrl, + String remoteServer, + String processFunction, + String funcNameForWarnText + ) { + if (endPointUrl.contains(EMBED)) { + if (!processFunction.contains(EMBEDDING) || !processFunction.contains(remoteServer)) { + logWarningForInvalidProcessFunc(remoteServer + " " + EMBEDDING, funcNameForWarnText); + } + } + } + public enum ActionType { PREDICT, EXECUTE, diff --git a/common/src/test/java/org/opensearch/ml/common/connector/ConnectorActionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/ConnectorActionTest.java index e05d8d04d2..5bed7e6689 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/ConnectorActionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/ConnectorActionTest.java @@ -67,8 +67,11 @@ public class ConnectorActionTest { private static final String TEST_REQUEST_BODY = "{\"input\": \"${parameters.input}\"}"; private static final String URL = "https://test.com"; private static final String OPENAI_URL = "https://api.openai.com/v1/chat/completions"; - private static final String COHERE_URL = "https://api.cohere.ai/v1/embed"; - private static final String BEDROCK_URL = "https://bedrock-runtime.us-east-1.amazonaws.com/model/amazon.titan-embed-text-v1/invoke"; + private static final String COHERE_URL = "https://api.cohere.ai/v1/rerank"; + private static final String COHERE_EMBED_URL = "https://api.cohere.ai/v1/embed"; + private static final String BEDROCK_URL = "https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-v2/invoke"; + private static final String BEDROCK_EMBED_URL = + "https://bedrock-runtime.us-east-1.amazonaws.com/model/amazon.titan-embed-text-v1/invoke"; private static final String SAGEMAKER_URL = "https://runtime.sagemaker.us-west-2.amazonaws.com/endpoints/lmi-model-2023-06-24-01-35-32-275/invocations"; private static final Logger logger = LogManager.getLogger(ConnectorActionTest.class); @@ -232,7 +235,7 @@ public void testValidatePrePostProcessFunctionsWithCohereConnectorCorrectInBuilt ConnectorAction action = new ConnectorAction( TEST_ACTION_TYPE, TEST_METHOD_HTTP, - COHERE_URL, + COHERE_EMBED_URL, null, TEST_REQUEST_BODY, TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, @@ -244,7 +247,7 @@ public void testValidatePrePostProcessFunctionsWithCohereConnectorCorrectInBuilt action = new ConnectorAction( TEST_ACTION_TYPE, TEST_METHOD_HTTP, - COHERE_URL, + COHERE_EMBED_URL, null, TEST_REQUEST_BODY, IMAGE_TO_COHERE_MULTI_MODAL_EMBEDDING_INPUT, @@ -271,7 +274,7 @@ public void testValidatePrePostProcessFunctionsWithCohereConnectorWrongInBuiltPr ConnectorAction action = new ConnectorAction( TEST_ACTION_TYPE, TEST_METHOD_HTTP, - COHERE_URL, + COHERE_EMBED_URL, null, TEST_REQUEST_BODY, TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT, @@ -296,7 +299,7 @@ public void testValidatePrePostProcessFunctionsWithCohereConnectorWrongInBuiltPo ConnectorAction action = new ConnectorAction( TEST_ACTION_TYPE, TEST_METHOD_HTTP, - COHERE_URL, + COHERE_EMBED_URL, null, TEST_REQUEST_BODY, TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, @@ -316,12 +319,52 @@ public void testValidatePrePostProcessFunctionsWithCohereConnectorWrongInBuiltPo assertTrue(isWarningLogged); } + @Test + public void testValidatePrePostProcessFunctionsWithCohereConnectorWrongEmbedInBuiltPrePostProcessFunction() { + ConnectorAction action = new ConnectorAction( + TEST_ACTION_TYPE, + TEST_METHOD_HTTP, + COHERE_EMBED_URL, + null, + TEST_REQUEST_BODY, + TEXT_SIMILARITY_TO_COHERE_RERANK_INPUT, + COHERE_RERANK + ); + action.validatePrePostProcessFunctions(Map.of()); + boolean isWarningLogged = testAppender + .getLogEvents() + .stream() + .anyMatch( + event -> event.getLevel() == Level.WARN + && event + .getMessage() + .getFormattedMessage() + .contains( + "LLM service is cohere embedding, so PreProcessFunction should be corresponding to cohere embedding for better results." + ) + ); + assertTrue(isWarningLogged); + isWarningLogged = testAppender + .getLogEvents() + .stream() + .anyMatch( + event -> event.getLevel() == Level.WARN + && event + .getMessage() + .getFormattedMessage() + .contains( + "LLM service is cohere embedding, so PostProcessFunction should be corresponding to cohere embedding for better results." + ) + ); + assertTrue(isWarningLogged); + } + @Test public void testValidatePrePostProcessFunctionsWithBedrockConnectorCorrectInBuiltPrePostProcessFunctionSuccess() { ConnectorAction action = new ConnectorAction( TEST_ACTION_TYPE, TEST_METHOD_HTTP, - BEDROCK_URL, + BEDROCK_EMBED_URL, null, TEST_REQUEST_BODY, TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT, @@ -360,7 +403,7 @@ public void testValidatePrePostProcessFunctionsWithBedrockConnectorWrongInBuiltP ConnectorAction action = new ConnectorAction( TEST_ACTION_TYPE, TEST_METHOD_HTTP, - BEDROCK_URL, + BEDROCK_EMBED_URL, null, TEST_REQUEST_BODY, TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, @@ -385,7 +428,7 @@ public void testValidatePrePostProcessFunctionsWithBedrockConnectorWrongInBuiltP ConnectorAction action = new ConnectorAction( TEST_ACTION_TYPE, TEST_METHOD_HTTP, - BEDROCK_URL, + BEDROCK_EMBED_URL, null, TEST_REQUEST_BODY, TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT, From 98f7c31a6649a3c20537081bb5052b6084b81e52 Mon Sep 17 00:00:00 2001 From: Abdul Muneer Kolarkunnu Date: Tue, 8 Jul 2025 13:07:07 +0530 Subject: [PATCH 2/3] [Enhancement] Enhance validation for create connector API - Part 3 This change will address the third part of validation "embeddings pre and post processing function validation". Resolves #2993 Signed-off-by: Abdul Muneer Kolarkunnu --- .../org/opensearch/ml/common/connector/ConnectorAction.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java index 122e8002e0..639ffb0816 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java @@ -285,7 +285,7 @@ private void validateEmbeddingProcessFunctions( String funcNameForWarnText ) { if (endPointUrl.contains(EMBED)) { - if (!processFunction.contains(EMBEDDING) || !processFunction.contains(remoteServer)) { + if (!processFunction.contains(EMBEDDING)) { logWarningForInvalidProcessFunc(remoteServer + " " + EMBEDDING, funcNameForWarnText); } } From 74aa59393ba4a37df24ad7a7c7c60e7a04f85bf6 Mon Sep 17 00:00:00 2001 From: Abdul Muneer Kolarkunnu Date: Tue, 8 Jul 2025 13:11:25 +0530 Subject: [PATCH 3/3] [Enhancement] Enhance validation for create connector API - Part 3 This change will address the third part of validation "embeddings pre and post processing function validation". Resolves #2993 Signed-off-by: Abdul Muneer Kolarkunnu --- .../org/opensearch/ml/common/connector/ConnectorAction.java | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java index 639ffb0816..4e7aed6353 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java @@ -284,10 +284,8 @@ private void validateEmbeddingProcessFunctions( String processFunction, String funcNameForWarnText ) { - if (endPointUrl.contains(EMBED)) { - if (!processFunction.contains(EMBEDDING)) { - logWarningForInvalidProcessFunc(remoteServer + " " + EMBEDDING, funcNameForWarnText); - } + if (endPointUrl.contains(EMBED) && !(processFunction.contains(EMBEDDING))) { + logWarningForInvalidProcessFunc(remoteServer + " " + EMBEDDING, funcNameForWarnText); } }