22
33import com .google .protobuf .Int64Value ;
44import org .junit .jupiter .api .Test ;
5+ import org .tensorflow .example .Example ;
6+ import org .tensorflow .example .Feature ;
7+ import org .tensorflow .example .Features ;
8+ import org .tensorflow .example .FloatList ;
9+ import org .tensorflow .framework .DataType ;
10+ import org .tensorflow .framework .TensorProto ;
11+ import org .tensorflow .framework .TensorShapeProto ;
512import org .testcontainers .junit .jupiter .Testcontainers ;
13+ import tensorflow .serving .Classification ;
14+ import tensorflow .serving .GetModelMetadata ;
615import tensorflow .serving .GetModelStatus ;
16+ import tensorflow .serving .InputOuterClass ;
717import tensorflow .serving .Model ;
18+ import tensorflow .serving .Predict ;
19+ import tensorflow .serving .RegressionOuterClass ;
820
921import static org .junit .jupiter .api .Assertions .assertEquals ;
22+ import static org .junit .jupiter .api .Assertions .assertNotNull ;
23+ import static org .junit .jupiter .api .Assertions .assertTrue ;
1024import static tensorflow .serving .GetModelStatus .ModelVersionStatus .State .AVAILABLE ;
1125
1226@ Testcontainers
@@ -30,4 +44,94 @@ void testGetModelStatus() {
3044 assertEquals (AVAILABLE , modelVersionStatus .getState ());
3145 }
3246
47+ @ Test
48+ void testModelMetadata () {
49+ var request = GetModelMetadata .GetModelMetadataRequest .newBuilder ()
50+ .setModelSpec (Model .ModelSpec .newBuilder ()
51+ .setName (DEFAULT_MODEL )
52+ .setVersion (Int64Value .of (DEFAULT_MODEL_VERSION )))
53+ .addMetadataField ("signature_def" )
54+ .build ();
55+ var response = client .getModelMetadata (request );
56+
57+ assertEquals (1 , response .getMetadataCount ());
58+ var modelSpec = response .getModelSpec ();
59+ assertEquals (DEFAULT_MODEL , modelSpec .getName ());
60+ assertEquals (DEFAULT_MODEL_VERSION , modelSpec .getVersion ().getValue ());
61+ var metadata = response .getMetadataOrThrow ("signature_def" );
62+ assertNotNull (metadata );
63+ }
64+
65+ @ Test
66+ void testClassify () {
67+ var request = Classification .ClassificationRequest .newBuilder ()
68+ .setModelSpec (Model .ModelSpec .newBuilder ()
69+ .setName (DEFAULT_MODEL )
70+ .setVersion (Int64Value .of (DEFAULT_MODEL_VERSION ))
71+ .setSignatureName ("classify_x_to_y" ))
72+ .setInput (InputOuterClass .Input .newBuilder ()
73+ .setExampleList (InputOuterClass .ExampleList .newBuilder ()
74+ .addExamples (Example .newBuilder ()
75+ .setFeatures (Features .newBuilder ()
76+ .putFeature ("x" , Feature .newBuilder ()
77+ .setFloatList (FloatList .newBuilder ().addValue (1.0f ))
78+ .build ())))))
79+ .build ();
80+ var response = client .classify (request );
81+
82+ assertTrue (response .hasResult ());
83+ var result = response .getResult ();
84+ assertEquals (1 , result .getClassificationsCount ());
85+ assertEquals (1 , result .getClassifications (0 ).getClassesCount ());
86+ assertEquals (2.5f , result .getClassifications (0 ).getClasses (0 ).getScore (), 0.0001 );
87+ }
88+
89+ @ Test
90+ void testRegress () {
91+ var request = RegressionOuterClass .RegressionRequest .newBuilder ()
92+ .setModelSpec (Model .ModelSpec .newBuilder ()
93+ .setName (DEFAULT_MODEL )
94+ .setVersion (Int64Value .of (DEFAULT_MODEL_VERSION ))
95+ .setSignatureName ("regress_x_to_y" ))
96+ .setInput (InputOuterClass .Input .newBuilder ()
97+ .setExampleList (InputOuterClass .ExampleList .newBuilder ()
98+ .addExamples (Example .newBuilder ()
99+ .setFeatures (Features .newBuilder ()
100+ .putFeature ("x" , Feature .newBuilder ()
101+ .setFloatList (FloatList .newBuilder ().addValue (1.0f ))
102+ .build ())))))
103+ .build ();
104+ var response = client .regress (request );
105+
106+ assertTrue (response .hasResult ());
107+ var result = response .getResult ();
108+ assertEquals (1 , result .getRegressionsCount ());
109+ assertEquals (2.5f , result .getRegressions (0 ).getValue (), 0.0001 );
110+ }
111+
112+ @ Test
113+ void testPredict () {
114+ var request = Predict .PredictRequest .newBuilder ()
115+ .setModelSpec (Model .ModelSpec .newBuilder ()
116+ .setName (DEFAULT_MODEL )
117+ .setVersion (Int64Value .of (DEFAULT_MODEL_VERSION )))
118+ .putInputs ("x" , TensorProto .newBuilder ()
119+ .setDtype (DataType .DT_FLOAT )
120+ .setTensorShape (TensorShapeProto .newBuilder ()
121+ .addDim (TensorShapeProto .Dim .newBuilder ().setSize (3 )))
122+ .addFloatVal (1.0f )
123+ .addFloatVal (2.0f )
124+ .addFloatVal (5.0f )
125+ .build ())
126+ .build ();
127+ var response = client .predict (request );
128+
129+ assertEquals (1 , response .getOutputsCount ());
130+ var output = response .getOutputsOrThrow ("y" );
131+ assertEquals (3 , output .getTensorShape ().getDim (0 ).getSize ());
132+ assertEquals (2.5f , output .getFloatVal (0 ), 0.0001 );
133+ assertEquals (3.0f , output .getFloatVal (1 ), 0.0001 );
134+ assertEquals (4.5f , output .getFloatVal (2 ), 0.0001 );
135+ }
136+
33137}
0 commit comments