|
3 | 3 | [](<https://jitpack.io/#tadayosi/tensorflow-serving-client-java>) |
4 | 4 | [](https://github.com/tadayosi/tensorflow-serving-client-java/actions/workflows/test.yml) |
5 | 5 |
|
6 | | -TensorFlow Serving Client for Java (TFSC4J) is a Java client library for [TensorFlow Serving](https://github.com/tensorflow/serving). It supports the following [TensorFlow Serving REST API](https://www.tensorflow.org/tfx/serving/api_rest): |
| 6 | +TensorFlow Serving Client for Java (TFSC4J) is a Java client library for [TensorFlow Serving](https://github.com/tensorflow/serving). It supports the following [TensorFlow Serving Client API (gRPC)](https://github.com/tensorflow/serving/tree/master/tensorflow_serving/apis): |
7 | 7 |
|
8 | 8 | - [Model status API](https://www.tensorflow.org/tfx/serving/api_rest#model_status_api) |
9 | 9 | - [Model Metadata API](https://www.tensorflow.org/tfx/serving/api_rest#model_metadata_api) |
@@ -33,194 +33,245 @@ TensorFlow Serving Client for Java (TFSC4J) is a Java client library for [Tensor |
33 | 33 | <dependency> |
34 | 34 | <groupId>com.github.tadayosi</groupId> |
35 | 35 | <artifactId>tensorflow-serving-client-java</artifactId> |
36 | | - <version>v0.3</version> |
| 36 | + <version>v0.1</version> |
37 | 37 | </dependency> |
38 | 38 | ``` |
39 | 39 |
|
40 | 40 | ## Usage |
41 | 41 |
|
42 | | -### Inference |
| 42 | +> [!IMPORTANT] |
| 43 | +> TFSC4J uses the gRPC port (default: `8500`) to communicate with the TensorFlow model server. |
43 | 44 |
|
44 | | -- Prediction: |
| 45 | +To creat a client: |
45 | 46 |
|
46 | | - ```java |
47 | | - TensorFlowServingClient client = TensorFlowServingClient.newInstance(); |
48 | | - |
49 | | - byte[] image = Files.readAllBytes(Path.of("0.png")); |
50 | | - Object result = client.inference().predictions("mnist_v2", image); |
51 | | - System.out.println(result); |
52 | | - // => 0 |
53 | | - ``` |
54 | | - |
55 | | -- With the inference API endpoint other than <http://localhost:8080>: |
| 47 | +```java |
| 48 | +TensorFlowServingClient client = TensorFlowServingClient.newInstance(); |
| 49 | +``` |
56 | 50 |
|
57 | | - ```java |
58 | | - TensorFlowServingClient client = TensorFlowServingClient.builder() |
59 | | - .inferenceAddress("http://localhost:12345") |
60 | | - .build(); |
61 | | - ``` |
| 51 | +By default, the client connects to `localhost:8500`, but if you want to connect to a different target URI (e.g. `example.com:8080`), instantiate a client as follows: |
62 | 52 |
|
63 | | -- With token authorization: |
| 53 | +```java |
| 54 | +TensorFlowServingClient client = TensorFlowServingClient.builder() |
| 55 | + .target("example.com:8080") |
| 56 | + .build(); |
| 57 | +``` |
64 | 58 |
|
65 | | - ```java |
66 | | - TensorFlowServingClient client = TensorFlowServingClient.builder() |
67 | | - .inferenceKey("<inference-key>") |
68 | | - .build(); |
69 | | - ``` |
| 59 | +### Model status API |
70 | 60 |
|
71 | | -### Management |
| 61 | +To get the status of a model: |
72 | 62 |
|
73 | | -- Register a model: |
| 63 | +```java |
| 64 | +TensorFlowServingClient client = TensorFlowServingClient.newInstance(); |
74 | 65 |
|
75 | | - ```java |
76 | | - TensorFlowServingClient client = TensorFlowServingClient.newInstance(); |
| 66 | +GetModelStatusRequest request = GetModelStatusRequest.newBuilder() |
| 67 | + .setModelSpec(ModelSpec.newBuilder() |
| 68 | + .setName("half_plus_two") |
| 69 | + .setVersion(Int64Value.of(123))) |
| 70 | + .build(); |
| 71 | +GetModelStatusResponse response = client.getModelStatus(request); |
| 72 | +System.out.println(response); |
| 73 | +``` |
77 | 74 |
|
78 | | - Response response = client.management().registerModel( |
79 | | - "https://torchserve.pytorch.org/mar_files/mnist_v2.mar", |
80 | | - RegisterModelOptions.empty()); |
81 | | - System.out.println(response.getStatus()); |
82 | | - // => "Model "mnist_v2" Version: 2.0 registered with 0 initial workers. Use scale workers API to add workers for the model." |
83 | | - ``` |
| 75 | +Output: |
84 | 76 |
|
85 | | -- Scale workers for a model: |
| 77 | +```console |
| 78 | +model_version_status { |
| 79 | + version: 123 |
| 80 | + state: AVAILABLE |
| 81 | + status { |
| 82 | + } |
| 83 | +} |
| 84 | +``` |
86 | 85 |
|
87 | | - ```java |
88 | | - TensorFlowServingClient client = TensorFlowServingClient.newInstance(); |
| 86 | +### Model Metadata API |
89 | 87 |
|
90 | | - Response response = client.management().setAutoScale( |
91 | | - "mnist_v2", |
92 | | - SetAutoScaleOptions.builder() |
93 | | - .minWorker(1) |
94 | | - .maxWorker(2) |
95 | | - .build()); |
96 | | - System.out.println(response.getStatus()); |
97 | | - // => "Processing worker updates..." |
98 | | - ``` |
| 88 | +To get the metadata of a model: |
99 | 89 |
|
100 | | -- Describe a model: |
| 90 | +```java |
| 91 | +TensorFlowServingClient client = TensorFlowServingClient.newInstance(); |
101 | 92 |
|
102 | | - ```java |
103 | | - TensorFlowServingClient client = TensorFlowServingClient.newInstance(); |
| 93 | +GetModelMetadataRequest request = GetModelMetadataRequest.newBuilder() |
| 94 | + .setModelSpec(ModelSpec.newBuilder() |
| 95 | + .setName("half_plus_two") |
| 96 | + .setVersion(Int64Value.of(123))) |
| 97 | + .addMetadataField("signature_def")) // metadata_field is mandatory |
| 98 | + .build(); |
| 99 | +GetModelMetadataResponse response = client.getModelMetadata(request); |
| 100 | +System.out.println(response); |
| 101 | +``` |
104 | 102 |
|
105 | | - List<ModelDetail> model = client.management().describeModel("mnist_v2"); |
106 | | - System.out.println(model.get(0)); |
107 | | - // => |
108 | | - // ModelDetail { |
109 | | - // modelName: mnist_v2 |
110 | | - // modelVersion: 2.0 |
111 | | - // ... |
112 | | - ``` |
| 103 | +Output: |
113 | 104 |
|
114 | | -- Unregister a model: |
| 105 | +```console |
| 106 | +model_spec { |
| 107 | + name: "half_plus_two" |
| 108 | + version { |
| 109 | + value: 123 |
| 110 | + } |
| 111 | +} |
| 112 | +metadata { |
| 113 | + key: "signature_def" |
| 114 | + value { |
| 115 | + type_url: "type.googleapis.com/tensorflow.serving.SignatureDefMap" |
| 116 | + value: "..." |
| 117 | + } |
| 118 | +} |
| 119 | +``` |
115 | 120 |
|
116 | | - ```java |
117 | | - TensorFlowServingClient client = TensorFlowServingClient.newInstance(); |
| 121 | +### Classify API |
| 122 | + |
| 123 | +To classify: |
| 124 | + |
| 125 | +```java |
| 126 | +TensorFlowServingClient client = TensorFlowServingClient.newInstance(); |
| 127 | + |
| 128 | +ClassificationRequest request = ClassificationRequest.newBuilder() |
| 129 | + .setModelSpec(ModelSpec.newBuilder() |
| 130 | + .setName("half_plus_two") |
| 131 | + .setVersion(Int64Value.of(123)) |
| 132 | + .setSignatureName("classify_x_to_y")) |
| 133 | + .setInput(Input.newBuilder() |
| 134 | + .setExampleList(ExampleList.newBuilder() |
| 135 | + .addExamples(Example.newBuilder() |
| 136 | + .setFeatures(Features.newBuilder() |
| 137 | + .putFeature("x", Feature.newBuilder() |
| 138 | + .setFloatList(FloatList.newBuilder().addValue(1.0f)) |
| 139 | + .build()))))) |
| 140 | + .build(); |
| 141 | +ClassificationResponse response = client.classify(request); |
| 142 | +System.out.println(response); |
| 143 | +``` |
118 | 144 |
|
119 | | - Response response = client.management().unregisterModel( |
120 | | - "mnist_v2", |
121 | | - UnregisterModelOptions.empty()); |
122 | | - System.out.println(response.getStatus()); |
123 | | - // => "Model "mnist_v2" unregistered" |
124 | | - ``` |
| 145 | +Output: |
125 | 146 |
|
126 | | -- List models: |
| 147 | +```console |
| 148 | +result { |
| 149 | + classifications { |
| 150 | + classes { |
| 151 | + score: 2.5 |
| 152 | + } |
| 153 | + } |
| 154 | +} |
| 155 | +model_spec { |
| 156 | + name: "half_plus_two" |
| 157 | + version { |
| 158 | + value: 123 |
| 159 | + } |
| 160 | + signature_name: "classify_x_to_y" |
| 161 | +} |
| 162 | +``` |
127 | 163 |
|
128 | | - ```java |
129 | | - TensorFlowServingClient client = TensorFlowServingClient.newInstance(); |
| 164 | +### Regress API |
| 165 | + |
| 166 | +To regress: |
| 167 | + |
| 168 | +```java |
| 169 | +TensorFlowServingClient client = TensorFlowServingClient.newInstance(); |
| 170 | + |
| 171 | +RegressionRequest request = RegressionRequest.newBuilder() |
| 172 | + .setModelSpec(ModelSpec.newBuilder() |
| 173 | + .setName("half_plus_two") |
| 174 | + .setVersion(Int64Value.of(123)) |
| 175 | + .setSignatureName("regress_x_to_y")) |
| 176 | + .setInput(Input.newBuilder() |
| 177 | + .setExampleList(ExampleList.newBuilder() |
| 178 | + .addExamples(Example.newBuilder() |
| 179 | + .setFeatures(Features.newBuilder() |
| 180 | + .putFeature("x", Feature.newBuilder() |
| 181 | + .setFloatList(FloatList.newBuilder().addValue(1.0f)) |
| 182 | + .build()))))) |
| 183 | + .build(); |
| 184 | +RegressionResponse response = client.regress(request); |
| 185 | +System.out.println(response); |
| 186 | +``` |
130 | 187 |
|
131 | | - ModelList models = client.management().listModels(10, null); |
132 | | - System.out.println(models); |
133 | | - // => |
134 | | - // ModelList { |
135 | | - // nextPageToken: null |
136 | | - // models: [Model { |
137 | | - // modelName: mnist_v2 |
138 | | - // modelUrl: https://torchserve.pytorch.org/mar_files/mnist_v2.mar |
139 | | - // }, |
140 | | - // ... |
141 | | - ``` |
| 188 | +Output: |
142 | 189 |
|
143 | | -- Set default version for a model: |
144 | | - |
145 | | - ```java |
146 | | - TensorFlowServingClient client = TensorFlowServingClient.newInstance(); |
147 | | - |
148 | | - Response response = client.management().setDefault("mnist_v2", "2.0"); |
149 | | - System.out.println(response.getStatus()); |
150 | | - // => "Default version successfully updated for model "mnist_v2" to "2.0"" |
151 | | - ``` |
152 | | - |
153 | | -- With the management API endpoint other than <http://localhost:8081>: |
| 190 | +```console |
| 191 | +result { |
| 192 | + regressions { |
| 193 | + value: 2.5 |
| 194 | + } |
| 195 | +} |
| 196 | +model_spec { |
| 197 | + name: "half_plus_two" |
| 198 | + version { |
| 199 | + value: 123 |
| 200 | + } |
| 201 | + signature_name: "regress_x_to_y" |
| 202 | +} |
| 203 | +``` |
154 | 204 |
|
155 | | - ```java |
156 | | - TensorFlowServingClient client = TensorFlowServingClient.builder() |
157 | | - .managementAddress("http://localhost:12345") |
158 | | - .build(); |
159 | | - ``` |
160 | | - |
161 | | -- With token authorization: |
| 205 | +### Predict API |
| 206 | + |
| 207 | +To predict: |
| 208 | + |
| 209 | +```java |
| 210 | +TensorFlowServingClient client = TensorFlowServingClient.newInstance(); |
| 211 | + |
| 212 | +PredictRequest request = PredictRequest.newBuilder() |
| 213 | + .setModelSpec(ModelSpec.newBuilder() |
| 214 | + .setName("half_plus_two") |
| 215 | + .setVersion(Int64Value.of(123))) |
| 216 | + .putInputs("x", TensorProto.newBuilder() |
| 217 | + .setDtype(DataType.DT_FLOAT) |
| 218 | + .setTensorShape(TensorShapeProto.newBuilder() |
| 219 | + .addDim(Dim.newBuilder().setSize(3))) |
| 220 | + .addFloatVal(1.0f) |
| 221 | + .addFloatVal(2.0f) |
| 222 | + .addFloatVal(5.0f) |
| 223 | + .build()) |
| 224 | + .build(); |
| 225 | +PredictResponse response = client.predict(request); |
| 226 | +System.out.println(response); |
| 227 | +``` |
162 | 228 |
|
163 | | - ```java |
164 | | - TensorFlowServingClient client = TensorFlowServingClient.builder() |
165 | | - .managementKey("<management-key>") |
166 | | - .build(); |
167 | | - ``` |
168 | | - |
169 | | -### Metrics |
| 229 | +Output: |
170 | 230 |
|
171 | | -- Get metrics in Prometheus format: |
172 | | - |
173 | | - ```java |
174 | | - TensorFlowServingClient client = TensorFlowServingClient.newInstance(); |
175 | | - |
176 | | - String metrics = client.metrics().metrics(); |
177 | | - System.out.println(metrics); |
178 | | - // => |
179 | | - // # HELP MemoryUsed Torchserve prometheus gauge metric with unit: Megabytes |
180 | | - // # TYPE MemoryUsed gauge |
181 | | - // MemoryUsed{Level="Host",Hostname="3a9b51d41fbf",} 2075.09765625 |
182 | | - // ... |
183 | | - ``` |
184 | | - |
185 | | -- With the metrics API endpoint other than <http://localhost:8082>: |
186 | | - |
187 | | - ```java |
188 | | - TensorFlowServingClient client = TensorFlowServingClient.builder() |
189 | | - .metricsAddress("http://localhost:12345") |
190 | | - .build(); |
191 | | - ``` |
| 231 | +```console |
| 232 | +outputs { |
| 233 | + key: "y" |
| 234 | + value { |
| 235 | + dtype: DT_FLOAT |
| 236 | + tensor_shape { |
| 237 | + dim { |
| 238 | + size: 3 |
| 239 | + } |
| 240 | + } |
| 241 | + float_val: 2.5 |
| 242 | + float_val: 3.0 |
| 243 | + float_val: 4.5 |
| 244 | + } |
| 245 | +} |
| 246 | +model_spec { |
| 247 | + name: "half_plus_two" |
| 248 | + version { |
| 249 | + value: 123 |
| 250 | + } |
| 251 | + signature_name: "serving_default" |
| 252 | +} |
| 253 | +``` |
192 | 254 |
|
193 | 255 | ## Configuration |
194 | 256 |
|
195 | | -### tsc4j.properties |
| 257 | +### tfsc4j.properties |
196 | 258 |
|
197 | 259 | ```properties |
198 | | -inference.key = <inference-key> |
199 | | -inference.address = http://localhost:8080 |
200 | | -# inference.address takes precedence over inference.port if it's defined |
201 | | -inference.port = 8080 |
202 | | - |
203 | | -management.key = <management-key> |
204 | | -management.address = http://localhost:8081 |
205 | | -# management.address takes precedence over management.port if it's defined |
206 | | -management.port = 8081 |
207 | | - |
208 | | -metrics.address = http://localhost:8082 |
209 | | -# metrics.address takes precedence over metrics.port if it's defined |
210 | | -metrics.port = 8082 |
| 260 | +target = <target> |
| 261 | +credentials = <credentials> |
211 | 262 | ``` |
212 | 263 |
|
213 | 264 | ### System properties |
214 | 265 |
|
215 | | -You can configure the TSC4J properties via system properties with prefix `tsc4j.`. |
| 266 | +You can configure the TFSC4J properties via system properties with prefix `tfsc4j.`. |
216 | 267 |
|
217 | | -For instance, you can configure `inference.address` with the `tsc4j.inference.address` system property. |
| 268 | +For instance, you can configure `target` with the `tfsc4j.target` system property. |
218 | 269 |
|
219 | 270 | ### Environment variables |
220 | 271 |
|
221 | | -You can also configure the TSC4J properties via environment variables with prefix `TSC4J_`. |
| 272 | +You can also configure the TFSC4J properties via environment variables with prefix `TFSC4J_`. |
222 | 273 |
|
223 | | -For instance, you can configure `inference.address` with the `TSC4J_INFERENCE_ADDRESS` environment variable. |
| 274 | +For instance, you can configure `target` with the `TFSC4J_TARGET` environment variable. |
224 | 275 |
|
225 | 276 | ## Examples |
226 | 277 |
|
|
0 commit comments