Skip to content

Commit 248006f

Browse files
committed
Introduce sdk client to LocalRegexGuardrail
Signed-off-by: Yuanchun Shen <[email protected]>
1 parent 71d47e9 commit 248006f

File tree

8 files changed

+82
-54
lines changed

8 files changed

+82
-54
lines changed

common/src/main/java/org/opensearch/ml/common/model/Guardrail.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.opensearch.core.common.io.stream.StreamOutput;
1212
import org.opensearch.core.xcontent.NamedXContentRegistry;
1313
import org.opensearch.core.xcontent.ToXContentObject;
14+
import org.opensearch.remote.metadata.client.SdkClient;
1415
import org.opensearch.transport.client.Client;
1516

1617
public abstract class Guardrail implements ToXContentObject {
@@ -19,5 +20,5 @@ public abstract class Guardrail implements ToXContentObject {
1920

2021
public abstract Boolean validate(String input, Map<String, String> parameters);
2122

22-
public abstract void init(NamedXContentRegistry xContentRegistry, Client client);
23+
public abstract void init(NamedXContentRegistry xContentRegistry, Client client, SdkClient sdkClient, String tenantId);
2324
}

common/src/main/java/org/opensearch/ml/common/model/LocalRegexGuardrail.java

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import java.util.stream.Collectors;
2626

2727
import org.opensearch.action.LatchedActionListener;
28-
import org.opensearch.action.search.SearchRequest;
2928
import org.opensearch.action.search.SearchResponse;
3029
import org.opensearch.common.util.concurrent.ThreadContext;
3130
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
@@ -36,6 +35,9 @@
3635
import org.opensearch.core.xcontent.NamedXContentRegistry;
3736
import org.opensearch.core.xcontent.XContentBuilder;
3837
import org.opensearch.core.xcontent.XContentParser;
38+
import org.opensearch.remote.metadata.client.SdkClient;
39+
import org.opensearch.remote.metadata.client.SearchDataObjectRequest;
40+
import org.opensearch.remote.metadata.common.SdkClientUtils;
3941
import org.opensearch.search.builder.SearchSourceBuilder;
4042
import org.opensearch.transport.client.Client;
4143

@@ -58,6 +60,8 @@ public class LocalRegexGuardrail extends Guardrail {
5860
private Map<String, List<String>> stopWordsIndicesInput;
5961
private NamedXContentRegistry xContentRegistry;
6062
private Client client;
63+
private SdkClient sdkClient;
64+
private String tenantId;
6165

6266
@Builder(toBuilder = true)
6367
public LocalRegexGuardrail(List<StopWords> stopWords, String[] regex) {
@@ -109,9 +113,11 @@ public Boolean validate(String input, Map<String, String> parameters) {
109113
}
110114

111115
@Override
112-
public void init(NamedXContentRegistry xContentRegistry, Client client) {
116+
public void init(NamedXContentRegistry xContentRegistry, Client client, SdkClient sdkClient, String tenantId) {
113117
this.xContentRegistry = xContentRegistry;
114118
this.client = client;
119+
this.sdkClient = sdkClient;
120+
this.tenantId = tenantId;
115121
init();
116122
}
117123

@@ -211,8 +217,7 @@ public Boolean validateStopWords(String input, Map<String, List<String>> stopWor
211217
* @return true if no stop words matching, otherwise false.
212218
*/
213219
public Boolean validateStopWordsSingleIndex(String input, String indexName, List<String> fieldNames) {
214-
SearchRequest searchRequest;
215-
AtomicBoolean hitStopWords = new AtomicBoolean(false);
220+
AtomicBoolean passedStopWorkCheck = new AtomicBoolean(false);
216221
String queryBody;
217222
Map<String, String> documentMap = new HashMap<>();
218223
for (String field : fieldNames) {
@@ -230,32 +235,35 @@ public Boolean validateStopWordsSingleIndex(String input, String indexName, List
230235
.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, queryBody);
231236
searchSourceBuilder.parseXContent(queryParser);
232237
searchSourceBuilder.size(1); // Only need 1 doc returned, if hit.
233-
searchRequest = new SearchRequest().source(searchSourceBuilder).indices(indexName);
238+
var responseListener = new LatchedActionListener<>(ActionListener.<SearchResponse>wrap(r -> {
239+
if (r == null || r.getHits() == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value() == 0) {
240+
passedStopWorkCheck.set(true);
241+
}
242+
}, e -> {
243+
log.error("Failed to search stop words index {}", indexName, e);
244+
passedStopWorkCheck.set(true);
245+
}), latch);
246+
SearchDataObjectRequest searchDataObjectRequest = SearchDataObjectRequest
247+
.builder()
248+
.indices(indexName)
249+
.searchSourceBuilder(searchSourceBuilder)
250+
.tenantId(tenantId)
251+
.build();
234252
if (isStopWordsSystemIndex(indexName)) {
235253
context = client.threadPool().getThreadContext().stashContext();
236254
ThreadContext.StoredContext finalContext = context;
237-
client.search(searchRequest, ActionListener.runBefore(new LatchedActionListener(ActionListener.<SearchResponse>wrap(r -> {
238-
if (r == null || r.getHits() == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value() == 0) {
239-
hitStopWords.set(true);
240-
}
241-
}, e -> {
242-
log.error("Failed to search stop words index {}", indexName, e);
243-
hitStopWords.set(true);
244-
}), latch), () -> finalContext.restore()));
255+
sdkClient
256+
.searchDataObjectAsync(searchDataObjectRequest)
257+
.whenComplete(SdkClientUtils.wrapSearchCompletion(ActionListener.runBefore(responseListener, finalContext::restore)));
245258
} else {
246-
client.search(searchRequest, new LatchedActionListener(ActionListener.<SearchResponse>wrap(r -> {
247-
if (r == null || r.getHits() == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value() == 0) {
248-
hitStopWords.set(true);
249-
}
250-
}, e -> {
251-
log.error("Failed to search stop words index {}", indexName, e);
252-
hitStopWords.set(true);
253-
}), latch));
259+
sdkClient
260+
.searchDataObjectAsync(searchDataObjectRequest)
261+
.whenComplete(SdkClientUtils.wrapSearchCompletion(responseListener));
254262
}
255263
} catch (Exception e) {
256264
log.error("[validateStopWords] Searching stop words index failed.", e);
257265
latch.countDown();
258-
hitStopWords.set(true);
266+
passedStopWorkCheck.set(true);
259267
} finally {
260268
if (context != null) {
261269
context.close();
@@ -268,7 +276,7 @@ public Boolean validateStopWordsSingleIndex(String input, String indexName, List
268276
log.error("[validateStopWords] Searching stop words index was timeout.", e);
269277
throw new IllegalStateException(e);
270278
}
271-
return hitStopWords.get();
279+
return passedStopWorkCheck.get();
272280
}
273281

274282
private boolean isStopWordsSystemIndex(String index) {

common/src/main/java/org/opensearch/ml/common/model/MLGuard.java

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import java.util.Map;
99

1010
import org.opensearch.core.xcontent.NamedXContentRegistry;
11+
import org.opensearch.remote.metadata.client.SdkClient;
1112
import org.opensearch.transport.client.Client;
1213

1314
import lombok.Getter;
@@ -18,17 +19,21 @@
1819
public class MLGuard {
1920
private NamedXContentRegistry xContentRegistry;
2021
private Client client;
22+
private final SdkClient sdkClient;
23+
private final String tenantId;
2124
private Guardrails guardrails;
2225

23-
public MLGuard(Guardrails guardrails, NamedXContentRegistry xContentRegistry, Client client) {
26+
public MLGuard(Guardrails guardrails, NamedXContentRegistry xContentRegistry, Client client, SdkClient sdkClient, String tenantId) {
2427
this.xContentRegistry = xContentRegistry;
2528
this.client = client;
29+
this.sdkClient = sdkClient;
30+
this.tenantId = tenantId;
2631
this.guardrails = guardrails;
2732
if (this.guardrails != null && this.guardrails.getInputGuardrail() != null) {
28-
this.guardrails.getInputGuardrail().init(xContentRegistry, client);
33+
this.guardrails.getInputGuardrail().init(xContentRegistry, client, sdkClient, tenantId);
2934
}
3035
if (this.guardrails != null && this.guardrails.getOutputGuardrail() != null) {
31-
this.guardrails.getOutputGuardrail().init(xContentRegistry, client);
36+
this.guardrails.getOutputGuardrail().init(xContentRegistry, client, sdkClient, tenantId);
3237
}
3338
}
3439

common/src/main/java/org/opensearch/ml/common/model/ModelGuardrail.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import org.opensearch.ml.common.transport.MLTaskResponse;
3838
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
3939
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
40+
import org.opensearch.remote.metadata.client.SdkClient;
4041
import org.opensearch.transport.client.Client;
4142

4243
import lombok.Builder;
@@ -58,6 +59,8 @@ public class ModelGuardrail extends Guardrail {
5859
private String responseAccept;
5960
private NamedXContentRegistry xContentRegistry;
6061
private Client client;
62+
private SdkClient sdkClient;
63+
private String tenantId;
6164
private Pattern regexAcceptPattern;
6265

6366
@Builder(toBuilder = true)
@@ -141,9 +144,11 @@ public Boolean validate(String in, Map<String, String> parameters) {
141144
}
142145

143146
@Override
144-
public void init(NamedXContentRegistry xContentRegistry, Client client) {
147+
public void init(NamedXContentRegistry xContentRegistry, Client client, SdkClient sdkClient, String tenantId) {
145148
this.xContentRegistry = xContentRegistry;
146149
this.client = client;
150+
this.sdkClient = sdkClient;
151+
this.tenantId = tenantId;
147152
regexAcceptPattern = Pattern.compile(responseAccept);
148153
}
149154

common/src/test/java/org/opensearch/ml/common/model/MLGuardTests.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import org.opensearch.common.settings.Settings;
2020
import org.opensearch.common.util.concurrent.ThreadContext;
2121
import org.opensearch.core.xcontent.NamedXContentRegistry;
22+
import org.opensearch.remote.metadata.client.SdkClient;
2223
import org.opensearch.search.SearchModule;
2324
import org.opensearch.threadpool.ThreadPool;
2425
import org.opensearch.transport.client.Client;
@@ -29,8 +30,11 @@ public class MLGuardTests {
2930
@Mock
3031
Client client;
3132
@Mock
33+
SdkClient sdkClient;
34+
@Mock
3235
ThreadPool threadPool;
3336
ThreadContext threadContext;
37+
String tenantId;
3438

3539
StopWords stopWords;
3640
String[] regex;
@@ -48,14 +52,15 @@ public void setUp() {
4852
this.threadContext = new ThreadContext(settings);
4953
when(this.client.threadPool()).thenReturn(this.threadPool);
5054
when(this.threadPool.getThreadContext()).thenReturn(this.threadContext);
55+
tenantId = "tenantId";
5156

5257
stopWords = new StopWords("test_index", List.of("test_field").toArray(new String[0]));
5358
regex = List.of("(.|\n)*stop words(.|\n)*").toArray(new String[0]);
5459
regexPatterns = List.of(Pattern.compile("(.|\n)*stop words(.|\n)*"));
5560
inputLocalRegexGuardrail = new LocalRegexGuardrail(List.of(stopWords), regex);
5661
outputLocalRegexGuardrail = new LocalRegexGuardrail(List.of(stopWords), regex);
5762
guardrails = new Guardrails("test_type", inputLocalRegexGuardrail, outputLocalRegexGuardrail);
58-
mlGuard = new MLGuard(guardrails, xContentRegistry, client);
63+
mlGuard = new MLGuard(guardrails, xContentRegistry, client, sdkClient, tenantId);
5964
}
6065

6166
@Test
@@ -74,7 +79,7 @@ public void validateInitializedStopWordsEmpty() {
7479
inputLocalRegexGuardrail = new LocalRegexGuardrail(List.of(stopWords), regex);
7580
outputLocalRegexGuardrail = new LocalRegexGuardrail(List.of(stopWords), regex);
7681
guardrails = new Guardrails("test_type", inputLocalRegexGuardrail, outputLocalRegexGuardrail);
77-
mlGuard = new MLGuard(guardrails, xContentRegistry, client);
82+
mlGuard = new MLGuard(guardrails, xContentRegistry, client, sdkClient, tenantId);
7883

7984
String input = "\n\nHuman:hello good words.\n\nAssistant:";
8085
Boolean res = mlGuard.validate(input, MLGuard.Type.INPUT, Collections.emptyMap());

common/src/test/java/org/opensearch/ml/common/model/ModelGuardrailTests.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,17 @@
2121
import org.opensearch.core.xcontent.XContentBuilder;
2222
import org.opensearch.core.xcontent.XContentParser;
2323
import org.opensearch.ml.common.TestHelper;
24+
import org.opensearch.remote.metadata.client.SdkClient;
2425
import org.opensearch.search.SearchModule;
2526
import org.opensearch.transport.client.Client;
2627

2728
public class ModelGuardrailTests {
2829
NamedXContentRegistry xContentRegistry;
2930
@Mock
3031
Client client;
32+
@Mock
33+
SdkClient sdkClient;
34+
String tenantId;
3135

3236
Pattern regexPattern;
3337
ModelGuardrail modelGuardrail;
@@ -38,6 +42,7 @@ public void setUp() {
3842
xContentRegistry = new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents());
3943
doNothing().when(this.client).execute(any(), any(), any());
4044
modelGuardrail = new ModelGuardrail("test_model_id", "$.test", "^accept$");
45+
tenantId = "tenant_id";
4146
regexPattern = Pattern.compile("^accept$");
4247
}
4348

@@ -70,7 +75,7 @@ public void validateParametersEmpty1() {
7075
@Test
7176
public void init() {
7277
Assert.assertNull(modelGuardrail.getRegexAcceptPattern());
73-
modelGuardrail.init(xContentRegistry, client);
78+
modelGuardrail.init(xContentRegistry, client, sdkClient, tenantId);
7479
Assert.assertEquals(regexPattern.toString(), modelGuardrail.getRegexAcceptPattern().toString());
7580
}
7681

plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1222,7 +1222,7 @@ public void deployModel(
12221222
}
12231223

12241224
setupRateLimiter(modelId, eligibleNodeCount, mlModel.getRateLimiter());
1225-
setupMLGuard(modelId, mlModel.getGuardrails());
1225+
setupMLGuard(modelId, tenantId, mlModel.getGuardrails());
12261226
setupModelInterface(modelId, mlModel.getModelInterface());
12271227
deployControllerWithDeployingModel(mlModel, eligibleNodeCount);
12281228
// check circuit breaker before deploying custom model chunks
@@ -1373,7 +1373,7 @@ public void deployModel(
13731373
}
13741374

13751375
setupRateLimiter(modelId, eligibleNodeCount, mlModel.getRateLimiter());
1376-
setupMLGuard(modelId, mlModel.getGuardrails());
1376+
setupMLGuard(modelId, mlModel.getTenantId(), mlModel.getGuardrails());
13771377
setupModelInterface(modelId, mlModel.getModelInterface());
13781378
deployControllerWithDeployingModel(mlModel, eligibleNodeCount);
13791379
// check circuit breaker before deploying custom model chunks
@@ -1438,7 +1438,7 @@ public void deployModel(
14381438
private void deployRemoteOrBuiltInModel(MLModel mlModel, Integer eligibleNodeCount, ActionListener<String> wrappedListener) {
14391439
String modelId = mlModel.getModelId();
14401440
setupRateLimiter(modelId, eligibleNodeCount, mlModel.getRateLimiter());
1441-
setupMLGuard(modelId, mlModel.getGuardrails());
1441+
setupMLGuard(modelId, mlModel.getTenantId(), mlModel.getGuardrails());
14421442
setupModelInterface(modelId, mlModel.getModelInterface());
14431443
if (mlModel.getConnector() != null || FunctionName.REMOTE != mlModel.getAlgorithm()) {
14441444
setupParamsAndPredictable(modelId, mlModel);
@@ -1461,12 +1461,12 @@ private void deployRemoteOrBuiltInModel(MLModel mlModel, Integer eligibleNodeCou
14611461
}
14621462

14631463
private void setupParamsAndPredictable(String modelId, MLModel mlModel) {
1464-
Map<String, Object> params = setUpParameterMap(modelId);
1464+
Map<String, Object> params = setUpParameterMap(modelId, mlModel.getTenantId());
14651465
Predictable predictable = mlEngine.deploy(mlModel, params);
14661466
modelCacheHelper.setPredictor(modelId, predictable);
14671467
}
14681468

1469-
private Map<String, Object> setUpParameterMap(String modelId) {
1469+
private Map<String, Object> setUpParameterMap(String modelId, String tenantId) {
14701470
TokenBucket rateLimiter = getRateLimiter(modelId);
14711471
Map<String, TokenBucket> userRateLimiterMap = getUserRateLimiterMap(modelId);
14721472
MLGuard mlGuard = getMLGuard(modelId);
@@ -1519,7 +1519,7 @@ public synchronized void updateModelCache(String modelId, ActionListener<String>
15191519
int eligibleNodeCount = getWorkerNodes(modelId, mlModel.getAlgorithm()).length;
15201520
modelCacheHelper.setIsModelEnabled(modelId, mlModel.getIsEnabled());
15211521
setupRateLimiter(modelId, eligibleNodeCount, mlModel.getRateLimiter());
1522-
setupMLGuard(modelId, mlModel.getGuardrails());
1522+
setupMLGuard(modelId, mlModel.getTenantId(), mlModel.getGuardrails());
15231523
setupModelInterface(modelId, mlModel.getModelInterface());
15241524
if (mlModel.getAlgorithm() == FunctionName.REMOTE) {
15251525
if (mlModel.getConnector() != null) {
@@ -1852,23 +1852,29 @@ public Map<String, String> getModelInterface(String modelId) {
18521852
* @param guardrails guardrail for the model
18531853
*/
18541854

1855-
private void setupMLGuard(String modelId, Guardrails guardrails) {
1855+
private void setupMLGuard(String modelId, String tenantId, Guardrails guardrails) {
18561856
if (guardrails != null) {
1857-
modelCacheHelper.setMLGuard(modelId, createMLGuard(guardrails, xContentRegistry, client));
1857+
modelCacheHelper.setMLGuard(modelId, createMLGuard(guardrails, xContentRegistry, client, sdkClient, tenantId));
18581858
} else {
18591859
modelCacheHelper.removeMLGuard(modelId);
18601860
}
18611861
}
18621862

1863-
private MLGuard createMLGuard(Guardrails guardrails, NamedXContentRegistry xContentRegistry, Client client) {
1864-
1865-
return new MLGuard(guardrails, xContentRegistry, client);
1863+
private MLGuard createMLGuard(
1864+
Guardrails guardrails,
1865+
NamedXContentRegistry xContentRegistry,
1866+
Client client,
1867+
SdkClient sdkClient,
1868+
String tenantId
1869+
) {
1870+
return new MLGuard(guardrails, xContentRegistry, client, sdkClient, tenantId);
18661871
}
18671872

18681873
/**
18691874
* Get ML guard with model id.
18701875
*
18711876
* @param modelId model id
1877+
*
18721878
* @return a ML guard
18731879
*/
18741880
public MLGuard getMLGuard(String modelId) {

0 commit comments

Comments
 (0)