Skip to content

Commit

Permalink
refactor: predict_mnist.java
Browse files Browse the repository at this point in the history
  • Loading branch information
tadayosi committed Dec 16, 2024
1 parent b217e6d commit afedbc7
Showing 1 changed file with 27 additions and 23 deletions.
50 changes: 27 additions & 23 deletions examples/predict_mnist.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,32 +35,36 @@ public static void main(String... args) throws Exception {
System.out.println("Directory: " + file.getFileName());
return;
}
try {
var data = preprocess(file);
var inputs = TensorProto.newBuilder()
.setDtype(DataType.DT_FLOAT)
.setTensorShape(TensorShapeProto.newBuilder()
.addDim(Dim.newBuilder().setSize(28))
.addDim(Dim.newBuilder().setSize(28)))
.addAllFloatVal(data)
.build();
var request = PredictRequest.newBuilder()
.setModelSpec(ModelSpec.newBuilder()
.setName("mnist")
.setVersion(Int64Value.of(1)))
.putInputs("keras_tensor", inputs)
.build();
var response = client.predict(request);
var output = response.getOutputsOrThrow("output_0");
var answer = argmax(output);
System.out.println(" %s => %s".formatted(file.getFileName(), answer));
} catch (IOException e) {
e.printStackTrace();
}
predict(client, file);
});
}
}

static void predict(TensorFlowServingClient client, Path file) {
try {
var data = preprocess(file);
var inputs = TensorProto.newBuilder()
.setDtype(DataType.DT_FLOAT)
.setTensorShape(TensorShapeProto.newBuilder()
.addDim(Dim.newBuilder().setSize(28))
.addDim(Dim.newBuilder().setSize(28)))
.addAllFloatVal(data)
.build();
var request = PredictRequest.newBuilder()
.setModelSpec(ModelSpec.newBuilder()
.setName("mnist")
.setVersion(Int64Value.of(1)))
.putInputs("keras_tensor", inputs)
.build();
var response = client.predict(request);
var output = response.getOutputsOrThrow("output_0");
var answer = argmax(output);
System.out.println(" %s => %s".formatted(file.getFileName(), answer));
} catch (IOException e) {
e.printStackTrace();
}
}

static List<Float> preprocess(Path file) throws IOException {
var image = ImageIO.read(file.toFile());
var width = image.getWidth();
Expand All @@ -72,7 +76,7 @@ static List<Float> preprocess(Path file) throws IOException {
for (var y = 0; y < height; y++) {
for (var x = 0; x < width; x++) {
var rgb = image.getRGB(x, y);
normalised.add( (rgb & 0xFF) / 255.0f);
normalised.add((rgb & 0xFF) / 255.0f);
}
}
return normalised;
Expand Down

0 comments on commit afedbc7

Please sign in to comment.