diff --git a/common/src/main/java/org/opensearch/ml/common/prompt/MLPrompt.java b/common/src/main/java/org/opensearch/ml/common/prompt/MLPrompt.java index 202df7f321..8bee076da1 100644 --- a/common/src/main/java/org/opensearch/ml/common/prompt/MLPrompt.java +++ b/common/src/main/java/org/opensearch/ml/common/prompt/MLPrompt.java @@ -13,6 +13,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.function.BiFunction; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -24,29 +25,40 @@ import lombok.Builder; import lombok.EqualsAndHashCode; import lombok.Getter; +import lombok.Setter; /** * MLPrompt is the class to store prompt information. */ @Getter +@Setter @EqualsAndHashCode public class MLPrompt implements ToXContentObject, Writeable { + // fields public static final String PROMPT_ID_FIELD = "prompt_id"; public static final String NAME_FIELD = "name"; public static final String DESCRIPTION_FIELD = "description"; public static final String VERSION_FIELD = "version"; public static final String PROMPT_FIELD = "prompt"; + public static final String PROMPT_MANAGEMENT_TYPE_FIELD = "prompt_management_type"; // prompt management type -> MLPrompt or Langfuse public static final String TAGS_FIELD = "tags"; + public static final String PROMPT_EXTRA_CONFIG_FIELD = "extra_config"; public static final String CREATE_TIME_FIELD = "create_time"; public static final String LAST_UPDATE_TIME_FIELD = "last_update_time"; + // prompt management type + public static final String LANGFUSE = "LANGFUSE"; + public static final String MLPROMPT = "MLPROMPT"; + private String promptId; private String name; private String description; private String version; private Map prompt; + private String promptManagementType; private List tags; + private PromptExtraConfig promptExtraConfig; private String tenantId; private Instant createTime; private Instant lastUpdateTime; @@ -71,7 +83,9 @@ public MLPrompt( String description, String version, Map prompt, + String promptManagementType, List tags, + PromptExtraConfig promptExtraConfig, String tenantId, Instant createTime, Instant lastUpdateTime @@ -81,7 +95,9 @@ public MLPrompt( this.description = description; this.version = version; this.prompt = prompt; + this.promptManagementType = promptManagementType; this.tags = tags; + this.promptExtraConfig = promptExtraConfig; this.tenantId = tenantId; this.createTime = createTime; this.lastUpdateTime = lastUpdateTime; @@ -99,7 +115,9 @@ public MLPrompt(StreamInput input) throws IOException { this.description = input.readOptionalString(); this.version = input.readOptionalString(); this.prompt = input.readMap(s -> s.readString(), s -> s.readString()); + this.promptManagementType = input.readOptionalString(); this.tags = input.readList(StreamInput::readString); + this.promptExtraConfig = new PromptExtraConfig(input); this.tenantId = input.readOptionalString(); this.createTime = input.readOptionalInstant(); this.lastUpdateTime = input.readOptionalInstant(); @@ -118,7 +136,9 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalString(description); out.writeOptionalString(version); out.writeMap(prompt, StreamOutput::writeString, StreamOutput::writeString); + out.writeOptionalString(promptManagementType); out.writeCollection(tags, StreamOutput::writeString); + promptExtraConfig.writeTo(out); out.writeOptionalString(tenantId); out.writeOptionalInstant(createTime); out.writeOptionalInstant(lastUpdateTime); @@ -150,9 +170,15 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params if (prompt != null) { builder.field(PROMPT_FIELD, prompt); } + if (promptManagementType != null) { + builder.field(PROMPT_MANAGEMENT_TYPE_FIELD, promptManagementType); + } if (tags != null) { builder.field(TAGS_FIELD, tags); } + if (promptExtraConfig != null) { + builder.field(PROMPT_EXTRA_CONFIG_FIELD, promptExtraConfig); + } if (tenantId != null) { builder.field(TENANT_ID_FIELD, tenantId); } @@ -189,7 +215,9 @@ public static MLPrompt parse(XContentParser parser) throws IOException { String description = null; String version = null; Map prompt = null; + String promptManagementType = null; List tags = null; + PromptExtraConfig promptExtraConfig = null; String tenantId = null; Instant createTime = null; Instant lastUpdateTime = null; @@ -214,6 +242,9 @@ public static MLPrompt parse(XContentParser parser) throws IOException { case PROMPT_FIELD: prompt = parser.mapStrings(); break; + case PROMPT_MANAGEMENT_TYPE_FIELD: + promptManagementType = parser.text(); + break; case TAGS_FIELD: tags = new ArrayList<>(); ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); @@ -221,6 +252,9 @@ public static MLPrompt parse(XContentParser parser) throws IOException { tags.add(parser.text()); } break; + case PROMPT_EXTRA_CONFIG_FIELD: + promptExtraConfig = PromptExtraConfig.parse(parser); + break; case TENANT_ID_FIELD: tenantId = parser.text(); break; @@ -242,10 +276,40 @@ public static MLPrompt parse(XContentParser parser) throws IOException { .description(description) .version(version) .prompt(prompt) + .promptManagementType(promptManagementType) .tags(tags) + .promptExtraConfig(promptExtraConfig) .tenantId(tenantId) .createTime(createTime) .lastUpdateTime(lastUpdateTime) .build(); } + + public void encrypt(String promptManagementType, BiFunction function, String tenantId) { + if (promptManagementType.equalsIgnoreCase(LANGFUSE)) { + PromptExtraConfig promptExtraConfig = this.getPromptExtraConfig(); + String publicKey = promptExtraConfig.getPublicKey(); + String accessKey = this.getPromptExtraConfig().getAccessKey(); + + promptExtraConfig.setPublicKey(function.apply(publicKey, tenantId)); + promptExtraConfig.setAccessKey(function.apply(accessKey, tenantId)); + + this.setPromptExtraConfig(promptExtraConfig); + } + // add other prompt management client case here, if needed + } + + public void decrypt(String promptManagementType, BiFunction function, String tenantId) { + if (promptManagementType.equalsIgnoreCase(LANGFUSE)) { + PromptExtraConfig promptExtraConfig = this.getPromptExtraConfig(); + String publicKey = promptExtraConfig.getPublicKey(); + String accessKey = this.getPromptExtraConfig().getAccessKey(); + + promptExtraConfig.setPublicKey(function.apply(publicKey, tenantId)); + promptExtraConfig.setAccessKey(function.apply(accessKey, tenantId)); + + this.setPromptExtraConfig(promptExtraConfig); + } + // add other prompt management client case here, if needed + } } diff --git a/common/src/main/java/org/opensearch/ml/common/prompt/PromptExtraConfig.java b/common/src/main/java/org/opensearch/ml/common/prompt/PromptExtraConfig.java new file mode 100644 index 0000000000..43f686bc8b --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/prompt/PromptExtraConfig.java @@ -0,0 +1,112 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.prompt; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +import lombok.Builder; +import lombok.Data; + +@Data +public class PromptExtraConfig implements ToXContentObject, Writeable { + + public static final String LANGFUSE_PROMPT_TYPE_FIELD = "type"; + public static final String LANGFUSE_PROMPT_PUBLIC_KEY_FIELD = "public_key"; + public static final String LANGFUSE_PROMPT_ACCESS_KEY_FIELD = "access_key"; + public static final String LANGFUSE_PROMPT_LABELS_FIELD = "labels"; + + private String type; // required + private String publicKey; // required + private String accessKey; // required + private List labels; // optional + + @Builder(toBuilder = true) + public PromptExtraConfig(String type, String publicKey, String accessKey, List labels) { + this.type = type; + this.publicKey = publicKey; + this.accessKey = accessKey; + this.labels = labels; + } + + public PromptExtraConfig(StreamInput input) throws IOException { + this.type = input.readOptionalString(); + this.publicKey = input.readOptionalString(); + this.accessKey = input.readOptionalString(); + this.labels = input.readList(StreamInput::readString); + } + + @Override + public void writeTo(StreamOutput output) throws IOException { + output.writeOptionalString(type); + output.writeOptionalString(publicKey); + output.writeOptionalString(accessKey); + output.writeCollection(labels, StreamOutput::writeString); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (type != null) { + builder.field(LANGFUSE_PROMPT_TYPE_FIELD, type); + } + if (publicKey != null) { + builder.field(LANGFUSE_PROMPT_PUBLIC_KEY_FIELD, publicKey); + } + if (accessKey != null) { + builder.field(LANGFUSE_PROMPT_ACCESS_KEY_FIELD, accessKey); + } + if (labels != null) { + builder.field(LANGFUSE_PROMPT_LABELS_FIELD, labels); + } + return builder.endObject(); + } + + public static PromptExtraConfig parse(XContentParser parser) throws IOException { + String type = null; + String publicKey = null; + String accessKey = null; + List labels = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + switch (fieldName) { + case LANGFUSE_PROMPT_TYPE_FIELD: + type = parser.text(); + break; + case LANGFUSE_PROMPT_PUBLIC_KEY_FIELD: + publicKey = parser.text(); + break; + case LANGFUSE_PROMPT_ACCESS_KEY_FIELD: + accessKey = parser.text(); + break; + case LANGFUSE_PROMPT_LABELS_FIELD: + labels = new ArrayList<>(); + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + labels.add(parser.text()); + } + break; + default: + parser.skipChildren(); + break; + } + } + return PromptExtraConfig.builder().type(type).publicKey(publicKey).accessKey(accessKey).labels(labels).build(); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/prompt/MLCreatePromptInput.java b/common/src/main/java/org/opensearch/ml/common/transport/prompt/MLCreatePromptInput.java index 546f0005a4..aff4637ad7 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/prompt/MLCreatePromptInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/prompt/MLCreatePromptInput.java @@ -7,6 +7,7 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD; +import static org.opensearch.ml.common.prompt.MLPrompt.LANGFUSE; import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; import java.io.IOException; @@ -20,6 +21,7 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.prompt.PromptExtraConfig; import lombok.Builder; import lombok.Data; @@ -34,18 +36,22 @@ public class MLCreatePromptInput implements ToXContentObject, Writeable { public static final String PROMPT_DESCRIPTION_FIELD = "description"; public static final String PROMPT_VERSION_FIELD = "version"; public static final String PROMPT_PROMPT_FIELD = "prompt"; + public static final String PROMPT_MANAGEMENT_TYPE = "prompt_management_type"; public static final String PROMPT_TAGS_FIELD = "tags"; - public static final String PROMPT_FIELD_USER_PROMPT = "user"; - public static final String PROMPT_FIELD_SYSTEM_PROMPT = "system"; + public static final String PROMPT_EXTRA_CONFIG_FIELD = "extra_config"; + + public static final String PROMPT_VERSION_INITIAL_VERSION = "1"; private String name; private String description; private String version; private Map prompt; + private String promptManagementType; private List tags; @Setter private String tenantId; + private PromptExtraConfig promptExtraConfig; /** * Constructor to pass values to the MLCreatePromptInput constructor. @@ -63,8 +69,10 @@ public MLCreatePromptInput( String description, String version, Map prompt, + String promptManagementType, List tags, - String tenantId + String tenantId, + PromptExtraConfig promptExtraConfig ) { if (name == null) { throw new IllegalArgumentException("MLPrompt name field is null"); @@ -72,19 +80,29 @@ public MLCreatePromptInput( if (prompt == null || prompt.isEmpty()) { throw new IllegalArgumentException("MLPrompt prompt field cannot be empty or null"); } - if (!prompt.containsKey(PROMPT_FIELD_SYSTEM_PROMPT)) { - throw new IllegalArgumentException("MLPrompt prompt field requires " + PROMPT_FIELD_SYSTEM_PROMPT + " parameter"); - } - if (!prompt.containsKey(PROMPT_FIELD_USER_PROMPT)) { - throw new IllegalArgumentException("MLPrompt prompt field requires " + PROMPT_FIELD_USER_PROMPT + " parameter"); + if (promptExtraConfig != null && promptManagementType != null && promptManagementType.equalsIgnoreCase(LANGFUSE)) { + if (promptExtraConfig.getType() == null) { + throw new IllegalArgumentException("LangfusePrompt type field is null"); + } + if (promptExtraConfig.getPublicKey() == null) { + throw new IllegalArgumentException("LangfusePrompt Public Key field cannot be null"); + } + if (promptExtraConfig.getAccessKey() == null) { + throw new IllegalArgumentException("LangfusePrompt Access Key field cannot be null"); + } + if (promptExtraConfig.getType().equals("text") && !prompt.containsKey("user")) { + throw new IllegalArgumentException("Langfuse Text Prompt requires User Prompt"); + } } this.name = name; this.description = description; - this.version = version; + this.version = version == null ? PROMPT_VERSION_INITIAL_VERSION : version; this.prompt = prompt; + this.promptManagementType = promptManagementType; this.tags = tags; this.tenantId = tenantId; + this.promptExtraConfig = promptExtraConfig; } /** @@ -98,8 +116,10 @@ public MLCreatePromptInput(StreamInput input) throws IOException { this.description = input.readOptionalString(); this.version = input.readOptionalString(); this.prompt = input.readMap(s -> s.readString(), s -> s.readString()); + this.promptManagementType = input.readOptionalString(); this.tags = input.readList(StreamInput::readString); this.tenantId = input.readOptionalString(); + this.promptExtraConfig = new PromptExtraConfig(input); } /** @@ -114,8 +134,10 @@ public void writeTo(StreamOutput output) throws IOException { output.writeOptionalString(description); output.writeOptionalString(version); output.writeMap(prompt, StreamOutput::writeString, StreamOutput::writeString); + output.writeOptionalString(promptManagementType); output.writeCollection(tags, StreamOutput::writeString); output.writeOptionalString(tenantId); + promptExtraConfig.writeTo(output); } /** @@ -130,8 +152,10 @@ public static MLCreatePromptInput parse(XContentParser parser) throws IOExceptio String description = null; String version = null; Map prompt = null; + String promptManagementType = null; List tags = null; String tenantId = null; + PromptExtraConfig promptExtraConfig = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -150,6 +174,9 @@ public static MLCreatePromptInput parse(XContentParser parser) throws IOExceptio case PROMPT_PROMPT_FIELD: prompt = getParameterMap(parser.map()); break; + case PROMPT_MANAGEMENT_TYPE: + promptManagementType = parser.text(); + break; case PROMPT_TAGS_FIELD: tags = new ArrayList<>(); ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); @@ -159,6 +186,10 @@ public static MLCreatePromptInput parse(XContentParser parser) throws IOExceptio break; case TENANT_ID_FIELD: tenantId = parser.textOrNull(); + break; + case PROMPT_EXTRA_CONFIG_FIELD: + promptExtraConfig = PromptExtraConfig.parse(parser); + break; default: parser.skipChildren(); break; @@ -170,8 +201,10 @@ public static MLCreatePromptInput parse(XContentParser parser) throws IOExceptio .description(description) .version(version) .prompt(prompt) + .promptManagementType(promptManagementType) .tags(tags) .tenantId(tenantId) + .promptExtraConfig(promptExtraConfig) .build(); } @@ -198,11 +231,14 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (prompt != null) { builder.field(PROMPT_PROMPT_FIELD, prompt); } + if (promptManagementType != null) { + builder.field(PROMPT_MANAGEMENT_TYPE, promptManagementType); + } if (tags != null) { builder.field(PROMPT_TAGS_FIELD, tags); } - if (tenantId != null) { - builder.field(TENANT_ID_FIELD, tenantId); + if (promptExtraConfig != null) { + builder.field(PROMPT_EXTRA_CONFIG_FIELD, promptExtraConfig); } builder.endObject(); return builder; diff --git a/common/src/main/java/org/opensearch/ml/common/transport/prompt/MLImportPromptAction.java b/common/src/main/java/org/opensearch/ml/common/transport/prompt/MLImportPromptAction.java new file mode 100644 index 0000000000..9e3e84f218 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/prompt/MLImportPromptAction.java @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.prompt; + +import org.opensearch.action.ActionType; + +public class MLImportPromptAction extends ActionType { + public static MLImportPromptAction INSTANCE = new MLImportPromptAction(); + public static final String NAME = "cluster:admin/opensearch/ml/import"; + + private MLImportPromptAction() { + super(NAME, MLImportPromptResponse::new); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/prompt/MLImportPromptInput.java b/common/src/main/java/org/opensearch/ml/common/transport/prompt/MLImportPromptInput.java new file mode 100644 index 0000000000..a95001719f --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/prompt/MLImportPromptInput.java @@ -0,0 +1,162 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.prompt; + +import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD; + +import java.io.IOException; +import java.util.Objects; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +import lombok.Builder; +import lombok.Data; +import lombok.NonNull; + +@Data +public class MLImportPromptInput implements ToXContentObject, Writeable { + public static final String NAME = "name"; + public static final String TAG = "tag"; + public static final String PROMPT_MANAGEMENT_TYPE = "prompt_management_type"; + public static final String PUBLIC_KEY = "public_key"; + public static final String ACCESS_KEY = "access_key"; + public static final String LIMIT = "limit"; + + private String name; + private String tag; + private String promptManagementType; + private String publicKey; + private String accessKey; + private String limit; + private String tenantId; + + @Builder(toBuilder = true) + public MLImportPromptInput( + String name, + String tag, + @NonNull String promptManagementType, + @NonNull String publicKey, + @NonNull String accessKey, + String limit, + String tenantId + ) { + Objects.requireNonNull(promptManagementType, "must specify prompt management type"); + Objects.requireNonNull(publicKey, "public key can not be null"); + Objects.requireNonNull(accessKey, "access key can not be null"); + this.name = name; + this.tag = tag; + this.promptManagementType = promptManagementType; + this.promptManagementType = promptManagementType; + this.publicKey = publicKey; + this.accessKey = accessKey; + this.limit = limit; + this.tenantId = tenantId; + } + + public MLImportPromptInput(StreamInput input) throws IOException { + this.name = input.readOptionalString(); + this.tag = input.readOptionalString(); + this.promptManagementType = input.readOptionalString(); + this.publicKey = input.readOptionalString(); + this.publicKey = input.readOptionalString(); + this.limit = input.readOptionalString(); + this.tenantId = input.readOptionalString(); + } + + public void writeTo(StreamOutput output) throws IOException { + output.writeOptionalString(name); + output.writeOptionalString(tag); + output.writeOptionalString(promptManagementType); + output.writeOptionalString(publicKey); + output.writeOptionalString(accessKey); + output.writeOptionalString(limit); + output.writeOptionalString(tenantId); + } + + public static MLImportPromptInput parse(XContentParser parser) throws IOException { + String name = null; + String tag = null; + String promptManagementType = null; + String publicKey = null; + String accessKey = null; + String limit = null; + String tenantId = null; + + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + switch (fieldName) { + case NAME: + name = parser.text(); + break; + case TAG: + tag = parser.text(); + break; + case PROMPT_MANAGEMENT_TYPE: + promptManagementType = parser.text(); + break; + case PUBLIC_KEY: + publicKey = parser.text(); + break; + case ACCESS_KEY: + accessKey = parser.text(); + break; + case LIMIT: + limit = parser.text(); + break; + case TENANT_ID_FIELD: + tenantId = parser.text(); + break; + default: + parser.skipChildren(); + break; + } + } + return MLImportPromptInput + .builder() + .name(name) + .tag(tag) + .promptManagementType(promptManagementType) + .publicKey(publicKey) + .accessKey(accessKey) + .limit(limit) + .tenantId(tenantId) + .build(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (name != null) { + builder.field(NAME, name); + } + if (tag != null) { + builder.field(TAG, tag); + } + if (promptManagementType != null) { + builder.field(PROMPT_MANAGEMENT_TYPE, promptManagementType); + } + if (publicKey != null) { + builder.field(PUBLIC_KEY, publicKey); + } + if (accessKey != null) { + builder.field(ACCESS_KEY, accessKey); + } + if (limit != null) { + builder.field(LIMIT, limit); + } + if (tenantId != null) { + builder.field(TENANT_ID_FIELD, tenantId); + } + builder.endObject(); + return builder; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/prompt/MLImportPromptRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/prompt/MLImportPromptRequest.java new file mode 100644 index 0000000000..d28bace071 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/prompt/MLImportPromptRequest.java @@ -0,0 +1,69 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.prompt; + +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +import lombok.Builder; +import lombok.Getter; + +@Getter +public class MLImportPromptRequest extends ActionRequest { + private MLImportPromptInput mlImportPromptInput; + + @Builder + public MLImportPromptRequest(MLImportPromptInput mlImportPromptInput) { + this.mlImportPromptInput = mlImportPromptInput; + } + + public MLImportPromptRequest(StreamInput in) throws IOException { + super(in); + this.mlImportPromptInput = new MLImportPromptInput(in); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + if (mlImportPromptInput == null) { + exception = addValidationError("ML Prompt Input can't be null", exception); + } + + return exception; + } + + @Override + public void writeTo(StreamOutput output) throws IOException { + super.writeTo(output); + this.mlImportPromptInput.writeTo(output); + } + + public static MLImportPromptRequest fromActionRequest(ActionRequest actionRequest) { + if (actionRequest instanceof MLImportPromptRequest) { + return (MLImportPromptRequest) actionRequest; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionRequest.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLImportPromptRequest(input); + } + } catch (IOException e) { + throw new UncheckedIOException("Failed to parse ActionRequest into MLImportPromptRequest", e); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/prompt/MLImportPromptResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/prompt/MLImportPromptResponse.java new file mode 100644 index 0000000000..498791381d --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/prompt/MLImportPromptResponse.java @@ -0,0 +1,63 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.prompt; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Map; + +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; + +public class MLImportPromptResponse extends ActionResponse implements ToXContentObject { + public static final String IMPORTED_PROMPTS_FIELD = "imported_prompts"; + + private Map importedPrompts; + + public MLImportPromptResponse(Map importedPrompts) { + this.importedPrompts = importedPrompts; + } + + public MLImportPromptResponse(StreamInput in) throws IOException { + super(in); + this.importedPrompts = in.readMap(s -> s.readString(), s -> s.readString()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeMap(importedPrompts, StreamOutput::writeString, StreamOutput::writeString); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(IMPORTED_PROMPTS_FIELD, importedPrompts); + builder.endObject(); + return builder; + } + + public static MLImportPromptResponse fromActionResponse(ActionResponse actionResponse) { + if (actionResponse instanceof MLImportPromptResponse) { + return (MLImportPromptResponse) actionResponse; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionResponse.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLImportPromptResponse(input); + } + } catch (IOException e) { + throw new UncheckedIOException("Failed to parse ActionResponse into MLImportPromptResponse", e); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/prompt/MLUpdatePromptInput.java b/common/src/main/java/org/opensearch/ml/common/transport/prompt/MLUpdatePromptInput.java index b890b9e51c..56396dcba0 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/prompt/MLUpdatePromptInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/prompt/MLUpdatePromptInput.java @@ -21,6 +21,7 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.prompt.PromptExtraConfig; import lombok.Builder; import lombok.Data; @@ -37,6 +38,7 @@ public class MLUpdatePromptInput implements ToXContentObject, Writeable { public static final String PROMPT_VERSION_FIELD = "version"; public static final String PROMPT_PROMPT_FIELD = "prompt"; public static final String PROMPT_TAGS_FIELD = "tags"; + public static final String PROMPT_EXTRA_CONFIG_FIELD = "extra_config"; public static final String PROMPT_FIELD_USER_PROMPT = "user"; public static final String PROMPT_FIELD_SYSTEM_PROMPT = "system"; @@ -48,6 +50,7 @@ public class MLUpdatePromptInput implements ToXContentObject, Writeable { private String version; private Map prompt; private List tags; + private PromptExtraConfig extraConfig; @Setter private String tenantId; private Instant lastUpdateTime; @@ -68,6 +71,7 @@ public MLUpdatePromptInput( String version, Map prompt, List tags, + PromptExtraConfig extraConfig, String tenantId, Instant lastUpdateTime ) { @@ -77,6 +81,7 @@ public MLUpdatePromptInput( this.version = version; this.prompt = prompt; this.tags = tags; + this.extraConfig = extraConfig; this.tenantId = tenantId; this.lastUpdateTime = lastUpdateTime; } @@ -93,6 +98,7 @@ public MLUpdatePromptInput(StreamInput input) throws IOException { this.version = input.readOptionalString(); this.prompt = input.readMap(s -> s.readString(), s -> s.readString()); this.tags = input.readList(StreamInput::readString); + this.extraConfig = new PromptExtraConfig(input); this.tenantId = input.readOptionalString(); this.lastUpdateTime = input.readOptionalInstant(); } @@ -110,6 +116,7 @@ public void writeTo(StreamOutput output) throws IOException { output.writeOptionalString(version); output.writeMap(prompt, StreamOutput::writeString, StreamOutput::writeString); output.writeCollection(tags, StreamOutput::writeString); + extraConfig.writeTo(output); output.writeOptionalString(tenantId); output.writeOptionalInstant(lastUpdateTime); } @@ -140,6 +147,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (tags != null) { builder.field(PROMPT_TAGS_FIELD, tags); } + if (extraConfig != null) { + builder.field(PROMPT_EXTRA_CONFIG_FIELD, extraConfig); + } if (tenantId != null) { builder.field(TENANT_ID_FIELD, tenantId); } @@ -163,6 +173,7 @@ public static MLUpdatePromptInput parse(XContentParser parser) throws IOExceptio String version = null; Map prompt = null; List tags = null; + PromptExtraConfig extraConfig = null; String tenantId = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); @@ -189,6 +200,9 @@ public static MLUpdatePromptInput parse(XContentParser parser) throws IOExceptio tags.add(parser.text()); } break; + case PROMPT_EXTRA_CONFIG_FIELD: + extraConfig = PromptExtraConfig.parse(parser); + break; case TENANT_ID_FIELD: tenantId = parser.textOrNull(); default: @@ -203,6 +217,7 @@ public static MLUpdatePromptInput parse(XContentParser parser) throws IOExceptio .version(version) .prompt(prompt) .tags(tags) + .extraConfig(extraConfig) .tenantId(tenantId) .build(); } diff --git a/common/src/main/resources/index-mappings/ml_prompt.json b/common/src/main/resources/index-mappings/ml_prompt.json index 01a05043f4..29a0364ac9 100644 --- a/common/src/main/resources/index-mappings/ml_prompt.json +++ b/common/src/main/resources/index-mappings/ml_prompt.json @@ -21,9 +21,15 @@ "prompt": { "type": "flat_object" }, + "prompt_management_type": { + "type": "keyword" + }, "tags": { "type": "keyword" }, + "extra_config": { + "type": "flat_object" + }, "tenant_id": { "type": "keyword" }, diff --git a/plugin/build.gradle b/plugin/build.gradle index 6d3b04240a..4294034c30 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -97,6 +97,13 @@ dependencies { // https://mvnrepository.com/artifact/io.projectreactor/reactor-test testImplementation("io.projectreactor:reactor-test:3.5.20") + + // Langfuse SDK Client + implementation("com.langfuse:langfuse-java:0.0.6") + implementation("com.squareup.okhttp3:okhttp:5.0.0-alpha.14") + implementation("com.squareup.okio:okio:3.11.0") + implementation("com.fasterxml.jackson.core:jackson-databind:2.19.0") + implementation("com.fasterxml.jackson.datatype:jackson-datatype-jdk8:2.18.2") } publishing { diff --git a/plugin/src/main/java/org/opensearch/ml/action/prompt/GetPromptTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/prompt/GetPromptTransportAction.java index 2981766517..0846565b6f 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/prompt/GetPromptTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/prompt/GetPromptTransportAction.java @@ -6,6 +6,7 @@ package org.opensearch.ml.action.prompt; import static org.opensearch.ml.common.CommonValue.ML_PROMPT_INDEX; +import static org.opensearch.ml.prompt.AbstractPromptManagement.init; import static org.opensearch.ml.prompt.MLPromptManager.handleFailure; import org.opensearch.action.support.ActionFilters; @@ -17,6 +18,8 @@ import org.opensearch.ml.common.transport.prompt.MLPromptGetAction; import org.opensearch.ml.common.transport.prompt.MLPromptGetRequest; import org.opensearch.ml.common.transport.prompt.MLPromptGetResponse; +import org.opensearch.ml.engine.encryptor.EncryptorImpl; +import org.opensearch.ml.prompt.AbstractPromptManagement; import org.opensearch.ml.prompt.MLPromptManager; import org.opensearch.ml.utils.TenantAwareHelper; import org.opensearch.remote.metadata.client.GetDataObjectRequest; @@ -39,6 +42,7 @@ public class GetPromptTransportAction extends HandledTransportAction { Client client; SdkClient sdkClient; + EncryptorImpl encryptor; MLPromptManager mlPromptManager; private final MLFeatureEnabledSetting mlFeatureEnabledSetting; @@ -49,12 +53,14 @@ public GetPromptTransportAction( ActionFilters actionFilters, Client client, SdkClient sdkClient, + EncryptorImpl encryptor, MLFeatureEnabledSetting mlFeatureEnabledSetting, MLPromptManager mlPromptManager ) { super(MLPromptGetAction.NAME, transportService, actionFilters, MLPromptGetRequest::new); this.client = client; this.sdkClient = sdkClient; + this.encryptor = encryptor; this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; this.mlPromptManager = mlPromptManager; } @@ -85,15 +91,16 @@ protected void doExecute(Task task, MLPromptGetRequest mlPromptGetRequest, Actio .fetchSourceContext(fetchSourceContext) .build(); - mlPromptManager - .getPromptAsync( - getDataObjectRequest, - promptId, - ActionListener - .wrap( - mlPrompt -> actionListener.onResponse(MLPromptGetResponse.builder().mlPrompt(mlPrompt).build()), - e -> handleFailure(e, promptId, actionListener, "Failed to get MLPrompt") - ) - ); + mlPromptManager.getPromptAsync(getDataObjectRequest, promptId, ActionListener.wrap(mlPrompt -> { + try { + mlPrompt.decrypt(mlPrompt.getPromptManagementType(), encryptor::decrypt, tenantId); + AbstractPromptManagement promptManagement = init(mlPrompt.getPromptManagementType(), mlPrompt.getPromptExtraConfig()); + promptManagement.getPrompt(mlPrompt); + actionListener.onResponse(MLPromptGetResponse.builder().mlPrompt(mlPrompt).build()); + } catch (Exception e) { + log.error("Failed to process " + mlPrompt.getPromptManagementType() + " Prompt"); + actionListener.onFailure(e); + } + }, e -> handleFailure(e, promptId, actionListener, "Failed to get MLPrompt"))); } } diff --git a/plugin/src/main/java/org/opensearch/ml/action/prompt/ImportPromptTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/prompt/ImportPromptTransportAction.java new file mode 100644 index 0000000000..9d37801754 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/prompt/ImportPromptTransportAction.java @@ -0,0 +1,270 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.prompt; + +import static org.opensearch.ml.common.CommonValue.ML_PROMPT_INDEX; +import static org.opensearch.ml.common.prompt.MLPrompt.LANGFUSE; +import static org.opensearch.ml.common.prompt.MLPrompt.MLPROMPT; +import static org.opensearch.ml.prompt.AbstractPromptManagement.init; +import static org.opensearch.ml.prompt.MLPromptManager.handleFailure; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; + +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.ml.common.prompt.MLPrompt; +import org.opensearch.ml.common.prompt.PromptExtraConfig; +import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; +import org.opensearch.ml.common.transport.prompt.MLImportPromptAction; +import org.opensearch.ml.common.transport.prompt.MLImportPromptInput; +import org.opensearch.ml.common.transport.prompt.MLImportPromptRequest; +import org.opensearch.ml.common.transport.prompt.MLImportPromptResponse; +import org.opensearch.ml.engine.MLEngine; +import org.opensearch.ml.engine.indices.MLIndicesHandler; +import org.opensearch.ml.prompt.AbstractPromptManagement; +import org.opensearch.ml.prompt.MLPromptManager; +import org.opensearch.ml.prompt.PromptImportable; +import org.opensearch.ml.utils.TenantAwareHelper; +import org.opensearch.remote.metadata.client.GetDataObjectRequest; +import org.opensearch.remote.metadata.client.PutDataObjectRequest; +import org.opensearch.remote.metadata.client.PutDataObjectResponse; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.client.UpdateDataObjectRequest; +import org.opensearch.remote.metadata.common.SdkClientUtils; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; +import org.opensearch.transport.client.Client; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class ImportPromptTransportAction extends HandledTransportAction { + private final MLIndicesHandler mlIndicesHandler; + private final Client client; + private final SdkClient sdkClient; + private final MLEngine mlEngine; + private final MLPromptManager mlPromptManager; + private final MLFeatureEnabledSetting mlFeatureEnabledSetting; + + @Inject + public ImportPromptTransportAction( + TransportService transportService, + ActionFilters actionFilters, + MLIndicesHandler mlIndicesHandler, + Client client, + SdkClient sdkClient, + MLEngine mlEngine, + MLPromptManager mlPromptManager, + MLFeatureEnabledSetting mlFeatureEnabledSetting + ) { + super(MLImportPromptAction.NAME, transportService, actionFilters, MLImportPromptRequest::new); + this.mlIndicesHandler = mlIndicesHandler; + this.client = client; + this.sdkClient = sdkClient; + this.mlEngine = mlEngine; + this.mlPromptManager = mlPromptManager; + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; + } + + @Override + protected void doExecute(Task task, MLImportPromptRequest mlImportPromptRequest, ActionListener listener) { + MLImportPromptInput mlImportPromptInput = mlImportPromptRequest.getMlImportPromptInput(); + + if (!TenantAwareHelper.validateTenantId(mlFeatureEnabledSetting, mlImportPromptInput.getTenantId(), listener)) { + return; + } + String promptManagementType = mlImportPromptInput.getPromptManagementType(); + String publicKey = mlImportPromptInput.getPublicKey(); + String accessKey = mlImportPromptInput.getAccessKey(); + String tenantId = mlImportPromptInput.getTenantId(); + + try { + AbstractPromptManagement promptManagement = init( + promptManagementType, + PromptExtraConfig.builder().publicKey(publicKey).accessKey(accessKey).build() + ); + + if (!(promptManagement instanceof PromptImportable importer)) { + throw new OpenSearchStatusException("Import prompt is not supported for MLPromptManagement", RestStatus.BAD_REQUEST); + } else { + List mlPromptList = importer.importPrompts(mlImportPromptInput); + Map responseBody = new HashMap<>(); + if (mlPromptList.isEmpty()) { + listener.onResponse(new MLImportPromptResponse(responseBody)); + return; + } + AtomicInteger remainingMLPrompts = new AtomicInteger(mlPromptList.size()); + for (MLPrompt mlPrompt : mlPromptList) { + mlPrompt.encrypt(promptManagementType, mlEngine::encrypt, tenantId); + handleConflictingName(mlPrompt, tenantId, ActionListener.wrap(promptId -> { + if (promptId == null) { + indexPrompt(mlPrompt, responseBody, remainingMLPrompts, listener); + } else { + updateImportResponseBody(promptId, mlPrompt.getName(), responseBody, remainingMLPrompts, listener); + } + }, listener::onFailure)); + } + } + } catch (Exception e) { + handleFailure(e, null, listener, "Failed to import " + promptManagementType + " Prompts into System Index"); + } + } + + /** + * Store prompt into system index + * + * @param prompt prompt that needs to be stored into the system index + * @param responseBody response body that will be return upon success in the format of prompt name to prompt id + * @param remainingMLPrompts remaining prompt to be stored into the system index + * @param listener actionListener that will be notified upon success or failure of the prompt creation + */ + private void indexPrompt( + MLPrompt prompt, + Map responseBody, + AtomicInteger remainingMLPrompts, + ActionListener listener + ) { + log.info("prompt created, indexing into the prompt system index"); + mlIndicesHandler.initMLPromptIndex(ActionListener.wrap(indexCreated -> { + if (!indexCreated) { + Exception exception = new RuntimeException("No response to create ML Prompt Index"); + handleFailure(exception, null, listener, "Failed to create a system index for prompt"); + return; + } + + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + PutDataObjectRequest putRequest = buildPromptPutRequest(prompt); + sdkClient.putDataObjectAsync(putRequest).whenComplete((putResponse, throwable) -> { + context.restore(); + handlePromptPutResponse(putResponse, throwable, prompt.getName(), responseBody, remainingMLPrompts, listener); + }); + } + }, e -> { handleFailure(e, null, listener, "Failed to init ML prompt index"); })); + } + + /** + * Builds putRequest to write prompt into index + * + * @param prompt prompt that needs to be stored into the system index + * @return PutDataObjectRequest + */ + private PutDataObjectRequest buildPromptPutRequest(MLPrompt prompt) { + return PutDataObjectRequest.builder().tenantId(prompt.getTenantId()).index(ML_PROMPT_INDEX).dataObject(prompt).build(); + } + + /** + * Handles PutResponse after prompt is indexed + * + * @param putResponse response received after prompt is indexed + * @param throwable throwable + * @param name prompt name that is indexed + * @param responseBody response body that will be return upon success in the format of prompt name to prompt id + * @param remainingMLPrompts remaining prompt to be stored into the system index + * @param listener actionListener that will be notified upon success or failure of the prompt creation + */ + private void handlePromptPutResponse( + PutDataObjectResponse putResponse, + Throwable throwable, + String name, + Map responseBody, + AtomicInteger remainingMLPrompts, + ActionListener listener + ) { + if (putResponse == null || throwable != null) { + Exception cause = SdkClientUtils.unwrapAndConvertToException(throwable); + handleFailure(cause, null, listener, "Prompt Put Response cannot be null"); + } + try { + IndexResponse indexResponse = IndexResponse.fromXContent(putResponse.parser()); + log.info("Prompt creation result: {}, prompt id: {}", indexResponse.getResult(), indexResponse.getId()); + updateImportResponseBody(indexResponse.getId(), name, responseBody, remainingMLPrompts, listener); + } catch (Exception e) { + handleFailure(e, null, listener, "Failed to parse PutDataObjectResponse into Index Response from xContent"); + } + } + + /** + * Update the response body with the prompt name and prompt id upon successful import + * + * @param promptId prompt id returned after prompt is successfully indexed into the system index + * @param name name of the prompt that is stored + * @param responseBody response body that will be return upon success in the format of prompt name to prompt id + * @param remainingMLPrompts remaining prompt to be stored into the system index + * @param listener actionListener that will be notified upon success or failure of the prompt creation + */ + private void updateImportResponseBody( + String promptId, + String name, + Map responseBody, + AtomicInteger remainingMLPrompts, + ActionListener listener + ) { + responseBody.put(name, promptId); + remainingMLPrompts.set(remainingMLPrompts.get() - 1); + if (remainingMLPrompts.get() == 0) { + listener.onResponse(new MLImportPromptResponse(responseBody)); + } + } + + /** + * Search name field on prompt system index. + * + * @param importingPrompt prompt that needs to be imported into prompt system index + * @param tenantId tenant id + * @param wrappedListener listener that will be notified with prompt id upon success or failure of the prompt creation + * @throws IOException if search hits, meaning conflicting name exist + */ + private void handleConflictingName(MLPrompt importingPrompt, String tenantId, ActionListener wrappedListener) + throws IOException { + String name = importingPrompt.getName(); + SearchResponse searchResponse = mlPromptManager.searchPromptByName(name, tenantId); + if (searchResponse != null + && searchResponse.getHits().getTotalHits() != null + && searchResponse.getHits().getTotalHits().value() != 0) { + String promptId = searchResponse.getHits().getAt(0).getId(); + GetDataObjectRequest getDataObjectRequest = GetDataObjectRequest + .builder() + .index(ML_PROMPT_INDEX) + .id(promptId) + .tenantId(tenantId) + .build(); + MLPrompt existingMLPrompt = mlPromptManager.getPrompt(getDataObjectRequest); + + // check the prompt management type + String promptManagementType = existingMLPrompt.getPromptManagementType(); + if (promptManagementType.equalsIgnoreCase(MLPROMPT)) { + throw new IllegalArgumentException("Provided name: " + name + " is already being used by ML Prompt with id: " + promptId); + } else if (promptManagementType.equalsIgnoreCase(LANGFUSE)) { + // update the existing langfuse prompt with new content if the version des not match + UpdateDataObjectRequest updateDataObjectRequest = UpdateDataObjectRequest + .builder() + .index(ML_PROMPT_INDEX) + .id(promptId) + .tenantId(tenantId) + .dataObject(importingPrompt) + .build(); + mlPromptManager.updatePromptIndex(updateDataObjectRequest, promptId, ActionListener.wrap(updateResponse -> { + log.info("{} Prompt with promptId: {} updated successfully", promptManagementType, promptId); + wrappedListener.onResponse(promptId); + }, wrappedListener::onFailure)); + } + } else { + // provided name is unique, good to be imported + wrappedListener.onResponse(null); + } + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/prompt/SearchPromptTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/prompt/SearchPromptTransportAction.java index bcac08feab..5244eb606d 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/prompt/SearchPromptTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/prompt/SearchPromptTransportAction.java @@ -7,12 +7,20 @@ import static org.opensearch.ml.utils.RestActionUtils.wrapListenerToHandleSearchIndexNotFound; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; + import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.common.inject.Inject; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.prompt.MLPrompt; +import org.opensearch.ml.common.prompt.PromptExtraConfig; import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; import org.opensearch.ml.common.transport.prompt.MLPromptSearchAction; import org.opensearch.ml.common.transport.search.MLSearchActionRequest; @@ -20,6 +28,8 @@ import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.remote.metadata.client.SearchDataObjectRequest; import org.opensearch.remote.metadata.common.SdkClientUtils; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.fetch.subphase.FetchSourceContext; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; import org.opensearch.transport.client.Client; @@ -68,9 +78,34 @@ protected void doExecute(Task task, MLSearchActionRequest mlSearchActionRequest, } try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - final ActionListener wrappedListener = ActionListener + ActionListener wrappedListener = ActionListener.runBefore(actionListener, context::restore); + + List excludes = Optional + .ofNullable(mlSearchActionRequest.source()) + .map(SearchSourceBuilder::fetchSource) + .map(FetchSourceContext::excludes) + .map(x -> Arrays.stream(x).collect(Collectors.toList())) + .orElse(new ArrayList<>()); + excludes.add(MLPrompt.PROMPT_EXTRA_CONFIG_FIELD + "." + PromptExtraConfig.LANGFUSE_PROMPT_ACCESS_KEY_FIELD); + excludes.add(MLPrompt.PROMPT_EXTRA_CONFIG_FIELD + "." + PromptExtraConfig.LANGFUSE_PROMPT_PUBLIC_KEY_FIELD); + FetchSourceContext rebuiltFetchSourceContext = new FetchSourceContext( + Optional + .ofNullable(mlSearchActionRequest.source()) + .map(SearchSourceBuilder::fetchSource) + .map(FetchSourceContext::fetchSource) + .orElse(true), + Optional + .ofNullable(mlSearchActionRequest.source()) + .map(SearchSourceBuilder::fetchSource) + .map(FetchSourceContext::includes) + .orElse(null), + excludes.toArray(new String[0]) + ); + mlSearchActionRequest.source().fetchSource(rebuiltFetchSourceContext); + + final ActionListener doubleWrappedListener = ActionListener .runBefore( - ActionListener.wrap(actionListener::onResponse, e -> wrapListenerToHandleSearchIndexNotFound(e, actionListener)), + ActionListener.wrap(wrappedListener::onResponse, e -> wrapListenerToHandleSearchIndexNotFound(e, wrappedListener)), context::restore ); @@ -81,7 +116,9 @@ protected void doExecute(Task task, MLSearchActionRequest mlSearchActionRequest, .tenantId(tenantId) .build(); - sdkClient.searchDataObjectAsync(searchDataObjectRequest).whenComplete(SdkClientUtils.wrapSearchCompletion(wrappedListener)); + sdkClient + .searchDataObjectAsync(searchDataObjectRequest) + .whenComplete(SdkClientUtils.wrapSearchCompletion(doubleWrappedListener)); } catch (Exception e) { log.error("Failed to search ML Prompt", e); actionListener.onFailure(e); diff --git a/plugin/src/main/java/org/opensearch/ml/action/prompt/TransportCreatePromptAction.java b/plugin/src/main/java/org/opensearch/ml/action/prompt/TransportCreatePromptAction.java index a0bf547e12..331cdbc184 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/prompt/TransportCreatePromptAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/prompt/TransportCreatePromptAction.java @@ -6,13 +6,15 @@ package org.opensearch.ml.action.prompt; import static org.opensearch.ml.common.CommonValue.ML_PROMPT_INDEX; +import static org.opensearch.ml.common.prompt.MLPrompt.MLPROMPT; +import static org.opensearch.ml.prompt.AbstractPromptManagement.init; import static org.opensearch.ml.prompt.MLPromptManager.MLPromptNameAlreadyExists; import static org.opensearch.ml.prompt.MLPromptManager.TAG_RESTRICTION_ERR_MESSAGE; import static org.opensearch.ml.prompt.MLPromptManager.UNIQUE_NAME_ERR_MESSAGE; import static org.opensearch.ml.prompt.MLPromptManager.handleFailure; import static org.opensearch.ml.prompt.MLPromptManager.validateTags; -import java.time.Instant; +import java.util.Objects; import org.opensearch.action.index.IndexResponse; import org.opensearch.action.search.SearchResponse; @@ -22,12 +24,15 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.prompt.MLPrompt; +import org.opensearch.ml.common.prompt.PromptExtraConfig; import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; import org.opensearch.ml.common.transport.prompt.MLCreatePromptAction; import org.opensearch.ml.common.transport.prompt.MLCreatePromptInput; import org.opensearch.ml.common.transport.prompt.MLCreatePromptRequest; import org.opensearch.ml.common.transport.prompt.MLCreatePromptResponse; +import org.opensearch.ml.engine.MLEngine; import org.opensearch.ml.engine.indices.MLIndicesHandler; +import org.opensearch.ml.prompt.AbstractPromptManagement; import org.opensearch.ml.prompt.MLPromptManager; import org.opensearch.ml.utils.TenantAwareHelper; import org.opensearch.remote.metadata.client.PutDataObjectRequest; @@ -47,10 +52,10 @@ */ @Log4j2 public class TransportCreatePromptAction extends HandledTransportAction { - private static final String INITIAL_VERSION = "1"; private final MLIndicesHandler mlIndicesHandler; private final Client client; private final SdkClient sdkClient; + private final MLEngine mlEngine; private final MLPromptManager mlPromptManager; private final MLFeatureEnabledSetting mlFeatureEnabledSetting; @@ -62,6 +67,7 @@ public TransportCreatePromptAction( MLIndicesHandler mlIndicesHandler, Client client, SdkClient sdkClient, + MLEngine mlEngine, MLPromptManager mlPromptManager, MLFeatureEnabledSetting mlFeatureEnabledSetting ) { @@ -69,6 +75,7 @@ public TransportCreatePromptAction( this.mlIndicesHandler = mlIndicesHandler; this.client = client; this.sdkClient = sdkClient; + this.mlEngine = mlEngine; this.mlPromptManager = mlPromptManager; this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; } @@ -119,19 +126,18 @@ protected void doExecute(Task task, MLCreatePromptRequest mlCreatePromptRequest, ); return; } - String version = mlCreatePromptInput.getVersion(); - MLPrompt mlPrompt = MLPrompt - .builder() - .name(mlCreatePromptInput.getName()) - .description(mlCreatePromptInput.getDescription()) - .version(version == null ? INITIAL_VERSION : version) - .prompt(mlCreatePromptInput.getPrompt()) - .tags(mlCreatePromptInput.getTags()) - .tenantId(mlCreatePromptInput.getTenantId()) - .createTime(Instant.now()) - .lastUpdateTime(Instant.now()) - .build(); + // set prompt management type to default MLPROMPT if not provided + if (mlCreatePromptInput.getPromptManagementType() == null) { + mlCreatePromptInput.setPromptManagementType(MLPROMPT); + } + PromptExtraConfig extraConfig = mlCreatePromptInput.getPromptExtraConfig(); + String promptManagementType = mlCreatePromptInput.getPromptManagementType(); + AbstractPromptManagement promptManagement = init(promptManagementType, extraConfig); + + MLPrompt mlPrompt = promptManagement.createPrompt(mlCreatePromptInput); + mlPrompt.encrypt(promptManagementType, mlEngine::encrypt, mlCreatePromptInput.getTenantId()); + Objects.requireNonNull(mlPrompt, "MLPrompt cannot be null"); indexPrompt(mlPrompt, listener); } catch (Exception e) { handleFailure(e, null, listener, "Failed to create a MLPrompt"); diff --git a/plugin/src/main/java/org/opensearch/ml/action/prompt/UpdatePromptTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/prompt/UpdatePromptTransportAction.java index 4d672f89be..8955546a8d 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/prompt/UpdatePromptTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/prompt/UpdatePromptTransportAction.java @@ -6,33 +6,32 @@ package org.opensearch.ml.action.prompt; import static org.opensearch.ml.common.CommonValue.ML_PROMPT_INDEX; -import static org.opensearch.ml.prompt.MLPromptManager.MLPromptNameAlreadyExists; +import static org.opensearch.ml.prompt.AbstractPromptManagement.init; import static org.opensearch.ml.prompt.MLPromptManager.TAG_RESTRICTION_ERR_MESSAGE; import static org.opensearch.ml.prompt.MLPromptManager.UNIQUE_NAME_ERR_MESSAGE; import static org.opensearch.ml.prompt.MLPromptManager.handleFailure; import static org.opensearch.ml.prompt.MLPromptManager.validateTags; -import java.time.Instant; - +import org.apache.commons.lang3.StringUtils; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.action.update.UpdateResponse; import org.opensearch.common.inject.Inject; -import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.prompt.MLPrompt; +import org.opensearch.ml.common.prompt.PromptExtraConfig; import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; import org.opensearch.ml.common.transport.prompt.MLUpdatePromptAction; import org.opensearch.ml.common.transport.prompt.MLUpdatePromptInput; import org.opensearch.ml.common.transport.prompt.MLUpdatePromptRequest; +import org.opensearch.ml.engine.encryptor.EncryptorImpl; +import org.opensearch.ml.prompt.AbstractPromptManagement; import org.opensearch.ml.prompt.MLPromptManager; import org.opensearch.ml.utils.TenantAwareHelper; import org.opensearch.remote.metadata.client.GetDataObjectRequest; import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.remote.metadata.client.UpdateDataObjectRequest; -import org.opensearch.remote.metadata.client.UpdateDataObjectResponse; -import org.opensearch.remote.metadata.common.SdkClientUtils; import org.opensearch.search.SearchHit; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -52,6 +51,7 @@ public class UpdatePromptTransportAction extends HandledTransportAction listener ) { - int version = Integer.parseInt(mlPrompt.getVersion()); - mlUpdatePromptInput.setLastUpdateTime(Instant.now()); - mlUpdatePromptInput.setVersion(String.valueOf(version + 1)); - UpdateDataObjectRequest updateDataObjectRequest = UpdateDataObjectRequest - .builder() - .index(ML_PROMPT_INDEX) - .id(promptId) - .tenantId(tenantId) - .dataObject(mlUpdatePromptInput) - .build(); - - updatePrompt(updateDataObjectRequest, promptId, listener); - } - - /** - * Updates the prompt based on the update contents and replace the old prompt with updated prompt from the index - * - * @param updateDataObjectRequest the updateRequest that needs to be handled - * @param promptId The prompt ID of a prompt that needs to be updated - * @param listener a listener to be notified of the response - */ - private void updatePrompt(UpdateDataObjectRequest updateDataObjectRequest, String promptId, ActionListener listener) { - try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - sdkClient.updateDataObjectAsync(updateDataObjectRequest).whenComplete((updateDataObjectResponse, throwable) -> { - context.restore(); - handleUpdateResponse(updateDataObjectResponse, throwable, promptId, listener); - }); - } - } + PromptExtraConfig extraConfig = mlPrompt.getPromptExtraConfig(); + String promptManagementType = mlPrompt.getPromptManagementType(); + mlPrompt.setPromptId(promptId); + mlPrompt.decrypt(mlPrompt.getPromptManagementType(), encryptor::decrypt, tenantId); + AbstractPromptManagement promptManagement = init(promptManagementType, extraConfig); + UpdateDataObjectRequest updateDataObjectRequest = promptManagement.updatePrompt(mlUpdatePromptInput, mlPrompt); - /** - * Handles the response from the update prompt request. If the response is successful, notify the listener - * with the UpdateResponse. Otherwise, notify the failure exception to the listener. - * - * @param updateDataObjectResponse The response from the update prompt request - * @param throwable The exception that occurred during the update prompt request - * @param listener The listener to be notified of the response - */ - private void handleUpdateResponse( - UpdateDataObjectResponse updateDataObjectResponse, - Throwable throwable, - String promptId, - ActionListener listener - ) { - if (throwable != null) { - Exception cause = SdkClientUtils.unwrapAndConvertToException(throwable); - handleFailure(cause, promptId, listener, "Failed to update ML prompt {}"); - return; - } - UpdateResponse updateResponse = updateDataObjectResponse.updateResponse(); - listener.onResponse(updateResponse); + mlPromptManager.updatePromptIndex(updateDataObjectRequest, promptId, listener); } } diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index b1c77d5c7f..b736afcbc7 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -104,6 +104,7 @@ import org.opensearch.ml.action.profile.MLProfileTransportAction; import org.opensearch.ml.action.prompt.DeletePromptTransportAction; import org.opensearch.ml.action.prompt.GetPromptTransportAction; +import org.opensearch.ml.action.prompt.ImportPromptTransportAction; import org.opensearch.ml.action.prompt.SearchPromptTransportAction; import org.opensearch.ml.action.prompt.TransportCreatePromptAction; import org.opensearch.ml.action.prompt.UpdatePromptTransportAction; @@ -185,6 +186,7 @@ import org.opensearch.ml.common.transport.model_group.MLUpdateModelGroupAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prompt.MLCreatePromptAction; +import org.opensearch.ml.common.transport.prompt.MLImportPromptAction; import org.opensearch.ml.common.transport.prompt.MLPromptDeleteAction; import org.opensearch.ml.common.transport.prompt.MLPromptGetAction; import org.opensearch.ml.common.transport.prompt.MLPromptSearchAction; @@ -290,6 +292,7 @@ import org.opensearch.ml.rest.RestMLGetPromptAction; import org.opensearch.ml.rest.RestMLGetTaskAction; import org.opensearch.ml.rest.RestMLGetToolAction; +import org.opensearch.ml.rest.RestMLImportPromptAction; import org.opensearch.ml.rest.RestMLListToolsAction; import org.opensearch.ml.rest.RestMLPredictionAction; import org.opensearch.ml.rest.RestMLProfileAction; @@ -489,6 +492,7 @@ public MachineLearningPlugin(Settings settings) { new ActionHandler<>(MLPromptGetAction.INSTANCE, GetPromptTransportAction.class), new ActionHandler<>(MLPromptDeleteAction.INSTANCE, DeletePromptTransportAction.class), new ActionHandler<>(MLUpdatePromptAction.INSTANCE, UpdatePromptTransportAction.class), + new ActionHandler<>(MLImportPromptAction.INSTANCE, ImportPromptTransportAction.class), new ActionHandler<>(MLPromptSearchAction.INSTANCE, SearchPromptTransportAction.class), new ActionHandler<>(CreateConversationAction.INSTANCE, CreateConversationTransportAction.class), new ActionHandler<>(GetConversationsAction.INSTANCE, GetConversationsTransportAction.class), @@ -611,7 +615,7 @@ public Collection createComponents( this.mlStats = new MLStats(stats); mlTaskManager = new MLTaskManager(client, sdkClient, threadPool, mlIndicesHandler); - mlPromptManager = new MLPromptManager(client, sdkClient); + mlPromptManager = new MLPromptManager(client, sdkClient, (EncryptorImpl) encryptor); modelHelper = new ModelHelper(mlEngine); mlInputDatasetHandler = new MLInputDatasetHandler(client); @@ -870,6 +874,7 @@ public List getRestHandlers( RestMLGetPromptAction restMLGetPromptAction = new RestMLGetPromptAction(mlFeatureEnabledSetting); RestMLSearchPromptAction restMLSearchPromptAction = new RestMLSearchPromptAction(mlFeatureEnabledSetting); RestMLUpdatePromptAction restMLUpdatePromptAction = new RestMLUpdatePromptAction(mlFeatureEnabledSetting); + RestMLImportPromptAction restMLImportPromptAction = new RestMLImportPromptAction(mlFeatureEnabledSetting); RestMLDeletePromptAction restMLDeletePromptAction = new RestMLDeletePromptAction(mlFeatureEnabledSetting); RestMemoryCreateConversationAction restCreateConversationAction = new RestMemoryCreateConversationAction(); RestMemoryGetConversationsAction restListConversationsAction = new RestMemoryGetConversationsAction(); @@ -937,6 +942,7 @@ public List getRestHandlers( restMLSearchPromptAction, restMLDeletePromptAction, restMLUpdatePromptAction, + restMLImportPromptAction, restCreateConversationAction, restListConversationsAction, restCreateInteractionAction, diff --git a/plugin/src/main/java/org/opensearch/ml/prompt/AbstractPromptManagement.java b/plugin/src/main/java/org/opensearch/ml/prompt/AbstractPromptManagement.java new file mode 100644 index 0000000000..8d90f93cad --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/prompt/AbstractPromptManagement.java @@ -0,0 +1,39 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.prompt; + +import static org.opensearch.ml.common.prompt.MLPrompt.LANGFUSE; +import static org.opensearch.ml.common.prompt.MLPrompt.MLPROMPT; + +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.ml.common.prompt.MLPrompt; +import org.opensearch.ml.common.prompt.PromptExtraConfig; +import org.opensearch.ml.common.transport.prompt.MLCreatePromptInput; +import org.opensearch.ml.common.transport.prompt.MLUpdatePromptInput; +import org.opensearch.remote.metadata.client.UpdateDataObjectRequest; + +import lombok.Getter; + +@Getter +public abstract class AbstractPromptManagement implements ToXContentObject { + public static AbstractPromptManagement init(String promptManagementType, PromptExtraConfig extraConfig) { + // add additional case for new prompt management client below, if needed + switch (promptManagementType.toUpperCase()) { + case MLPROMPT: + return new MLPromptManagement(); + case LANGFUSE: + return new LangfusePromptManagement(extraConfig.getPublicKey(), extraConfig.getAccessKey()); + default: + throw new IllegalArgumentException("Unsupported prompt management type: " + promptManagementType); + } + } + + public abstract MLPrompt createPrompt(MLCreatePromptInput mlCreatePromptInput); + + public abstract void getPrompt(MLPrompt mlPrompt); + + public abstract UpdateDataObjectRequest updatePrompt(MLUpdatePromptInput mlUpdatePromptInput, MLPrompt mlPrompt); +} diff --git a/plugin/src/main/java/org/opensearch/ml/prompt/LangfusePromptManagement.java b/plugin/src/main/java/org/opensearch/ml/prompt/LangfusePromptManagement.java new file mode 100644 index 0000000000..188e7e394c --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/prompt/LangfusePromptManagement.java @@ -0,0 +1,435 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.prompt; + +import static org.opensearch.common.xcontent.json.JsonXContent.jsonXContent; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.CommonValue.ML_PROMPT_INDEX; +import static org.opensearch.ml.common.prompt.MLPrompt.LANGFUSE; +import static org.opensearch.ml.prompt.MLPromptManagement.INITIAL_VERSION; + +import java.io.IOException; +import java.time.Instant; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +import org.opensearch.OpenSearchStatusException; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.prompt.MLPrompt; +import org.opensearch.ml.common.prompt.PromptExtraConfig; +import org.opensearch.ml.common.transport.prompt.MLCreatePromptInput; +import org.opensearch.ml.common.transport.prompt.MLImportPromptInput; +import org.opensearch.ml.common.transport.prompt.MLUpdatePromptInput; +import org.opensearch.remote.metadata.client.UpdateDataObjectRequest; + +import com.langfuse.client.LangfuseClient; +import com.langfuse.client.core.LangfuseClientApiException; +import com.langfuse.client.resources.prompts.requests.GetPromptRequest; +import com.langfuse.client.resources.prompts.requests.ListPromptsMetaRequest; +import com.langfuse.client.resources.prompts.types.ChatMessage; +import com.langfuse.client.resources.prompts.types.ChatPrompt; +import com.langfuse.client.resources.prompts.types.CreateChatPromptRequest; +import com.langfuse.client.resources.prompts.types.CreatePromptRequest; +import com.langfuse.client.resources.prompts.types.CreateTextPromptRequest; +import com.langfuse.client.resources.prompts.types.Prompt; +import com.langfuse.client.resources.prompts.types.PromptMeta; +import com.langfuse.client.resources.prompts.types.PromptMetaListResponse; +import com.langfuse.client.resources.prompts.types.TextPrompt; + +import lombok.Getter; +import lombok.extern.log4j.Log4j2; + +@Log4j2 +@Getter +public class LangfusePromptManagement extends AbstractPromptManagement implements PromptImportable { + public static final String PUBLIC_KEY_FIELD = "public_key"; + public static final String ACCESS_KEY_FIELD = "access_key"; + public static final String LANGFUSE_URL = "https://us.cloud.langfuse.com"; + public static final String USER_ROLE = "user"; + public static final String DEFAULT_LIMIT = "20"; + + public final String TEXT_PROMPT = "text"; + public final String CHAT_PROMPT = "chat"; + public final String PRODUCTION_LABEL = "production"; + + private LangfuseClient langfuseClient; + private String publicKey; + private String accessKey; + + public LangfusePromptManagement(String publicKey, String accessKey) { + this.publicKey = publicKey; + this.accessKey = accessKey; + this.langfuseClient = initLangfuseClient(this.publicKey, this.accessKey); + } + + /** + * Initialize Langfuse Client that is used to invoke Langfuse API to Langfuse Server + * + * @param username + * @param password + * @return + */ + public LangfuseClient initLangfuseClient(String username, String password) { + return LangfuseClient.builder().url(LANGFUSE_URL).credentials(username, password).build(); + } + + /** + * Create Langfuse Prompt in Langfuse Server + * + * @param mlCreatePromptInput input that contains metadata to create a prompt + * @return MLPrompt that will be used to create Langfuse Prompt + */ + @Override + public MLPrompt createPrompt(MLCreatePromptInput mlCreatePromptInput) { + PromptExtraConfig promptExtraConfig = mlCreatePromptInput.getPromptExtraConfig(); + String type = promptExtraConfig.getType(); + + // assigns production label by default + List labels = promptExtraConfig.getLabels(); + if (labels == null) { + labels = List.of(PRODUCTION_LABEL); + } else { + labels.add(PRODUCTION_LABEL); + } + promptExtraConfig.setLabels(labels); + + CreatePromptRequest langfuseRequest; + switch (type) { + case TEXT_PROMPT: + langfuseRequest = buildTextPromptRequest(mlCreatePromptInput, promptExtraConfig); + break; + case CHAT_PROMPT: + langfuseRequest = buildChatPromptRequest(mlCreatePromptInput, promptExtraConfig); + break; + default: + log.error("Unable to find prompt template type"); + throw new IllegalArgumentException("Unable to find prompt template type for Langfuse"); + } + try { + langfuseClient.prompts().create(langfuseRequest); + // we only need to store the fields that are necessary to retrieve this prompt later -> name, prompt management type, encrypted + // credentials + return MLPrompt + .builder() + .name(mlCreatePromptInput.getName()) + .promptManagementType(mlCreatePromptInput.getPromptManagementType()) + .promptExtraConfig(PromptExtraConfig.builder().publicKey(publicKey).accessKey(accessKey).build()) + .build(); + } catch (Exception e) { + String errorMessage = e.getMessage(); + if (e instanceof LangfuseClientApiException) { + errorMessage = getLangfuseClientExceptionMessage((LangfuseClientApiException) e); + } + log.error("Failed to create a prompt in Langfuse Server", e); + throw new OpenSearchStatusException(errorMessage, RestStatus.INTERNAL_SERVER_ERROR); + } + } + + /** + * Build Text Prompt Request + * + * @param mlCreatePromptInput MLCreatePromptInput that contains metadata to create Langfuse Text Prompt + * @param promptExtraConfig Prompt Extra Config that contains credentials and other metadatas to construct Text Prompt + * @return CreatePromptRequest + */ + private CreatePromptRequest buildTextPromptRequest(MLCreatePromptInput mlCreatePromptInput, PromptExtraConfig promptExtraConfig) { + CreateTextPromptRequest textRequest = CreateTextPromptRequest + .builder() + .name(mlCreatePromptInput.getName()) + .prompt(mlCreatePromptInput.getPrompt().get(USER_ROLE)) + .labels(promptExtraConfig.getLabels()) + .tags(mlCreatePromptInput.getTags()) + .build(); + + return CreatePromptRequest.text(textRequest); + } + + /** + * Build Chat Prompt Request + * + * @param mlCreatePromptInput MLCreatePromptInput that contains metadata to create Langfuse Chat Prompt + * @param promptExtraConfig Prompt Extra Config that contains credentials and other metadatas to construct Chat Prompt + * @return CreatePromptRequest + */ + private CreatePromptRequest buildChatPromptRequest(MLCreatePromptInput mlCreatePromptInput, PromptExtraConfig promptExtraConfig) { + List langfusePromptTemplate = new ArrayList<>(); + Map mlPromptTemplate = mlCreatePromptInput.getPrompt(); + for (String role : mlPromptTemplate.keySet()) { + String content = mlPromptTemplate.get(role); + ChatMessage message = ChatMessage.builder().role(role).content(content).build(); + langfusePromptTemplate.add(message); + } + + CreateChatPromptRequest chatRequest = CreateChatPromptRequest + .builder() + .name(mlCreatePromptInput.getName()) + .prompt(langfusePromptTemplate) + .labels(promptExtraConfig.getLabels()) + .tags(mlCreatePromptInput.getTags()) + .build(); + + return CreatePromptRequest.chat(chatRequest); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (publicKey != null) { + builder.field(PUBLIC_KEY_FIELD, publicKey); + } + if (accessKey != null) { + builder.field(ACCESS_KEY_FIELD, accessKey); + } + builder.endObject(); + return builder; + } + + /** + * Retrieve prompt from Langfuse Server + * + * @param mlPrompt Prompt that contains credentials and prompt name that is used to retrieve Langfuse Prompt + */ + @Override + public void getPrompt(MLPrompt mlPrompt) { + mlPrompt.setPromptExtraConfig(null); // won't include credentials in response body + try { + Prompt langfusePrompt = langfuseClient.prompts().get(mlPrompt.getName()); + Prompt promptWithInitialVersion = langfuseClient + .prompts() + .get(mlPrompt.getName(), GetPromptRequest.builder().version(Integer.parseInt(INITIAL_VERSION)).build()); + + // check if the fetched langfuse prompt is text or chat prompt + if (langfusePrompt.isText() && langfusePrompt.getText().isPresent()) { + buildMLPromptFromTextPrompt(langfusePrompt.getText().get(), mlPrompt, promptWithInitialVersion.getText().get()); + } else if (langfusePrompt.isChat() && langfusePrompt.getChat().isPresent()) { + buildMLPromptFromChatPrompt(langfusePrompt.getChat().get(), mlPrompt, promptWithInitialVersion.getChat().get()); + } else { + log.error("Error when fetching the Langfuse Prompt"); + throw new OpenSearchStatusException("Failed to get a Langfuse Prompt", RestStatus.INTERNAL_SERVER_ERROR); + } + } catch (Exception e) { + String errorMessage = e.getMessage(); + if (e instanceof LangfuseClientApiException) { + errorMessage = getLangfuseClientExceptionMessage((LangfuseClientApiException) e); + } + log.error("Failed to fetch a Langfuse prompt", e); + throw new OpenSearchStatusException("Failed to get Langfuse Prompt: " + errorMessage, RestStatus.INTERNAL_SERVER_ERROR); + } + } + + /** + * Build Langfuse Prompt from TextPrompt + * + * @param textPrompt TextPrompt + * @param mlPrompt Prompt that is used to be deserialized into from retrieved Langfuse Prompt + * @param promptWithInitialVersion Initial version of LangfusePrompt + */ + private void buildMLPromptFromTextPrompt(TextPrompt textPrompt, MLPrompt mlPrompt, TextPrompt promptWithInitialVersion) { + mlPrompt.setVersion(String.valueOf(textPrompt.getVersion())); + mlPrompt.setPrompt(Map.of(USER_ROLE, textPrompt.getPrompt())); + mlPrompt.setTags(!textPrompt.getTags().isEmpty() ? textPrompt.getTags() : null); + + PromptExtraConfig promptExtraConfig = PromptExtraConfig.builder().type(TEXT_PROMPT).labels(textPrompt.getLabels()).build(); + mlPrompt.setPromptExtraConfig(promptExtraConfig); + + // get initial created Time set when initial version prompt is created + setTimeInstants(promptWithInitialVersion.toString(), mlPrompt); + setTimeInstants(textPrompt.toString(), mlPrompt); + } + + /** + * Build Langfuse Prompt from ChatPrompt + * + * @param chatPrompt ChatPrompt + * @param mlPrompt Prompt that is used to be deserialized into from retrieved Langfuse Prompt + * @param promptWithInitialVersion Initial version of LangfusePrompt + */ + private void buildMLPromptFromChatPrompt(ChatPrompt chatPrompt, MLPrompt mlPrompt, ChatPrompt promptWithInitialVersion) { + mlPrompt.setVersion(String.valueOf(chatPrompt.getVersion())); + mlPrompt.setTags(chatPrompt.getTags()); + + PromptExtraConfig promptExtraConfig = PromptExtraConfig.builder().type(CHAT_PROMPT).labels(chatPrompt.getLabels()).build(); + mlPrompt.setPromptExtraConfig(promptExtraConfig); + + List langfusePromptTemplate = chatPrompt.getPrompt(); + Map mlPromptTemplate = new HashMap<>(); + + for (ChatMessage message : langfusePromptTemplate) { + if (mlPromptTemplate.containsKey(message.getRole())) { + continue; + } + mlPromptTemplate.put(message.getRole(), message.getContent()); + } + mlPrompt.setPrompt(mlPromptTemplate); + + // get initial created Time set when initial version prompt is created + setTimeInstants(promptWithInitialVersion.toString(), mlPrompt); + setTimeInstants(chatPrompt.toString(), mlPrompt); + } + + /** + * Set time instant based on fetch source + * + * @param fetchSource response sent from langfuse server upon successful import in JSON format + * @param mlPrompt Prompt + */ + private void setTimeInstants(String fetchSource, MLPrompt mlPrompt) { + int version = 0; + String createdTime = null; + String lastUpdatedTime = null; + + try ( + XContentParser parser = jsonXContent.createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, fetchSource) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + switch (fieldName) { + case "version": + version = parser.intValue(); + case "createdAt": + if (version == Integer.parseInt(INITIAL_VERSION)) { + createdTime = parser.text(); + } + break; + case "updatedAt": + lastUpdatedTime = parser.text(); + break; + default: + parser.skipChildren(); + break; + } + } + } catch (Exception e) { + throw new OpenSearchStatusException("Failed to parse Langfuse Prompt into MLPrompt", RestStatus.INTERNAL_SERVER_ERROR); + } + if (createdTime != null) { + mlPrompt.setCreateTime(Instant.parse(createdTime)); + } + mlPrompt.setLastUpdateTime(Instant.parse(lastUpdatedTime)); + } + + private String getLangfuseClientExceptionMessage(LangfuseClientApiException langfuseClientApiException) { + String message = langfuseClientApiException.getMessage(); + Object errorBody = langfuseClientApiException.body(); + if (errorBody instanceof LinkedHashMap && ((LinkedHashMap) errorBody).containsKey("message")) { + message = ((LinkedHashMap) errorBody).get("message").toString(); + } + return message; + } + + /** + * Import Langfuse prompt based on user input + * + *

+ * 1. Import a langfuse prompt by name + * 2. Import a langfuse prompt or list of langfuse prompts by shared tag + * 3. Import a langfuse prompt or list of langfuse prompts by setting a limit + *

+ * + * @param mlImportPromptInput MLImportPromptInput that contains importing details + * @return list of imported langfuse prompts + */ + @Override + public List importPrompts(MLImportPromptInput mlImportPromptInput) { + String name = mlImportPromptInput.getName(); + String tag = mlImportPromptInput.getTag(); + String limit = mlImportPromptInput.getLimit() == null ? DEFAULT_LIMIT : mlImportPromptInput.getLimit(); + + try { + if (name != null) { + MLPrompt mlPrompt = MLPrompt.builder().name(name).promptManagementType(LANGFUSE).build(); + getPrompt(mlPrompt); + mlPrompt.setPromptExtraConfig(PromptExtraConfig.builder().accessKey(this.accessKey).publicKey(this.publicKey).build()); + return List.of(mlPrompt); + } + + PromptMetaListResponse promptMetaListResponse = langfuseClient + .prompts() + .list(ListPromptsMetaRequest.builder().tag(tag).limit(Integer.parseInt(limit)).build()); + + List promptMetas = promptMetaListResponse.getData(); + List mlPromptList = new ArrayList<>(); + + // There is no langfuse prompts created in the provided environment + if (promptMetas.isEmpty()) { + log.info("No langfuse prompt is found"); + return mlPromptList; + } + + for (PromptMeta promptMeta : promptMetas) { + MLPrompt mlPrompt = MLPrompt.builder().name(promptMeta.getName()).promptManagementType(LANGFUSE).build(); + getPrompt(mlPrompt); + PromptExtraConfig config = mlPrompt.getPromptExtraConfig(); + config.setAccessKey(this.accessKey); + config.setPublicKey(this.publicKey); + mlPrompt.setPromptExtraConfig(config); + + mlPromptList.add(mlPrompt); + } + return mlPromptList; + } catch (Exception e) { + String errorMessage = e.getMessage(); + if (e instanceof LangfuseClientApiException) { + errorMessage = getLangfuseClientExceptionMessage((LangfuseClientApiException) e); + } + log.error("Failed to import a Langfuse prompt", e); + throw new OpenSearchStatusException( + "Failed to import Langfuse Prompts into ML Prompt Index: " + errorMessage, + RestStatus.INTERNAL_SERVER_ERROR + ); + } + } + + /** + * Update the prompt based on the update content + * + * @param mlUpdatePromptInput content that needs to be updated + * @param mlPrompt prompt that contains content before update + * @return updateDataObjectRequest + */ + @Override + public UpdateDataObjectRequest updatePrompt(MLUpdatePromptInput mlUpdatePromptInput, MLPrompt mlPrompt) { + getPrompt(mlPrompt); + MLCreatePromptInput updateContent = MLCreatePromptInput + .builder() + .name(mlPrompt.getName()) + .tags(mlPrompt.getTags()) + .prompt(mlPrompt.getPrompt()) + .promptExtraConfig(mlPrompt.getPromptExtraConfig()) + .build(); + + // Langfuse does not allow users to change prompt's name + if (mlUpdatePromptInput.getTags() != null) { + updateContent.setTags(mlUpdatePromptInput.getTags()); + } + if (mlUpdatePromptInput.getPrompt() != null) { + updateContent.setPrompt(mlUpdatePromptInput.getPrompt()); + } + if (mlUpdatePromptInput.getExtraConfig() != null && mlUpdatePromptInput.getExtraConfig().getLabels() != null) { + updateContent.getPromptExtraConfig().setLabels(mlUpdatePromptInput.getExtraConfig().getLabels()); + } + + // Langfuse can only be updated via create endpoint + createPrompt(updateContent); + MLUpdatePromptInput input = MLUpdatePromptInput.builder().build(); + return UpdateDataObjectRequest + .builder() + .index(ML_PROMPT_INDEX) + .id(mlPrompt.getPromptId()) + .tenantId(mlUpdatePromptInput.getTenantId()) + .dataObject(input) + .build(); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/prompt/MLPromptManagement.java b/plugin/src/main/java/org/opensearch/ml/prompt/MLPromptManagement.java new file mode 100644 index 0000000000..d3cc9f910f --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/prompt/MLPromptManagement.java @@ -0,0 +1,69 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.prompt; + +import static org.opensearch.ml.common.CommonValue.ML_PROMPT_INDEX; + +import java.io.IOException; +import java.time.Instant; + +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.prompt.MLPrompt; +import org.opensearch.ml.common.transport.prompt.MLCreatePromptInput; +import org.opensearch.ml.common.transport.prompt.MLUpdatePromptInput; +import org.opensearch.remote.metadata.client.UpdateDataObjectRequest; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class MLPromptManagement extends AbstractPromptManagement { + public static final String INITIAL_VERSION = "1"; + + public MLPromptManagement() {} + + @Override + public MLPrompt createPrompt(MLCreatePromptInput mlCreatePromptInput) { + String version = mlCreatePromptInput.getVersion(); + return MLPrompt + .builder() + .name(mlCreatePromptInput.getName()) + .description(mlCreatePromptInput.getDescription()) + .version(version == null ? INITIAL_VERSION : version) + .prompt(mlCreatePromptInput.getPrompt()) + .promptManagementType(mlCreatePromptInput.getPromptManagementType()) + .tags(mlCreatePromptInput.getTags()) + .promptExtraConfig(mlCreatePromptInput.getPromptExtraConfig()) + .tenantId(mlCreatePromptInput.getTenantId()) + .createTime(Instant.now()) + .lastUpdateTime(Instant.now()) + .build(); + } + + @Override + public void getPrompt(MLPrompt mlPrompt) { + // do nothing + } + + @Override + public UpdateDataObjectRequest updatePrompt(MLUpdatePromptInput mlUpdatePromptInput, MLPrompt mlPrompt) { + int version = Integer.parseInt(mlPrompt.getVersion()); + mlUpdatePromptInput.setLastUpdateTime(Instant.now()); + mlUpdatePromptInput.setVersion(String.valueOf(version + 1)); + + return UpdateDataObjectRequest + .builder() + .index(ML_PROMPT_INDEX) + .id(mlPrompt.getPromptId()) + .tenantId(mlUpdatePromptInput.getTenantId()) + .dataObject(mlUpdatePromptInput) + .build(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder; + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/prompt/MLPromptManager.java b/plugin/src/main/java/org/opensearch/ml/prompt/MLPromptManager.java index 4c9f01aa48..605cc3581d 100644 --- a/plugin/src/main/java/org/opensearch/ml/prompt/MLPromptManager.java +++ b/plugin/src/main/java/org/opensearch/ml/prompt/MLPromptManager.java @@ -8,6 +8,8 @@ import static org.opensearch.common.xcontent.json.JsonXContent.jsonXContent; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.CommonValue.ML_PROMPT_INDEX; +import static org.opensearch.ml.common.prompt.MLPrompt.LANGFUSE; +import static org.opensearch.ml.prompt.AbstractPromptManagement.init; import java.io.IOException; import java.util.ArrayList; @@ -24,6 +26,7 @@ import org.opensearch.action.get.GetResponse; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.update.UpdateResponse; import org.opensearch.common.collect.Tuple; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.LoggingDeprecationHandler; @@ -42,12 +45,15 @@ import org.opensearch.ml.common.prompt.MLPrompt; import org.opensearch.ml.common.transport.prompt.MLCreatePromptInput; import org.opensearch.ml.common.utils.StringUtils; +import org.opensearch.ml.engine.encryptor.EncryptorImpl; import org.opensearch.ml.utils.MLExceptionUtils; import org.opensearch.remote.metadata.client.GetDataObjectRequest; import org.opensearch.remote.metadata.client.GetDataObjectResponse; import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.remote.metadata.client.SearchDataObjectRequest; import org.opensearch.remote.metadata.client.SearchDataObjectResponse; +import org.opensearch.remote.metadata.client.UpdateDataObjectRequest; +import org.opensearch.remote.metadata.client.UpdateDataObjectResponse; import org.opensearch.remote.metadata.common.SdkClientUtils; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.transport.client.Client; @@ -81,10 +87,12 @@ public class MLPromptManager { private final Client client; private final SdkClient sdkClient; + private final EncryptorImpl encryptor; - public MLPromptManager(@NonNull Client client, @NonNull SdkClient sdkClient) { + public MLPromptManager(@NonNull Client client, @NonNull SdkClient sdkClient, EncryptorImpl encryptor) { this.client = Objects.requireNonNull(client, "Client cannot be null"); this.sdkClient = Objects.requireNonNull(sdkClient, "SdkClient cannot be null"); + this.encryptor = encryptor; } /** @@ -360,6 +368,13 @@ public PromptResult pullPrompt(String promptRef, String key, PromptParameters pr .build(); // fetch prompt first based on prompt id MLPrompt mlPrompt = getPrompt(getDataObjectRequest); + // enables user execute the prompt in external prompt management server that is created via ml commons create, without importing + // it + if (fetchPromptExternally(mlPrompt)) { + mlPrompt.decrypt(mlPrompt.getPromptManagementType(), encryptor::decrypt, tenantId); + AbstractPromptManagement promptManagement = init(mlPrompt.getPromptManagementType(), mlPrompt.getPromptExtraConfig()); + promptManagement.getPrompt(mlPrompt); + } // extract a prompt object from retrieved ML Prompt Map promptField = mlPrompt.getPrompt(); // check if the specified key is defined in the prompt @@ -409,6 +424,20 @@ private String resolvePromptID(String promptRef, String tenantId) { return promptId; } + /** + * Checks if the prompt type is Langfuse + * + * @param mlPrompt prompt that is either MLPrompt or LangfusePrompt + * @return true if the given prompt is Langfuse prompt, false otherwise. + */ + boolean fetchPromptExternally(MLPrompt mlPrompt) { + if (mlPrompt.getPromptManagementType().equalsIgnoreCase(LANGFUSE) && mlPrompt.getPrompt() == null) { + return true; + } else { + return false; + } + } + /** * Replace all the placeholder variables with user-defined values provided during execution time. * @@ -418,17 +447,17 @@ private String resolvePromptID(String promptRef, String tenantId) { * @return */ private String populatePlaceholders(String content, PromptParameters promptParameters, String promptRef) { + StringSubstitutor substitutor = new StringSubstitutor(); if (!promptParameters.isEmpty() && content.contains(PROMPT_PARAMETER_PLACEHOLDER)) { - StringSubstitutor substitutor = new StringSubstitutor( - promptParameters.getParameters(promptRef), - PROMPT_PARAMETER_PLACEHOLDER, - "}" - ); + substitutor = new StringSubstitutor(promptParameters.getParameters(promptRef), PROMPT_PARAMETER_PLACEHOLDER, "}"); content = substitutor.replace(content); + } else if (!promptParameters.isEmpty() && content.contains("{{") && content.contains("}}")) { + substitutor = new StringSubstitutor(promptParameters.getParameters(promptRef), "{{", "}}"); } + content = substitutor.replace(content); // this checks if all the required input values are provided by users and all the placeholder variables are replaced. - if (content.contains(PROMPT_PARAMETER_PLACEHOLDER)) { + if (content.contains(PROMPT_PARAMETER_PLACEHOLDER) || (content.contains("{{") && content.contains("}}"))) { throw new InvalidPullPromptSyntaxException("Failed to replace all the placeholders"); } return content; @@ -460,6 +489,49 @@ public MLPrompt getPrompt(GetDataObjectRequest getDataObjectRequest) { } } + /** + * Updates the prompt based on the update contents and replace the old prompt with updated prompt from the index + * + * @param updateDataObjectRequest the updateRequest that needs to be handled + * @param promptId The prompt ID of a prompt that needs to be updated + * @param listener a listener to be notified of the response + */ + public void updatePromptIndex( + UpdateDataObjectRequest updateDataObjectRequest, + String promptId, + ActionListener listener + ) { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + sdkClient.updateDataObjectAsync(updateDataObjectRequest).whenComplete((updateDataObjectResponse, throwable) -> { + context.restore(); + handleUpdateResponse(updateDataObjectResponse, throwable, promptId, listener); + }); + } + } + + /** + * Handles the response from the update prompt request. If the response is successful, notify the listener + * with the UpdateResponse. Otherwise, notify the failure exception to the listener. + * + * @param updateDataObjectResponse The response from the update prompt request + * @param throwable The exception that occurred during the update prompt request + * @param listener The listener to be notified of the response + */ + private void handleUpdateResponse( + UpdateDataObjectResponse updateDataObjectResponse, + Throwable throwable, + String promptId, + ActionListener listener + ) { + if (throwable != null) { + Exception cause = SdkClientUtils.unwrapAndConvertToException(throwable); + handleFailure(cause, promptId, listener, "Failed to update ML prompt {}"); + return; + } + UpdateResponse updateResponse = updateDataObjectResponse.updateResponse(); + listener.onResponse(updateResponse); + } + /** * A class that represents a list of messages. */ diff --git a/plugin/src/main/java/org/opensearch/ml/prompt/PromptImportable.java b/plugin/src/main/java/org/opensearch/ml/prompt/PromptImportable.java new file mode 100644 index 0000000000..73ac41638c --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/prompt/PromptImportable.java @@ -0,0 +1,11 @@ +package org.opensearch.ml.prompt; + +import java.util.List; + +import org.opensearch.ml.common.prompt.MLPrompt; +import org.opensearch.ml.common.transport.prompt.MLImportPromptInput; + +public interface PromptImportable { + + public List importPrompts(MLImportPromptInput mlImportPromptInput); +} diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLImportPromptAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLImportPromptAction.java new file mode 100644 index 0000000000..db955d90dd --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLImportPromptAction.java @@ -0,0 +1,65 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; +import static org.opensearch.ml.utils.TenantAwareHelper.getTenantID; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; +import org.opensearch.ml.common.transport.prompt.MLImportPromptAction; +import org.opensearch.ml.common.transport.prompt.MLImportPromptInput; +import org.opensearch.ml.common.transport.prompt.MLImportPromptRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.transport.client.node.NodeClient; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; + +public class RestMLImportPromptAction extends BaseRestHandler { + private static final String ML_IMPORT_PROMPT_ACTION = "ml_import_prompt_actio"; + private final MLFeatureEnabledSetting mlFeatureEnabledSetting; + + public RestMLImportPromptAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) { + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; + } + + @Override + public String getName() { + return ML_IMPORT_PROMPT_ACTION; + } + + @Override + public List routes() { + return ImmutableList.of(new Route(RestRequest.Method.POST, String.format(Locale.ROOT, "%s/prompts/_import", ML_BASE_URI))); + } + + @Override + public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + MLImportPromptRequest mlImportPromptRequest = getRequest(request); + return channel -> client.execute(MLImportPromptAction.INSTANCE, mlImportPromptRequest, new RestToXContentListener<>(channel)); + } + + @VisibleForTesting + MLImportPromptRequest getRequest(RestRequest request) throws IOException { + if (!request.hasContent()) { + throw new IOException("Import Prompt request has empty body"); + } + XContentParser parser = request.contentParser(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + MLImportPromptInput mlImportPromptInput = MLImportPromptInput.parse(parser); + String tenantId = getTenantID(mlFeatureEnabledSetting.isMultiTenancyEnabled(), request); + mlImportPromptInput.setTenantId(tenantId); + return new MLImportPromptRequest(mlImportPromptInput); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/action/prompt/GetPromptTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/prompt/GetPromptTransportActionTests.java index 0c1e9728ea..2bb05bfb23 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/prompt/GetPromptTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/prompt/GetPromptTransportActionTests.java @@ -34,6 +34,7 @@ import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; import org.opensearch.ml.common.transport.prompt.MLPromptGetRequest; import org.opensearch.ml.common.transport.prompt.MLPromptGetResponse; +import org.opensearch.ml.engine.encryptor.EncryptorImpl; import org.opensearch.ml.prompt.MLPromptManager; import org.opensearch.remote.metadata.client.GetDataObjectRequest; import org.opensearch.remote.metadata.client.SdkClient; @@ -84,6 +85,9 @@ public class GetPromptTransportActionTests extends OpenSearchTestCase { @Mock private MLPromptManager mlPromptManager; + @Mock + private EncryptorImpl encryptor; + @Captor private ArgumentCaptor getDataObjectRequestArgumentCaptor; @@ -98,7 +102,15 @@ public void setup() throws IOException { when(getResponse.getSourceAsString()).thenReturn("{}"); getPromptTransportAction = spy( - new GetPromptTransportAction(transportService, actionFilters, client, sdkClient, mlFeatureEnabledSetting, mlPromptManager) + new GetPromptTransportAction( + transportService, + actionFilters, + client, + sdkClient, + encryptor, + mlFeatureEnabledSetting, + mlPromptManager + ) ); threadContext = new ThreadContext(Settings.EMPTY); @@ -113,6 +125,7 @@ public void testConstructor() { actionFilters, client, sdkClient, + encryptor, mlFeatureEnabledSetting, mlPromptManager ); diff --git a/plugin/src/test/java/org/opensearch/ml/action/prompt/TransportCreatePromptActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/prompt/TransportCreatePromptActionTests.java index 608c5715fd..92d89370c8 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/prompt/TransportCreatePromptActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/prompt/TransportCreatePromptActionTests.java @@ -43,6 +43,7 @@ import org.opensearch.ml.common.transport.prompt.MLCreatePromptInput; import org.opensearch.ml.common.transport.prompt.MLCreatePromptRequest; import org.opensearch.ml.common.transport.prompt.MLCreatePromptResponse; +import org.opensearch.ml.engine.MLEngine; import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.prompt.MLPromptManager; import org.opensearch.ml.utils.TestHelper; @@ -101,6 +102,9 @@ public class TransportCreatePromptActionTests extends OpenSearchTestCase { @Mock private MLPromptManager mlPromptManager; + @Mock + private MLEngine mlEngine; + @Before public void setup() throws IOException { MockitoAnnotations.openMocks(this); @@ -115,6 +119,7 @@ public void setup() throws IOException { mlIndicesHandler, client, sdkClient, + mlEngine, mlPromptManager, mlFeatureEnabledSetting ) @@ -149,6 +154,7 @@ public void testConstructor() { mlIndicesHandler, client, sdkClient, + mlEngine, mlPromptManager, mlFeatureEnabledSetting ); diff --git a/plugin/src/test/java/org/opensearch/ml/action/prompt/UpdatePromptTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/prompt/UpdatePromptTransportActionTests.java index 0d2b4ce9bf..0d5b5662b5 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/prompt/UpdatePromptTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/prompt/UpdatePromptTransportActionTests.java @@ -45,6 +45,7 @@ import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; import org.opensearch.ml.common.transport.prompt.MLUpdatePromptInput; import org.opensearch.ml.common.transport.prompt.MLUpdatePromptRequest; +import org.opensearch.ml.engine.encryptor.EncryptorImpl; import org.opensearch.ml.prompt.MLPromptManager; import org.opensearch.ml.utils.TestHelper; import org.opensearch.remote.metadata.client.GetDataObjectRequest; @@ -96,6 +97,9 @@ public class UpdatePromptTransportActionTests extends OpenSearchTestCase { UpdateResponse updateResponse; + @Mock + EncryptorImpl encryptor; + @Captor private ArgumentCaptor getDataObjectRequestArgumentCaptor; @@ -106,7 +110,15 @@ public void setup() throws IOException { sdkClient = SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap()); when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(false); updatePromptTransportAction = spy( - new UpdatePromptTransportAction(transportService, actionFilters, client, sdkClient, mlFeatureEnabledSetting, mlPromptManager) + new UpdatePromptTransportAction( + transportService, + actionFilters, + client, + sdkClient, + mlFeatureEnabledSetting, + mlPromptManager, + encryptor + ) ); threadContext = new ThreadContext(Settings.EMPTY); when(client.threadPool()).thenReturn(threadPool); @@ -131,7 +143,8 @@ public void testConstructor() { client, sdkClient, mlFeatureEnabledSetting, - mlPromptManager + mlPromptManager, + encryptor ); assertNotNull(updatePromptTransportAction); }