Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,12 @@ public MLCreatePromptInput(
if (prompt == null || prompt.isEmpty()) {
throw new IllegalArgumentException("MLPrompt prompt field cannot be empty or null");
}
if (!prompt.containsKey(PROMPT_FIELD_SYSTEM_PROMPT)) {
/*if (!prompt.containsKey(PROMPT_FIELD_SYSTEM_PROMPT)) {
throw new IllegalArgumentException("MLPrompt prompt field requires " + PROMPT_FIELD_SYSTEM_PROMPT + " parameter");
}
if (!prompt.containsKey(PROMPT_FIELD_USER_PROMPT)) {
throw new IllegalArgumentException("MLPrompt prompt field requires " + PROMPT_FIELD_USER_PROMPT + " parameter");
}
}*/

this.name = name;
this.description = description;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,8 @@ public Collection<Object> createComponents(
mlTaskDispatcher,
mlCircuitBreakerService,
nodeHelper,
mlEngine
mlEngine,
mlPromptManager
);

// Register thread-safe ML objects here.
Expand Down
181 changes: 175 additions & 6 deletions plugin/src/main/java/org/opensearch/ml/prompt/MLPromptManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicReference;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

Expand All @@ -38,8 +40,12 @@
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.prompt.MLPrompt;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest;
import org.opensearch.ml.common.transport.prompt.MLCreatePromptInput;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.ml.utils.MLExceptionUtils;
Expand Down Expand Up @@ -74,6 +80,24 @@ public class MLPromptManager {
public static final String ROLE_PARAMETER = "role";
public static final String CONTENT_PARAMETER = "content";

public static final Map<String, String> ML_PROMPT_MATCHING_KEYS = Map
.of(
"planner_template",
"planner_prompt_template",
"planner",
"planner_prompt",
"reflect_template",
"reflect_prompt_template",
"reflect",
"reflect_prompt",
"planner_with_history_template",
"planner_with_history_template",
"system",
"system_prompt",
"executor_system",
"executor_prompt_system"
);

public static final String PARAMETERS_PROMPT_FIELD = "prompt";
public static final String PARAMETERS_MESSAGES_FIELD = "messages";
public static final String PARAMETERS_PROMPT_PARAMETERS_FIELD = "prompt_parameters";
Expand Down Expand Up @@ -221,6 +245,97 @@ public static boolean MLPromptNameAlreadyExists(SearchResponse searchResponse) {
&& searchResponse.getHits().getTotalHits().value() != 0;
}

public static Map<String, String> extractInputParameter(MLInput input) {
MLInputDataset inputDataset = input.getInputDataset();
return inputDataset instanceof RemoteInferenceInputDataSet
? ((RemoteInferenceInputDataSet) inputDataset).getParameters()
: new HashMap<>();
}

/**
* Resolves a prompt group reference by fetching all associated prompts
* linked to the given group ID or name.
*
* <p>
* This method checks if prompt_group field and correct pull_prompt syntax
* are present in the execution request parameters. It queries the prompt store
* to retrieve all prompts under the specified prompt reference,
* and returns them in the defined order.</p>
* </p>
*
* @param request MLExecuteTaskRequest that contains execution request parameters
* @throws IOException if failed parsing prompt body
*/
public void resolvePromptGroup(MLExecuteTaskRequest request) throws IOException {
MLInput mlInput = (MLInput) request.getInput();
MLInputDataset inputDataset = mlInput.getInputDataset();
Map<String, String> inputParameters = inputDataset instanceof RemoteInferenceInputDataSet
? ((RemoteInferenceInputDataSet) inputDataset).getParameters()
: new HashMap<>();

resolveSinglePrompts(inputParameters);

if (!inputParameters.containsKey("prompt_group")) {
return;
}
resolvePromptGroup(inputParameters);

((RemoteInferenceInputDataSet) inputDataset).setParameters(inputParameters);
mlInput.setInputDataset(inputDataset);
request = MLExecuteTaskRequest.builder().functionName(request.getFunctionName()).input(mlInput).build();
}

/**
* Resolve single prompt that has following pull_prompt syntax: pull_prompt(id).key
*
* @param inputParameters parameters fetched from Input Dataset
*/
private void resolveSinglePrompts(Map<String, String> inputParameters) throws IOException {
List<String> unresolvedPrompts = ML_PROMPT_MATCHING_KEYS
.values()
.stream()
.filter(inputParameters::containsKey)
.filter(value -> inputParameters.get(value).contains("pull_prompt"))
.toList();
if (unresolvedPrompts.isEmpty()) {
// if the list is empty, then the request does not contain any single pull prompts to resolve
return;
}

for (String prompt : unresolvedPrompts) {
handlePromptField(inputParameters, prompt, inputParameters.get(prompt), null, null);
}
}

/**
* Resolve prompt group that has following pull_prompt syntax: Either pull_prompt(id) or pull_prompt(id, [filter_list])
*
* @param inputParameters parameters fetched from Input Dataset
*/
private void resolvePromptGroup(Map<String, String> inputParameters) {
List<String> promptGroupParameters = validatePullPromptGroupSyntax(inputParameters.get("prompt_group"));
String nameOrID = promptGroupParameters.getFirst();
String promptId = resolvePromptID(nameOrID, null);
inputParameters.remove("prompt_group"); // don't need this no more

AtomicReference<List<String>> reference = new AtomicReference<>(promptGroupParameters);

GetDataObjectRequest getDataObjectRequest = GetDataObjectRequest.builder().index(ML_PROMPT_INDEX).id(promptId).build();
MLPrompt mlPrompt = getPrompt(getDataObjectRequest);
Map<String, String> promptTemplate = mlPrompt.getPrompt();
for (String key : promptTemplate.keySet()) {
if (reference.get().size() == 1 || reference.get().contains(key)) {
String matchingKey = ML_PROMPT_MATCHING_KEYS.get(key);
// If a key exists in both the prompt group and request parameters, the value from the
// value from the request parameters takes precedence.
if (inputParameters.containsKey(matchingKey)) {
continue;
}
inputParameters.put(matchingKey, promptTemplate.get(key));
}
}
}

/**
* Builds a new map containing modified request body after pull_prompt is invoked
*
Expand All @@ -246,7 +361,7 @@ public void buildInputParameters(
PromptParameters promptParam = PromptParameters.buildPromptParameters(JsonStrPromptParameters);
switch (promptType) {
case PARAMETERS_PROMPT_FIELD:
handlePromptField(parameters, inputContent, promptParam, tenantId);
handlePromptField(parameters, PARAMETERS_PROMPT_FIELD, inputContent, promptParam, tenantId);
break;
case PARAMETERS_MESSAGES_FIELD:
handleMessagesField(parameters, inputContent, promptParam, tenantId);
Expand All @@ -272,13 +387,18 @@ public void buildInputParameters(
* @param promptParam Prompt Parameters field that holds user-defined values to placeholder variables
* @param tenantId tenant id
*/
private void handlePromptField(Map<String, String> parameters, String promptContent, PromptParameters promptParam, String tenantId)
throws IOException {
private void handlePromptField(
Map<String, String> parameters,
String promptType,
String promptContent,
PromptParameters promptParam,
String tenantId
) throws IOException {
Tuple<String, String> IDAndKey = validatePullPromptSyntax(promptContent);
String promptId = IDAndKey.v1();
String key = IDAndKey.v2();
PromptResult promptResult = pullPrompt(promptId, key, promptParam, tenantId);
parameters.put(PARAMETERS_PROMPT_FIELD, promptResult.getContent());
parameters.put(promptType, promptResult.getContent());
}

/**
Expand Down Expand Up @@ -331,6 +451,55 @@ private static Tuple<String, String> validatePullPromptSyntax(String input) {
);
}

/**
* Validate pull_prompt syntax and Retrieves prompt reference and key that are needed to retrieve a specific prompt
*
* @param content content that contains pull_prompt syntax alongside prompt reference and key
* @return List that contains prompt reference and list of keys user wants to extract
* @throws InvalidPullPromptSyntaxException if invalid syntax is provided
*/
private static List<String> validatePullPromptGroupSyntax(String content) {
if (content != null && content.contains("pull_prompt(")) {
// e.g. pull_prompt(prompt_id)
String pullPromptWithOnlyIDRegex = "pull_prompt\\(\\s*([a-zA-Z0-9_\\-]+)\\s*\\)";
// e.g. pull_prompt(prompt_id, [filter_list])
String pullPromptWithIDAndFilterListRegex =
"pull_prompt\\(\\s*([a-zA-Z0-9_\\-]+)\\s*,\\s*\\[\\s*([a-zA-Z0-9_\\-\\s,]*)\\s*]\\s*\\)";

Pattern patternWithOnlyID = Pattern.compile(pullPromptWithOnlyIDRegex);
Pattern patternWithIDAndFilterList = Pattern.compile(pullPromptWithIDAndFilterListRegex);

Matcher matcherWithOnlyID = patternWithOnlyID.matcher(content);
Matcher matcherWithIDAndFilterList = patternWithIDAndFilterList.matcher(content);

if (matcherWithOnlyID.matches()) {
String promptId = matcherWithOnlyID.group(1);

return List.of(promptId);
}

if (matcherWithIDAndFilterList.matches()) {
String promptId = matcherWithIDAndFilterList.group(1);
String filterList = matcherWithIDAndFilterList.group(2);

// Split variables by comma and trim whitespace
filterList = promptId + "," + filterList;
String[] promptList = Arrays
.stream(filterList.split(","))
.map(String::trim)
.filter(s -> !s.isEmpty())
.toArray(String[]::new);

return Arrays.asList(promptList);
}
}
throw new InvalidPullPromptSyntaxException(
"Invalid pull_prompt syntax is provided: "
+ content
+ ". Expected: pull_prompt(prompt_id) or pull_prompt(prompt_id, [filtered_list])"
);
}

/**
* Fetches the ML Prompt based on prompt id, then replace the content with the retrieved content from prompt
* template based on the specified key.
Expand All @@ -349,7 +518,7 @@ private static Tuple<String, String> validatePullPromptSyntax(String input) {
* </p>
* @throws OpenSearchStatusException if the ML Prompt is not found
*/
public PromptResult pullPrompt(String promptRef, String key, PromptParameters promptParameters, String tenantId) throws IOException {
public PromptResult pullPrompt(String promptRef, String key, PromptParameters promptParameters, String tenantId) {
try {
String promptId = resolvePromptID(promptRef, tenantId);
GetDataObjectRequest getDataObjectRequest = GetDataObjectRequest
Expand Down Expand Up @@ -418,7 +587,7 @@ private String resolvePromptID(String promptRef, String tenantId) {
* @return
*/
private String populatePlaceholders(String content, PromptParameters promptParameters, String promptRef) {
if (!promptParameters.isEmpty() && content.contains(PROMPT_PARAMETER_PLACEHOLDER)) {
if (promptParameters != null && !promptParameters.isEmpty() && content.contains(PROMPT_PARAMETER_PLACEHOLDER)) {
StringSubstitutor substitutor = new StringSubstitutor(
promptParameters.getParameters(promptRef),
PROMPT_PARAMETER_PLACEHOLDER,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.opensearch.ml.common.transport.execute.MLExecuteTaskResponse;
import org.opensearch.ml.engine.MLEngine;
import org.opensearch.ml.engine.indices.MLInputDatasetHandler;
import org.opensearch.ml.prompt.MLPromptManager;
import org.opensearch.ml.stats.ActionName;
import org.opensearch.ml.stats.MLActionLevelStat;
import org.opensearch.ml.stats.MLNodeLevelStat;
Expand All @@ -42,6 +43,7 @@ public class MLExecuteTaskRunner extends MLTaskRunner<MLExecuteTaskRequest, MLEx
protected final DiscoveryNodeHelper nodeHelper;
private final MLEngine mlEngine;
private volatile Boolean isPythonModelEnabled;
private final MLPromptManager mlPromptManager;

public MLExecuteTaskRunner(
ThreadPool threadPool,
Expand All @@ -53,7 +55,8 @@ public MLExecuteTaskRunner(
MLTaskDispatcher mlTaskDispatcher,
MLCircuitBreakerService mlCircuitBreakerService,
DiscoveryNodeHelper nodeHelper,
MLEngine mlEngine
MLEngine mlEngine,
MLPromptManager mlPromptManager
) {
super(mlTaskManager, mlStats, nodeHelper, mlTaskDispatcher, mlCircuitBreakerService, clusterService);
this.threadPool = threadPool;
Expand All @@ -66,6 +69,7 @@ public MLExecuteTaskRunner(
this.clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ML_COMMONS_ENABLE_INHOUSE_PYTHON_MODEL, it -> isPythonModelEnabled = it);
this.mlPromptManager = mlPromptManager;
}

@Override
Expand Down Expand Up @@ -94,6 +98,7 @@ protected void executeTask(MLExecuteTaskRequest request, ActionListener<MLExecut
.increment();

// ActionListener<MLExecuteTaskResponse> wrappedListener = ActionListener.runBefore(listener, )
mlPromptManager.resolvePromptGroup(request);
Input input = request.getInput();
FunctionName functionName = request.getFunctionName();
if (FunctionName.METRICS_CORRELATION.equals(functionName)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
import org.opensearch.ml.engine.indices.MLInputDatasetHandler;
import org.opensearch.ml.prompt.MLPromptManager;
import org.opensearch.ml.stats.MLNodeLevelStat;
import org.opensearch.ml.stats.MLStat;
import org.opensearch.ml.stats.MLStats;
Expand Down Expand Up @@ -76,6 +77,8 @@ public class MLExecuteTaskRunnerTests extends OpenSearchTestCase {
DiscoveryNodeHelper nodeHelper;
@Mock
ClusterApplierService clusterApplierService;
@Mock
MLPromptManager mlPromptManager;

@Rule
public ExpectedException exceptionRule = ExpectedException.none();
Expand Down Expand Up @@ -136,7 +139,8 @@ public void setup() {
mlTaskDispatcher,
mlCircuitBreakerService,
nodeHelper,
mlEngine
mlEngine,
mlPromptManager
)
);

Expand Down
Loading