Skip to content

Commit 898d64b

Browse files
committed
add javadoc & apply spotless
Signed-off-by: seungwon cho <[email protected]>
1 parent 023f528 commit 898d64b

File tree

2 files changed

+48
-20
lines changed

2 files changed

+48
-20
lines changed

plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,6 @@
5959
import lombok.experimental.FieldDefaults;
6060
import lombok.extern.log4j.Log4j2;
6161

62-
import java.util.HashMap;
63-
import java.util.Map;
64-
6562
@Log4j2
6663
@FieldDefaults(level = AccessLevel.PRIVATE)
6764
public class TransportPredictionTaskAction extends HandledTransportAction<ActionRequest, MLTaskResponse> {

plugin/src/main/java/org/opensearch/ml/prompt/MLPromptManager.java

Lines changed: 48 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -278,13 +278,18 @@ public void resolvePromptGroup(MLExecuteTaskRequest request) throws IOException
278278
if (!inputParameters.containsKey("prompt_group")) {
279279
return;
280280
}
281-
resolvePromptGroups(inputParameters);
281+
resolvePromptGroup(inputParameters);
282282

283283
((RemoteInferenceInputDataSet) inputDataset).setParameters(inputParameters);
284284
mlInput.setInputDataset(inputDataset);
285285
request = MLExecuteTaskRequest.builder().functionName(request.getFunctionName()).input(mlInput).build();
286286
}
287287

288+
/**
289+
* Resolve single prompt that has following pull_prompt syntax: pull_prompt(id).key
290+
*
291+
* @param inputParameters parameters fetched from Input Dataset
292+
*/
288293
private void resolveSinglePrompts(Map<String, String> inputParameters) throws IOException {
289294
List<String> unresolvedPrompts = ML_PROMPT_MATCHING_KEYS
290295
.values()
@@ -302,7 +307,12 @@ private void resolveSinglePrompts(Map<String, String> inputParameters) throws IO
302307
}
303308
}
304309

305-
private void resolvePromptGroups(Map<String, String> inputParameters) {
310+
/**
311+
* Resolve prompt group that has following pull_prompt syntax: Either pull_prompt(id) or pull_prompt(id, [filter_list])
312+
*
313+
* @param inputParameters parameters fetched from Input Dataset
314+
*/
315+
private void resolvePromptGroup(Map<String, String> inputParameters) {
306316
List<String> promptGroupParameters = validatePullPromptGroupSyntax(inputParameters.get("prompt_group"));
307317
String nameOrID = promptGroupParameters.getFirst();
308318
String promptId = resolvePromptID(nameOrID, null);
@@ -377,8 +387,13 @@ public void buildInputParameters(
377387
* @param promptParam Prompt Parameters field that holds user-defined values to placeholder variables
378388
* @param tenantId tenant id
379389
*/
380-
private void handlePromptField(Map<String, String> parameters, String promptType, String promptContent, PromptParameters promptParam, String tenantId)
381-
throws IOException {
390+
private void handlePromptField(
391+
Map<String, String> parameters,
392+
String promptType,
393+
String promptContent,
394+
PromptParameters promptParam,
395+
String tenantId
396+
) throws IOException {
382397
Tuple<String, String> IDAndKey = validatePullPromptSyntax(promptContent);
383398
String promptId = IDAndKey.v1();
384399
String key = IDAndKey.v2();
@@ -436,9 +451,18 @@ private static Tuple<String, String> validatePullPromptSyntax(String input) {
436451
);
437452
}
438453

454+
/**
455+
* Validate pull_prompt syntax and Retrieves prompt reference and key that are needed to retrieve a specific prompt
456+
*
457+
* @param content content that contains pull_prompt syntax alongside prompt reference and key
458+
* @return List that contains prompt reference and list of keys user wants to extract
459+
* @throws InvalidPullPromptSyntaxException if invalid syntax is provided
460+
*/
439461
private static List<String> validatePullPromptGroupSyntax(String content) {
440462
if (content != null && content.contains("pull_prompt(")) {
463+
// e.g. pull_prompt(prompt_id)
441464
String pullPromptWithOnlyIDRegex = "pull_prompt\\(\\s*([a-zA-Z0-9_\\-]+)\\s*\\)";
465+
// e.g. pull_prompt(prompt_id, [filter_list])
442466
String pullPromptWithIDAndFilterListRegex =
443467
"pull_prompt\\(\\s*([a-zA-Z0-9_\\-]+)\\s*,\\s*\\[\\s*([a-zA-Z0-9_\\-\\s,]*)\\s*]\\s*\\)";
444468

@@ -448,13 +472,13 @@ private static List<String> validatePullPromptGroupSyntax(String content) {
448472
Matcher matcherWithOnlyID = patternWithOnlyID.matcher(content);
449473
Matcher matcherWithIDAndFilterList = patternWithIDAndFilterList.matcher(content);
450474

451-
while (matcherWithOnlyID.find()) {
475+
if (matcherWithOnlyID.matches()) {
452476
String promptId = matcherWithOnlyID.group(1);
453477

454478
return List.of(promptId);
455479
}
456480

457-
while (matcherWithIDAndFilterList.find()) {
481+
if (matcherWithIDAndFilterList.matches()) {
458482
String promptId = matcherWithIDAndFilterList.group(1);
459483
String filterList = matcherWithIDAndFilterList.group(2);
460484

@@ -466,11 +490,14 @@ private static List<String> validatePullPromptGroupSyntax(String content) {
466490
.filter(s -> !s.isEmpty())
467491
.toArray(String[]::new);
468492

469-
List<String> filterListArray = Arrays.asList(promptList);
470-
return filterListArray;
493+
return Arrays.asList(promptList);
471494
}
472495
}
473-
throw new IllegalArgumentException("Wrong pull_prompt syntax is provided: " + content);
496+
throw new InvalidPullPromptSyntaxException(
497+
"Invalid pull_prompt syntax is provided: "
498+
+ content
499+
+ ". Expected: pull_prompt(prompt_id) or pull_prompt(prompt_id, [filtered_list])"
500+
);
474501
}
475502

476503
/**
@@ -491,7 +518,7 @@ private static List<String> validatePullPromptGroupSyntax(String content) {
491518
* </p>
492519
* @throws OpenSearchStatusException if the ML Prompt is not found
493520
*/
494-
public PromptResult pullPrompt(String promptRef, String key, PromptParameters promptParameters, String tenantId) throws IOException {
521+
public PromptResult pullPrompt(String promptRef, String key, PromptParameters promptParameters, String tenantId) {
495522
try {
496523
String promptId = resolvePromptID(promptRef, tenantId);
497524
GetDataObjectRequest getDataObjectRequest = GetDataObjectRequest
@@ -584,9 +611,13 @@ private String populatePlaceholders(String content, PromptParameters promptParam
584611
public MLPrompt getPrompt(GetDataObjectRequest getDataObjectRequest) {
585612
GetDataObjectResponse getPromptResponse = sdkClient.getDataObject(getDataObjectRequest);
586613
try (
587-
XContentParser parser = XContentType.JSON
588-
.xContent()
589-
.createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, getPromptResponse.getResponse().getSourceAsString())
614+
XContentParser parser = XContentType.JSON
615+
.xContent()
616+
.createParser(
617+
NamedXContentRegistry.EMPTY,
618+
LoggingDeprecationHandler.INSTANCE,
619+
getPromptResponse.getResponse().getSourceAsString()
620+
)
590621
) {
591622
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
592623
return MLPrompt.parse(parser);
@@ -663,9 +694,9 @@ static class Message {
663694
"role": "user",
664695
"content": "pull_prompt(prompt_id).<key>"
665696
}
666-
697+
667698
After parsing:
668-
699+
669700
this.role = user
670701
this.content = pull_prompt(prompt_id).<key>
671702
this.promptId = prompt_id
@@ -729,9 +760,9 @@ static class PromptParameters {
729760
"name": "jeff"
730761
}
731762
}
732-
763+
733764
After parsing:
734-
765+
735766
this.parameters = Map.of("name", "jeff")
736767
*/
737768
private final Map<String, Map<String, String>> parameters;

0 commit comments

Comments
 (0)