Skip to content

Commit 66582a3

Browse files
Improvements for inference ModelStatsTest (#134707)
This commit improves instance mutation, adds a test for `ModelStats.add()`, and some more minor changes.
1 parent a2b7754 commit 66582a3

File tree

1 file changed

+27
-4
lines changed

1 file changed

+27
-4
lines changed

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/usage/ModelStatsTests.java

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,12 @@
1010
import org.elasticsearch.common.io.stream.Writeable;
1111
import org.elasticsearch.inference.TaskType;
1212
import org.elasticsearch.test.AbstractWireSerializingTestCase;
13+
import org.elasticsearch.test.ESTestCase;
1314

1415
import java.io.IOException;
1516

17+
import static org.hamcrest.Matchers.equalTo;
18+
1619
public class ModelStatsTests extends AbstractWireSerializingTestCase<ModelStats> {
1720

1821
@Override
@@ -27,12 +30,32 @@ protected ModelStats createTestInstance() {
2730

2831
@Override
2932
protected ModelStats mutateInstance(ModelStats modelStats) throws IOException {
30-
ModelStats newModelStats = new ModelStats(modelStats);
31-
newModelStats.add();
32-
return newModelStats;
33+
String service = modelStats.service();
34+
TaskType taskType = modelStats.taskType();
35+
long count = modelStats.count();
36+
return switch (randomInt(2)) {
37+
case 0 -> new ModelStats(randomValueOtherThan(service, ESTestCase::randomIdentifier), taskType, count);
38+
case 1 -> new ModelStats(service, randomValueOtherThan(taskType, () -> randomFrom(TaskType.values())), count);
39+
case 2 -> new ModelStats(service, taskType, randomValueOtherThan(count, ESTestCase::randomLong));
40+
default -> throw new IllegalArgumentException();
41+
};
42+
}
43+
44+
public void testAdd() {
45+
ModelStats stats = new ModelStats("test_service", randomFrom(TaskType.values()));
46+
assertThat(stats.count(), equalTo(0L));
47+
48+
stats.add();
49+
assertThat(stats.count(), equalTo(1L));
50+
51+
int iterations = randomIntBetween(1, 10);
52+
for (int i = 0; i < iterations; i++) {
53+
stats.add();
54+
}
55+
assertThat(stats.count(), equalTo(1L + iterations));
3356
}
3457

3558
public static ModelStats createRandomInstance() {
36-
return new ModelStats(randomIdentifier(), TaskType.values()[randomInt(TaskType.values().length - 1)], randomInt(10));
59+
return new ModelStats(randomIdentifier(), randomFrom(TaskType.values()), randomLong());
3760
}
3861
}

0 commit comments

Comments
 (0)