17
17
import java .util .Map ;
18
18
import java .util .concurrent .atomic .AtomicInteger ;
19
19
20
+ import org .opensearch .OpenSearchStatusException ;
20
21
import org .opensearch .action .index .IndexResponse ;
21
22
import org .opensearch .action .search .SearchResponse ;
22
23
import org .opensearch .action .support .ActionFilters ;
23
24
import org .opensearch .action .support .HandledTransportAction ;
24
25
import org .opensearch .common .inject .Inject ;
25
26
import org .opensearch .common .util .concurrent .ThreadContext ;
26
27
import org .opensearch .core .action .ActionListener ;
28
+ import org .opensearch .core .rest .RestStatus ;
27
29
import org .opensearch .ml .common .prompt .MLPrompt ;
28
30
import org .opensearch .ml .common .prompt .PromptExtraConfig ;
29
31
import org .opensearch .ml .common .settings .MLFeatureEnabledSetting ;
35
37
import org .opensearch .ml .engine .indices .MLIndicesHandler ;
36
38
import org .opensearch .ml .prompt .AbstractPromptManagement ;
37
39
import org .opensearch .ml .prompt .MLPromptManager ;
40
+ import org .opensearch .ml .prompt .PromptImportable ;
38
41
import org .opensearch .ml .utils .TenantAwareHelper ;
39
42
import org .opensearch .remote .metadata .client .GetDataObjectRequest ;
40
43
import org .opensearch .remote .metadata .client .PutDataObjectRequest ;
@@ -94,28 +97,41 @@ protected void doExecute(Task task, MLImportPromptRequest mlImportPromptRequest,
94
97
promptManagementType ,
95
98
PromptExtraConfig .builder ().publicKey (publicKey ).accessKey (accessKey ).build ()
96
99
);
97
- List <MLPrompt > mlPromptList = promptManagement .importPrompts (mlImportPromptInput );
98
- Map <String , String > responseBody = new HashMap <>();
99
- if (mlPromptList .isEmpty ()) {
100
- listener .onResponse (new MLImportPromptResponse (responseBody ));
101
- return ;
102
- }
103
- AtomicInteger remainingMLPrompts = new AtomicInteger (mlPromptList .size ());
104
- for (MLPrompt mlPrompt : mlPromptList ) {
105
- mlPrompt .encrypt (promptManagementType , mlEngine ::encrypt , tenantId );
106
- handleDuplicateName (mlPrompt , tenantId , ActionListener .wrap (promptId -> {
107
- if (promptId == null ) {
108
- indexPrompt (mlPrompt , responseBody , remainingMLPrompts , listener );
109
- } else {
110
- updateImportResponseBody (promptId , mlPrompt .getName (), responseBody , remainingMLPrompts , listener );
111
- }
112
- }, listener ::onFailure ));
100
+
101
+ if (!(promptManagement instanceof PromptImportable importer )) {
102
+ throw new OpenSearchStatusException ("Import prompt is not supported for MLPromptManagement" , RestStatus .BAD_REQUEST );
103
+ } else {
104
+ List <MLPrompt > mlPromptList = importer .importPrompts (mlImportPromptInput );
105
+ Map <String , String > responseBody = new HashMap <>();
106
+ if (mlPromptList .isEmpty ()) {
107
+ listener .onResponse (new MLImportPromptResponse (responseBody ));
108
+ return ;
109
+ }
110
+ AtomicInteger remainingMLPrompts = new AtomicInteger (mlPromptList .size ());
111
+ for (MLPrompt mlPrompt : mlPromptList ) {
112
+ mlPrompt .encrypt (promptManagementType , mlEngine ::encrypt , tenantId );
113
+ handleConflictingName (mlPrompt , tenantId , ActionListener .wrap (promptId -> {
114
+ if (promptId == null ) {
115
+ indexPrompt (mlPrompt , responseBody , remainingMLPrompts , listener );
116
+ } else {
117
+ updateImportResponseBody (promptId , mlPrompt .getName (), responseBody , remainingMLPrompts , listener );
118
+ }
119
+ }, listener ::onFailure ));
120
+ }
113
121
}
114
122
} catch (Exception e ) {
115
123
handleFailure (e , null , listener , "Failed to import " + promptManagementType + " Prompts into System Index" );
116
124
}
117
125
}
118
126
127
+ /**
128
+ * Store prompt into system index
129
+ *
130
+ * @param prompt prompt that needs to be stored into the system index
131
+ * @param responseBody response body that will be return upon success in the format of prompt name to prompt id
132
+ * @param remainingMLPrompts remaining prompt to be stored into the system index
133
+ * @param listener actionListener that will be notified upon success or failure of the prompt creation
134
+ */
119
135
private void indexPrompt (
120
136
MLPrompt prompt ,
121
137
Map <String , String > responseBody ,
@@ -140,10 +156,26 @@ private void indexPrompt(
140
156
}, e -> { handleFailure (e , null , listener , "Failed to init ML prompt index" ); }));
141
157
}
142
158
159
+ /**
160
+ * Builds putRequest to write prompt into index
161
+ *
162
+ * @param prompt prompt that needs to be stored into the system index
163
+ * @return PutDataObjectRequest
164
+ */
143
165
private PutDataObjectRequest buildPromptPutRequest (MLPrompt prompt ) {
144
166
return PutDataObjectRequest .builder ().tenantId (prompt .getTenantId ()).index (ML_PROMPT_INDEX ).dataObject (prompt ).build ();
145
167
}
146
168
169
+ /**
170
+ * Handles PutResponse after prompt is indexed
171
+ *
172
+ * @param putResponse response received after prompt is indexed
173
+ * @param throwable throwable
174
+ * @param name prompt name that is indexed
175
+ * @param responseBody response body that will be return upon success in the format of prompt name to prompt id
176
+ * @param remainingMLPrompts remaining prompt to be stored into the system index
177
+ * @param listener actionListener that will be notified upon success or failure of the prompt creation
178
+ */
147
179
private void handlePromptPutResponse (
148
180
PutDataObjectResponse putResponse ,
149
181
Throwable throwable ,
@@ -165,6 +197,15 @@ private void handlePromptPutResponse(
165
197
}
166
198
}
167
199
200
+ /**
201
+ * Update the response body with the prompt name and prompt id upon successful import
202
+ *
203
+ * @param promptId prompt id returned after prompt is successfully indexed into the system index
204
+ * @param name name of the prompt that is stored
205
+ * @param responseBody response body that will be return upon success in the format of prompt name to prompt id
206
+ * @param remainingMLPrompts remaining prompt to be stored into the system index
207
+ * @param listener actionListener that will be notified upon success or failure of the prompt creation
208
+ */
168
209
private void updateImportResponseBody (
169
210
String promptId ,
170
211
String name ,
@@ -179,7 +220,16 @@ private void updateImportResponseBody(
179
220
}
180
221
}
181
222
182
- private void handleDuplicateName (MLPrompt importingPrompt , String tenantId , ActionListener <String > wrappedListener ) throws IOException {
223
+ /**
224
+ * Search name field on prompt system index.
225
+ *
226
+ * @param importingPrompt prompt that needs to be imported into prompt system index
227
+ * @param tenantId tenant id
228
+ * @param wrappedListener listener that will be notified with prompt id upon success or failure of the prompt creation
229
+ * @throws IOException if search hits, meaning conflicting name exist
230
+ */
231
+ private void handleConflictingName (MLPrompt importingPrompt , String tenantId , ActionListener <String > wrappedListener )
232
+ throws IOException {
183
233
String name = importingPrompt .getName ();
184
234
SearchResponse searchResponse = mlPromptManager .searchPromptByName (name , tenantId );
185
235
if (searchResponse != null
0 commit comments