Skip to content

Commit 3dc23f4

Browse files
committed
add tests
1 parent 1bc6399 commit 3dc23f4

File tree

2 files changed

+104
-1
lines changed

2 files changed

+104
-1
lines changed

src/test/java/com/github/tadayosi/tensorflow/serving/client/TensorFlowServingClientTest.java

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,25 @@
22

33
import com.google.protobuf.Int64Value;
44
import org.junit.jupiter.api.Test;
5+
import org.tensorflow.example.Example;
6+
import org.tensorflow.example.Feature;
7+
import org.tensorflow.example.Features;
8+
import org.tensorflow.example.FloatList;
9+
import org.tensorflow.framework.DataType;
10+
import org.tensorflow.framework.TensorProto;
11+
import org.tensorflow.framework.TensorShapeProto;
512
import org.testcontainers.junit.jupiter.Testcontainers;
13+
import tensorflow.serving.Classification;
14+
import tensorflow.serving.GetModelMetadata;
615
import tensorflow.serving.GetModelStatus;
16+
import tensorflow.serving.InputOuterClass;
717
import tensorflow.serving.Model;
18+
import tensorflow.serving.Predict;
19+
import tensorflow.serving.RegressionOuterClass;
820

921
import static org.junit.jupiter.api.Assertions.assertEquals;
22+
import static org.junit.jupiter.api.Assertions.assertNotNull;
23+
import static org.junit.jupiter.api.Assertions.assertTrue;
1024
import static tensorflow.serving.GetModelStatus.ModelVersionStatus.State.AVAILABLE;
1125

1226
@Testcontainers
@@ -30,4 +44,94 @@ void testGetModelStatus() {
3044
assertEquals(AVAILABLE, modelVersionStatus.getState());
3145
}
3246

47+
@Test
48+
void testModelMetadata() {
49+
var request = GetModelMetadata.GetModelMetadataRequest.newBuilder()
50+
.setModelSpec(Model.ModelSpec.newBuilder()
51+
.setName(DEFAULT_MODEL)
52+
.setVersion(Int64Value.of(DEFAULT_MODEL_VERSION)))
53+
.addMetadataField("signature_def")
54+
.build();
55+
var response = client.getModelMetadata(request);
56+
57+
assertEquals(1, response.getMetadataCount());
58+
var modelSpec = response.getModelSpec();
59+
assertEquals(DEFAULT_MODEL, modelSpec.getName());
60+
assertEquals(DEFAULT_MODEL_VERSION, modelSpec.getVersion().getValue());
61+
var metadata = response.getMetadataOrThrow("signature_def");
62+
assertNotNull(metadata);
63+
}
64+
65+
@Test
66+
void testClassify() {
67+
var request = Classification.ClassificationRequest.newBuilder()
68+
.setModelSpec(Model.ModelSpec.newBuilder()
69+
.setName(DEFAULT_MODEL)
70+
.setVersion(Int64Value.of(DEFAULT_MODEL_VERSION))
71+
.setSignatureName("classify_x_to_y"))
72+
.setInput(InputOuterClass.Input.newBuilder()
73+
.setExampleList(InputOuterClass.ExampleList.newBuilder()
74+
.addExamples(Example.newBuilder()
75+
.setFeatures(Features.newBuilder()
76+
.putFeature("x", Feature.newBuilder()
77+
.setFloatList(FloatList.newBuilder().addValue(1.0f))
78+
.build())))))
79+
.build();
80+
var response = client.classify(request);
81+
82+
assertTrue(response.hasResult());
83+
var result = response.getResult();
84+
assertEquals(1, result.getClassificationsCount());
85+
assertEquals(1, result.getClassifications(0).getClassesCount());
86+
assertEquals(2.5f, result.getClassifications(0).getClasses(0).getScore(), 0.0001);
87+
}
88+
89+
@Test
90+
void testRegress() {
91+
var request = RegressionOuterClass.RegressionRequest.newBuilder()
92+
.setModelSpec(Model.ModelSpec.newBuilder()
93+
.setName(DEFAULT_MODEL)
94+
.setVersion(Int64Value.of(DEFAULT_MODEL_VERSION))
95+
.setSignatureName("regress_x_to_y"))
96+
.setInput(InputOuterClass.Input.newBuilder()
97+
.setExampleList(InputOuterClass.ExampleList.newBuilder()
98+
.addExamples(Example.newBuilder()
99+
.setFeatures(Features.newBuilder()
100+
.putFeature("x", Feature.newBuilder()
101+
.setFloatList(FloatList.newBuilder().addValue(1.0f))
102+
.build())))))
103+
.build();
104+
var response = client.regress(request);
105+
106+
assertTrue(response.hasResult());
107+
var result = response.getResult();
108+
assertEquals(1, result.getRegressionsCount());
109+
assertEquals(2.5f, result.getRegressions(0).getValue(), 0.0001);
110+
}
111+
112+
@Test
113+
void testPredict() {
114+
var request = Predict.PredictRequest.newBuilder()
115+
.setModelSpec(Model.ModelSpec.newBuilder()
116+
.setName(DEFAULT_MODEL)
117+
.setVersion(Int64Value.of(DEFAULT_MODEL_VERSION)))
118+
.putInputs("x", TensorProto.newBuilder()
119+
.setDtype(DataType.DT_FLOAT)
120+
.setTensorShape(TensorShapeProto.newBuilder()
121+
.addDim(TensorShapeProto.Dim.newBuilder().setSize(3)))
122+
.addFloatVal(1.0f)
123+
.addFloatVal(2.0f)
124+
.addFloatVal(5.0f)
125+
.build())
126+
.build();
127+
var response = client.predict(request);
128+
129+
assertEquals(1, response.getOutputsCount());
130+
var output = response.getOutputsOrThrow("y");
131+
assertEquals(3, output.getTensorShape().getDim(0).getSize());
132+
assertEquals(2.5f, output.getFloatVal(0), 0.0001);
133+
assertEquals(3.0f, output.getFloatVal(1), 0.0001);
134+
assertEquals(4.5f, output.getFloatVal(2), 0.0001);
135+
}
136+
33137
}

src/test/java/com/github/tadayosi/tensorflow/serving/client/TensorFlowServingTestSupport.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import org.junit.jupiter.api.BeforeEach;
44
import org.testcontainers.containers.GenericContainer;
55
import org.testcontainers.containers.wait.strategy.Wait;
6-
import org.testcontainers.images.builder.Transferable;
76
import org.testcontainers.junit.jupiter.Container;
87
import org.testcontainers.utility.DockerImageName;
98
import org.testcontainers.utility.MountableFile;

0 commit comments

Comments
 (0)