Skip to content

Commit 7d7f381

Browse files
committed
add javadoc & move importPrompts logic out of AbstractPromptManagement
Signed-off-by: seungwon cho <[email protected]>
1 parent 0b65cc9 commit 7d7f381

File tree

5 files changed

+151
-32
lines changed

5 files changed

+151
-32
lines changed

plugin/src/main/java/org/opensearch/ml/action/prompt/ImportPromptTransportAction.java

Lines changed: 67 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,15 @@
1717
import java.util.Map;
1818
import java.util.concurrent.atomic.AtomicInteger;
1919

20+
import org.opensearch.OpenSearchStatusException;
2021
import org.opensearch.action.index.IndexResponse;
2122
import org.opensearch.action.search.SearchResponse;
2223
import org.opensearch.action.support.ActionFilters;
2324
import org.opensearch.action.support.HandledTransportAction;
2425
import org.opensearch.common.inject.Inject;
2526
import org.opensearch.common.util.concurrent.ThreadContext;
2627
import org.opensearch.core.action.ActionListener;
28+
import org.opensearch.core.rest.RestStatus;
2729
import org.opensearch.ml.common.prompt.MLPrompt;
2830
import org.opensearch.ml.common.prompt.PromptExtraConfig;
2931
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
@@ -35,6 +37,7 @@
3537
import org.opensearch.ml.engine.indices.MLIndicesHandler;
3638
import org.opensearch.ml.prompt.AbstractPromptManagement;
3739
import org.opensearch.ml.prompt.MLPromptManager;
40+
import org.opensearch.ml.prompt.PromptImportable;
3841
import org.opensearch.ml.utils.TenantAwareHelper;
3942
import org.opensearch.remote.metadata.client.GetDataObjectRequest;
4043
import org.opensearch.remote.metadata.client.PutDataObjectRequest;
@@ -94,28 +97,41 @@ protected void doExecute(Task task, MLImportPromptRequest mlImportPromptRequest,
9497
promptManagementType,
9598
PromptExtraConfig.builder().publicKey(publicKey).accessKey(accessKey).build()
9699
);
97-
List<MLPrompt> mlPromptList = promptManagement.importPrompts(mlImportPromptInput);
98-
Map<String, String> responseBody = new HashMap<>();
99-
if (mlPromptList.isEmpty()) {
100-
listener.onResponse(new MLImportPromptResponse(responseBody));
101-
return;
102-
}
103-
AtomicInteger remainingMLPrompts = new AtomicInteger(mlPromptList.size());
104-
for (MLPrompt mlPrompt : mlPromptList) {
105-
mlPrompt.encrypt(promptManagementType, mlEngine::encrypt, tenantId);
106-
handleDuplicateName(mlPrompt, tenantId, ActionListener.wrap(promptId -> {
107-
if (promptId == null) {
108-
indexPrompt(mlPrompt, responseBody, remainingMLPrompts, listener);
109-
} else {
110-
updateImportResponseBody(promptId, mlPrompt.getName(), responseBody, remainingMLPrompts, listener);
111-
}
112-
}, listener::onFailure));
100+
101+
if (!(promptManagement instanceof PromptImportable importer)) {
102+
throw new OpenSearchStatusException("Import prompt is not supported for MLPromptManagement", RestStatus.BAD_REQUEST);
103+
} else {
104+
List<MLPrompt> mlPromptList = importer.importPrompts(mlImportPromptInput);
105+
Map<String, String> responseBody = new HashMap<>();
106+
if (mlPromptList.isEmpty()) {
107+
listener.onResponse(new MLImportPromptResponse(responseBody));
108+
return;
109+
}
110+
AtomicInteger remainingMLPrompts = new AtomicInteger(mlPromptList.size());
111+
for (MLPrompt mlPrompt : mlPromptList) {
112+
mlPrompt.encrypt(promptManagementType, mlEngine::encrypt, tenantId);
113+
handleConflictingName(mlPrompt, tenantId, ActionListener.wrap(promptId -> {
114+
if (promptId == null) {
115+
indexPrompt(mlPrompt, responseBody, remainingMLPrompts, listener);
116+
} else {
117+
updateImportResponseBody(promptId, mlPrompt.getName(), responseBody, remainingMLPrompts, listener);
118+
}
119+
}, listener::onFailure));
120+
}
113121
}
114122
} catch (Exception e) {
115123
handleFailure(e, null, listener, "Failed to import " + promptManagementType + " Prompts into System Index");
116124
}
117125
}
118126

127+
/**
128+
* Store prompt into system index
129+
*
130+
* @param prompt prompt that needs to be stored into the system index
131+
* @param responseBody response body that will be return upon success in the format of prompt name to prompt id
132+
* @param remainingMLPrompts remaining prompt to be stored into the system index
133+
* @param listener actionListener that will be notified upon success or failure of the prompt creation
134+
*/
119135
private void indexPrompt(
120136
MLPrompt prompt,
121137
Map<String, String> responseBody,
@@ -140,10 +156,26 @@ private void indexPrompt(
140156
}, e -> { handleFailure(e, null, listener, "Failed to init ML prompt index"); }));
141157
}
142158

159+
/**
160+
* Builds putRequest to write prompt into index
161+
*
162+
* @param prompt prompt that needs to be stored into the system index
163+
* @return PutDataObjectRequest
164+
*/
143165
private PutDataObjectRequest buildPromptPutRequest(MLPrompt prompt) {
144166
return PutDataObjectRequest.builder().tenantId(prompt.getTenantId()).index(ML_PROMPT_INDEX).dataObject(prompt).build();
145167
}
146168

169+
/**
170+
* Handles PutResponse after prompt is indexed
171+
*
172+
* @param putResponse response received after prompt is indexed
173+
* @param throwable throwable
174+
* @param name prompt name that is indexed
175+
* @param responseBody response body that will be return upon success in the format of prompt name to prompt id
176+
* @param remainingMLPrompts remaining prompt to be stored into the system index
177+
* @param listener actionListener that will be notified upon success or failure of the prompt creation
178+
*/
147179
private void handlePromptPutResponse(
148180
PutDataObjectResponse putResponse,
149181
Throwable throwable,
@@ -165,6 +197,15 @@ private void handlePromptPutResponse(
165197
}
166198
}
167199

200+
/**
201+
* Update the response body with the prompt name and prompt id upon successful import
202+
*
203+
* @param promptId prompt id returned after prompt is successfully indexed into the system index
204+
* @param name name of the prompt that is stored
205+
* @param responseBody response body that will be return upon success in the format of prompt name to prompt id
206+
* @param remainingMLPrompts remaining prompt to be stored into the system index
207+
* @param listener actionListener that will be notified upon success or failure of the prompt creation
208+
*/
168209
private void updateImportResponseBody(
169210
String promptId,
170211
String name,
@@ -179,7 +220,16 @@ private void updateImportResponseBody(
179220
}
180221
}
181222

182-
private void handleDuplicateName(MLPrompt importingPrompt, String tenantId, ActionListener<String> wrappedListener) throws IOException {
223+
/**
224+
* Search name field on prompt system index.
225+
*
226+
* @param importingPrompt prompt that needs to be imported into prompt system index
227+
* @param tenantId tenant id
228+
* @param wrappedListener listener that will be notified with prompt id upon success or failure of the prompt creation
229+
* @throws IOException if search hits, meaning conflicting name exist
230+
*/
231+
private void handleConflictingName(MLPrompt importingPrompt, String tenantId, ActionListener<String> wrappedListener)
232+
throws IOException {
183233
String name = importingPrompt.getName();
184234
SearchResponse searchResponse = mlPromptManager.searchPromptByName(name, tenantId);
185235
if (searchResponse != null

plugin/src/main/java/org/opensearch/ml/prompt/AbstractPromptManagement.java

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,10 @@
88
import static org.opensearch.ml.common.prompt.MLPrompt.LANGFUSE;
99
import static org.opensearch.ml.common.prompt.MLPrompt.MLPROMPT;
1010

11-
import java.util.List;
12-
1311
import org.opensearch.core.xcontent.ToXContentObject;
1412
import org.opensearch.ml.common.prompt.MLPrompt;
1513
import org.opensearch.ml.common.prompt.PromptExtraConfig;
1614
import org.opensearch.ml.common.transport.prompt.MLCreatePromptInput;
17-
import org.opensearch.ml.common.transport.prompt.MLImportPromptInput;
1815
import org.opensearch.ml.common.transport.prompt.MLUpdatePromptInput;
1916
import org.opensearch.remote.metadata.client.UpdateDataObjectRequest;
2017

@@ -38,7 +35,5 @@ public static AbstractPromptManagement init(String promptManagementType, PromptE
3835

3936
public abstract void getPrompt(MLPrompt mlPrompt);
4037

41-
public abstract List<MLPrompt> importPrompts(MLImportPromptInput mlImportPromptInput);
42-
4338
public abstract UpdateDataObjectRequest updatePrompt(MLUpdatePromptInput mlUpdatePromptInput, MLPrompt mlPrompt);
4439
}

plugin/src/main/java/org/opensearch/ml/prompt/LangfusePromptManagement.java

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151

5252
@Log4j2
5353
@Getter
54-
public class LangfusePromptManagement extends AbstractPromptManagement {
54+
public class LangfusePromptManagement extends AbstractPromptManagement implements PromptImportable {
5555
public static final String PUBLIC_KEY_FIELD = "public_key";
5656
public static final String ACCESS_KEY_FIELD = "access_key";
5757
public static final String LANGFUSE_URL = "https://us.cloud.langfuse.com";
@@ -72,10 +72,23 @@ public LangfusePromptManagement(String publicKey, String accessKey) {
7272
this.langfuseClient = initLangfuseClient(this.publicKey, this.accessKey);
7373
}
7474

75+
/**
76+
* Initialize Langfuse Client that is used to invoke Langfuse API to Langfuse Server
77+
*
78+
* @param username
79+
* @param password
80+
* @return
81+
*/
7582
public LangfuseClient initLangfuseClient(String username, String password) {
7683
return LangfuseClient.builder().url(LANGFUSE_URL).credentials(username, password).build();
7784
}
7885

86+
/**
87+
* Create Langfuse Prompt in Langfuse Server
88+
*
89+
* @param mlCreatePromptInput input that contains metadata to create a prompt
90+
* @return MLPrompt that will be used to create Langfuse Prompt
91+
*/
7992
@Override
8093
public MLPrompt createPrompt(MLCreatePromptInput mlCreatePromptInput) {
8194
PromptExtraConfig promptExtraConfig = mlCreatePromptInput.getPromptExtraConfig();
@@ -122,6 +135,13 @@ public MLPrompt createPrompt(MLCreatePromptInput mlCreatePromptInput) {
122135
}
123136
}
124137

138+
/**
139+
* Build Text Prompt Request
140+
*
141+
* @param mlCreatePromptInput MLCreatePromptInput that contains metadata to create Langfuse Text Prompt
142+
* @param promptExtraConfig Prompt Extra Config that contains credentials and other metadatas to construct Text Prompt
143+
* @return CreatePromptRequest
144+
*/
125145
private CreatePromptRequest buildTextPromptRequest(MLCreatePromptInput mlCreatePromptInput, PromptExtraConfig promptExtraConfig) {
126146
CreateTextPromptRequest textRequest = CreateTextPromptRequest
127147
.builder()
@@ -134,6 +154,13 @@ private CreatePromptRequest buildTextPromptRequest(MLCreatePromptInput mlCreateP
134154
return CreatePromptRequest.text(textRequest);
135155
}
136156

157+
/**
158+
* Build Chat Prompt Request
159+
*
160+
* @param mlCreatePromptInput MLCreatePromptInput that contains metadata to create Langfuse Chat Prompt
161+
* @param promptExtraConfig Prompt Extra Config that contains credentials and other metadatas to construct Chat Prompt
162+
* @return CreatePromptRequest
163+
*/
137164
private CreatePromptRequest buildChatPromptRequest(MLCreatePromptInput mlCreatePromptInput, PromptExtraConfig promptExtraConfig) {
138165
List<ChatMessage> langfusePromptTemplate = new ArrayList<>();
139166
Map<String, String> mlPromptTemplate = mlCreatePromptInput.getPrompt();
@@ -167,6 +194,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
167194
return builder;
168195
}
169196

197+
/**
198+
* Retrieve prompt from Langfuse Server
199+
*
200+
* @param mlPrompt Prompt that contains credentials and prompt name that is used to retrieve Langfuse Prompt
201+
*/
170202
@Override
171203
public void getPrompt(MLPrompt mlPrompt) {
172204
mlPrompt.setPromptExtraConfig(null); // won't include credentials in response body
@@ -195,6 +227,13 @@ public void getPrompt(MLPrompt mlPrompt) {
195227
}
196228
}
197229

230+
/**
231+
* Build Langfuse Prompt from TextPrompt
232+
*
233+
* @param textPrompt TextPrompt
234+
* @param mlPrompt Prompt that is used to be deserialized into from retrieved Langfuse Prompt
235+
* @param promptWithInitialVersion Initial version of LangfusePrompt
236+
*/
198237
private void buildMLPromptFromTextPrompt(TextPrompt textPrompt, MLPrompt mlPrompt, TextPrompt promptWithInitialVersion) {
199238
mlPrompt.setVersion(String.valueOf(textPrompt.getVersion()));
200239
mlPrompt.setPrompt(Map.of(USER_ROLE, textPrompt.getPrompt()));
@@ -208,6 +247,13 @@ private void buildMLPromptFromTextPrompt(TextPrompt textPrompt, MLPrompt mlPromp
208247
setTimeInstants(textPrompt.toString(), mlPrompt);
209248
}
210249

250+
/**
251+
* Build Langfuse Prompt from ChatPrompt
252+
*
253+
* @param chatPrompt ChatPrompt
254+
* @param mlPrompt Prompt that is used to be deserialized into from retrieved Langfuse Prompt
255+
* @param promptWithInitialVersion Initial version of LangfusePrompt
256+
*/
211257
private void buildMLPromptFromChatPrompt(ChatPrompt chatPrompt, MLPrompt mlPrompt, ChatPrompt promptWithInitialVersion) {
212258
mlPrompt.setVersion(String.valueOf(chatPrompt.getVersion()));
213259
mlPrompt.setTags(chatPrompt.getTags());
@@ -231,6 +277,12 @@ private void buildMLPromptFromChatPrompt(ChatPrompt chatPrompt, MLPrompt mlPromp
231277
setTimeInstants(chatPrompt.toString(), mlPrompt);
232278
}
233279

280+
/**
281+
* Set time instant based on fetch source
282+
*
283+
* @param fetchSource response sent from langfuse server upon successful import in JSON format
284+
* @param mlPrompt Prompt
285+
*/
234286
private void setTimeInstants(String fetchSource, MLPrompt mlPrompt) {
235287
int version = 0;
236288
String createdTime = null;
@@ -277,6 +329,18 @@ private String getLangfuseClientExceptionMessage(LangfuseClientApiException lang
277329
return message;
278330
}
279331

332+
/**
333+
* Import Langfuse prompt based on user input
334+
*
335+
* <p>
336+
* 1. Import a langfuse prompt by name
337+
* 2. Import a langfuse prompt or list of langfuse prompts by shared tag
338+
* 3. Import a langfuse prompt or list of langfuse prompts by setting a limit
339+
* </p>
340+
*
341+
* @param mlImportPromptInput MLImportPromptInput that contains importing details
342+
* @return list of imported langfuse prompts
343+
*/
280344
@Override
281345
public List<MLPrompt> importPrompts(MLImportPromptInput mlImportPromptInput) {
282346
String name = mlImportPromptInput.getName();
@@ -328,6 +392,13 @@ public List<MLPrompt> importPrompts(MLImportPromptInput mlImportPromptInput) {
328392
}
329393
}
330394

395+
/**
396+
* Update the prompt based on the update content
397+
*
398+
* @param mlUpdatePromptInput content that needs to be updated
399+
* @param mlPrompt prompt that contains content before update
400+
* @return updateDataObjectRequest
401+
*/
331402
@Override
332403
public UpdateDataObjectRequest updatePrompt(MLUpdatePromptInput mlUpdatePromptInput, MLPrompt mlPrompt) {
333404
getPrompt(mlPrompt);
@@ -350,6 +421,7 @@ public UpdateDataObjectRequest updatePrompt(MLUpdatePromptInput mlUpdatePromptIn
350421
updateContent.getPromptExtraConfig().setLabels(mlUpdatePromptInput.getExtraConfig().getLabels());
351422
}
352423

424+
// Langfuse can only be updated via create endpoint
353425
createPrompt(updateContent);
354426
MLUpdatePromptInput input = MLUpdatePromptInput.builder().build();
355427
return UpdateDataObjectRequest

plugin/src/main/java/org/opensearch/ml/prompt/MLPromptManagement.java

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,10 @@
99

1010
import java.io.IOException;
1111
import java.time.Instant;
12-
import java.util.List;
1312

14-
import org.opensearch.OpenSearchStatusException;
15-
import org.opensearch.core.rest.RestStatus;
1613
import org.opensearch.core.xcontent.XContentBuilder;
1714
import org.opensearch.ml.common.prompt.MLPrompt;
1815
import org.opensearch.ml.common.transport.prompt.MLCreatePromptInput;
19-
import org.opensearch.ml.common.transport.prompt.MLImportPromptInput;
2016
import org.opensearch.ml.common.transport.prompt.MLUpdatePromptInput;
2117
import org.opensearch.remote.metadata.client.UpdateDataObjectRequest;
2218

@@ -66,11 +62,6 @@ public UpdateDataObjectRequest updatePrompt(MLUpdatePromptInput mlUpdatePromptIn
6662
.build();
6763
}
6864

69-
@Override
70-
public List<MLPrompt> importPrompts(MLImportPromptInput mlImportPromptInput) {
71-
throw new OpenSearchStatusException("Import prompt is not supported for MLPromptManagement", RestStatus.BAD_REQUEST);
72-
}
73-
7465
@Override
7566
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
7667
return builder;
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
package org.opensearch.ml.prompt;
2+
3+
import java.util.List;
4+
5+
import org.opensearch.ml.common.prompt.MLPrompt;
6+
import org.opensearch.ml.common.transport.prompt.MLImportPromptInput;
7+
8+
public interface PromptImportable {
9+
10+
public List<MLPrompt> importPrompts(MLImportPromptInput mlImportPromptInput);
11+
}

0 commit comments

Comments
 (0)