Skip to content

Commit f474f58

Browse files
committed
update examples and readme
1 parent 9692323 commit f474f58

15 files changed

+377
-356
lines changed

README.md

Lines changed: 196 additions & 145 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
[![Release](https://jitpack.io/v/tadayosi/tensorflow-serving-client-java.svg)](<https://jitpack.io/#tadayosi/tensorflow-serving-client-java>)
44
[![Test](https://github.com/tadayosi/tensorflow-serving-client-java/actions/workflows/test.yml/badge.svg)](https://github.com/tadayosi/tensorflow-serving-client-java/actions/workflows/test.yml)
55

6-
TensorFlow Serving Client for Java (TFSC4J) is a Java client library for [TensorFlow Serving](https://github.com/tensorflow/serving). It supports the following [TensorFlow Serving REST API](https://www.tensorflow.org/tfx/serving/api_rest):
6+
TensorFlow Serving Client for Java (TFSC4J) is a Java client library for [TensorFlow Serving](https://github.com/tensorflow/serving). It supports the following [TensorFlow Serving Client API (gRPC)](https://github.com/tensorflow/serving/tree/master/tensorflow_serving/apis):
77

88
- [Model status API](https://www.tensorflow.org/tfx/serving/api_rest#model_status_api)
99
- [Model Metadata API](https://www.tensorflow.org/tfx/serving/api_rest#model_metadata_api)
@@ -33,194 +33,245 @@ TensorFlow Serving Client for Java (TFSC4J) is a Java client library for [Tensor
3333
<dependency>
3434
<groupId>com.github.tadayosi</groupId>
3535
<artifactId>tensorflow-serving-client-java</artifactId>
36-
<version>v0.3</version>
36+
<version>v0.1</version>
3737
</dependency>
3838
```
3939

4040
## Usage
4141

42-
### Inference
42+
> [!IMPORTANT]
43+
> TFSC4J uses the gRPC port (default: `8500`) to communicate with the TensorFlow model server.
4344

44-
- Prediction:
45+
To creat a client:
4546

46-
```java
47-
TensorFlowServingClient client = TensorFlowServingClient.newInstance();
48-
49-
byte[] image = Files.readAllBytes(Path.of("0.png"));
50-
Object result = client.inference().predictions("mnist_v2", image);
51-
System.out.println(result);
52-
// => 0
53-
```
54-
55-
- With the inference API endpoint other than <http://localhost:8080>:
47+
```java
48+
TensorFlowServingClient client = TensorFlowServingClient.newInstance();
49+
```
5650

57-
```java
58-
TensorFlowServingClient client = TensorFlowServingClient.builder()
59-
.inferenceAddress("http://localhost:12345")
60-
.build();
61-
```
51+
By default, the client connects to `localhost:8500`, but if you want to connect to a different target URI (e.g. `example.com:8080`), instantiate a client as follows:
6252

63-
- With token authorization:
53+
```java
54+
TensorFlowServingClient client = TensorFlowServingClient.builder()
55+
.target("example.com:8080")
56+
.build();
57+
```
6458

65-
```java
66-
TensorFlowServingClient client = TensorFlowServingClient.builder()
67-
.inferenceKey("<inference-key>")
68-
.build();
69-
```
59+
### Model status API
7060

71-
### Management
61+
To get the status of a model:
7262

73-
- Register a model:
63+
```java
64+
TensorFlowServingClient client = TensorFlowServingClient.newInstance();
7465

75-
```java
76-
TensorFlowServingClient client = TensorFlowServingClient.newInstance();
66+
GetModelStatusRequest request = GetModelStatusRequest.newBuilder()
67+
.setModelSpec(ModelSpec.newBuilder()
68+
.setName("half_plus_two")
69+
.setVersion(Int64Value.of(123)))
70+
.build();
71+
GetModelStatusResponse response = client.getModelStatus(request);
72+
System.out.println(response);
73+
```
7774

78-
Response response = client.management().registerModel(
79-
"https://torchserve.pytorch.org/mar_files/mnist_v2.mar",
80-
RegisterModelOptions.empty());
81-
System.out.println(response.getStatus());
82-
// => "Model "mnist_v2" Version: 2.0 registered with 0 initial workers. Use scale workers API to add workers for the model."
83-
```
75+
Output:
8476

85-
- Scale workers for a model:
77+
```console
78+
model_version_status {
79+
version: 123
80+
state: AVAILABLE
81+
status {
82+
}
83+
}
84+
```
8685

87-
```java
88-
TensorFlowServingClient client = TensorFlowServingClient.newInstance();
86+
### Model Metadata API
8987

90-
Response response = client.management().setAutoScale(
91-
"mnist_v2",
92-
SetAutoScaleOptions.builder()
93-
.minWorker(1)
94-
.maxWorker(2)
95-
.build());
96-
System.out.println(response.getStatus());
97-
// => "Processing worker updates..."
98-
```
88+
To get the metadata of a model:
9989

100-
- Describe a model:
90+
```java
91+
TensorFlowServingClient client = TensorFlowServingClient.newInstance();
10192

102-
```java
103-
TensorFlowServingClient client = TensorFlowServingClient.newInstance();
93+
GetModelMetadataRequest request = GetModelMetadataRequest.newBuilder()
94+
.setModelSpec(ModelSpec.newBuilder()
95+
.setName("half_plus_two")
96+
.setVersion(Int64Value.of(123)))
97+
.addMetadataField("signature_def")) // metadata_field is mandatory
98+
.build();
99+
GetModelMetadataResponse response = client.getModelMetadata(request);
100+
System.out.println(response);
101+
```
104102

105-
List<ModelDetail> model = client.management().describeModel("mnist_v2");
106-
System.out.println(model.get(0));
107-
// =>
108-
// ModelDetail {
109-
// modelName: mnist_v2
110-
// modelVersion: 2.0
111-
// ...
112-
```
103+
Output:
113104

114-
- Unregister a model:
105+
```console
106+
model_spec {
107+
name: "half_plus_two"
108+
version {
109+
value: 123
110+
}
111+
}
112+
metadata {
113+
key: "signature_def"
114+
value {
115+
type_url: "type.googleapis.com/tensorflow.serving.SignatureDefMap"
116+
value: "..."
117+
}
118+
}
119+
```
115120

116-
```java
117-
TensorFlowServingClient client = TensorFlowServingClient.newInstance();
121+
### Classify API
122+
123+
To classify:
124+
125+
```java
126+
TensorFlowServingClient client = TensorFlowServingClient.newInstance();
127+
128+
ClassificationRequest request = ClassificationRequest.newBuilder()
129+
.setModelSpec(ModelSpec.newBuilder()
130+
.setName("half_plus_two")
131+
.setVersion(Int64Value.of(123))
132+
.setSignatureName("classify_x_to_y"))
133+
.setInput(Input.newBuilder()
134+
.setExampleList(ExampleList.newBuilder()
135+
.addExamples(Example.newBuilder()
136+
.setFeatures(Features.newBuilder()
137+
.putFeature("x", Feature.newBuilder()
138+
.setFloatList(FloatList.newBuilder().addValue(1.0f))
139+
.build())))))
140+
.build();
141+
ClassificationResponse response = client.classify(request);
142+
System.out.println(response);
143+
```
118144

119-
Response response = client.management().unregisterModel(
120-
"mnist_v2",
121-
UnregisterModelOptions.empty());
122-
System.out.println(response.getStatus());
123-
// => "Model "mnist_v2" unregistered"
124-
```
145+
Output:
125146

126-
- List models:
147+
```console
148+
result {
149+
classifications {
150+
classes {
151+
score: 2.5
152+
}
153+
}
154+
}
155+
model_spec {
156+
name: "half_plus_two"
157+
version {
158+
value: 123
159+
}
160+
signature_name: "classify_x_to_y"
161+
}
162+
```
127163

128-
```java
129-
TensorFlowServingClient client = TensorFlowServingClient.newInstance();
164+
### Regress API
165+
166+
To regress:
167+
168+
```java
169+
TensorFlowServingClient client = TensorFlowServingClient.newInstance();
170+
171+
RegressionRequest request = RegressionRequest.newBuilder()
172+
.setModelSpec(ModelSpec.newBuilder()
173+
.setName("half_plus_two")
174+
.setVersion(Int64Value.of(123))
175+
.setSignatureName("regress_x_to_y"))
176+
.setInput(Input.newBuilder()
177+
.setExampleList(ExampleList.newBuilder()
178+
.addExamples(Example.newBuilder()
179+
.setFeatures(Features.newBuilder()
180+
.putFeature("x", Feature.newBuilder()
181+
.setFloatList(FloatList.newBuilder().addValue(1.0f))
182+
.build())))))
183+
.build();
184+
RegressionResponse response = client.regress(request);
185+
System.out.println(response);
186+
```
130187

131-
ModelList models = client.management().listModels(10, null);
132-
System.out.println(models);
133-
// =>
134-
// ModelList {
135-
// nextPageToken: null
136-
// models: [Model {
137-
// modelName: mnist_v2
138-
// modelUrl: https://torchserve.pytorch.org/mar_files/mnist_v2.mar
139-
// },
140-
// ...
141-
```
188+
Output:
142189

143-
- Set default version for a model:
144-
145-
```java
146-
TensorFlowServingClient client = TensorFlowServingClient.newInstance();
147-
148-
Response response = client.management().setDefault("mnist_v2", "2.0");
149-
System.out.println(response.getStatus());
150-
// => "Default version successfully updated for model "mnist_v2" to "2.0""
151-
```
152-
153-
- With the management API endpoint other than <http://localhost:8081>:
190+
```console
191+
result {
192+
regressions {
193+
value: 2.5
194+
}
195+
}
196+
model_spec {
197+
name: "half_plus_two"
198+
version {
199+
value: 123
200+
}
201+
signature_name: "regress_x_to_y"
202+
}
203+
```
154204

155-
```java
156-
TensorFlowServingClient client = TensorFlowServingClient.builder()
157-
.managementAddress("http://localhost:12345")
158-
.build();
159-
```
160-
161-
- With token authorization:
205+
### Predict API
206+
207+
To predict:
208+
209+
```java
210+
TensorFlowServingClient client = TensorFlowServingClient.newInstance();
211+
212+
PredictRequest request = PredictRequest.newBuilder()
213+
.setModelSpec(ModelSpec.newBuilder()
214+
.setName("half_plus_two")
215+
.setVersion(Int64Value.of(123)))
216+
.putInputs("x", TensorProto.newBuilder()
217+
.setDtype(DataType.DT_FLOAT)
218+
.setTensorShape(TensorShapeProto.newBuilder()
219+
.addDim(Dim.newBuilder().setSize(3)))
220+
.addFloatVal(1.0f)
221+
.addFloatVal(2.0f)
222+
.addFloatVal(5.0f)
223+
.build())
224+
.build();
225+
PredictResponse response = client.predict(request);
226+
System.out.println(response);
227+
```
162228

163-
```java
164-
TensorFlowServingClient client = TensorFlowServingClient.builder()
165-
.managementKey("<management-key>")
166-
.build();
167-
```
168-
169-
### Metrics
229+
Output:
170230

171-
- Get metrics in Prometheus format:
172-
173-
```java
174-
TensorFlowServingClient client = TensorFlowServingClient.newInstance();
175-
176-
String metrics = client.metrics().metrics();
177-
System.out.println(metrics);
178-
// =>
179-
// # HELP MemoryUsed Torchserve prometheus gauge metric with unit: Megabytes
180-
// # TYPE MemoryUsed gauge
181-
// MemoryUsed{Level="Host",Hostname="3a9b51d41fbf",} 2075.09765625
182-
// ...
183-
```
184-
185-
- With the metrics API endpoint other than <http://localhost:8082>:
186-
187-
```java
188-
TensorFlowServingClient client = TensorFlowServingClient.builder()
189-
.metricsAddress("http://localhost:12345")
190-
.build();
191-
```
231+
```console
232+
outputs {
233+
key: "y"
234+
value {
235+
dtype: DT_FLOAT
236+
tensor_shape {
237+
dim {
238+
size: 3
239+
}
240+
}
241+
float_val: 2.5
242+
float_val: 3.0
243+
float_val: 4.5
244+
}
245+
}
246+
model_spec {
247+
name: "half_plus_two"
248+
version {
249+
value: 123
250+
}
251+
signature_name: "serving_default"
252+
}
253+
```
192254

193255
## Configuration
194256

195-
### tsc4j.properties
257+
### tfsc4j.properties
196258

197259
```properties
198-
inference.key = <inference-key>
199-
inference.address = http://localhost:8080
200-
# inference.address takes precedence over inference.port if it's defined
201-
inference.port = 8080
202-
203-
management.key = <management-key>
204-
management.address = http://localhost:8081
205-
# management.address takes precedence over management.port if it's defined
206-
management.port = 8081
207-
208-
metrics.address = http://localhost:8082
209-
# metrics.address takes precedence over metrics.port if it's defined
210-
metrics.port = 8082
260+
target = <target>
261+
credentials = <credentials>
211262
```
212263

213264
### System properties
214265

215-
You can configure the TSC4J properties via system properties with prefix `tsc4j.`.
266+
You can configure the TFSC4J properties via system properties with prefix `tfsc4j.`.
216267

217-
For instance, you can configure `inference.address` with the `tsc4j.inference.address` system property.
268+
For instance, you can configure `target` with the `tfsc4j.target` system property.
218269

219270
### Environment variables
220271

221-
You can also configure the TSC4J properties via environment variables with prefix `TSC4J_`.
272+
You can also configure the TFSC4J properties via environment variables with prefix `TFSC4J_`.
222273

223-
For instance, you can configure `inference.address` with the `TSC4J_INFERENCE_ADDRESS` environment variable.
274+
For instance, you can configure `target` with the `TFSC4J_TARGET` environment variable.
224275

225276
## Examples
226277

0 commit comments

Comments
 (0)