10
10
import org .elasticsearch .common .io .stream .Writeable ;
11
11
import org .elasticsearch .inference .TaskType ;
12
12
import org .elasticsearch .test .AbstractWireSerializingTestCase ;
13
+ import org .elasticsearch .test .ESTestCase ;
13
14
14
15
import java .io .IOException ;
15
16
17
+ import static org .hamcrest .Matchers .equalTo ;
18
+
16
19
public class ModelStatsTests extends AbstractWireSerializingTestCase <ModelStats > {
17
20
18
21
@ Override
@@ -27,12 +30,32 @@ protected ModelStats createTestInstance() {
27
30
28
31
@ Override
29
32
protected ModelStats mutateInstance (ModelStats modelStats ) throws IOException {
30
- ModelStats newModelStats = new ModelStats (modelStats );
31
- newModelStats .add ();
32
- return newModelStats ;
33
+ String service = modelStats .service ();
34
+ TaskType taskType = modelStats .taskType ();
35
+ long count = modelStats .count ();
36
+ return switch (randomInt (2 )) {
37
+ case 0 -> new ModelStats (randomValueOtherThan (service , ESTestCase ::randomIdentifier ), taskType , count );
38
+ case 1 -> new ModelStats (service , randomValueOtherThan (taskType , () -> randomFrom (TaskType .values ())), count );
39
+ case 2 -> new ModelStats (service , taskType , randomValueOtherThan (count , ESTestCase ::randomLong ));
40
+ default -> throw new IllegalArgumentException ();
41
+ };
42
+ }
43
+
44
+ public void testAdd () {
45
+ ModelStats stats = new ModelStats ("test_service" , randomFrom (TaskType .values ()));
46
+ assertThat (stats .count (), equalTo (0L ));
47
+
48
+ stats .add ();
49
+ assertThat (stats .count (), equalTo (1L ));
50
+
51
+ int iterations = randomIntBetween (1 , 10 );
52
+ for (int i = 0 ; i < iterations ; i ++) {
53
+ stats .add ();
54
+ }
55
+ assertThat (stats .count (), equalTo (1L + iterations ));
33
56
}
34
57
35
58
public static ModelStats createRandomInstance () {
36
- return new ModelStats (randomIdentifier (), TaskType . values ()[ randomInt ( TaskType .values (). length - 1 )], randomInt ( 10 ));
59
+ return new ModelStats (randomIdentifier (), randomFrom ( TaskType .values ()), randomLong ( ));
37
60
}
38
61
}
0 commit comments