Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.remote.metadata.client.SdkClient;
import org.opensearch.transport.client.Client;

public abstract class Guardrail implements ToXContentObject {
Expand All @@ -19,5 +20,5 @@ public abstract class Guardrail implements ToXContentObject {

public abstract Boolean validate(String input, Map<String, String> parameters);

public abstract void init(NamedXContentRegistry xContentRegistry, Client client);
public abstract void init(NamedXContentRegistry xContentRegistry, Client client, SdkClient sdkClient, String tenantId);
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import static java.util.concurrent.TimeUnit.SECONDS;
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.CommonValue.stopWordsIndices;
import static org.opensearch.ml.common.utils.StringUtils.gson;

import java.io.IOException;
Expand All @@ -25,7 +24,6 @@
import java.util.stream.Collectors;

import org.opensearch.action.LatchedActionListener;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
Expand All @@ -36,6 +34,9 @@
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.remote.metadata.client.SdkClient;
import org.opensearch.remote.metadata.client.SearchDataObjectRequest;
import org.opensearch.remote.metadata.common.SdkClientUtils;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.transport.client.Client;

Expand All @@ -58,6 +59,8 @@ public class LocalRegexGuardrail extends Guardrail {
private Map<String, List<String>> stopWordsIndicesInput;
private NamedXContentRegistry xContentRegistry;
private Client client;
private SdkClient sdkClient;
private String tenantId;

@Builder(toBuilder = true)
public LocalRegexGuardrail(List<StopWords> stopWords, String[] regex) {
Expand Down Expand Up @@ -109,9 +112,11 @@ public Boolean validate(String input, Map<String, String> parameters) {
}

@Override
public void init(NamedXContentRegistry xContentRegistry, Client client) {
public void init(NamedXContentRegistry xContentRegistry, Client client, SdkClient sdkClient, String tenantId) {
this.xContentRegistry = xContentRegistry;
this.client = client;
this.sdkClient = sdkClient;
this.tenantId = tenantId;
init();
}

Expand Down Expand Up @@ -211,55 +216,34 @@ public Boolean validateStopWords(String input, Map<String, List<String>> stopWor
* @return true if no stop words matching, otherwise false.
*/
public Boolean validateStopWordsSingleIndex(String input, String indexName, List<String> fieldNames) {
SearchRequest searchRequest;
AtomicBoolean hitStopWords = new AtomicBoolean(false);
AtomicBoolean passedStopWordCheck = new AtomicBoolean(false);
String queryBody;
Map<String, String> documentMap = new HashMap<>();
for (String field : fieldNames) {
documentMap.put(field, input);
}
Map<String, Object> queryBodyMap = Map.of("query", Map.of("percolate", Map.of("field", "query", "document", documentMap)));
CountDownLatch latch = new CountDownLatch(1);
ThreadContext.StoredContext context = null;

try {
queryBody = AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(queryBodyMap));
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
XContentParser queryParser = XContentType.JSON
.xContent()
.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, queryBody);
searchSourceBuilder.parseXContent(queryParser);
searchSourceBuilder.size(1); // Only need 1 doc returned, if hit.
searchRequest = new SearchRequest().source(searchSourceBuilder).indices(indexName);
if (isStopWordsSystemIndex(indexName)) {
context = client.threadPool().getThreadContext().stashContext();
ThreadContext.StoredContext finalContext = context;
client.search(searchRequest, ActionListener.runBefore(new LatchedActionListener(ActionListener.<SearchResponse>wrap(r -> {
if (r == null || r.getHits() == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value() == 0) {
hitStopWords.set(true);
}
}, e -> {
log.error("Failed to search stop words index {}", indexName, e);
hitStopWords.set(true);
}), latch), () -> finalContext.restore()));
} else {
client.search(searchRequest, new LatchedActionListener(ActionListener.<SearchResponse>wrap(r -> {
if (r == null || r.getHits() == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value() == 0) {
hitStopWords.set(true);
}
}, e -> {
log.error("Failed to search stop words index {}", indexName, e);
hitStopWords.set(true);
}), latch));
SearchDataObjectRequest searchDataObjectRequest = buildSearchDataObjectRequest(indexName, queryBody);
var responseListener = new LatchedActionListener<>(ActionListener.<SearchResponse>wrap(r -> {
if (r == null || r.getHits() == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value() == 0) {
passedStopWordCheck.set(true);
}
}, e -> {
log.error("Failed to search stop words index {}", indexName, e);
passedStopWordCheck.set(true);
}), latch);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
sdkClient
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems the if...else can be merged, and check the other places using SdkClient, try to always use the same approach like:

 try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
            sdkClient.getDataObjectAsync(getDataObjectRequest).whenComplete((

Without this try block, it could cause issue some edge cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By merging the if...else, do you mean I can always stash and restore the thread context no matter whether it is a stop word system index or not?

Actually I don't really know why it has to stash and restore the thread context when the index is a system index, but does not have to do so when the index is not. Explanations will be helpful

.searchDataObjectAsync(searchDataObjectRequest)
.whenComplete(SdkClientUtils.wrapSearchCompletion(ActionListener.runBefore(responseListener, context::restore)));
}
} catch (Exception e) {
log.error("[validateStopWords] Searching stop words index failed.", e);
latch.countDown();
hitStopWords.set(true);
} finally {
if (context != null) {
context.close();
}
passedStopWordCheck.set(true);
}

try {
Expand All @@ -268,10 +252,17 @@ public Boolean validateStopWordsSingleIndex(String input, String indexName, List
log.error("[validateStopWords] Searching stop words index was timeout.", e);
throw new IllegalStateException(e);
}
return hitStopWords.get();
return passedStopWordCheck.get();
}

private boolean isStopWordsSystemIndex(String index) {
return stopWordsIndices.contains(index);
protected SearchDataObjectRequest buildSearchDataObjectRequest(String indexName, String queryBody) throws IOException {
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
XContentParser queryParser = XContentType.JSON
.xContent()
.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, queryBody);
searchSourceBuilder.parseXContent(queryParser);
searchSourceBuilder.size(1); // Only need 1 doc returned, if hit.

return SearchDataObjectRequest.builder().indices(indexName).searchSourceBuilder(searchSourceBuilder).tenantId(tenantId).build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import java.util.Map;

import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.remote.metadata.client.SdkClient;
import org.opensearch.transport.client.Client;

import lombok.Getter;
Expand All @@ -18,17 +19,21 @@
public class MLGuard {
private NamedXContentRegistry xContentRegistry;
private Client client;
private final SdkClient sdkClient;
private final String tenantId;
private Guardrails guardrails;

public MLGuard(Guardrails guardrails, NamedXContentRegistry xContentRegistry, Client client) {
public MLGuard(Guardrails guardrails, NamedXContentRegistry xContentRegistry, Client client, SdkClient sdkClient, String tenantId) {
this.xContentRegistry = xContentRegistry;
this.client = client;
this.sdkClient = sdkClient;
this.tenantId = tenantId;
this.guardrails = guardrails;
if (this.guardrails != null && this.guardrails.getInputGuardrail() != null) {
this.guardrails.getInputGuardrail().init(xContentRegistry, client);
this.guardrails.getInputGuardrail().init(xContentRegistry, client, sdkClient, tenantId);
}
if (this.guardrails != null && this.guardrails.getOutputGuardrail() != null) {
this.guardrails.getOutputGuardrail().init(xContentRegistry, client);
this.guardrails.getOutputGuardrail().init(xContentRegistry, client, sdkClient, tenantId);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.remote.metadata.client.SdkClient;
import org.opensearch.transport.client.Client;

import lombok.Builder;
Expand All @@ -58,6 +59,8 @@ public class ModelGuardrail extends Guardrail {
private String responseAccept;
private NamedXContentRegistry xContentRegistry;
private Client client;
private SdkClient sdkClient;
private String tenantId;
private Pattern regexAcceptPattern;

@Builder(toBuilder = true)
Expand Down Expand Up @@ -141,9 +144,11 @@ public Boolean validate(String in, Map<String, String> parameters) {
}

@Override
public void init(NamedXContentRegistry xContentRegistry, Client client) {
public void init(NamedXContentRegistry xContentRegistry, Client client, SdkClient sdkClient, String tenantId) {
this.xContentRegistry = xContentRegistry;
this.client = client;
this.sdkClient = sdkClient;
this.tenantId = tenantId;
regexAcceptPattern = Pattern.compile(responseAccept);
}

Expand Down
Loading
Loading