From 82fccd4f128432295f9b0ffcaa18c8f9c4f00202 Mon Sep 17 00:00:00 2001 From: Yuanchun Shen Date: Thu, 14 Aug 2025 15:17:52 +0800 Subject: [PATCH 1/4] Introduce sdk client to LocalRegexGuardrail Signed-off-by: Yuanchun Shen --- .../opensearch/ml/common/model/Guardrail.java | 3 +- .../ml/common/model/LocalRegexGuardrail.java | 54 +++++++++++-------- .../opensearch/ml/common/model/MLGuard.java | 11 ++-- .../ml/common/model/ModelGuardrail.java | 7 ++- .../ml/common/model/MLGuardTests.java | 9 +++- .../ml/common/model/ModelGuardrailTests.java | 7 ++- .../opensearch/ml/model/MLModelManager.java | 28 ++++++---- .../ml/rest/RestMLGuardrailsIT.java | 17 ++---- 8 files changed, 82 insertions(+), 54 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/model/Guardrail.java b/common/src/main/java/org/opensearch/ml/common/model/Guardrail.java index c707b849d0..3bf2cae7e1 100644 --- a/common/src/main/java/org/opensearch/ml/common/model/Guardrail.java +++ b/common/src/main/java/org/opensearch/ml/common/model/Guardrail.java @@ -11,6 +11,7 @@ import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.transport.client.Client; public abstract class Guardrail implements ToXContentObject { @@ -19,5 +20,5 @@ public abstract class Guardrail implements ToXContentObject { public abstract Boolean validate(String input, Map parameters); - public abstract void init(NamedXContentRegistry xContentRegistry, Client client); + public abstract void init(NamedXContentRegistry xContentRegistry, Client client, SdkClient sdkClient, String tenantId); } diff --git a/common/src/main/java/org/opensearch/ml/common/model/LocalRegexGuardrail.java b/common/src/main/java/org/opensearch/ml/common/model/LocalRegexGuardrail.java index ebd0f4dce6..3d01379d7e 100644 --- a/common/src/main/java/org/opensearch/ml/common/model/LocalRegexGuardrail.java +++ b/common/src/main/java/org/opensearch/ml/common/model/LocalRegexGuardrail.java @@ -25,7 +25,6 @@ import java.util.stream.Collectors; import org.opensearch.action.LatchedActionListener; -import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.LoggingDeprecationHandler; @@ -36,6 +35,9 @@ import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.client.SearchDataObjectRequest; +import org.opensearch.remote.metadata.common.SdkClientUtils; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.transport.client.Client; @@ -58,6 +60,8 @@ public class LocalRegexGuardrail extends Guardrail { private Map> stopWordsIndicesInput; private NamedXContentRegistry xContentRegistry; private Client client; + private SdkClient sdkClient; + private String tenantId; @Builder(toBuilder = true) public LocalRegexGuardrail(List stopWords, String[] regex) { @@ -109,9 +113,11 @@ public Boolean validate(String input, Map parameters) { } @Override - public void init(NamedXContentRegistry xContentRegistry, Client client) { + public void init(NamedXContentRegistry xContentRegistry, Client client, SdkClient sdkClient, String tenantId) { this.xContentRegistry = xContentRegistry; this.client = client; + this.sdkClient = sdkClient; + this.tenantId = tenantId; init(); } @@ -211,8 +217,7 @@ public Boolean validateStopWords(String input, Map> stopWor * @return true if no stop words matching, otherwise false. */ public Boolean validateStopWordsSingleIndex(String input, String indexName, List fieldNames) { - SearchRequest searchRequest; - AtomicBoolean hitStopWords = new AtomicBoolean(false); + AtomicBoolean passedStopWordCheck = new AtomicBoolean(false); String queryBody; Map documentMap = new HashMap<>(); for (String field : fieldNames) { @@ -230,32 +235,35 @@ public Boolean validateStopWordsSingleIndex(String input, String indexName, List .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, queryBody); searchSourceBuilder.parseXContent(queryParser); searchSourceBuilder.size(1); // Only need 1 doc returned, if hit. - searchRequest = new SearchRequest().source(searchSourceBuilder).indices(indexName); + var responseListener = new LatchedActionListener<>(ActionListener.wrap(r -> { + if (r == null || r.getHits() == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value() == 0) { + passedStopWordCheck.set(true); + } + }, e -> { + log.error("Failed to search stop words index {}", indexName, e); + passedStopWordCheck.set(true); + }), latch); + SearchDataObjectRequest searchDataObjectRequest = SearchDataObjectRequest + .builder() + .indices(indexName) + .searchSourceBuilder(searchSourceBuilder) + .tenantId(tenantId) + .build(); if (isStopWordsSystemIndex(indexName)) { context = client.threadPool().getThreadContext().stashContext(); ThreadContext.StoredContext finalContext = context; - client.search(searchRequest, ActionListener.runBefore(new LatchedActionListener(ActionListener.wrap(r -> { - if (r == null || r.getHits() == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value() == 0) { - hitStopWords.set(true); - } - }, e -> { - log.error("Failed to search stop words index {}", indexName, e); - hitStopWords.set(true); - }), latch), () -> finalContext.restore())); + sdkClient + .searchDataObjectAsync(searchDataObjectRequest) + .whenComplete(SdkClientUtils.wrapSearchCompletion(ActionListener.runBefore(responseListener, finalContext::restore))); } else { - client.search(searchRequest, new LatchedActionListener(ActionListener.wrap(r -> { - if (r == null || r.getHits() == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value() == 0) { - hitStopWords.set(true); - } - }, e -> { - log.error("Failed to search stop words index {}", indexName, e); - hitStopWords.set(true); - }), latch)); + sdkClient + .searchDataObjectAsync(searchDataObjectRequest) + .whenComplete(SdkClientUtils.wrapSearchCompletion(responseListener)); } } catch (Exception e) { log.error("[validateStopWords] Searching stop words index failed.", e); latch.countDown(); - hitStopWords.set(true); + passedStopWordCheck.set(true); } finally { if (context != null) { context.close(); @@ -268,7 +276,7 @@ public Boolean validateStopWordsSingleIndex(String input, String indexName, List log.error("[validateStopWords] Searching stop words index was timeout.", e); throw new IllegalStateException(e); } - return hitStopWords.get(); + return passedStopWordCheck.get(); } private boolean isStopWordsSystemIndex(String index) { diff --git a/common/src/main/java/org/opensearch/ml/common/model/MLGuard.java b/common/src/main/java/org/opensearch/ml/common/model/MLGuard.java index 8ea842a20e..1e78cca6b5 100644 --- a/common/src/main/java/org/opensearch/ml/common/model/MLGuard.java +++ b/common/src/main/java/org/opensearch/ml/common/model/MLGuard.java @@ -8,6 +8,7 @@ import java.util.Map; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.transport.client.Client; import lombok.Getter; @@ -18,17 +19,21 @@ public class MLGuard { private NamedXContentRegistry xContentRegistry; private Client client; + private final SdkClient sdkClient; + private final String tenantId; private Guardrails guardrails; - public MLGuard(Guardrails guardrails, NamedXContentRegistry xContentRegistry, Client client) { + public MLGuard(Guardrails guardrails, NamedXContentRegistry xContentRegistry, Client client, SdkClient sdkClient, String tenantId) { this.xContentRegistry = xContentRegistry; this.client = client; + this.sdkClient = sdkClient; + this.tenantId = tenantId; this.guardrails = guardrails; if (this.guardrails != null && this.guardrails.getInputGuardrail() != null) { - this.guardrails.getInputGuardrail().init(xContentRegistry, client); + this.guardrails.getInputGuardrail().init(xContentRegistry, client, sdkClient, tenantId); } if (this.guardrails != null && this.guardrails.getOutputGuardrail() != null) { - this.guardrails.getOutputGuardrail().init(xContentRegistry, client); + this.guardrails.getOutputGuardrail().init(xContentRegistry, client, sdkClient, tenantId); } } diff --git a/common/src/main/java/org/opensearch/ml/common/model/ModelGuardrail.java b/common/src/main/java/org/opensearch/ml/common/model/ModelGuardrail.java index afb32ddd7f..9b1b6c6a81 100644 --- a/common/src/main/java/org/opensearch/ml/common/model/ModelGuardrail.java +++ b/common/src/main/java/org/opensearch/ml/common/model/ModelGuardrail.java @@ -37,6 +37,7 @@ import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; +import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.transport.client.Client; import lombok.Builder; @@ -58,6 +59,8 @@ public class ModelGuardrail extends Guardrail { private String responseAccept; private NamedXContentRegistry xContentRegistry; private Client client; + private SdkClient sdkClient; + private String tenantId; private Pattern regexAcceptPattern; @Builder(toBuilder = true) @@ -141,9 +144,11 @@ public Boolean validate(String in, Map parameters) { } @Override - public void init(NamedXContentRegistry xContentRegistry, Client client) { + public void init(NamedXContentRegistry xContentRegistry, Client client, SdkClient sdkClient, String tenantId) { this.xContentRegistry = xContentRegistry; this.client = client; + this.sdkClient = sdkClient; + this.tenantId = tenantId; regexAcceptPattern = Pattern.compile(responseAccept); } diff --git a/common/src/test/java/org/opensearch/ml/common/model/MLGuardTests.java b/common/src/test/java/org/opensearch/ml/common/model/MLGuardTests.java index 7244ad21b3..2c5eb809bf 100644 --- a/common/src/test/java/org/opensearch/ml/common/model/MLGuardTests.java +++ b/common/src/test/java/org/opensearch/ml/common/model/MLGuardTests.java @@ -19,6 +19,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.search.SearchModule; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.client.Client; @@ -29,8 +30,11 @@ public class MLGuardTests { @Mock Client client; @Mock + SdkClient sdkClient; + @Mock ThreadPool threadPool; ThreadContext threadContext; + String tenantId; StopWords stopWords; String[] regex; @@ -48,6 +52,7 @@ public void setUp() { this.threadContext = new ThreadContext(settings); when(this.client.threadPool()).thenReturn(this.threadPool); when(this.threadPool.getThreadContext()).thenReturn(this.threadContext); + tenantId = "tenantId"; stopWords = new StopWords("test_index", List.of("test_field").toArray(new String[0])); regex = List.of("(.|\n)*stop words(.|\n)*").toArray(new String[0]); @@ -55,7 +60,7 @@ public void setUp() { inputLocalRegexGuardrail = new LocalRegexGuardrail(List.of(stopWords), regex); outputLocalRegexGuardrail = new LocalRegexGuardrail(List.of(stopWords), regex); guardrails = new Guardrails("test_type", inputLocalRegexGuardrail, outputLocalRegexGuardrail); - mlGuard = new MLGuard(guardrails, xContentRegistry, client); + mlGuard = new MLGuard(guardrails, xContentRegistry, client, sdkClient, tenantId); } @Test @@ -74,7 +79,7 @@ public void validateInitializedStopWordsEmpty() { inputLocalRegexGuardrail = new LocalRegexGuardrail(List.of(stopWords), regex); outputLocalRegexGuardrail = new LocalRegexGuardrail(List.of(stopWords), regex); guardrails = new Guardrails("test_type", inputLocalRegexGuardrail, outputLocalRegexGuardrail); - mlGuard = new MLGuard(guardrails, xContentRegistry, client); + mlGuard = new MLGuard(guardrails, xContentRegistry, client, sdkClient, tenantId); String input = "\n\nHuman:hello good words.\n\nAssistant:"; Boolean res = mlGuard.validate(input, MLGuard.Type.INPUT, Collections.emptyMap()); diff --git a/common/src/test/java/org/opensearch/ml/common/model/ModelGuardrailTests.java b/common/src/test/java/org/opensearch/ml/common/model/ModelGuardrailTests.java index b5da05751b..9d82aef807 100644 --- a/common/src/test/java/org/opensearch/ml/common/model/ModelGuardrailTests.java +++ b/common/src/test/java/org/opensearch/ml/common/model/ModelGuardrailTests.java @@ -21,6 +21,7 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.TestHelper; +import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.search.SearchModule; import org.opensearch.transport.client.Client; @@ -28,6 +29,9 @@ public class ModelGuardrailTests { NamedXContentRegistry xContentRegistry; @Mock Client client; + @Mock + SdkClient sdkClient; + String tenantId; Pattern regexPattern; ModelGuardrail modelGuardrail; @@ -38,6 +42,7 @@ public void setUp() { xContentRegistry = new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()); doNothing().when(this.client).execute(any(), any(), any()); modelGuardrail = new ModelGuardrail("test_model_id", "$.test", "^accept$"); + tenantId = "tenant_id"; regexPattern = Pattern.compile("^accept$"); } @@ -70,7 +75,7 @@ public void validateParametersEmpty1() { @Test public void init() { Assert.assertNull(modelGuardrail.getRegexAcceptPattern()); - modelGuardrail.init(xContentRegistry, client); + modelGuardrail.init(xContentRegistry, client, sdkClient, tenantId); Assert.assertEquals(regexPattern.toString(), modelGuardrail.getRegexAcceptPattern().toString()); } diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index f20c8819cf..d4bd04238e 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -1222,7 +1222,7 @@ public void deployModel( } setupRateLimiter(modelId, eligibleNodeCount, mlModel.getRateLimiter()); - setupMLGuard(modelId, mlModel.getGuardrails()); + setupMLGuard(modelId, tenantId, mlModel.getGuardrails()); setupModelInterface(modelId, mlModel.getModelInterface()); deployControllerWithDeployingModel(mlModel, eligibleNodeCount); // check circuit breaker before deploying custom model chunks @@ -1373,7 +1373,7 @@ public void deployModel( } setupRateLimiter(modelId, eligibleNodeCount, mlModel.getRateLimiter()); - setupMLGuard(modelId, mlModel.getGuardrails()); + setupMLGuard(modelId, mlModel.getTenantId(), mlModel.getGuardrails()); setupModelInterface(modelId, mlModel.getModelInterface()); deployControllerWithDeployingModel(mlModel, eligibleNodeCount); // check circuit breaker before deploying custom model chunks @@ -1438,7 +1438,7 @@ public void deployModel( private void deployRemoteOrBuiltInModel(MLModel mlModel, Integer eligibleNodeCount, ActionListener wrappedListener) { String modelId = mlModel.getModelId(); setupRateLimiter(modelId, eligibleNodeCount, mlModel.getRateLimiter()); - setupMLGuard(modelId, mlModel.getGuardrails()); + setupMLGuard(modelId, mlModel.getTenantId(), mlModel.getGuardrails()); setupModelInterface(modelId, mlModel.getModelInterface()); if (mlModel.getConnector() != null || FunctionName.REMOTE != mlModel.getAlgorithm()) { setupParamsAndPredictable(modelId, mlModel); @@ -1461,12 +1461,12 @@ private void deployRemoteOrBuiltInModel(MLModel mlModel, Integer eligibleNodeCou } private void setupParamsAndPredictable(String modelId, MLModel mlModel) { - Map params = setUpParameterMap(modelId); + Map params = setUpParameterMap(modelId, mlModel.getTenantId()); Predictable predictable = mlEngine.deploy(mlModel, params); modelCacheHelper.setPredictor(modelId, predictable); } - private Map setUpParameterMap(String modelId) { + private Map setUpParameterMap(String modelId, String tenantId) { TokenBucket rateLimiter = getRateLimiter(modelId); Map userRateLimiterMap = getUserRateLimiterMap(modelId); MLGuard mlGuard = getMLGuard(modelId); @@ -1519,7 +1519,7 @@ public synchronized void updateModelCache(String modelId, ActionListener int eligibleNodeCount = getWorkerNodes(modelId, mlModel.getAlgorithm()).length; modelCacheHelper.setIsModelEnabled(modelId, mlModel.getIsEnabled()); setupRateLimiter(modelId, eligibleNodeCount, mlModel.getRateLimiter()); - setupMLGuard(modelId, mlModel.getGuardrails()); + setupMLGuard(modelId, mlModel.getTenantId(), mlModel.getGuardrails()); setupModelInterface(modelId, mlModel.getModelInterface()); if (mlModel.getAlgorithm() == FunctionName.REMOTE) { if (mlModel.getConnector() != null) { @@ -1852,23 +1852,29 @@ public Map getModelInterface(String modelId) { * @param guardrails guardrail for the model */ - private void setupMLGuard(String modelId, Guardrails guardrails) { + private void setupMLGuard(String modelId, String tenantId, Guardrails guardrails) { if (guardrails != null) { - modelCacheHelper.setMLGuard(modelId, createMLGuard(guardrails, xContentRegistry, client)); + modelCacheHelper.setMLGuard(modelId, createMLGuard(guardrails, xContentRegistry, client, sdkClient, tenantId)); } else { modelCacheHelper.removeMLGuard(modelId); } } - private MLGuard createMLGuard(Guardrails guardrails, NamedXContentRegistry xContentRegistry, Client client) { - - return new MLGuard(guardrails, xContentRegistry, client); + private MLGuard createMLGuard( + Guardrails guardrails, + NamedXContentRegistry xContentRegistry, + Client client, + SdkClient sdkClient, + String tenantId + ) { + return new MLGuard(guardrails, xContentRegistry, client, sdkClient, tenantId); } /** * Get ML guard with model id. * * @param modelId model id + * * @return a ML guard */ public MLGuard getMLGuard(String modelId) { diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLGuardrailsIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGuardrailsIT.java index f31b4aeafc..18366398b2 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLGuardrailsIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGuardrailsIT.java @@ -12,6 +12,7 @@ import java.util.regex.Pattern; import org.junit.Assert; +import org.junit.Assume; import org.junit.Before; import org.junit.Ignore; import org.junit.Rule; @@ -155,9 +156,7 @@ public void testPredictRemoteModelSuccess() throws IOException, InterruptedExcep public void testPredictRemoteModelFailed() throws IOException, InterruptedException { // Skip test if key is null - if (OPENAI_KEY == null) { - return; - } + Assume.assumeNotNull(OPENAI_KEY); exceptionRule.expect(ResponseException.class); exceptionRule.expectMessage("guardrails triggered for user input"); Response response = createConnector(completionModelConnectorEntity); @@ -180,9 +179,7 @@ public void testPredictRemoteModelFailed() throws IOException, InterruptedExcept public void testPredictRemoteModelFailedNonType() throws IOException, InterruptedException { // Skip test if key is null - if (OPENAI_KEY == null) { - return; - } + Assume.assumeNotNull(OPENAI_KEY); exceptionRule.expect(ResponseException.class); exceptionRule.expectMessage("guardrails triggered for user input"); Response response = createConnector(completionModelConnectorEntity); @@ -205,9 +202,7 @@ public void testPredictRemoteModelFailedNonType() throws IOException, Interrupte @Ignore public void testPredictRemoteModelSuccessWithModelGuardrail() throws IOException, InterruptedException { // Skip test if key is null - if (OPENAI_KEY == null) { - return; - } + Assume.assumeNotNull(OPENAI_KEY); // Create guardrails model. Response response = createConnector(completionModelConnectorEntityWithGuardrail); Map responseMap = parseResponseToMap(response); @@ -279,9 +274,7 @@ public void testPredictRemoteModelSuccessWithModelGuardrail() throws IOException public void testPredictRemoteModelFailedWithModelGuardrail() throws IOException, InterruptedException { // Skip test if key is null - if (OPENAI_KEY == null) { - return; - } + Assume.assumeNotNull(OPENAI_KEY); exceptionRule.expect(ResponseException.class); exceptionRule.expectMessage("guardrails triggered for user input"); // Create guardrails model. From c62c460f7b25ff28cd5b13b2826ac7617f58b415 Mon Sep 17 00:00:00 2001 From: Yuanchun Shen Date: Wed, 27 Aug 2025 17:11:26 +0800 Subject: [PATCH 2/4] Use try-with-resource when validating stop words Signed-off-by: Yuanchun Shen --- .../ml/common/model/LocalRegexGuardrail.java | 21 ++----------------- 1 file changed, 2 insertions(+), 19 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/model/LocalRegexGuardrail.java b/common/src/main/java/org/opensearch/ml/common/model/LocalRegexGuardrail.java index 3d01379d7e..808270963e 100644 --- a/common/src/main/java/org/opensearch/ml/common/model/LocalRegexGuardrail.java +++ b/common/src/main/java/org/opensearch/ml/common/model/LocalRegexGuardrail.java @@ -7,7 +7,6 @@ import static java.util.concurrent.TimeUnit.SECONDS; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.ml.common.CommonValue.stopWordsIndices; import static org.opensearch.ml.common.utils.StringUtils.gson; import java.io.IOException; @@ -225,8 +224,6 @@ public Boolean validateStopWordsSingleIndex(String input, String indexName, List } Map queryBodyMap = Map.of("query", Map.of("percolate", Map.of("field", "query", "document", documentMap))); CountDownLatch latch = new CountDownLatch(1); - ThreadContext.StoredContext context = null; - try { queryBody = AccessController.doPrivileged((PrivilegedExceptionAction) () -> gson.toJson(queryBodyMap)); SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); @@ -249,25 +246,15 @@ public Boolean validateStopWordsSingleIndex(String input, String indexName, List .searchSourceBuilder(searchSourceBuilder) .tenantId(tenantId) .build(); - if (isStopWordsSystemIndex(indexName)) { - context = client.threadPool().getThreadContext().stashContext(); - ThreadContext.StoredContext finalContext = context; - sdkClient - .searchDataObjectAsync(searchDataObjectRequest) - .whenComplete(SdkClientUtils.wrapSearchCompletion(ActionListener.runBefore(responseListener, finalContext::restore))); - } else { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { sdkClient .searchDataObjectAsync(searchDataObjectRequest) - .whenComplete(SdkClientUtils.wrapSearchCompletion(responseListener)); + .whenComplete(SdkClientUtils.wrapSearchCompletion(ActionListener.runBefore(responseListener, context::restore))); } } catch (Exception e) { log.error("[validateStopWords] Searching stop words index failed.", e); latch.countDown(); passedStopWordCheck.set(true); - } finally { - if (context != null) { - context.close(); - } } try { @@ -278,8 +265,4 @@ public Boolean validateStopWordsSingleIndex(String input, String indexName, List } return passedStopWordCheck.get(); } - - private boolean isStopWordsSystemIndex(String index) { - return stopWordsIndices.contains(index); - } } From bafe5f0c967715230ed808646938e0019ac6fe30 Mon Sep 17 00:00:00 2001 From: Yuanchun Shen Date: Fri, 29 Aug 2025 00:00:31 +0800 Subject: [PATCH 3/4] Improve unit tests for LocalRegexGuardrail Signed-off-by: Yuanchun Shen --- .../ml/common/model/LocalRegexGuardrail.java | 24 +-- .../model/LocalRegexGuardrailTests.java | 145 +++++++++--------- 2 files changed, 84 insertions(+), 85 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/model/LocalRegexGuardrail.java b/common/src/main/java/org/opensearch/ml/common/model/LocalRegexGuardrail.java index 808270963e..0d76609060 100644 --- a/common/src/main/java/org/opensearch/ml/common/model/LocalRegexGuardrail.java +++ b/common/src/main/java/org/opensearch/ml/common/model/LocalRegexGuardrail.java @@ -226,12 +226,7 @@ public Boolean validateStopWordsSingleIndex(String input, String indexName, List CountDownLatch latch = new CountDownLatch(1); try { queryBody = AccessController.doPrivileged((PrivilegedExceptionAction) () -> gson.toJson(queryBodyMap)); - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - XContentParser queryParser = XContentType.JSON - .xContent() - .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, queryBody); - searchSourceBuilder.parseXContent(queryParser); - searchSourceBuilder.size(1); // Only need 1 doc returned, if hit. + SearchDataObjectRequest searchDataObjectRequest = buildSearchDataObjectRequest(indexName, queryBody); var responseListener = new LatchedActionListener<>(ActionListener.wrap(r -> { if (r == null || r.getHits() == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value() == 0) { passedStopWordCheck.set(true); @@ -240,12 +235,6 @@ public Boolean validateStopWordsSingleIndex(String input, String indexName, List log.error("Failed to search stop words index {}", indexName, e); passedStopWordCheck.set(true); }), latch); - SearchDataObjectRequest searchDataObjectRequest = SearchDataObjectRequest - .builder() - .indices(indexName) - .searchSourceBuilder(searchSourceBuilder) - .tenantId(tenantId) - .build(); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { sdkClient .searchDataObjectAsync(searchDataObjectRequest) @@ -265,4 +254,15 @@ public Boolean validateStopWordsSingleIndex(String input, String indexName, List } return passedStopWordCheck.get(); } + + protected SearchDataObjectRequest buildSearchDataObjectRequest(String indexName, String queryBody) throws IOException { + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + XContentParser queryParser = XContentType.JSON + .xContent() + .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, queryBody); + searchSourceBuilder.parseXContent(queryParser); + searchSourceBuilder.size(1); // Only need 1 doc returned, if hit. + + return SearchDataObjectRequest.builder().indices(indexName).searchSourceBuilder(searchSourceBuilder).tenantId(tenantId).build(); + } } diff --git a/common/src/test/java/org/opensearch/ml/common/model/LocalRegexGuardrailTests.java b/common/src/test/java/org/opensearch/ml/common/model/LocalRegexGuardrailTests.java index 43cb7ddfab..e72158fb46 100644 --- a/common/src/test/java/org/opensearch/ml/common/model/LocalRegexGuardrailTests.java +++ b/common/src/test/java/org/opensearch/ml/common/model/LocalRegexGuardrailTests.java @@ -6,13 +6,16 @@ package org.opensearch.ml.common.model; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.when; import java.io.IOException; import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.concurrent.TimeUnit; +import java.util.concurrent.CompletableFuture; import java.util.regex.Pattern; import org.apache.lucene.search.TotalHits; @@ -20,13 +23,12 @@ import org.junit.Before; import org.junit.Test; import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.MockitoAnnotations; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.ShardSearchFailure; -import org.opensearch.common.action.ActionFuture; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.settings.Settings; -import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.XContentType; @@ -36,6 +38,9 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.TestHelper; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.client.SearchDataObjectRequest; +import org.opensearch.remote.metadata.client.SearchDataObjectResponse; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.SearchModule; @@ -52,26 +57,32 @@ public class LocalRegexGuardrailTests { Client client; @Mock ThreadPool threadPool; + @Mock + SdkClient sdkClient; ThreadContext threadContext; StopWords stopWords; String[] regex; List regexPatterns; LocalRegexGuardrail localRegexGuardrail; + final String tenantId = "tenant_id"; + final String indexName = "test_index"; + final String testField = "test_field"; @Before public void setUp() { MockitoAnnotations.openMocks(this); - xContentRegistry = new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()); + xContentRegistry = spy(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents())); Settings settings = Settings.builder().build(); this.threadContext = new ThreadContext(settings); when(this.client.threadPool()).thenReturn(this.threadPool); when(this.threadPool.getThreadContext()).thenReturn(this.threadContext); - stopWords = new StopWords("test_index", List.of("test_field").toArray(new String[0])); + stopWords = new StopWords(indexName, List.of(testField).toArray(new String[0])); regex = List.of("(.|\n)*stop words(.|\n)*").toArray(new String[0]); regexPatterns = List.of(Pattern.compile("(.|\n)*stop words(.|\n)*")); localRegexGuardrail = new LocalRegexGuardrail(List.of(stopWords), regex); + localRegexGuardrail.init(xContentRegistry, client, sdkClient, tenantId); } @Test @@ -188,36 +199,78 @@ public void validateRegexFailed() { } @Test - public void validateStopWords() throws IOException { - Map> stopWordsIndices = Map.of("test_index", List.of("test_field")); - SearchResponse searchResponse = createSearchResponse(1); - ActionFuture future = createSearchResponseFuture(searchResponse); - when(this.client.search(any())).thenReturn(future); + public void testValidateStopWordsPass() { + Map> stopWordsIndices = Map.of(indexName, List.of(testField)); + LocalRegexGuardrail spyGuardrail = spy(localRegexGuardrail); + doReturn(true).when(spyGuardrail).validateStopWordsSingleIndex("hello world", indexName, List.of(testField)); - Boolean res = localRegexGuardrail.validateStopWords("hello world", stopWordsIndices); - Assert.assertTrue(res); + Boolean resPass = spyGuardrail.validateStopWords("hello world", stopWordsIndices); + Assert.assertTrue(resPass); + } + + @Test + public void testValidateStopWordsFail() { + Map> stopWordsIndices = Map.of(indexName, List.of(testField)); + LocalRegexGuardrail spyGuardrail = spy(localRegexGuardrail); + doReturn(false).when(spyGuardrail).validateStopWordsSingleIndex("stop word", indexName, List.of(testField)); + + Boolean resFail = spyGuardrail.validateStopWords("stop word", stopWordsIndices); + Assert.assertFalse(resFail); } @Test - public void validateStopWordsNull() { + public void testValidateStopWordsNull() { Boolean res = localRegexGuardrail.validateStopWords("hello world", null); Assert.assertTrue(res); } @Test - public void validateStopWordsEmpty() { + public void testValidateStopWordsEmpty() { Boolean res = localRegexGuardrail.validateStopWords("hello world", Map.of()); Assert.assertTrue(res); } @Test - public void validateStopWordsSingleIndex() throws IOException { - SearchResponse searchResponse = createSearchResponse(1); - ActionFuture future = createSearchResponseFuture(searchResponse); - when(this.client.search(any())).thenReturn(future); + public void testValidateStopWordsSingleIndexWithoutHit() throws Exception { + LocalRegexGuardrail spyGuardrail = spy(localRegexGuardrail); + doReturn(mock(SearchDataObjectRequest.class)).when(spyGuardrail).buildSearchDataObjectRequest(any(), any()); + + SearchResponse emptySearchResponse = createSearchResponse(0); + SearchDataObjectResponse searchDataObjectResponse = new SearchDataObjectResponse(emptySearchResponse); + CompletableFuture completedFuture = CompletableFuture.completedFuture(searchDataObjectResponse); + + // Mock the searchDataObjectAsync to return our future with empty response + when(sdkClient.searchDataObjectAsync(any())).thenReturn(completedFuture); - Boolean res = localRegexGuardrail.validateStopWordsSingleIndex("hello world", "test_index", List.of("test_field")); + Boolean res = spyGuardrail.validateStopWordsSingleIndex("hello world", indexName, List.of(testField)); Assert.assertTrue(res); + Mockito.verify(sdkClient, Mockito.times(1)).searchDataObjectAsync(any()); + } + + @Test + public void testValidateStopWordsSingleIndexWithStopWordHit() throws Exception { + LocalRegexGuardrail spyGuardrail = spy(localRegexGuardrail); + doReturn(mock(SearchDataObjectRequest.class)).when(spyGuardrail).buildSearchDataObjectRequest(any(), any()); + + // Mock the sdkClient response - search returns a hit (stop word found) + SearchResponse searchResponseWithHit = createSearchResponse(1); + SearchDataObjectResponse searchDataObjectResponse = new SearchDataObjectResponse(searchResponseWithHit); + // Create a completable future that will immediately execute the callback with a response containing hits + CompletableFuture completedFuture = CompletableFuture.completedFuture(searchDataObjectResponse); + + // Mock the searchDataObjectAsync to return our future that has hits + when(sdkClient.searchDataObjectAsync(any())).thenReturn(completedFuture); + + Boolean res = spyGuardrail.validateStopWordsSingleIndex("hello bad word", indexName, List.of(testField)); + Assert.assertFalse(res); + Mockito.verify(sdkClient, Mockito.times(1)).searchDataObjectAsync(any()); + } + + @Test + public void testBuildSearchDataObjectRequest() throws IOException { + SearchDataObjectRequest request = localRegexGuardrail.buildSearchDataObjectRequest(indexName, "{}"); + Assert.assertEquals(indexName, request.indices()[0]); + Assert.assertEquals(tenantId, request.tenantId()); } private SearchResponse createSearchResponse(int size) throws IOException { @@ -245,58 +298,4 @@ private SearchResponse createSearchResponse(int size) throws IOException { SearchResponse.Clusters.EMPTY ); } - - private ActionFuture createSearchResponseFuture(SearchResponse searchResponse) { - return new ActionFuture<>() { - @Override - public SearchResponse actionGet() { - return searchResponse; - } - - @Override - public SearchResponse actionGet(String timeout) { - return searchResponse; - } - - @Override - public SearchResponse actionGet(long timeoutMillis) { - return searchResponse; - } - - @Override - public SearchResponse actionGet(long timeout, TimeUnit unit) { - return searchResponse; - } - - @Override - public SearchResponse actionGet(TimeValue timeout) { - return searchResponse; - } - - @Override - public boolean cancel(boolean mayInterruptIfRunning) { - return false; - } - - @Override - public boolean isCancelled() { - return false; - } - - @Override - public boolean isDone() { - return false; - } - - @Override - public SearchResponse get() { - return searchResponse; - } - - @Override - public SearchResponse get(long timeout, TimeUnit unit) { - return searchResponse; - } - }; - } } From b98c2640b4edbe9a14f34ceb467a5edb9fdf1b2d Mon Sep 17 00:00:00 2001 From: Yuanchun Shen Date: Fri, 5 Sep 2025 17:03:57 +0800 Subject: [PATCH 4/4] Unit test failed cases for validateStopWordsSingleIndex Signed-off-by: Yuanchun Shen --- .../model/LocalRegexGuardrailTests.java | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/common/src/test/java/org/opensearch/ml/common/model/LocalRegexGuardrailTests.java b/common/src/test/java/org/opensearch/ml/common/model/LocalRegexGuardrailTests.java index e72158fb46..25dab021d2 100644 --- a/common/src/test/java/org/opensearch/ml/common/model/LocalRegexGuardrailTests.java +++ b/common/src/test/java/org/opensearch/ml/common/model/LocalRegexGuardrailTests.java @@ -266,6 +266,34 @@ public void testValidateStopWordsSingleIndexWithStopWordHit() throws Exception { Mockito.verify(sdkClient, Mockito.times(1)).searchDataObjectAsync(any()); } + @Test + public void testValidateStopWordsSingleIndexFailedSearchingIndex() throws IOException { + LocalRegexGuardrail spyGuardrail = spy(localRegexGuardrail); + doReturn(mock(SearchDataObjectRequest.class)).when(spyGuardrail).buildSearchDataObjectRequest(any(), any()); + + // Create a completable future that throws an exception when get() is called + CompletableFuture failedFuture = new CompletableFuture<>(); + failedFuture.completeExceptionally(new IOException("Index not found")); + when(sdkClient.searchDataObjectAsync(any())).thenReturn(failedFuture); + + // Covers error "Failed to search stop words index test_index" + Boolean res = spyGuardrail.validateStopWordsSingleIndex("hello world", indexName, List.of(testField)); + Assert.assertTrue(res); + Mockito.verify(sdkClient, Mockito.times(1)).searchDataObjectAsync(any()); + } + + @Test + public void testValidateStopWordsSingleIndexFailed() throws IOException { + LocalRegexGuardrail spyGuardrail = spy(localRegexGuardrail); + doReturn(mock(SearchDataObjectRequest.class)).when(spyGuardrail).buildSearchDataObjectRequest(any(), any()); + + when(sdkClient.searchDataObjectAsync(any())).thenThrow(new RuntimeException("test exception")); + // Covers error "[validateStopWords] Searching stop words index failed." + Boolean res = spyGuardrail.validateStopWordsSingleIndex("hello world", indexName, List.of(testField)); + Assert.assertTrue(res); + Mockito.verify(sdkClient, Mockito.times(1)).searchDataObjectAsync(any()); + } + @Test public void testBuildSearchDataObjectRequest() throws IOException { SearchDataObjectRequest request = localRegexGuardrail.buildSearchDataObjectRequest(indexName, "{}");