Skip to content

Commit

Permalink
update examples and readme
Browse files Browse the repository at this point in the history
  • Loading branch information
tadayosi committed Nov 27, 2024
1 parent 9692323 commit f474f58
Show file tree
Hide file tree
Showing 15 changed files with 377 additions and 356 deletions.
341 changes: 196 additions & 145 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
[![Release](https://jitpack.io/v/tadayosi/tensorflow-serving-client-java.svg)](<https://jitpack.io/#tadayosi/tensorflow-serving-client-java>)
[![Test](https://github.com/tadayosi/tensorflow-serving-client-java/actions/workflows/test.yml/badge.svg)](https://github.com/tadayosi/tensorflow-serving-client-java/actions/workflows/test.yml)

TensorFlow Serving Client for Java (TFSC4J) is a Java client library for [TensorFlow Serving](https://github.com/tensorflow/serving). It supports the following [TensorFlow Serving REST API](https://www.tensorflow.org/tfx/serving/api_rest):
TensorFlow Serving Client for Java (TFSC4J) is a Java client library for [TensorFlow Serving](https://github.com/tensorflow/serving). It supports the following [TensorFlow Serving Client API (gRPC)](https://github.com/tensorflow/serving/tree/master/tensorflow_serving/apis):

- [Model status API](https://www.tensorflow.org/tfx/serving/api_rest#model_status_api)
- [Model Metadata API](https://www.tensorflow.org/tfx/serving/api_rest#model_metadata_api)
Expand Down Expand Up @@ -33,194 +33,245 @@ TensorFlow Serving Client for Java (TFSC4J) is a Java client library for [Tensor
<dependency>
<groupId>com.github.tadayosi</groupId>
<artifactId>tensorflow-serving-client-java</artifactId>
<version>v0.3</version>
<version>v0.1</version>
</dependency>
```

## Usage

### Inference
> [!IMPORTANT]
> TFSC4J uses the gRPC port (default: `8500`) to communicate with the TensorFlow model server.

- Prediction:
To creat a client:

```java
TensorFlowServingClient client = TensorFlowServingClient.newInstance();

byte[] image = Files.readAllBytes(Path.of("0.png"));
Object result = client.inference().predictions("mnist_v2", image);
System.out.println(result);
// => 0
```

- With the inference API endpoint other than <http://localhost:8080>:
```java
TensorFlowServingClient client = TensorFlowServingClient.newInstance();
```

```java
TensorFlowServingClient client = TensorFlowServingClient.builder()
.inferenceAddress("http://localhost:12345")
.build();
```
By default, the client connects to `localhost:8500`, but if you want to connect to a different target URI (e.g. `example.com:8080`), instantiate a client as follows:

- With token authorization:
```java
TensorFlowServingClient client = TensorFlowServingClient.builder()
.target("example.com:8080")
.build();
```

```java
TensorFlowServingClient client = TensorFlowServingClient.builder()
.inferenceKey("<inference-key>")
.build();
```
### Model status API

### Management
To get the status of a model:

- Register a model:
```java
TensorFlowServingClient client = TensorFlowServingClient.newInstance();

```java
TensorFlowServingClient client = TensorFlowServingClient.newInstance();
GetModelStatusRequest request = GetModelStatusRequest.newBuilder()
.setModelSpec(ModelSpec.newBuilder()
.setName("half_plus_two")
.setVersion(Int64Value.of(123)))
.build();
GetModelStatusResponse response = client.getModelStatus(request);
System.out.println(response);
```

Response response = client.management().registerModel(
"https://torchserve.pytorch.org/mar_files/mnist_v2.mar",
RegisterModelOptions.empty());
System.out.println(response.getStatus());
// => "Model "mnist_v2" Version: 2.0 registered with 0 initial workers. Use scale workers API to add workers for the model."
```
Output:

- Scale workers for a model:
```console
model_version_status {
version: 123
state: AVAILABLE
status {
}
}
```

```java
TensorFlowServingClient client = TensorFlowServingClient.newInstance();
### Model Metadata API

Response response = client.management().setAutoScale(
"mnist_v2",
SetAutoScaleOptions.builder()
.minWorker(1)
.maxWorker(2)
.build());
System.out.println(response.getStatus());
// => "Processing worker updates..."
```
To get the metadata of a model:

- Describe a model:
```java
TensorFlowServingClient client = TensorFlowServingClient.newInstance();

```java
TensorFlowServingClient client = TensorFlowServingClient.newInstance();
GetModelMetadataRequest request = GetModelMetadataRequest.newBuilder()
.setModelSpec(ModelSpec.newBuilder()
.setName("half_plus_two")
.setVersion(Int64Value.of(123)))
.addMetadataField("signature_def")) // metadata_field is mandatory
.build();
GetModelMetadataResponse response = client.getModelMetadata(request);
System.out.println(response);
```

List<ModelDetail> model = client.management().describeModel("mnist_v2");
System.out.println(model.get(0));
// =>
// ModelDetail {
// modelName: mnist_v2
// modelVersion: 2.0
// ...
```
Output:

- Unregister a model:
```console
model_spec {
name: "half_plus_two"
version {
value: 123
}
}
metadata {
key: "signature_def"
value {
type_url: "type.googleapis.com/tensorflow.serving.SignatureDefMap"
value: "..."
}
}
```

```java
TensorFlowServingClient client = TensorFlowServingClient.newInstance();
### Classify API

To classify:

```java
TensorFlowServingClient client = TensorFlowServingClient.newInstance();

ClassificationRequest request = ClassificationRequest.newBuilder()
.setModelSpec(ModelSpec.newBuilder()
.setName("half_plus_two")
.setVersion(Int64Value.of(123))
.setSignatureName("classify_x_to_y"))
.setInput(Input.newBuilder()
.setExampleList(ExampleList.newBuilder()
.addExamples(Example.newBuilder()
.setFeatures(Features.newBuilder()
.putFeature("x", Feature.newBuilder()
.setFloatList(FloatList.newBuilder().addValue(1.0f))
.build())))))
.build();
ClassificationResponse response = client.classify(request);
System.out.println(response);
```

Response response = client.management().unregisterModel(
"mnist_v2",
UnregisterModelOptions.empty());
System.out.println(response.getStatus());
// => "Model "mnist_v2" unregistered"
```
Output:

- List models:
```console
result {
classifications {
classes {
score: 2.5
}
}
}
model_spec {
name: "half_plus_two"
version {
value: 123
}
signature_name: "classify_x_to_y"
}
```

```java
TensorFlowServingClient client = TensorFlowServingClient.newInstance();
### Regress API

To regress:

```java
TensorFlowServingClient client = TensorFlowServingClient.newInstance();

RegressionRequest request = RegressionRequest.newBuilder()
.setModelSpec(ModelSpec.newBuilder()
.setName("half_plus_two")
.setVersion(Int64Value.of(123))
.setSignatureName("regress_x_to_y"))
.setInput(Input.newBuilder()
.setExampleList(ExampleList.newBuilder()
.addExamples(Example.newBuilder()
.setFeatures(Features.newBuilder()
.putFeature("x", Feature.newBuilder()
.setFloatList(FloatList.newBuilder().addValue(1.0f))
.build())))))
.build();
RegressionResponse response = client.regress(request);
System.out.println(response);
```

ModelList models = client.management().listModels(10, null);
System.out.println(models);
// =>
// ModelList {
// nextPageToken: null
// models: [Model {
// modelName: mnist_v2
// modelUrl: https://torchserve.pytorch.org/mar_files/mnist_v2.mar
// },
// ...
```
Output:

- Set default version for a model:

```java
TensorFlowServingClient client = TensorFlowServingClient.newInstance();

Response response = client.management().setDefault("mnist_v2", "2.0");
System.out.println(response.getStatus());
// => "Default version successfully updated for model "mnist_v2" to "2.0""
```

- With the management API endpoint other than <http://localhost:8081>:
```console
result {
regressions {
value: 2.5
}
}
model_spec {
name: "half_plus_two"
version {
value: 123
}
signature_name: "regress_x_to_y"
}
```

```java
TensorFlowServingClient client = TensorFlowServingClient.builder()
.managementAddress("http://localhost:12345")
.build();
```

- With token authorization:
### Predict API

To predict:

```java
TensorFlowServingClient client = TensorFlowServingClient.newInstance();

PredictRequest request = PredictRequest.newBuilder()
.setModelSpec(ModelSpec.newBuilder()
.setName("half_plus_two")
.setVersion(Int64Value.of(123)))
.putInputs("x", TensorProto.newBuilder()
.setDtype(DataType.DT_FLOAT)
.setTensorShape(TensorShapeProto.newBuilder()
.addDim(Dim.newBuilder().setSize(3)))
.addFloatVal(1.0f)
.addFloatVal(2.0f)
.addFloatVal(5.0f)
.build())
.build();
PredictResponse response = client.predict(request);
System.out.println(response);
```

```java
TensorFlowServingClient client = TensorFlowServingClient.builder()
.managementKey("<management-key>")
.build();
```

### Metrics
Output:

- Get metrics in Prometheus format:

```java
TensorFlowServingClient client = TensorFlowServingClient.newInstance();

String metrics = client.metrics().metrics();
System.out.println(metrics);
// =>
// # HELP MemoryUsed Torchserve prometheus gauge metric with unit: Megabytes
// # TYPE MemoryUsed gauge
// MemoryUsed{Level="Host",Hostname="3a9b51d41fbf",} 2075.09765625
// ...
```

- With the metrics API endpoint other than <http://localhost:8082>:

```java
TensorFlowServingClient client = TensorFlowServingClient.builder()
.metricsAddress("http://localhost:12345")
.build();
```
```console
outputs {
key: "y"
value {
dtype: DT_FLOAT
tensor_shape {
dim {
size: 3
}
}
float_val: 2.5
float_val: 3.0
float_val: 4.5
}
}
model_spec {
name: "half_plus_two"
version {
value: 123
}
signature_name: "serving_default"
}
```

## Configuration

### tsc4j.properties
### tfsc4j.properties

```properties
inference.key = <inference-key>
inference.address = http://localhost:8080
# inference.address takes precedence over inference.port if it's defined
inference.port = 8080

management.key = <management-key>
management.address = http://localhost:8081
# management.address takes precedence over management.port if it's defined
management.port = 8081

metrics.address = http://localhost:8082
# metrics.address takes precedence over metrics.port if it's defined
metrics.port = 8082
target = <target>
credentials = <credentials>
```

### System properties

You can configure the TSC4J properties via system properties with prefix `tsc4j.`.
You can configure the TFSC4J properties via system properties with prefix `tfsc4j.`.

For instance, you can configure `inference.address` with the `tsc4j.inference.address` system property.
For instance, you can configure `target` with the `tfsc4j.target` system property.

### Environment variables

You can also configure the TSC4J properties via environment variables with prefix `TSC4J_`.
You can also configure the TFSC4J properties via environment variables with prefix `TFSC4J_`.

For instance, you can configure `inference.address` with the `TSC4J_INFERENCE_ADDRESS` environment variable.
For instance, you can configure `target` with the `TFSC4J_TARGET` environment variable.

## Examples

Expand Down
Loading

0 comments on commit f474f58

Please sign in to comment.