@@ -278,13 +278,18 @@ public void resolvePromptGroup(MLExecuteTaskRequest request) throws IOException
278
278
if (!inputParameters .containsKey ("prompt_group" )) {
279
279
return ;
280
280
}
281
- resolvePromptGroups (inputParameters );
281
+ resolvePromptGroup (inputParameters );
282
282
283
283
((RemoteInferenceInputDataSet ) inputDataset ).setParameters (inputParameters );
284
284
mlInput .setInputDataset (inputDataset );
285
285
request = MLExecuteTaskRequest .builder ().functionName (request .getFunctionName ()).input (mlInput ).build ();
286
286
}
287
287
288
+ /**
289
+ * Resolve single prompt that has following pull_prompt syntax: pull_prompt(id).key
290
+ *
291
+ * @param inputParameters parameters fetched from Input Dataset
292
+ */
288
293
private void resolveSinglePrompts (Map <String , String > inputParameters ) throws IOException {
289
294
List <String > unresolvedPrompts = ML_PROMPT_MATCHING_KEYS
290
295
.values ()
@@ -302,7 +307,12 @@ private void resolveSinglePrompts(Map<String, String> inputParameters) throws IO
302
307
}
303
308
}
304
309
305
- private void resolvePromptGroups (Map <String , String > inputParameters ) {
310
+ /**
311
+ * Resolve prompt group that has following pull_prompt syntax: Either pull_prompt(id) or pull_prompt(id, [filter_list])
312
+ *
313
+ * @param inputParameters parameters fetched from Input Dataset
314
+ */
315
+ private void resolvePromptGroup (Map <String , String > inputParameters ) {
306
316
List <String > promptGroupParameters = validatePullPromptGroupSyntax (inputParameters .get ("prompt_group" ));
307
317
String nameOrID = promptGroupParameters .getFirst ();
308
318
String promptId = resolvePromptID (nameOrID , null );
@@ -377,8 +387,13 @@ public void buildInputParameters(
377
387
* @param promptParam Prompt Parameters field that holds user-defined values to placeholder variables
378
388
* @param tenantId tenant id
379
389
*/
380
- private void handlePromptField (Map <String , String > parameters , String promptType , String promptContent , PromptParameters promptParam , String tenantId )
381
- throws IOException {
390
+ private void handlePromptField (
391
+ Map <String , String > parameters ,
392
+ String promptType ,
393
+ String promptContent ,
394
+ PromptParameters promptParam ,
395
+ String tenantId
396
+ ) throws IOException {
382
397
Tuple <String , String > IDAndKey = validatePullPromptSyntax (promptContent );
383
398
String promptId = IDAndKey .v1 ();
384
399
String key = IDAndKey .v2 ();
@@ -436,9 +451,18 @@ private static Tuple<String, String> validatePullPromptSyntax(String input) {
436
451
);
437
452
}
438
453
454
+ /**
455
+ * Validate pull_prompt syntax and Retrieves prompt reference and key that are needed to retrieve a specific prompt
456
+ *
457
+ * @param content content that contains pull_prompt syntax alongside prompt reference and key
458
+ * @return List that contains prompt reference and list of keys user wants to extract
459
+ * @throws InvalidPullPromptSyntaxException if invalid syntax is provided
460
+ */
439
461
private static List <String > validatePullPromptGroupSyntax (String content ) {
440
462
if (content != null && content .contains ("pull_prompt(" )) {
463
+ // e.g. pull_prompt(prompt_id)
441
464
String pullPromptWithOnlyIDRegex = "pull_prompt\\ (\\ s*([a-zA-Z0-9_\\ -]+)\\ s*\\ )" ;
465
+ // e.g. pull_prompt(prompt_id, [filter_list])
442
466
String pullPromptWithIDAndFilterListRegex =
443
467
"pull_prompt\\ (\\ s*([a-zA-Z0-9_\\ -]+)\\ s*,\\ s*\\ [\\ s*([a-zA-Z0-9_\\ -\\ s,]*)\\ s*]\\ s*\\ )" ;
444
468
@@ -448,13 +472,13 @@ private static List<String> validatePullPromptGroupSyntax(String content) {
448
472
Matcher matcherWithOnlyID = patternWithOnlyID .matcher (content );
449
473
Matcher matcherWithIDAndFilterList = patternWithIDAndFilterList .matcher (content );
450
474
451
- while (matcherWithOnlyID .find ()) {
475
+ if (matcherWithOnlyID .matches ()) {
452
476
String promptId = matcherWithOnlyID .group (1 );
453
477
454
478
return List .of (promptId );
455
479
}
456
480
457
- while (matcherWithIDAndFilterList .find ()) {
481
+ if (matcherWithIDAndFilterList .matches ()) {
458
482
String promptId = matcherWithIDAndFilterList .group (1 );
459
483
String filterList = matcherWithIDAndFilterList .group (2 );
460
484
@@ -466,11 +490,14 @@ private static List<String> validatePullPromptGroupSyntax(String content) {
466
490
.filter (s -> !s .isEmpty ())
467
491
.toArray (String []::new );
468
492
469
- List <String > filterListArray = Arrays .asList (promptList );
470
- return filterListArray ;
493
+ return Arrays .asList (promptList );
471
494
}
472
495
}
473
- throw new IllegalArgumentException ("Wrong pull_prompt syntax is provided: " + content );
496
+ throw new InvalidPullPromptSyntaxException (
497
+ "Invalid pull_prompt syntax is provided: "
498
+ + content
499
+ + ". Expected: pull_prompt(prompt_id) or pull_prompt(prompt_id, [filtered_list])"
500
+ );
474
501
}
475
502
476
503
/**
@@ -491,7 +518,7 @@ private static List<String> validatePullPromptGroupSyntax(String content) {
491
518
* </p>
492
519
* @throws OpenSearchStatusException if the ML Prompt is not found
493
520
*/
494
- public PromptResult pullPrompt (String promptRef , String key , PromptParameters promptParameters , String tenantId ) throws IOException {
521
+ public PromptResult pullPrompt (String promptRef , String key , PromptParameters promptParameters , String tenantId ) {
495
522
try {
496
523
String promptId = resolvePromptID (promptRef , tenantId );
497
524
GetDataObjectRequest getDataObjectRequest = GetDataObjectRequest
@@ -584,9 +611,13 @@ private String populatePlaceholders(String content, PromptParameters promptParam
584
611
public MLPrompt getPrompt (GetDataObjectRequest getDataObjectRequest ) {
585
612
GetDataObjectResponse getPromptResponse = sdkClient .getDataObject (getDataObjectRequest );
586
613
try (
587
- XContentParser parser = XContentType .JSON
588
- .xContent ()
589
- .createParser (NamedXContentRegistry .EMPTY , LoggingDeprecationHandler .INSTANCE , getPromptResponse .getResponse ().getSourceAsString ())
614
+ XContentParser parser = XContentType .JSON
615
+ .xContent ()
616
+ .createParser (
617
+ NamedXContentRegistry .EMPTY ,
618
+ LoggingDeprecationHandler .INSTANCE ,
619
+ getPromptResponse .getResponse ().getSourceAsString ()
620
+ )
590
621
) {
591
622
ensureExpectedToken (XContentParser .Token .START_OBJECT , parser .nextToken (), parser );
592
623
return MLPrompt .parse (parser );
@@ -663,9 +694,9 @@ static class Message {
663
694
"role": "user",
664
695
"content": "pull_prompt(prompt_id).<key>"
665
696
}
666
-
697
+
667
698
After parsing:
668
-
699
+
669
700
this.role = user
670
701
this.content = pull_prompt(prompt_id).<key>
671
702
this.promptId = prompt_id
@@ -729,9 +760,9 @@ static class PromptParameters {
729
760
"name": "jeff"
730
761
}
731
762
}
732
-
763
+
733
764
After parsing:
734
-
765
+
735
766
this.parameters = Map.of("name", "jeff")
736
767
*/
737
768
private final Map <String , Map <String , String >> parameters ;
0 commit comments