7
7
8
8
import static org .mockito .ArgumentMatchers .any ;
9
9
import static org .mockito .Mockito .argThat ;
10
+ import static org .mockito .Mockito .doThrow ;
10
11
import static org .mockito .Mockito .spy ;
11
12
import static org .mockito .Mockito .times ;
13
+ import static org .mockito .Mockito .verify ;
12
14
import static org .mockito .Mockito .when ;
13
15
import static org .opensearch .ml .common .connector .AbstractConnector .ACCESS_KEY_FIELD ;
14
16
import static org .opensearch .ml .common .connector .AbstractConnector .SECRET_KEY_FIELD ;
32
34
import org .opensearch .common .settings .Settings ;
33
35
import org .opensearch .common .util .concurrent .ThreadContext ;
34
36
import org .opensearch .core .action .ActionListener ;
37
+ import org .opensearch .core .xcontent .XContentBuilder ;
35
38
import org .opensearch .ingest .TestTemplateService ;
36
39
import org .opensearch .ml .common .FunctionName ;
37
40
import org .opensearch .ml .common .connector .AwsConnector ;
41
44
import org .opensearch .ml .common .connector .RetryBackoffPolicy ;
42
45
import org .opensearch .ml .common .dataset .remote .RemoteInferenceInputDataSet ;
43
46
import org .opensearch .ml .common .input .MLInput ;
47
+ import org .opensearch .ml .common .input .parameter .MLAlgoParams ;
48
+ import org .opensearch .ml .common .input .parameter .clustering .KMeansParams ;
44
49
import org .opensearch .ml .common .input .parameter .textembedding .AsymmetricTextEmbeddingParameters ;
45
50
import org .opensearch .ml .common .input .parameter .textembedding .SparseEmbeddingFormat ;
46
51
import org .opensearch .ml .common .output .model .ModelTensors ;
@@ -68,6 +73,9 @@ public class RemoteConnectorExecutorTest {
68
73
@ Mock
69
74
ActionListener <Tuple <Integer , ModelTensors >> actionListener ;
70
75
76
+ @ Mock
77
+ private MLAlgoParams mlInputParams ;
78
+
71
79
@ Before
72
80
public void setUp () {
73
81
MockitoAnnotations .openMocks (this );
@@ -174,6 +182,62 @@ public void executePreparePayloadAndInvoke_SkipValidateMissingParameterDefault()
174
182
assert exception .getMessage ().contains ("Some parameter placeholder not filled in payload: role" );
175
183
}
176
184
185
+ @ Test
186
+ public void executePreparePayloadAndInvoke_PassingParameter () {
187
+ Map <String , String > parameters = ImmutableMap .of (SERVICE_NAME_FIELD , "sagemaker" , REGION_FIELD , "us-west-2" );
188
+ Connector connector = getConnector (parameters );
189
+ AwsConnectorExecutor executor = getExecutor (connector );
190
+
191
+ RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet
192
+ .builder ()
193
+ .parameters (Map .of ("input" , "You are a ${parameters.role}" ))
194
+ .actionType (PREDICT )
195
+ .build ();
196
+ String actionType = inputDataSet .getActionType ().toString ();
197
+ AsymmetricTextEmbeddingParameters inputParams = AsymmetricTextEmbeddingParameters
198
+ .builder ()
199
+ .sparseEmbeddingFormat (SparseEmbeddingFormat .WORD )
200
+ .embeddingContentType (null )
201
+ .build ();
202
+ MLInput mlInput = MLInput
203
+ .builder ()
204
+ .algorithm (FunctionName .TEXT_EMBEDDING )
205
+ .parameters (inputParams )
206
+ .inputDataset (inputDataSet )
207
+ .build ();
208
+
209
+ Exception exception = Assert
210
+ .assertThrows (
211
+ IllegalArgumentException .class ,
212
+ () -> executor .preparePayloadAndInvoke (actionType , mlInput , null , actionListener )
213
+ );
214
+ assert exception .getMessage ().contains ("Some parameter placeholder not filled in payload: role" );
215
+ }
216
+
217
+ @ Test
218
+ public void executePreparePayloadAndInvoke_GetParamsIOException () throws Exception {
219
+ Map <String , String > parameters = ImmutableMap .of (SERVICE_NAME_FIELD , "sagemaker" , REGION_FIELD , "us-west-2" );
220
+ Connector connector = getConnector (parameters );
221
+ AwsConnectorExecutor executor = getExecutor (connector );
222
+
223
+ RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet
224
+ .builder ()
225
+ .parameters (Map .of ("input" , "test input" ))
226
+ .actionType (PREDICT )
227
+ .build ();
228
+ String actionType = inputDataSet .getActionType ().toString ();
229
+ doThrow (new IOException ("UT test IOException" )).when (mlInputParams ).toXContent (any (XContentBuilder .class ), any ());
230
+ MLInput mlInput = MLInput
231
+ .builder ()
232
+ .algorithm (FunctionName .TEXT_EMBEDDING )
233
+ .parameters (mlInputParams )
234
+ .inputDataset (inputDataSet )
235
+ .build ();
236
+
237
+ executor .preparePayloadAndInvoke (actionType , mlInput , null , actionListener );
238
+ verify (actionListener ).onFailure (argThat (e -> e instanceof IOException && e .getMessage ().contains ("UT test IOException" )));
239
+ }
240
+
177
241
@ Test
178
242
public void executeGetParams_MissingParameter () {
179
243
Map <String , String > parameters = ImmutableMap .of (SERVICE_NAME_FIELD , "sagemaker" , REGION_FIELD , "us-west-2" );
@@ -199,7 +263,7 @@ public void executeGetParams_MissingParameter() {
199
263
.build ();
200
264
201
265
try {
202
- Map <String , String > paramsMap = executor .getParams (mlInput );
266
+ Map <String , String > paramsMap = RemoteConnectorExecutor .getParams (mlInput );
203
267
Map <String , String > expectedMap = new HashMap <>();
204
268
expectedMap .put ("sparse_embedding_format" , "WORD" );
205
269
Assert .assertEquals (expectedMap , paramsMap );
@@ -233,7 +297,7 @@ public void executeGetParams_PassingParameter() {
233
297
.build ();
234
298
235
299
try {
236
- Map <String , String > paramsMap = executor .getParams (mlInput );
300
+ Map <String , String > paramsMap = RemoteConnectorExecutor .getParams (mlInput );
237
301
Map <String , String > expectedMap = new HashMap <>();
238
302
expectedMap .put ("sparse_embedding_format" , "WORD" );
239
303
expectedMap .put ("content_type" , "PASSAGE" );
@@ -242,4 +306,40 @@ public void executeGetParams_PassingParameter() {
242
306
e .printStackTrace ();
243
307
}
244
308
}
309
+
310
+ @ Test
311
+ public void executeGetParams_ConvertToString () {
312
+ Map <String , String > parameters = ImmutableMap .of (SERVICE_NAME_FIELD , "sagemaker" , REGION_FIELD , "us-west-2" );
313
+ Connector connector = getConnector (parameters );
314
+ AwsConnectorExecutor executor = getExecutor (connector );
315
+
316
+ RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet
317
+ .builder ()
318
+ .parameters (Map .of ("input" , "${parameters.input}" ))
319
+ .actionType (PREDICT )
320
+ .build ();
321
+ KMeansParams inputParams = KMeansParams
322
+ .builder ()
323
+ .centroids (5 )
324
+ .iterations (100 )
325
+ .distanceType (KMeansParams .DistanceType .EUCLIDEAN )
326
+ .build ();
327
+ MLInput mlInput = MLInput
328
+ .builder ()
329
+ .algorithm (FunctionName .TEXT_EMBEDDING )
330
+ .parameters (inputParams )
331
+ .inputDataset (inputDataSet )
332
+ .build ();
333
+
334
+ try {
335
+ Map <String , String > paramsMap = RemoteConnectorExecutor .getParams (mlInput );
336
+ Map <String , String > expectedMap = new HashMap <>();
337
+ expectedMap .put ("centroids" , "5" );
338
+ expectedMap .put ("iterations" , "100" );
339
+ expectedMap .put ("distance_type" , "EUCLIDEAN" );
340
+ Assert .assertEquals (expectedMap , paramsMap );
341
+ } catch (IOException e ) {
342
+ e .printStackTrace ();
343
+ }
344
+ }
245
345
}
0 commit comments