update = accuracyHelper(labels, predictions);
- totalInstances.compute(key, (k, v) -> v + update.getKey());
- correctInstances.compute(key, (k, v) -> v + update.getValue().sum().getLong());
+ NDArray value = update.getValue();
+ NDArray sum = value.sum();
+ long correct = sum.getLong();
+ for (String key : keys) {
+ totalInstances.compute(key, (k, v) -> v + update.getKey());
+ correctInstances.compute(key, (k, v) -> v + correct);
+ }
+ value.close();
+ sum.close();
}
/** {@inheritDoc} */
diff --git a/api/src/main/java/ai/djl/training/evaluator/BoundingBoxError.java b/api/src/main/java/ai/djl/training/evaluator/BoundingBoxError.java
index 4af9e5de3d1..ab2d554142d 100644
--- a/api/src/main/java/ai/djl/training/evaluator/BoundingBoxError.java
+++ b/api/src/main/java/ai/djl/training/evaluator/BoundingBoxError.java
@@ -63,10 +63,18 @@ public void addAccumulator(String key) {
/** {@inheritDoc} */
@Override
public void updateAccumulator(String key, NDList labels, NDList predictions) {
+ updateAccumulators(new String[] {key}, labels, predictions);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public void updateAccumulators(String[] keys, NDList labels, NDList predictions) {
NDArray boundingBoxError = evaluate(labels, predictions);
float update = boundingBoxError.sum().getFloat();
- totalInstances.compute(key, (k, v) -> v + boundingBoxError.size());
- ssdBoxPredictionError.compute(key, (k, v) -> v + update);
+ for (String key : keys) {
+ totalInstances.compute(key, (k, v) -> v + boundingBoxError.size());
+ ssdBoxPredictionError.compute(key, (k, v) -> v + update);
+ }
}
/** {@inheritDoc} */
diff --git a/api/src/main/java/ai/djl/training/evaluator/Evaluator.java b/api/src/main/java/ai/djl/training/evaluator/Evaluator.java
index 6d2c5995601..c373471f6cf 100644
--- a/api/src/main/java/ai/djl/training/evaluator/Evaluator.java
+++ b/api/src/main/java/ai/djl/training/evaluator/Evaluator.java
@@ -74,6 +74,25 @@ public String getName() {
*/
public abstract void addAccumulator(String key);
+ /**
+ * Updates the evaluator with the given keys based on a {@link NDList} of labels and
+ * predictions.
+ *
+ * This is a synchronized operation. You should only call it at the end of a batch or epoch.
+ *
+ *
This is an alternative to @{link {@link #updateAccumulator(String, NDList, NDList)}} that
+ * may be more efficient when updating multiple accumulators at once.
+ *
+ * @param keys the keys of all the accumulators to update
+ * @param labels a {@code NDList} of labels
+ * @param predictions a {@code NDList} of predictions
+ */
+ public void updateAccumulators(String[] keys, NDList labels, NDList predictions) {
+ for (String key : keys) {
+ updateAccumulator(key, labels, predictions);
+ }
+ }
+
/**
* Updates the evaluator with the given key based on a {@link NDList} of labels and predictions.
*
diff --git a/api/src/main/java/ai/djl/training/evaluator/IndexEvaluator.java b/api/src/main/java/ai/djl/training/evaluator/IndexEvaluator.java
index a7fe08b610e..aa12cae628c 100644
--- a/api/src/main/java/ai/djl/training/evaluator/IndexEvaluator.java
+++ b/api/src/main/java/ai/djl/training/evaluator/IndexEvaluator.java
@@ -67,6 +67,12 @@ public void updateAccumulator(String key, NDList labels, NDList predictions) {
evaluator.updateAccumulator(key, getLabels(labels), getPredictions(predictions));
}
+ /** {@inheritDoc} */
+ @Override
+ public void updateAccumulators(String[] keys, NDList labels, NDList predictions) {
+ evaluator.updateAccumulators(keys, getLabels(labels), getPredictions(predictions));
+ }
+
/** {@inheritDoc} */
@Override
public void resetAccumulator(String key) {
diff --git a/api/src/main/java/ai/djl/training/listener/AlgebraicListener.java b/api/src/main/java/ai/djl/training/listener/AlgebraicListener.java
new file mode 100644
index 00000000000..51b5288e838
--- /dev/null
+++ b/api/src/main/java/ai/djl/training/listener/AlgebraicListener.java
@@ -0,0 +1,288 @@
+/*
+ * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+package ai.djl.training.listener;
+
+import ai.djl.Device;
+import ai.djl.Model;
+import ai.djl.ndarray.NDArray;
+import ai.djl.ndarray.NDList;
+import ai.djl.ndarray.types.Shape;
+import ai.djl.nn.Parameter;
+import ai.djl.training.Trainer;
+import ai.djl.util.NativeResource;
+import ai.djl.util.Pair;
+import ai.djl.util.PairList;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.io.PrintStream;
+import java.nio.charset.StandardCharsets;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.LinkedHashMap;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+/** {@link TrainingListener} that records algebraic operations as Python code. */
+public class AlgebraicListener extends TrainingListenerAdapter {
+
+ private static AlgebraicListener currentListener;
+
+ private static final Logger logger = LoggerFactory.getLogger(AlgebraicListener.class);
+
+ private final Map nodeMap = new ConcurrentHashMap<>();
+ private final Map nodeMapForParameters = new ConcurrentHashMap<>();
+
+ @SuppressWarnings("PMD.UseConcurrentHashMap")
+ private final Map losses = new LinkedHashMap<>();
+
+ @SuppressWarnings("PMD.UseConcurrentHashMap")
+ private final Map predictions = new LinkedHashMap<>();
+
+ private Map parameters;
+ private String outputFile;
+ private AtomicInteger parametersOpCount = new AtomicInteger(0);
+
+ private int numEpoch;
+
+ /**
+ * New listener to record algebraic operations into the given file.
+ *
+ * @param outputFile file to store output - will be overridden if exist
+ */
+ public AlgebraicListener(String outputFile) {
+ this.outputFile = outputFile;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public void onEpoch(Trainer trainer) {
+ numEpoch++;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public void onTrainingBatch(Trainer trainer, BatchData batchData) {
+ writeParameters(trainer.getModel());
+ AtomicInteger opCount = new AtomicInteger(parametersOpCount.get());
+ for (Device device : batchData.getLabels().keySet()) {
+ NDList data = batchData.getData().get(device);
+ NDList preds = batchData.getPredictions().get(device);
+ NDList labels = batchData.getLabels().get(device);
+ NDArray loss = batchData.getLoss().get(device);
+ if (data != null) {
+ setLeaf(data, "x");
+ }
+ if (preds != null) {
+ writePredictions(preds, opCount);
+ }
+ if (preds != null) {
+ setLeaf(preds, "prediction");
+ }
+ if (labels != null) {
+ setLeaf(labels, "label");
+ }
+ if (loss != null) {
+ writeLoss(loss, opCount);
+ }
+ }
+ nodeMap.clear();
+ nodeMap.putAll(nodeMapForParameters);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public void onTrainingBegin(Trainer trainer) {
+ setCurrentListener(this);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public void onTrainingEnd(Trainer trainer) {
+ try (OutputStream out = Files.newOutputStream(Paths.get(outputFile))) {
+ describe(out);
+ } catch (IOException e) {
+ logger.error("Failed logging algebraic operations", e);
+ }
+ parameters.clear();
+ predictions.clear();
+ losses.clear();
+ nodeMap.clear();
+ nodeMapForParameters.clear();
+ setCurrentListener(null);
+ }
+
+ private void setLeaf(NDArray x, String name) {
+ Node node = get(x);
+ if (node == null) {
+ return;
+ }
+ node.name = name;
+ node.isLeaf = true;
+ }
+
+ private void setLeaf(NDList data, String name) {
+ for (NDArray x : data) {
+ setLeaf(x, name);
+ }
+ }
+
+ private void writePredictions(NDList preds, AtomicInteger opCount) {
+ String tuple =
+ preds.stream()
+ .map(this::getArrayName)
+ .collect(Collectors.joining(", ", "return tf.tuple([", "])"));
+ if (preds.size() == 1) {
+ tuple = "return result";
+ }
+ String python =
+ preds.stream()
+ .map(pred -> get(pred).toPythonFunctionBody(opCount, getArrayName(pred)))
+ .collect(Collectors.joining("\n", "", "\n" + Node.indent(tuple)));
+ predictions.compute(python, (key, count) -> count == null ? 1 : count + 1);
+ }
+
+ private String getArrayName(NDArray pred) {
+ return pred.getName() != null ? pred.getName() : "result";
+ }
+
+ private void writeLoss(NDArray loss, AtomicInteger opCount) {
+ String python =
+ get(loss).toPythonFunctionBody(opCount, "result")
+ + "\n"
+ + Node.indent("return result");
+ losses.compute(python, (key, count) -> count == null ? 1 : count + 1);
+ }
+
+ private void describe(OutputStream out) throws IOException {
+ PrintStream writer = new PrintStream(out, true, StandardCharsets.US_ASCII.name());
+ writer.println("class MyModel(tf.keras.Model):");
+ writer.println(" def __init__(self, **kwargs):");
+ writer.println(" super().__init__(**kwargs)");
+ for (Entry param : parameters.entrySet()) {
+ writer.println(Node.indent(param.getKey() + " = tf.Variable("));
+ writer.println(Node.indent(Node.indent(param.getValue())));
+ writer.println(Node.indent(")"));
+ }
+ writer.println("");
+ for (Entry pred : predictions.entrySet()) {
+ writer.println("## " + pred.getValue());
+ writer.println(" def call(self, x):");
+ writer.println(pred.getKey());
+ }
+ writer.println("");
+ for (Entry loss : losses.entrySet()) {
+ writer.println("## " + loss.getValue());
+ writer.println("def loss(label, prediction):");
+ writer.println(loss.getKey());
+ }
+ writer.println("");
+ writer.println(String.format("# number of epochs was %s", numEpoch));
+ writer.println(String.format("# number of prediction functions is %s", predictions.size()));
+ writer.println(String.format("# number of loss functions is %s", losses.size()));
+ writer.println("");
+ }
+
+ private void writeParameters(Model model) {
+ if (parameters != null) {
+ return;
+ }
+ parameters = new LinkedHashMap<>();
+ for (Pair pair : model.getBlock().getParameters()) {
+ NDArray array = pair.getValue().getArray();
+
+ Node init = get(array);
+ String initialization;
+ if (pair.getKey().endsWith("Conv2d_weight")) {
+ int[] perm = {2, 3, 1, 0};
+ PairList param =
+ new PairList<>(Collections.singletonMap("axes", Arrays.toString(perm)));
+ Node transpose = new Node("_np_transpose", param, init);
+ transpose.outputShape =
+ new Shape(IntStream.of(perm).mapToLong(init.outputShape::get).toArray());
+ initialization = transpose.toPythonExpression(null, parametersOpCount);
+ init.outputShape = transpose.outputShape;
+ } else {
+ initialization =
+ init.toPythonExpression(null, parametersOpCount)
+ + (pair.getValue().requiresGradient()
+ ? ""
+ : "\n, trainable = False");
+ }
+ String pythonClassVariable = "self._" + pair.getKey();
+ parameters.put(pythonClassVariable, initialization);
+ setLeaf(array, pythonClassVariable);
+ nodeMapForParameters.put(key(array), init);
+ }
+ }
+
+ /**
+ * Records an algebraic operation that is executed with the given parameters.
+ *
+ * @param name the name of the operation
+ * @param src the input to the operation
+ * @param dest the output of the operation
+ * @param param parameters for the operation
+ */
+ public static void record(
+ String name, NDArray[] src, NDArray[] dest, PairList param) {
+ if (currentListener != null) {
+ currentListener.recordInternal(name, src, dest, param);
+ }
+ }
+
+ private void recordInternal(
+ String name, NDArray[] src, NDArray[] dest, PairList param) {
+ Node n = new Node(name, param != null ? param : new PairList<>(), new Node[src.length]);
+ int index = 0;
+ for (NDArray array : src) {
+ Node node = get(array);
+ if (node == null) {
+ node =
+ new Node(
+ array.getName() != null
+ ? array.getName()
+ : "UNKNOWN_ARRAY" + array.getShape(),
+ new PairList<>());
+ nodeMap.put(key(array), n);
+ node.outputShape = array.getShape();
+ }
+ n.src[index++] = node;
+ }
+ for (NDArray array : dest) {
+ nodeMap.put(key(array), n);
+ n.outputShape = array.getShape();
+ }
+ }
+
+ private Node get(NDArray array) {
+ return nodeMap.get(key(array));
+ }
+
+ private Object key(NDArray array) {
+ return ((NativeResource>) array).getHandle();
+ }
+
+ private static void setCurrentListener(AlgebraicListener algebraicListener) {
+ currentListener = algebraicListener;
+ }
+}
diff --git a/api/src/main/java/ai/djl/training/listener/EarlyStoppingListener.java b/api/src/main/java/ai/djl/training/listener/EarlyStoppingListener.java
new file mode 100644
index 00000000000..6c013c37715
--- /dev/null
+++ b/api/src/main/java/ai/djl/training/listener/EarlyStoppingListener.java
@@ -0,0 +1,281 @@
+/*
+ * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+package ai.djl.training.listener;
+
+import ai.djl.training.Trainer;
+import ai.djl.training.TrainingResult;
+
+import java.time.Duration;
+
+/**
+ * Listener that allows the training to be stopped early if the validation loss is not improving, or
+ * if time has expired.
+ *
+ * Usage: Add this listener to the training config, and add it as the last one.
+ *
+ *
+ * new DefaultTrainingConfig(...)
+ * .addTrainingListeners(EarlyStoppingListener.builder()
+ * .setEpochPatience(1)
+ * .setEarlyStopPctImprovement(1)
+ * .setMaxDuration(Duration.ofMinutes(42))
+ * .setMinEpochs(1)
+ * .build()
+ * );
+ *
+ *
+ * Then surround the fit with a try catch that catches the {@link
+ * EarlyStoppingListener.EarlyStoppedException}.
+ * Example:
+ *
+ *
+ * try {
+ * EasyTrain.fit(trainer, 5, trainDataset, testDataset);
+ * } catch (EarlyStoppingListener.EarlyStoppedException e) {
+ * // handle early stopping
+ * log.info("Stopped early at epoch {} because: {}", e.getEpoch(), e.getMessage());
+ * }
+ *
+ *
+ *
+ * Note: Ensure that Metrics are set on the trainer.
+ */
+public final class EarlyStoppingListener implements TrainingListener {
+ private final double objectiveSuccess;
+
+ private final int minEpochs;
+ private final long maxMillis;
+ private final double earlyStopPctImprovement;
+ private final int epochPatience;
+
+ private long startTimeMills;
+ private double prevLoss;
+ private int numberOfEpochsWithoutImprovements;
+
+ private EarlyStoppingListener(
+ double objectiveSuccess,
+ int minEpochs,
+ long maxMillis,
+ double earlyStopPctImprovement,
+ int earlyStopPatience) {
+ this.objectiveSuccess = objectiveSuccess;
+ this.minEpochs = minEpochs;
+ this.maxMillis = maxMillis;
+ this.earlyStopPctImprovement = earlyStopPctImprovement;
+ this.epochPatience = earlyStopPatience;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public void onEpoch(Trainer trainer) {
+ int currentEpoch = trainer.getTrainingResult().getEpoch();
+ // stopping criteria
+ final double loss = getLoss(trainer.getTrainingResult());
+ if (currentEpoch >= minEpochs) {
+ if (loss < objectiveSuccess) {
+ throw new EarlyStoppedException(
+ currentEpoch,
+ String.format(
+ "validation loss %s < objectiveSuccess %s",
+ loss, objectiveSuccess));
+ }
+ long elapsedMillis = System.currentTimeMillis() - startTimeMills;
+ if (elapsedMillis >= maxMillis) {
+ throw new EarlyStoppedException(
+ currentEpoch,
+ String.format("%s ms elapsed >= %s maxMillis", elapsedMillis, maxMillis));
+ }
+ // consider early stopping?
+ if (Double.isFinite(prevLoss)) {
+ double goalImprovement = prevLoss * (100 - earlyStopPctImprovement) / 100.0;
+ boolean improved = loss <= goalImprovement; // false if any NANs
+ if (improved) {
+ numberOfEpochsWithoutImprovements = 0;
+ } else {
+ numberOfEpochsWithoutImprovements++;
+ if (numberOfEpochsWithoutImprovements >= epochPatience) {
+ throw new EarlyStoppedException(
+ currentEpoch,
+ String.format(
+ "failed to achieve %s%% improvement %s times in a row",
+ earlyStopPctImprovement, epochPatience));
+ }
+ }
+ }
+ }
+ if (Double.isFinite(loss)) {
+ prevLoss = loss;
+ }
+ }
+
+ private static double getLoss(TrainingResult trainingResult) {
+ Float vLoss = trainingResult.getValidateLoss();
+ if (vLoss != null) {
+ return vLoss;
+ }
+ Float tLoss = trainingResult.getTrainLoss();
+ if (tLoss == null) {
+ return Double.NaN;
+ }
+ return tLoss;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public void onTrainingBatch(Trainer trainer, BatchData batchData) {
+ // do nothing
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public void onValidationBatch(Trainer trainer, BatchData batchData) {
+ // do nothing
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public void onTrainingBegin(Trainer trainer) {
+ this.startTimeMills = System.currentTimeMillis();
+ this.prevLoss = Double.NaN;
+ this.numberOfEpochsWithoutImprovements = 0;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public void onTrainingEnd(Trainer trainer) {
+ // do nothing
+ }
+
+ /**
+ * Creates a builder to build a {@link EarlyStoppingListener}.
+ *
+ * @return a new builder
+ */
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ /** A builder for a {@link EarlyStoppingListener}. */
+ public static final class Builder {
+ private final double objectiveSuccess;
+ private int minEpochs;
+ private long maxMillis;
+ private double earlyStopPctImprovement;
+ private int epochPatience;
+
+ /** Constructs a {@link Builder} with default values. */
+ public Builder() {
+ this.objectiveSuccess = 0;
+ this.minEpochs = 0;
+ this.maxMillis = Long.MAX_VALUE;
+ this.earlyStopPctImprovement = 0;
+ this.epochPatience = 0;
+ }
+
+ /**
+ * Set the minimum # epochs, defaults to 0.
+ *
+ * @param minEpochs the minimum # epochs
+ * @return this builder
+ */
+ public Builder optMinEpochs(int minEpochs) {
+ this.minEpochs = minEpochs;
+ return this;
+ }
+
+ /**
+ * Set the maximum duration a training run should take, defaults to Long.MAX_VALUE in ms.
+ *
+ * @param duration the maximum duration a training run should take
+ * @return this builder
+ */
+ public Builder optMaxDuration(Duration duration) {
+ this.maxMillis = duration.toMillis();
+ return this;
+ }
+
+ /**
+ * Set the maximum # milliseconds a training run should take, defaults to Long.MAX_VALUE.
+ *
+ * @param maxMillis the maximum # milliseconds a training run should take
+ * @return this builder
+ */
+ public Builder optMaxMillis(int maxMillis) {
+ this.maxMillis = maxMillis;
+ return this;
+ }
+
+ /**
+ * Consider early stopping if not x% improvement, defaults to 0.
+ *
+ * @param earlyStopPctImprovement the percentage improvement to consider early stopping,
+ * must be between 0 and 100.
+ * @return this builder
+ */
+ public Builder optEarlyStopPctImprovement(double earlyStopPctImprovement) {
+ this.earlyStopPctImprovement = earlyStopPctImprovement;
+ return this;
+ }
+
+ /**
+ * Stop if insufficient improvement for x epochs in a row, defaults to 0.
+ *
+ * @param epochPatience the number of epochs without improvement to consider stopping, must
+ * be greater than 0.
+ * @return this builder
+ */
+ public Builder optEpochPatience(int epochPatience) {
+ this.epochPatience = epochPatience;
+ return this;
+ }
+
+ /**
+ * Builds a {@link EarlyStoppingListener} with the specified values.
+ *
+ * @return a new {@link EarlyStoppingListener}
+ */
+ public EarlyStoppingListener build() {
+ return new EarlyStoppingListener(
+ objectiveSuccess, minEpochs, maxMillis, earlyStopPctImprovement, epochPatience);
+ }
+ }
+
+ /**
+ * Thrown when training is stopped early, the message will contain the reason why it is stopped
+ * early.
+ */
+ public static class EarlyStoppedException extends RuntimeException {
+ private static final long serialVersionUID = 1L;
+ private final int stopEpoch;
+
+ /**
+ * Constructs an {@link EarlyStoppedException} with the specified message and epoch.
+ *
+ * @param stopEpoch the epoch at which training was stopped early
+ * @param message the message/reason why training was stopped early
+ */
+ public EarlyStoppedException(int stopEpoch, String message) {
+ super(message);
+ this.stopEpoch = stopEpoch;
+ }
+
+ /**
+ * Gets the epoch at which training was stopped early.
+ *
+ * @return the epoch at which training was stopped early.
+ */
+ public int getStopEpoch() {
+ return stopEpoch;
+ }
+ }
+}
diff --git a/api/src/main/java/ai/djl/training/listener/EvaluatorTrainingListener.java b/api/src/main/java/ai/djl/training/listener/EvaluatorTrainingListener.java
index 1dbfe4117cd..2556a026259 100644
--- a/api/src/main/java/ai/djl/training/listener/EvaluatorTrainingListener.java
+++ b/api/src/main/java/ai/djl/training/listener/EvaluatorTrainingListener.java
@@ -144,9 +144,7 @@ private void updateEvaluators(Trainer trainer, BatchData batchData, String[] acc
for (Device device : batchData.getLabels().keySet()) {
NDList labels = batchData.getLabels().get(device);
NDList predictions = batchData.getPredictions().get(device);
- for (String accumulator : accumulators) {
- evaluator.updateAccumulator(accumulator, labels, predictions);
- }
+ evaluator.updateAccumulators(accumulators, labels, predictions);
}
}
}
diff --git a/api/src/main/java/ai/djl/training/listener/Node.java b/api/src/main/java/ai/djl/training/listener/Node.java
new file mode 100644
index 00000000000..8dfa569e26a
--- /dev/null
+++ b/api/src/main/java/ai/djl/training/listener/Node.java
@@ -0,0 +1,463 @@
+/*
+ * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+package ai.djl.training.listener;
+
+import ai.djl.ndarray.types.Shape;
+import ai.djl.util.PairList;
+
+import java.util.ArrayList;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+import java.util.stream.Stream;
+
+/** One node of the computational graph. */
+class Node {
+
+ String name;
+ final Node[] src;
+ final PairList param;
+ boolean isLeaf;
+ Shape outputShape;
+
+ public Node(String name, PairList param, Node... src) {
+ this.name = name;
+ this.param = param;
+ this.src = src;
+ }
+
+ String toPythonExpression(Map locals, AtomicInteger opCount) {
+ return toPythonExpression(locals, opCount, false) + " # " + outputShape;
+ }
+
+ String toPythonExpression(Map locals, AtomicInteger opCount, boolean useLocals) {
+ if (isLeaf) {
+ return name;
+ }
+ if (useLocals && locals != null && locals.containsKey(this)) {
+ return locals.get(this);
+ }
+ switch (name) {
+ case "pick":
+ {
+ Object[][] args = {{0}, {1, "indices"}, {"axis", "batch_dims"}};
+ return format("tf.gather", args, locals, opCount);
+ }
+ case "where":
+ {
+ Object[][] args = {{0}, {1, "x"}, {2, "y"}};
+ return format("tf.where", args, locals, opCount);
+ }
+ case "_npi_slice":
+ {
+ Object[][] args = {
+ {0}, {"begin", "begin"}, {"end", "end"}, {"step", "strides"}
+ };
+ return format("tf.strided_slice", args, locals, opCount);
+ }
+ case "_npi_concatenate":
+ {
+ Object[][] args = {{-1}, {"axis", "axis"}};
+ return format("tf.concat", args, locals, opCount);
+ }
+ case "_np_squeeze":
+ {
+ Object[][] args = {{0}, {"axis", "axis"}};
+ return format("tf.squeeze", args, locals, opCount);
+ }
+ case "_npi_stack":
+ {
+ Object[][] args = {{-1}, {"axis", "axis"}};
+ return format("tf.stack", args, locals, opCount);
+ }
+ case "_npi_split":
+ {
+ Object[][] args = {
+ {0}, {"axis", "axis"}, {"num_outputs", "num_or_size_splits"}
+ };
+ return format("tf.split", args, locals, opCount);
+ }
+ case "_npi_swapaxes":
+ {
+ Object[][] args = {{0}, {"dim1", "axis1"}, {"dim2", "axis2"}};
+ return format("tf.experimental.numpy.swapaxes", args, locals, opCount);
+ }
+ case "_np_repeat":
+ {
+ Object[][] args = {{0}, {"repeats", "repeats"}, {"axis", "axis"}};
+ return format("tf.repeat", args, locals, opCount);
+ }
+ case "_npi_copyto":
+ {
+ return src[0].toPythonExpression(locals, opCount, true);
+ }
+ case "_npi_expand_dims":
+ {
+ Object[][] args = {{0}, {"axis", "axis"}};
+ return format("tf.expand_dims", args, locals, opCount);
+ }
+ case "_npx_log_softmax":
+ {
+ Object[][] args = {{0}, {"axis", "axis"}};
+ return format("tf.nn.log_softmax", args, locals, opCount);
+ }
+ case "_npi_zeros":
+ {
+ Object[][] args = {{"shape", "shape"}, {"dtype", "dtype", "tf.dtypes.%s"}};
+ return format("tf.zeros", args, locals, opCount);
+ }
+ case "_npi_ones":
+ {
+ Object[][] args = {{"shape", "shape"}, {"dtype", "dtype", "tf.dtypes.%s"}};
+ return format("tf.ones", args, locals, opCount);
+ }
+ case "_npi_normal":
+ {
+ Object[][] args = {
+ {"size", "shape"},
+ {"loc", "mean"},
+ {"scale", "stddev"},
+ {"dtype", "dtype", "tf.dtypes.%s"}
+ };
+ return format("tf.random.normal", args, locals, opCount);
+ }
+ case "_npi_uniform":
+ {
+ Object[][] args = {
+ {"low", "minval"},
+ {"high", "maxval"},
+ {"shape", "shape"},
+ {"dtype", "dtype", "tf.dtypes.%s"}
+ };
+ return format("tf.random.uniform", args, locals, opCount);
+ }
+ case "_np_reshape":
+ {
+ Object[][] args = {{0}, {"newshape", "shape"}};
+ return format("tf.reshape", args, locals, opCount);
+ }
+ case "_np_transpose":
+ {
+ Object[][] args = {{0}, {"axes", "perm"}};
+ return format("tf.transpose", args, locals, opCount);
+ }
+ case "_npx_activation":
+ {
+ Object[][] args = {{0}};
+ String op =
+ this.param.get("act_type").toString().replace("softrelu", "softplus");
+ return format("tf.nn." + op, args, locals, opCount);
+ }
+ case "_npx_convolution":
+ {
+ String padding = "(0, 0)".equals(this.param.get("pad")) ? "'VALID'" : "'SAME'";
+ Object[][] args = {
+ {0},
+ {1, "filters"},
+ {"stride", "strides"},
+ {"pad", "padding", padding},
+ {"dilate", "dilations"},
+ {null, "data_format", "'NCHW'"}
+ };
+ return addBias(
+ format("tf.nn.convolution", args, locals, opCount),
+ true,
+ locals,
+ opCount);
+ }
+ case "_npx_pooling":
+ {
+ if ("True".equals(this.param.get("global_pool"))) {
+ String op =
+ "avg".equals(this.param.get("pool_type"))
+ ? "reduce_mean"
+ : "reduce_max";
+ Object[][] args = {{0}, {null, "axis", "[2, 3]"}};
+ return format("tf." + op, args, locals, opCount);
+ }
+ String padding = "(0, 0)".equals(this.param.get("pad")) ? "'VALID'" : "'SAME'";
+ String poolingType =
+ "avg".equals(this.param.get("pool_type")) ? "'AVG'" : "'MAX'";
+ Object[][] args = {
+ {0},
+ {"kernel", "window_shape"},
+ {"pool_type", "pooling_type", poolingType},
+ {"stride", "strides"},
+ {"pad", "padding", padding},
+ {"dilate", "dilations"},
+ {null, "data_format", "'NCHW'"}
+ };
+ return format("tf.nn.pool", args, locals, opCount);
+ }
+ case "_npx_batch_norm":
+ {
+ Object[][] args = {
+ {0},
+ {1, "scale"},
+ {2, "offset"},
+ {3, "mean"},
+ {4, "variance"},
+ {"eps", "epsilon"},
+ {null, "is_training", "True"},
+ {"momentum", "exponential_avg_factor"},
+ {null, "data_format", "'NCHW'"}
+ };
+ return format("tf.compat.v1.nn.fused_batch_norm", args, locals, opCount);
+ }
+
+ case "_npx_embedding":
+ {
+ Object[][] args = {
+ {0, "ids"},
+ {1, "params"}
+ };
+ return format("tf.nn.embedding_lookup", args, locals, opCount);
+ }
+ case "_npx_fully_connected":
+ {
+ Object[][] args = {{0}, {1, "b"}, {null, "transpose_b", "True"}};
+ return addBias(
+ format("tf.matmul", args, locals, opCount), false, locals, opCount);
+ }
+ case "_npi_matmul":
+ {
+ Object[][] args = {{0}, {1}};
+ return addBias(
+ format("tf.matmul", args, locals, opCount), false, locals, opCount);
+ }
+ case "_npi_not_equal_scalar":
+ {
+ Object[][] args = {{0}, {"scalar", "y"}};
+ return format("tf.not_equal", args, locals, opCount);
+ }
+ case "_rdiv_scalar":
+ {
+ Object[][] args = {{0}, {"scalar", "y"}};
+ return format("tf.divide", args, locals, opCount);
+ }
+ case "_npi_add_scalar":
+ {
+ Object[][] args = {{0}, {"scalar", "y"}};
+ return format("tf.add", args, locals, opCount);
+ }
+ case "_npi_add":
+ {
+ Object[][] args = {{0}, {1}};
+ return format("tf.add", args, locals, opCount);
+ }
+ case "_npi_subtract":
+ {
+ Object[][] args = {{0}, {1}};
+ return format("tf.subtract", args, locals, opCount);
+ }
+ case "_npi_mean":
+ {
+ Object[][] args = {{0}, {"axis", "axis"}, {"keepdims", "keepdims"}};
+ return format("tf.reduce_mean", args, locals, opCount);
+ }
+ case "gammaln":
+ {
+ Object[][] args = {{0}};
+ return format("tf.gammaln", args, locals, opCount);
+ }
+ case "_np_sum":
+ {
+ Object[][] args = {{0}, {"axis", "axis"}, {"keepdims", "keepdims"}};
+ return format("tf.reduce_sum", args, locals, opCount);
+ }
+ case "_npi_maximum_scalar":
+ {
+ Object[][] args = {{0}, {"scalar", "y"}};
+ return format("tf.maximum", args, locals, opCount);
+ }
+ case "_npi_multiply_scalar":
+ {
+ Object[][] args = {{0}, {"scalar", "y"}};
+ return format("tf.multiply", args, locals, opCount);
+ }
+ case "_npi_multiply":
+ {
+ Object[][] args = {{0}, {1}};
+ return format("tf.multiply", args, locals, opCount);
+ }
+ case "_npi_true_divide":
+ {
+ Object[][] args = {{0}, {1}};
+ return format("tf.divide", args, locals, opCount);
+ }
+ case "_npi_greater":
+ {
+ Object[][] args = {{0}, {1}};
+ return format("tf.greater", args, locals, opCount);
+ }
+ case "_npi_negative":
+ {
+ Object[][] args = {{0}};
+ return format("tf.negative", args, locals, opCount);
+ }
+ case "_npi_absolute":
+ {
+ Object[][] args = {{0}};
+ return format("tf.abs", args, locals, opCount);
+ }
+ case "_npi_log":
+ {
+ Object[][] args = {{0}};
+ return format("tf.log", args, locals, opCount);
+ }
+ case "_npi_exp":
+ {
+ Object[][] args = {{0}};
+ return format("tf.exp", args, locals, opCount);
+ }
+ default:
+ {
+ Stream srcStream =
+ IntStream.range(0, src.length).mapToObj(i -> new Object[] {i});
+ Stream paramStream =
+ param.stream().map(p -> new Object[] {p.getKey(), p.getKey()});
+ Object[][] args =
+ Stream.concat(srcStream, paramStream).toArray(Object[][]::new);
+ return format(name, args, locals, opCount);
+ }
+ }
+ }
+
+ /**
+ * Constructs a Python expression for the given operation and formatting arguments.
+ *
+ * @param op tensorflow operation name
+ * @param args array of array of:
+ * [0]: index for {@link #src} or {@link #param} to retrieve argument value, or null
+ *
+ * [1]: tensorflow parameter name
+ * [2]: format of argument
+ * [3]: output shape of argument
+ * @param locals nodes stored in local Python variables
+ * @param opCount operation counter
+ * @return the Python expression
+ */
+ private String format(
+ String op, Object[][] args, Map locals, AtomicInteger opCount) {
+ StringBuilder sb = new StringBuilder(op + "(\n");
+ for (Object[] arg : args) {
+ String s = arg.length >= 3 ? String.valueOf(arg[2]) : "%s";
+ Shape shape = arg.length >= 4 ? (Shape) arg[3] : null;
+ if (Integer.valueOf(-1).equals(arg[0])) {
+ s =
+ Stream.of(src)
+ .map(node -> node.toPythonExpression(locals, opCount, true))
+ .map(Node::indent)
+ .collect(Collectors.joining(",\n", "[\n", "\n]"));
+ } else if (arg[0] instanceof Integer && src.length > (int) arg[0]) {
+ Node node = src[(int) arg[0]];
+ s = String.format(s, node.toPythonExpression(locals, opCount, true));
+ shape = node.outputShape;
+ } else if (this.param.get(String.valueOf(arg[0])) != null) {
+ s = String.format(s, this.param.get(String.valueOf(arg[0])));
+ } else if (arg[0] != null) {
+ continue; // cannot resolve index, so skip
+ }
+ if (s.startsWith("(") && s.endsWith(")")) {
+ s = String.format("[%s]", s.substring(1, s.length() - 1));
+ }
+ if (arg.length >= 2 && arg[1] != null) {
+ s = String.format("%s=%s", arg[1], s);
+ }
+ sb.append(indent(s) + "," + (shape != null ? " # " + shape : "") + "\n");
+ }
+ sb.append(
+ indent(
+ String.format(
+ "name='%s_%s_',",
+ op.substring(op.lastIndexOf('.') + 1), opCount.incrementAndGet())));
+ sb.append("\n)");
+ return sb.toString();
+ }
+
+ private String addBias(
+ String result,
+ boolean setChannelFirst,
+ Map locals,
+ AtomicInteger opCount) {
+ if (src.length == 3) {
+ Object[][] args = {
+ {null, null, result, this.outputShape},
+ {2, "bias"},
+ {null, "data_format", setChannelFirst ? "'NCHW'" : "None"}
+ };
+ return format("tf.nn.bias_add", args, locals, opCount);
+ }
+ return result;
+ }
+
+ private void identifyMultipleUsages(Map usages) {
+ if (isLeaf) {
+ return;
+ }
+ if (usages.compute(this, (key, count) -> count == null ? 1 : count + 1) >= 2) {
+ return;
+ }
+ for (Node node : src) {
+ node.identifyMultipleUsages(usages);
+ }
+ // reposition behind src nodes
+ usages.put(this, usages.remove(this));
+ }
+
+ String toPythonFunctionBody(AtomicInteger opCount, String result) {
+ @SuppressWarnings("PMD.UseConcurrentHashMap")
+ Map usages = new LinkedHashMap<>();
+ identifyMultipleUsages(usages);
+ Map locals = new ConcurrentHashMap<>();
+ List statements = new ArrayList<>();
+ int val = 1;
+ int batchnorm = 1;
+ for (Map.Entry usage : usages.entrySet()) {
+ Node node = usage.getKey();
+ if (usage.getValue() >= 2) {
+ // save the result of an expression that is used multiple times in local variable
+ locals.put(node, "val".concat(Integer.toString(val++)));
+ } else if ("_npx_batch_norm".equals(node.name)) {
+ // local required to assign locals 'running_mean' and 'running_var' at the same time
+ locals.put(node, "batchnorm".concat(Integer.toString(batchnorm++)));
+ }
+ }
+ for (Map.Entry usage : usages.entrySet()) {
+ Node node = usage.getKey();
+ if (usage.getValue() >= 2) {
+ statements.add(
+ String.format(
+ "%s = %s",
+ locals.get(node), node.toPythonExpression(locals, opCount)));
+ } else if ("_npx_batch_norm".equals(node.name)) {
+ statements.add(
+ String.format(
+ "(%s, running_mean, running_var) = %s",
+ locals.get(node), node.toPythonExpression(locals, opCount)));
+ statements.add(String.format("%s.assign(running_mean)", node.src[3].name));
+ statements.add(String.format("%s.assign(running_var)", node.src[4].name));
+ }
+ }
+ statements.add(String.format("%s = %s", result, toPythonExpression(locals, opCount)));
+ return statements.stream().map(Node::indent).collect(Collectors.joining(" \n"));
+ }
+
+ static String indent(String val) {
+ return val.replaceAll("(?m)^", " ");
+ }
+}
diff --git a/api/src/main/java/ai/djl/training/listener/TrainingListener.java b/api/src/main/java/ai/djl/training/listener/TrainingListener.java
index 3d81601f20f..c228bdade2b 100644
--- a/api/src/main/java/ai/djl/training/listener/TrainingListener.java
+++ b/api/src/main/java/ai/djl/training/listener/TrainingListener.java
@@ -13,11 +13,13 @@
package ai.djl.training.listener;
import ai.djl.Device;
+import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.training.Trainer;
import ai.djl.training.dataset.Batch;
import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
/**
* {@code TrainingListener} offers an interface that performs some actions when certain events have
@@ -163,6 +165,20 @@ static TrainingListener[] logging(String outputDir) {
new TimeMeasureTrainingListener(outputDir)
};
}
+
+ /**
+ * Returns listener for logging algebraic operation.
+ *
+ * @param outputFile the output file to store the algebraic log. Can be null which skips
+ * algebraic logging.
+ * @return the new set of listeners
+ */
+ static TrainingListener[] algebraicLogging(String outputFile) {
+ if (outputFile == null) {
+ return new TrainingListener[] {}; // algebraic logging disabled
+ }
+ return new TrainingListener[] {new AlgebraicListener(outputFile)};
+ }
}
/** A class to pass data from the batch into the training listeners. */
@@ -171,6 +187,8 @@ class BatchData {
private Batch batch;
private Map labels;
private Map predictions;
+ private Map data;
+ private Map loss;
/**
* Constructs a new {@link BatchData}.
@@ -183,6 +201,8 @@ public BatchData(Batch batch, Map labels, Map pr
this.batch = batch;
this.labels = labels;
this.predictions = predictions;
+ this.data = new ConcurrentHashMap<>();
+ this.loss = new ConcurrentHashMap<>();
}
/**
@@ -211,5 +231,23 @@ public Map getLabels() {
public Map getPredictions() {
return predictions;
}
+
+ /**
+ * Returns the data for each device.
+ *
+ * @return the data for each device
+ */
+ public Map getData() {
+ return data;
+ }
+
+ /**
+ * Returns the loss for each device.
+ *
+ * @return the loss for each device
+ */
+ public Map getLoss() {
+ return loss;
+ }
}
}
diff --git a/api/src/main/java/ai/djl/training/loss/AbstractCompositeLoss.java b/api/src/main/java/ai/djl/training/loss/AbstractCompositeLoss.java
index 2a46416190a..2e2cdcb8c86 100644
--- a/api/src/main/java/ai/djl/training/loss/AbstractCompositeLoss.java
+++ b/api/src/main/java/ai/djl/training/loss/AbstractCompositeLoss.java
@@ -80,10 +80,10 @@ public void addAccumulator(String key) {
/** {@inheritDoc} */
@Override
- public void updateAccumulator(String key, NDList labels, NDList predictions) {
+ public void updateAccumulators(String[] keys, NDList labels, NDList predictions) {
for (int i = 0; i < components.size(); i++) {
Pair inputs = inputForComponent(i, labels, predictions);
- components.get(i).updateAccumulator(key, inputs.getKey(), inputs.getValue());
+ components.get(i).updateAccumulators(keys, inputs.getKey(), inputs.getValue());
}
}
diff --git a/api/src/main/java/ai/djl/training/loss/Loss.java b/api/src/main/java/ai/djl/training/loss/Loss.java
index a661a3e9a0e..bcf39d23b39 100644
--- a/api/src/main/java/ai/djl/training/loss/Loss.java
+++ b/api/src/main/java/ai/djl/training/loss/Loss.java
@@ -385,10 +385,18 @@ public void addAccumulator(String key) {
/** {@inheritDoc} */
@Override
public void updateAccumulator(String key, NDList labels, NDList predictions) {
+ updateAccumulators(new String[] {key}, labels, predictions);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public void updateAccumulators(String[] keys, NDList labels, NDList predictions) {
// this is a synchronized operation, only call it at end of batch or epoch
float update = evaluate(labels, predictions).sum().getFloat();
- totalInstances.compute(key, (k, v) -> v + 1);
- totalLoss.compute(key, (k, v) -> v + update);
+ for (String key : keys) {
+ totalInstances.compute(key, (k, v) -> v + 1);
+ totalLoss.compute(key, (k, v) -> v + update);
+ }
}
/** {@inheritDoc} */
diff --git a/api/src/main/java/ai/djl/translate/PaddingStackBatchifier.java b/api/src/main/java/ai/djl/translate/PaddingStackBatchifier.java
index 3f3bb1b2d6e..f026bd431c9 100644
--- a/api/src/main/java/ai/djl/translate/PaddingStackBatchifier.java
+++ b/api/src/main/java/ai/djl/translate/PaddingStackBatchifier.java
@@ -29,10 +29,17 @@ public final class PaddingStackBatchifier implements Batchifier {
private static final long serialVersionUID = 1L;
+ @SuppressWarnings("serial")
private List arraysToPad;
+
+ @SuppressWarnings("serial")
private List dimsToPad;
+
private transient List paddingSuppliers;
+
+ @SuppressWarnings("serial")
private List paddingSizes;
+
private boolean includeValidLengths;
private PaddingStackBatchifier(Builder builder) {
diff --git a/api/src/main/java/ai/djl/util/Ec2Utils.java b/api/src/main/java/ai/djl/util/Ec2Utils.java
index 178c3d7efe7..5408182964f 100644
--- a/api/src/main/java/ai/djl/util/Ec2Utils.java
+++ b/api/src/main/java/ai/djl/util/Ec2Utils.java
@@ -97,7 +97,7 @@ public static String readMetadata(String key) {
* @param engine the default engine name
*/
public static void callHome(String engine) {
- if (Boolean.getBoolean("offline")
+ if (Utils.isOfflineMode()
|| Boolean.parseBoolean(Utils.getEnvOrSystemProperty("OPT_OUT_TRACKING"))
|| System.currentTimeMillis() - lastCheckIn < ONE_DAY) {
return;
diff --git a/api/src/main/java/ai/djl/util/StringPair.java b/api/src/main/java/ai/djl/util/StringPair.java
new file mode 100644
index 00000000000..a42e739614b
--- /dev/null
+++ b/api/src/main/java/ai/djl/util/StringPair.java
@@ -0,0 +1,27 @@
+/*
+ * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+package ai.djl.util;
+
+/** A class containing the string key-value pair. */
+public class StringPair extends Pair {
+
+ /**
+ * Constructs a {@code Pair} instance with key and value.
+ *
+ * @param key the key
+ * @param value the value
+ */
+ public StringPair(String key, String value) {
+ super(key, value);
+ }
+}
diff --git a/api/src/main/java/ai/djl/util/TarUtils.java b/api/src/main/java/ai/djl/util/TarUtils.java
new file mode 100644
index 00000000000..c02b278788f
--- /dev/null
+++ b/api/src/main/java/ai/djl/util/TarUtils.java
@@ -0,0 +1,69 @@
+/*
+ * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+package ai.djl.util;
+
+import org.apache.commons.compress.archivers.tar.TarArchiveEntry;
+import org.apache.commons.compress.archivers.tar.TarArchiveInputStream;
+import org.apache.commons.compress.compressors.gzip.GzipCompressorInputStream;
+import org.apache.commons.compress.utils.CloseShieldFilterInputStream;
+
+import java.io.BufferedInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.StandardCopyOption;
+
+/** Utilities for working with zip files. */
+public final class TarUtils {
+
+ private TarUtils() {}
+
+ /**
+ * Un-compress a tar ball from InputStream.
+ *
+ * @param is the InputStream
+ * @param dir the target directory
+ * @param gzip if the bar ball is gzip
+ * @throws IOException for failures to untar the input directory
+ */
+ public static void untar(InputStream is, Path dir, boolean gzip) throws IOException {
+ InputStream bis;
+ if (gzip) {
+ bis = new GzipCompressorInputStream(new BufferedInputStream(is));
+ } else {
+ bis = new BufferedInputStream(is);
+ }
+ bis = new CloseShieldFilterInputStream(bis);
+ try (TarArchiveInputStream tis = new TarArchiveInputStream(bis)) {
+ TarArchiveEntry entry;
+ while ((entry = tis.getNextEntry()) != null) {
+ String entryName = entry.getName();
+ if (entryName.contains("..")) {
+ throw new IOException("Malicious zip entry: " + entryName);
+ }
+ Path file = dir.resolve(entryName).toAbsolutePath();
+ if (entry.isDirectory()) {
+ Files.createDirectories(file);
+ } else {
+ Path parentFile = file.getParent();
+ if (parentFile == null) {
+ throw new AssertionError("Parent path should never be null: " + file);
+ }
+ Files.createDirectories(parentFile);
+ Files.copy(tis, file, StandardCopyOption.REPLACE_EXISTING);
+ }
+ }
+ }
+ }
+}
diff --git a/api/src/main/java/ai/djl/util/Utils.java b/api/src/main/java/ai/djl/util/Utils.java
index c8e1bd514ac..270958d5b40 100644
--- a/api/src/main/java/ai/djl/util/Utils.java
+++ b/api/src/main/java/ai/djl/util/Utils.java
@@ -357,6 +357,20 @@ public static Path getCacheDir() {
return Paths.get(cacheDir);
}
+ /**
+ * Returns if offline mode is enabled.
+ *
+ * @return true if offline mode is enabled
+ */
+ public static boolean isOfflineMode() {
+ String mode = getenv("DJL_OFFLINE", System.getProperty("ai.djl.offline"));
+ if (mode != null) {
+ return Boolean.parseBoolean(mode);
+ }
+ // backward compatible
+ return Boolean.getBoolean("offline");
+ }
+
/**
* Returns nested model directory if the directory contains only one subdirectory.
*
@@ -481,7 +495,7 @@ public static InputStream openUrl(String url) throws IOException {
*/
public static InputStream openUrl(URL url) throws IOException {
String protocol = url.getProtocol();
- if (Boolean.getBoolean("offline")
+ if (isOfflineMode()
&& ("http".equalsIgnoreCase(protocol) || "https".equalsIgnoreCase(protocol))) {
throw new IOException("Offline model is enabled.");
}
diff --git a/api/src/main/java/ai/djl/util/cuda/CudaUtils.java b/api/src/main/java/ai/djl/util/cuda/CudaUtils.java
index b0b8e3e4247..b30a208f6ab 100644
--- a/api/src/main/java/ai/djl/util/cuda/CudaUtils.java
+++ b/api/src/main/java/ai/djl/util/cuda/CudaUtils.java
@@ -22,7 +22,11 @@
import org.slf4j.LoggerFactory;
import java.io.File;
+import java.io.IOException;
+import java.io.InputStream;
import java.lang.management.MemoryUsage;
+import java.util.ArrayList;
+import java.util.List;
import java.util.Locale;
import java.util.regex.Pattern;
@@ -33,6 +37,8 @@ public final class CudaUtils {
private static final CudaLibrary LIB = loadLibrary();
+ private static String[] gpuInfo;
+
private CudaUtils() {}
/**
@@ -49,7 +55,15 @@ public static boolean hasCuda() {
*
* @return the number of GPUs available in the system
*/
+ @SuppressWarnings("PMD.NonThreadSafeSingleton")
public static int getGpuCount() {
+ if (Boolean.getBoolean("ai.djl.util.cuda.folk")) {
+ if (gpuInfo == null) {
+ gpuInfo = execute(-1); // NOPMD
+ }
+ return Integer.parseInt(gpuInfo[0]);
+ }
+
if (LIB == null) {
return 0;
}
@@ -79,7 +93,19 @@ public static int getGpuCount() {
*
* @return the version of CUDA runtime
*/
+ @SuppressWarnings("PMD.NonThreadSafeSingleton")
public static int getCudaVersion() {
+ if (Boolean.getBoolean("ai.djl.util.cuda.folk")) {
+ if (gpuInfo == null) {
+ gpuInfo = execute(-1);
+ }
+ int version = Integer.parseInt(gpuInfo[1]);
+ if (version == -1) {
+ throw new IllegalArgumentException("No cuda device found.");
+ }
+ return version;
+ }
+
if (LIB == null) {
throw new IllegalStateException("No cuda library is loaded.");
}
@@ -95,9 +121,6 @@ public static int getCudaVersion() {
* @return the version string of CUDA runtime
*/
public static String getCudaVersionString() {
- if (LIB == null) {
- throw new IllegalStateException("No cuda library is loaded.");
- }
int version = getCudaVersion();
int major = version / 1000;
int minor = (version / 10) % 10;
@@ -111,6 +134,14 @@ public static String getCudaVersionString() {
* @return the CUDA compute capability
*/
public static String getComputeCapability(int device) {
+ if (Boolean.getBoolean("ai.djl.util.cuda.folk")) {
+ String[] ret = execute(device);
+ if (ret.length != 3) {
+ throw new IllegalArgumentException(ret[0]);
+ }
+ return ret[0];
+ }
+
if (LIB == null) {
throw new IllegalStateException("No cuda library is loaded.");
}
@@ -137,6 +168,16 @@ public static MemoryUsage getGpuMemory(Device device) {
throw new IllegalArgumentException("Only GPU device is allowed.");
}
+ if (Boolean.getBoolean("ai.djl.util.cuda.folk")) {
+ String[] ret = execute(device.getDeviceId());
+ if (ret.length != 3) {
+ throw new IllegalArgumentException(ret[0]);
+ }
+ long total = Long.parseLong(ret[1]);
+ long used = Long.parseLong(ret[2]);
+ return new MemoryUsage(-1, used, used, total);
+ }
+
if (LIB == null) {
throw new IllegalStateException("No GPU device detected.");
}
@@ -155,8 +196,42 @@ public static MemoryUsage getGpuMemory(Device device) {
return new MemoryUsage(-1, committed, committed, total[0]);
}
+ /**
+ * The main entrypoint to get CUDA information with command line.
+ *
+ * @param args the command line arguments.
+ */
+ @SuppressWarnings("PMD.SystemPrintln")
+ public static void main(String[] args) {
+ int gpuCount = getGpuCount();
+ if (args.length == 0) {
+ if (gpuCount <= 0) {
+ System.out.println("0,-1");
+ return;
+ }
+ int cudaVersion = getCudaVersion();
+ System.out.println(gpuCount + "," + cudaVersion);
+ return;
+ }
+ try {
+ int deviceId = Integer.parseInt(args[0]);
+ if (deviceId < 0 || deviceId >= gpuCount) {
+ System.out.println("Invalid device: " + deviceId);
+ return;
+ }
+ MemoryUsage mem = getGpuMemory(Device.gpu(deviceId));
+ String cc = getComputeCapability(deviceId);
+ System.out.println(cc + ',' + mem.getMax() + ',' + mem.getUsed());
+ } catch (NumberFormatException e) {
+ System.out.println("Invalid device: " + args[0]);
+ }
+ }
+
private static CudaLibrary loadLibrary() {
try {
+ if (Boolean.getBoolean("ai.djl.util.cuda.folk")) {
+ return null;
+ }
if (System.getProperty("os.name").startsWith("Win")) {
String path = Utils.getenv("PATH");
if (path == null) {
@@ -187,15 +262,40 @@ private static CudaLibrary loadLibrary() {
} catch (UnsatisfiedLinkError e) {
logger.debug("cudart library not found.");
logger.trace("", e);
- return null;
- } catch (IncompatibleClassChangeError e) {
+ } catch (LinkageError e) {
logger.warn("You have a conflict version of JNA in the classpath.");
logger.debug("", e);
- return null;
} catch (SecurityException e) {
logger.warn("Access denied during loading cudart library.");
logger.trace("", e);
- return null;
+ }
+ return null;
+ }
+
+ private static String[] execute(int deviceId) {
+ try {
+ String javaHome = System.getProperty("java.home");
+ String classPath = System.getProperty("java.class.path");
+ String os = System.getProperty("os.name");
+ List cmd = new ArrayList<>(4);
+ if (os.startsWith("Win")) {
+ cmd.add(javaHome + "\\bin\\java.exe");
+ } else {
+ cmd.add(javaHome + "/bin/java");
+ }
+ cmd.add("-cp");
+ cmd.add(classPath);
+ cmd.add("ai.djl.util.cuda.CudaUtils");
+ if (deviceId >= 0) {
+ cmd.add(String.valueOf(deviceId));
+ }
+ Process ps = new ProcessBuilder(cmd).redirectErrorStream(true).start();
+ try (InputStream is = ps.getInputStream()) {
+ String line = Utils.toString(is).trim();
+ return line.split(",");
+ }
+ } catch (IOException e) {
+ throw new IllegalArgumentException("Failed get GPU information", e);
}
}
diff --git a/api/src/test/java/ai/djl/DeviceTest.java b/api/src/test/java/ai/djl/DeviceTest.java
index 92a0474c6e7..a69a502739b 100644
--- a/api/src/test/java/ai/djl/DeviceTest.java
+++ b/api/src/test/java/ai/djl/DeviceTest.java
@@ -13,6 +13,7 @@
package ai.djl;
+import ai.djl.Device.MultiDevice;
import ai.djl.engine.Engine;
import org.testng.Assert;
@@ -37,6 +38,9 @@ public void testDevice() {
System.setProperty("test_key", "test");
Engine.debugEnvironment();
+
+ Assert.assertEquals(1, Device.cpu().getDevices().size());
+ Assert.assertEquals(2, new MultiDevice(Device.gpu(1), Device.gpu(2)).getDevices().size());
}
@Test
@@ -54,5 +58,9 @@ public void testDeviceName() {
Device defaultDevice = Engine.getInstance().defaultDevice();
Assert.assertEquals(Device.fromName(""), defaultDevice);
Assert.assertEquals(Device.fromName(null), defaultDevice);
+
+ Assert.assertEquals(
+ Device.fromName("gpu1+gpu2"), new MultiDevice(Device.gpu(2), Device.gpu(1)));
+ Assert.assertEquals(Device.fromName("gpu1+gpu2"), new MultiDevice("gpu", 1, 3));
}
}
diff --git a/api/src/test/java/ai/djl/inference/streaming/PublisherBytesSupplierTest.java b/api/src/test/java/ai/djl/inference/streaming/PublisherBytesSupplierTest.java
index 8c140688124..a8b2bdfab62 100644
--- a/api/src/test/java/ai/djl/inference/streaming/PublisherBytesSupplierTest.java
+++ b/api/src/test/java/ai/djl/inference/streaming/PublisherBytesSupplierTest.java
@@ -15,32 +15,38 @@
import org.testng.Assert;
import org.testng.annotations.Test;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicInteger;
public class PublisherBytesSupplierTest {
@Test
- public void test() {
+ public void test() throws ExecutionException, InterruptedException {
AtomicInteger contentCount = new AtomicInteger();
PublisherBytesSupplier supplier = new PublisherBytesSupplier();
- // Add to supplier without subscriber
- supplier.appendContent(new byte[] {1}, false);
- Assert.assertEquals(contentCount.get(), 0);
+ new Thread(
+ () -> {
+ // Add to supplier without subscriber
+ supplier.appendContent(new byte[] {1}, false);
+ // Add to supplier with subscriber
+ supplier.appendContent(new byte[] {1}, true);
+ })
+ .start();
// Subscribing with data should trigger subscriptions
- supplier.subscribe(
- d -> {
- if (d == null) {
- // Do nothing on completion
- return;
- }
- contentCount.getAndIncrement();
- });
- Assert.assertEquals(contentCount.get(), 1);
+ CompletableFuture future =
+ supplier.subscribe(
+ d -> {
+ if (d == null) {
+ // Do nothing on completion
+ return;
+ }
+ contentCount.getAndIncrement();
+ });
- // Add to supplier with subscriber
- supplier.appendContent(new byte[] {1}, true);
+ future.get();
Assert.assertEquals(contentCount.get(), 2);
}
}
diff --git a/api/src/test/java/ai/djl/modality/cv/translator/YoloV8TranslatorFactoryTest.java b/api/src/test/java/ai/djl/modality/cv/translator/YoloV8TranslatorFactoryTest.java
new file mode 100644
index 00000000000..8fbbae7301b
--- /dev/null
+++ b/api/src/test/java/ai/djl/modality/cv/translator/YoloV8TranslatorFactoryTest.java
@@ -0,0 +1,76 @@
+/*
+ * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+package ai.djl.modality.cv.translator;
+
+import ai.djl.Model;
+import ai.djl.modality.Input;
+import ai.djl.modality.Output;
+import ai.djl.modality.cv.Image;
+import ai.djl.modality.cv.output.DetectedObjects;
+import ai.djl.translate.BasicTranslator;
+import ai.djl.translate.Translator;
+
+import org.testng.Assert;
+import org.testng.annotations.BeforeClass;
+import org.testng.annotations.Test;
+
+import java.io.InputStream;
+import java.net.URL;
+import java.nio.file.Path;
+import java.util.HashMap;
+import java.util.Map;
+
+public class YoloV8TranslatorFactoryTest {
+
+ private YoloV8TranslatorFactory factory;
+
+ @BeforeClass
+ public void setUp() {
+ factory = new YoloV8TranslatorFactory();
+ }
+
+ @Test
+ public void testGetSupportedTypes() {
+ Assert.assertEquals(factory.getSupportedTypes().size(), 5);
+ }
+
+ @Test
+ public void testNewInstance() {
+ Map arguments = new HashMap<>();
+ try (Model model = Model.newInstance("test")) {
+ Translator translator1 =
+ factory.newInstance(Image.class, DetectedObjects.class, model, arguments);
+ Assert.assertTrue(translator1 instanceof YoloV8Translator);
+
+ Translator translator2 =
+ factory.newInstance(Path.class, DetectedObjects.class, model, arguments);
+ Assert.assertTrue(translator2 instanceof BasicTranslator);
+
+ Translator translator3 =
+ factory.newInstance(URL.class, DetectedObjects.class, model, arguments);
+ Assert.assertTrue(translator3 instanceof BasicTranslator);
+
+ Translator translator4 =
+ factory.newInstance(InputStream.class, DetectedObjects.class, model, arguments);
+ Assert.assertTrue(translator4 instanceof BasicTranslator);
+
+ Translator translator5 =
+ factory.newInstance(Input.class, Output.class, model, arguments);
+ Assert.assertTrue(translator5 instanceof ImageServingTranslator);
+
+ Assert.assertThrows(
+ IllegalArgumentException.class,
+ () -> factory.newInstance(Image.class, Output.class, model, arguments));
+ }
+ }
+}
diff --git a/api/src/test/java/ai/djl/ndarray/NDSerializerTest.java b/api/src/test/java/ai/djl/ndarray/NDSerializerTest.java
index 0e38c2d8be6..e89f2244203 100644
--- a/api/src/test/java/ai/djl/ndarray/NDSerializerTest.java
+++ b/api/src/test/java/ai/djl/ndarray/NDSerializerTest.java
@@ -107,7 +107,7 @@ private static byte[] encode(NDArray array) throws IOException {
private static NDArray decode(NDManager manager, byte[] data) throws IOException {
try (ByteArrayInputStream bis = new ByteArrayInputStream(data)) {
- return NDSerializer.decodeNumpy(manager, bis);
+ return NDList.decode(manager, bis).get(0);
}
}
diff --git a/api/src/test/java/ai/djl/repository/ZooTest.java b/api/src/test/java/ai/djl/repository/ZooTest.java
index 2b44f967144..29fc10391aa 100644
--- a/api/src/test/java/ai/djl/repository/ZooTest.java
+++ b/api/src/test/java/ai/djl/repository/ZooTest.java
@@ -17,6 +17,7 @@
import ai.djl.modality.Output;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
+import ai.djl.repository.zoo.ModelZoo;
import org.testng.Assert;
import org.testng.annotations.Test;
@@ -48,4 +49,11 @@ public void testInvalidCriteria()
Criteria, ?> criteria = Criteria.builder().build();
criteria.loadModel();
}
+
+ @Test
+ public void testModelZooResolver() {
+ ModelZoo.setModelZooResolver(groupId -> null);
+ ModelZoo zoo = ModelZoo.getModelZoo("unknown");
+ Assert.assertNull(zoo);
+ }
}
diff --git a/api/src/test/java/ai/djl/util/SecurityManagerTest.java b/api/src/test/java/ai/djl/util/SecurityManagerTest.java
index fd9b5db72bc..1e9eb17f63c 100644
--- a/api/src/test/java/ai/djl/util/SecurityManagerTest.java
+++ b/api/src/test/java/ai/djl/util/SecurityManagerTest.java
@@ -74,8 +74,11 @@ public void checkPermission(Permission perm) {
}
};
System.setSecurityManager(sm);
-
- Assert.assertFalse(CudaUtils.hasCuda());
- Assert.assertEquals(CudaUtils.getGpuCount(), 0);
+ try {
+ Assert.assertFalse(CudaUtils.hasCuda());
+ Assert.assertEquals(CudaUtils.getGpuCount(), 0);
+ } finally {
+ System.setSecurityManager(null);
+ }
}
}
diff --git a/api/src/test/java/ai/djl/util/cuda/CudaUtilsTest.java b/api/src/test/java/ai/djl/util/cuda/CudaUtilsTest.java
index de1c5cb4a20..a598d8482e6 100644
--- a/api/src/test/java/ai/djl/util/cuda/CudaUtilsTest.java
+++ b/api/src/test/java/ai/djl/util/cuda/CudaUtilsTest.java
@@ -20,8 +20,6 @@
import org.testng.annotations.Test;
import java.lang.management.MemoryUsage;
-import java.util.Arrays;
-import java.util.List;
public class CudaUtilsTest {
@@ -30,6 +28,9 @@ public class CudaUtilsTest {
@Test
public void testCudaUtils() {
if (!CudaUtils.hasCuda()) {
+ Assert.assertThrows(CudaUtils::getCudaVersionString);
+ Assert.assertThrows(() -> CudaUtils.getComputeCapability(0));
+ Assert.assertThrows(() -> CudaUtils.getGpuMemory(Device.gpu()));
return;
}
// Possible to have CUDA and not have a GPU.
@@ -37,16 +38,24 @@ public void testCudaUtils() {
return;
}
- int cudaVersion = CudaUtils.getCudaVersion();
+ String cudaVersion = CudaUtils.getCudaVersionString();
String smVersion = CudaUtils.getComputeCapability(0);
MemoryUsage memoryUsage = CudaUtils.getGpuMemory(Device.gpu());
logger.info("CUDA runtime version: {}, sm: {}", cudaVersion, smVersion);
logger.info("Memory usage: {}", memoryUsage);
- Assert.assertTrue(cudaVersion >= 9020, "cuda 9.2+ required.");
+ Assert.assertNotNull(cudaVersion);
+ Assert.assertNotNull(smVersion);
+ }
- List supportedSm = Arrays.asList("37", "52", "60", "61", "70", "75");
- Assert.assertTrue(supportedSm.contains(smVersion), "Unsupported cuda sm: " + smVersion);
+ @Test
+ public void testCudaUtilsWithFolk() {
+ System.setProperty("ai.djl.util.cuda.folk", "true");
+ try {
+ testCudaUtils();
+ } finally {
+ System.clearProperty("ai.djl.util.cuda.folk");
+ }
}
}
diff --git a/apt.txt b/apt.txt
index 7083f85c374..c89953ff1f9 100644
--- a/apt.txt
+++ b/apt.txt
@@ -1 +1 @@
-openjdk-11-jdk
+openjdk-17-jdk
diff --git a/basicdataset/README.md b/basicdataset/README.md
index 37bab679551..217f58d22b3 100644
--- a/basicdataset/README.md
+++ b/basicdataset/README.md
@@ -29,7 +29,7 @@ You can pull the module from the central Maven repository by including the follo
ai.djl
basicdataset
- 0.23.0
+ 0.26.0
```
diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/FashionMnist.java b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/FashionMnist.java
index a92a9b6a3d4..deef04907be 100644
--- a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/FashionMnist.java
+++ b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/FashionMnist.java
@@ -30,6 +30,7 @@
import java.io.IOException;
import java.io.InputStream;
+import java.nio.ByteBuffer;
import java.util.Map;
/**
@@ -118,8 +119,9 @@ private NDArray readData(Artifact.Item item, long length) throws IOException {
byte[] buf = Utils.toByteArray(is);
try (NDArray array =
manager.create(
- new Shape(length, IMAGE_WIDTH, IMAGE_HEIGHT, 1), DataType.UINT8)) {
- array.set(buf);
+ ByteBuffer.wrap(buf),
+ new Shape(length, IMAGE_WIDTH, IMAGE_HEIGHT, 1),
+ DataType.UINT8)) {
return array.toType(DataType.FLOAT32, false);
}
}
@@ -132,8 +134,8 @@ private NDArray readLabel(Artifact.Item item) throws IOException {
}
byte[] buf = Utils.toByteArray(is);
- try (NDArray array = manager.create(new Shape(buf.length), DataType.UINT8)) {
- array.set(buf);
+ try (NDArray array =
+ manager.create(ByteBuffer.wrap(buf), new Shape(buf.length), DataType.UINT8)) {
return array.toType(DataType.FLOAT32, false);
}
}
diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/Mnist.java b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/Mnist.java
index 164ba9876cb..5503e721caa 100644
--- a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/Mnist.java
+++ b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/Mnist.java
@@ -30,6 +30,7 @@
import java.io.IOException;
import java.io.InputStream;
+import java.nio.ByteBuffer;
import java.util.Map;
/**
@@ -111,8 +112,9 @@ private NDArray readData(Artifact.Item item, long length) throws IOException {
}
byte[] buf = Utils.toByteArray(is);
- try (NDArray array = manager.create(new Shape(length, 28, 28, 1), DataType.UINT8)) {
- array.set(buf);
+ try (NDArray array =
+ manager.create(
+ ByteBuffer.wrap(buf), new Shape(length, 28, 28, 1), DataType.UINT8)) {
return array.toType(DataType.FLOAT32, false);
}
}
@@ -123,10 +125,9 @@ private NDArray readLabel(Artifact.Item item) throws IOException {
if (is.skip(8) != 8) {
throw new AssertionError("Failed skip data.");
}
-
byte[] buf = Utils.toByteArray(is);
- try (NDArray array = manager.create(new Shape(buf.length), DataType.UINT8)) {
- array.set(buf);
+ try (NDArray array =
+ manager.create(ByteBuffer.wrap(buf), new Shape(buf.length), DataType.UINT8)) {
return array.toType(DataType.FLOAT32, false);
}
}
diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/tabular/ListFeatures.java b/basicdataset/src/main/java/ai/djl/basicdataset/tabular/ListFeatures.java
index 42fc1744451..b04ae800a10 100644
--- a/basicdataset/src/main/java/ai/djl/basicdataset/tabular/ListFeatures.java
+++ b/basicdataset/src/main/java/ai/djl/basicdataset/tabular/ListFeatures.java
@@ -44,6 +44,7 @@ public ListFeatures(int initialCapacity) {
*
* @param source the source list
*/
+ @SuppressWarnings("this-escape")
public ListFeatures(List source) {
super(source.size());
addAll(source);
diff --git a/bom/README.md b/bom/README.md
index 44519846712..c98b9d1fbe1 100644
--- a/bom/README.md
+++ b/bom/README.md
@@ -22,7 +22,7 @@ will need to mention the type as pom and the scope as import) as the following:
ai.djl
bom
- 0.23.0
+ 0.26.0
pom
import
@@ -38,7 +38,7 @@ will need to mention the type as pom and the scope as import) as the following:
ai.djl
bom
- 0.23.0
+ 0.26.0
pom
import
@@ -65,7 +65,7 @@ will need to mention the type as pom and the scope as import) as the following:
- First you need add BOM into your build.gradle file as the following:
```
- implementation platform("ai.djl:bom:0.23.0")
+ implementation platform("ai.djl:bom:0.26.0")
```
- Then you import the desired DJL modules into to you pom.xml file (no version is needed):
diff --git a/bom/build.gradle b/bom/build.gradle
index 4708978b5b5..31317316138 100644
--- a/bom/build.gradle
+++ b/bom/build.gradle
@@ -28,6 +28,7 @@ dependencies {
api "ai.djl.fasttext:fasttext-engine:${version}"
api "ai.djl.hadoop:hadoop:${version}"
api "ai.djl.huggingface:tokenizers:${version}"
+ api "ai.djl.llama:llama:${version}"
api "ai.djl.ml.lightgbm:lightgbm:${version}"
api "ai.djl.ml.xgboost:xgboost-gpu:${version}"
api "ai.djl.ml.xgboost:xgboost:${version}"
@@ -115,15 +116,12 @@ publishing {
addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cpu", "win-x86_64", "${pytorch_version}")
addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cpu-precxx11", "linux-x86_64", "${pytorch_version}")
addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cpu-precxx11", "linux-aarch64", "${pytorch_version}")
- addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cu116", "linux-x86_64", "1.12.1")
- addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cu116", "win-x86_64", "1.12.1")
- addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cu116-precxx11", "linux-x86_64", "1.12.1")
+ addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cu121", "linux-x86_64", "${pytorch_version}")
+ addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cu121", "win-x86_64", "${pytorch_version}")
+ addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cu121-precxx11", "linux-x86_64", "${pytorch_version}")
addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cu117", "linux-x86_64", "1.13.1")
addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cu117", "win-x86_64", "1.13.1")
addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cu117-precxx11", "linux-x86_64", "1.13.1")
- addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cu118", "linux-x86_64", "${pytorch_version}")
- addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cu118", "win-x86_64", "${pytorch_version}")
- addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cu118-precxx11", "linux-x86_64", "${pytorch_version}")
addDependency(dependencies, "ai.djl.tensorflow", "tensorflow-native-cpu", "osx-x86_64", "${tensorflow_version}")
addDependency(dependencies, "ai.djl.tensorflow", "tensorflow-native-cpu", "linux-x86_64", "${tensorflow_version}")
addDependency(dependencies, "ai.djl.tensorflow", "tensorflow-native-cpu", "win-x86_64", "${tensorflow_version}")
diff --git a/build.gradle b/build.gradle
index f98b86c4e51..ca6f7e68133 100644
--- a/build.gradle
+++ b/build.gradle
@@ -44,6 +44,7 @@ configure(javaProjects()) {
targetCompatibility = JavaVersion.VERSION_11
options.compilerArgs << "-proc:none" << "-Xlint:all,-options,-static,-removal" << "-Werror"
}
+ javadoc.options.addStringOption("Xdoclint:none", "-quiet")
apply plugin: 'eclipse'
@@ -88,7 +89,7 @@ configure(javaProjects()) {
systemProperty "disableProgressBar", "true"
systemProperty "nightly", System.getProperty("nightly", "false")
if (gradle.startParameter.offline) {
- systemProperty "offline", "true"
+ systemProperty "ai.djl.offline", "true"
}
// This is used to avoid overriding on default engine for modules:
// mxnet-engine, mxnet-model-zoo, api (MockEngine), basicdataset, fasttext, etc
diff --git a/djl-zero/README.md b/djl-zero/README.md
index 2d2c473cc88..34acbac07b9 100644
--- a/djl-zero/README.md
+++ b/djl-zero/README.md
@@ -49,6 +49,6 @@ You can pull the module from the central Maven repository by including the follo
ai.djl
djl-zero
- 0.23.0
+ 0.26.0
```
diff --git a/docker/README.md b/docker/README.md
index 5b5bd01be2b..0df33be9f83 100644
--- a/docker/README.md
+++ b/docker/README.md
@@ -1,10 +1,12 @@
# Docker Resources
+
DJL provides docker files that you can use to setup containers with the appropriate environment for certain platforms.
We recommend setting up a docker container with the provided Dockerfile when developing for the following
platforms and/or engines.
## Windows
+
You can use the [docker file](https://github.com/deepjavalibrary/djl/blob/master/docker/windows/Dockerfile) provided by us.
Please note that this docker will only work with Windows server 2019 by default. If you want it to work with other
versions of Windows, you need to pass the version as an argument as follows:
@@ -14,19 +16,20 @@ docker build --build-arg version=
```
## TensorRT
+
You can use the [docker file](https://github.com/deepjavalibrary/djl/blob/master/docker/tensorrt/Dockerfile) provided by us.
This docker file is a modification of the one provided by NVIDIA in
-[TensorRT](https://github.com/NVIDIA/TensorRT/blob/8.4.1/docker/ubuntu-18.04.Dockerfile) to include JDK11.
-By default this sets up a container using Ubuntu 18.04 and CUDA 11.6.2. You can build the container with other versions as follows,
+[TensorRT](https://github.com/NVIDIA/TensorRT/blob/8.4.1/docker/ubuntu-18.04.Dockerfile) to include JDK17.
+By default this sets up a container using Ubuntu 18.04 and CUDA 11.6.2. You can build the container with other versions as follows,
but keep in mind the TensorRT software requirements outlined [here](https://github.com/NVIDIA/TensorRT#prerequisites):
```bash
docker build --build-arg OS_VERSION= --build-arg CUDA_VERSION=
```
-To run the container, we recommend using `nvidia-docker run ...` to ensure cuda driver and runtime are compatible.
+To run the container, we recommend using `nvidia-docker run ...` to ensure cuda driver and runtime are compatible.
-We recommend that you follow the setup steps in the [TensorRT guide](https://github.com/NVIDIA/TensorRT) if you
-need access to the full suite of tools TensorRT provides, such as `trtexec` which can convert onnx models to
-uff tensorrt models. When following that guide, make sure to use the DJL provided
-[docker file](https://github.com/deepjavalibrary/djl/blob/master/docker/tensorrt/Dockerfile) to enable JDK11 in the docker container.
+We recommend that you follow the setup steps in the [TensorRT guide](https://github.com/NVIDIA/TensorRT) if you
+need access to the full suite of tools TensorRT provides, such as `trtexec` which can convert onnx models to
+uff tensorrt models. When following that guide, make sure to use the DJL provided
+[docker file](https://github.com/deepjavalibrary/djl/blob/master/docker/tensorrt/Dockerfile) to enable JDK17 in the docker container.
diff --git a/docker/spark/Dockerfile b/docker/spark/Dockerfile
index b715899e2f1..b777d5a69ed 100644
--- a/docker/spark/Dockerfile
+++ b/docker/spark/Dockerfile
@@ -13,7 +13,7 @@ FROM 314815235551.dkr.ecr.us-east-2.amazonaws.com/sagemaker-spark-processing:3.3
LABEL maintainer="djl-dev@amazon.com"
# Install dependencies
-ARG DJL_VERSION=0.23.0
+ARG DJL_VERSION=0.24.0
ARG JNA_VERSION=5.13.0
ARG JAVACV_VERSION=1.5.9
ARG JAVACPP_VERSION=1.5.9
diff --git a/docker/tensorrt/Dockerfile b/docker/tensorrt/Dockerfile
index 3a99bb9cb5d..94f81230e19 100644
--- a/docker/tensorrt/Dockerfile
+++ b/docker/tensorrt/Dockerfile
@@ -42,7 +42,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
fakeroot \
dh-make \
build-essential \
- openjdk-11-jdk && \
+ openjdk-17-jdk && \
apt-get clean -y && rm -rf /var/lib/apt/lists/*
# Install python3
diff --git a/docker/windows/Dockerfile b/docker/windows/Dockerfile
index 31567b3168b..10989e8a4c8 100644
--- a/docker/windows/Dockerfile
+++ b/docker/windows/Dockerfile
@@ -11,4 +11,4 @@ RUN powershell -Command \
Set-ExecutionPolicy Bypass -Scope Process -Force; [System.Net.ServicePointManager]::SecurityProtocol = [System.Net.ServicePointManager]::SecurityProtocol -bor 3072; iex ((New-Object System.Net.WebClient).DownloadString('https://chocolatey.org/install.ps1')); \
choco feature disable --name showDownloadProgress
-RUN choco install -y openjdk11
+RUN choco install -y openjdk17
diff --git a/docs/README.md b/docs/README.md
index cdd02661c78..81d547e92f2 100644
--- a/docs/README.md
+++ b/docs/README.md
@@ -20,14 +20,14 @@ Note: when searching in JavaDoc, if your access is denied, please try removing t
- [Troubleshooting](development/troubleshooting.md)
- [Inference Optimization](development/inference_performance_optimization.md)
-## [Jupyter notebook tutorials](../jupyter/README.md)
-
-- **[Beginner Jupyter Tutorial](../jupyter/tutorial/README.md)**
-- [Run object detection with model zoo](../jupyter/object_detection_with_model_zoo.ipynb)
-- [Load pre-trained PyTorch model](../jupyter/load_pytorch_model.ipynb)
-- [Load pre-trained Apache MXNet model](../jupyter/load_mxnet_model.ipynb)
-- [Transfer learning example](../jupyter/transfer_learning_on_cifar10.ipynb)
-- [Question answering example](../jupyter/BERTQA.ipynb)
+## [Jupyter notebook tutorials](http://docs.djl.ai/docs/demos/jupyter/index.html)
+
+- **[Beginner Jupyter Tutorial](http://docs.djl.ai/docs/demos/jupyter/tutorial/index.html)**
+- [Run object detection with model zoo](http://docs.djl.ai/docs/demos/jupyter/object_detection_with_model_zoo.html)
+- [Load pre-trained PyTorch model](http://docs.djl.ai/docs/demos/jupyter/load_pytorch_model.html)
+- [Load pre-trained Apache MXNet model](http://docs.djl.ai/docs/demos/jupyter/load_mxnet_model.html)
+- [Transfer learning example](http://docs.djl.ai/docs/demos/jupyter/transfer_learning_on_cifar10.html)
+- [Question answering example](http://docs.djl.ai/docs/demos/jupyter/BERTQA.html)
## [API Examples](../examples/README.md)
diff --git a/docs/development/example_dataset.md b/docs/development/example_dataset.md
index 35e071f728b..63583c2fdeb 100644
--- a/docs/development/example_dataset.md
+++ b/docs/development/example_dataset.md
@@ -1,4 +1,4 @@
-## Example CSV Dataset
+# Custom CSV Dataset Example
If the provided Datasets don't meet your requirements, you can also easily extend our dataset to create your own customized dataset.
@@ -24,8 +24,8 @@ api group: 'org.apache.commons', name: 'commons-csv', version: '1.7'
In order to extend the dataset, the following dependencies are required:
```
-api "ai.djl:api:0.23.0"
-api "ai.djl:basicdataset:0.23.0"
+api "ai.djl:api:0.26.0"
+api "ai.djl:basicdataset:0.26.0"
```
There are four parts we need to implement for CSVDataset.
diff --git a/docs/development/external_libraries.md b/docs/development/external_libraries.md
index 7f57fec3165..701fb9d0a03 100644
--- a/docs/development/external_libraries.md
+++ b/docs/development/external_libraries.md
@@ -1,5 +1,4 @@
-
-## DJL external dependencies
+# DJL external dependencies
This document contains external libraries that DJL depends on and their versions.
diff --git a/docs/development/profiler.md b/docs/development/profiler.md
index 6db5739483c..4a2a9f626e4 100644
--- a/docs/development/profiler.md
+++ b/docs/development/profiler.md
@@ -1,4 +1,4 @@
-## Profiler (Experimental)
+# Engine Profiler Support
Currently, DJL supports experimental profilers for developers that
investigate the performance of operator execution as well as memory consumption.
diff --git a/docs/development/setup.md b/docs/development/setup.md
index e4eb73b2501..fb290eb0e3a 100644
--- a/docs/development/setup.md
+++ b/docs/development/setup.md
@@ -10,13 +10,13 @@ you can use the $JAVA_HOME environment variable to control which version of Java
For ubuntu:
```bash
-sudo apt-get install openjdk-11-jdk
+sudo apt-get install openjdk-17-jdk
```
For centos
```bash
-sudo yum install java-11-openjdk
+sudo yum install java-17-openjdk
```
For Mac:
@@ -24,7 +24,7 @@ For Mac:
```bash
brew tap homebrew/cask-versions
brew update
-brew install --cask temurin11
+brew install --cask zulu17
```
You can also download and install [Oracle JDK](https://www.oracle.com/technetwork/java/javase/overview/index.html)
diff --git a/docs/development/troubleshooting.md b/docs/development/troubleshooting.md
index ff03d32648e..1a04592dc12 100644
--- a/docs/development/troubleshooting.md
+++ b/docs/development/troubleshooting.md
@@ -105,6 +105,11 @@ For more information, please refer to [DJL Cache Management](cache_management.md
It happened when you had a wrong version with DJL and Deep Engines.
You can check the combination [here](dependency_management.md) and use DJL BOM to solve the issue.
+### 1.6 Manual initialization
+
+If you are using manual engine initialization, you must both register an engine and set it as the default.
+This can be done with `Engine.registerEngine(..)` and `Engine.setDefaultEngine(..)`.
+
## 2. IntelliJ throws the `No Log4j 2 configuration file found.` exception.
The following exception may appear after running the `./gradlew clean` command:
diff --git a/docs/hybrid_engine.md b/docs/hybrid_engine.md
index 58bdbe69cb4..ddde08337ee 100644
--- a/docs/hybrid_engine.md
+++ b/docs/hybrid_engine.md
@@ -21,17 +21,17 @@ to run in a hybrid mode:
To use it along with Apache MXNet for additional API support, add the following two dependencies:
```
-runtimeOnly "ai.djl.mxnet:mxnet-engine:0.23.0"
+runtimeOnly "ai.djl.mxnet:mxnet-engine:0.26.0"
```
You can also use PyTorch or TensorFlow Engine as the supplemental engine by adding their corresponding dependencies.
```
-runtimeOnly "ai.djl.pytorch:pytorch-engine:0.23.0"
+runtimeOnly "ai.djl.pytorch:pytorch-engine:0.26.0"
```
```
-runtimeOnly "ai.djl.tensorflow:tensorflow-engine:0.23.0"
+runtimeOnly "ai.djl.tensorflow:tensorflow-engine:0.26.0"
```
## How Hybrid works
diff --git a/docs/interactive_tool.md b/docs/interactive_tool.md
index ed102fedc8d..d7d267db710 100644
--- a/docs/interactive_tool.md
+++ b/docs/interactive_tool.md
@@ -63,7 +63,7 @@ After that, click `run` and you should see the following result:
Finally, you can get the running project setup by clicking `Get Template`. This will bring you a gradle project that can be used in your local machine.
-## [Java Jupyter Notebook](../jupyter/README.md)
+## [Java Jupyter Notebook](http://docs.djl.ai/docs/demos/jupyter/index.html)
Wait a second, are we talking about hosting Jupyter Notebook in python?
No, it’s Java 11, only.
@@ -71,9 +71,9 @@ No, it’s Java 11, only.
![jupyter](https://djl-ai.s3.amazonaws.com/web-data/images/jupyter.gif)
Inspired by Spencer Park’s [IJava project](https://github.com/SpencerPark/IJava), we integrated DJL with Jupyter Notebooks.
-For more information on the simple setup, follow the instructions in [DJL Jupyter notebooks](../jupyter/README.md#setup).
+For more information on the simple setup, follow the instructions in [DJL Jupyter notebooks](http://docs.djl.ai/docs/demos/jupyter/index.html#setup).
After that, use the Jupyter Notebook freely in your hosted server. You can do all kinds of work, like block building and plotting a graph.
-There are [tutorials and instructions](../jupyter/README.md#djl---jupyter-notebooks) to guide you how you can run training and/or inference with Java.
+There are [tutorials and instructions](http://docs.djl.ai/docs/demos/jupyter/index.html#djl---jupyter-notebooks) to guide you how you can run training and/or inference with Java.
## About Future Lab
diff --git a/docs/load_model.md b/docs/load_model.md
index 621d7514605..653ba3e91d7 100644
--- a/docs/load_model.md
+++ b/docs/load_model.md
@@ -181,7 +181,7 @@ Here is a few tips you can use to help you debug model loading issue:
See [here](development/configure_logging.md#configure-logging-level) for how to enable debug log
#### List models programmatically in your code
-You can use [ModelZoo.listModels()](https://javadoc.io/static/ai.djl/api/0.23.0/ai/djl/repository/zoo/ModelZoo.html#listModels--) API to query available models.
+You can use [ModelZoo.listModels()](https://javadoc.io/static/ai.djl/api/0.26.0/ai/djl/repository/zoo/ModelZoo.html#listModels--) API to query available models.
#### List available models using DJL command line
diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml
index c911bf43b2d..6511e9a865e 100644
--- a/docs/mkdocs.yml
+++ b/docs/mkdocs.yml
@@ -61,15 +61,15 @@ nav:
- 'docs/faq.md'
- Tutorials:
- Beginner Tutorial:
- - 'jupyter/tutorial/01_create_your_first_network.ipynb'
- - 'jupyter/tutorial/02_train_your_first_model.ipynb'
- - 'jupyter/tutorial/03_image_classification_with_your_model.ipynb'
+ - 'docs/demos/jupyter/tutorial/01_create_your_first_network.ipynb'
+ - 'docs/demos/jupyter/tutorial/02_train_your_first_model.ipynb'
+ - 'docs/demos/jupyter/tutorial/03_image_classification_with_your_model.ipynb'
- 'docs/d2l.md'
- - 'jupyter/rank_classification_using_BERT_on_Amazon_Review.ipynb'
- - 'jupyter/transfer_learning_on_cifar10.ipynb'
+ - 'docs/demos/jupyter/rank_classification_using_BERT_on_Amazon_Review.ipynb'
+ - 'docs/demos/jupyter/transfer_learning_on_cifar10.ipynb'
- Load your own BERT:
- - BERT with MXNet: 'jupyter/mxnet/load_your_own_mxnet_bert.ipynb'
- - BERT with PyTorch: 'jupyter/pytorch/load_your_own_pytorch_bert.ipynb'
+ - BERT with MXNet: 'docs/demos/jupyter/mxnet/load_your_own_mxnet_bert.ipynb'
+ - BERT with PyTorch: 'docs/demos/jupyter/pytorch/load_your_own_pytorch_bert.ipynb'
- Guides:
- Models:
- 'docs/load_model.md'
@@ -97,25 +97,25 @@ nav:
- PyTorch NDArray Operators: 'docs/pytorch/pytorch-djl-ndarray-cheatsheet.md'
- PyTorch Model Zoo: 'engines/pytorch/pytorch-model-zoo/README.md'
- Import PyTorch Model: 'docs/pytorch/how_to_convert_your_model_to_torchscript.md'
- - Load a PyTorch Model: 'jupyter/load_pytorch_model.ipynb'
+ - Load a PyTorch Model: 'docs/demos/jupyter/load_pytorch_model.ipynb'
- TensorFlow:
- Overview: 'engines/tensorflow/README.md'
- TensorFlow Engine: 'engines/tensorflow/tensorflow-engine/README.md'
- TensorFlow Model Zoo: 'engines/tensorflow/tensorflow-model-zoo/README.md'
- Import TensorFlow Model: 'docs/tensorflow/how_to_import_tensorflow_models_in_DJL.md'
- - Load a TensorFlow Model: 'jupyter/tensorflow/pneumonia_detection.ipynb'
+ - Load a TensorFlow Model: 'docs/demos/jupyter/tensorflow/pneumonia_detection.ipynb'
- Apache MXNet:
- Overview: 'engines/mxnet/README.md'
- MXNet Engine: 'engines/mxnet/mxnet-engine/README.md'
- MXNet Model Zoo: 'engines/mxnet/mxnet-model-zoo/README.md'
- Import Gluon Model: 'docs/mxnet/how_to_convert_your_model_to_symbol.md'
- - Load a MXNet Model: 'jupyter/load_mxnet_model.ipynb'
+ - Load a MXNet Model: 'docs/demos/jupyter/load_mxnet_model.ipynb'
- Backend Optimizer for MXNet: 'docs/mxnet/mxnet_backend_optimizer.md'
- Hybrid engines:
- Hybrid engine overview: 'docs/hybrid_engine.md'
- ONNX Runtime:
- Overview: 'engines/onnxruntime/onnxruntime-engine/README.md'
- - Load a ONNX Model: 'jupyter/onnxruntime/machine_learning_with_ONNXRuntime.ipynb'
+ - Load a ONNX Model: 'docs/demos/jupyter/onnxruntime/machine_learning_with_ONNXRuntime.ipynb'
- PaddlePaddle:
- Overview: 'engines/paddlepaddle/README.md'
- PaddlePaddle Engine: 'engines/paddlepaddle/paddlepaddle-engine/README.md'
@@ -124,11 +124,11 @@ nav:
- English: 'docs/paddlepaddle/how_to_create_paddlepaddle_model.md'
- 中文: 'docs/paddlepaddle/how_to_create_paddlepaddle_model_zh.md'
- Facemask detection using PaddlePaddle:
- - English: 'jupyter/paddlepaddle/face_mask_detection_paddlepaddle.ipynb'
- - 中文: 'jupyter/paddlepaddle/face_mask_detection_paddlepaddle_zh.ipynb'
+ - English: 'docs/demos/jupyter/paddlepaddle/face_mask_detection_paddlepaddle.ipynb'
+ - 中文: 'docs/demos/jupyter/paddlepaddle/face_mask_detection_paddlepaddle_zh.ipynb'
- PaddleOCR example:
- - English: 'jupyter/paddlepaddle/paddle_ocr_java.ipynb'
- - 中文: 'jupyter/paddlepaddle/paddle_ocr_java_zh.ipynb'
+ - English: 'docs/demos/jupyter/paddlepaddle/paddle_ocr_java.ipynb'
+ - 中文: 'docs/demos/jupyter/paddlepaddle/paddle_ocr_java_zh.ipynb'
- XGBoost: 'engines/ml/xgboost/README.md'
- LightGBM: 'engines/ml/lightgbm/README.md'
- TensorRT: 'engines/tensorrt/README.md'
@@ -153,15 +153,34 @@ nav:
- 'docs/serving/serving/docs/inference.md'
- 'docs/serving/serving/docs/modes.md'
- 'docs/serving/serving/docs/console.md'
- - 'docs/serving/serving/docs/configuration.md'
- - 'docs/serving/serving/docs/configurations.md'
- - 'docs/serving/serving/docs/workflows.md'
+ - Configuration:
+ - 'docs/serving/serving/docs/configuration.md'
+ - 'docs/serving/serving/docs/configurations_global.md'
+ - 'docs/serving/serving/docs/configurations.md'
+ - 'docs/serving/serving/docs/workflows.md'
+ - 'docs/serving/serving/docs/configurations_model.md'
- 'docs/serving/serving/docs/architecture.md'
- HTTP API:
- 'docs/serving/serving/docs/inference_api.md'
- 'docs/serving/serving/docs/management_api.md'
- 'docs/serving/serving/docs/plugin_management.md'
- 'docs/serving/wlm/README.md'
+ - Large Model Inference:
+ - 'docs/serving/serving/docs/large_model_inference.md'
+ - 'docs/serving/serving/docs/lmi/configurations_large_model_inference_containers.md'
+ - 'docs/serving/serving/docs/lmi/lmi_environment_variable_instruction.md'
+ - 'docs/serving/serving/docs/lmi/lmi_input_output_schema.md'
+ - Tutorials:
+ - 'docs/serving/serving/docs/lmi/tutorials/seq_scheduler_tutorial.md'
+ - 'docs/serving/serving/docs/lmi/tutorials/trtllm_aot_tutorial.md'
+ - 'docs/serving/serving/docs/lmi/tutorials/trtllm_manual_convert_tutorial.md'
+ - Tuning guides:
+ - 'docs/serving/serving/docs/lmi/tuning_guides/deepspeed_tuning_guide.md'
+ - 'docs/serving/serving/docs/lmi/tuning_guides/lmi_dist_tuning_guide.md'
+ - 'docs/serving/serving/docs/lmi/tuning_guides/tnx_tuning_guide.md'
+ - 'docs/serving/serving/docs/lmi/tuning_guides/trtllm_tuning_guide.md'
+ - SageMaker LMI containers resources:
+ - 'docs/demos/aws/sagemaker/large-model-inference/README.md'
- Demos:
- Demos: 'docs/demos/README.md'
- AWS:
diff --git a/docs/mxnet/how_to_convert_your_model_to_symbol.md b/docs/mxnet/how_to_convert_your_model_to_symbol.md
index be178afe437..57a5b8a9b05 100644
--- a/docs/mxnet/how_to_convert_your_model_to_symbol.md
+++ b/docs/mxnet/how_to_convert_your_model_to_symbol.md
@@ -1,4 +1,4 @@
-## How to convert your Gluon model to an MXNet Symbol
+# How to convert your Gluon model to an MXNet Symbol
DJL currently supports symbolic model loading from MXNet.
A gluon [HybridBlock](https://mxnet.apache.org/api/python/docs/api/gluon/hybrid_block.html) can be converted into a symbol for loading by doing as follows:
diff --git a/docs/paddlepaddle/how_to_create_paddlepaddle_model.md b/docs/paddlepaddle/how_to_create_paddlepaddle_model.md
index 042acbd2d61..b78d4406946 100644
--- a/docs/paddlepaddle/how_to_create_paddlepaddle_model.md
+++ b/docs/paddlepaddle/how_to_create_paddlepaddle_model.md
@@ -157,5 +157,5 @@ predictor.predict(list);
As mentioned, you need to find out what is the input for the model, like images usually interpret as NCHW (batch_size, channel, height, width).
-However, usage like this is really basic, you can write a `Translator` in DJL for it. You can find some code examples [here](../../jupyter/paddlepaddle/face_mask_detection_paddlepaddle.ipynb).
+However, usage like this is really basic, you can write a `Translator` in DJL for it. You can find some code examples [here](http://docs.djl.ai/docs/demos/jupyter/paddlepaddle/face_mask_detection_paddlepaddle.html).
diff --git a/docs/paddlepaddle/how_to_create_paddlepaddle_model_zh.md b/docs/paddlepaddle/how_to_create_paddlepaddle_model_zh.md
index 74e5dec634f..5f79d713783 100644
--- a/docs/paddlepaddle/how_to_create_paddlepaddle_model_zh.md
+++ b/docs/paddlepaddle/how_to_create_paddlepaddle_model_zh.md
@@ -156,4 +156,4 @@ predictor.predict(list);
在这里,你需要知道模型的输入输出格式, 比如图片经常表达成 NCHW (批大小, RGB通道, 高度, 宽度)的多维矩阵。
-虽然这样可以让模型跑起来, 但是最好还是结合 DJL 的 `Translator` class 使用。你可以在 [这里](../../jupyter/paddlepaddle/face_mask_detection_paddlepaddle.ipynb) 找到一些示例代码。
+虽然这样可以让模型跑起来, 但是最好还是结合 DJL 的 `Translator` class 使用。你可以在 [这里](http://docs.djl.ai/docs/demos/jupyter/paddlepaddle/face_mask_detection_paddlepaddle.html) 找到一些示例代码。
diff --git a/docs/pytorch/how_to_convert_your_model_to_torchscript.md b/docs/pytorch/how_to_convert_your_model_to_torchscript.md
index 4dd4b3102d7..f90ee468764 100644
--- a/docs/pytorch/how_to_convert_your_model_to_torchscript.md
+++ b/docs/pytorch/how_to_convert_your_model_to_torchscript.md
@@ -1,4 +1,4 @@
-## How to convert your PyTorch model to TorchScript
+# How to convert your PyTorch model to TorchScript
There are two ways to convert your model to TorchScript: tracing and scripting.
We will only demonstrate the first one, tracing, but you can find information about scripting from the PyTorch documentation.
diff --git a/docs/pytorch/pytorch-djl-ndarray-cheatsheet.md b/docs/pytorch/pytorch-djl-ndarray-cheatsheet.md
index 7416ec50bab..37d24276d82 100644
--- a/docs/pytorch/pytorch-djl-ndarray-cheatsheet.md
+++ b/docs/pytorch/pytorch-djl-ndarray-cheatsheet.md
@@ -1,4 +1,4 @@
-## PyTorch NDArray operators
+# PyTorch NDArray operators
In the following examples, we assume
diff --git a/docs/quick_start.md b/docs/quick_start.md
index f352a39156a..85a94494b2d 100644
--- a/docs/quick_start.md
+++ b/docs/quick_start.md
@@ -1,7 +1,7 @@
# Quick start
Deep Java Library (DJL) is designed to be easy to get started with and simple to use.
-The easiest way to learn DJL is to read the [beginner tutorial](../jupyter/tutorial/README.md) or
+The easiest way to learn DJL is to read the [beginner tutorial](http://docs.djl.ai/docs/demos/jupyter/tutorial/README.md) or
our [examples](../examples/README.md).
You can also view our 1.5 hour long (in 8 x ~10 minute segments) DJL 101 tutorial video series:
@@ -22,7 +22,7 @@ See [DJL Future Labs](interactive_tool.md)
## Beginner tutorial
-To get started, we recommend that you follow our short [beginner tutorial](../jupyter/tutorial/README.md). It takes you through some of the basics of deep learning to create a model, train your model, and run inference using your trained model.
+To get started, we recommend that you follow our short [beginner tutorial](http://docs.djl.ai/docs/demos/jupyter/tutorial/index.html). It takes you through some of the basics of deep learning to create a model, train your model, and run inference using your trained model.
## Run examples
@@ -33,7 +33,7 @@ All of our examples are executed by a simple command. For detailed command line
- [Train your first model](../examples/docs/train_mnist_mlp.md)
- [Single-shot Object Detection inference example](../examples/docs/object_detection.md)
- [More examples](https://github.com/deepjavalibrary/djl/tree/master/examples)
-- [Jupyter examples](../jupyter/README.md)
+- [Jupyter examples](http://docs.djl.ai/docs/demos/jupyter/index.html)
## Other resources
diff --git a/docs/telemetry.md b/docs/telemetry.md
index d6ff9b20bc1..256adf00a49 100644
--- a/docs/telemetry.md
+++ b/docs/telemetry.md
@@ -20,5 +20,5 @@ System.setProperty("OPT_OUT_TRACKING", "true")
Usage tracking is also disable in `offline` mode:
```java
-System.setProperty("offline", "true")
+System.setProperty("ai.djl.offline", "true")
```
diff --git a/engines/llama/.gitignore b/engines/llama/.gitignore
new file mode 100644
index 00000000000..3428b3b2f53
--- /dev/null
+++ b/engines/llama/.gitignore
@@ -0,0 +1,3 @@
+jnilib/
+llama.cpp/
+models/
diff --git a/engines/llama/CMakeLists.txt b/engines/llama/CMakeLists.txt
new file mode 100644
index 00000000000..d1fc8131db8
--- /dev/null
+++ b/engines/llama/CMakeLists.txt
@@ -0,0 +1,23 @@
+cmake_minimum_required(VERSION 3.12 FATAL_ERROR)
+
+project(djl_llama CXX)
+
+set(CMAKE_POSITION_INDEPENDENT_CODE ON)
+set(BUILD_SHARED_LIBS ON)
+
+set(JAVA_AWT_LIBRARY NotNeeded)
+set(JAVA_AWT_INCLUDE_PATH NotNeeded)
+find_package(JNI REQUIRED)
+
+add_subdirectory(llama.cpp)
+include(build-args.cmake)
+add_library(djl_llama SHARED src/main/native/ai_djl_llama.cpp)
+
+target_include_directories(djl_llama PRIVATE
+ ${JNI_INCLUDE_DIRS}
+ src/main/native
+ llama.cpp
+ llama.cpp/common
+ build/include)
+target_link_libraries(djl_llama PRIVATE common llama ${LLAMA_EXTRA_LIBS})
+target_compile_features(djl_llama PRIVATE cxx_std_11)
diff --git a/engines/llama/build-args.cmake b/engines/llama/build-args.cmake
new file mode 100644
index 00000000000..dee0db659cd
--- /dev/null
+++ b/engines/llama/build-args.cmake
@@ -0,0 +1,639 @@
+if (APPLE)
+ set(LLAMA_METAL_DEFAULT ON)
+else()
+ set(LLAMA_METAL_DEFAULT OFF)
+endif()
+
+# general
+option(LLAMA_NATIVE "llama: enable -march=native flag" ON)
+
+# instruction set specific
+if (LLAMA_NATIVE)
+ set(INS_ENB OFF)
+else()
+ set(INS_ENB ON)
+endif()
+
+option(LLAMA_AVX "llama: enable AVX" ${INS_ENB})
+option(LLAMA_AVX2 "llama: enable AVX2" ${INS_ENB})
+option(LLAMA_AVX512 "llama: enable AVX512" OFF)
+option(LLAMA_AVX512_VBMI "llama: enable AVX512-VBMI" OFF)
+option(LLAMA_AVX512_VNNI "llama: enable AVX512-VNNI" OFF)
+option(LLAMA_FMA "llama: enable FMA" ${INS_ENB})
+# in MSVC F16C is implied with AVX2/AVX512
+if (NOT MSVC)
+ option(LLAMA_F16C "llama: enable F16C" ${INS_ENB})
+endif()
+
+# 3rd party libs
+option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON)
+option(LLAMA_BLAS "llama: use BLAS" OFF)
+set(LLAMA_BLAS_VENDOR "Generic" CACHE STRING "llama: BLAS library vendor")
+option(LLAMA_CUBLAS "llama: use CUDA" OFF)
+#option(LLAMA_CUDA_CUBLAS "llama: use cuBLAS for prompt processing" OFF)
+option(LLAMA_CUDA_FORCE_DMMV "llama: use dmmv instead of mmvq CUDA kernels" OFF)
+option(LLAMA_CUDA_FORCE_MMQ "llama: use mmq kernels instead of cuBLAS" OFF)
+set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels")
+set(LLAMA_CUDA_MMV_Y "1" CACHE STRING "llama: y block size for mmv CUDA kernels")
+option(LLAMA_CUDA_F16 "llama: use 16 bit floats for some calculations" OFF)
+set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for Q2_K/Q6_K")
+set(LLAMA_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING
+ "llama: max. batch size for using peer access")
+option(LLAMA_HIPBLAS "llama: use hipBLAS" OFF)
+option(LLAMA_CLBLAST "llama: use CLBlast" OFF)
+option(LLAMA_METAL "llama: use Metal" ${LLAMA_METAL_DEFAULT})
+option(LLAMA_METAL_NDEBUG "llama: disable Metal debugging" OFF)
+option(LLAMA_MPI "llama: use MPI" OFF)
+option(LLAMA_QKK_64 "llama: use super-block size of 64 for k-quants" OFF)
+
+
+#
+# Compile flags
+#
+
+set(CMAKE_CXX_STANDARD 11)
+set(CMAKE_CXX_STANDARD_REQUIRED true)
+set(CMAKE_C_STANDARD 11)
+set(CMAKE_C_STANDARD_REQUIRED true)
+set(THREADS_PREFER_PTHREAD_FLAG ON)
+find_package(Threads REQUIRED)
+include(CheckCXXCompilerFlag)
+
+# enable libstdc++ assertions for debug builds
+if (CMAKE_SYSTEM_NAME MATCHES "Linux")
+ add_compile_definitions($<$:_GLIBCXX_ASSERTIONS>)
+endif()
+
+if (NOT MSVC)
+ if (LLAMA_SANITIZE_THREAD)
+ add_compile_options(-fsanitize=thread)
+ link_libraries(-fsanitize=thread)
+ endif()
+
+ if (LLAMA_SANITIZE_ADDRESS)
+ add_compile_options(-fsanitize=address -fno-omit-frame-pointer)
+ link_libraries(-fsanitize=address)
+ endif()
+
+ if (LLAMA_SANITIZE_UNDEFINED)
+ add_compile_options(-fsanitize=undefined)
+ link_libraries(-fsanitize=undefined)
+ endif()
+endif()
+
+if (APPLE AND LLAMA_ACCELERATE)
+ find_library(ACCELERATE_FRAMEWORK Accelerate)
+ if (ACCELERATE_FRAMEWORK)
+ message(STATUS "Accelerate framework found")
+
+ add_compile_definitions(GGML_USE_ACCELERATE)
+ add_compile_definitions(ACCELERATE_NEW_LAPACK)
+ add_compile_definitions(ACCELERATE_LAPACK_ILP64)
+ set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${ACCELERATE_FRAMEWORK})
+ else()
+ message(WARNING "Accelerate framework not found")
+ endif()
+endif()
+
+if (LLAMA_METAL)
+ find_library(FOUNDATION_LIBRARY Foundation REQUIRED)
+ find_library(METAL_FRAMEWORK Metal REQUIRED)
+ find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
+
+ message(STATUS "Metal framework found")
+ set(GGML_HEADERS_METAL ggml-metal.h)
+ set(GGML_SOURCES_METAL ggml-metal.m)
+
+ add_compile_definitions(GGML_USE_METAL)
+ if (LLAMA_METAL_NDEBUG)
+ add_compile_definitions(GGML_METAL_NDEBUG)
+ endif()
+
+ # get full path to the file
+ #add_compile_definitions(GGML_METAL_DIR_KERNELS="${CMAKE_CURRENT_SOURCE_DIR}/")
+
+ set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS}
+ ${FOUNDATION_LIBRARY}
+ ${METAL_FRAMEWORK}
+ ${METALKIT_FRAMEWORK}
+ )
+endif()
+if (LLAMA_BLAS)
+ if (LLAMA_STATIC)
+ set(BLA_STATIC ON)
+ endif()
+ if ($(CMAKE_VERSION) VERSION_GREATER_EQUAL 3.22)
+ set(BLA_SIZEOF_INTEGER 8)
+ endif()
+
+ set(BLA_VENDOR ${LLAMA_BLAS_VENDOR})
+ find_package(BLAS)
+
+ if (BLAS_FOUND)
+ message(STATUS "BLAS found, Libraries: ${BLAS_LIBRARIES}")
+
+ if ("${BLAS_INCLUDE_DIRS}" STREQUAL "")
+ # BLAS_INCLUDE_DIRS is missing in FindBLAS.cmake.
+ # see https://gitlab.kitware.com/cmake/cmake/-/issues/20268
+ find_package(PkgConfig REQUIRED)
+ if (${LLAMA_BLAS_VENDOR} MATCHES "Generic")
+ pkg_check_modules(DepBLAS REQUIRED blas)
+ elseif (${LLAMA_BLAS_VENDOR} MATCHES "OpenBLAS")
+ pkg_check_modules(DepBLAS REQUIRED openblas)
+ elseif (${LLAMA_BLAS_VENDOR} MATCHES "FLAME")
+ pkg_check_modules(DepBLAS REQUIRED blis)
+ elseif (${LLAMA_BLAS_VENDOR} MATCHES "ATLAS")
+ pkg_check_modules(DepBLAS REQUIRED blas-atlas)
+ elseif (${LLAMA_BLAS_VENDOR} MATCHES "FlexiBLAS")
+ pkg_check_modules(DepBLAS REQUIRED flexiblas_api)
+ elseif (${LLAMA_BLAS_VENDOR} MATCHES "Intel")
+ # all Intel* libraries share the same include path
+ pkg_check_modules(DepBLAS REQUIRED mkl-sdl)
+ elseif (${LLAMA_BLAS_VENDOR} MATCHES "NVHPC")
+ # this doesn't provide pkg-config
+ # suggest to assign BLAS_INCLUDE_DIRS on your own
+ if ("${NVHPC_VERSION}" STREQUAL "")
+ message(WARNING "Better to set NVHPC_VERSION")
+ else()
+ set(DepBLAS_FOUND ON)
+ set(DepBLAS_INCLUDE_DIRS "/opt/nvidia/hpc_sdk/${CMAKE_SYSTEM_NAME}_${CMAKE_SYSTEM_PROCESSOR}/${NVHPC_VERSION}/math_libs/include")
+ endif()
+ endif()
+ if (DepBLAS_FOUND)
+ set(BLAS_INCLUDE_DIRS ${DepBLAS_INCLUDE_DIRS})
+ else()
+ message(WARNING "BLAS_INCLUDE_DIRS neither been provided nor been automatically"
+ " detected by pkgconfig, trying to find cblas.h from possible paths...")
+ find_path(BLAS_INCLUDE_DIRS
+ NAMES cblas.h
+ HINTS
+ /usr/include
+ /usr/local/include
+ /usr/include/openblas
+ /opt/homebrew/opt/openblas/include
+ /usr/local/opt/openblas/include
+ /usr/include/x86_64-linux-gnu/openblas/include
+ )
+ endif()
+ endif()
+
+ message(STATUS "BLAS found, Includes: ${BLAS_INCLUDE_DIRS}")
+ add_compile_options(${BLAS_LINKER_FLAGS})
+ add_compile_definitions(GGML_USE_OPENBLAS)
+ if (${BLAS_INCLUDE_DIRS} MATCHES "mkl" AND (${LLAMA_BLAS_VENDOR} MATCHES "Generic" OR ${LLAMA_BLAS_VENDOR} MATCHES "Intel"))
+ add_compile_definitions(GGML_BLAS_USE_MKL)
+ endif()
+ set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${BLAS_LIBRARIES})
+ set(LLAMA_EXTRA_INCLUDES ${LLAMA_EXTRA_INCLUDES} ${BLAS_INCLUDE_DIRS})
+
+ else()
+ message(WARNING "BLAS not found, please refer to "
+ "https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors"
+ " to set correct LLAMA_BLAS_VENDOR")
+ endif()
+endif()
+
+if (LLAMA_QKK_64)
+ add_compile_definitions(GGML_QKK_64)
+endif()
+
+if (LLAMA_CUBLAS)
+ cmake_minimum_required(VERSION 3.17)
+
+ find_package(CUDAToolkit)
+ if (CUDAToolkit_FOUND)
+ message(STATUS "cuBLAS found")
+
+ enable_language(CUDA)
+
+ set(GGML_HEADERS_CUDA ggml-cuda.h)
+ set(GGML_SOURCES_CUDA ggml-cuda.cu)
+
+ add_compile_definitions(GGML_USE_CUBLAS)
+# if (LLAMA_CUDA_CUBLAS)
+# add_compile_definitions(GGML_CUDA_CUBLAS)
+# endif()
+ if (LLAMA_CUDA_FORCE_DMMV)
+ add_compile_definitions(GGML_CUDA_FORCE_DMMV)
+ endif()
+ if (LLAMA_CUDA_FORCE_MMQ)
+ add_compile_definitions(GGML_CUDA_FORCE_MMQ)
+ endif()
+ add_compile_definitions(GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X})
+ add_compile_definitions(GGML_CUDA_MMV_Y=${LLAMA_CUDA_MMV_Y})
+ if (DEFINED LLAMA_CUDA_DMMV_Y)
+ add_compile_definitions(GGML_CUDA_MMV_Y=${LLAMA_CUDA_DMMV_Y}) # for backwards compatibility
+ endif()
+ if (LLAMA_CUDA_F16 OR LLAMA_CUDA_DMMV_F16)
+ add_compile_definitions(GGML_CUDA_F16)
+ endif()
+ add_compile_definitions(K_QUANTS_PER_ITERATION=${LLAMA_CUDA_KQUANTS_ITER})
+ add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${LLAMA_CUDA_PEER_MAX_BATCH_SIZE})
+
+ if (LLAMA_STATIC)
+ if (WIN32)
+ # As of 12.3.1 CUDA Tookit for Windows does not offer a static cublas library
+ set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas CUDA::cublasLt)
+ else ()
+ set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
+ endif()
+ else()
+ set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
+ endif()
+
+ if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
+ # 52 == lowest CUDA 12 standard
+ # 60 == f16 CUDA intrinsics
+ # 61 == integer CUDA intrinsics
+ # 70 == compute capability at which unrolling a loop in mul_mat_q kernels is faster
+ if (LLAMA_CUDA_F16 OR LLAMA_CUDA_DMMV_F16)
+ set(CMAKE_CUDA_ARCHITECTURES "60;61;70") # needed for f16 CUDA intrinsics
+ else()
+ set(CMAKE_CUDA_ARCHITECTURES "52;61;70") # lowest CUDA 12 standard + lowest for integer intrinsics
+ #set(CMAKE_CUDA_ARCHITECTURES "") # use this to compile much faster, but only F16 models work
+ endif()
+ endif()
+ message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
+
+ else()
+ message(WARNING "cuBLAS not found")
+ endif()
+endif()
+
+if (LLAMA_MPI)
+ cmake_minimum_required(VERSION 3.10)
+ find_package(MPI)
+ if (MPI_C_FOUND)
+ message(STATUS "MPI found")
+ set(GGML_HEADERS_MPI ggml-mpi.h)
+ set(GGML_SOURCES_MPI ggml-mpi.c ggml-mpi.h)
+ add_compile_definitions(GGML_USE_MPI)
+ add_compile_definitions(${MPI_C_COMPILE_DEFINITIONS})
+ if (NOT MSVC)
+ add_compile_options(-Wno-cast-qual)
+ endif()
+ set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${MPI_C_LIBRARIES})
+ set(LLAMA_EXTRA_INCLUDES ${LLAMA_EXTRA_INCLUDES} ${MPI_C_INCLUDE_DIRS})
+ # Even if you're only using the C header, C++ programs may bring in MPI
+ # C++ functions, so more linkage is needed
+ if (MPI_CXX_FOUND)
+ set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${MPI_CXX_LIBRARIES})
+ endif()
+ else()
+ message(WARNING "MPI not found")
+ endif()
+endif()
+
+if (LLAMA_CLBLAST)
+ find_package(CLBlast)
+ if (CLBlast_FOUND)
+ message(STATUS "CLBlast found")
+
+ set(GGML_HEADERS_OPENCL ggml-opencl.h)
+ set(GGML_SOURCES_OPENCL ggml-opencl.cpp)
+
+ add_compile_definitions(GGML_USE_CLBLAST)
+
+ set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} clblast)
+ else()
+ message(WARNING "CLBlast not found")
+ endif()
+endif()
+
+if (LLAMA_HIPBLAS)
+ list(APPEND CMAKE_PREFIX_PATH /opt/rocm)
+
+ if (NOT ${CMAKE_C_COMPILER_ID} MATCHES "Clang")
+ message(WARNING "Only LLVM is supported for HIP, hint: CC=/opt/rocm/llvm/bin/clang")
+ endif()
+ if (NOT ${CMAKE_CXX_COMPILER_ID} MATCHES "Clang")
+ message(WARNING "Only LLVM is supported for HIP, hint: CXX=/opt/rocm/llvm/bin/clang++")
+ endif()
+
+ find_package(hip)
+ find_package(hipblas)
+ find_package(rocblas)
+
+ if (${hipblas_FOUND} AND ${hip_FOUND})
+ message(STATUS "HIP and hipBLAS found")
+ add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUBLAS)
+ add_library(ggml-rocm OBJECT ggml-cuda.cu ggml-cuda.h)
+ if (BUILD_SHARED_LIBS)
+ set_target_properties(ggml-rocm PROPERTIES POSITION_INDEPENDENT_CODE ON)
+ endif()
+ if (LLAMA_CUDA_FORCE_DMMV)
+ target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_FORCE_DMMV)
+ endif()
+ if (LLAMA_CUDA_FORCE_MMQ)
+ target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_FORCE_MMQ)
+ endif()
+ target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X})
+ target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_MMV_Y=${LLAMA_CUDA_MMV_Y})
+ target_compile_definitions(ggml-rocm PRIVATE K_QUANTS_PER_ITERATION=${LLAMA_CUDA_KQUANTS_ITER})
+ set_source_files_properties(ggml-cuda.cu PROPERTIES LANGUAGE CXX)
+ target_link_libraries(ggml-rocm PRIVATE hip::device PUBLIC hip::host roc::rocblas roc::hipblas)
+
+ if (LLAMA_STATIC)
+ message(FATAL_ERROR "Static linking not supported for HIP/ROCm")
+ endif()
+ set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ggml-rocm)
+ else()
+ message(WARNING "hipBLAS or HIP not found. Try setting CMAKE_PREFIX_PATH=/opt/rocm")
+ endif()
+endif()
+
+function(get_flags CCID CCVER)
+ set(C_FLAGS "")
+ set(CXX_FLAGS "")
+
+ if (CCID MATCHES "Clang")
+ set(C_FLAGS -Wunreachable-code-break -Wunreachable-code-return)
+ set(CXX_FLAGS -Wunreachable-code-break -Wunreachable-code-return -Wmissing-prototypes -Wextra-semi)
+
+ if (
+ (CCID STREQUAL "Clang" AND CCVER VERSION_GREATER_EQUAL 3.8.0) OR
+ (CCID STREQUAL "AppleClang" AND CCVER VERSION_GREATER_EQUAL 7.3.0)
+ )
+ set(C_FLAGS ${C_FLAGS} -Wdouble-promotion)
+ endif()
+ elseif (CCID STREQUAL "GNU")
+ set(C_FLAGS -Wdouble-promotion)
+ set(CXX_FLAGS -Wno-array-bounds)
+
+ if (CCVER VERSION_GREATER_EQUAL 7.1.0)
+ set(CXX_FLAGS ${CXX_FLAGS} -Wno-format-truncation)
+ endif()
+ if (CCVER VERSION_GREATER_EQUAL 8.1.0)
+ set(CXX_FLAGS ${CXX_FLAGS} -Wextra-semi)
+ endif()
+ endif()
+
+ set(GF_C_FLAGS ${C_FLAGS} PARENT_SCOPE)
+ set(GF_CXX_FLAGS ${CXX_FLAGS} PARENT_SCOPE)
+endfunction()
+
+if (LLAMA_ALL_WARNINGS)
+ if (NOT MSVC)
+ set(WARNING_FLAGS -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function)
+ set(C_FLAGS -Wshadow -Wstrict-prototypes -Wpointer-arith -Wmissing-prototypes
+ -Werror=implicit-int -Werror=implicit-function-declaration)
+ set(CXX_FLAGS -Wmissing-declarations -Wmissing-noreturn)
+
+ set(C_FLAGS ${WARNING_FLAGS} ${C_FLAGS})
+ set(CXX_FLAGS ${WARNING_FLAGS} ${CXX_FLAGS})
+
+ get_flags(${CMAKE_CXX_COMPILER_ID} ${CMAKE_CXX_COMPILER_VERSION})
+
+ add_compile_options("$<$:${C_FLAGS};${GF_C_FLAGS}>"
+ "$<$:${CXX_FLAGS};${GF_CXX_FLAGS}>")
+ else()
+ # todo : msvc
+ set(C_FLAGS "")
+ set(CXX_FLAGS "")
+ endif()
+endif()
+
+if (LLAMA_CUBLAS)
+ set(CUDA_FLAGS ${CXX_FLAGS} -use_fast_math)
+ if (NOT MSVC)
+ set(CUDA_FLAGS ${CUDA_FLAGS} -Wno-pedantic)
+ endif()
+
+ if (LLAMA_ALL_WARNINGS AND NOT MSVC)
+ set(NVCC_CMD ${CMAKE_CUDA_COMPILER} .c)
+ if (NOT CMAKE_CUDA_HOST_COMPILER STREQUAL "")
+ set(NVCC_CMD ${NVCC_CMD} -ccbin ${CMAKE_CUDA_HOST_COMPILER})
+ endif()
+
+ execute_process(
+ COMMAND ${NVCC_CMD} -Xcompiler --version
+ OUTPUT_VARIABLE CUDA_CCFULLVER
+ ERROR_QUIET
+ )
+
+ if (NOT CUDA_CCFULLVER MATCHES clang)
+ set(CUDA_CCID "GNU")
+ execute_process(
+ COMMAND ${NVCC_CMD} -Xcompiler "-dumpfullversion -dumpversion"
+ OUTPUT_VARIABLE CUDA_CCVER
+ ERROR_QUIET
+ )
+ else()
+ if (CUDA_CCFULLVER MATCHES Apple)
+ set(CUDA_CCID "AppleClang")
+ else()
+ set(CUDA_CCID "Clang")
+ endif()
+ string(REGEX REPLACE "^.* version ([0-9.]*).*$" "\\1" CUDA_CCVER ${CUDA_CCFULLVER})
+ endif()
+
+ message("-- CUDA host compiler is ${CUDA_CCID} ${CUDA_CCVER}")
+
+ get_flags(${CUDA_CCID} ${CUDA_CCVER})
+ list(JOIN GF_CXX_FLAGS " " CUDA_CXX_FLAGS) # pass host compiler flags as a single argument
+ if (NOT CUDA_CXX_FLAGS STREQUAL "")
+ set(CUDA_FLAGS ${CUDA_FLAGS} -Xcompiler ${CUDA_CXX_FLAGS})
+ endif()
+ endif()
+
+ add_compile_options("$<$:${CUDA_FLAGS}>")
+endif()
+
+if (WIN32)
+ add_compile_definitions(_CRT_SECURE_NO_WARNINGS)
+
+ if (BUILD_SHARED_LIBS)
+ set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON)
+ endif()
+endif()
+
+if (LLAMA_LTO)
+ include(CheckIPOSupported)
+ check_ipo_supported(RESULT result OUTPUT output)
+ if (result)
+ set(CMAKE_INTERPROCEDURAL_OPTIMIZATION TRUE)
+ else()
+ message(WARNING "IPO is not supported: ${output}")
+ endif()
+endif()
+
+# this version of Apple ld64 is buggy
+execute_process(
+ COMMAND ${CMAKE_C_COMPILER} ${CMAKE_EXE_LINKER_FLAGS} -Wl,-v
+ ERROR_VARIABLE output
+ OUTPUT_QUIET
+)
+if (output MATCHES "dyld-1015\.7")
+ add_compile_definitions(HAVE_BUGGY_APPLE_LINKER)
+endif()
+
+# Architecture specific
+# TODO: probably these flags need to be tweaked on some architectures
+# feel free to update the Makefile for your architecture and send a pull request or issue
+message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}")
+if (MSVC)
+ string(TOLOWER "${CMAKE_GENERATOR_PLATFORM}" CMAKE_GENERATOR_PLATFORM_LWR)
+ message(STATUS "CMAKE_GENERATOR_PLATFORM: ${CMAKE_GENERATOR_PLATFORM}")
+else ()
+ set(CMAKE_GENERATOR_PLATFORM_LWR "")
+endif ()
+
+if (NOT MSVC)
+ if (LLAMA_STATIC)
+ add_link_options(-static)
+ if (MINGW)
+ add_link_options(-static-libgcc -static-libstdc++)
+ endif()
+ endif()
+ if (LLAMA_GPROF)
+ add_compile_options(-pg)
+ endif()
+endif()
+
+if ((${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm") OR (${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64") OR ("${CMAKE_GENERATOR_PLATFORM_LWR}" MATCHES "arm64"))
+ message(STATUS "ARM detected")
+ if (MSVC)
+ add_compile_definitions(__ARM_NEON)
+ add_compile_definitions(__ARM_FEATURE_FMA)
+ add_compile_definitions(__ARM_FEATURE_DOTPROD)
+ # add_compile_definitions(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) # MSVC doesn't support vdupq_n_f16, vld1q_f16, vst1q_f16
+ add_compile_definitions(__aarch64__) # MSVC defines _M_ARM64 instead
+ else()
+ check_cxx_compiler_flag(-mfp16-format=ieee COMPILER_SUPPORTS_FP16_FORMAT_I3E)
+ if (NOT "${COMPILER_SUPPORTS_FP16_FORMAT_I3E}" STREQUAL "")
+ add_compile_options(-mfp16-format=ieee)
+ endif()
+ if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv6")
+ # Raspberry Pi 1, Zero
+ add_compile_options(-mfpu=neon-fp-armv8 -mno-unaligned-access)
+ endif()
+ if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv7")
+ # Raspberry Pi 2
+ add_compile_options(-mfpu=neon-fp-armv8 -mno-unaligned-access -funsafe-math-optimizations)
+ endif()
+ if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv8")
+ # Raspberry Pi 3, 4, Zero 2 (32-bit)
+ add_compile_options(-mno-unaligned-access)
+ endif()
+ endif()
+elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "^(x86_64|i686|AMD64)$" OR "${CMAKE_GENERATOR_PLATFORM_LWR}" MATCHES "^(x86_64|i686|amd64|x64)$" )
+ message(STATUS "x86 detected")
+ if (MSVC)
+ # instruction set detection for MSVC only
+ if (LLAMA_NATIVE)
+ include(${llama.cpp_SOURCE_DIR}/cmake/FindSIMD.cmake)
+ endif ()
+ if (LLAMA_AVX512)
+ add_compile_options($<$:/arch:AVX512>)
+ add_compile_options($<$:/arch:AVX512>)
+ # MSVC has no compile-time flags enabling specific
+ # AVX512 extensions, neither it defines the
+ # macros corresponding to the extensions.
+ # Do it manually.
+ if (LLAMA_AVX512_VBMI)
+ add_compile_definitions($<$:__AVX512VBMI__>)
+ add_compile_definitions($<$:__AVX512VBMI__>)
+ endif()
+ if (LLAMA_AVX512_VNNI)
+ add_compile_definitions($<$:__AVX512VNNI__>)
+ add_compile_definitions($<$:__AVX512VNNI__>)
+ endif()
+ elseif (LLAMA_AVX2)
+ add_compile_options($<$:/arch:AVX2>)
+ add_compile_options($<$:/arch:AVX2>)
+ elseif (LLAMA_AVX)
+ add_compile_options($<$:/arch:AVX>)
+ add_compile_options($<$:/arch:AVX>)
+ endif()
+ else()
+ if (LLAMA_NATIVE)
+ add_compile_options(-march=native)
+ endif()
+ if (LLAMA_F16C)
+ add_compile_options(-mf16c)
+ endif()
+ if (LLAMA_FMA)
+ add_compile_options(-mfma)
+ endif()
+ if (LLAMA_AVX)
+ add_compile_options(-mavx)
+ endif()
+ if (LLAMA_AVX2)
+ add_compile_options(-mavx2)
+ endif()
+ if (LLAMA_AVX512)
+ add_compile_options(-mavx512f)
+ add_compile_options(-mavx512bw)
+ endif()
+ if (LLAMA_AVX512_VBMI)
+ add_compile_options(-mavx512vbmi)
+ endif()
+ if (LLAMA_AVX512_VNNI)
+ add_compile_options(-mavx512vnni)
+ endif()
+ endif()
+elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64")
+ message(STATUS "PowerPC detected")
+ if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le")
+ add_compile_options(-mcpu=powerpc64le)
+ else()
+ add_compile_options(-mcpu=native -mtune=native)
+ #TODO: Add targets for Power8/Power9 (Altivec/VSX) and Power10(MMA) and query for big endian systems (ppc64/le/be)
+ endif()
+else()
+ message(STATUS "Unknown architecture")
+endif()
+
+if (MINGW)
+ # Target Windows 8 for PrefetchVirtualMemory
+ add_compile_definitions(_WIN32_WINNT=0x602)
+endif()
+
+#
+# POSIX conformance
+#
+
+# clock_gettime came in POSIX.1b (1993)
+# CLOCK_MONOTONIC came in POSIX.1-2001 / SUSv3 as optional
+# posix_memalign came in POSIX.1-2001 / SUSv3
+# M_PI is an XSI extension since POSIX.1-2001 / SUSv3, came in XPG1 (1985)
+add_compile_definitions(_XOPEN_SOURCE=600)
+
+# Somehow in OpenBSD whenever POSIX conformance is specified
+# some string functions rely on locale_t availability,
+# which was introduced in POSIX.1-2008, forcing us to go higher
+if (CMAKE_SYSTEM_NAME MATCHES "OpenBSD")
+ remove_definitions(-D_XOPEN_SOURCE=600)
+ add_compile_definitions(_XOPEN_SOURCE=700)
+endif()
+
+# Data types, macros and functions related to controlling CPU affinity and
+# some memory allocation are available on Linux through GNU extensions in libc
+if (CMAKE_SYSTEM_NAME MATCHES "Linux")
+ add_compile_definitions(_GNU_SOURCE)
+endif()
+
+# RLIMIT_MEMLOCK came in BSD, is not specified in POSIX.1,
+# and on macOS its availability depends on enabling Darwin extensions
+# similarly on DragonFly, enabling BSD extensions is necessary
+if (
+ CMAKE_SYSTEM_NAME MATCHES "Darwin" OR
+ CMAKE_SYSTEM_NAME MATCHES "iOS" OR
+ CMAKE_SYSTEM_NAME MATCHES "tvOS" OR
+ CMAKE_SYSTEM_NAME MATCHES "DragonFly"
+)
+ add_compile_definitions(_DARWIN_C_SOURCE)
+endif()
+
+# alloca is a non-standard interface that is not visible on BSDs when
+# POSIX conformance is specified, but not all of them provide a clean way
+# to enable it in such cases
+if (CMAKE_SYSTEM_NAME MATCHES "FreeBSD")
+ add_compile_definitions(__BSD_VISIBLE)
+endif()
+if (CMAKE_SYSTEM_NAME MATCHES "NetBSD")
+ add_compile_definitions(_NETBSD_SOURCE)
+endif()
+if (CMAKE_SYSTEM_NAME MATCHES "OpenBSD")
+ add_compile_definitions(_BSD_SOURCE)
+endif()
diff --git a/engines/llama/build.cmd b/engines/llama/build.cmd
new file mode 100644
index 00000000000..93c422028bc
--- /dev/null
+++ b/engines/llama/build.cmd
@@ -0,0 +1,23 @@
+@rem https://chocolatey.org/docs/installation#install-with-cmdexe
+@rem to install rust java etc..
+@rem choco install jdk17 -y
+
+set VERSION="%1"
+
+if exist "llama.cpp" (
+ echo Found "llama.cpp"
+) else (
+ git clone https://github.com/ggerganov/llama.cpp.git -b %VERSION%
+)
+
+if exist build rd /q /s build
+md build\classes
+cd build
+javac -sourcepath ..\src\main\java\ ..\src\main\java\ai\djl\llama\jni\LlamaLibrary.java -h include -d classes
+cmake ..
+cmake --build . --config Release
+
+@rem for nightly ci
+md jnilib\win-x86_64
+copy Release\djl_llama.dll jnilib\win-x86_64\
+copy bin\Release\llama.dll jnilib\win-x86_64\
diff --git a/engines/llama/build.gradle b/engines/llama/build.gradle
new file mode 100644
index 00000000000..e340758d18e
--- /dev/null
+++ b/engines/llama/build.gradle
@@ -0,0 +1,107 @@
+import java.util.zip.GZIPInputStream
+
+group "ai.djl.llama"
+
+dependencies {
+ api project(":api")
+
+ testImplementation project(":testing")
+ testImplementation "org.slf4j:slf4j-simple:${slf4j_version}"
+}
+
+compileJava.dependsOn(processResources)
+
+processResources {
+ outputs.dir file("${project.projectDir}/build/classes/java/main/native/lib")
+ doLast {
+ def url = "https://publish.djl.ai/llama/${llamacpp_version}/jnilib/${djl_version}"
+ def files = new String[]{
+ "linux-x86_64/libdjl_llama.so",
+ "linux-x86_64/libllama.so",
+ "linux-aarch64/libdjl_llama.so",
+ "linux-aarch64/libllama.so",
+ "osx-x86_64/libdjl_llama.dylib",
+ "osx-x86_64/libllama.dylib",
+ "osx-x86_64/ggml-metal.metal",
+ "osx-aarch64/libdjl_llama.dylib",
+ "osx-aarch64/libllama.dylib",
+ "osx-aarch64/ggml-metal.metal",
+ "win-x86_64/djl_llama.dll",
+ "win-x86_64/llama.dll",
+ }
+ def jnilibDir = "${project.projectDir}/jnilib/${djl_version}"
+ files.each { entry ->
+ def file = new File("${jnilibDir}/${entry}")
+ if (file.exists()) {
+ project.logger.lifecycle("prebuilt or cached file found for ${entry}")
+ } else if (!project.hasProperty("jni")) {
+ project.logger.lifecycle("Downloading ${url}/${entry}")
+ file.getParentFile().mkdirs()
+ def downloadPath = new URL("${url}/${entry}")
+ downloadPath.withInputStream { i -> file.withOutputStream { it << i } }
+ }
+ }
+ copy {
+ from jnilibDir
+ into "${project.projectDir}/build/classes/java/main/native/lib"
+ }
+
+ // write properties
+ def propFile = file("${project.projectDir}/build/classes/java/main/native/lib/llama.properties")
+ propFile.text = "version=${llamacpp_version}-${version}\n"
+
+ url = "https://mlrepo.djl.ai/model/nlp/text_generation/ai/djl/huggingface/gguf/models.json.gz"
+ def prefix = "${project.projectDir}/build/classes/java/main/nlp/text_generation"
+ def file = new File("${prefix}/ai.djl.huggingface.gguf.json")
+ if (file.exists()) {
+ project.logger.lifecycle("gguf index file already exists")
+ } else {
+ project.logger.lifecycle("Downloading gguf index file")
+ file.getParentFile().mkdirs()
+ def downloadPath = new URL(url)
+ downloadPath.withInputStream { i -> file.withOutputStream { it << new GZIPInputStream(i) } }
+ }
+ }
+}
+
+publishing {
+ publications {
+ maven(MavenPublication) {
+ pom {
+ name = "DJL NLP utilities for Llama.cpp"
+ description = "Deep Java Library (DJL) NLP utilities for llama.cpp"
+ url = "http://www.djl.ai/engines/${project.name}"
+ }
+ }
+ }
+}
+
+apply from: file("${rootProject.projectDir}/tools/gradle/cpp-formatter.gradle")
+
+tasks.register('compileJNI') {
+ doFirst {
+ if (System.properties['os.name'].toLowerCase(Locale.ROOT).contains("mac")
+ || System.properties['os.name'].toLowerCase(Locale.ROOT).contains("linux")) {
+ def arch = System.properties["os.arch"] == "amd64" ? "x86_64" : System.properties["os.arch"]
+ exec {
+ commandLine "bash", "build.sh", llamacpp_version, arch
+ }
+ } else {
+ exec {
+ commandLine "${project.projectDir}/build.cmd", llamacpp_version, "x86_64"
+ }
+ }
+
+ // for ci to upload to S3
+ def ciDir = "${project.projectDir}/jnilib/${djl_version}/"
+ copy {
+ from "${project.projectDir}/build/jnilib"
+ into ciDir
+ }
+ delete System.getProperty("user.home") + "/.djl.ai/llama"
+ }
+}
+
+clean.doFirst {
+ delete System.getProperty("user.home") + "/.djl.ai/llama"
+}
diff --git a/engines/llama/build.sh b/engines/llama/build.sh
new file mode 100755
index 00000000000..1cf7151cde4
--- /dev/null
+++ b/engines/llama/build.sh
@@ -0,0 +1,44 @@
+#!/usr/bin/env bash
+
+set -e
+WORK_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+NUM_PROC=1
+if [[ -n $(command -v nproc) ]]; then
+ NUM_PROC=$(nproc)
+elif [[ -n $(command -v sysctl) ]]; then
+ NUM_PROC=$(sysctl -n hw.ncpu)
+fi
+PLATFORM=$(uname | tr '[:upper:]' '[:lower:]')
+
+VERSION=$1
+ARCH=$2
+
+pushd $WORK_DIR
+if [ ! -d "llama.cpp" ]; then
+ git clone https://github.com/ggerganov/llama.cpp.git -b $VERSION
+fi
+
+if [ ! -d "build" ]; then
+ mkdir build
+fi
+cd build
+
+rm -rf classes
+mkdir classes
+javac -sourcepath ../src/main/java/ ../src/main/java/ai/djl/llama/jni/LlamaLibrary.java -h include -d classes
+cmake ..
+cmake --build . --config Release -- -j "${NUM_PROC}"
+
+popd
+
+# for nightly ci
+if [[ $PLATFORM == 'darwin' ]]; then
+ mkdir -p build/jnilib/osx-$ARCH
+ cp -f build/libdjl_llama.dylib build/jnilib/osx-$ARCH/
+ cp -f build/llama.cpp/libllama.dylib build/jnilib/osx-$ARCH/
+ cp -f llama.cpp/ggml-metal.metal build/jnilib/osx-$ARCH/
+elif [[ $PLATFORM == 'linux' ]]; then
+ mkdir -p build/jnilib/linux-$ARCH
+ cp -f build/libdjl_llama.so build/jnilib/linux-$ARCH/
+ cp -f build/llama.cpp/libllama.so build/jnilib/linux-$ARCH/
+fi
diff --git a/engines/llama/gradlew b/engines/llama/gradlew
new file mode 120000
index 00000000000..343e0d2caa4
--- /dev/null
+++ b/engines/llama/gradlew
@@ -0,0 +1 @@
+../../gradlew
\ No newline at end of file
diff --git a/engines/llama/src/main/java/ai/djl/llama/engine/LlamaEngine.java b/engines/llama/src/main/java/ai/djl/llama/engine/LlamaEngine.java
new file mode 100644
index 00000000000..75fdf5a5d8c
--- /dev/null
+++ b/engines/llama/src/main/java/ai/djl/llama/engine/LlamaEngine.java
@@ -0,0 +1,110 @@
+/*
+ * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+
+package ai.djl.llama.engine;
+
+import ai.djl.Device;
+import ai.djl.Model;
+import ai.djl.engine.Engine;
+import ai.djl.engine.EngineException;
+import ai.djl.llama.jni.LibUtils;
+import ai.djl.ndarray.NDManager;
+import ai.djl.util.Platform;
+import ai.djl.util.passthrough.PassthroughNDManager;
+
+/** The {@code LlamaEngine} is an implementation of the {@link Engine} based on the llama.cpp. */
+public final class LlamaEngine extends Engine {
+
+ public static final String ENGINE_NAME = "Llama";
+ static final int RANK = 10;
+
+ private Engine alternativeEngine;
+ private boolean initialized;
+
+ private LlamaEngine() {
+ try {
+ LibUtils.loadLibrary();
+ } catch (EngineException e) { // NOPMD
+ throw e;
+ } catch (Throwable t) {
+ throw new EngineException("Failed to load llama.cpp native library", t);
+ }
+ }
+
+ static Engine newInstance() {
+ return new LlamaEngine();
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public Engine getAlternativeEngine() {
+ if (!initialized && !Boolean.getBoolean("ai.djl.llama.disable_alternative")) {
+ Engine engine = Engine.getInstance();
+ if (engine.getRank() < getRank()) {
+ // alternativeEngine should not have the same rank as Llama
+ alternativeEngine = engine;
+ }
+ initialized = true;
+ }
+ return alternativeEngine;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public String getEngineName() {
+ return ENGINE_NAME;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public int getRank() {
+ return RANK;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public String getVersion() {
+ Platform platform = Platform.detectPlatform("llama");
+ return platform.getVersion();
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public boolean hasCapability(String capability) {
+ return false;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public Model newModel(String name, Device device) {
+ return new LlamaModel(name, newBaseManager(device));
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public NDManager newBaseManager() {
+ return newBaseManager(null);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public NDManager newBaseManager(Device device) {
+ return PassthroughNDManager.INSTANCE;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public String toString() {
+ return getEngineName() + ':' + getVersion() + ", " + getEngineName() + ':' + getVersion();
+ }
+}
diff --git a/engines/llama/src/main/java/ai/djl/llama/engine/LlamaEngineProvider.java b/engines/llama/src/main/java/ai/djl/llama/engine/LlamaEngineProvider.java
new file mode 100644
index 00000000000..ca5cc646498
--- /dev/null
+++ b/engines/llama/src/main/java/ai/djl/llama/engine/LlamaEngineProvider.java
@@ -0,0 +1,42 @@
+/*
+ * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+package ai.djl.llama.engine;
+
+import ai.djl.engine.Engine;
+import ai.djl.engine.EngineProvider;
+
+/** {@code LlamaEngineProvider} is the Llama implementation of {@link EngineProvider}. */
+public class LlamaEngineProvider implements EngineProvider {
+
+ /** {@inheritDoc} */
+ @Override
+ public String getEngineName() {
+ return LlamaEngine.ENGINE_NAME;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public int getEngineRank() {
+ return LlamaEngine.RANK;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public Engine getEngine() {
+ return InstanceHolder.INSTANCE;
+ }
+
+ private static class InstanceHolder {
+ static final Engine INSTANCE = LlamaEngine.newInstance();
+ }
+}
diff --git a/engines/llama/src/main/java/ai/djl/llama/engine/LlamaInput.java b/engines/llama/src/main/java/ai/djl/llama/engine/LlamaInput.java
new file mode 100644
index 00000000000..4b4d332fc9f
--- /dev/null
+++ b/engines/llama/src/main/java/ai/djl/llama/engine/LlamaInput.java
@@ -0,0 +1,430 @@
+/*
+ * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+package ai.djl.llama.engine;
+
+import ai.djl.llama.jni.InputParameters;
+
+import com.google.gson.annotations.SerializedName;
+
+import java.util.Map;
+
+/** A class hold input data for Llama model. */
+public class LlamaInput {
+
+ private String inputs;
+ private String prefix;
+ private String suffix;
+ private Parameters parameters;
+
+ /**
+ * Returns the input prompt.
+ *
+ * @return the input prompt
+ */
+ public String getInputs() {
+ return inputs;
+ }
+
+ /**
+ * Sets the input prompt.
+ *
+ * @param inputs the input prompt
+ */
+ public void setInputs(String inputs) {
+ this.inputs = inputs;
+ }
+
+ /**
+ * Returns the prompt prefix.
+ *
+ * @return the prompt prefix
+ */
+ public String getPrefix() {
+ return prefix;
+ }
+
+ /**
+ * Sets the prompt prefix.
+ *
+ * @param prefix the prompt prefix
+ */
+ public void setPrefix(String prefix) {
+ this.prefix = prefix;
+ }
+
+ /**
+ * Returns the prompt suffix.
+ *
+ * @return the prompt suffix
+ */
+ public String getSuffix() {
+ return suffix;
+ }
+
+ /**
+ * Sets the prompt suffix.
+ *
+ * @param suffix the prompt suffix
+ */
+ public void setSuffix(String suffix) {
+ this.suffix = suffix;
+ }
+
+ /**
+ * Returns the input parameters.
+ *
+ * @return the input parameters
+ */
+ public Parameters getParameters() {
+ if (parameters == null) {
+ parameters = new Parameters();
+ }
+ return parameters;
+ }
+
+ /**
+ * Sets the input parameters.
+ *
+ * @param parameters the input parameters
+ */
+ public void setParameters(Parameters parameters) {
+ this.parameters = parameters;
+ }
+
+ /** The input parameters class. */
+ public static final class Parameters {
+
+ @SerializedName("max_new_tokens")
+ private int nPredict;
+
+ @SerializedName("number_keep")
+ private int nKeep;
+
+ @SerializedName("number_probabilities")
+ private int nProbs;
+
+ @SerializedName("top_k")
+ private int topK;
+
+ @SerializedName("top_p")
+ private float topP;
+
+ @SerializedName("tfs_z")
+ private float tfsZ;
+
+ @SerializedName("typical_p")
+ private float typicalP;
+
+ @SerializedName("temperature")
+ private float temperature;
+
+ @SerializedName("repeat_penalty")
+ private float repeatPenalty;
+
+ @SerializedName("repeat_last_n")
+ private int repeatLastN;
+
+ @SerializedName("frequency_penalty")
+ private float frequencyPenalty;
+
+ @SerializedName("presence_penalty")
+ private float presencePenalty;
+
+ @SerializedName("penalize_nl")
+ private boolean penalizeNl;
+
+ @SerializedName("ignore_eos")
+ private boolean ignoreEos;
+
+ @SerializedName("mirostat")
+ private int mirostat;
+
+ @SerializedName("mirostat_tau")
+ private float mirostatTau;
+
+ @SerializedName("mirostat_eta")
+ private float mirostatEta;
+
+ @SerializedName("number_beams")
+ private int nBeams;
+
+ @SerializedName("seed")
+ private int seed;
+
+ @SerializedName("logit_bias")
+ private Map logitBias;
+
+ @SerializedName("grammar")
+ private String grammar;
+
+ @SerializedName("anti_prompt")
+ private String[] antiPrompt;
+
+ /**
+ * Sets the max new tokens.
+ *
+ * @param maxNewTokens the max new tokens
+ */
+ public void setMaxNewTokens(int maxNewTokens) {
+ this.nPredict = maxNewTokens;
+ }
+
+ /**
+ * Sets the number of keep.
+ *
+ * @param nKeep the number of keep
+ */
+ public void setNumberKeep(int nKeep) {
+ this.nKeep = nKeep;
+ }
+
+ /**
+ * Sets the number of probabilities.
+ *
+ * @param nProbs the number of probabilities
+ */
+ public void setNumberProbabilities(int nProbs) {
+ this.nProbs = nProbs;
+ }
+
+ /**
+ * Sets the top K.
+ *
+ * @param topK the top K
+ */
+ public void setTopK(int topK) {
+ this.topK = topK;
+ }
+
+ /**
+ * Sets the top P.
+ *
+ * @param topP the top P
+ */
+ public void setTopP(float topP) {
+ this.topP = topP;
+ }
+
+ /**
+ * Sets the tfs Z.
+ *
+ * @param tfsZ the tfs Z
+ */
+ public void setTfsZ(float tfsZ) {
+ this.tfsZ = tfsZ;
+ }
+
+ /**
+ * Sets the typical P.
+ *
+ * @param typicalP the typical P
+ */
+ public void setTypicalP(float typicalP) {
+ this.typicalP = typicalP;
+ }
+
+ /**
+ * Sets the temperature.
+ *
+ * @param temperature the temperature
+ */
+ public void setTemperature(float temperature) {
+ this.temperature = temperature;
+ }
+
+ /**
+ * Sets the repeat penalty.
+ *
+ * @param repeatPenalty the repeat penalty
+ */
+ public void setRepeatPenalty(float repeatPenalty) {
+ this.repeatPenalty = repeatPenalty;
+ }
+
+ /**
+ * Sets the repeat last N.
+ *
+ * @param repeatLastN the repeat last N
+ */
+ public void setRepeatLastN(int repeatLastN) {
+ this.repeatLastN = repeatLastN;
+ }
+
+ /**
+ * Sets the frequency penalty.
+ *
+ * @param frequencyPenalty the frequency penalty
+ */
+ public void setFrequencyPenalty(float frequencyPenalty) {
+ this.frequencyPenalty = frequencyPenalty;
+ }
+
+ /**
+ * Sets the presence penalty.
+ *
+ * @param presencePenalty the presence penalty
+ */
+ public void setPresencePenalty(float presencePenalty) {
+ this.presencePenalty = presencePenalty;
+ }
+
+ /**
+ * Sets the penalize nl.
+ *
+ * @param penalizeNl the penalize nl
+ */
+ public void setPenalizeNl(boolean penalizeNl) {
+ this.penalizeNl = penalizeNl;
+ }
+
+ /**
+ * Sets if ignore EOS.
+ *
+ * @param ignoreEos if ignore EOS
+ */
+ public void setIgnoreEos(boolean ignoreEos) {
+ this.ignoreEos = ignoreEos;
+ }
+
+ /**
+ * Sets the mirostat.
+ *
+ * @param mirostat the mirostat
+ */
+ public void setMirostat(int mirostat) {
+ this.mirostat = mirostat;
+ }
+
+ /**
+ * Sets the mirostat TAU.
+ *
+ * @param mirostatTau the mirostat TAU
+ */
+ public void setMirostatTau(float mirostatTau) {
+ this.mirostatTau = mirostatTau;
+ }
+
+ /**
+ * Sets the mirostat ETA.
+ *
+ * @param mirostatEta the mirostat ETA
+ */
+ public void setMirostatEta(float mirostatEta) {
+ this.mirostatEta = mirostatEta;
+ }
+
+ /**
+ * Sets the number of beams.
+ *
+ * @param nBeams the number of beams
+ */
+ public void setNumberBeams(int nBeams) {
+ this.nBeams = nBeams;
+ }
+
+ /**
+ * Sets the seed.
+ *
+ * @param seed the seed
+ */
+ public void setSeed(int seed) {
+ this.seed = seed;
+ }
+
+ /**
+ * Sets the logit bias.
+ *
+ * @param logitBias the logit bias
+ */
+ public void setLogitBias(Map logitBias) {
+ this.logitBias = logitBias;
+ }
+
+ /**
+ * Sets the grammar template.
+ *
+ * @param grammar the grammar template
+ */
+ public void setGrammar(String grammar) {
+ this.grammar = grammar;
+ }
+
+ /**
+ * Sets the anti prompt.
+ *
+ * @param antiPrompt the anti prompt
+ */
+ public void setAntiPrompt(String[] antiPrompt) {
+ this.antiPrompt = antiPrompt;
+ }
+
+ /**
+ * Returns the {@link InputParameters} object.
+ *
+ * @return the {@link InputParameters} object
+ */
+ public InputParameters toInputParameters() {
+ setDefaultValue();
+ return new InputParameters(
+ nPredict,
+ nKeep,
+ nProbs,
+ topK,
+ topP,
+ tfsZ,
+ typicalP,
+ temperature,
+ repeatPenalty,
+ repeatLastN,
+ frequencyPenalty,
+ presencePenalty,
+ penalizeNl,
+ ignoreEos,
+ mirostat,
+ mirostatTau,
+ mirostatEta,
+ nBeams,
+ seed,
+ logitBias,
+ grammar,
+ antiPrompt);
+ }
+
+ private void setDefaultValue() {
+ if (nPredict == 0) {
+ nPredict = -1;
+ }
+ if (topK == 0) {
+ topK = 40;
+ }
+ if (topP == 0) {
+ topP = 0.95f;
+ }
+ if (tfsZ == 0) {
+ tfsZ = 1f;
+ }
+ if (typicalP == 0) {
+ typicalP = 1f;
+ }
+ if (temperature == 0) {
+ temperature = 0.8f;
+ }
+ if (repeatPenalty == 0) {
+ repeatPenalty = 1.10f;
+ }
+ if (repeatLastN == 0) {
+ repeatLastN = 64;
+ }
+ }
+ }
+}
diff --git a/engines/llama/src/main/java/ai/djl/llama/engine/LlamaModel.java b/engines/llama/src/main/java/ai/djl/llama/engine/LlamaModel.java
new file mode 100644
index 00000000000..0ff3c6d70c0
--- /dev/null
+++ b/engines/llama/src/main/java/ai/djl/llama/engine/LlamaModel.java
@@ -0,0 +1,112 @@
+/*
+ * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+package ai.djl.llama.engine;
+
+import ai.djl.BaseModel;
+import ai.djl.Model;
+import ai.djl.llama.jni.LlamaLibrary;
+import ai.djl.llama.jni.ModelParameters;
+import ai.djl.ndarray.NDManager;
+import ai.djl.ndarray.types.DataType;
+import ai.djl.nn.Blocks;
+
+import java.io.FileNotFoundException;
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.util.Map;
+
+/** {@code LlamaModel} is the llama.cpp implementation of {@link Model}. */
+public class LlamaModel extends BaseModel {
+
+ private long handle = -1;
+
+ /**
+ * Constructs a new Model on a given device.
+ *
+ * @param name the model name
+ * @param manager the {@link NDManager} to holds the NDArray
+ */
+ LlamaModel(String name, NDManager manager) {
+ super(name);
+ this.manager = manager;
+ this.manager.setName("llamaModel");
+ dataType = DataType.FLOAT32;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public void load(Path modelPath, String prefix, Map options) throws IOException {
+ setModelDir(modelPath);
+ wasLoaded = true;
+ if (block != null) {
+ throw new UnsupportedOperationException("Llama does not support dynamic blocks");
+ }
+
+ if (prefix == null) {
+ prefix = modelName;
+ }
+
+ // search for .onnx file with prefix, folder name or "model.onnx"
+ Path modelFile = findModelFile(prefix, modelDir.toFile().getName(), "model.gguf");
+ if (modelFile == null) {
+ throw new FileNotFoundException(".gguf file not found in: " + modelPath);
+ }
+
+ ModelParameters param = new ModelParameters(options);
+ handle = LlamaLibrary.loadModel(modelFile.toString(), param);
+ block = Blocks.identityBlock();
+ }
+
+ long getHandle() {
+ return handle;
+ }
+
+ private Path findModelFile(String... prefixes) {
+ if (Files.isRegularFile(modelDir)) {
+ Path file = modelDir;
+ modelDir = modelDir.getParent();
+ String fileName = file.toFile().getName();
+ if (fileName.endsWith(".gguf")) {
+ modelName = fileName.substring(0, fileName.length() - 5);
+ } else {
+ modelName = fileName;
+ }
+ return file;
+ }
+ for (String prefix : prefixes) {
+ Path modelFile = modelDir.resolve(prefix);
+ if (Files.isRegularFile(modelFile)) {
+ return modelFile;
+ }
+ if (!prefix.endsWith(".gguf")) {
+ modelFile = modelDir.resolve(prefix + ".gguf");
+ if (Files.isRegularFile(modelFile)) {
+ return modelFile;
+ }
+ }
+ }
+ return null;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public void close() {
+ if (handle == -1) {
+ return;
+ }
+ LlamaLibrary.delete(handle);
+ handle = -1;
+ super.close();
+ }
+}
diff --git a/engines/llama/src/main/java/ai/djl/llama/engine/LlamaTranslator.java b/engines/llama/src/main/java/ai/djl/llama/engine/LlamaTranslator.java
new file mode 100644
index 00000000000..c8d3692b160
--- /dev/null
+++ b/engines/llama/src/main/java/ai/djl/llama/engine/LlamaTranslator.java
@@ -0,0 +1,107 @@
+/*
+ * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+package ai.djl.llama.engine;
+
+import ai.djl.inference.streaming.IteratorBytesSupplier;
+import ai.djl.llama.jni.InputParameters;
+import ai.djl.llama.jni.LlamaLibrary;
+import ai.djl.llama.jni.Token;
+import ai.djl.llama.jni.TokenIterator;
+import ai.djl.modality.Input;
+import ai.djl.modality.Output;
+import ai.djl.ndarray.BytesSupplier;
+import ai.djl.ndarray.NDList;
+import ai.djl.translate.NoBatchifyTranslator;
+import ai.djl.translate.TranslatorContext;
+import ai.djl.util.JsonUtils;
+
+import java.util.Iterator;
+
+/** Built-in {@code Translator} that provides preprocessing and postprocessing for llama.cpp. */
+public class LlamaTranslator implements NoBatchifyTranslator {
+
+ private long handle;
+
+ /** {@inheritDoc} */
+ @Override
+ public void prepare(TranslatorContext ctx) {
+ LlamaModel model = (LlamaModel) ctx.getModel();
+ handle = model.getHandle();
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public NDList processInput(TranslatorContext ctx, I input) {
+ if (input instanceof String) {
+ ctx.setAttachment("out", generate((String) input));
+ } else if (input instanceof LlamaInput) {
+ ctx.setAttachment("out", generate((LlamaInput) input));
+ } else if (input instanceof Input) {
+ String prompt = ((Input) input).getData().getAsString();
+ TokenIterator it = generate(prompt);
+ Output output = new Output();
+ output.add(new IteratorBytesSupplier(new OutputIterator(it)));
+ ctx.setAttachment("out", output);
+ }
+ return new NDList();
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ @SuppressWarnings("unchecked")
+ public O processOutput(TranslatorContext ctx, NDList list) {
+ return (O) ctx.getAttachment("out");
+ }
+
+ private TokenIterator generate(String input) {
+ LlamaInput in = JsonUtils.GSON.fromJson(input, LlamaInput.class);
+ return generate(in);
+ }
+
+ private TokenIterator generate(LlamaInput in) {
+ InputParameters param = in.getParameters().toInputParameters();
+ String prefix = in.getPrefix();
+ String suffix = in.getSuffix();
+ String inputs = in.getInputs();
+ if (prefix != null && suffix != null) {
+ LlamaLibrary.infill(handle, prefix, prefix, param);
+ } else if (inputs != null && !inputs.isEmpty()) {
+ LlamaLibrary.generate(handle, inputs, param);
+ } else {
+ throw new IllegalArgumentException("Unsupported input format");
+ }
+ return new TokenIterator(handle);
+ }
+
+ private static final class OutputIterator implements Iterator {
+
+ private TokenIterator it;
+
+ public OutputIterator(TokenIterator it) {
+ this.it = it;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public boolean hasNext() {
+ return it.hasNext();
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public BytesSupplier next() {
+ Token token = it.next();
+ return BytesSupplier.wrap(JsonUtils.GSON.toJson(token) + "\n");
+ }
+ }
+}
diff --git a/engines/llama/src/main/java/ai/djl/llama/engine/LlamaTranslatorFactory.java b/engines/llama/src/main/java/ai/djl/llama/engine/LlamaTranslatorFactory.java
new file mode 100644
index 00000000000..089b5055b51
--- /dev/null
+++ b/engines/llama/src/main/java/ai/djl/llama/engine/LlamaTranslatorFactory.java
@@ -0,0 +1,60 @@
+/*
+ * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+package ai.djl.llama.engine;
+
+import ai.djl.Model;
+import ai.djl.llama.jni.TokenIterator;
+import ai.djl.modality.Input;
+import ai.djl.modality.Output;
+import ai.djl.translate.Translator;
+import ai.djl.translate.TranslatorFactory;
+import ai.djl.util.Pair;
+
+import java.io.Serializable;
+import java.lang.reflect.Type;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+
+/** A {@link TranslatorFactory} that creates a {@link LlamaTranslator} instance. */
+public class LlamaTranslatorFactory implements TranslatorFactory, Serializable {
+
+ private static final long serialVersionUID = 1L;
+
+ private static final Set> SUPPORTED_TYPES = new HashSet<>();
+
+ static {
+ SUPPORTED_TYPES.add(new Pair<>(String.class, TokenIterator.class));
+ SUPPORTED_TYPES.add(new Pair<>(LlamaInput.class, TokenIterator.class));
+ SUPPORTED_TYPES.add(new Pair<>(Input.class, Output.class));
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public Set> getSupportedTypes() {
+ return SUPPORTED_TYPES;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public boolean isSupported(Class> input, Class> output) {
+ return true;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public Translator newInstance(
+ Class input, Class output, Model model, Map arguments) {
+ return new LlamaTranslator<>();
+ }
+}
diff --git a/engines/llama/src/main/java/ai/djl/llama/engine/package-info.java b/engines/llama/src/main/java/ai/djl/llama/engine/package-info.java
new file mode 100644
index 00000000000..226e7a6ddb8
--- /dev/null
+++ b/engines/llama/src/main/java/ai/djl/llama/engine/package-info.java
@@ -0,0 +1,15 @@
+/*
+ * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+
+/** Contains classes to interface with the underlying Llama Engine. */
+package ai.djl.llama.engine;
diff --git a/engines/llama/src/main/java/ai/djl/llama/jni/InputParameters.java b/engines/llama/src/main/java/ai/djl/llama/jni/InputParameters.java
new file mode 100644
index 00000000000..d13abc5ef90
--- /dev/null
+++ b/engines/llama/src/main/java/ai/djl/llama/jni/InputParameters.java
@@ -0,0 +1,314 @@
+/*
+ * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+package ai.djl.llama.jni;
+
+import java.util.Map;
+
+/** A class holds input parameters. */
+@SuppressWarnings({"PMD.UnusedPrivateField", "PMD.UnusedAssignment"})
+public class InputParameters {
+
+ private int nPredict;
+ private int nKeep;
+ private int nProbs;
+ private int topK;
+ private float topP;
+ private float tfsZ;
+ private float typicalP;
+ private float temperature;
+ private float repeatPenalty;
+ private int repeatLastN;
+ private float frequencyPenalty;
+ private float presencePenalty;
+ private boolean penalizeNl;
+ private boolean ignoreEos;
+ private int mirostat;
+ private float mirostatTau;
+ private float mirostatEta;
+ private int nBeams;
+ private int seed;
+ private Map logitBias;
+ private String grammar;
+ private String[] antiPrompt;
+
+ /**
+ * Constructs new {@code InputParameters} instance.
+ *
+ * @param nPredict the max new tokens
+ * @param nKeep the number of keep
+ * @param nProbs the number of probabilities
+ * @param topK the top K
+ * @param topP the top P
+ * @param tfsZ the tfs Z
+ * @param typicalP the typical P
+ * @param temperature the temperature
+ * @param repeatPenalty the repeat penalty
+ * @param repeatLastN the repeat last N
+ * @param frequencyPenalty the frequency penalty
+ * @param presencePenalty the presence penalty
+ * @param penalizeNl the penalize nl
+ * @param ignoreEos the ignore EOS
+ * @param mirostat the mirostat
+ * @param mirostatTau the mirostat TAU
+ * @param mirostatEta the mirostat ETA
+ * @param nBeams the number of beams
+ * @param seed the seed
+ * @param logitBias the logit bias
+ * @param grammar the grammar
+ * @param antiPrompt the anti prompt
+ */
+ public InputParameters(
+ int nPredict,
+ int nKeep,
+ int nProbs,
+ int topK,
+ float topP,
+ float tfsZ,
+ float typicalP,
+ float temperature,
+ float repeatPenalty,
+ int repeatLastN,
+ float frequencyPenalty,
+ float presencePenalty,
+ boolean penalizeNl,
+ boolean ignoreEos,
+ int mirostat,
+ float mirostatTau,
+ float mirostatEta,
+ int nBeams,
+ int seed,
+ Map logitBias,
+ String grammar,
+ String[] antiPrompt) {
+ this.nPredict = nPredict;
+ this.nKeep = nKeep;
+ this.nProbs = nProbs;
+ this.topK = topK;
+ this.topP = topP;
+ this.tfsZ = tfsZ;
+ this.typicalP = typicalP;
+ this.temperature = temperature;
+ this.repeatPenalty = repeatPenalty;
+ this.repeatLastN = repeatLastN;
+ this.frequencyPenalty = frequencyPenalty;
+ this.presencePenalty = presencePenalty;
+ this.penalizeNl = penalizeNl;
+ this.ignoreEos = ignoreEos;
+ this.mirostat = mirostat;
+ this.mirostatTau = mirostatTau;
+ this.mirostatEta = mirostatEta;
+ this.nBeams = nBeams;
+ this.seed = seed;
+ this.logitBias = logitBias;
+ this.grammar = grammar;
+ this.antiPrompt = antiPrompt;
+ }
+
+ /**
+ * Returns the max new tokens.
+ *
+ * @return the max new tokens
+ */
+ public int getMaxNewTokens() {
+ return nPredict;
+ }
+
+ /**
+ * Returns the number of keep.
+ *
+ * @return the number of keep
+ */
+ public int getNumberKeep() {
+ return nKeep;
+ }
+
+ /**
+ * Returns the number of probabilities.
+ *
+ * @return the number of probabilities
+ */
+ public int getNumberProbabilities() {
+ return nProbs;
+ }
+
+ /**
+ * Returns the top K.
+ *
+ * @return the top K
+ */
+ public int getTopK() {
+ return topK;
+ }
+
+ /**
+ * Return the top P.
+ *
+ * @return the top P
+ */
+ public float getTopP() {
+ return topP;
+ }
+
+ /**
+ * Return the TfsZ.
+ *
+ * @return the TfsZ
+ */
+ public float getTfsZ() {
+ return tfsZ;
+ }
+
+ /**
+ * Return the typical P.
+ *
+ * @return the typical P
+ */
+ public float getTypicalP() {
+ return typicalP;
+ }
+
+ /**
+ * Return the temperature.
+ *
+ * @return the temperature
+ */
+ public float getTemperature() {
+ return temperature;
+ }
+
+ /**
+ * Return the repeat penalty.
+ *
+ * @return the repeat penalty
+ */
+ public float getRepeatPenalty() {
+ return repeatPenalty;
+ }
+
+ /**
+ * Return the repeat last N.
+ *
+ * @return the repeat last N
+ */
+ public int getRepeatLastN() {
+ return repeatLastN;
+ }
+
+ /**
+ * Return the frequency penalty.
+ *
+ * @return the frequency penalty
+ */
+ public float getFrequencyPenalty() {
+ return frequencyPenalty;
+ }
+
+ /**
+ * Return the presence penalty.
+ *
+ * @return the presence penalty
+ */
+ public float getPresencePenalty() {
+ return presencePenalty;
+ }
+
+ /**
+ * Return the penalize NL.
+ *
+ * @return the penalize NL
+ */
+ public boolean isPenalizeNl() {
+ return penalizeNl;
+ }
+
+ /**
+ * Returns {@code true} if ignore EOS.
+ *
+ * @return {@code true} if ignore EOS
+ */
+ public boolean isIgnoreEos() {
+ return ignoreEos;
+ }
+
+ /**
+ * Returns the mirostat.
+ *
+ * @return the mirostat
+ */
+ public int getMirostat() {
+ return mirostat;
+ }
+
+ /**
+ * Returns the mirostat TAU.
+ *
+ * @return the mirostat TAU
+ */
+ public float getMirostatTau() {
+ return mirostatTau;
+ }
+
+ /**
+ * Returns the mirostat ETA.
+ *
+ * @return the mirostat ETA
+ */
+ public float getMirostatEta() {
+ return mirostatEta;
+ }
+
+ /**
+ * Returns the number of beams.
+ *
+ * @return the number of beams
+ */
+ public int getNumberBeams() {
+ return nBeams;
+ }
+
+ /**
+ * Returns the seed.
+ *
+ * @return the seed
+ */
+ public int getSeed() {
+ return seed;
+ }
+
+ /**
+ * Returns the logit bias.
+ *
+ * @return the logit bias
+ */
+ public Map getLogitBias() {
+ return logitBias;
+ }
+
+ /**
+ * Returns the grammar template.
+ *
+ * @return the grammar template
+ */
+ public String getGrammar() {
+ return grammar;
+ }
+
+ /**
+ * Returns the anti-prompt.
+ *
+ * @return the anti-prompt
+ */
+ public String[] getAntiPrompt() {
+ return antiPrompt;
+ }
+}
diff --git a/engines/llama/src/main/java/ai/djl/llama/jni/LibUtils.java b/engines/llama/src/main/java/ai/djl/llama/jni/LibUtils.java
new file mode 100644
index 00000000000..3792864c346
--- /dev/null
+++ b/engines/llama/src/main/java/ai/djl/llama/jni/LibUtils.java
@@ -0,0 +1,98 @@
+/*
+ * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+package ai.djl.llama.jni;
+
+import ai.djl.util.ClassLoaderUtils;
+import ai.djl.util.Platform;
+import ai.djl.util.Utils;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.StandardCopyOption;
+import java.util.ArrayList;
+import java.util.List;
+
+/** Utilities for finding the llama.cpp native binary on the System. */
+public final class LibUtils {
+
+ private static final Logger logger = LoggerFactory.getLogger(LibUtils.class);
+
+ private static final String LIB_NAME = System.mapLibraryName("djl_llama");
+ private static final String LLAMA_NAME = System.mapLibraryName("llama");
+
+ private LibUtils() {}
+
+ /** Loads llama.cpp native library. */
+ public static void loadLibrary() {
+ List libs = new ArrayList<>(3);
+ libs.add(LLAMA_NAME);
+ libs.add(LIB_NAME);
+ if (System.getProperty("os.name").startsWith("Mac")) {
+ libs.add("ggml-metal.metal");
+ }
+ Path dir = copyJniLibraryFromClasspath(libs.toArray(new String[0]));
+ logger.debug("Loading llama.cpp library from: {}", dir);
+
+ for (int i = 0; i < 2; ++i) {
+ String lib = libs.get(i);
+ String path = dir.resolve(lib).toString();
+ logger.debug("Loading native library: {}", path);
+ String nativeHelper = System.getProperty("ai.djl.llama.native_helper");
+ if (nativeHelper != null && !nativeHelper.isEmpty()) {
+ ClassLoaderUtils.nativeLoad(nativeHelper, path);
+ }
+ System.load(path); // NOPMD
+ }
+ }
+
+ private static Path copyJniLibraryFromClasspath(String... libs) {
+ Path cacheDir = Utils.getEngineCacheDir("llama");
+ Platform platform = Platform.detectPlatform("llama");
+ String classifier = platform.getClassifier();
+ String version = platform.getVersion();
+ Path dir = cacheDir.resolve(version + '-' + classifier);
+ Path path = dir.resolve(LIB_NAME);
+ logger.debug("Using cache dir: {}", dir);
+ if (Files.exists(path)) {
+ return dir.toAbsolutePath();
+ }
+
+ Path tmp = null;
+ try {
+ Files.createDirectories(cacheDir);
+ tmp = Files.createTempDirectory(cacheDir, "tmp");
+
+ for (String libName : libs) {
+ String libPath = "native/lib/" + classifier + "/" + libName;
+ logger.info("Extracting {} to cache ...", libPath);
+ try (InputStream is = ClassLoaderUtils.getResourceAsStream(libPath)) {
+ Path target = tmp.resolve(libName);
+ Files.copy(is, target, StandardCopyOption.REPLACE_EXISTING);
+ }
+ }
+ Utils.moveQuietly(tmp, dir);
+ return dir.toAbsolutePath();
+ } catch (IOException e) {
+ throw new IllegalStateException("Cannot copy jni files", e);
+ } finally {
+ if (tmp != null) {
+ Utils.deleteQuietly(tmp);
+ }
+ }
+ }
+}
diff --git a/engines/llama/src/main/java/ai/djl/llama/jni/LlamaLibrary.java b/engines/llama/src/main/java/ai/djl/llama/jni/LlamaLibrary.java
new file mode 100644
index 00000000000..5d40fa29830
--- /dev/null
+++ b/engines/llama/src/main/java/ai/djl/llama/jni/LlamaLibrary.java
@@ -0,0 +1,37 @@
+/*
+ * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+package ai.djl.llama.jni;
+
+/** Native library for llama.cpp. */
+@SuppressWarnings("MissingJavadocMethod")
+public final class LlamaLibrary {
+
+ private LlamaLibrary() {}
+
+ public static native long loadModel(String filePath, ModelParameters param);
+
+ public static native void generate(long handle, String prompt, InputParameters param);
+
+ public static native void infill(
+ long handle, String prefix, String suffix, InputParameters param);
+
+ public static native Token getNext(long handle, long count, long pos);
+
+ public static native float[] embed(long handle, String prompt);
+
+ public static native int[] encode(long handle, String prompt);
+
+ public static native byte[] decodeBytes(long handle, int[] tokens);
+
+ public static native void delete(long handle);
+}
diff --git a/engines/llama/src/main/java/ai/djl/llama/jni/ModelParameters.java b/engines/llama/src/main/java/ai/djl/llama/jni/ModelParameters.java
new file mode 100644
index 00000000000..e3e440474a8
--- /dev/null
+++ b/engines/llama/src/main/java/ai/djl/llama/jni/ModelParameters.java
@@ -0,0 +1,114 @@
+/*
+ * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+package ai.djl.llama.jni;
+
+import java.util.Map;
+
+/** A class holds llama.cpp model loading parameters. */
+@SuppressWarnings("PMD.SingularField")
+public final class ModelParameters {
+
+ private int nThreads;
+ private int nCtx;
+ private int nBatch;
+ private int nGpuLayers;
+ private int mainGpu;
+ private float ropeFreqBase;
+ private float ropeFreqScale;
+ private boolean mulMatQ;
+ private boolean f16Kv;
+ private boolean logitsAll;
+ private boolean vocabOnly;
+ private boolean useMmap;
+ private boolean useMlock;
+ private boolean embedding;
+ private boolean memoryF16;
+ private boolean memTest;
+ private boolean numa;
+ private boolean verbosePrompt;
+ private float[] tensorSplit;
+ private String loraAdapter;
+ private String loraBase;
+
+ /**
+ * Constructs a new {@code ModelParameters} instance.
+ *
+ * @param options the model loading options
+ */
+ public ModelParameters(Map options) {
+ nThreads = intValue(options, "number_threads", Runtime.getRuntime().availableProcessors());
+ nCtx = intValue(options, "max_context_length", 512);
+ nBatch = intValue(options, "max_rolling_batch", 512);
+ nGpuLayers = intValue(options, "number_gpu_layers", -1);
+ mainGpu = intValue(options, "tensor_parallel_degree", 0);
+ ropeFreqBase = floatValue(options, "rope_freq_base");
+ ropeFreqScale = floatValue(options, "ropeFreqScale");
+ f16Kv = booleanValue(options, "f16_kv");
+ mulMatQ = booleanValue(options, "mulmat_q", true);
+ logitsAll = booleanValue(options, "logits_all");
+ vocabOnly = booleanValue(options, "vocab_only");
+ useMmap = booleanValue(options, "use_mmap", true);
+ useMlock = booleanValue(options, "use_mlock");
+ embedding = booleanValue(options, "embedding");
+ memoryF16 = booleanValue(options, "memory_f16", true);
+ memTest = booleanValue(options, "mem_test");
+ numa = booleanValue(options, "numa");
+ verbosePrompt = booleanValue(options, "verbose_prompt");
+ String val = stringValue(options, "tensor_split");
+ if (val != null && !val.isEmpty()) {
+ String[] tokens = val.split(",");
+ tensorSplit = new float[tokens.length];
+ for (int i = 0; i < tokens.length; ++i) {
+ tensorSplit[i] = Float.parseFloat(tokens[i].trim());
+ }
+ }
+ loraAdapter = stringValue(options, "lora_adapter");
+ loraBase = stringValue(options, "loraBase");
+ }
+
+ private static int intValue(Map arguments, String key, int def) {
+ Object value = arguments.get(key);
+ if (value == null) {
+ return def;
+ }
+ return (int) Double.parseDouble(value.toString());
+ }
+
+ private static float floatValue(Map arguments, String key) {
+ Object value = arguments.get(key);
+ if (value == null) {
+ return 0f;
+ }
+ return (float) Double.parseDouble(value.toString());
+ }
+
+ private static boolean booleanValue(Map arguments, String key) {
+ return booleanValue(arguments, key, false);
+ }
+
+ private static boolean booleanValue(Map arguments, String key, boolean def) {
+ Object value = arguments.get(key);
+ if (value == null) {
+ return def;
+ }
+ return Boolean.parseBoolean(value.toString());
+ }
+
+ private static String stringValue(Map arguments, String key) {
+ Object value = arguments.get(key);
+ if (value == null) {
+ return null;
+ }
+ return value.toString();
+ }
+}
diff --git a/engines/llama/src/main/java/ai/djl/llama/jni/Token.java b/engines/llama/src/main/java/ai/djl/llama/jni/Token.java
new file mode 100644
index 00000000000..b8d74306b56
--- /dev/null
+++ b/engines/llama/src/main/java/ai/djl/llama/jni/Token.java
@@ -0,0 +1,87 @@
+/*
+ * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+package ai.djl.llama.jni;
+
+import ai.djl.util.JsonUtils;
+
+import java.nio.charset.StandardCharsets;
+import java.util.Map;
+
+/** The output token class. */
+public final class Token {
+
+ private int token;
+ private String text;
+ private Map probabilities;
+ transient long count;
+ transient long pos;
+ transient boolean hasNext;
+
+ /**
+ * Constructs a new {@code Token} instance.
+ *
+ * @param token the token id
+ * @param generated the token text
+ * @param probabilities the token probabilities
+ * @param count the generated token count
+ * @param pos the token index
+ * @param hasNext has more tokens
+ */
+ public Token(
+ int token,
+ byte[] generated,
+ Map probabilities,
+ long count,
+ long pos,
+ boolean hasNext) {
+ this.token = token;
+ this.text = new String(generated, StandardCharsets.UTF_8);
+ this.probabilities = probabilities;
+ this.count = count;
+ this.pos = pos;
+ this.hasNext = hasNext;
+ }
+
+ /**
+ * Returns the token id.
+ *
+ * @return the token id
+ */
+ public int getToken() {
+ return token;
+ }
+
+ /**
+ * Returns the token text.
+ *
+ * @return the token text
+ */
+ public String getText() {
+ return text;
+ }
+
+ /**
+ * Returns the token probabilities.
+ *
+ * @return the token probabilities
+ */
+ public Map getProbabilities() {
+ return probabilities;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public String toString() {
+ return JsonUtils.GSON.toJson(this) + '\n';
+ }
+}
diff --git a/engines/llama/src/main/java/ai/djl/llama/jni/TokenIterator.java b/engines/llama/src/main/java/ai/djl/llama/jni/TokenIterator.java
new file mode 100644
index 00000000000..cab6575d8f7
--- /dev/null
+++ b/engines/llama/src/main/java/ai/djl/llama/jni/TokenIterator.java
@@ -0,0 +1,69 @@
+/*
+ * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+package ai.djl.llama.jni;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.Iterator;
+import java.util.NoSuchElementException;
+import java.util.concurrent.atomic.AtomicBoolean;
+
+/** A iterator class holds generated tokens. */
+public class TokenIterator implements Iterator {
+
+ private static final Logger logger = LoggerFactory.getLogger(TokenIterator.class);
+
+ private static AtomicBoolean active = new AtomicBoolean();
+
+ private long handle;
+ private long count;
+ private long pos;
+ private boolean hasNext;
+
+ /**
+ * Constructs a new {@code TokenIterator} instance.
+ *
+ * @param handle the llama.cpp handle
+ */
+ public TokenIterator(long handle) {
+ this.handle = handle;
+ hasNext = true;
+ if (!active.compareAndSet(false, true)) {
+ active.set(true);
+ logger.warn("Previous inference has been reset");
+ }
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public boolean hasNext() {
+ return hasNext;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public Token next() {
+ if (!hasNext) {
+ throw new NoSuchElementException();
+ }
+ Token token = LlamaLibrary.getNext(handle, count, pos);
+ count = token.count;
+ pos = token.pos;
+ hasNext = token.hasNext;
+ if (!hasNext) {
+ active.set(false);
+ }
+ return token;
+ }
+}
diff --git a/engines/llama/src/main/java/ai/djl/llama/jni/package-info.java b/engines/llama/src/main/java/ai/djl/llama/jni/package-info.java
new file mode 100644
index 00000000000..6f429aceda2
--- /dev/null
+++ b/engines/llama/src/main/java/ai/djl/llama/jni/package-info.java
@@ -0,0 +1,14 @@
+/*
+ * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+/** Contains classes to interface with the native llama.cpp code. */
+package ai.djl.llama.jni;
diff --git a/engines/llama/src/main/java/ai/djl/llama/zoo/LlamaModelZoo.java b/engines/llama/src/main/java/ai/djl/llama/zoo/LlamaModelZoo.java
new file mode 100644
index 00000000000..69d4f200ba9
--- /dev/null
+++ b/engines/llama/src/main/java/ai/djl/llama/zoo/LlamaModelZoo.java
@@ -0,0 +1,172 @@
+/*
+ * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+package ai.djl.llama.zoo;
+
+import ai.djl.Application;
+import ai.djl.repository.Repository;
+import ai.djl.repository.zoo.ModelLoader;
+import ai.djl.repository.zoo.ModelZoo;
+import ai.djl.util.ClassLoaderUtils;
+import ai.djl.util.JsonUtils;
+import ai.djl.util.Utils;
+
+import com.google.gson.reflect.TypeToken;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.Reader;
+import java.io.Writer;
+import java.lang.reflect.Type;
+import java.net.URI;
+import java.net.URL;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.time.Duration;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Map;
+import java.util.Set;
+import java.util.zip.GZIPInputStream;
+
+/** LlamaModelZoo is a repository that contains llama.cpp models. */
+public class LlamaModelZoo extends ModelZoo {
+
+ private static final Logger logger = LoggerFactory.getLogger(LlamaModelZoo.class);
+
+ private static final String REPO = "https://mlrepo.djl.ai/";
+ private static final Repository REPOSITORY = Repository.newInstance("gguf", REPO);
+ private static final String GROUP_ID = "ai.djl.huggingface.gguf";
+
+ private static final long ONE_DAY = Duration.ofDays(1).toMillis();
+
+ private boolean initialized;
+
+ LlamaModelZoo() {}
+
+ /** {@inheritDoc} */
+ @Override
+ public String getGroupId() {
+ return GROUP_ID;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public Set getSupportedEngines() {
+ return Collections.singleton("Llama");
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public Collection getModelLoaders() {
+ init();
+ return super.getModelLoaders();
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public ModelLoader getModelLoader(String name) {
+ init();
+ return super.getModelLoader(name);
+ }
+
+ private void init() {
+ if (!initialized) {
+ Application app = Application.NLP.TEXT_GENERATION;
+ Map map = listModels(app);
+ for (Map.Entry entry : map.entrySet()) {
+ String artifactId = entry.getKey();
+ Map gguf = entry.getValue().getGguf();
+ if (gguf != null) {
+ for (String key : gguf.keySet()) {
+ addModel(REPOSITORY.model(app, GROUP_ID, artifactId, "0.0.1", key));
+ }
+ }
+ }
+ initialized = true;
+ }
+ }
+
+ private Map listModels(Application app) {
+ try {
+ String path = "model/" + app.getPath() + "/ai/djl/huggingface/gguf/";
+ Path dir = Utils.getCacheDir().resolve("cache/repo/" + path);
+ if (Files.notExists(dir)) {
+ Files.createDirectories(dir);
+ } else if (!Files.isDirectory(dir)) {
+ logger.warn("Failed initialize cache directory: " + dir);
+ return Collections.emptyMap();
+ }
+ Type type = new TypeToken>() {}.getType();
+
+ Path file = dir.resolve("models.json");
+ if (Files.exists(file)) {
+ long lastModified = Files.getLastModifiedTime(file).toMillis();
+ if (Utils.isOfflineMode() || System.currentTimeMillis() - lastModified < ONE_DAY) {
+ try (Reader reader = Files.newBufferedReader(file)) {
+ return JsonUtils.GSON.fromJson(reader, type);
+ }
+ }
+ }
+
+ URL url = URI.create(REPO).resolve(path + "models.json.gz").toURL();
+ Path tmp = Files.createTempFile(dir, "models", ".tmp");
+ try (GZIPInputStream gis = new GZIPInputStream(Utils.openUrl(url))) {
+ String json = Utils.toString(gis);
+ try (Writer writer = Files.newBufferedWriter(tmp)) {
+ writer.write(json);
+ }
+ Utils.moveQuietly(tmp, file);
+ return JsonUtils.GSON.fromJson(json, type);
+ } catch (IOException e) {
+ logger.warn("Failed to download Huggingface gguf index: {}", app);
+ if (Files.exists(file)) {
+ try (Reader reader = Files.newBufferedReader(file)) {
+ return JsonUtils.GSON.fromJson(reader, type);
+ }
+ }
+
+ String resource = app.getPath() + "/" + GROUP_ID + ".json";
+ try (InputStream is = ClassLoaderUtils.getResourceAsStream(resource)) {
+ String json = Utils.toString(is);
+ try (Writer writer = Files.newBufferedWriter(tmp)) {
+ writer.write(json);
+ }
+ Utils.moveQuietly(tmp, file);
+ return JsonUtils.GSON.fromJson(json, type);
+ }
+ } finally {
+ Utils.deleteQuietly(tmp);
+ }
+ } catch (IOException e) {
+ logger.warn("Failed load gguf index file", e);
+ }
+
+ return Collections.emptyMap();
+ }
+
+ private static final class ModelDetail {
+
+ private Map gguf;
+
+ public Map getGguf() {
+ return gguf;
+ }
+
+ public void setGguf(Map gguf) {
+ this.gguf = gguf;
+ }
+ }
+}
diff --git a/engines/llama/src/main/java/ai/djl/llama/zoo/LlamaZooProvider.java b/engines/llama/src/main/java/ai/djl/llama/zoo/LlamaZooProvider.java
new file mode 100644
index 00000000000..ba2b04722c1
--- /dev/null
+++ b/engines/llama/src/main/java/ai/djl/llama/zoo/LlamaZooProvider.java
@@ -0,0 +1,29 @@
+/*
+ * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+package ai.djl.llama.zoo;
+
+import ai.djl.repository.zoo.ModelZoo;
+import ai.djl.repository.zoo.ZooProvider;
+
+/**
+ * An Huggingface llama.cpp model zoo provider implements the {@link
+ * ai.djl.repository.zoo.ZooProvider} interface.
+ */
+public class LlamaZooProvider implements ZooProvider {
+
+ /** {@inheritDoc} */
+ @Override
+ public ModelZoo getModelZoo() {
+ return new LlamaModelZoo();
+ }
+}
diff --git a/engines/llama/src/main/java/ai/djl/llama/zoo/package-info.java b/engines/llama/src/main/java/ai/djl/llama/zoo/package-info.java
new file mode 100644
index 00000000000..a9c1df64cd0
--- /dev/null
+++ b/engines/llama/src/main/java/ai/djl/llama/zoo/package-info.java
@@ -0,0 +1,14 @@
+/*
+ * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+/** Contains the built-in {@link ai.djl.llama.zoo.LlamaModelZoo}. */
+package ai.djl.llama.zoo;
diff --git a/engines/llama/src/main/javadoc/overview.html b/engines/llama/src/main/javadoc/overview.html
new file mode 100644
index 00000000000..05dec7d0bd4
--- /dev/null
+++ b/engines/llama/src/main/javadoc/overview.html
@@ -0,0 +1,14 @@
+
+
+
+
+
+This document is the API specification for the Deep Java Library (DJL) Llama Engine.
+
+
+ The Llama Engine module contains the Llama.cpp implementation of the DJL EngineProvider.
+ See here for more details.
+
+
+
+
diff --git a/engines/llama/src/main/native/ai_djl_llama.cpp b/engines/llama/src/main/native/ai_djl_llama.cpp
new file mode 100644
index 00000000000..1d6072751f2
--- /dev/null
+++ b/engines/llama/src/main/native/ai_djl_llama.cpp
@@ -0,0 +1,1025 @@
+#include
+#include
+#include
+#include
+
+#include "ai_djl_llama_jni_LlamaLibrary.h"
+#include "common.h"
+#include "grammar-parser.h"
+#include "llama.h"
+#include "sampling.h"
+
+// classes
+static jclass c_lib_utils = 0;
+static jclass c_model_params = 0;
+static jclass c_input_params = 0;
+static jclass c_token = 0;
+static jclass c_standard_charsets = 0;
+static jclass c_string = 0;
+static jclass c_hash_map = 0;
+static jclass c_map = 0;
+static jclass c_set = 0;
+static jclass c_entry = 0;
+static jclass c_integer = 0;
+static jclass c_float = 0;
+static jclass c_logger = 0;
+static jclass c_engine_exception = 0;
+
+// constructors
+static jmethodID cc_token = 0;
+static jmethodID cc_hash_map = 0;
+static jmethodID cc_integer = 0;
+static jmethodID cc_float = 0;
+
+// methods
+static jmethodID m_get_bytes = 0;
+static jmethodID m_entry_set = 0;
+static jmethodID m_set_iterator = 0;
+static jmethodID m_iterator_has_next = 0;
+static jmethodID m_iterator_next = 0;
+static jmethodID m_entry_key = 0;
+static jmethodID m_entry_value = 0;
+static jmethodID m_map_put = 0;
+static jmethodID m_int_value = 0;
+static jmethodID m_float_value = 0;
+static jmethodID m_log_debug = 0;
+static jmethodID m_log_info = 0;
+static jmethodID m_log_warn = 0;
+static jmethodID m_log_error = 0;
+
+// fields
+static jfieldID f_logger = 0;
+// inference parameters
+static jfieldID f_n_predict = 0;
+static jfieldID f_n_keep = 0;
+static jfieldID f_n_probs = 0;
+static jfieldID f_logit_bias = 0;
+static jfieldID f_top_k = 0;
+static jfieldID f_top_p = 0;
+static jfieldID f_tfs_z = 0;
+static jfieldID f_typical_p = 0;
+static jfieldID f_temperature = 0;
+static jfieldID f_repeat_penalty = 0;
+static jfieldID f_repeat_last_n = 0;
+static jfieldID f_frequency_penalty = 0;
+static jfieldID f_presence_penalty = 0;
+static jfieldID f_penalize_nl = 0;
+static jfieldID f_ignore_eos = 0;
+static jfieldID f_mirostat = 0;
+static jfieldID f_mirostat_tau = 0;
+static jfieldID f_mirostat_eta = 0;
+static jfieldID f_n_beams = 0;
+static jfieldID f_grammar = 0;
+static jfieldID f_antiprompt = 0;
+static jfieldID f_infer_seed = 0;
+// model parameters
+static jfieldID f_n_threads = 0;
+static jfieldID f_n_ctx = 0;
+static jfieldID f_n_batch = 0;
+static jfieldID f_n_gpu_layers = 0;
+static jfieldID f_main_gpu = 0;
+static jfieldID f_tensor_split = 0;
+static jfieldID f_rope_freq_base = 0;
+static jfieldID f_rope_freq_scale = 0;
+static jfieldID f_mul_mat_q = 0;
+static jfieldID f_f16_kv = 0;
+static jfieldID f_logits_all = 0;
+static jfieldID f_vocab_only = 0;
+static jfieldID f_use_mmap = 0;
+static jfieldID f_use_mlock = 0;
+static jfieldID f_embedding = 0;
+static jfieldID f_lora_adapter = 0;
+static jfieldID f_lora_base = 0;
+static jfieldID f_memory_f16 = 0;
+static jfieldID f_mem_test = 0;
+static jfieldID f_numa = 0;
+static jfieldID f_verbose_prompt = 0;
+// log level
+static jfieldID f_utf_8 = 0;
+// objects
+static jobject o_utf_8 = 0;
+static jobject o_logger = 0;
+
+static JavaVM *g_vm = nullptr;
+
+static void null_log_callback(enum ggml_log_level level, const char *text, void *user_data) {}
+
+JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved) {
+ JNIEnv *env = 0;
+
+ if (JNI_OK != vm->GetEnv((void **) &env, JNI_VERSION_1_1)) {
+ return JNI_ERR;
+ }
+
+ log_disable();
+ llama_log_set(null_log_callback, nullptr);
+
+ // find classes
+ c_input_params = env->FindClass("ai/djl/llama/jni/InputParameters");
+ c_model_params = env->FindClass("ai/djl/llama/jni/ModelParameters");
+ c_lib_utils = env->FindClass("ai/djl/llama/jni/LibUtils");
+ c_token = env->FindClass("ai/djl/llama/jni/Token");
+ c_engine_exception = env->FindClass("ai/djl/engine/EngineException");
+ c_logger = env->FindClass("org/slf4j/Logger");
+ c_standard_charsets = env->FindClass("java/nio/charset/StandardCharsets");
+ c_string = env->FindClass("java/lang/String");
+ c_hash_map = env->FindClass("java/util/HashMap");
+ c_map = env->FindClass("java/util/Map");
+ c_set = env->FindClass("java/util/Set");
+ c_entry = env->FindClass("java/util/Map$Entry");
+ c_integer = env->FindClass("java/lang/Integer");
+ c_float = env->FindClass("java/lang/Float");
+
+ // create references
+ c_input_params = (jclass) env->NewGlobalRef(c_input_params);
+ c_model_params = (jclass) env->NewGlobalRef(c_model_params);
+ c_lib_utils = (jclass) env->NewGlobalRef(c_lib_utils);
+ c_token = (jclass) env->NewGlobalRef(c_token);
+ c_engine_exception = (jclass) env->NewGlobalRef(c_engine_exception);
+ c_logger = (jclass) env->NewGlobalRef(c_logger);
+ c_string = (jclass) env->NewGlobalRef(c_string);
+ c_hash_map = (jclass) env->NewGlobalRef(c_hash_map);
+ c_map = (jclass) env->NewGlobalRef(c_map);
+ c_set = (jclass) env->NewGlobalRef(c_set);
+ c_entry = (jclass) env->NewGlobalRef(c_entry);
+ c_integer = (jclass) env->NewGlobalRef(c_integer);
+ c_float = (jclass) env->NewGlobalRef(c_float);
+
+ // find constructors
+ cc_token = env->GetMethodID(c_token, "", "(I[BLjava/util/Map;JJZ)V");
+ cc_hash_map = env->GetMethodID(c_hash_map, "", "()V");
+ cc_integer = env->GetMethodID(c_integer, "", "(I)V");
+ cc_float = env->GetMethodID(c_float, "", "(F)V");
+
+ // find methods
+ m_get_bytes = env->GetMethodID(c_string, "getBytes", "(Ljava/lang/String;)[B");
+ m_entry_set = env->GetMethodID(c_map, "entrySet", "()Ljava/util/Set;");
+ m_entry_key = env->GetMethodID(c_entry, "getKey", "()Ljava/lang/Object;");
+ m_entry_value = env->GetMethodID(c_entry, "getValue", "()Ljava/lang/Object;");
+ m_map_put = env->GetMethodID(c_map, "put", "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;");
+ m_int_value = env->GetMethodID(c_integer, "intValue", "()I");
+ m_float_value = env->GetMethodID(c_float, "floatValue", "()F");
+ m_log_debug = env->GetMethodID(c_logger, "debug", "(Ljava/lang/String;)V");
+ m_log_info = env->GetMethodID(c_logger, "info", "(Ljava/lang/String;)V");
+ m_log_warn = env->GetMethodID(c_logger, "warn", "(Ljava/lang/String;)V");
+ m_log_error = env->GetMethodID(c_logger, "error", "(Ljava/lang/String;)V");
+
+ // find fields
+ f_logger = env->GetStaticFieldID(c_lib_utils, "logger", "Lorg/slf4j/Logger;");
+
+ f_n_predict = env->GetFieldID(c_input_params, "nPredict", "I");
+ f_n_keep = env->GetFieldID(c_input_params, "nKeep", "I");
+ f_n_probs = env->GetFieldID(c_input_params, "nProbs", "I");
+ f_logit_bias = env->GetFieldID(c_input_params, "logitBias", "Ljava/util/Map;");
+ f_top_k = env->GetFieldID(c_input_params, "topK", "I");
+ f_top_p = env->GetFieldID(c_input_params, "topP", "F");
+ f_tfs_z = env->GetFieldID(c_input_params, "tfsZ", "F");
+ f_typical_p = env->GetFieldID(c_input_params, "typicalP", "F");
+ f_temperature = env->GetFieldID(c_input_params, "temperature", "F");
+ f_repeat_penalty = env->GetFieldID(c_input_params, "repeatPenalty", "F");
+ f_repeat_last_n = env->GetFieldID(c_input_params, "repeatLastN", "I");
+ f_frequency_penalty = env->GetFieldID(c_input_params, "frequencyPenalty", "F");
+ f_presence_penalty = env->GetFieldID(c_input_params, "presencePenalty", "F");
+ f_penalize_nl = env->GetFieldID(c_input_params, "penalizeNl", "Z");
+ f_ignore_eos = env->GetFieldID(c_input_params, "ignoreEos", "Z");
+ f_mirostat = env->GetFieldID(c_input_params, "mirostat", "I");
+ f_mirostat_tau = env->GetFieldID(c_input_params, "mirostatTau", "F");
+ f_mirostat_eta = env->GetFieldID(c_input_params, "mirostatEta", "F");
+ f_n_beams = env->GetFieldID(c_input_params, "nBeams", "I");
+ f_grammar = env->GetFieldID(c_input_params, "grammar", "Ljava/lang/String;");
+ f_antiprompt = env->GetFieldID(c_input_params, "antiPrompt", "[Ljava/lang/String;");
+ f_infer_seed = env->GetFieldID(c_input_params, "seed", "I");
+
+ f_n_threads = env->GetFieldID(c_model_params, "nThreads", "I");
+ f_n_ctx = env->GetFieldID(c_model_params, "nCtx", "I");
+ f_n_batch = env->GetFieldID(c_model_params, "nBatch", "I");
+ f_n_gpu_layers = env->GetFieldID(c_model_params, "nGpuLayers", "I");
+ f_main_gpu = env->GetFieldID(c_model_params, "mainGpu", "I");
+ f_tensor_split = env->GetFieldID(c_model_params, "tensorSplit", "[F");
+ f_rope_freq_base = env->GetFieldID(c_model_params, "ropeFreqBase", "F");
+ f_rope_freq_scale = env->GetFieldID(c_model_params, "ropeFreqScale", "F");
+ f_mul_mat_q = env->GetFieldID(c_model_params, "mulMatQ", "Z");
+ f_f16_kv = env->GetFieldID(c_model_params, "f16Kv", "Z");
+ f_logits_all = env->GetFieldID(c_model_params, "logitsAll", "Z");
+ f_vocab_only = env->GetFieldID(c_model_params, "vocabOnly", "Z");
+ f_use_mmap = env->GetFieldID(c_model_params, "useMmap", "Z");
+ f_use_mlock = env->GetFieldID(c_model_params, "useMlock", "Z");
+ f_embedding = env->GetFieldID(c_model_params, "embedding", "Z");
+ f_lora_adapter = env->GetFieldID(c_model_params, "loraAdapter", "Ljava/lang/String;");
+ f_lora_base = env->GetFieldID(c_model_params, "loraBase", "Ljava/lang/String;");
+ f_memory_f16 = env->GetFieldID(c_model_params, "memoryF16", "Z");
+ f_mem_test = env->GetFieldID(c_model_params, "memTest", "Z");
+ f_numa = env->GetFieldID(c_model_params, "numa", "Z");
+ f_verbose_prompt = env->GetFieldID(c_model_params, "verbosePrompt", "Z");
+
+ f_utf_8 = env->GetStaticFieldID(c_standard_charsets, "UTF_8", "Ljava/nio/charset/Charset;");
+ o_utf_8 = env->NewStringUTF("UTF-8");
+ o_utf_8 = (jobject) env->NewGlobalRef(o_utf_8);
+ o_logger = env->GetStaticObjectField(c_lib_utils, f_logger);
+ o_logger = (jobject) env->NewGlobalRef(o_logger);
+
+ if (env->ExceptionCheck()) {
+ env->ExceptionDescribe();
+ return JNI_ERR;
+ }
+
+ return JNI_VERSION_1_1;
+}
+
+JNIEXPORT void JNICALL JNI_OnUnload(JavaVM *vm, void *reserved) {
+ JNIEnv *env = 0;
+
+ if (JNI_OK != vm->GetEnv((void **) &env, JNI_VERSION_1_1)) {
+ return;
+ }
+
+ env->DeleteGlobalRef(c_input_params);
+ env->DeleteGlobalRef(c_model_params);
+ env->DeleteGlobalRef(c_token);
+ env->DeleteGlobalRef(c_string);
+ env->DeleteGlobalRef(c_hash_map);
+ env->DeleteGlobalRef(c_map);
+ env->DeleteGlobalRef(c_set);
+ env->DeleteGlobalRef(c_entry);
+ env->DeleteGlobalRef(c_integer);
+ env->DeleteGlobalRef(c_float);
+ env->DeleteGlobalRef(c_logger);
+ env->DeleteGlobalRef(c_engine_exception);
+
+ env->DeleteGlobalRef(o_utf_8);
+}
+
+static void log(JNIEnv *env, enum ggml_log_level level, const char *text) {
+ jstring java_text = env->NewStringUTF(text);
+
+ switch (level) {
+ case GGML_LOG_LEVEL_ERROR:
+ env->CallVoidMethod(o_logger, m_log_error, java_text);
+ break;
+ case GGML_LOG_LEVEL_WARN:
+ env->CallVoidMethod(o_logger, m_log_warn, java_text);
+ break;
+ case GGML_LOG_LEVEL_INFO:
+ env->CallVoidMethod(o_logger, m_log_info, java_text);
+ break;
+ default:
+ env->CallVoidMethod(o_logger, m_log_debug, java_text);
+ break;
+ }
+ env->DeleteLocalRef(java_text);
+}
+
+static void log(JNIEnv *env, enum ggml_log_level level, std::string text) { log(env, level, text.c_str()); }
+
+static std::string parse_jstring(JNIEnv *env, jstring java_string) {
+ const jbyteArray string_bytes = (jbyteArray) env->CallObjectMethod(java_string, m_get_bytes, o_utf_8);
+
+ size_t length = (size_t) env->GetArrayLength(string_bytes);
+ jbyte *byte_elements = env->GetByteArrayElements(string_bytes, nullptr);
+
+ std::string string = std::string((char *) byte_elements, length);
+
+ env->ReleaseByteArrayElements(string_bytes, byte_elements, JNI_ABORT);
+ env->DeleteLocalRef(string_bytes);
+
+ return string;
+}
+
+static int parse_jinteger(JNIEnv *env, jobject java_integer) {
+ if (!java_integer) return 0;
+ return env->CallIntMethod(java_integer, m_int_value);
+}
+
+static float parse_jfloat(JNIEnv *env, jobject java_float) {
+ if (!java_float) return 0;
+ return env->CallFloatMethod(java_float, m_float_value);
+}
+
+static jbyteArray parse_jbytes(JNIEnv *env, std::string string) {
+ jsize len = string.size();
+ jbyteArray bytes = env->NewByteArray(len);
+ env->SetByteArrayRegion(bytes, 0, len, reinterpret_cast(string.c_str()));
+ return bytes;
+}
+
+// completion token output with probabilities
+struct completion_token_output {
+ struct token_prob {
+ llama_token tok;
+ float prob;
+ };
+
+ std::vector probs;
+ llama_token tok;
+};
+
+static size_t common_part(const std::vector &a, const std::vector &b) {
+ size_t i;
+ for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {
+ }
+ return i;
+}
+
+enum stop_type {
+ STOP_FULL,
+ STOP_PARTIAL,
+};
+
+static bool ends_with(const std::string &str, const std::string &suffix) {
+ return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
+}
+
+static size_t find_partial_stop_string(const std::string &stop, const std::string &text) {
+ if (!text.empty() && !stop.empty()) {
+ const char text_last_char = text.back();
+ for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) {
+ if (stop[char_index] == text_last_char) {
+ const std::string current_partial = stop.substr(0, char_index + 1);
+ if (ends_with(text, current_partial)) {
+ return text.size() - char_index - 1;
+ }
+ }
+ }
+ }
+ return std::string::npos;
+}
+
+template
+static std::string tokens_to_str(llama_context *ctx, Iter begin, Iter end) {
+ std::string ret;
+ for (; begin != end; ++begin) {
+ ret += llama_token_to_piece(ctx, *begin);
+ }
+ return ret;
+}
+
+// format incomplete utf-8 multibyte character for output
+static std::string tokens_to_output_formatted_string(const llama_context *ctx, const llama_token token) {
+ std::string out = token == -1 ? "" : llama_token_to_piece(ctx, token);
+ // if the size is 1 and first bit is 1, meaning it's a partial character
+ // (size > 1 meaning it's already a known token)
+ if (out.size() == 1 && (out[0] & 0x80) == 0x80) {
+ std::stringstream ss;
+ ss << std::hex << (out[0] & 0xff);
+ std::string res(ss.str());
+ out = "byte: \\x" + res;
+ }
+ return out;
+}
+
+struct jllama_context {
+ bool has_next_token = false;
+ std::string generated_text;
+ std::vector generated_token_probs;
+
+ size_t num_prompt_tokens = 0;
+ size_t num_tokens_predicted = 0;
+ size_t n_past = 0;
+ size_t n_remain = 0;
+
+ std::string prompt;
+ std::vector embd;
+ std::vector last_n_tokens;
+
+ llama_model *model = nullptr;
+ llama_context *ctx = nullptr;
+ gpt_params params;
+ llama_sampling_context ctx_sampling;
+ int n_ctx;
+
+ grammar_parser::parse_state parsed_grammar;
+ llama_grammar *grammar = nullptr;
+
+ bool truncated = false;
+ bool stopped_eos = false;
+ bool stopped_word = false;
+ bool stopped_limit = false;
+ std::string stopping_word;
+ int32_t multibyte_pending = 0;
+
+ std::mutex mutex;
+
+ std::unique_lock lock() { return std::unique_lock(mutex); }
+
+ ~jllama_context() {
+ if (ctx) {
+ llama_free(ctx);
+ ctx = nullptr;
+ }
+ if (model) {
+ llama_free_model(model);
+ model = nullptr;
+ }
+ if (grammar) {
+ llama_grammar_free(grammar);
+ grammar = nullptr;
+ }
+ }
+
+ void rewind() {
+ params.antiprompt.clear();
+ params.sparams.grammar.clear();
+ num_prompt_tokens = 0;
+ num_tokens_predicted = 0;
+ generated_text = "";
+ generated_text.reserve(n_ctx);
+ generated_token_probs.clear();
+ truncated = false;
+ stopped_eos = false;
+ stopped_word = false;
+ stopped_limit = false;
+ stopping_word = "";
+ multibyte_pending = 0;
+ n_remain = 0;
+ n_past = 0;
+
+ if (grammar != nullptr) {
+ llama_grammar_free(grammar);
+ grammar = nullptr;
+ ctx_sampling = *llama_sampling_init(params.sparams);
+ }
+ }
+
+ bool loadModel(const gpt_params ¶ms_) {
+ params = params_;
+ std::tie(model, ctx) = llama_init_from_gpt_params(params);
+ if (model == nullptr) {
+ return false;
+ }
+ n_ctx = llama_n_ctx(ctx);
+ last_n_tokens.resize(n_ctx);
+ std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
+ return true;
+ }
+
+ std::vector tokenize(std::string prompt, bool add_bos) const {
+ return ::llama_tokenize(ctx, prompt, add_bos);
+ }
+
+ bool loadGrammar(JNIEnv *env) {
+ if (!params.sparams.grammar.empty()) {
+ parsed_grammar = grammar_parser::parse(params.sparams.grammar.c_str());
+ // will be empty (default) if there are parse errors
+ if (parsed_grammar.rules.empty()) {
+ log(env, GGML_LOG_LEVEL_ERROR, "grammar parse error");
+ return false;
+ }
+ grammar_parser::print_grammar(stderr, parsed_grammar);
+
+ {
+ auto it = params.sparams.logit_bias.find(llama_token_eos(model));
+ if (it != params.sparams.logit_bias.end() && it->second == -INFINITY) {
+ log(env, GGML_LOG_LEVEL_WARN, "EOS token is disabled, which will cause most grammars to fail");
+ }
+ }
+
+ std::vector grammar_rules(parsed_grammar.c_rules());
+ grammar = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
+ }
+ ctx_sampling = *llama_sampling_init(params.sparams);
+ return true;
+ }
+
+ void loadInfill(JNIEnv *env) {
+ bool suff_rm_leading_spc = true;
+ if (params.input_suffix.find_first_of(" ") == 0 && params.input_suffix.size() > 1) {
+ params.input_suffix.erase(0, 1);
+ suff_rm_leading_spc = false;
+ }
+
+ auto prefix_tokens = tokenize(params.input_prefix, false);
+ auto suffix_tokens = tokenize(params.input_suffix, false);
+ const int space_token = 29871;
+ if (suff_rm_leading_spc && suffix_tokens[0] == space_token) {
+ suffix_tokens.erase(suffix_tokens.begin());
+ }
+ prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model));
+ prefix_tokens.insert(prefix_tokens.begin(), llama_token_bos(model)); // always add BOS
+ prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(model));
+ prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end());
+ prefix_tokens.push_back(llama_token_middle(model));
+ auto prompt_tokens = prefix_tokens;
+
+ num_prompt_tokens = prompt_tokens.size();
+
+ if (params.n_keep < 0) {
+ params.n_keep = (int) num_prompt_tokens;
+ }
+ params.n_keep = std::min(params.n_ctx - 4, params.n_keep);
+
+ // if input prompt is too big, truncate like normal
+ if (num_prompt_tokens >= (size_t) params.n_ctx) {
+ // todo we probably want to cut from both sides
+ const int n_left = (params.n_ctx - params.n_keep) / 2;
+ std::vector new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
+ const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left;
+ new_tokens.insert(
+ new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end());
+ std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), last_n_tokens.begin());
+
+ log(env, GGML_LOG_LEVEL_INFO, "input truncated n_left=" + std::to_string(n_left));
+
+ truncated = true;
+ prompt_tokens = new_tokens;
+ } else {
+ const size_t ps = num_prompt_tokens;
+ std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0);
+ std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps);
+ }
+
+ // compare the evaluated prompt with the new prompt
+ n_past = common_part(embd, prompt_tokens);
+ embd = prompt_tokens;
+
+ if (n_past == num_prompt_tokens) {
+ // we have to evaluate at least 1 token to generate logits.
+ n_past--;
+ }
+
+ // since #3228 we now have to manually manage the KV cache
+ llama_kv_cache_seq_rm(ctx, 0, n_past, -1);
+
+ has_next_token = true;
+ }
+
+ void loadPrompt(JNIEnv *env) {
+ auto prompt_tokens = tokenize(prompt, true); // always add BOS
+
+ num_prompt_tokens = prompt_tokens.size();
+
+ if (params.n_keep < 0) {
+ params.n_keep = (int) num_prompt_tokens;
+ }
+ params.n_keep = std::min(n_ctx - 4, params.n_keep);
+
+ // if input prompt is too big, truncate like normal
+ if (num_prompt_tokens >= (size_t) n_ctx) {
+ const int n_left = (n_ctx - params.n_keep) / 2;
+ std::vector new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
+ const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left;
+ new_tokens.insert(
+ new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end());
+ std::copy(prompt_tokens.end() - n_ctx, prompt_tokens.end(), last_n_tokens.begin());
+
+ log(env, GGML_LOG_LEVEL_INFO, "input truncated n_left=" + std::to_string(n_left));
+
+ truncated = true;
+ prompt_tokens = new_tokens;
+ } else {
+ const size_t ps = num_prompt_tokens;
+ std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0);
+ std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps);
+ }
+
+ // compare the evaluated prompt with the new prompt
+ n_past = common_part(embd, prompt_tokens);
+
+ embd = prompt_tokens;
+ if (n_past == num_prompt_tokens) {
+ // we have to evaluate at least 1 token to generate logits.
+ n_past--;
+ }
+
+ // since #3228 we now have to manually manage the KV cache
+ llama_kv_cache_seq_rm(ctx, 0, n_past, -1);
+
+ has_next_token = true;
+ }
+
+ void beginCompletion() {
+ // number of tokens to keep when resetting context
+ n_remain = params.n_predict;
+ llama_set_rng_seed(ctx, params.seed);
+ }
+
+ completion_token_output nextToken(JNIEnv *env) {
+ completion_token_output result;
+ result.tok = -1;
+
+ if (embd.size() >= (size_t) n_ctx) {
+ // Shift context
+
+ const int n_left = n_past - params.n_keep - 1;
+ const int n_discard = n_left / 2;
+
+ llama_kv_cache_seq_rm(ctx, 0, params.n_keep + 1, params.n_keep + n_discard + 1);
+ llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
+
+ for (size_t i = params.n_keep + 1 + n_discard; i < embd.size(); i++) {
+ embd[i - n_discard] = embd[i];
+ }
+ embd.resize(embd.size() - n_discard);
+
+ n_past -= n_discard;
+
+ truncated = true;
+ log(env, GGML_LOG_LEVEL_INFO, "input truncated n_left=" + std::to_string(n_left));
+ }
+
+ bool tg = true;
+ while (n_past < embd.size()) {
+ int n_eval = (int) embd.size() - n_past;
+ tg = n_eval == 1;
+ if (n_eval > params.n_batch) {
+ n_eval = params.n_batch;
+ }
+
+ if (llama_decode(ctx, llama_batch_get_one(&embd[n_past], n_eval, n_past, 0))) {
+ log(env, GGML_LOG_LEVEL_ERROR, "failed to eval n_eval=" + std::to_string(n_eval));
+ has_next_token = false;
+ return result;
+ }
+ n_past += n_eval;
+ }
+
+ if (params.n_predict == 0) {
+ has_next_token = false;
+ result.tok = llama_token_eos(model);
+ return result;
+ }
+
+ {
+ // out of user input, sample next token
+ result.tok = llama_sampling_sample(&ctx_sampling, ctx, NULL);
+
+ llama_token_data_array candidates_p = {ctx_sampling.cur.data(), ctx_sampling.cur.size(), false};
+
+ const int32_t n_probs = params.sparams.n_probs;
+ if (params.sparams.temp <= 0 && n_probs > 0) {
+ // For llama_sample_token_greedy we need to sort candidates
+ llama_sample_softmax(ctx, &candidates_p);
+ }
+
+ for (size_t i = 0; i < std::min(candidates_p.size, (size_t) n_probs); ++i) {
+ result.probs.push_back({candidates_p.data[i].id, candidates_p.data[i].p});
+ }
+
+ llama_sampling_accept(&ctx_sampling, ctx, result.tok, true);
+ if (tg) {
+ num_tokens_predicted++;
+ }
+ }
+
+ // add it to the context
+ embd.push_back(result.tok);
+ // decrement remaining sampling budget
+ --n_remain;
+
+ if (!embd.empty() && embd.back() == llama_token_eos(model)) {
+ // stopping_word = llama_token_to_piece(ctx, embd.back());
+ has_next_token = false;
+ stopped_eos = true;
+ return result;
+ }
+
+ has_next_token = params.n_predict == -1 || n_remain != 0;
+ return result;
+ }
+
+ size_t findStoppingStrings(const std::string &text, const size_t last_token_size, const stop_type type) {
+ size_t stop_pos = std::string::npos;
+ for (const std::string &word : params.antiprompt) {
+ size_t pos;
+ if (type == STOP_FULL) {
+ const size_t tmp = word.size() + last_token_size;
+ const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0;
+ pos = text.find(word, from_pos);
+ } else {
+ pos = find_partial_stop_string(word, text);
+ }
+ if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) {
+ if (type == STOP_FULL) {
+ stopping_word = word;
+ stopped_word = true;
+ has_next_token = false;
+ }
+ stop_pos = pos;
+ }
+ }
+ return stop_pos;
+ }
+
+ completion_token_output doCompletion(JNIEnv *env) {
+ auto token_with_probs = nextToken(env);
+
+ const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_piece(ctx, token_with_probs.tok);
+ generated_text += token_text;
+
+ if (params.sparams.n_probs > 0) {
+ generated_token_probs.push_back(token_with_probs);
+ }
+
+ if (multibyte_pending > 0) {
+ multibyte_pending -= token_text.size();
+ } else if (token_text.size() == 1) {
+ const char c = token_text[0];
+ // 2-byte characters: 110xxxxx 10xxxxxx
+ if ((c & 0xE0) == 0xC0) {
+ multibyte_pending = 1;
+ // 3-byte characters: 1110xxxx 10xxxxxx 10xxxxxx
+ } else if ((c & 0xF0) == 0xE0) {
+ multibyte_pending = 2;
+ // 4-byte characters: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
+ } else if ((c & 0xF8) == 0xF0) {
+ multibyte_pending = 3;
+ } else {
+ multibyte_pending = 0;
+ }
+ }
+
+ if (multibyte_pending > 0 && !has_next_token) {
+ has_next_token = true;
+ n_remain++;
+ }
+
+ if (!has_next_token && n_remain == 0) {
+ stopped_limit = true;
+ }
+
+ return token_with_probs;
+ }
+
+ std::vector getEmbedding(JNIEnv *env) {
+ static const int n_embd = llama_n_embd(model);
+ if (!params.embedding) {
+ log(env, GGML_LOG_LEVEL_ERROR, "embedding disabled");
+ return std::vector(n_embd, 0.0f);
+ }
+ const float *data = llama_get_embeddings(ctx);
+ std::vector embedding(data, data + n_embd);
+ return embedding;
+ }
+};
+
+static gpt_params parse_model_params(JNIEnv *env, jobject jparams, jstring java_file_path) {
+ gpt_params params;
+
+ params.model = parse_jstring(env, java_file_path);
+ params.n_threads = env->GetIntField(jparams, f_n_threads);
+ params.n_ctx = env->GetIntField(jparams, f_n_ctx);
+ params.n_batch = env->GetIntField(jparams, f_n_batch);
+ params.n_gpu_layers = env->GetIntField(jparams, f_n_gpu_layers);
+ params.main_gpu = env->GetIntField(jparams, f_main_gpu);
+ params.rope_freq_base = env->GetFloatField(jparams, f_rope_freq_base);
+ params.rope_freq_scale = env->GetFloatField(jparams, f_rope_freq_scale);
+ params.mul_mat_q = env->GetBooleanField(jparams, f_mul_mat_q);
+ params.embedding = env->GetBooleanField(jparams, f_embedding);
+ params.escape = env->GetIntField(jparams, f_n_predict);
+ params.use_mmap = env->GetBooleanField(jparams, f_use_mmap);
+ params.use_mlock = env->GetBooleanField(jparams, f_use_mlock);
+ params.numa = env->GetBooleanField(jparams, f_numa);
+ params.verbose_prompt = env->GetBooleanField(jparams, f_verbose_prompt);
+
+ if (params.model_alias == "unknown") {
+ params.model_alias = params.model;
+ }
+
+ return params;
+}
+
+static void setup_infer_params(JNIEnv *env, jllama_context *llama, jobject jparams) {
+ auto ¶ms = llama->params;
+
+ params.seed = env->GetIntField(jparams, f_infer_seed);
+ params.n_predict = env->GetIntField(jparams, f_n_predict);
+ params.n_keep = env->GetIntField(jparams, f_n_keep);
+
+ auto &sparams = params.sparams;
+
+ sparams.top_k = env->GetIntField(jparams, f_top_k);
+ sparams.top_p = env->GetFloatField(jparams, f_top_p);
+ sparams.tfs_z = env->GetFloatField(jparams, f_tfs_z);
+ sparams.typical_p = env->GetFloatField(jparams, f_typical_p);
+ sparams.temp = env->GetFloatField(jparams, f_temperature);
+ sparams.penalty_repeat = env->GetFloatField(jparams, f_repeat_penalty);
+ sparams.n_prev = env->GetIntField(jparams, f_repeat_last_n);
+ sparams.penalty_freq = env->GetFloatField(jparams, f_frequency_penalty);
+ sparams.penalty_present = env->GetFloatField(jparams, f_presence_penalty);
+ sparams.penalize_nl = env->GetBooleanField(jparams, f_penalize_nl);
+ sparams.mirostat = env->GetIntField(jparams, f_mirostat);
+ sparams.mirostat_tau = env->GetFloatField(jparams, f_mirostat_tau);
+ sparams.mirostat_eta = env->GetFloatField(jparams, f_mirostat_eta);
+ sparams.n_probs = env->GetIntField(jparams, f_n_probs);
+
+ jstring j_grammar = (jstring) env->GetObjectField(jparams, f_grammar);
+ if (j_grammar != nullptr) {
+ sparams.grammar = parse_jstring(env, j_grammar);
+ env->DeleteLocalRef(j_grammar);
+ if (!llama->loadGrammar(env)) {
+ env->ThrowNew(c_engine_exception, "could not load grammar");
+ }
+ }
+
+ sparams.logit_bias.clear();
+ jboolean ignore_eos = env->GetBooleanField(jparams, f_ignore_eos);
+ if (ignore_eos) {
+ sparams.logit_bias[llama_token_eos(llama->model)] = -INFINITY;
+ }
+
+ jobject logit_bias = env->GetObjectField(jparams, f_logit_bias);
+ if (logit_bias != nullptr) {
+ jobject entry_set = env->CallObjectMethod(logit_bias, m_entry_set);
+ jobject iterator = env->CallObjectMethod(entry_set, m_set_iterator);
+ while (env->CallBooleanMethod(iterator, m_iterator_has_next)) {
+ jobject entry = env->CallObjectMethod(iterator, m_iterator_next);
+ jobject key = env->CallObjectMethod(entry, m_entry_key);
+ jobject value = env->CallObjectMethod(entry, m_entry_value);
+
+ int tok = parse_jinteger(env, key);
+ float bias = parse_jfloat(env, value);
+ sparams.logit_bias[tok] = bias;
+
+ env->DeleteLocalRef(entry);
+ env->DeleteLocalRef(key);
+ env->DeleteLocalRef(value);
+ }
+ }
+
+ params.antiprompt.clear();
+ jobjectArray antiprompt = (jobjectArray) env->GetObjectField(jparams, f_antiprompt);
+ if (antiprompt != nullptr) {
+ jsize array_length = env->GetArrayLength(antiprompt);
+ for (jsize i = 0; i < array_length; i++) {
+ jstring java_string = (jstring) env->GetObjectArrayElement(antiprompt, i);
+ if (java_string != nullptr) {
+ std::string string = parse_jstring(env, java_string);
+ params.antiprompt.push_back(string);
+ env->DeleteLocalRef(java_string);
+ }
+ }
+ }
+
+ llama->ctx_sampling = *llama_sampling_init(params.sparams);
+}
+
+static void setup_answering(JNIEnv *env, jllama_context *llama, jstring prompt, jobject params) {
+ llama->prompt = parse_jstring(env, prompt);
+ llama->params.input_prefix = "";
+ llama->params.input_suffix = "";
+ setup_infer_params(env, llama, params);
+}
+
+static void setup_infilling(JNIEnv *env, jllama_context *llama, jstring prefix, jstring suffix, jobject params) {
+ llama->prompt = "";
+ llama->params.input_prefix = parse_jstring(env, prefix);
+ llama->params.input_suffix = parse_jstring(env, suffix);
+ setup_infer_params(env, llama, params);
+}
+
+JNIEXPORT jlong JNICALL Java_ai_djl_llama_jni_LlamaLibrary_loadModel(
+ JNIEnv *env, jclass clazz, jstring file_path, jobject jparams) {
+ gpt_params params = parse_model_params(env, jparams, file_path);
+
+ jllama_context *llama = new jllama_context;
+ llama_backend_init(false);
+
+ if (!llama->loadModel(params)) {
+ env->ThrowNew(c_engine_exception, "could not load model from given file path");
+ return 0;
+ }
+
+ return reinterpret_cast(llama);
+}
+
+JNIEXPORT void JNICALL Java_ai_djl_llama_jni_LlamaLibrary_generate(
+ JNIEnv *env, jclass clazz, jlong handle, jstring prompt, jobject params) {
+ auto *llama = reinterpret_cast(handle);
+
+ llama->rewind();
+ llama_reset_timings(llama->ctx);
+ setup_answering(env, llama, prompt, params);
+
+ llama->loadPrompt(env);
+ llama->beginCompletion();
+}
+
+JNIEXPORT void JNICALL Java_ai_djl_llama_jni_LlamaLibrary_infill(
+ JNIEnv *env, jclass clazz, jlong handle, jstring prefix, jstring suffix, jobject params) {
+ auto *llama = reinterpret_cast(handle);
+
+ llama->rewind();
+
+ llama_reset_timings(llama->ctx);
+
+ setup_infilling(env, llama, prefix, suffix, params);
+
+ llama->loadInfill(env);
+ llama->beginCompletion();
+}
+
+JNIEXPORT jobject JNICALL Java_ai_djl_llama_jni_LlamaLibrary_getNext(
+ JNIEnv *env, jclass clazz, jlong handle, jlong sent_count, jlong sent_token_probs_index) {
+ auto *llama = reinterpret_cast(handle);
+
+ completion_token_output token_with_probs;
+ while (llama->has_next_token) {
+ token_with_probs = llama->doCompletion(env);
+ if (token_with_probs.tok >= 0 && llama->multibyte_pending <= 0) {
+ break;
+ }
+ }
+ const std::string token_text = llama_token_to_piece(llama->ctx, token_with_probs.tok);
+
+ size_t pos = std::min((size_t) sent_count, llama->generated_text.size());
+
+ const std::string str_test = llama->generated_text.substr(pos);
+ bool is_stop_full = false;
+ size_t stop_pos = llama->findStoppingStrings(str_test, token_text.size(), STOP_FULL);
+ if (stop_pos != std::string::npos) {
+ is_stop_full = true;
+ llama->generated_text.erase(llama->generated_text.begin() + pos + stop_pos, llama->generated_text.end());
+ pos = std::min((size_t) sent_count, llama->generated_text.size());
+ } else {
+ is_stop_full = false;
+ stop_pos = llama->findStoppingStrings(str_test, token_text.size(), STOP_PARTIAL);
+ }
+
+ std::string to_send;
+ if (stop_pos == std::string::npos ||
+ // Send rest of the text if we are at the end of the generation
+ (!llama->has_next_token && !is_stop_full && stop_pos > 0)) {
+ to_send = llama->generated_text.substr(pos, std::string::npos);
+
+ sent_count += to_send.size();
+ std::vector probs_output = {};
+
+ if (llama->params.sparams.n_probs > 0) {
+ const std::vector to_send_toks = llama_tokenize(llama->ctx, to_send, false);
+ size_t probs_pos = std::min((size_t) sent_token_probs_index, llama->generated_token_probs.size());
+ size_t probs_stop_pos =
+ std::min(sent_token_probs_index + to_send_toks.size(), llama->generated_token_probs.size());
+ if (probs_pos < probs_stop_pos) {
+ probs_output = std::vector(
+ llama->generated_token_probs.begin() + probs_pos, llama->generated_token_probs.begin() + probs_stop_pos);
+ }
+ sent_token_probs_index = probs_stop_pos;
+ }
+ } else {
+ to_send = "";
+ }
+
+ jobject o_probabilities = env->NewObject(c_hash_map, cc_hash_map);
+ for (const auto &tp : token_with_probs.probs) {
+ jobject jtoken = env->NewObject(c_integer, cc_integer, tp.tok);
+ jobject jprob = env->NewObject(c_float, cc_float, tp.prob);
+ env->CallObjectMethod(o_probabilities, m_map_put, jtoken, jprob);
+ }
+
+ jbyteArray jbytes = parse_jbytes(env, to_send);
+ return env->NewObject(c_token, cc_token, token_with_probs.tok, jbytes, o_probabilities, sent_count,
+ sent_token_probs_index, llama->has_next_token);
+}
+
+JNIEXPORT jfloatArray JNICALL Java_ai_djl_llama_jni_LlamaLibrary_embed(
+ JNIEnv *env, jclass clazz, jlong handle, jstring java_prompt) {
+ auto *llama = reinterpret_cast(handle);
+
+ llama->rewind();
+ llama_reset_timings(llama->ctx);
+ llama->prompt = parse_jstring(env, java_prompt);
+ llama->params.n_predict = 0;
+ llama->loadPrompt(env);
+ llama->beginCompletion();
+ llama->doCompletion(env);
+
+ static const int n_embd = llama_n_embd(llama->model);
+ const float *data = llama_get_embeddings(llama->ctx);
+ std::vector embedding(data, data + n_embd);
+
+ jfloatArray java_embedding = env->NewFloatArray(embedding.size());
+ env->SetFloatArrayRegion(java_embedding, 0, embedding.size(), reinterpret_cast(embedding.data()));
+
+ return java_embedding;
+}
+
+JNIEXPORT jintArray JNICALL Java_ai_djl_llama_jni_LlamaLibrary_encode(
+ JNIEnv *env, jclass clazz, jlong handle, jstring jprompt) {
+ auto *llama = reinterpret_cast(handle);
+
+ std::string prompt = parse_jstring(env, jprompt);
+ std::vector tokens = llama->tokenize(prompt, false);
+
+ jintArray java_tokens = env->NewIntArray(tokens.size());
+ env->SetIntArrayRegion(java_tokens, 0, tokens.size(), reinterpret_cast(tokens.data()));
+
+ return java_tokens;
+}
+
+JNIEXPORT jbyteArray JNICALL Java_ai_djl_llama_jni_LlamaLibrary_decodeBytes(
+ JNIEnv *env, jclass clazz, jlong handle, jintArray java_tokens) {
+ auto *llama = reinterpret_cast(handle);
+
+ jsize length = env->GetArrayLength(java_tokens);
+ jint *elements = env->GetIntArrayElements(java_tokens, nullptr);
+ std::vector tokens(elements, elements + length);
+ std::string text = tokens_to_str(llama->ctx, tokens.cbegin(), tokens.cend());
+
+ env->ReleaseIntArrayElements(java_tokens, elements, 0);
+
+ return parse_jbytes(env, text);
+}
+
+JNIEXPORT void JNICALL Java_ai_djl_llama_jni_LlamaLibrary_delete(JNIEnv *env, jclass clazz, jlong handle) {
+ auto *llama = reinterpret_cast(handle);
+ delete llama;
+}
diff --git a/engines/llama/src/main/resources/META-INF/services/ai.djl.engine.EngineProvider b/engines/llama/src/main/resources/META-INF/services/ai.djl.engine.EngineProvider
new file mode 100644
index 00000000000..d2f8ca8e42c
--- /dev/null
+++ b/engines/llama/src/main/resources/META-INF/services/ai.djl.engine.EngineProvider
@@ -0,0 +1 @@
+ai.djl.llama.engine.LlamaEngineProvider
diff --git a/engines/llama/src/main/resources/META-INF/services/ai.djl.repository.zoo.ZooProvider b/engines/llama/src/main/resources/META-INF/services/ai.djl.repository.zoo.ZooProvider
new file mode 100644
index 00000000000..92f6245340f
--- /dev/null
+++ b/engines/llama/src/main/resources/META-INF/services/ai.djl.repository.zoo.ZooProvider
@@ -0,0 +1 @@
+ai.djl.llama.zoo.LlamaZooProvider
diff --git a/engines/llama/src/test/java/ai/djl/llama/engine/LlamaInputTest.java b/engines/llama/src/test/java/ai/djl/llama/engine/LlamaInputTest.java
new file mode 100644
index 00000000000..429cd569392
--- /dev/null
+++ b/engines/llama/src/test/java/ai/djl/llama/engine/LlamaInputTest.java
@@ -0,0 +1,101 @@
+/*
+ * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+package ai.djl.llama.engine;
+
+import ai.djl.llama.engine.LlamaInput.Parameters;
+import ai.djl.llama.jni.InputParameters;
+import ai.djl.util.JsonUtils;
+
+import org.testng.Assert;
+import org.testng.annotations.Test;
+
+import java.io.IOException;
+import java.io.Reader;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.Map;
+
+public class LlamaInputTest {
+
+ @Test
+ public void testInputParameters() throws IOException {
+ Path file = Paths.get("src/test/resources/inputs.json");
+ try (Reader reader = Files.newBufferedReader(file)) {
+ LlamaInput in = JsonUtils.GSON.fromJson(reader, LlamaInput.class);
+ checkParameters(in);
+ }
+
+ Parameters param = new Parameters();
+ LlamaInput in = new LlamaInput();
+ in.setInputs("prompt");
+ in.setPrefix("prefix");
+ in.setSuffix("suffix");
+ in.setParameters(param);
+ param.setMaxNewTokens(2);
+ param.setNumberKeep(2);
+ param.setNumberProbabilities(2);
+ param.setTopK(2);
+ param.setTopP(2f);
+ param.setTfsZ(2f);
+ param.setTypicalP(2f);
+ param.setTemperature(2f);
+ param.setRepeatPenalty(2f);
+ param.setRepeatLastN(2);
+ param.setFrequencyPenalty(2f);
+ param.setFrequencyPenalty(2f);
+ param.setPresencePenalty(2f);
+ param.setPenalizeNl(true);
+ param.setIgnoreEos(true);
+ param.setMirostat(2);
+ param.setMirostatTau(2f);
+ param.setMirostatEta(2f);
+ param.setNumberBeams(5);
+ param.setSeed(2);
+ Map logitBias = Map.of(2, 0.4f, 3, 0.5f);
+ param.setLogitBias(logitBias);
+ param.setGrammar("grammar");
+ param.setAntiPrompt(new String[] {"User: "});
+ checkParameters(in);
+ }
+
+ private void checkParameters(LlamaInput in) {
+ InputParameters param = in.getParameters().toInputParameters();
+ Assert.assertEquals(param.getMaxNewTokens(), 2);
+ Assert.assertEquals(param.getNumberKeep(), 2);
+ Assert.assertEquals(param.getNumberProbabilities(), 2);
+ Assert.assertEquals(param.getTopK(), 2);
+ Assert.assertEquals(param.getTopP(), 2f);
+ Assert.assertEquals(param.getTfsZ(), 2f);
+ Assert.assertEquals(param.getTypicalP(), 2f);
+ Assert.assertEquals(param.getTemperature(), 2f);
+ Assert.assertEquals(param.getRepeatPenalty(), 2f);
+ Assert.assertEquals(param.getRepeatLastN(), 2);
+ Assert.assertEquals(param.getFrequencyPenalty(), 2f);
+ Assert.assertEquals(param.getFrequencyPenalty(), 2f);
+ Assert.assertEquals(param.getPresencePenalty(), 2f);
+ Assert.assertTrue(param.isPenalizeNl());
+ Assert.assertTrue(param.isIgnoreEos());
+ Assert.assertEquals(param.getMirostat(), 2);
+ Assert.assertEquals(param.getMirostatTau(), 2f);
+ Assert.assertEquals(param.getMirostatEta(), 2f);
+ Assert.assertEquals(param.getNumberBeams(), 5);
+ Assert.assertEquals(param.getSeed(), 2);
+ Map logitBias = param.getLogitBias();
+ Assert.assertNotNull(logitBias);
+ Assert.assertEquals(logitBias.size(), 2);
+ Assert.assertEquals(logitBias.get(2), 0.4f);
+ Assert.assertNotNull(param.getGrammar());
+ Assert.assertNotNull(param.getAntiPrompt()[0], "User: ");
+ }
+}
diff --git a/engines/llama/src/test/java/ai/djl/llama/engine/LlamaTest.java b/engines/llama/src/test/java/ai/djl/llama/engine/LlamaTest.java
new file mode 100644
index 00000000000..7b372ee4258
--- /dev/null
+++ b/engines/llama/src/test/java/ai/djl/llama/engine/LlamaTest.java
@@ -0,0 +1,143 @@
+/*
+ * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+package ai.djl.llama.engine;
+
+import ai.djl.ModelException;
+import ai.djl.engine.Engine;
+import ai.djl.engine.StandardCapabilities;
+import ai.djl.inference.Predictor;
+import ai.djl.llama.jni.Token;
+import ai.djl.llama.jni.TokenIterator;
+import ai.djl.modality.Input;
+import ai.djl.modality.Output;
+import ai.djl.ndarray.NDManager;
+import ai.djl.repository.zoo.Criteria;
+import ai.djl.repository.zoo.ZooModel;
+import ai.djl.testing.TestRequirements;
+import ai.djl.training.util.DownloadUtils;
+import ai.djl.translate.TranslateException;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.testng.Assert;
+import org.testng.annotations.AfterClass;
+import org.testng.annotations.BeforeClass;
+import org.testng.annotations.Test;
+
+import java.io.IOException;
+import java.net.URI;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+
+public class LlamaTest {
+
+ private static final Logger logger = LoggerFactory.getLogger(LlamaTest.class);
+
+ @BeforeClass
+ public void setUp() {
+ System.setProperty("DJL_CACHE_DIR", "build/cache");
+ }
+
+ @AfterClass
+ public void tierDown() {
+ System.clearProperty("DJL_CACHE_DIR");
+ }
+
+ @Test
+ public void testLlamaVersion() {
+ Engine engine = Engine.getEngine("Llama");
+ Assert.assertEquals(engine.getVersion(), "b1696-" + Engine.getDjlVersion());
+ Assert.assertNotNull(engine.toString());
+ Assert.assertEquals(engine.getRank(), 10);
+ Assert.assertFalse(engine.hasCapability(StandardCapabilities.CUDA));
+ Assert.assertNull(engine.getAlternativeEngine());
+ try (NDManager manager = engine.newBaseManager()) {
+ Assert.assertNotNull(manager);
+ }
+ }
+
+ @Test
+ public void testLlama() throws TranslateException, ModelException, IOException {
+ TestRequirements.nightly();
+ downloadModel();
+ Path path = Paths.get("models");
+ Criteria criteria =
+ Criteria.builder()
+ .setTypes(String.class, TokenIterator.class)
+ .optModelPath(path)
+ .optModelName("tinyllama-1.1b-1t-openorca.Q4_K_M")
+ .optEngine("Llama")
+ .optOption("number_gpu_layers", "43")
+ .optTranslatorFactory(new LlamaTranslatorFactory())
+ .build();
+
+ String prompt =
+ "{\"inputs\": \"<|im_start|>system\n"
+ + "{system_message}<|im_end|>\n"
+ + "<|im_start|>user\n"
+ + "{prompt}<|im_end|>\n"
+ + "<|im_start|>assistant\", \"parameters\": {\"max_new_tokens\": 10}}";
+ try (ZooModel model = criteria.loadModel();
+ Predictor predictor = model.newPredictor()) {
+ TokenIterator it = predictor.predict(prompt);
+ StringBuilder sb = new StringBuilder();
+ while (it.hasNext()) {
+ Token token = it.next();
+ Assert.assertNotNull(token.getText());
+ Assert.assertTrue(token.getToken() >= 0);
+ Assert.assertNotNull(token.getProbabilities());
+ sb.append(token.getText());
+ logger.info("{}", token);
+ }
+ Assert.assertTrue(sb.length() > 1);
+ }
+ }
+
+ @Test
+ public void testLlamaInfill() throws TranslateException, ModelException, IOException {
+ TestRequirements.nightly();
+ downloadModel();
+ Path path = Paths.get("models/tinyllama-1.1b-1t-openorca.Q4_K_M.gguf");
+ Criteria criteria =
+ Criteria.builder()
+ .setTypes(Input.class, Output.class)
+ .optModelPath(path)
+ .optOption("number_gpu_layers", "43")
+ .optEngine("Llama")
+ .optTranslatorFactory(new LlamaTranslatorFactory())
+ .build();
+
+ String prompt =
+ "{\n"
+ + " \"prefix\":\"def remove_non_ascii(s: str) -> str:\n\",\n"
+ + " \"suffix\":\"\n return result\n\",\n"
+ + " \"parameters\":{\n"
+ + " \"max_new_tokens\": 10"
+ + " }\n"
+ + "}";
+ try (ZooModel model = criteria.loadModel();
+ Predictor predictor = model.newPredictor()) {
+ Input in = new Input();
+ in.add("data", prompt);
+ Output out = predictor.predict(in);
+ Assert.assertNotNull(out.getData().getAsString());
+ }
+ }
+
+ private void downloadModel() throws IOException {
+ String url =
+ "https://huggingface.co/TheBloke/TinyLlama-1.1B-1T-OpenOrca-GGUF/resolve/main/tinyllama-1.1b-1t-openorca.Q4_K_M.gguf?download=true";
+ Path dir = Paths.get("models/tinyllama-1.1b-1t-openorca.Q4_K_M.gguf");
+ DownloadUtils.download(URI.create(url).toURL(), dir, null);
+ }
+}
diff --git a/engines/llama/src/test/java/ai/djl/llama/engine/package-info.java b/engines/llama/src/test/java/ai/djl/llama/engine/package-info.java
new file mode 100644
index 00000000000..b2ee786419f
--- /dev/null
+++ b/engines/llama/src/test/java/ai/djl/llama/engine/package-info.java
@@ -0,0 +1,14 @@
+/*
+ * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+/** Contains test classes for llama engine. */
+package ai.djl.llama.engine;
diff --git a/engines/llama/src/test/java/ai/djl/llama/zoo/LlamaModelZooTest.java b/engines/llama/src/test/java/ai/djl/llama/zoo/LlamaModelZooTest.java
new file mode 100644
index 00000000000..fab7bacb9e3
--- /dev/null
+++ b/engines/llama/src/test/java/ai/djl/llama/zoo/LlamaModelZooTest.java
@@ -0,0 +1,62 @@
+/*
+ * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+package ai.djl.llama.zoo;
+
+import ai.djl.repository.zoo.ModelLoader;
+import ai.djl.repository.zoo.ModelZoo;
+import ai.djl.util.Utils;
+
+import org.testng.Assert;
+import org.testng.annotations.Test;
+
+import java.nio.file.Paths;
+import java.util.Collection;
+
+public class LlamaModelZooTest {
+
+ @Test
+ public void testLlamaModelZoo() {
+ System.setProperty("DJL_CACHE_DIR", "build/cache");
+ Utils.deleteQuietly(Paths.get("build/cache/cache"));
+ try {
+ ModelZoo zoo = ModelZoo.getModelZoo("ai.djl.huggingface.gguf");
+ Collection models = zoo.getModelLoaders();
+ Assert.assertFalse(models.isEmpty());
+ Assert.assertEquals(zoo.getSupportedEngines().size(), 1);
+ ModelLoader loader = zoo.getModelLoader("TinyLlama/TinyLlama-1.1B-Chat-v0.6");
+ Assert.assertNotNull(loader);
+
+ ModelZoo llamaModelZoo = new LlamaModelZoo();
+ Assert.assertFalse(llamaModelZoo.getModelLoaders().isEmpty());
+ } finally {
+ System.clearProperty("DJL_CACHE_DIR");
+ }
+ }
+
+ @Test
+ public void testOffLine() {
+ System.setProperty("DJL_CACHE_DIR", "build/cache");
+ System.setProperty("ai.djl.offline", "true");
+ Utils.deleteQuietly(Paths.get("build/cache/cache"));
+ try {
+ // static variables cannot not be initialized properly if directly use LlamaModelZoo()
+ ModelZoo.getModelZoo("ai.djl.huggingface.gguf");
+
+ ModelZoo zoo = new LlamaModelZoo();
+ Assert.assertFalse(zoo.getModelLoaders().isEmpty());
+ } finally {
+ System.clearProperty("DJL_CACHE_DIR");
+ System.clearProperty("ai.djl.offline");
+ }
+ }
+}
diff --git a/engines/llama/src/test/java/ai/djl/llama/zoo/package-info.java b/engines/llama/src/test/java/ai/djl/llama/zoo/package-info.java
new file mode 100644
index 00000000000..145b2ddcca9
--- /dev/null
+++ b/engines/llama/src/test/java/ai/djl/llama/zoo/package-info.java
@@ -0,0 +1,14 @@
+/*
+ * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+/** Contains test classes for llama model zoo. */
+package ai.djl.llama.zoo;
diff --git a/engines/llama/src/test/resources/inputs.json b/engines/llama/src/test/resources/inputs.json
new file mode 100644
index 00000000000..ab77386e1b6
--- /dev/null
+++ b/engines/llama/src/test/resources/inputs.json
@@ -0,0 +1,33 @@
+{
+ "prefix": "def remove_non_ascii(s: str) -> str:",
+ "suffix": " return result",
+ "parameters": {
+ "max_new_tokens": 2,
+ "number_keep": 2,
+ "number_probabilities": 2,
+ "top_k": 2,
+ "top_p": 2,
+ "tfs_z": 2,
+ "typical_p": 2,
+ "temperature": 2,
+ "repeat_penalty": 2,
+ "repeat_last_n": 2,
+ "frequency_penalty": 2,
+ "presence_penalty": 2,
+ "penalize_nl": true,
+ "ignore_eos": true,
+ "mirostat": 2,
+ "mirostat_tau": 2,
+ "mirostat_eta": 2,
+ "number_beams": 5,
+ "seed": 2,
+ "logit_bias": {
+ "2": 0.4,
+ "5": 0.6
+ },
+ "grammar": "root ::= (expr \"=\" term \"\\n\")+\nexpr ::= term ([-+*/] term)*\nterm ::= [0-9]",
+ "anti_prompt": [
+ "User: "
+ ]
+ }
+}
diff --git a/engines/ml/lightgbm/README.md b/engines/ml/lightgbm/README.md
index 3ea950c8935..74ab3eba411 100644
--- a/engines/ml/lightgbm/README.md
+++ b/engines/ml/lightgbm/README.md
@@ -36,13 +36,13 @@ LightGBM can only run on top of the Linux/Mac/Windows machine using x86_64.
## Installation
You can pull the LightGBM engine from the central Maven repository by including the following dependency:
-- ai.djl.ml.lightgbm:lightgbm:0.23.0
+- ai.djl.ml.lightgbm:lightgbm:0.26.0
```xml
ai.djl.ml.lightgbm
lightgbm
- 0.23.0
+ 0.26.0
runtime
```
diff --git a/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngineProvider.java b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngineProvider.java
index a253ce3d246..583cd8132b2 100644
--- a/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngineProvider.java
+++ b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngineProvider.java
@@ -18,7 +18,8 @@
/** {@code LgbmEngineProvider} is the LightGBM implementation of {@link EngineProvider}. */
public class LgbmEngineProvider implements EngineProvider {
- private static volatile Engine engine; // NOPMD
+ private volatile Engine engine; // NOPMD
+ private volatile boolean initialized; // NOPMD
/** {@inheritDoc} */
@Override
@@ -35,9 +36,12 @@ public int getEngineRank() {
/** {@inheritDoc} */
@Override
public Engine getEngine() {
- if (engine == null) {
+ if (!initialized) {
synchronized (LgbmEngineProvider.class) {
- engine = LgbmEngine.newInstance();
+ if (!initialized) {
+ initialized = true;
+ engine = LgbmEngine.newInstance();
+ }
}
}
return engine;
diff --git a/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmSymbolBlock.java b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmSymbolBlock.java
index 0bb92645a89..826b1a0f900 100644
--- a/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmSymbolBlock.java
+++ b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmSymbolBlock.java
@@ -46,6 +46,7 @@ public class LgbmSymbolBlock extends AbstractSymbolBlock implements AutoCloseabl
* @param iterations the number of iterations the model was trained for
* @param handle the Booster handle
*/
+ @SuppressWarnings("this-escape")
public LgbmSymbolBlock(LgbmNDManager manager, int iterations, SWIGTYPE_p_p_void handle) {
this.handle = new AtomicReference<>(handle);
this.iterations = iterations;
diff --git a/engines/ml/xgboost/README.md b/engines/ml/xgboost/README.md
index d69f1830193..df0a7897e3c 100644
--- a/engines/ml/xgboost/README.md
+++ b/engines/ml/xgboost/README.md
@@ -37,13 +37,13 @@ XGBoost can only run on top of the Linux/Mac machine. User can build from source
## Installation
You can pull the XGBoost engine from the central Maven repository by including the following dependency:
-- ai.djl.ml.xgboost:xgboost:0.23.0
+- ai.djl.ml.xgboost:xgboost:0.26.0
```xml
ai.djl.ml.xgboost
xgboost
- 0.23.0
+ 0.26.0
runtime
```
diff --git a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbEngineProvider.java b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbEngineProvider.java
index 19cba32cc71..8b534d5196c 100644
--- a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbEngineProvider.java
+++ b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbEngineProvider.java
@@ -18,7 +18,8 @@
/** {@code XgbEngineProvider} is the XGBoost implementation of {@link EngineProvider}. */
public class XgbEngineProvider implements EngineProvider {
- private static volatile Engine engine; // NOPMD
+ private volatile Engine engine; // NOPMD
+ private volatile boolean initialized; // NOPMD
/** {@inheritDoc} */
@Override
@@ -35,9 +36,12 @@ public int getEngineRank() {
/** {@inheritDoc} */
@Override
public Engine getEngine() {
- if (engine == null) {
+ if (!initialized) {
synchronized (XgbEngineProvider.class) {
- engine = XgbEngine.newInstance();
+ if (!initialized) {
+ initialized = true;
+ engine = XgbEngine.newInstance();
+ }
}
}
return engine;
diff --git a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbModel.java b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbModel.java
index bf41acb9b6c..1b3c5ae277f 100644
--- a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbModel.java
+++ b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbModel.java
@@ -80,6 +80,8 @@ private Path findModelFile(String prefix) {
String fileName = file.toFile().getName();
if (fileName.endsWith(".json")) {
modelName = fileName.substring(0, fileName.length() - 5);
+ } else if (fileName.endsWith(".xgb")) {
+ modelName = fileName.substring(0, fileName.length() - 4);
} else {
modelName = fileName;
}
@@ -90,13 +92,22 @@ private Path findModelFile(String prefix) {
}
Path modelFile = modelDir.resolve(prefix);
if (Files.notExists(modelFile) || !Files.isRegularFile(modelFile)) {
- if (prefix.endsWith(".json")) {
+ if (prefix.endsWith(".json") || prefix.endsWith(".xgb")) {
return null;
}
modelFile = modelDir.resolve(prefix + ".json");
- if (Files.notExists(modelFile) || !Files.isRegularFile(modelFile)) {
- return null;
+ if (Files.isRegularFile(modelFile)) {
+ return modelFile;
+ }
+ modelFile = modelDir.resolve(prefix + ".xgb");
+ if (Files.isRegularFile(modelFile)) {
+ return modelFile;
+ }
+ modelFile = modelDir.resolve("model.xgb");
+ if (Files.isRegularFile(modelFile)) {
+ return modelFile;
}
+ return null;
}
return modelFile;
}
diff --git a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbNDManager.java b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbNDManager.java
index 3b56cbca241..81f9708e72b 100644
--- a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbNDManager.java
+++ b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbNDManager.java
@@ -39,6 +39,7 @@ public class XgbNDManager extends BaseNDManager {
private static final XgbNDManager SYSTEM_MANAGER = new SystemManager();
private float missingValue = Float.NaN;
+ private int nthread = 1;
private XgbNDManager(NDManager parent, Device device) {
super(parent, device);
@@ -57,6 +58,15 @@ public void setMissingValue(float missingValue) {
this.missingValue = missingValue;
}
+ /**
+ * Sets the default number of threads.
+ *
+ * @param nthread the default number of threads
+ */
+ public void setNthread(int nthread) {
+ this.nthread = nthread;
+ }
+
/** {@inheritDoc} */
@Override
public ByteBuffer allocateDirect(int capacity) {
@@ -166,7 +176,7 @@ public NDArray createCSR(Buffer buffer, long[] indptr, long[] indices, Shape sha
int[] intIndices = Arrays.stream(indices).mapToInt(Math::toIntExact).toArray();
float[] data = new float[buffer.remaining()];
((FloatBuffer) buffer).get(data);
- long handle = JniUtils.createDMatrixCSR(indptr, intIndices, data);
+ long handle = JniUtils.createDMatrixCSR(indptr, intIndices, data, missingValue, nthread);
return new XgbNDArray(this, alternativeManager, handle, shape, SparseFormat.CSR);
}
diff --git a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbSymbolBlock.java b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbSymbolBlock.java
index 1e2bcddd999..43a9e129dea 100644
--- a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbSymbolBlock.java
+++ b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbSymbolBlock.java
@@ -45,6 +45,7 @@ public class XgbSymbolBlock extends AbstractSymbolBlock implements AutoCloseable
* @param manager the manager to use for the block
* @param handle the Booster handle
*/
+ @SuppressWarnings("this-escape")
public XgbSymbolBlock(XgbNDManager manager, long handle) {
this.handle = new AtomicReference<>(handle);
this.manager = manager;
diff --git a/engines/ml/xgboost/src/main/java/ml/dmlc/xgboost4j/java/JniUtils.java b/engines/ml/xgboost/src/main/java/ml/dmlc/xgboost4j/java/JniUtils.java
index fefbe7f0716..eb071552fd0 100644
--- a/engines/ml/xgboost/src/main/java/ml/dmlc/xgboost4j/java/JniUtils.java
+++ b/engines/ml/xgboost/src/main/java/ml/dmlc/xgboost4j/java/JniUtils.java
@@ -67,9 +67,12 @@ public static long createDMatrix(ColumnBatch columnBatch, float missing, int nth
return handles[0];
}
- public static long createDMatrixCSR(long[] indptr, int[] indices, float[] array) {
+ public static long createDMatrixCSR(
+ long[] indptr, int[] indices, float[] array, float missing, int nthread) {
long[] handles = new long[1];
- checkCall(XGBoostJNI.XGDMatrixCreateFromCSREx(indptr, indices, array, 0, handles));
+ checkCall(
+ XGBoostJNI.XGDMatrixCreateFromCSR(
+ indptr, indices, array, 0, missing, nthread, handles));
return handles[0];
}
diff --git a/engines/ml/xgboost/src/test/java/ai/djl/ml/xgboost/XgbModelTest.java b/engines/ml/xgboost/src/test/java/ai/djl/ml/xgboost/XgbModelTest.java
index 0b09ed6807c..acbfa998867 100644
--- a/engines/ml/xgboost/src/test/java/ai/djl/ml/xgboost/XgbModelTest.java
+++ b/engines/ml/xgboost/src/test/java/ai/djl/ml/xgboost/XgbModelTest.java
@@ -53,7 +53,7 @@ public void downloadXGBoostModel() throws IOException {
@Test
public void testVersion() {
Engine engine = Engine.getEngine("XGBoost");
- Assert.assertEquals("1.7.5", engine.getVersion());
+ Assert.assertEquals("2.0.3", engine.getVersion());
}
/*
@@ -93,6 +93,7 @@ public void testNDArray() {
try (XgbNDManager manager =
(XgbNDManager) XgbNDManager.getSystemManager().newSubManager()) {
manager.setMissingValue(Float.NaN);
+ manager.setNthread(1);
NDArray zeros = manager.zeros(new Shape(1, 2));
Assert.expectThrows(UnsupportedOperationException.class, zeros::toFloatArray);
diff --git a/engines/mxnet/jnarator/build.gradle b/engines/mxnet/jnarator/build.gradle
index b9cc0d4cd5f..b9fd8ceab14 100644
--- a/engines/mxnet/jnarator/build.gradle
+++ b/engines/mxnet/jnarator/build.gradle
@@ -17,6 +17,11 @@ dependencies {
checkstyleMain.source = 'src/main/java'
pmdMain.source = 'src/main/java'
+compileJava {
+ options.compilerArgs.clear()
+ options.compilerArgs << "--release" << "11" << "-proc:none" << "-Xlint:all,-options,-static"
+}
+
jar {
manifest {
attributes (
diff --git a/engines/mxnet/jnarator/src/main/java/ai/djl/mxnet/jnarator/JnaGenerator.java b/engines/mxnet/jnarator/src/main/java/ai/djl/mxnet/jnarator/JnaGenerator.java
index 3105ec9cd48..ba3e18fea3b 100644
--- a/engines/mxnet/jnarator/src/main/java/ai/djl/mxnet/jnarator/JnaGenerator.java
+++ b/engines/mxnet/jnarator/src/main/java/ai/djl/mxnet/jnarator/JnaGenerator.java
@@ -276,6 +276,7 @@ public void writeNativeSize() throws IOException {
writer.append(" public NativeSizeByReference() {\n");
writer.append(" this(new NativeSize(0));\n");
writer.append(" }\n\n");
+ writer.append(" @SuppressWarnings(\"this-escape\")\n");
writer.append(" public NativeSizeByReference(NativeSize value) {\n");
writer.append(" super(NativeSize.SIZE);\n");
writer.append(" setValue(value);\n");
diff --git a/engines/mxnet/mxnet-engine/README.md b/engines/mxnet/mxnet-engine/README.md
index cef559f1e31..66b2c98adc1 100644
--- a/engines/mxnet/mxnet-engine/README.md
+++ b/engines/mxnet/mxnet-engine/README.md
@@ -7,7 +7,7 @@ This module contains the Deep Java Library (DJL) EngineProvider for Apache MXNet
We don't recommend that developers use classes in this module directly. Use of these classes
will couple your code with Apache MXNet and make switching between engines difficult. Even so,
developers are not restricted from using engine-specific features. For more information,
-see [NDManager#invoke()](https://javadoc.io/static/ai.djl/api/0.23.0/ai/djl/ndarray/NDManager.html#invoke-java.lang.String-ai.djl.ndarray.NDArray:A-ai.djl.ndarray.NDArray:A-ai.djl.util.PairList-).
+see [NDManager#invoke()](https://javadoc.io/static/ai.djl/api/0.26.0/ai/djl/ndarray/NDManager.html#invoke-java.lang.String-ai.djl.ndarray.NDArray:A-ai.djl.ndarray.NDArray:A-ai.djl.util.PairList-).
## Documentation
@@ -33,7 +33,7 @@ You can pull the MXNet engine from the central Maven repository by including the
ai.djl.mxnet
mxnet-engine
- 0.23.0
+ 0.26.0
runtime
```
diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/CachedOp.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/CachedOp.java
index 62398b1868e..b1ca8e49aa4 100644
--- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/CachedOp.java
+++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/CachedOp.java
@@ -63,6 +63,7 @@ public class CachedOp extends NativeResource {
* @param dataIndices the input data names required by the model and their corresponding
* location
*/
+ @SuppressWarnings("this-escape")
public CachedOp(
Pointer handle,
MxNDManager manager,
diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngineProvider.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngineProvider.java
index f30a6a89252..2a5ab970560 100644
--- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngineProvider.java
+++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngineProvider.java
@@ -18,7 +18,8 @@
/** {@code MxEngineProvider} is the MXNet implementation of {@link EngineProvider}. */
public class MxEngineProvider implements EngineProvider {
- private static volatile Engine engine; // NOPMD
+ private volatile Engine engine; // NOPMD
+ private volatile boolean initialized; // NOPMD
/** {@inheritDoc} */
@Override
@@ -35,9 +36,12 @@ public int getEngineRank() {
/** {@inheritDoc} */
@Override
public Engine getEngine() {
- if (engine == null) {
+ if (!initialized) {
synchronized (MxEngineProvider.class) {
- engine = MxEngine.newInstance();
+ if (!initialized) {
+ initialized = true;
+ engine = MxEngine.newInstance();
+ }
}
}
return engine;
diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java
index 87ccba78e96..8b884b3993a 100644
--- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java
+++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java
@@ -888,6 +888,13 @@ public NDArray atan() {
return manager.invoke("_npi_arctan", this, null);
}
+ /** {@inheritDoc} */
+ @Override
+ public NDArray atan2(NDArray other) {
+ other = manager.from(other);
+ return manager.invoke("_npi_arctan2", new NDArray[] {this, other}, null);
+ }
+
/** {@inheritDoc} */
@Override
public NDArray sinh() {
@@ -1153,6 +1160,18 @@ public NDArray stft(
throw new UnsupportedOperationException("Not implemented yet.");
}
+ /** {@inheritDoc} */
+ @Override
+ public NDArray fft2(long[] sizes, long[] axes) {
+ throw new UnsupportedOperationException("Not implemented yet.");
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public NDArray ifft2(long[] sizes, long[] axes) {
+ throw new UnsupportedOperationException("Not implemented yet.");
+ }
+
/** {@inheritDoc} */
@Override
public NDArray reshape(Shape shape) {
@@ -1601,6 +1620,12 @@ public NDArray erfinv() {
return manager.invoke("erfinv", this, null);
}
+ /** {@inheritDoc} */
+ @Override
+ public NDArray erf() {
+ return manager.invoke("erf", this, null);
+ }
+
/** {@inheritDoc} */
@Override
public NDArray norm(boolean keepDims) {
diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayEx.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayEx.java
index c7efd80eba3..e1ff0db645a 100644
--- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayEx.java
+++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayEx.java
@@ -287,7 +287,7 @@ public NDArray globalMaxPool() {
params.add("pool_type", "max");
params.addParam("global_pool", true);
try (NDArray temp = getManager().invoke("_npx_pooling", getArray(), params)) {
- return temp.reshape(temp.getShape().size(0), temp.getShape().size(1));
+ return temp.reshape(-1, temp.getShape().size(1));
}
}
@@ -318,7 +318,7 @@ public NDArray globalAvgPool() {
params.add("pool_type", "avg");
params.addParam("global_pool", true);
try (NDArray temp = getManager().invoke("_npx_pooling", getArray(), params)) {
- return temp.reshape(temp.getShape().size(0), temp.getShape().size(1));
+ return temp.reshape(-1, temp.getShape().size(1));
}
}
@@ -355,7 +355,7 @@ public NDArray globalLpPool(float normType) {
params.addParam("p_value", (int) normType);
params.addParam("global_pool", true);
try (NDArray temp = getManager().invoke("_npx_pooling", getArray(), params)) {
- return temp.reshape(temp.getShape().size(0), temp.getShape().size(1));
+ return temp.reshape(-1, temp.getShape().size(1));
}
}
diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDManager.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDManager.java
index 5f08cf5910c..99e415cf62c 100644
--- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDManager.java
+++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDManager.java
@@ -23,6 +23,7 @@
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.ndarray.types.SparseFormat;
+import ai.djl.training.listener.AlgebraicListener;
import ai.djl.util.PairList;
import com.sun.jna.Pointer;
@@ -338,12 +339,15 @@ public MxNDManager newSubManager(Device dev) {
public void invoke(
String operation, NDArray[] src, NDArray[] dest, PairList params) {
JnaUtils.op(operation).invoke(this, src, dest, params);
+ AlgebraicListener.record(operation, src, dest, params);
}
/** {@inheritDoc} */
@Override
public NDList invoke(String operation, NDList src, PairList params) {
- return new NDList(JnaUtils.op(operation).invoke(this, src.toArray(EMPTY), params));
+ NDArray[] dest = JnaUtils.op(operation).invoke(this, src.toArray(EMPTY), params);
+ AlgebraicListener.record(operation, src.toArray(EMPTY), dest, params);
+ return new NDList(dest);
}
/**
@@ -379,7 +383,9 @@ public void invoke(String operation, NDList src, NDList dest, PairList params) {
- return JnaUtils.op(operation).invoke(this, src, params)[0];
+ NDArray[] dest = JnaUtils.op(operation).invoke(this, src, params);
+ AlgebraicListener.record(operation, src, dest, params);
+ return dest[0];
}
/**
diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxParameterServer.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxParameterServer.java
index 36bead164e4..952ca2f0995 100644
--- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxParameterServer.java
+++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxParameterServer.java
@@ -40,6 +40,7 @@ public class MxParameterServer extends NativeResource implements Parame
*
* @param optimizer the optimizer to use for the parameter server updates
*/
+ @SuppressWarnings("this-escape")
public MxParameterServer(Optimizer optimizer) {
super(createdKVStore());
callback = new OptimizerCallback(optimizer);
diff --git a/engines/mxnet/mxnet-model-zoo/README.md b/engines/mxnet/mxnet-model-zoo/README.md
index c4f44fe358c..f32678944c0 100644
--- a/engines/mxnet/mxnet-model-zoo/README.md
+++ b/engines/mxnet/mxnet-model-zoo/README.md
@@ -27,7 +27,7 @@ You can pull the MXNet engine from the central Maven repository by including the
ai.djl.mxnet
mxnet-model-zoo
- 0.23.0
+ 0.26.0
```
diff --git a/engines/mxnet/native/build.gradle b/engines/mxnet/native/build.gradle
index 3f8ee285054..dc9d6e5e12d 100644
--- a/engines/mxnet/native/build.gradle
+++ b/engines/mxnet/native/build.gradle
@@ -89,6 +89,7 @@ flavorNames.each { flavor ->
}
from file("${BINARY_ROOT}/${flavor}/${osName}")
archiveClassifier = "${osName}-x86_64"
+ archiveBaseName = "mxnet-native-${flavor}"
manifest {
attributes("Automatic-Module-Name": "ai.djl.mxnet_native_${flavor}_${osName}")
diff --git a/engines/onnxruntime/onnxruntime-android/README.md b/engines/onnxruntime/onnxruntime-android/README.md
index e304e78d5c3..eba92b84288 100644
--- a/engines/onnxruntime/onnxruntime-android/README.md
+++ b/engines/onnxruntime/onnxruntime-android/README.md
@@ -6,13 +6,13 @@ This module contains the DJL ONNX Runtime engine for Android.
## Installation
You can pull the ONNX Runtime for Android from the central Maven repository by including the following dependency:
-- ai.djl.android:onnxruntime:0.23.0
+- ai.djl.android:onnxruntime:0.26.0
```xml
ai.djl.android
onnxruntime
- 0.23.0
+ 0.26.0
runtime
```
diff --git a/engines/onnxruntime/onnxruntime-engine/README.md b/engines/onnxruntime/onnxruntime-engine/README.md
index c287819d23f..b89b14f4473 100644
--- a/engines/onnxruntime/onnxruntime-engine/README.md
+++ b/engines/onnxruntime/onnxruntime-engine/README.md
@@ -37,13 +37,13 @@ for the official ONNX Runtime project.
## Installation
You can pull the ONNX Runtime engine from the central Maven repository by including the following dependency:
-- ai.djl.onnxruntime:onnxruntime-engine:0.23.0
+- ai.djl.onnxruntime:onnxruntime-engine:0.26.0
```xml
ai.djl.onnxruntime
onnxruntime-engine
- 0.23.0
+ 0.26.0
runtime
```
@@ -61,7 +61,7 @@ Maven:
ai.djl.onnxruntime
onnxruntime-engine
- 0.23.0
+ 0.26.0
runtime
@@ -81,7 +81,7 @@ Maven:
Gradle:
```groovy
-implementation("ai.djl.onnxruntime:onnxruntime-engine:0.23.0") {
+implementation("ai.djl.onnxruntime:onnxruntime-engine:0.26.0") {
exclude group: "com.microsoft.onnxruntime", module: "onnxruntime"
}
implementation "com.microsoft.onnxruntime:onnxruntime_gpu:1.14.0"
diff --git a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java
index 89599722435..243377785d8 100644
--- a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java
+++ b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java
@@ -97,7 +97,7 @@ public int getRank() {
/** {@inheritDoc} */
@Override
public String getVersion() {
- return "1.15.1";
+ return "1.16.3";
}
/** {@inheritDoc} */
diff --git a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngineProvider.java b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngineProvider.java
index c673b3dcbf1..5616eb80edb 100644
--- a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngineProvider.java
+++ b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngineProvider.java
@@ -18,7 +18,8 @@
/** {@code OrtEngineProvider} is the ONNX Runtime implementation of {@link EngineProvider}. */
public class OrtEngineProvider implements EngineProvider {
- private static volatile Engine engine; // NOPMD
+ private volatile Engine engine; // NOPMD
+ private volatile boolean initialized; // NOPMD
/** {@inheritDoc} */
@Override
@@ -35,9 +36,12 @@ public int getEngineRank() {
/** {@inheritDoc} */
@Override
public Engine getEngine() {
- if (engine == null) {
+ if (!initialized) {
synchronized (OrtEngineProvider.class) {
- engine = OrtEngine.newInstance();
+ if (!initialized) {
+ initialized = true;
+ engine = OrtEngine.newInstance();
+ }
}
}
return engine;
diff --git a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtSymbolBlock.java b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtSymbolBlock.java
index aa54b43f376..4e8df210d40 100644
--- a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtSymbolBlock.java
+++ b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtSymbolBlock.java
@@ -59,6 +59,7 @@ public class OrtSymbolBlock extends AbstractSymbolBlock implements AutoCloseable
* @param session the {@link OrtSession} contains the model information
* @param manager the {@link NDManager} to holds the NDArray
*/
+ @SuppressWarnings("this-escape")
public OrtSymbolBlock(OrtSession session, OrtNDManager manager) {
this.session = session;
this.manager = manager;
diff --git a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/OrtModelZoo.java b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/OrtModelZoo.java
index 9d8037cfa8b..d61cb81f1ee 100644
--- a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/OrtModelZoo.java
+++ b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/OrtModelZoo.java
@@ -31,6 +31,7 @@ public class OrtModelZoo extends ModelZoo {
OrtModelZoo() {
addModel(REPOSITORY.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "resnet", "0.0.1"));
addModel(REPOSITORY.model(CV.OBJECT_DETECTION, GROUP_ID, "yolo5s", "0.0.1"));
+ addModel(REPOSITORY.model(CV.OBJECT_DETECTION, GROUP_ID, "yolov8n", "0.0.1"));
addModel(REPOSITORY.model(Tabular.SOFTMAX_REGRESSION, GROUP_ID, "iris_flowers", "0.0.1"));
}
diff --git a/engines/onnxruntime/onnxruntime-engine/src/test/resources/mlrepo/model/cv/object_detection/ai/djl/onnxruntime/yolov8n/metadata.json b/engines/onnxruntime/onnxruntime-engine/src/test/resources/mlrepo/model/cv/object_detection/ai/djl/onnxruntime/yolov8n/metadata.json
new file mode 100644
index 00000000000..1e0169a2561
--- /dev/null
+++ b/engines/onnxruntime/onnxruntime-engine/src/test/resources/mlrepo/model/cv/object_detection/ai/djl/onnxruntime/yolov8n/metadata.json
@@ -0,0 +1,40 @@
+{
+ "metadataVersion": "0.2",
+ "resourceType": "model",
+ "application": "cv/object_detection",
+ "groupId": "ai.djl.onnxruntime",
+ "artifactId": "yolov8n",
+ "name": "yolov8n",
+ "description": "YoloV8 Model",
+ "website": "http://www.djl.ai/engines/onnxruntime/model-zoo",
+ "licenses": {
+ "license": {
+ "name": "The Apache License, Version 2.0",
+ "url": "https://www.apache.org/licenses/LICENSE-2.0"
+ }
+ },
+ "artifacts": [
+ {
+ "version": "0.0.1",
+ "snapshot": false,
+ "name": "yolov8n",
+ "arguments": {
+ "width": 640,
+ "height": 640,
+ "resize": true,
+ "rescale": true,
+ "optApplyRatio": true,
+ "threshold": 0.6,
+ "translatorFactory": "ai.djl.modality.cv.translator.YoloV8TranslatorFactory"
+ },
+ "files": {
+ "model": {
+ "uri": "0.0.1/yolov8n.zip",
+ "name": "",
+ "sha1Hash": "9fbad7f706713843cbb8c8d6a56c81a640ec6fa2",
+ "size": 11053839
+ }
+ }
+ }
+ ]
+}
diff --git a/engines/paddlepaddle/paddlepaddle-engine/README.md b/engines/paddlepaddle/paddlepaddle-engine/README.md
index 9e65fb76601..6671cfbcd42 100644
--- a/engines/paddlepaddle/paddlepaddle-engine/README.md
+++ b/engines/paddlepaddle/paddlepaddle-engine/README.md
@@ -30,7 +30,7 @@ You can pull the PaddlePaddle engine from the central Maven repository by includ
ai.djl.paddlepaddle
paddlepaddle-engine
- 0.23.0
+ 0.26.0
runtime
```
diff --git a/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngineProvider.java b/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngineProvider.java
index e2b5bdd35a0..e2fb86974f5 100644
--- a/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngineProvider.java
+++ b/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngineProvider.java
@@ -18,7 +18,8 @@
/** {@code PpEngineProvider} is the PaddlePaddle implementation of {@link EngineProvider}. */
public class PpEngineProvider implements EngineProvider {
- private static volatile Engine engine; // NOPMD
+ private volatile Engine engine; // NOPMD
+ private volatile boolean initialized; // NOPMD
/** {@inheritDoc} */
@Override
@@ -35,9 +36,12 @@ public int getEngineRank() {
/** {@inheritDoc} */
@Override
public Engine getEngine() {
- if (engine == null) {
+ if (!initialized) {
synchronized (PpEngineProvider.class) {
- engine = PpEngine.newInstance();
+ if (!initialized) {
+ initialized = true;
+ engine = PpEngine.newInstance();
+ }
}
}
return engine;
diff --git a/engines/paddlepaddle/paddlepaddle-model-zoo/README.md b/engines/paddlepaddle/paddlepaddle-model-zoo/README.md
index e2c9cf6036c..55d3c67fe50 100644
--- a/engines/paddlepaddle/paddlepaddle-model-zoo/README.md
+++ b/engines/paddlepaddle/paddlepaddle-model-zoo/README.md
@@ -26,7 +26,7 @@ from the central Maven repository by including the following dependency:
ai.djl.paddlepaddle
paddlepaddle-model-zoo
- 0.23.0
+ 0.26.0
```
diff --git a/engines/paddlepaddle/paddlepaddle-native/build.gradle b/engines/paddlepaddle/paddlepaddle-native/build.gradle
index 74a573debad..de1ea58da2b 100644
--- a/engines/paddlepaddle/paddlepaddle-native/build.gradle
+++ b/engines/paddlepaddle/paddlepaddle-native/build.gradle
@@ -213,6 +213,7 @@ flavorNames.each { flavor ->
}
from file("${BINARY_ROOT}/${flavor}/${osName}")
archiveClassifier = "${osName}-x86_64"
+ archiveBaseName = "paddlepaddle-native-${flavor}"
manifest {
attributes("Automatic-Module-Name": "ai.djl.paddlepaddle_native_${flavor}_${osName}")
diff --git a/engines/pytorch/pytorch-engine/README.md b/engines/pytorch/pytorch-engine/README.md
index ef74cf98808..c8571c54781 100644
--- a/engines/pytorch/pytorch-engine/README.md
+++ b/engines/pytorch/pytorch-engine/README.md
@@ -24,13 +24,13 @@ The javadocs output is built in the `build/doc/javadoc` folder.
## Installation
You can pull the PyTorch engine from the central Maven repository by including the following dependency:
-- ai.djl.pytorch:pytorch-engine:0.23.0
+- ai.djl.pytorch:pytorch-engine:0.26.0
```xml
ai.djl.pytorch
pytorch-engine
- 0.23.0
+ 0.26.0
runtime
```
@@ -46,6 +46,9 @@ The following table illustrates which pytorch version that DJL supports:
| PyTorch engine version | PyTorch native library version |
|------------------------|-------------------------------------------|
+| pytorch-engine:0.26.0 | 1.13.1, 2.0.1, **2.1.1** |
+| pytorch-engine:0.25.0 | 1.11.0, 1.12.1, **1.13.1**, 2.0.1 |
+| pytorch-engine:0.24.0 | 1.11.0, 1.12.1, **1.13.1**, 2.0.1 |
| pytorch-engine:0.23.0 | 1.11.0, 1.12.1, **1.13.1**, 2.0.1 |
| pytorch-engine:0.22.1 | 1.11.0, 1.12.1, **1.13.1**, 2.0.0 |
| pytorch-engine:0.21.0 | 1.11.0, 1.12.1, **1.13.1** |
@@ -110,21 +113,21 @@ export PYTORCH_FLAVOR=cpu
### macOS
For macOS, you can use the following library:
-- ai.djl.pytorch:pytorch-jni:2.0.1-0.23.0
-- ai.djl.pytorch:pytorch-native-cpu:2.0.1:osx-x86_64
+- ai.djl.pytorch:pytorch-jni:2.1.1-0.26.0
+- ai.djl.pytorch:pytorch-native-cpu:2.1.1:osx-x86_64
```xml
ai.djl.pytorch
pytorch-native-cpu
osx-x86_64
- 2.0.1
+ 2.1.1
runtime
ai.djl.pytorch
pytorch-jni
- 2.0.1-0.23.0
+ 2.1.1-0.26.0
runtime
```
@@ -134,21 +137,21 @@ For macOS, you can use the following library:
### macOS M1
For macOS M1, you can use the following library:
-- ai.djl.pytorch:pytorch-jni:2.0.1-0.23.0
-- ai.djl.pytorch:pytorch-native-cpu:2.0.1:osx-aarch64
+- ai.djl.pytorch:pytorch-jni:2.1.1-0.26.0
+- ai.djl.pytorch:pytorch-native-cpu:2.1.1:osx-aarch64
```xml
ai.djl.pytorch
pytorch-native-cpu
osx-aarch64
- 2.0.1
+ 2.1.1
runtime
ai.djl.pytorch
pytorch-jni
- 2.0.1-0.23.0
+ 2.1.1-0.26.0
runtime
```
@@ -159,29 +162,29 @@ installed on your GPU machine, you can use one of the following library:
#### Linux GPU
-- ai.djl.pytorch:pytorch-jni:2.0.1-0.23.0
-- ai.djl.pytorch:pytorch-native-cu118:2.0.1:linux-x86_64 - CUDA 11.8
+- ai.djl.pytorch:pytorch-jni:2.1.1-0.26.0
+- ai.djl.pytorch:pytorch-native-cu121:2.1.1:linux-x86_64 - CUDA 12.1
```xml
ai.djl.pytorch
- pytorch-native-cu118
+ pytorch-native-cu121
linux-x86_64
- 2.0.1
+ 2.1.1
runtime
ai.djl.pytorch
pytorch-jni
- 2.0.1-0.23.0
+ 2.1.1-0.26.0
runtime
```
### Linux CPU
-- ai.djl.pytorch:pytorch-jni:2.0.1-0.23.0
-- ai.djl.pytorch:pytorch-native-cpu:2.0.1:linux-x86_64
+- ai.djl.pytorch:pytorch-jni:2.1.1-0.26.0
+- ai.djl.pytorch:pytorch-native-cpu:2.1.1:linux-x86_64
```xml
@@ -189,20 +192,20 @@ installed on your GPU machine, you can use one of the following library:
pytorch-native-cpu
linux-x86_64
runtime
- 2.0.1
+ 2.1.1
ai.djl.pytorch
pytorch-jni
- 2.0.1-0.23.0
+ 2.1.1-0.26.0
runtime
```
### For aarch64 build
-- ai.djl.pytorch:pytorch-jni:2.0.1-0.23.0
-- ai.djl.pytorch:pytorch-native-cpu-precxx11:2.0.1:linux-aarch64
+- ai.djl.pytorch:pytorch-jni:2.1.1-0.26.0
+- ai.djl.pytorch:pytorch-native-cpu-precxx11:2.1.1:linux-aarch64
```xml
@@ -210,12 +213,12 @@ installed on your GPU machine, you can use one of the following library:
pytorch-native-cpu-precxx11
linux-aarch64
runtime
- 2.0.1
+ 2.1.1
ai.djl.pytorch
pytorch-jni
- 2.0.1-0.23.0
+ 2.1.1-0.26.0
runtime
```
@@ -225,22 +228,22 @@ installed on your GPU machine, you can use one of the following library:
We also provide packages for the system like CentOS 7/Ubuntu 14.04 with GLIBC >= 2.17.
All the package were built with GCC 7, we provided a newer `libstdc++.so.6.24` in the package that contains `CXXABI_1.3.9` to use the package successfully.
-- ai.djl.pytorch:pytorch-jni:2.0.1-0.23.0
-- ai.djl.pytorch:pytorch-native-cu118-precxx11:2.0.1:linux-x86_64 - CUDA 11.8
-- ai.djl.pytorch:pytorch-native-cpu-precxx11:2.0.1:linux-x86_64 - CPU
+- ai.djl.pytorch:pytorch-jni:2.1.1-0.26.0
+- ai.djl.pytorch:pytorch-native-cu121-precxx11:2.1.1:linux-x86_64 - CUDA 12.1
+- ai.djl.pytorch:pytorch-native-cpu-precxx11:2.1.1:linux-x86_64 - CPU
```xml
ai.djl.pytorch
- pytorch-native-cu118-precxx11
+ pytorch-native-cu121-precxx11
linux-x86_64
- 2.0.1
+ 2.1.1
runtime
ai.djl.pytorch
pytorch-jni
- 2.0.1-0.23.0
+ 2.1.1-0.26.0
runtime
```
@@ -250,13 +253,13 @@ All the package were built with GCC 7, we provided a newer `libstdc++.so.6.24` i
ai.djl.pytorch
pytorch-native-cpu-precxx11
linux-x86_64
- 2.0.1
+ 2.1.1
runtime
ai.djl.pytorch
pytorch-jni
- 2.0.1-0.23.0
+ 2.1.1-0.26.0
runtime
```
@@ -271,29 +274,29 @@ For the Windows platform, you can choose between CPU and GPU.
#### Windows GPU
-- ai.djl.pytorch:pytorch-jni:2.0.1-0.23.0
-- ai.djl.pytorch:pytorch-native-cu118:2.0.1:win-x86_64 - CUDA 11.8
+- ai.djl.pytorch:pytorch-jni:2.1.1-0.26.0
+- ai.djl.pytorch:pytorch-native-cu121:2.1.1:win-x86_64 - CUDA 12.1
```xml
ai.djl.pytorch
- pytorch-native-cu118
+ pytorch-native-cu121
win-x86_64
- 2.0.1
+ 2.1.1
runtime
ai.djl.pytorch
pytorch-jni
- 2.0.1-0.23.0
+ 2.1.1-0.26.0
runtime
```
### Windows CPU
-- ai.djl.pytorch:pytorch-jni:2.0.1-0.23.0
-- ai.djl.pytorch:pytorch-native-cpu:2.0.1:win-x86_64
+- ai.djl.pytorch:pytorch-jni:2.1.1-0.26.0
+- ai.djl.pytorch:pytorch-native-cpu:2.1.1:win-x86_64
```xml
@@ -301,12 +304,12 @@ For the Windows platform, you can choose between CPU and GPU.
pytorch-native-cpu
win-x86_64
runtime
- 2.0.1
+ 2.1.1
ai.djl.pytorch
pytorch-jni
- 2.0.1-0.23.0
+ 2.1.1-0.26.0
runtime
```
diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngineProvider.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngineProvider.java
index 57ae6c09d34..24be3e91d7a 100644
--- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngineProvider.java
+++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngineProvider.java
@@ -18,7 +18,8 @@
/** {@code PtEngineProvider} is the PyTorch implementation of {@link EngineProvider}. */
public class PtEngineProvider implements EngineProvider {
- private static volatile Engine engine; // NOPMD
+ private volatile Engine engine; // NOPMD
+ private volatile boolean initialized; // NOPMD
/** {@inheritDoc} */
@Override
@@ -35,9 +36,12 @@ public int getEngineRank() {
/** {@inheritDoc} */
@Override
public Engine getEngine() {
- if (engine == null) {
+ if (!initialized) {
synchronized (PtEngineProvider.class) {
- engine = PtEngine.newInstance();
+ if (!initialized) {
+ initialized = true;
+ engine = PtEngine.newInstance();
+ }
}
}
return engine;
diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java
index e72e98c9495..35e95f7de86 100644
--- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java
+++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java
@@ -18,6 +18,7 @@
import ai.djl.Model;
import ai.djl.ndarray.types.DataType;
import ai.djl.nn.Parameter;
+import ai.djl.nn.Parameter.Type;
import ai.djl.pytorch.jni.JniUtils;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
@@ -189,7 +190,9 @@ public Trainer newTrainer(TrainingConfig trainingConfig) {
}
if (wasLoaded) {
// Unfreeze parameters if training directly
- block.freezeParameters(false);
+ block.freezeParameters(
+ false,
+ p -> p.getType() != Type.RUNNING_MEAN && p.getType() != Type.RUNNING_VAR);
}
for (Pair> pair : initializer) {
if (pair.getKey() != null && pair.getValue() != null) {
diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java
index 9e36ec35884..499f51ebad5 100644
--- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java
+++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java
@@ -60,6 +60,7 @@ public class PtNDArray extends NativeResource implements NDArray {
* @param manager the manager to attach the new array to
* @param handle the pointer to the native PyTorch memory
*/
+ @SuppressWarnings("this-escape")
public PtNDArray(PtNDManager manager, long handle) {
super(handle);
this.manager = manager;
@@ -76,6 +77,7 @@ public PtNDArray(PtNDManager manager, long handle) {
* @param handle the pointer to the native PyTorch memory
* @param data the direct buffer of the data
*/
+ @SuppressWarnings("this-escape")
public PtNDArray(PtNDManager manager, long handle, ByteBuffer data) {
super(handle);
this.manager = manager;
@@ -93,6 +95,7 @@ public PtNDArray(PtNDManager manager, long handle, ByteBuffer data) {
* @param strs the string array
* @param shape the {@link Shape} of the {@link NDArray}
*/
+ @SuppressWarnings("this-escape")
public PtNDArray(PtNDManager manager, String[] strs, Shape shape) {
super(-1L);
this.manager = manager;
@@ -888,6 +891,12 @@ public PtNDArray atan() {
return JniUtils.atan(this);
}
+ /** {@inheritDoc} */
+ @Override
+ public PtNDArray atan2(NDArray other) {
+ return JniUtils.atan2(this, manager.from(other));
+ }
+
/** {@inheritDoc} */
@Override
public PtNDArray sinh() {
@@ -1097,6 +1106,18 @@ public NDArray stft(
this, nFft, hopLength, (PtNDArray) window, center, normalize, returnComplex);
}
+ /** {@inheritDoc} */
+ @Override
+ public NDArray fft2(long[] sizes, long[] axes) {
+ return JniUtils.fft2(this, sizes, axes);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public NDArray ifft2(long[] sizes, long[] axes) {
+ return JniUtils.ifft2(this, sizes, axes);
+ }
+
/** {@inheritDoc} */
@Override
public PtNDArray reshape(Shape shape) {
@@ -1539,6 +1560,12 @@ public PtNDArray erfinv() {
return JniUtils.erfinv(this);
}
+ /** {@inheritDoc} */
+ @Override
+ public PtNDArray erf() {
+ return JniUtils.erf(this);
+ }
+
/** {@inheritDoc} */
@Override
public PtNDArray inverse() {
diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java
index fa4ee81f26c..b7f92cbd1c3 100644
--- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java
+++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java
@@ -13,6 +13,7 @@
package ai.djl.pytorch.engine;
import ai.djl.ndarray.NDArray;
+import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.NDUtils;
@@ -24,6 +25,8 @@
import ai.djl.nn.recurrent.RNN;
import ai.djl.pytorch.jni.JniUtils;
+import java.util.Arrays;
+import java.util.Comparator;
import java.util.List;
/** {@code PtNDArrayEx} is the PyTorch implementation of the {@link NDArrayEx}. */
@@ -760,7 +763,152 @@ public NDList multiBoxDetection(
float nmsThreshold,
boolean forceSuppress,
int nmsTopK) {
- throw new UnsupportedOperationException("Not implemented");
+ assert (inputs.size() == 3);
+
+ NDArray clsProb = inputs.get(0);
+ NDArray locPred = inputs.get(1);
+ NDArray anchors = inputs.get(2).reshape(new Shape(-1, 4));
+
+ NDManager ndManager = array.getManager();
+
+ NDArray variances = ndManager.create(new float[] {0.1f, 0.1f, 0.2f, 0.2f});
+
+ assert (variances.size() == 4); // << "Variance size must be 4";
+ final int numClasses = (int) clsProb.size(1);
+ final int numAnchors = (int) clsProb.size(2);
+ final int numBatches = (int) clsProb.size(0);
+
+ final float[] pAnchor = anchors.toFloatArray();
+
+ // [id, prob, xmin, ymin, xmax, ymax]
+ // TODO Move to NDArray-based implementation
+ NDList batchOutputs = new NDList();
+ for (int nbatch = 0; nbatch < numBatches; ++nbatch) {
+ float[][] outputs = new float[numAnchors][6];
+ final float[] pClsProb = clsProb.get(nbatch).toFloatArray();
+ final float[] pLocPred = locPred.get(nbatch).toFloatArray();
+
+ for (int i = 0; i < numAnchors; ++i) {
+ // find the predicted class id and probability
+ float score = -1;
+ int id = 0;
+ for (int j = 1; j < numClasses; ++j) {
+ float temp = pClsProb[j * numAnchors + i];
+ if (temp > score) {
+ score = temp;
+ id = j;
+ }
+ }
+
+ if (id > 0 && score < threshold) {
+ id = 0;
+ }
+
+ // [id, prob, xmin, ymin, xmax, ymax]
+ outputs[i][0] = id - 1;
+ outputs[i][1] = score;
+ int offset = i * 4;
+ float[] pAnchorRow4 = new float[4];
+ pAnchorRow4[0] = pAnchor[offset];
+ pAnchorRow4[1] = pAnchor[offset + 1];
+ pAnchorRow4[2] = pAnchor[offset + 2];
+ pAnchorRow4[3] = pAnchor[offset + 3];
+ float[] pLocPredRow4 = new float[4];
+ pLocPredRow4[0] = pLocPred[offset];
+ pLocPredRow4[1] = pLocPred[offset + 1];
+ pLocPredRow4[2] = pLocPred[offset + 2];
+ pLocPredRow4[3] = pLocPred[offset + 3];
+ float[] outRowLast4 =
+ transformLocations(
+ pAnchorRow4,
+ pLocPredRow4,
+ clip,
+ variances.toFloatArray()[0],
+ variances.toFloatArray()[1],
+ variances.toFloatArray()[2],
+ variances.toFloatArray()[3]);
+ outputs[i][2] = outRowLast4[0];
+ outputs[i][3] = outRowLast4[1];
+ outputs[i][4] = outRowLast4[2];
+ outputs[i][5] = outRowLast4[3];
+ }
+
+ outputs =
+ Arrays.stream(outputs)
+ .filter(o -> o[0] >= 0)
+ .sorted(Comparator.comparing(o -> -o[1]))
+ .toArray(float[][]::new);
+
+ // apply nms
+ for (int i = 0; i < outputs.length; ++i) {
+ for (int j = i + 1; j < outputs.length; ++j) {
+ if (outputs[i][0] == outputs[j][0]) {
+ float[] outputsIRow4 = new float[4];
+ float[] outputsJRow4 = new float[4];
+ outputsIRow4[0] = outputs[i][2];
+ outputsIRow4[1] = outputs[i][3];
+ outputsIRow4[2] = outputs[i][4];
+ outputsIRow4[3] = outputs[i][5];
+ outputsJRow4[0] = outputs[j][2];
+ outputsJRow4[1] = outputs[j][3];
+ outputsJRow4[2] = outputs[j][4];
+ outputsJRow4[3] = outputs[j][5];
+ float iou = calculateOverlap(outputsIRow4, outputsJRow4);
+ if (iou >= nmsThreshold) {
+ outputs[j][0] = -1;
+ }
+ }
+ }
+ }
+ batchOutputs.add(ndManager.create(outputs));
+ } // end iter batch
+
+ NDArray pOutNDArray = NDArrays.stack(batchOutputs);
+ NDList resultNDList = new NDList();
+ resultNDList.add(pOutNDArray);
+ assert (resultNDList.size() == 1);
+ return resultNDList;
+ }
+
+ private float[] transformLocations(
+ final float[] anchors,
+ final float[] locPred,
+ final boolean clip,
+ final float vx,
+ final float vy,
+ final float vw,
+ final float vh) {
+ float[] outRowLast4 = new float[4];
+ // transform predictions to detection results
+ float al = anchors[0];
+ float at = anchors[1];
+ float ar = anchors[2];
+ float ab = anchors[3];
+ float aw = ar - al;
+ float ah = ab - at;
+ float ax = (al + ar) / 2.f;
+ float ay = (at + ab) / 2.f;
+ float px = locPred[0];
+ float py = locPred[1];
+ float pw = locPred[2];
+ float ph = locPred[3];
+ float ox = px * vx * aw + ax;
+ float oy = py * vy * ah + ay;
+ float ow = (float) (Math.exp(pw * vw) * aw / 2);
+ float oh = (float) (Math.exp(ph * vh) * ah / 2);
+ outRowLast4[0] = clip ? Math.max(0f, Math.min(1f, ox - ow)) : (ox - ow);
+ outRowLast4[1] = clip ? Math.max(0f, Math.min(1f, oy - oh)) : (oy - oh);
+ outRowLast4[2] = clip ? Math.max(0f, Math.min(1f, ox + ow)) : (ox + ow);
+ outRowLast4[3] = clip ? Math.max(0f, Math.min(1f, oy + oh)) : (oy + oh);
+ return outRowLast4;
+ }
+
+ private float calculateOverlap(final float[] a, final float[] b) {
+ float w = Math.max(0f, Math.min(a[2], b[2]) - Math.max(a[0], b[0]));
+ float h = Math.max(0f, Math.min(a[3], b[3]) - Math.max(a[1], b[1]));
+ float i = w * h;
+ float u = (a[2] - a[0]) * (a[3] - a[1]) + (b[2] - b[0]) * (b[3] - b[1]) - i;
+ return u <= 0.f ? 0f : (i / u);
}
/** {@inheritDoc} */
diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java
index 8bc28a2c21b..7075cb05efa 100644
--- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java
+++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java
@@ -67,6 +67,7 @@ public class PtSymbolBlock extends AbstractSymbolBlock implements AutoCloseable
* @param manager the manager to use for the block
* @param handle the module handle
*/
+ @SuppressWarnings("this-escape")
public PtSymbolBlock(PtNDManager manager, long handle) {
this(manager);
this.handle = new AtomicReference<>(handle);
diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java
index aad38ae8f0c..40a6a0065bc 100644
--- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java
+++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java
@@ -1040,6 +1040,18 @@ public static PtNDArray stft(
return new PtNDArray(ndArray.getManager(), handle);
}
+ public static PtNDArray fft2(PtNDArray ndArray, long[] sizes, long[] axes) {
+ return new PtNDArray(
+ ndArray.getManager(),
+ PyTorchLibrary.LIB.torchFft2(ndArray.getHandle(), sizes, axes));
+ }
+
+ public static PtNDArray ifft2(PtNDArray ndArray, long[] sizes, long[] axes) {
+ return new PtNDArray(
+ ndArray.getManager(),
+ PyTorchLibrary.LIB.torchIfft2(ndArray.getHandle(), sizes, axes));
+ }
+
public static PtNDArray real(PtNDArray ndArray) {
long handle = PyTorchLibrary.LIB.torchViewAsReal(ndArray.getHandle());
if (handle == -1) {
@@ -1145,6 +1157,12 @@ public static PtNDArray atan(PtNDArray ndArray) {
ndArray.getManager(), PyTorchLibrary.LIB.torchAtan(ndArray.getHandle()));
}
+ public static PtNDArray atan2(PtNDArray self, PtNDArray other) {
+ return new PtNDArray(
+ self.getManager(),
+ PyTorchLibrary.LIB.torchAtan2(self.getHandle(), other.getHandle()));
+ }
+
public static PtNDArray sqrt(PtNDArray ndArray) {
return new PtNDArray(
ndArray.getManager(), PyTorchLibrary.LIB.torchSqrt(ndArray.getHandle()));
@@ -1334,6 +1352,11 @@ public static PtNDArray erfinv(PtNDArray ndArray) {
ndArray.getManager(), PyTorchLibrary.LIB.torchErfinv(ndArray.getHandle()));
}
+ public static PtNDArray erf(PtNDArray ndArray) {
+ return new PtNDArray(
+ ndArray.getManager(), PyTorchLibrary.LIB.torchErf(ndArray.getHandle()));
+ }
+
public static PtNDArray inverse(PtNDArray ndArray) {
return new PtNDArray(
ndArray.getManager(), PyTorchLibrary.LIB.torchInverse(ndArray.getHandle()));
diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/LibUtils.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/LibUtils.java
index 9d422463910..03835b6ca68 100644
--- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/LibUtils.java
+++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/LibUtils.java
@@ -106,9 +106,16 @@ public static String getLibtorchPath() {
private static void loadLibTorch(LibTorch libTorch) {
Path libDir = libTorch.dir.toAbsolutePath();
- if ("1.8.1".equals(getVersion()) && System.getProperty("os.name").startsWith("Mac")) {
- // PyTorch 1.8.1 libtorch_cpu.dylib cannot be loaded individually
- return;
+ if (Files.exists(libDir.resolve("libstdc++.so.6"))) {
+ String libstd = Utils.getEnvOrSystemProperty("LIBSTDCXX_LIBRARY_PATH");
+ if (libstd != null) {
+ try {
+ logger.info("Loading libstdc++.so.6 from: {}", libstd);
+ System.load(libstd);
+ } catch (UnsatisfiedLinkError e) {
+ logger.warn("Failed Loading libstdc++.so.6 from: {}", libstd);
+ }
+ }
}
boolean isCuda = libTorch.flavor.contains("cu");
List deferred =
@@ -120,6 +127,7 @@ private static void loadLibTorch(LibTorch libTorch) {
System.mapLibraryName("torch_cuda_cpp"),
System.mapLibraryName("torch_cuda_cu"),
System.mapLibraryName("torch_cuda"),
+ System.mapLibraryName("nvfuser_codegen"),
System.mapLibraryName("torch"));
Set loadLater = new HashSet<>(deferred);
@@ -133,7 +141,8 @@ private static void loadLibTorch(LibTorch libTorch) {
&& name.contains("cudart")
&& name.contains("nvTools")) {
return false;
- } else if (name.startsWith("libarm_compute-")) {
+ } else if (name.startsWith("libarm_compute-")
+ || name.startsWith("libopenblasp")) {
rank.put(path, 2);
return true;
} else if (name.startsWith("libarm_compute_")) {
diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java
index c0f7b553ab2..54fc5419145 100644
--- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java
+++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java
@@ -273,6 +273,10 @@ native long torchStft(
boolean normalize,
boolean returnComplex);
+ native long torchFft2(long handle, long[] sizes, long[] axes);
+
+ native long torchIfft2(long handle, long[] sizes, long[] axes);
+
native long torchViewAsReal(long handle);
native long torchViewAsComplex(long handle);
@@ -332,6 +336,8 @@ native long[] torchUnique(
native long torchAtan(long handle);
+ native long torchAtan2(long self, long other);
+
native long torchSqrt(long handle);
native long torchSinh(long handle);
@@ -405,6 +411,8 @@ native long tensorUniform(
native long torchErfinv(long handle);
+ native long torchErf(long handle);
+
native long torchInverse(long self);
native long torchNNInterpolate(long handle, long[] size, int mode, boolean alignCorners);
diff --git a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/LibUtilsTest.java b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/ALibUtilsTest.java
similarity index 73%
rename from engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/LibUtilsTest.java
rename to engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/ALibUtilsTest.java
index 617d2cfb809..f6cfda91106 100644
--- a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/LibUtilsTest.java
+++ b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/ALibUtilsTest.java
@@ -18,17 +18,21 @@
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;
-public class LibUtilsTest {
+// Ensure this test run first
+public class ALibUtilsTest {
@BeforeClass
public void setup() {
- System.setProperty(
- "ai.djl.pytorch.native_helper", "ai.djl.pytorch.integration.LibUtilsTest");
+ System.setProperty("ai.djl.pytorch.native_helper", ALibUtilsTest.class.getName());
+ System.setProperty("STDCXX_LIBRARY_PATH", "/usr/lib/non-exists");
+ System.setProperty("PYTORCH_PRECXX11", "true");
}
@AfterClass
public void teardown() {
System.clearProperty("ai.djl.pytorch.native_helper");
+ System.clearProperty("LIBSTDCXX_LIBRARY_PATH");
+ System.clearProperty("PYTORCH_PRECXX11");
}
@Test
diff --git a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/MpsTest.java b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/MpsTest.java
index 8b4e2326f26..5b6ed349e10 100644
--- a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/MpsTest.java
+++ b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/MpsTest.java
@@ -13,6 +13,7 @@
package ai.djl.pytorch.integration;
import ai.djl.Device;
+import ai.djl.modality.Classifications;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
@@ -21,6 +22,10 @@
import org.testng.SkipException;
import org.testng.annotations.Test;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
public class MpsTest {
@Test
@@ -36,4 +41,39 @@ public void testMps() {
Assert.assertEquals(array.getDevice().getDeviceType(), "mps");
}
}
+
+ private static boolean checkMpsCompatible() {
+ return "aarch64".equals(System.getProperty("os.arch"))
+ && System.getProperty("os.name").startsWith("Mac");
+ }
+
+ @Test
+ public void testToTensorMPS() {
+ if (!checkMpsCompatible()) {
+ throw new SkipException("MPS toTensor test requires Apple Silicon macOS.");
+ }
+
+ // Test that toTensor does not fail on MPS (e.g. due to use of float64 for division)
+ try (NDManager manager = NDManager.newBaseManager(Device.fromName("mps"))) {
+ NDArray array = manager.create(127f).reshape(1, 1, 1, 1);
+ NDArray tensor = array.getNDArrayInternal().toTensor();
+ Assert.assertEquals(tensor.toFloatArray(), new float[] {127f / 255f});
+ }
+ }
+
+ @Test
+ public void testClassificationsMPS() {
+ if (!checkMpsCompatible()) {
+ throw new SkipException("MPS classification test requires Apple Silicon macOS.");
+ }
+
+ // Test that classifications do not fail on MPS (e.g. due to conversion of probabilities to
+ // float64)
+ try (NDManager manager = NDManager.newBaseManager(Device.fromName("mps"))) {
+ List names = Arrays.asList("First", "Second", "Third", "Fourth", "Fifth");
+ NDArray tensor = manager.create(new float[] {0f, 0.125f, 1f, 0.5f, 0.25f});
+ Classifications classifications = new Classifications(names, tensor);
+ Assert.assertEquals(classifications.topK(1), Collections.singletonList("Third"));
+ }
+ }
}
diff --git a/engines/pytorch/pytorch-jni/build.gradle b/engines/pytorch/pytorch-jni/build.gradle
index 450c832e803..c2b0ee9dc7b 100644
--- a/engines/pytorch/pytorch-jni/build.gradle
+++ b/engines/pytorch/pytorch-jni/build.gradle
@@ -24,7 +24,13 @@ processResources {
"osx-x86_64/cpu/libdjl_torch.dylib",
"win-x86_64/cpu/djl_torch.dll"
]
- if (ptVersion.startsWith("2.0.")) {
+ if (ptVersion.startsWith("2.1.")) {
+ files.add("linux-aarch64/cpu-precxx11/libdjl_torch.so")
+ files.add("linux-x86_64/cu121/libdjl_torch.so")
+ files.add("linux-x86_64/cu121-precxx11/libdjl_torch.so")
+ files.add("win-x86_64/cu121/djl_torch.dll")
+ files.add("osx-aarch64/cpu/libdjl_torch.dylib")
+ } else if (ptVersion.startsWith("2.0.")) {
files.add("linux-aarch64/cpu-precxx11/libdjl_torch.so")
files.add("linux-x86_64/cu118/libdjl_torch.so")
files.add("linux-x86_64/cu118-precxx11/libdjl_torch.so")
diff --git a/engines/pytorch/pytorch-model-zoo/README.md b/engines/pytorch/pytorch-model-zoo/README.md
index 8d3113842e1..f598dd2aecd 100644
--- a/engines/pytorch/pytorch-model-zoo/README.md
+++ b/engines/pytorch/pytorch-model-zoo/README.md
@@ -25,7 +25,7 @@ You can pull the PyTorch engine from the central Maven repository by including t
ai.djl.pytorch
pytorch-model-zoo
- 0.23.0
+ 0.26.0
```
diff --git a/engines/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/PtModelZoo.java b/engines/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/PtModelZoo.java
index ea70871eff0..abb820cced9 100644
--- a/engines/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/PtModelZoo.java
+++ b/engines/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/PtModelZoo.java
@@ -38,6 +38,7 @@ public class PtModelZoo extends ModelZoo {
REPOSITORY.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "resnet18_embedding", "0.0.1"));
addModel(REPOSITORY.model(CV.OBJECT_DETECTION, GROUP_ID, "ssd", "0.0.1"));
addModel(REPOSITORY.model(CV.OBJECT_DETECTION, GROUP_ID, "yolov5s", "0.0.1"));
+ addModel(REPOSITORY.model(CV.OBJECT_DETECTION, GROUP_ID, "yolov8n", "0.0.1"));
addModel(REPOSITORY.model(NLP.QUESTION_ANSWER, GROUP_ID, "bertqa", "0.0.1"));
addModel(REPOSITORY.model(NLP.SENTIMENT_ANALYSIS, GROUP_ID, "distilbert", "0.0.1"));
addModel(REPOSITORY.model(CV.IMAGE_GENERATION, GROUP_ID, "biggan-deep", "0.0.1"));
diff --git a/engines/pytorch/pytorch-model-zoo/src/test/resources/mlrepo/model/cv/object_detection/ai/djl/pytorch/yolov8n/metadata.json b/engines/pytorch/pytorch-model-zoo/src/test/resources/mlrepo/model/cv/object_detection/ai/djl/pytorch/yolov8n/metadata.json
new file mode 100644
index 00000000000..399b79b4889
--- /dev/null
+++ b/engines/pytorch/pytorch-model-zoo/src/test/resources/mlrepo/model/cv/object_detection/ai/djl/pytorch/yolov8n/metadata.json
@@ -0,0 +1,40 @@
+{
+ "metadataVersion": "0.2",
+ "resourceType": "model",
+ "application": "cv/object_detection",
+ "groupId": "ai.djl.pytorch",
+ "artifactId": "yolov8n",
+ "name": "yolov8n",
+ "description": "YoloV8 Model",
+ "website": "http://www.djl.ai/engines/onnxruntime/model-zoo",
+ "licenses": {
+ "license": {
+ "name": "The Apache License, Version 2.0",
+ "url": "https://www.apache.org/licenses/LICENSE-2.0"
+ }
+ },
+ "artifacts": [
+ {
+ "version": "0.0.1",
+ "snapshot": false,
+ "name": "yolov8n",
+ "arguments": {
+ "width": 640,
+ "height": 640,
+ "resize": true,
+ "rescale": true,
+ "optApplyRatio": true,
+ "threshold": 0.6,
+ "translatorFactory": "ai.djl.modality.cv.translator.YoloV8TranslatorFactory"
+ },
+ "files": {
+ "model": {
+ "uri": "0.0.1/yolov8n.zip",
+ "name": "",
+ "sha1Hash": "a868778452ef8d6d2f9cb7109a9e14a64e851d48",
+ "size": 11183356
+ }
+ }
+ }
+ ]
+}
diff --git a/engines/pytorch/pytorch-native/CMakeLists.txt b/engines/pytorch/pytorch-native/CMakeLists.txt
index 4453186be6f..c53d71dc93e 100644
--- a/engines/pytorch/pytorch-native/CMakeLists.txt
+++ b/engines/pytorch/pytorch-native/CMakeLists.txt
@@ -60,11 +60,12 @@ if(USE_CUDA)
endif()
add_library(djl_torch SHARED ${SOURCE_FILES})
+set_property(TARGET djl_torch PROPERTY CXX_STANDARD 17)
+
# build host
if(NOT BUILD_ANDROID)
target_link_libraries(djl_torch "${TORCH_LIBRARIES}")
target_include_directories(djl_torch PUBLIC build/include ${JNI_INCLUDE_DIRS} ${UTILS_INCLUDE_DIR})
- set_property(TARGET djl_torch PROPERTY CXX_STANDARD 14)
# We have to kill the default rpath and use current dir
set(CMAKE_SKIP_RPATH TRUE)
if(${CMAKE_SYSTEM_NAME} MATCHES "Linux")
diff --git a/engines/pytorch/pytorch-native/build.gradle b/engines/pytorch/pytorch-native/build.gradle
index b4a195e109f..99a658bf3ed 100644
--- a/engines/pytorch/pytorch-native/build.gradle
+++ b/engines/pytorch/pytorch-native/build.gradle
@@ -24,6 +24,8 @@ if (project.hasProperty("cu11")) {
FLAVOR = "cu117"
} else if (VERSION.startsWith("2.0.")) {
FLAVOR = "cu118"
+ } else if (VERSION.startsWith("2.1.")) {
+ FLAVOR = "cu121"
} else {
throw new GradleException("Unsupported PyTorch version: ${VERSION}")
}
@@ -88,15 +90,17 @@ def prepareNativeLib(String binaryRoot, String ver) {
def officialPytorchUrl = "https://download.pytorch.org/libtorch"
def aarch64PytorchUrl = "https://djl-ai.s3.amazonaws.com/publish/pytorch"
- String cu11
+ String cuda
if (ver.startsWith("1.11.")) {
- cu11 = "cu113"
+ cuda = "cu113"
} else if (ver.startsWith("1.12.")) {
- cu11 = "cu116"
+ cuda = "cu116"
} else if (ver.startsWith("1.13.")) {
- cu11 = "cu117"
+ cuda = "cu117"
} else if (ver.startsWith("2.0.")) {
- cu11 = "cu118"
+ cuda = "cu118"
+ } else if (ver.startsWith("2.1.")) {
+ cuda = "cu121"
} else {
throw new GradleException("Unsupported PyTorch version: ${ver}")
}
@@ -105,10 +109,10 @@ def prepareNativeLib(String binaryRoot, String ver) {
"cpu/libtorch-cxx11-abi-shared-with-deps-${ver}%2Bcpu.zip" : "cpu/linux-x86_64",
"cpu/libtorch-macos-${ver}.zip" : "cpu/osx-x86_64",
"cpu/libtorch-win-shared-with-deps-${ver}%2Bcpu.zip" : "cpu/win-x86_64",
- "${cu11}/libtorch-cxx11-abi-shared-with-deps-${ver}%2B${cu11}.zip": "${cu11}/linux-x86_64",
- "${cu11}/libtorch-win-shared-with-deps-${ver}%2B${cu11}.zip" : "${cu11}/win-x86_64",
+ "${cuda}/libtorch-cxx11-abi-shared-with-deps-${ver}%2B${cuda}.zip": "${cuda}/linux-x86_64",
+ "${cuda}/libtorch-win-shared-with-deps-${ver}%2B${cuda}.zip" : "${cuda}/win-x86_64",
"cpu/libtorch-shared-with-deps-${ver}%2Bcpu.zip" : "cpu-precxx11/linux-x86_64",
- "${cu11}/libtorch-shared-with-deps-${ver}%2B${cu11}.zip" : "${cu11}-precxx11/linux-x86_64"
+ "${cuda}/libtorch-shared-with-deps-${ver}%2B${cuda}.zip" : "${cuda}-precxx11/linux-x86_64"
]
def aarch64Files = [
@@ -138,17 +142,12 @@ def copyNativeLibToOutputDir(Map fileStoreMap, String binaryRoot
from zipTree(file)
into outputDir
}
- // CPU dependencies
- copy {
- from("${outputDir}/libtorch/lib/") {
- include "libc10.*", "c10.dll", "libiomp5*.*", "libarm_compute*.*", "libgomp*.*", "libnvfuser_codegen.so", "libtorch.*", "libtorch_cpu.*", "torch.dll", "torch_cpu.dll", "fbgemm.dll", "asmjit.dll", "uv.dll", "nvfuser_codegen.dll"
- }
- into("${outputDir}/native/lib")
- }
- // GPU dependencies
+ delete "${outputDir}/libtorch/lib/*.lib"
+ delete "${outputDir}/libtorch/lib/*.a"
+
copy {
from("${outputDir}/libtorch/lib/") {
- include "libtorch_cuda*.so", "torch_cuda*.dll", "libc10_cuda.so", "c10_cuda.dll", "libcaffe2_nvrtc.so", "libnvrtc*.so.*", "libcudart*.*", "*nvToolsExt*.*", "cudnn*.dll", "caffe2_nvrtc.dll", "nvrtc64*.dll", "uv.dll", "libcublas*", "zlibwapi.dll"
+ include "libarm_compute*", "libc10_cuda.so", "libc10.*", "libcaffe2_nvrtc.so", "libcu*", "libgfortran-*", "libgomp*", "libiomp*", "libnv*", "libopenblasp-*", "libtorch_cpu.*", "libtorch_cuda*.so", "libtorch.*", "asmjit.dll", "c10_cuda.dll", "c10.dll", "caffe2_nvrtc.dll", "cu*.dll", "fbgemm.dll", "nv*.dll", "torch_cpu.dll", "torch_cuda*.dll", "torch.dll", "uv.dll", "zlibwapi.dll"
}
into("${outputDir}/native/lib")
}
@@ -287,9 +286,9 @@ tasks.register('uploadS3') {
"${BINARY_ROOT}/cpu/win-x86_64/native/lib/",
"${BINARY_ROOT}/cpu-precxx11/linux-aarch64/native/lib/",
"${BINARY_ROOT}/cpu-precxx11/linux-x86_64/native/lib/",
- "${BINARY_ROOT}/cu118/linux-x86_64/native/lib/",
- "${BINARY_ROOT}/cu118/win-x86_64/native/lib/",
- "${BINARY_ROOT}/cu118-precxx11/linux-x86_64/native/lib/"
+ "${BINARY_ROOT}/cu121/linux-x86_64/native/lib/",
+ "${BINARY_ROOT}/cu121/win-x86_64/native/lib/",
+ "${BINARY_ROOT}/cu121-precxx11/linux-x86_64/native/lib/"
]
uploadDirs.each { item ->
fileTree(item).files.name.each {
diff --git a/engines/pytorch/pytorch-native/build.sh b/engines/pytorch/pytorch-native/build.sh
index 78c59d6bf2a..ae0456bec62 100755
--- a/engines/pytorch/pytorch-native/build.sh
+++ b/engines/pytorch/pytorch-native/build.sh
@@ -23,22 +23,22 @@ ARCH=$4
if [[ ! -d "libtorch" ]]; then
if [[ $PLATFORM == 'linux' ]]; then
- if [[ ! "$FLAVOR" =~ ^(cpu|cu102|cu113|cu116|cu117|cu118)$ ]]; then
+ if [[ ! "$FLAVOR" =~ ^(cpu|cu102|cu113|cu116|cu117|cu118|cu121)$ ]]; then
echo "$FLAVOR is not supported."
exit 1
fi
if [[ $ARCH == 'aarch64' ]]; then
- curl -s https://djl-ai.s3.amazonaws.com/publish/pytorch/${VERSION}/libtorch${AARCH64_CXX11ABI}-shared-with-deps-${VERSION}-aarch64.zip | jar xv
+ curl -s https://djl-ai.s3.amazonaws.com/publish/pytorch/${VERSION}/libtorch${AARCH64_CXX11ABI}-shared-with-deps-${VERSION}-aarch64.zip | jar xv > /dev/null
else
- curl -s https://download.pytorch.org/libtorch/${FLAVOR}/libtorch${CXX11ABI}-shared-with-deps-${VERSION}%2B${FLAVOR}.zip | jar xv
+ curl -s https://download.pytorch.org/libtorch/${FLAVOR}/libtorch${CXX11ABI}-shared-with-deps-${VERSION}%2B${FLAVOR}.zip | jar xv > /dev/null
fi
elif [[ $PLATFORM == 'darwin' ]]; then
if [[ $ARCH == 'aarch64' ]]; then
- curl -s https://djl-ai.s3.amazonaws.com/publish/pytorch/${VERSION}/libtorch-macos-${VERSION}-aarch64.zip | jar xv
+ curl -s https://djl-ai.s3.amazonaws.com/publish/pytorch/${VERSION}/libtorch-macos-${VERSION}-aarch64.zip | jar xv > /dev/null
else
- curl -s https://download.pytorch.org/libtorch/cpu/libtorch-macos-${VERSION}.zip | jar xv
+ curl -s https://download.pytorch.org/libtorch/cpu/libtorch-macos-${VERSION}.zip | jar xv > /dev/null
fi
else
echo "$PLATFORM is not supported."
@@ -62,6 +62,12 @@ mkdir classes
javac -sourcepath ../../pytorch-engine/src/main/java/ ../../pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java -h include -d classes
cmake -DCMAKE_PREFIX_PATH=libtorch -DPT_VERSION=${PT_VERSION} -DUSE_CUDA=$USE_CUDA ..
cmake --build . --config Release -- -j "${NUM_PROC}"
+if [[ "$FLAVOR" = cu* ]]; then
+ # avoid link with libcudart.so.11.0
+ sed -i -r "s/\/usr\/local\/cuda(.{5})?\/lib64\/lib(cudart|nvrtc).so//g" CMakeFiles/djl_torch.dir/link.txt
+ rm libdjl_torch.so
+ . CMakeFiles/djl_torch.dir/link.txt
+fi
if [[ $PLATFORM == 'darwin' ]]; then
install_name_tool -add_rpath @loader_path libdjl_torch.dylib
diff --git a/engines/pytorch/pytorch-native/build_android.sh b/engines/pytorch/pytorch-native/build_android.sh
index b37dd96a86d..72050b20a85 100755
--- a/engines/pytorch/pytorch-native/build_android.sh
+++ b/engines/pytorch/pytorch-native/build_android.sh
@@ -20,7 +20,7 @@ if [[ ! -d libtorch_android/"$FLAVOR" ]]; then
mkdir -p libtorch_android/"$FLAVOR"
cd libtorch_android/"$FLAVOR"
echo "Downloading https://publish.djl.ai/pytorch/$VERSION/android_native/${FLAVOR}_native.zip"
- curl -s "https://publish.djl.ai/pytorch/$VERSION/android_native/${FLAVOR}_native.zip" | jar xv
+ curl -s "https://publish.djl.ai/pytorch/$VERSION/android_native/${FLAVOR}_native.zip" | jar xv > /dev/null
mv install/include include
cd -
fi
diff --git a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_torch_other.cc b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_torch_other.cc
index 5a65e1eca69..08932098da9 100644
--- a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_torch_other.cc
+++ b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_torch_other.cc
@@ -34,6 +34,28 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchFft(
API_END_RETURN()
}
+JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchFft2(
+ JNIEnv* env, jobject jthis, jlong jhandle, jlongArray js, jlongArray jaxes) {
+ API_BEGIN()
+ const auto* tensor_ptr = reinterpret_cast(jhandle);
+ const std::vector sizes = djl::utils::jni::GetVecFromJLongArray(env, js);
+ const std::vector axes = djl::utils::jni::GetVecFromJLongArray(env, jaxes);
+ const auto* result_ptr = new torch::Tensor(torch::fft_fft2(*tensor_ptr, sizes, axes));
+ return reinterpret_cast(result_ptr);
+ API_END_RETURN()
+}
+
+JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIfft2(
+ JNIEnv* env, jobject jthis, jlong jhandle, jlongArray js, jlongArray jaxes) {
+ API_BEGIN()
+ const auto* tensor_ptr = reinterpret_cast(jhandle);
+ const std::vector sizes = djl::utils::jni::GetVecFromJLongArray(env, js);
+ const std::vector axes = djl::utils::jni::GetVecFromJLongArray(env, jaxes);
+ const auto* result_ptr = new torch::Tensor(torch::fft_ifft2(*tensor_ptr, sizes, axes));
+ return reinterpret_cast(result_ptr);
+ API_END_RETURN()
+}
+
JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchStft(JNIEnv* env, jobject jthis, jlong jhandle,
jlong jn_fft, jlong jhop_length, jlong jwindow, jboolean jcenter, jboolean jnormalize, jboolean jreturn_complex) {
#ifdef V1_11_X
diff --git a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_torch_pointwise.cc b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_torch_pointwise.cc
index 28e40e916be..ccf2616dc65 100644
--- a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_torch_pointwise.cc
+++ b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_torch_pointwise.cc
@@ -355,6 +355,16 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchAtan(JNIEnv*
API_END_RETURN()
}
+JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchAtan2(
+JNIEnv* env, jobject jthis, jlong jself, jlong jother) {
+ API_BEGIN()
+ const auto* self_ptr = reinterpret_cast(jself);
+ const auto* other_ptr = reinterpret_cast(jother);
+ const auto* result_ptr = new torch::Tensor(self_ptr->atan2(*other_ptr));
+ return reinterpret_cast(result_ptr);
+ API_END_RETURN()
+}
+
JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchSqrt(JNIEnv* env, jobject jthis, jlong jhandle) {
API_BEGIN()
const auto* tensor_ptr = reinterpret_cast(jhandle);
@@ -496,6 +506,14 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchErfinv(JNIEn
API_END_RETURN()
}
+JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchErf(JNIEnv* env, jobject jthis, jlong jhandle) {
+ API_BEGIN()
+ const auto* tensor_ptr = reinterpret_cast(jhandle);
+ const auto* result_ptr = new torch::Tensor(tensor_ptr->erf());
+ return reinterpret_cast(result_ptr);
+ API_END_RETURN()
+}
+
JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchInverse(JNIEnv* env, jobject jthis, jlong jself) {
API_BEGIN()
const auto* self_ptr = reinterpret_cast(jself);
diff --git a/engines/tensorflow/tensorflow-api/README.md b/engines/tensorflow/tensorflow-api/README.md
index fd2741dc9e4..9e151a274a0 100644
--- a/engines/tensorflow/tensorflow-api/README.md
+++ b/engines/tensorflow/tensorflow-api/README.md
@@ -16,6 +16,6 @@ You can pull the TensorFlow core java API from the central Maven repository by i
ai.djl.tensorflow
tensorflow-api
- 0.23.0
+ 0.26.0
```
diff --git a/engines/tensorflow/tensorflow-engine/README.md b/engines/tensorflow/tensorflow-engine/README.md
index 57bcdda98d7..5a6ac3e6da1 100644
--- a/engines/tensorflow/tensorflow-engine/README.md
+++ b/engines/tensorflow/tensorflow-engine/README.md
@@ -28,13 +28,13 @@ The javadocs output is built in the `build/doc/javadoc` folder.
You can pull the TensorFlow engine from the central Maven repository by including the following dependency:
-- ai.djl.tensorflow:tensorflow-engine:0.23.0
+- ai.djl.tensorflow:tensorflow-engine:0.26.0
```xml
ai.djl.tensorflow
tensorflow-engine
- 0.23.0
+ 0.26.0
runtime
```
diff --git a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfEngineProvider.java b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfEngineProvider.java
index d964ea5c295..fa7813a49fb 100644
--- a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfEngineProvider.java
+++ b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfEngineProvider.java
@@ -18,7 +18,8 @@
/** {@code TfEngineProvider} is the TensorFlow implementation of {@link EngineProvider}. */
public class TfEngineProvider implements EngineProvider {
- private static volatile Engine engine; // NOPMD
+ private volatile Engine engine; // NOPMD
+ private volatile boolean initialized; // NOPMD
/** {@inheritDoc} */
@Override
@@ -35,9 +36,12 @@ public int getEngineRank() {
/** {@inheritDoc} */
@Override
public Engine getEngine() {
- if (engine == null) {
+ if (!initialized) {
synchronized (TfEngineProvider.class) {
- engine = TfEngine.newInstance();
+ if (!initialized) {
+ initialized = true;
+ engine = TfEngine.newInstance();
+ }
}
}
return engine;
diff --git a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java
index 07c31bacd99..419be4c09f6 100644
--- a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java
+++ b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java
@@ -457,6 +457,12 @@ public NDArray erfinv() {
return manager.opExecutor("Erfinv").addInput(this).buildSingletonOrThrow();
}
+ /** {@inheritDoc} */
+ @Override
+ public NDArray erf() {
+ return manager.opExecutor("Erf").addInput(this).buildSingletonOrThrow();
+ }
+
/** {@inheritDoc} */
@Override
public NDArray norm(boolean keepDims) {
@@ -911,6 +917,12 @@ public NDArray atan() {
return manager.opExecutor("Atan").addInput(this).buildSingletonOrThrow();
}
+ /** {@inheritDoc} */
+ @Override
+ public NDArray atan2(NDArray other) {
+ return manager.opExecutor("Atan2").addInput(this).addInput(other).buildSingletonOrThrow();
+ }
+
/** {@inheritDoc} */
@Override
public NDArray sinh() {
@@ -1172,6 +1184,18 @@ public NDArray stft(
throw new UnsupportedOperationException("Not implemented yet.");
}
+ /** {@inheritDoc} */
+ @Override
+ public NDArray fft2(long[] sizes, long[] axes) {
+ throw new UnsupportedOperationException("Not implemented yet.");
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public NDArray ifft2(long[] sizes, long[] axes) {
+ throw new UnsupportedOperationException("Not implemented yet.");
+ }
+
/** {@inheritDoc} */
@Override
public NDArray reshape(Shape shape) {
diff --git a/engines/tensorflow/tensorflow-model-zoo/README.md b/engines/tensorflow/tensorflow-model-zoo/README.md
index b34154fa126..975caa6df82 100644
--- a/engines/tensorflow/tensorflow-model-zoo/README.md
+++ b/engines/tensorflow/tensorflow-model-zoo/README.md
@@ -26,7 +26,7 @@ from the central Maven repository by including the following dependency:
ai.djl.tensorflow
tensorflow-model-zoo
- 0.23.0
+ 0.26.0
```
diff --git a/engines/tensorflow/tensorflow-native/build.gradle b/engines/tensorflow/tensorflow-native/build.gradle
index 8138d93334d..56cd6eed9e2 100644
--- a/engines/tensorflow/tensorflow-native/build.gradle
+++ b/engines/tensorflow/tensorflow-native/build.gradle
@@ -153,6 +153,7 @@ flavorNames.each { flavor ->
}
from file("${BINARY_ROOT}/${flavor}/${osName}")
archiveClassifier = "${osName}-x86_64"
+ archiveBaseName = "tensorflow-native-${flavor}"
manifest {
attributes("Automatic-Module-Name": "ai.djl.tensorflow_native_${flavor}_${osName}")
diff --git a/engines/tensorrt/README.md b/engines/tensorrt/README.md
index 6373386479e..8100b615e24 100644
--- a/engines/tensorrt/README.md
+++ b/engines/tensorrt/README.md
@@ -28,13 +28,13 @@ The javadocs output is generated in the `build/doc/javadoc` folder.
## Installation
You can pull the TensorRT engine from the central Maven repository by including the following dependency:
-- ai.djl.tensorrt:tensorrt:0.23.0
+- ai.djl.tensorrt:tensorrt:0.26.0
```xml
ai.djl.tensorrt
tensorrt
- 0.23.0
+ 0.26.0
runtime
```
diff --git a/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtEngineProvider.java b/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtEngineProvider.java
index 05a7eceeb41..8c90859c6c6 100644
--- a/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtEngineProvider.java
+++ b/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtEngineProvider.java
@@ -18,7 +18,8 @@
/** {@code TrtEngineProvider} is the TensorRT implementation of {@link EngineProvider}. */
public class TrtEngineProvider implements EngineProvider {
- private static volatile Engine engine; // NOPMD
+ private volatile Engine engine; // NOPMD
+ private volatile boolean initialized; // NOPMD
/** {@inheritDoc} */
@Override
@@ -35,9 +36,12 @@ public int getEngineRank() {
/** {@inheritDoc} */
@Override
public Engine getEngine() {
- if (engine == null) {
+ if (!initialized) {
synchronized (TrtEngineProvider.class) {
- engine = TrtEngine.newInstance();
+ if (!initialized) {
+ initialized = true;
+ engine = TrtEngine.newInstance();
+ }
}
}
return engine;
diff --git a/engines/tflite/tflite-engine/README.md b/engines/tflite/tflite-engine/README.md
index b1dd8fc9778..861a66f9aaa 100644
--- a/engines/tflite/tflite-engine/README.md
+++ b/engines/tflite/tflite-engine/README.md
@@ -24,13 +24,13 @@ The javadocs output is built in the `build/doc/javadoc` folder.
## Installation
You can pull the TensorFlow Lite engine from the central Maven repository by including the following dependency:
-- ai.djl.tflite:tflite-engine:0.23.0
+- ai.djl.tflite:tflite-engine:0.26.0
```xml
ai.djl.tflite
tflite-engine
- 0.23.0
+ 0.26.0
runtime
```
diff --git a/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngineProvider.java b/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngineProvider.java
index aa0fdb73d21..b46cad53b99 100644
--- a/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngineProvider.java
+++ b/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngineProvider.java
@@ -18,7 +18,8 @@
/** {@code TfLiteEngineProvider} is the TFLite implementation of {@link EngineProvider}. */
public class TfLiteEngineProvider implements EngineProvider {
- private static volatile Engine engine; // NOPMD
+ private volatile Engine engine; // NOPMD
+ private volatile boolean initialized; // NOPMD
/** {@inheritDoc} */
@Override
@@ -35,9 +36,12 @@ public int getEngineRank() {
/** {@inheritDoc} */
@Override
public Engine getEngine() {
- if (engine == null) {
+ if (!initialized) {
synchronized (TfLiteEngineProvider.class) {
- engine = TfLiteEngine.newInstance();
+ if (!initialized) {
+ initialized = true;
+ engine = TfLiteEngine.newInstance();
+ }
}
}
return engine;
diff --git a/engines/tflite/tflite-native/build.gradle b/engines/tflite/tflite-native/build.gradle
index eb045331c12..3e2a6008f38 100644
--- a/engines/tflite/tflite-native/build.gradle
+++ b/engines/tflite/tflite-native/build.gradle
@@ -155,6 +155,7 @@ flavorNames.each { flavor ->
from file("src/main/resources")
from file("${project.buildDir}/classes/java/main")
archiveClassifier = "${osName}"
+ archiveBaseName = "tflite-native-${flavor}"
manifest {
attributes("Automatic-Module-Name": "ai.djl.tflite_native_${flavor}_${osName}")
diff --git a/examples/docs/image_classification.md b/examples/docs/image_classification.md
index 1f515f9680f..c8f331320a8 100644
--- a/examples/docs/image_classification.md
+++ b/examples/docs/image_classification.md
@@ -6,7 +6,7 @@ In this example, you learn how to implement inference code with Deep Java Librar
The image classification example code can be found at [ImageClassification.java](https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/inference/ImageClassification.java).
-You can also use the [Jupyter notebook tutorial](../../jupyter/tutorial/03_image_classification_with_your_model.ipynb).
+You can also use the [Jupyter notebook tutorial](http://docs.djl.ai/docs/demos/jupyter/tutorial/03_image_classification_with_your_model.html).
The Jupyter notebook explains the key concepts in detail.
## Setup Guide
diff --git a/examples/docs/object_detection.md b/examples/docs/object_detection.md
index 7d0898128b9..84286fb6e00 100644
--- a/examples/docs/object_detection.md
+++ b/examples/docs/object_detection.md
@@ -7,7 +7,7 @@ In this example, you learn how to implement inference code with a [ModelZoo mode
The source code can be found at [ObjectDetection.java](https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/inference/ObjectDetection.java).
-You can also use the [Jupyter notebook tutorial](../../jupyter/object_detection_with_model_zoo.ipynb).
+You can also use the [Jupyter notebook tutorial](http://docs.djl.ai/docs/demos/jupyter/object_detection_with_model_zoo.html).
The Jupyter notebook explains the key concepts in detail.
## Setup guide
diff --git a/examples/docs/stable_diffusion.md b/examples/docs/stable_diffusion.md
index 7eb544646ee..be3cbb48d6e 100644
--- a/examples/docs/stable_diffusion.md
+++ b/examples/docs/stable_diffusion.md
@@ -1,4 +1,4 @@
-## Stable Diffusion in DJL
+# Stable Diffusion in DJL
[Stable Diffusion](https://stability.ai/blog/stable-diffusion-public-release) is an open-source model
developed by Stability.ai. It aimed to produce images (artwork, pictures, etc.) based on
diff --git a/examples/docs/train_cifar10_resnet.md b/examples/docs/train_cifar10_resnet.md
index cfaf03f8a61..1cdfcb495c2 100644
--- a/examples/docs/train_cifar10_resnet.md
+++ b/examples/docs/train_cifar10_resnet.md
@@ -5,7 +5,7 @@ In this example, you learn how to train the [CIFAR-10](https://www.cs.toronto.ed
You can find the example source code in: [TrainResnetWithCifar10.java](https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/training/transferlearning/TrainResnetWithCifar10.java).
-You can also find the Jupyter notebook tutorial [here](../../jupyter/transfer_learning_on_cifar10.ipynb).
+You can also find the Jupyter notebook tutorial [here](http://docs.djl.ai/docs/demos/jupyter/transfer_learning_on_cifar10.html).
The Jupyter notebook explains the key concepts in detail.
## Setup guide
diff --git a/examples/docs/train_mnist_mlp.md b/examples/docs/train_mnist_mlp.md
index 72b591d062a..40a32ca365f 100644
--- a/examples/docs/train_mnist_mlp.md
+++ b/examples/docs/train_mnist_mlp.md
@@ -6,7 +6,7 @@ In this example, you learn how to train the MNIST dataset with Deep Java Library
The source code for this example can be found at [TrainMnist.java](https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/training/TrainMnist.java).
-You can also use the [Jupyter notebook tutorial](../../jupyter/tutorial/02_train_your_first_model.ipynb).
+You can also use the [Jupyter notebook tutorial](http://docs.djl.ai/docs/demos/jupyter/tutorial/02_train_your_first_model.html).
The Jupyter notebook explains the key concepts in detail.
## Setup guide
diff --git a/examples/pom.xml b/examples/pom.xml
index 9eb2ee32fa0..cc18358e947 100644
--- a/examples/pom.xml
+++ b/examples/pom.xml
@@ -5,12 +5,12 @@
ai.djl
examples
- 0.24.0-SNAPSHOT
+ 0.27.0-SNAPSHOT
11
11
- 0.24.0-SNAPSHOT
+ 0.27.0-SNAPSHOT
ai.djl.examples.inference.ObjectDetection
@@ -41,7 +41,7 @@
org.apache.logging.log4j
log4j-slf4j-impl
- 2.18.0
+ 2.21.0
ai.djl
diff --git a/examples/src/main/java/ai/djl/examples/inference/BertQaInference.java b/examples/src/main/java/ai/djl/examples/inference/BertQaInference.java
index b667cd29f90..093e159bebb 100644
--- a/examples/src/main/java/ai/djl/examples/inference/BertQaInference.java
+++ b/examples/src/main/java/ai/djl/examples/inference/BertQaInference.java
@@ -34,9 +34,8 @@
* See:
*
*
- * the jupyter
- * demo with more information about BERT.
+ * the jupyter demo with more
+ * information about BERT.
* the docs
* for information about running this example.
diff --git a/examples/src/main/java/ai/djl/examples/inference/Yolov8Detection.java b/examples/src/main/java/ai/djl/examples/inference/Yolov8Detection.java
new file mode 100644
index 00000000000..3d2cfb26409
--- /dev/null
+++ b/examples/src/main/java/ai/djl/examples/inference/Yolov8Detection.java
@@ -0,0 +1,86 @@
+/*
+ * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+package ai.djl.examples.inference;
+
+import ai.djl.ModelException;
+import ai.djl.inference.Predictor;
+import ai.djl.modality.cv.Image;
+import ai.djl.modality.cv.ImageFactory;
+import ai.djl.modality.cv.output.DetectedObjects;
+import ai.djl.modality.cv.translator.YoloV8TranslatorFactory;
+import ai.djl.repository.zoo.Criteria;
+import ai.djl.repository.zoo.ZooModel;
+import ai.djl.training.util.ProgressBar;
+import ai.djl.translate.TranslateException;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+
+/** An example of inference using an yolov8 model. */
+public final class Yolov8Detection {
+
+ private static final Logger logger = LoggerFactory.getLogger(Yolov8Detection.class);
+
+ private Yolov8Detection() {}
+
+ public static void main(String[] args) throws IOException, ModelException, TranslateException {
+ DetectedObjects detection = Yolov8Detection.predict();
+ logger.info("{}", detection);
+ }
+
+ public static DetectedObjects predict() throws IOException, ModelException, TranslateException {
+ Path imgPath = Paths.get("src/test/resources/yolov8_test.jpg");
+ Image img = ImageFactory.getInstance().fromFile(imgPath);
+
+ Criteria criteria =
+ Criteria.builder()
+ .setTypes(Image.class, DetectedObjects.class)
+ .optModelUrls("djl://ai.djl.onnxruntime/yolov8n")
+ .optEngine("OnnxRuntime")
+ .optArgument("width", 640)
+ .optArgument("height", 640)
+ .optArgument("resize", true)
+ .optArgument("toTensor", true)
+ .optArgument("applyRatio", true)
+ .optArgument("threshold", 0.6f)
+ // for performance optimization maxBox parameter can reduce number of
+ // considered boxes from 8400
+ .optArgument("maxBox", 1000)
+ .optTranslatorFactory(new YoloV8TranslatorFactory())
+ .optProgress(new ProgressBar())
+ .build();
+
+ try (ZooModel model = criteria.loadModel();
+ Predictor predictor = model.newPredictor()) {
+ Path outputPath = Paths.get("build/output");
+ Files.createDirectories(outputPath);
+
+ DetectedObjects detection = predictor.predict(img);
+ if (detection.getNumberOfObjects() > 0) {
+ img.drawBoundingBoxes(detection);
+ Path output = outputPath.resolve("yolov8_detected.png");
+ try (OutputStream os = Files.newOutputStream(output)) {
+ img.save(os, "png");
+ }
+ logger.info("Detected object saved in: {}", output);
+ }
+ return detection;
+ }
+ }
+}
diff --git a/examples/src/main/java/ai/djl/examples/training/TrainBertOnCode.java b/examples/src/main/java/ai/djl/examples/training/TrainBertOnCode.java
index aa2b12af420..193f6643d56 100644
--- a/examples/src/main/java/ai/djl/examples/training/TrainBertOnCode.java
+++ b/examples/src/main/java/ai/djl/examples/training/TrainBertOnCode.java
@@ -28,6 +28,7 @@
import ai.djl.training.TrainingConfig;
import ai.djl.training.TrainingResult;
import ai.djl.training.initializer.TruncatedNormalInitializer;
+import ai.djl.training.listener.TrainingListener;
import ai.djl.training.listener.TrainingListener.Defaults;
import ai.djl.training.optimizer.Adam;
import ai.djl.training.optimizer.Optimizer;
@@ -109,6 +110,8 @@ private static TrainingConfig createTrainingConfig(BertArguments arguments) {
return new DefaultTrainingConfig(new BertPretrainingLoss())
.optOptimizer(optimizer)
.optDevices(Engine.getInstance().getDevices(arguments.getMaxGpus()))
+ .addTrainingListeners(
+ TrainingListener.Defaults.algebraicLogging(arguments.getAlgebraicLogFile()))
.addTrainingListeners(Defaults.logging());
}
diff --git a/examples/src/main/java/ai/djl/examples/training/TrainMnist.java b/examples/src/main/java/ai/djl/examples/training/TrainMnist.java
index 786a71bfbed..85d09145081 100644
--- a/examples/src/main/java/ai/djl/examples/training/TrainMnist.java
+++ b/examples/src/main/java/ai/djl/examples/training/TrainMnist.java
@@ -107,6 +107,8 @@ private static DefaultTrainingConfig setupTrainingConfig(Arguments arguments) {
.addEvaluator(new Accuracy())
.optDevices(Engine.getInstance().getDevices(arguments.getMaxGpus()))
.addTrainingListeners(TrainingListener.Defaults.logging(outputDir))
+ .addTrainingListeners(
+ TrainingListener.Defaults.algebraicLogging(arguments.getAlgebraicLogFile()))
.addTrainingListeners(listener);
}
diff --git a/examples/src/main/java/ai/djl/examples/training/TrainTimeSeries.java b/examples/src/main/java/ai/djl/examples/training/TrainTimeSeries.java
index e0143ed524b..33db2efd2ff 100644
--- a/examples/src/main/java/ai/djl/examples/training/TrainTimeSeries.java
+++ b/examples/src/main/java/ai/djl/examples/training/TrainTimeSeries.java
@@ -214,6 +214,8 @@ private static DefaultTrainingConfig setupTrainingConfig(
.addEvaluator(new Rmsse(distributionOutput))
.optDevices(Engine.getInstance().getDevices(arguments.getMaxGpus()))
.optInitializer(new XavierInitializer(), Parameter.Type.WEIGHT)
+ .addTrainingListeners(
+ TrainingListener.Defaults.algebraicLogging(arguments.getAlgebraicLogFile()))
.addTrainingListeners(TrainingListener.Defaults.logging(outputDir))
.addTrainingListeners(listener);
}
diff --git a/examples/src/main/java/ai/djl/examples/training/transferlearning/TrainResnetWithCifar10.java b/examples/src/main/java/ai/djl/examples/training/transferlearning/TrainResnetWithCifar10.java
index 7acb2f3531f..aa6dbc389dc 100644
--- a/examples/src/main/java/ai/djl/examples/training/transferlearning/TrainResnetWithCifar10.java
+++ b/examples/src/main/java/ai/djl/examples/training/transferlearning/TrainResnetWithCifar10.java
@@ -215,6 +215,8 @@ private static DefaultTrainingConfig setupTrainingConfig(Arguments arguments) {
return new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
.addEvaluator(new Accuracy())
.optDevices(Engine.getInstance().getDevices(arguments.getMaxGpus()))
+ .addTrainingListeners(
+ TrainingListener.Defaults.algebraicLogging(arguments.getAlgebraicLogFile()))
.addTrainingListeners(TrainingListener.Defaults.logging(arguments.getOutputDir()));
}
diff --git a/examples/src/main/java/ai/djl/examples/training/util/Arguments.java b/examples/src/main/java/ai/djl/examples/training/util/Arguments.java
index bbfa48f6381..e72f0b94fa6 100644
--- a/examples/src/main/java/ai/djl/examples/training/util/Arguments.java
+++ b/examples/src/main/java/ai/djl/examples/training/util/Arguments.java
@@ -38,6 +38,7 @@ public class Arguments {
protected long limit;
protected String modelDir;
protected Map criteria;
+ protected String algebraicLogFile;
protected void initialize() {
epoch = 2;
@@ -45,6 +46,7 @@ protected void initialize() {
outputDir = "build/model";
limit = Long.MAX_VALUE;
modelDir = null;
+ algebraicLogFile = null;
}
protected void setCmd(CommandLine cmd) {
@@ -75,6 +77,9 @@ protected void setCmd(CommandLine cmd) {
Type type = new TypeToken>() {}.getType();
criteria = JsonUtils.GSON.fromJson(cmd.getOptionValue("criteria"), type);
}
+ if (cmd.hasOption("algebraic-log")) {
+ algebraicLogFile = cmd.getOptionValue("algebraic-log");
+ }
}
public Arguments parseArgs(String[] args) {
@@ -162,6 +167,15 @@ public Options getOptions() {
.argName("CRITERIA")
.desc("The criteria used for the model.")
.build());
+ options.addOption(
+ Option.builder("a")
+ .longOpt("algebraic-log")
+ .hasArg()
+ .argName("ALGEBRAIC-LOG")
+ .desc(
+ "File to log algebraic operations executed during training as"
+ + " Python program.")
+ .build());
return options;
}
@@ -193,6 +207,10 @@ public String getOutputDir() {
return outputDir;
}
+ public String getAlgebraicLogFile() {
+ return algebraicLogFile;
+ }
+
public long getLimit() {
return limit;
}
diff --git a/examples/src/test/java/ai/djl/examples/inference/Yolov8DetectionTest.java b/examples/src/test/java/ai/djl/examples/inference/Yolov8DetectionTest.java
new file mode 100644
index 00000000000..35e3fc434aa
--- /dev/null
+++ b/examples/src/test/java/ai/djl/examples/inference/Yolov8DetectionTest.java
@@ -0,0 +1,40 @@
+/*
+ * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+package ai.djl.examples.inference;
+
+import ai.djl.ModelException;
+import ai.djl.modality.Classifications;
+import ai.djl.modality.cv.output.DetectedObjects;
+import ai.djl.testing.TestRequirements;
+import ai.djl.translate.TranslateException;
+
+import org.testng.Assert;
+import org.testng.annotations.Test;
+
+import java.io.IOException;
+
+public class Yolov8DetectionTest {
+
+ @Test
+ public void testYolov8Detection() throws ModelException, TranslateException, IOException {
+ TestRequirements.engine("MXNet", "PyTorch");
+
+ DetectedObjects result = Yolov8Detection.predict();
+
+ Assert.assertTrue(result.getNumberOfObjects() >= 1);
+ Classifications.Classification obj = result.best();
+ String className = obj.getClassName();
+ Assert.assertEquals(className, "dog");
+ Assert.assertTrue(obj.getProbability() > 0.6);
+ }
+}
diff --git a/examples/src/test/java/ai/djl/examples/training/TrainPikachuTest.java b/examples/src/test/java/ai/djl/examples/training/TrainPikachuTest.java
index 2a61e25862e..1a5699836c8 100644
--- a/examples/src/test/java/ai/djl/examples/training/TrainPikachuTest.java
+++ b/examples/src/test/java/ai/djl/examples/training/TrainPikachuTest.java
@@ -27,7 +27,6 @@ public class TrainPikachuTest {
@Test
public void testDetection() throws IOException, MalformedModelException, TranslateException {
- TestRequirements.engine("MXNet");
TestRequirements.nightly();
String[] args;
diff --git a/examples/src/test/java/ai/djl/examples/training/TrainWithAlgebraicLogging.java b/examples/src/test/java/ai/djl/examples/training/TrainWithAlgebraicLogging.java
new file mode 100644
index 00000000000..a57373b5891
--- /dev/null
+++ b/examples/src/test/java/ai/djl/examples/training/TrainWithAlgebraicLogging.java
@@ -0,0 +1,90 @@
+/*
+ * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+package ai.djl.examples.training;
+
+import ai.djl.ModelException;
+import ai.djl.engine.Engine;
+import ai.djl.examples.training.transferlearning.TrainResnetWithCifar10;
+import ai.djl.testing.TestRequirements;
+import ai.djl.training.TrainingResult;
+import ai.djl.translate.TranslateException;
+import ai.djl.util.Utils;
+
+import org.testng.Assert;
+import org.testng.annotations.Test;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.List;
+
+public class TrainWithAlgebraicLogging {
+
+ private static final int SEED = 1234;
+
+ @Test
+ public void testTrainMnist() throws ModelException, TranslateException, IOException {
+ TestRequirements.engine("MXNet");
+
+ Path logDir = Paths.get("build/tmp/algebraiclog");
+ Path algebraicLogFile = logDir.resolve("TrainMnist.py");
+ if (!algebraicLogFile.toFile().delete()) {
+ Files.createDirectories(logDir);
+ }
+
+ String[] args =
+ new String[] {"-g", "1", "-m", "2", "-a", algebraicLogFile.toFile().toString()};
+
+ TrainMnist.runExample(args);
+ Path path = Paths.get("src/test/resources/algebraiclog/TrainMnist.py");
+
+ try (InputStream is = Files.newInputStream(path);
+ InputStream isActual = Files.newInputStream(algebraicLogFile)) {
+ List expected = Utils.readLines(is);
+ List actual = Utils.readLines(isActual);
+ Assert.assertEquals(expected, actual);
+ }
+ }
+
+ @Test
+ public void testTrainResNetImperative() throws ModelException, IOException, TranslateException {
+ TestRequirements.engine("MXNet");
+
+ Path logDir = Paths.get("build/tmp/algebraiclog");
+ Path algebraicLogFile = logDir.resolve("TrainResnetWithCifar10.py");
+ if (!algebraicLogFile.toFile().delete()) {
+ Files.createDirectories(logDir);
+ }
+
+ // Limit max 4 gpu for cifar10 training to make it converge faster.
+ // and only train 10 batch for unit test.
+ String[] args = {
+ "-e", "2", "-g", "4", "-m", "1", "-b", "111", "-a", algebraicLogFile.toFile().toString()
+ };
+
+ Engine.getInstance().setRandomSeed(SEED);
+ TrainingResult result = TrainResnetWithCifar10.runExample(args);
+ Assert.assertNotNull(result);
+
+ Path path = Paths.get("src/test/resources/algebraiclog/TrainResnetWithCifar10.py");
+
+ try (InputStream is = Files.newInputStream(path);
+ InputStream isActual = Files.newInputStream(algebraicLogFile)) {
+ List expected = Utils.readLines(is);
+ List actual = Utils.readLines(isActual);
+ Assert.assertEquals(expected, actual);
+ }
+ }
+}
diff --git a/examples/src/test/java/ai/djl/testing/TestRequirements.java b/examples/src/test/java/ai/djl/testing/TestRequirements.java
index e8c9bd4bdda..01eef756201 100644
--- a/examples/src/test/java/ai/djl/testing/TestRequirements.java
+++ b/examples/src/test/java/ai/djl/testing/TestRequirements.java
@@ -14,6 +14,7 @@
import ai.djl.engine.Engine;
import ai.djl.engine.EngineException;
+import ai.djl.util.Utils;
import org.testng.SkipException;
@@ -45,7 +46,7 @@ public static void weekly() {
/** Requires a test not be run in offline mode. */
public static void notOffline() {
- if (Boolean.getBoolean("offline")) {
+ if (Utils.isOfflineMode()) {
throw new SkipException("This test can not run while offline");
}
}
diff --git a/examples/src/test/resources/algebraiclog/TrainMnist.py b/examples/src/test/resources/algebraiclog/TrainMnist.py
new file mode 100644
index 00000000000..e5fa7b2763e
--- /dev/null
+++ b/examples/src/test/resources/algebraiclog/TrainMnist.py
@@ -0,0 +1,121 @@
+class MyModel(tf.keras.Model):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self._02Linear_weight = tf.Variable(
+ tf.random.normal(
+ shape=[128, 784],
+ mean=0.0,
+ stddev=0.050507627,
+ dtype=tf.dtypes.float32,
+ name='normal_1_',
+ ) # (128, 784)
+ )
+ self._02Linear_bias = tf.Variable(
+ tf.zeros(
+ shape=[128],
+ dtype=tf.dtypes.float32,
+ name='zeros_2_',
+ ) # (128)
+ )
+ self._04Linear_weight = tf.Variable(
+ tf.random.normal(
+ shape=[64, 128],
+ mean=0.0,
+ stddev=0.125,
+ dtype=tf.dtypes.float32,
+ name='normal_3_',
+ ) # (64, 128)
+ )
+ self._04Linear_bias = tf.Variable(
+ tf.zeros(
+ shape=[64],
+ dtype=tf.dtypes.float32,
+ name='zeros_4_',
+ ) # (64)
+ )
+ self._06Linear_weight = tf.Variable(
+ tf.random.normal(
+ shape=[10, 64],
+ mean=0.0,
+ stddev=0.17677669,
+ dtype=tf.dtypes.float32,
+ name='normal_5_',
+ ) # (10, 64)
+ )
+ self._06Linear_bias = tf.Variable(
+ tf.zeros(
+ shape=[10],
+ dtype=tf.dtypes.float32,
+ name='zeros_6_',
+ ) # (10)
+ )
+
+## 4
+ def call(self, x):
+ result = tf.nn.bias_add(
+ tf.matmul(
+ tf.nn.relu(
+ tf.nn.bias_add(
+ tf.matmul(
+ tf.nn.relu(
+ tf.nn.bias_add(
+ tf.matmul(
+ tf.reshape(
+ x, # (32, 1, 28, 28)
+ shape=[-1, 784],
+ name='reshape_7_',
+ ), # (32, 784)
+ b=self._02Linear_weight, # (128, 784)
+ transpose_b=True,
+ name='matmul_8_',
+ ), # (32, 128)
+ bias=self._02Linear_bias, # (128)
+ data_format=None,
+ name='bias_add_9_',
+ ), # (32, 128)
+ name='relu_10_',
+ ), # (32, 128)
+ b=self._04Linear_weight, # (64, 128)
+ transpose_b=True,
+ name='matmul_11_',
+ ), # (32, 64)
+ bias=self._04Linear_bias, # (64)
+ data_format=None,
+ name='bias_add_12_',
+ ), # (32, 64)
+ name='relu_13_',
+ ), # (32, 64)
+ b=self._06Linear_weight, # (10, 64)
+ transpose_b=True,
+ name='matmul_14_',
+ ), # (32, 10)
+ bias=self._06Linear_bias, # (10)
+ data_format=None,
+ name='bias_add_15_',
+ ) # (32, 10)
+ return result
+
+## 4
+def loss(label, prediction):
+ result = tf.reduce_mean(
+ tf.negative(
+ tf.gather(
+ tf.nn.log_softmax(
+ prediction, # (32, 10)
+ axis=-1,
+ name='log_softmax_16_',
+ ), # (32, 10)
+ indices=label, # (32)
+ batch_dims=1,
+ name='gather_17_',
+ ), # (32, 1)
+ name='negative_18_',
+ ), # (32, 1)
+ name='reduce_mean_19_',
+ ) # ()
+ return result
+
+# number of epochs was 2
+# number of prediction functions is 1
+# number of loss functions is 1
+
diff --git a/examples/src/test/resources/algebraiclog/TrainResnetWithCifar10.py b/examples/src/test/resources/algebraiclog/TrainResnetWithCifar10.py
new file mode 100644
index 00000000000..cb3b619f810
--- /dev/null
+++ b/examples/src/test/resources/algebraiclog/TrainResnetWithCifar10.py
@@ -0,0 +1,4084 @@
+class MyModel(tf.keras.Model):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self._01Conv2d_weight = tf.Variable(
+ tf.transpose(
+ tf.random.normal(
+ shape=[64, 3, 3, 3],
+ mean=0.0,
+ stddev=0.27216554,
+ dtype=tf.dtypes.float32,
+ name='normal_1_',
+ ), # (64, 3, 3, 3)
+ perm=[2, 3, 1, 0],
+ name='transpose_2_',
+ ) # (3, 3, 3, 64)
+ )
+ self._02ParallelBlock_01SequentialBlock_01Conv2d_weight = tf.Variable(
+ tf.transpose(
+ tf.random.normal(
+ shape=[64, 64, 1, 1],
+ mean=0.0,
+ stddev=0.17677669,
+ dtype=tf.dtypes.float32,
+ name='normal_3_',
+ ), # (64, 64, 1, 1)
+ perm=[2, 3, 1, 0],
+ name='transpose_4_',
+ ) # (1, 1, 64, 64)
+ )
+ self._02ParallelBlock_01SequentialBlock_01Conv2d_bias = tf.Variable(
+ tf.zeros(
+ shape=[64],
+ dtype=tf.dtypes.float32,
+ name='zeros_5_',
+ ) # (64)
+ )
+ self._02ParallelBlock_01SequentialBlock_02BatchNorm_gamma = tf.Variable(
+ tf.ones(
+ shape=[64],
+ dtype=tf.dtypes.float32,
+ name='ones_6_',
+ ) # (64)
+ )
+ self._02ParallelBlock_01SequentialBlock_02BatchNorm_beta = tf.Variable(
+ tf.zeros(
+ shape=[64],
+ dtype=tf.dtypes.float32,
+ name='zeros_7_',
+ ) # (64)
+ )
+ self._02ParallelBlock_01SequentialBlock_02BatchNorm_runningMean = tf.Variable(
+ tf.zeros(
+ shape=[64],
+ dtype=tf.dtypes.float32,
+ name='zeros_8_',
+ ) # (64)
+ , trainable = False
+ )
+ self._02ParallelBlock_01SequentialBlock_02BatchNorm_runningVar = tf.Variable(
+ tf.ones(
+ shape=[64],
+ dtype=tf.dtypes.float32,
+ name='ones_9_',
+ ) # (64)
+ , trainable = False
+ )
+ self._02ParallelBlock_01SequentialBlock_04Conv2d_weight = tf.Variable(
+ tf.transpose(
+ tf.random.normal(
+ shape=[64, 64, 3, 3],
+ mean=0.0,
+ stddev=0.058925565,
+ dtype=tf.dtypes.float32,
+ name='normal_10_',
+ ), # (64, 64, 3, 3)
+ perm=[2, 3, 1, 0],
+ name='transpose_11_',
+ ) # (3, 3, 64, 64)
+ )
+ self._02ParallelBlock_01SequentialBlock_05BatchNorm_gamma = tf.Variable(
+ tf.ones(
+ shape=[64],
+ dtype=tf.dtypes.float32,
+ name='ones_12_',
+ ) # (64)
+ )
+ self._02ParallelBlock_01SequentialBlock_05BatchNorm_beta = tf.Variable(
+ tf.zeros(
+ shape=[64],
+ dtype=tf.dtypes.float32,
+ name='zeros_13_',
+ ) # (64)
+ )
+ self._02ParallelBlock_01SequentialBlock_05BatchNorm_runningMean = tf.Variable(
+ tf.zeros(
+ shape=[64],
+ dtype=tf.dtypes.float32,
+ name='zeros_14_',
+ ) # (64)
+ , trainable = False
+ )
+ self._02ParallelBlock_01SequentialBlock_05BatchNorm_runningVar = tf.Variable(
+ tf.ones(
+ shape=[64],
+ dtype=tf.dtypes.float32,
+ name='ones_15_',
+ ) # (64)
+ , trainable = False
+ )
+ self._02ParallelBlock_01SequentialBlock_07Conv2d_weight = tf.Variable(
+ tf.transpose(
+ tf.random.normal(
+ shape=[256, 64, 1, 1],
+ mean=0.0,
+ stddev=0.17677669,
+ dtype=tf.dtypes.float32,
+ name='normal_16_',
+ ), # (256, 64, 1, 1)
+ perm=[2, 3, 1, 0],
+ name='transpose_17_',
+ ) # (1, 1, 64, 256)
+ )
+ self._02ParallelBlock_01SequentialBlock_07Conv2d_bias = tf.Variable(
+ tf.zeros(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='zeros_18_',
+ ) # (256)
+ )
+ self._02ParallelBlock_01SequentialBlock_08BatchNorm_gamma = tf.Variable(
+ tf.ones(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='ones_19_',
+ ) # (256)
+ )
+ self._02ParallelBlock_01SequentialBlock_08BatchNorm_beta = tf.Variable(
+ tf.zeros(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='zeros_20_',
+ ) # (256)
+ )
+ self._02ParallelBlock_01SequentialBlock_08BatchNorm_runningMean = tf.Variable(
+ tf.zeros(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='zeros_21_',
+ ) # (256)
+ , trainable = False
+ )
+ self._02ParallelBlock_01SequentialBlock_08BatchNorm_runningVar = tf.Variable(
+ tf.ones(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='ones_22_',
+ ) # (256)
+ , trainable = False
+ )
+ self._02ParallelBlock_02SequentialBlock_01Conv2d_weight = tf.Variable(
+ tf.transpose(
+ tf.random.normal(
+ shape=[256, 64, 1, 1],
+ mean=0.0,
+ stddev=0.17677669,
+ dtype=tf.dtypes.float32,
+ name='normal_23_',
+ ), # (256, 64, 1, 1)
+ perm=[2, 3, 1, 0],
+ name='transpose_24_',
+ ) # (1, 1, 64, 256)
+ )
+ self._02ParallelBlock_02SequentialBlock_02BatchNorm_gamma = tf.Variable(
+ tf.ones(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='ones_25_',
+ ) # (256)
+ )
+ self._02ParallelBlock_02SequentialBlock_02BatchNorm_beta = tf.Variable(
+ tf.zeros(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='zeros_26_',
+ ) # (256)
+ )
+ self._02ParallelBlock_02SequentialBlock_02BatchNorm_runningMean = tf.Variable(
+ tf.zeros(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='zeros_27_',
+ ) # (256)
+ , trainable = False
+ )
+ self._02ParallelBlock_02SequentialBlock_02BatchNorm_runningVar = tf.Variable(
+ tf.ones(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='ones_28_',
+ ) # (256)
+ , trainable = False
+ )
+ self._03ParallelBlock_01SequentialBlock_01Conv2d_weight = tf.Variable(
+ tf.transpose(
+ tf.random.normal(
+ shape=[64, 256, 1, 1],
+ mean=0.0,
+ stddev=0.088388346,
+ dtype=tf.dtypes.float32,
+ name='normal_29_',
+ ), # (64, 256, 1, 1)
+ perm=[2, 3, 1, 0],
+ name='transpose_30_',
+ ) # (1, 1, 256, 64)
+ )
+ self._03ParallelBlock_01SequentialBlock_01Conv2d_bias = tf.Variable(
+ tf.zeros(
+ shape=[64],
+ dtype=tf.dtypes.float32,
+ name='zeros_31_',
+ ) # (64)
+ )
+ self._03ParallelBlock_01SequentialBlock_02BatchNorm_gamma = tf.Variable(
+ tf.ones(
+ shape=[64],
+ dtype=tf.dtypes.float32,
+ name='ones_32_',
+ ) # (64)
+ )
+ self._03ParallelBlock_01SequentialBlock_02BatchNorm_beta = tf.Variable(
+ tf.zeros(
+ shape=[64],
+ dtype=tf.dtypes.float32,
+ name='zeros_33_',
+ ) # (64)
+ )
+ self._03ParallelBlock_01SequentialBlock_02BatchNorm_runningMean = tf.Variable(
+ tf.zeros(
+ shape=[64],
+ dtype=tf.dtypes.float32,
+ name='zeros_34_',
+ ) # (64)
+ , trainable = False
+ )
+ self._03ParallelBlock_01SequentialBlock_02BatchNorm_runningVar = tf.Variable(
+ tf.ones(
+ shape=[64],
+ dtype=tf.dtypes.float32,
+ name='ones_35_',
+ ) # (64)
+ , trainable = False
+ )
+ self._03ParallelBlock_01SequentialBlock_04Conv2d_weight = tf.Variable(
+ tf.transpose(
+ tf.random.normal(
+ shape=[64, 64, 3, 3],
+ mean=0.0,
+ stddev=0.058925565,
+ dtype=tf.dtypes.float32,
+ name='normal_36_',
+ ), # (64, 64, 3, 3)
+ perm=[2, 3, 1, 0],
+ name='transpose_37_',
+ ) # (3, 3, 64, 64)
+ )
+ self._03ParallelBlock_01SequentialBlock_05BatchNorm_gamma = tf.Variable(
+ tf.ones(
+ shape=[64],
+ dtype=tf.dtypes.float32,
+ name='ones_38_',
+ ) # (64)
+ )
+ self._03ParallelBlock_01SequentialBlock_05BatchNorm_beta = tf.Variable(
+ tf.zeros(
+ shape=[64],
+ dtype=tf.dtypes.float32,
+ name='zeros_39_',
+ ) # (64)
+ )
+ self._03ParallelBlock_01SequentialBlock_05BatchNorm_runningMean = tf.Variable(
+ tf.zeros(
+ shape=[64],
+ dtype=tf.dtypes.float32,
+ name='zeros_40_',
+ ) # (64)
+ , trainable = False
+ )
+ self._03ParallelBlock_01SequentialBlock_05BatchNorm_runningVar = tf.Variable(
+ tf.ones(
+ shape=[64],
+ dtype=tf.dtypes.float32,
+ name='ones_41_',
+ ) # (64)
+ , trainable = False
+ )
+ self._03ParallelBlock_01SequentialBlock_07Conv2d_weight = tf.Variable(
+ tf.transpose(
+ tf.random.normal(
+ shape=[256, 64, 1, 1],
+ mean=0.0,
+ stddev=0.17677669,
+ dtype=tf.dtypes.float32,
+ name='normal_42_',
+ ), # (256, 64, 1, 1)
+ perm=[2, 3, 1, 0],
+ name='transpose_43_',
+ ) # (1, 1, 64, 256)
+ )
+ self._03ParallelBlock_01SequentialBlock_07Conv2d_bias = tf.Variable(
+ tf.zeros(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='zeros_44_',
+ ) # (256)
+ )
+ self._03ParallelBlock_01SequentialBlock_08BatchNorm_gamma = tf.Variable(
+ tf.ones(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='ones_45_',
+ ) # (256)
+ )
+ self._03ParallelBlock_01SequentialBlock_08BatchNorm_beta = tf.Variable(
+ tf.zeros(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='zeros_46_',
+ ) # (256)
+ )
+ self._03ParallelBlock_01SequentialBlock_08BatchNorm_runningMean = tf.Variable(
+ tf.zeros(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='zeros_47_',
+ ) # (256)
+ , trainable = False
+ )
+ self._03ParallelBlock_01SequentialBlock_08BatchNorm_runningVar = tf.Variable(
+ tf.ones(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='ones_48_',
+ ) # (256)
+ , trainable = False
+ )
+ self._04ParallelBlock_01SequentialBlock_01Conv2d_weight = tf.Variable(
+ tf.transpose(
+ tf.random.normal(
+ shape=[64, 256, 1, 1],
+ mean=0.0,
+ stddev=0.088388346,
+ dtype=tf.dtypes.float32,
+ name='normal_49_',
+ ), # (64, 256, 1, 1)
+ perm=[2, 3, 1, 0],
+ name='transpose_50_',
+ ) # (1, 1, 256, 64)
+ )
+ self._04ParallelBlock_01SequentialBlock_01Conv2d_bias = tf.Variable(
+ tf.zeros(
+ shape=[64],
+ dtype=tf.dtypes.float32,
+ name='zeros_51_',
+ ) # (64)
+ )
+ self._04ParallelBlock_01SequentialBlock_02BatchNorm_gamma = tf.Variable(
+ tf.ones(
+ shape=[64],
+ dtype=tf.dtypes.float32,
+ name='ones_52_',
+ ) # (64)
+ )
+ self._04ParallelBlock_01SequentialBlock_02BatchNorm_beta = tf.Variable(
+ tf.zeros(
+ shape=[64],
+ dtype=tf.dtypes.float32,
+ name='zeros_53_',
+ ) # (64)
+ )
+ self._04ParallelBlock_01SequentialBlock_02BatchNorm_runningMean = tf.Variable(
+ tf.zeros(
+ shape=[64],
+ dtype=tf.dtypes.float32,
+ name='zeros_54_',
+ ) # (64)
+ , trainable = False
+ )
+ self._04ParallelBlock_01SequentialBlock_02BatchNorm_runningVar = tf.Variable(
+ tf.ones(
+ shape=[64],
+ dtype=tf.dtypes.float32,
+ name='ones_55_',
+ ) # (64)
+ , trainable = False
+ )
+ self._04ParallelBlock_01SequentialBlock_04Conv2d_weight = tf.Variable(
+ tf.transpose(
+ tf.random.normal(
+ shape=[64, 64, 3, 3],
+ mean=0.0,
+ stddev=0.058925565,
+ dtype=tf.dtypes.float32,
+ name='normal_56_',
+ ), # (64, 64, 3, 3)
+ perm=[2, 3, 1, 0],
+ name='transpose_57_',
+ ) # (3, 3, 64, 64)
+ )
+ self._04ParallelBlock_01SequentialBlock_05BatchNorm_gamma = tf.Variable(
+ tf.ones(
+ shape=[64],
+ dtype=tf.dtypes.float32,
+ name='ones_58_',
+ ) # (64)
+ )
+ self._04ParallelBlock_01SequentialBlock_05BatchNorm_beta = tf.Variable(
+ tf.zeros(
+ shape=[64],
+ dtype=tf.dtypes.float32,
+ name='zeros_59_',
+ ) # (64)
+ )
+ self._04ParallelBlock_01SequentialBlock_05BatchNorm_runningMean = tf.Variable(
+ tf.zeros(
+ shape=[64],
+ dtype=tf.dtypes.float32,
+ name='zeros_60_',
+ ) # (64)
+ , trainable = False
+ )
+ self._04ParallelBlock_01SequentialBlock_05BatchNorm_runningVar = tf.Variable(
+ tf.ones(
+ shape=[64],
+ dtype=tf.dtypes.float32,
+ name='ones_61_',
+ ) # (64)
+ , trainable = False
+ )
+ self._04ParallelBlock_01SequentialBlock_07Conv2d_weight = tf.Variable(
+ tf.transpose(
+ tf.random.normal(
+ shape=[256, 64, 1, 1],
+ mean=0.0,
+ stddev=0.17677669,
+ dtype=tf.dtypes.float32,
+ name='normal_62_',
+ ), # (256, 64, 1, 1)
+ perm=[2, 3, 1, 0],
+ name='transpose_63_',
+ ) # (1, 1, 64, 256)
+ )
+ self._04ParallelBlock_01SequentialBlock_07Conv2d_bias = tf.Variable(
+ tf.zeros(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='zeros_64_',
+ ) # (256)
+ )
+ self._04ParallelBlock_01SequentialBlock_08BatchNorm_gamma = tf.Variable(
+ tf.ones(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='ones_65_',
+ ) # (256)
+ )
+ self._04ParallelBlock_01SequentialBlock_08BatchNorm_beta = tf.Variable(
+ tf.zeros(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='zeros_66_',
+ ) # (256)
+ )
+ self._04ParallelBlock_01SequentialBlock_08BatchNorm_runningMean = tf.Variable(
+ tf.zeros(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='zeros_67_',
+ ) # (256)
+ , trainable = False
+ )
+ self._04ParallelBlock_01SequentialBlock_08BatchNorm_runningVar = tf.Variable(
+ tf.ones(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='ones_68_',
+ ) # (256)
+ , trainable = False
+ )
+ self._05ParallelBlock_01SequentialBlock_01Conv2d_weight = tf.Variable(
+ tf.transpose(
+ tf.random.normal(
+ shape=[128, 256, 1, 1],
+ mean=0.0,
+ stddev=0.088388346,
+ dtype=tf.dtypes.float32,
+ name='normal_69_',
+ ), # (128, 256, 1, 1)
+ perm=[2, 3, 1, 0],
+ name='transpose_70_',
+ ) # (1, 1, 256, 128)
+ )
+ self._05ParallelBlock_01SequentialBlock_01Conv2d_bias = tf.Variable(
+ tf.zeros(
+ shape=[128],
+ dtype=tf.dtypes.float32,
+ name='zeros_71_',
+ ) # (128)
+ )
+ self._05ParallelBlock_01SequentialBlock_02BatchNorm_gamma = tf.Variable(
+ tf.ones(
+ shape=[128],
+ dtype=tf.dtypes.float32,
+ name='ones_72_',
+ ) # (128)
+ )
+ self._05ParallelBlock_01SequentialBlock_02BatchNorm_beta = tf.Variable(
+ tf.zeros(
+ shape=[128],
+ dtype=tf.dtypes.float32,
+ name='zeros_73_',
+ ) # (128)
+ )
+ self._05ParallelBlock_01SequentialBlock_02BatchNorm_runningMean = tf.Variable(
+ tf.zeros(
+ shape=[128],
+ dtype=tf.dtypes.float32,
+ name='zeros_74_',
+ ) # (128)
+ , trainable = False
+ )
+ self._05ParallelBlock_01SequentialBlock_02BatchNorm_runningVar = tf.Variable(
+ tf.ones(
+ shape=[128],
+ dtype=tf.dtypes.float32,
+ name='ones_75_',
+ ) # (128)
+ , trainable = False
+ )
+ self._05ParallelBlock_01SequentialBlock_04Conv2d_weight = tf.Variable(
+ tf.transpose(
+ tf.random.normal(
+ shape=[128, 128, 3, 3],
+ mean=0.0,
+ stddev=0.041666668,
+ dtype=tf.dtypes.float32,
+ name='normal_76_',
+ ), # (128, 128, 3, 3)
+ perm=[2, 3, 1, 0],
+ name='transpose_77_',
+ ) # (3, 3, 128, 128)
+ )
+ self._05ParallelBlock_01SequentialBlock_05BatchNorm_gamma = tf.Variable(
+ tf.ones(
+ shape=[128],
+ dtype=tf.dtypes.float32,
+ name='ones_78_',
+ ) # (128)
+ )
+ self._05ParallelBlock_01SequentialBlock_05BatchNorm_beta = tf.Variable(
+ tf.zeros(
+ shape=[128],
+ dtype=tf.dtypes.float32,
+ name='zeros_79_',
+ ) # (128)
+ )
+ self._05ParallelBlock_01SequentialBlock_05BatchNorm_runningMean = tf.Variable(
+ tf.zeros(
+ shape=[128],
+ dtype=tf.dtypes.float32,
+ name='zeros_80_',
+ ) # (128)
+ , trainable = False
+ )
+ self._05ParallelBlock_01SequentialBlock_05BatchNorm_runningVar = tf.Variable(
+ tf.ones(
+ shape=[128],
+ dtype=tf.dtypes.float32,
+ name='ones_81_',
+ ) # (128)
+ , trainable = False
+ )
+ self._05ParallelBlock_01SequentialBlock_07Conv2d_weight = tf.Variable(
+ tf.transpose(
+ tf.random.normal(
+ shape=[512, 128, 1, 1],
+ mean=0.0,
+ stddev=0.125,
+ dtype=tf.dtypes.float32,
+ name='normal_82_',
+ ), # (512, 128, 1, 1)
+ perm=[2, 3, 1, 0],
+ name='transpose_83_',
+ ) # (1, 1, 128, 512)
+ )
+ self._05ParallelBlock_01SequentialBlock_07Conv2d_bias = tf.Variable(
+ tf.zeros(
+ shape=[512],
+ dtype=tf.dtypes.float32,
+ name='zeros_84_',
+ ) # (512)
+ )
+ self._05ParallelBlock_01SequentialBlock_08BatchNorm_gamma = tf.Variable(
+ tf.ones(
+ shape=[512],
+ dtype=tf.dtypes.float32,
+ name='ones_85_',
+ ) # (512)
+ )
+ self._05ParallelBlock_01SequentialBlock_08BatchNorm_beta = tf.Variable(
+ tf.zeros(
+ shape=[512],
+ dtype=tf.dtypes.float32,
+ name='zeros_86_',
+ ) # (512)
+ )
+ self._05ParallelBlock_01SequentialBlock_08BatchNorm_runningMean = tf.Variable(
+ tf.zeros(
+ shape=[512],
+ dtype=tf.dtypes.float32,
+ name='zeros_87_',
+ ) # (512)
+ , trainable = False
+ )
+ self._05ParallelBlock_01SequentialBlock_08BatchNorm_runningVar = tf.Variable(
+ tf.ones(
+ shape=[512],
+ dtype=tf.dtypes.float32,
+ name='ones_88_',
+ ) # (512)
+ , trainable = False
+ )
+ self._05ParallelBlock_02SequentialBlock_01Conv2d_weight = tf.Variable(
+ tf.transpose(
+ tf.random.normal(
+ shape=[512, 256, 1, 1],
+ mean=0.0,
+ stddev=0.088388346,
+ dtype=tf.dtypes.float32,
+ name='normal_89_',
+ ), # (512, 256, 1, 1)
+ perm=[2, 3, 1, 0],
+ name='transpose_90_',
+ ) # (1, 1, 256, 512)
+ )
+ self._05ParallelBlock_02SequentialBlock_02BatchNorm_gamma = tf.Variable(
+ tf.ones(
+ shape=[512],
+ dtype=tf.dtypes.float32,
+ name='ones_91_',
+ ) # (512)
+ )
+ self._05ParallelBlock_02SequentialBlock_02BatchNorm_beta = tf.Variable(
+ tf.zeros(
+ shape=[512],
+ dtype=tf.dtypes.float32,
+ name='zeros_92_',
+ ) # (512)
+ )
+ self._05ParallelBlock_02SequentialBlock_02BatchNorm_runningMean = tf.Variable(
+ tf.zeros(
+ shape=[512],
+ dtype=tf.dtypes.float32,
+ name='zeros_93_',
+ ) # (512)
+ , trainable = False
+ )
+ self._05ParallelBlock_02SequentialBlock_02BatchNorm_runningVar = tf.Variable(
+ tf.ones(
+ shape=[512],
+ dtype=tf.dtypes.float32,
+ name='ones_94_',
+ ) # (512)
+ , trainable = False
+ )
+ self._06ParallelBlock_01SequentialBlock_01Conv2d_weight = tf.Variable(
+ tf.transpose(
+ tf.random.normal(
+ shape=[128, 512, 1, 1],
+ mean=0.0,
+ stddev=0.0625,
+ dtype=tf.dtypes.float32,
+ name='normal_95_',
+ ), # (128, 512, 1, 1)
+ perm=[2, 3, 1, 0],
+ name='transpose_96_',
+ ) # (1, 1, 512, 128)
+ )
+ self._06ParallelBlock_01SequentialBlock_01Conv2d_bias = tf.Variable(
+ tf.zeros(
+ shape=[128],
+ dtype=tf.dtypes.float32,
+ name='zeros_97_',
+ ) # (128)
+ )
+ self._06ParallelBlock_01SequentialBlock_02BatchNorm_gamma = tf.Variable(
+ tf.ones(
+ shape=[128],
+ dtype=tf.dtypes.float32,
+ name='ones_98_',
+ ) # (128)
+ )
+ self._06ParallelBlock_01SequentialBlock_02BatchNorm_beta = tf.Variable(
+ tf.zeros(
+ shape=[128],
+ dtype=tf.dtypes.float32,
+ name='zeros_99_',
+ ) # (128)
+ )
+ self._06ParallelBlock_01SequentialBlock_02BatchNorm_runningMean = tf.Variable(
+ tf.zeros(
+ shape=[128],
+ dtype=tf.dtypes.float32,
+ name='zeros_100_',
+ ) # (128)
+ , trainable = False
+ )
+ self._06ParallelBlock_01SequentialBlock_02BatchNorm_runningVar = tf.Variable(
+ tf.ones(
+ shape=[128],
+ dtype=tf.dtypes.float32,
+ name='ones_101_',
+ ) # (128)
+ , trainable = False
+ )
+ self._06ParallelBlock_01SequentialBlock_04Conv2d_weight = tf.Variable(
+ tf.transpose(
+ tf.random.normal(
+ shape=[128, 128, 3, 3],
+ mean=0.0,
+ stddev=0.041666668,
+ dtype=tf.dtypes.float32,
+ name='normal_102_',
+ ), # (128, 128, 3, 3)
+ perm=[2, 3, 1, 0],
+ name='transpose_103_',
+ ) # (3, 3, 128, 128)
+ )
+ self._06ParallelBlock_01SequentialBlock_05BatchNorm_gamma = tf.Variable(
+ tf.ones(
+ shape=[128],
+ dtype=tf.dtypes.float32,
+ name='ones_104_',
+ ) # (128)
+ )
+ self._06ParallelBlock_01SequentialBlock_05BatchNorm_beta = tf.Variable(
+ tf.zeros(
+ shape=[128],
+ dtype=tf.dtypes.float32,
+ name='zeros_105_',
+ ) # (128)
+ )
+ self._06ParallelBlock_01SequentialBlock_05BatchNorm_runningMean = tf.Variable(
+ tf.zeros(
+ shape=[128],
+ dtype=tf.dtypes.float32,
+ name='zeros_106_',
+ ) # (128)
+ , trainable = False
+ )
+ self._06ParallelBlock_01SequentialBlock_05BatchNorm_runningVar = tf.Variable(
+ tf.ones(
+ shape=[128],
+ dtype=tf.dtypes.float32,
+ name='ones_107_',
+ ) # (128)
+ , trainable = False
+ )
+ self._06ParallelBlock_01SequentialBlock_07Conv2d_weight = tf.Variable(
+ tf.transpose(
+ tf.random.normal(
+ shape=[512, 128, 1, 1],
+ mean=0.0,
+ stddev=0.125,
+ dtype=tf.dtypes.float32,
+ name='normal_108_',
+ ), # (512, 128, 1, 1)
+ perm=[2, 3, 1, 0],
+ name='transpose_109_',
+ ) # (1, 1, 128, 512)
+ )
+ self._06ParallelBlock_01SequentialBlock_07Conv2d_bias = tf.Variable(
+ tf.zeros(
+ shape=[512],
+ dtype=tf.dtypes.float32,
+ name='zeros_110_',
+ ) # (512)
+ )
+ self._06ParallelBlock_01SequentialBlock_08BatchNorm_gamma = tf.Variable(
+ tf.ones(
+ shape=[512],
+ dtype=tf.dtypes.float32,
+ name='ones_111_',
+ ) # (512)
+ )
+ self._06ParallelBlock_01SequentialBlock_08BatchNorm_beta = tf.Variable(
+ tf.zeros(
+ shape=[512],
+ dtype=tf.dtypes.float32,
+ name='zeros_112_',
+ ) # (512)
+ )
+ self._06ParallelBlock_01SequentialBlock_08BatchNorm_runningMean = tf.Variable(
+ tf.zeros(
+ shape=[512],
+ dtype=tf.dtypes.float32,
+ name='zeros_113_',
+ ) # (512)
+ , trainable = False
+ )
+ self._06ParallelBlock_01SequentialBlock_08BatchNorm_runningVar = tf.Variable(
+ tf.ones(
+ shape=[512],
+ dtype=tf.dtypes.float32,
+ name='ones_114_',
+ ) # (512)
+ , trainable = False
+ )
+ self._07ParallelBlock_01SequentialBlock_01Conv2d_weight = tf.Variable(
+ tf.transpose(
+ tf.random.normal(
+ shape=[128, 512, 1, 1],
+ mean=0.0,
+ stddev=0.0625,
+ dtype=tf.dtypes.float32,
+ name='normal_115_',
+ ), # (128, 512, 1, 1)
+ perm=[2, 3, 1, 0],
+ name='transpose_116_',
+ ) # (1, 1, 512, 128)
+ )
+ self._07ParallelBlock_01SequentialBlock_01Conv2d_bias = tf.Variable(
+ tf.zeros(
+ shape=[128],
+ dtype=tf.dtypes.float32,
+ name='zeros_117_',
+ ) # (128)
+ )
+ self._07ParallelBlock_01SequentialBlock_02BatchNorm_gamma = tf.Variable(
+ tf.ones(
+ shape=[128],
+ dtype=tf.dtypes.float32,
+ name='ones_118_',
+ ) # (128)
+ )
+ self._07ParallelBlock_01SequentialBlock_02BatchNorm_beta = tf.Variable(
+ tf.zeros(
+ shape=[128],
+ dtype=tf.dtypes.float32,
+ name='zeros_119_',
+ ) # (128)
+ )
+ self._07ParallelBlock_01SequentialBlock_02BatchNorm_runningMean = tf.Variable(
+ tf.zeros(
+ shape=[128],
+ dtype=tf.dtypes.float32,
+ name='zeros_120_',
+ ) # (128)
+ , trainable = False
+ )
+ self._07ParallelBlock_01SequentialBlock_02BatchNorm_runningVar = tf.Variable(
+ tf.ones(
+ shape=[128],
+ dtype=tf.dtypes.float32,
+ name='ones_121_',
+ ) # (128)
+ , trainable = False
+ )
+ self._07ParallelBlock_01SequentialBlock_04Conv2d_weight = tf.Variable(
+ tf.transpose(
+ tf.random.normal(
+ shape=[128, 128, 3, 3],
+ mean=0.0,
+ stddev=0.041666668,
+ dtype=tf.dtypes.float32,
+ name='normal_122_',
+ ), # (128, 128, 3, 3)
+ perm=[2, 3, 1, 0],
+ name='transpose_123_',
+ ) # (3, 3, 128, 128)
+ )
+ self._07ParallelBlock_01SequentialBlock_05BatchNorm_gamma = tf.Variable(
+ tf.ones(
+ shape=[128],
+ dtype=tf.dtypes.float32,
+ name='ones_124_',
+ ) # (128)
+ )
+ self._07ParallelBlock_01SequentialBlock_05BatchNorm_beta = tf.Variable(
+ tf.zeros(
+ shape=[128],
+ dtype=tf.dtypes.float32,
+ name='zeros_125_',
+ ) # (128)
+ )
+ self._07ParallelBlock_01SequentialBlock_05BatchNorm_runningMean = tf.Variable(
+ tf.zeros(
+ shape=[128],
+ dtype=tf.dtypes.float32,
+ name='zeros_126_',
+ ) # (128)
+ , trainable = False
+ )
+ self._07ParallelBlock_01SequentialBlock_05BatchNorm_runningVar = tf.Variable(
+ tf.ones(
+ shape=[128],
+ dtype=tf.dtypes.float32,
+ name='ones_127_',
+ ) # (128)
+ , trainable = False
+ )
+ self._07ParallelBlock_01SequentialBlock_07Conv2d_weight = tf.Variable(
+ tf.transpose(
+ tf.random.normal(
+ shape=[512, 128, 1, 1],
+ mean=0.0,
+ stddev=0.125,
+ dtype=tf.dtypes.float32,
+ name='normal_128_',
+ ), # (512, 128, 1, 1)
+ perm=[2, 3, 1, 0],
+ name='transpose_129_',
+ ) # (1, 1, 128, 512)
+ )
+ self._07ParallelBlock_01SequentialBlock_07Conv2d_bias = tf.Variable(
+ tf.zeros(
+ shape=[512],
+ dtype=tf.dtypes.float32,
+ name='zeros_130_',
+ ) # (512)
+ )
+ self._07ParallelBlock_01SequentialBlock_08BatchNorm_gamma = tf.Variable(
+ tf.ones(
+ shape=[512],
+ dtype=tf.dtypes.float32,
+ name='ones_131_',
+ ) # (512)
+ )
+ self._07ParallelBlock_01SequentialBlock_08BatchNorm_beta = tf.Variable(
+ tf.zeros(
+ shape=[512],
+ dtype=tf.dtypes.float32,
+ name='zeros_132_',
+ ) # (512)
+ )
+ self._07ParallelBlock_01SequentialBlock_08BatchNorm_runningMean = tf.Variable(
+ tf.zeros(
+ shape=[512],
+ dtype=tf.dtypes.float32,
+ name='zeros_133_',
+ ) # (512)
+ , trainable = False
+ )
+ self._07ParallelBlock_01SequentialBlock_08BatchNorm_runningVar = tf.Variable(
+ tf.ones(
+ shape=[512],
+ dtype=tf.dtypes.float32,
+ name='ones_134_',
+ ) # (512)
+ , trainable = False
+ )
+ self._08ParallelBlock_01SequentialBlock_01Conv2d_weight = tf.Variable(
+ tf.transpose(
+ tf.random.normal(
+ shape=[128, 512, 1, 1],
+ mean=0.0,
+ stddev=0.0625,
+ dtype=tf.dtypes.float32,
+ name='normal_135_',
+ ), # (128, 512, 1, 1)
+ perm=[2, 3, 1, 0],
+ name='transpose_136_',
+ ) # (1, 1, 512, 128)
+ )
+ self._08ParallelBlock_01SequentialBlock_01Conv2d_bias = tf.Variable(
+ tf.zeros(
+ shape=[128],
+ dtype=tf.dtypes.float32,
+ name='zeros_137_',
+ ) # (128)
+ )
+ self._08ParallelBlock_01SequentialBlock_02BatchNorm_gamma = tf.Variable(
+ tf.ones(
+ shape=[128],
+ dtype=tf.dtypes.float32,
+ name='ones_138_',
+ ) # (128)
+ )
+ self._08ParallelBlock_01SequentialBlock_02BatchNorm_beta = tf.Variable(
+ tf.zeros(
+ shape=[128],
+ dtype=tf.dtypes.float32,
+ name='zeros_139_',
+ ) # (128)
+ )
+ self._08ParallelBlock_01SequentialBlock_02BatchNorm_runningMean = tf.Variable(
+ tf.zeros(
+ shape=[128],
+ dtype=tf.dtypes.float32,
+ name='zeros_140_',
+ ) # (128)
+ , trainable = False
+ )
+ self._08ParallelBlock_01SequentialBlock_02BatchNorm_runningVar = tf.Variable(
+ tf.ones(
+ shape=[128],
+ dtype=tf.dtypes.float32,
+ name='ones_141_',
+ ) # (128)
+ , trainable = False
+ )
+ self._08ParallelBlock_01SequentialBlock_04Conv2d_weight = tf.Variable(
+ tf.transpose(
+ tf.random.normal(
+ shape=[128, 128, 3, 3],
+ mean=0.0,
+ stddev=0.041666668,
+ dtype=tf.dtypes.float32,
+ name='normal_142_',
+ ), # (128, 128, 3, 3)
+ perm=[2, 3, 1, 0],
+ name='transpose_143_',
+ ) # (3, 3, 128, 128)
+ )
+ self._08ParallelBlock_01SequentialBlock_05BatchNorm_gamma = tf.Variable(
+ tf.ones(
+ shape=[128],
+ dtype=tf.dtypes.float32,
+ name='ones_144_',
+ ) # (128)
+ )
+ self._08ParallelBlock_01SequentialBlock_05BatchNorm_beta = tf.Variable(
+ tf.zeros(
+ shape=[128],
+ dtype=tf.dtypes.float32,
+ name='zeros_145_',
+ ) # (128)
+ )
+ self._08ParallelBlock_01SequentialBlock_05BatchNorm_runningMean = tf.Variable(
+ tf.zeros(
+ shape=[128],
+ dtype=tf.dtypes.float32,
+ name='zeros_146_',
+ ) # (128)
+ , trainable = False
+ )
+ self._08ParallelBlock_01SequentialBlock_05BatchNorm_runningVar = tf.Variable(
+ tf.ones(
+ shape=[128],
+ dtype=tf.dtypes.float32,
+ name='ones_147_',
+ ) # (128)
+ , trainable = False
+ )
+ self._08ParallelBlock_01SequentialBlock_07Conv2d_weight = tf.Variable(
+ tf.transpose(
+ tf.random.normal(
+ shape=[512, 128, 1, 1],
+ mean=0.0,
+ stddev=0.125,
+ dtype=tf.dtypes.float32,
+ name='normal_148_',
+ ), # (512, 128, 1, 1)
+ perm=[2, 3, 1, 0],
+ name='transpose_149_',
+ ) # (1, 1, 128, 512)
+ )
+ self._08ParallelBlock_01SequentialBlock_07Conv2d_bias = tf.Variable(
+ tf.zeros(
+ shape=[512],
+ dtype=tf.dtypes.float32,
+ name='zeros_150_',
+ ) # (512)
+ )
+ self._08ParallelBlock_01SequentialBlock_08BatchNorm_gamma = tf.Variable(
+ tf.ones(
+ shape=[512],
+ dtype=tf.dtypes.float32,
+ name='ones_151_',
+ ) # (512)
+ )
+ self._08ParallelBlock_01SequentialBlock_08BatchNorm_beta = tf.Variable(
+ tf.zeros(
+ shape=[512],
+ dtype=tf.dtypes.float32,
+ name='zeros_152_',
+ ) # (512)
+ )
+ self._08ParallelBlock_01SequentialBlock_08BatchNorm_runningMean = tf.Variable(
+ tf.zeros(
+ shape=[512],
+ dtype=tf.dtypes.float32,
+ name='zeros_153_',
+ ) # (512)
+ , trainable = False
+ )
+ self._08ParallelBlock_01SequentialBlock_08BatchNorm_runningVar = tf.Variable(
+ tf.ones(
+ shape=[512],
+ dtype=tf.dtypes.float32,
+ name='ones_154_',
+ ) # (512)
+ , trainable = False
+ )
+ self._09ParallelBlock_01SequentialBlock_01Conv2d_weight = tf.Variable(
+ tf.transpose(
+ tf.random.normal(
+ shape=[256, 512, 1, 1],
+ mean=0.0,
+ stddev=0.0625,
+ dtype=tf.dtypes.float32,
+ name='normal_155_',
+ ), # (256, 512, 1, 1)
+ perm=[2, 3, 1, 0],
+ name='transpose_156_',
+ ) # (1, 1, 512, 256)
+ )
+ self._09ParallelBlock_01SequentialBlock_01Conv2d_bias = tf.Variable(
+ tf.zeros(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='zeros_157_',
+ ) # (256)
+ )
+ self._09ParallelBlock_01SequentialBlock_02BatchNorm_gamma = tf.Variable(
+ tf.ones(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='ones_158_',
+ ) # (256)
+ )
+ self._09ParallelBlock_01SequentialBlock_02BatchNorm_beta = tf.Variable(
+ tf.zeros(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='zeros_159_',
+ ) # (256)
+ )
+ self._09ParallelBlock_01SequentialBlock_02BatchNorm_runningMean = tf.Variable(
+ tf.zeros(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='zeros_160_',
+ ) # (256)
+ , trainable = False
+ )
+ self._09ParallelBlock_01SequentialBlock_02BatchNorm_runningVar = tf.Variable(
+ tf.ones(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='ones_161_',
+ ) # (256)
+ , trainable = False
+ )
+ self._09ParallelBlock_01SequentialBlock_04Conv2d_weight = tf.Variable(
+ tf.transpose(
+ tf.random.normal(
+ shape=[256, 256, 3, 3],
+ mean=0.0,
+ stddev=0.029462783,
+ dtype=tf.dtypes.float32,
+ name='normal_162_',
+ ), # (256, 256, 3, 3)
+ perm=[2, 3, 1, 0],
+ name='transpose_163_',
+ ) # (3, 3, 256, 256)
+ )
+ self._09ParallelBlock_01SequentialBlock_05BatchNorm_gamma = tf.Variable(
+ tf.ones(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='ones_164_',
+ ) # (256)
+ )
+ self._09ParallelBlock_01SequentialBlock_05BatchNorm_beta = tf.Variable(
+ tf.zeros(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='zeros_165_',
+ ) # (256)
+ )
+ self._09ParallelBlock_01SequentialBlock_05BatchNorm_runningMean = tf.Variable(
+ tf.zeros(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='zeros_166_',
+ ) # (256)
+ , trainable = False
+ )
+ self._09ParallelBlock_01SequentialBlock_05BatchNorm_runningVar = tf.Variable(
+ tf.ones(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='ones_167_',
+ ) # (256)
+ , trainable = False
+ )
+ self._09ParallelBlock_01SequentialBlock_07Conv2d_weight = tf.Variable(
+ tf.transpose(
+ tf.random.normal(
+ shape=[1024, 256, 1, 1],
+ mean=0.0,
+ stddev=0.088388346,
+ dtype=tf.dtypes.float32,
+ name='normal_168_',
+ ), # (1024, 256, 1, 1)
+ perm=[2, 3, 1, 0],
+ name='transpose_169_',
+ ) # (1, 1, 256, 1024)
+ )
+ self._09ParallelBlock_01SequentialBlock_07Conv2d_bias = tf.Variable(
+ tf.zeros(
+ shape=[1024],
+ dtype=tf.dtypes.float32,
+ name='zeros_170_',
+ ) # (1024)
+ )
+ self._09ParallelBlock_01SequentialBlock_08BatchNorm_gamma = tf.Variable(
+ tf.ones(
+ shape=[1024],
+ dtype=tf.dtypes.float32,
+ name='ones_171_',
+ ) # (1024)
+ )
+ self._09ParallelBlock_01SequentialBlock_08BatchNorm_beta = tf.Variable(
+ tf.zeros(
+ shape=[1024],
+ dtype=tf.dtypes.float32,
+ name='zeros_172_',
+ ) # (1024)
+ )
+ self._09ParallelBlock_01SequentialBlock_08BatchNorm_runningMean = tf.Variable(
+ tf.zeros(
+ shape=[1024],
+ dtype=tf.dtypes.float32,
+ name='zeros_173_',
+ ) # (1024)
+ , trainable = False
+ )
+ self._09ParallelBlock_01SequentialBlock_08BatchNorm_runningVar = tf.Variable(
+ tf.ones(
+ shape=[1024],
+ dtype=tf.dtypes.float32,
+ name='ones_174_',
+ ) # (1024)
+ , trainable = False
+ )
+ self._09ParallelBlock_02SequentialBlock_01Conv2d_weight = tf.Variable(
+ tf.transpose(
+ tf.random.normal(
+ shape=[1024, 512, 1, 1],
+ mean=0.0,
+ stddev=0.0625,
+ dtype=tf.dtypes.float32,
+ name='normal_175_',
+ ), # (1024, 512, 1, 1)
+ perm=[2, 3, 1, 0],
+ name='transpose_176_',
+ ) # (1, 1, 512, 1024)
+ )
+ self._09ParallelBlock_02SequentialBlock_02BatchNorm_gamma = tf.Variable(
+ tf.ones(
+ shape=[1024],
+ dtype=tf.dtypes.float32,
+ name='ones_177_',
+ ) # (1024)
+ )
+ self._09ParallelBlock_02SequentialBlock_02BatchNorm_beta = tf.Variable(
+ tf.zeros(
+ shape=[1024],
+ dtype=tf.dtypes.float32,
+ name='zeros_178_',
+ ) # (1024)
+ )
+ self._09ParallelBlock_02SequentialBlock_02BatchNorm_runningMean = tf.Variable(
+ tf.zeros(
+ shape=[1024],
+ dtype=tf.dtypes.float32,
+ name='zeros_179_',
+ ) # (1024)
+ , trainable = False
+ )
+ self._09ParallelBlock_02SequentialBlock_02BatchNorm_runningVar = tf.Variable(
+ tf.ones(
+ shape=[1024],
+ dtype=tf.dtypes.float32,
+ name='ones_180_',
+ ) # (1024)
+ , trainable = False
+ )
+ self._10ParallelBlock_01SequentialBlock_01Conv2d_weight = tf.Variable(
+ tf.transpose(
+ tf.random.normal(
+ shape=[256, 1024, 1, 1],
+ mean=0.0,
+ stddev=0.044194173,
+ dtype=tf.dtypes.float32,
+ name='normal_181_',
+ ), # (256, 1024, 1, 1)
+ perm=[2, 3, 1, 0],
+ name='transpose_182_',
+ ) # (1, 1, 1024, 256)
+ )
+ self._10ParallelBlock_01SequentialBlock_01Conv2d_bias = tf.Variable(
+ tf.zeros(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='zeros_183_',
+ ) # (256)
+ )
+ self._10ParallelBlock_01SequentialBlock_02BatchNorm_gamma = tf.Variable(
+ tf.ones(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='ones_184_',
+ ) # (256)
+ )
+ self._10ParallelBlock_01SequentialBlock_02BatchNorm_beta = tf.Variable(
+ tf.zeros(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='zeros_185_',
+ ) # (256)
+ )
+ self._10ParallelBlock_01SequentialBlock_02BatchNorm_runningMean = tf.Variable(
+ tf.zeros(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='zeros_186_',
+ ) # (256)
+ , trainable = False
+ )
+ self._10ParallelBlock_01SequentialBlock_02BatchNorm_runningVar = tf.Variable(
+ tf.ones(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='ones_187_',
+ ) # (256)
+ , trainable = False
+ )
+ self._10ParallelBlock_01SequentialBlock_04Conv2d_weight = tf.Variable(
+ tf.transpose(
+ tf.random.normal(
+ shape=[256, 256, 3, 3],
+ mean=0.0,
+ stddev=0.029462783,
+ dtype=tf.dtypes.float32,
+ name='normal_188_',
+ ), # (256, 256, 3, 3)
+ perm=[2, 3, 1, 0],
+ name='transpose_189_',
+ ) # (3, 3, 256, 256)
+ )
+ self._10ParallelBlock_01SequentialBlock_05BatchNorm_gamma = tf.Variable(
+ tf.ones(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='ones_190_',
+ ) # (256)
+ )
+ self._10ParallelBlock_01SequentialBlock_05BatchNorm_beta = tf.Variable(
+ tf.zeros(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='zeros_191_',
+ ) # (256)
+ )
+ self._10ParallelBlock_01SequentialBlock_05BatchNorm_runningMean = tf.Variable(
+ tf.zeros(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='zeros_192_',
+ ) # (256)
+ , trainable = False
+ )
+ self._10ParallelBlock_01SequentialBlock_05BatchNorm_runningVar = tf.Variable(
+ tf.ones(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='ones_193_',
+ ) # (256)
+ , trainable = False
+ )
+ self._10ParallelBlock_01SequentialBlock_07Conv2d_weight = tf.Variable(
+ tf.transpose(
+ tf.random.normal(
+ shape=[1024, 256, 1, 1],
+ mean=0.0,
+ stddev=0.088388346,
+ dtype=tf.dtypes.float32,
+ name='normal_194_',
+ ), # (1024, 256, 1, 1)
+ perm=[2, 3, 1, 0],
+ name='transpose_195_',
+ ) # (1, 1, 256, 1024)
+ )
+ self._10ParallelBlock_01SequentialBlock_07Conv2d_bias = tf.Variable(
+ tf.zeros(
+ shape=[1024],
+ dtype=tf.dtypes.float32,
+ name='zeros_196_',
+ ) # (1024)
+ )
+ self._10ParallelBlock_01SequentialBlock_08BatchNorm_gamma = tf.Variable(
+ tf.ones(
+ shape=[1024],
+ dtype=tf.dtypes.float32,
+ name='ones_197_',
+ ) # (1024)
+ )
+ self._10ParallelBlock_01SequentialBlock_08BatchNorm_beta = tf.Variable(
+ tf.zeros(
+ shape=[1024],
+ dtype=tf.dtypes.float32,
+ name='zeros_198_',
+ ) # (1024)
+ )
+ self._10ParallelBlock_01SequentialBlock_08BatchNorm_runningMean = tf.Variable(
+ tf.zeros(
+ shape=[1024],
+ dtype=tf.dtypes.float32,
+ name='zeros_199_',
+ ) # (1024)
+ , trainable = False
+ )
+ self._10ParallelBlock_01SequentialBlock_08BatchNorm_runningVar = tf.Variable(
+ tf.ones(
+ shape=[1024],
+ dtype=tf.dtypes.float32,
+ name='ones_200_',
+ ) # (1024)
+ , trainable = False
+ )
+ self._11ParallelBlock_01SequentialBlock_01Conv2d_weight = tf.Variable(
+ tf.transpose(
+ tf.random.normal(
+ shape=[256, 1024, 1, 1],
+ mean=0.0,
+ stddev=0.044194173,
+ dtype=tf.dtypes.float32,
+ name='normal_201_',
+ ), # (256, 1024, 1, 1)
+ perm=[2, 3, 1, 0],
+ name='transpose_202_',
+ ) # (1, 1, 1024, 256)
+ )
+ self._11ParallelBlock_01SequentialBlock_01Conv2d_bias = tf.Variable(
+ tf.zeros(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='zeros_203_',
+ ) # (256)
+ )
+ self._11ParallelBlock_01SequentialBlock_02BatchNorm_gamma = tf.Variable(
+ tf.ones(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='ones_204_',
+ ) # (256)
+ )
+ self._11ParallelBlock_01SequentialBlock_02BatchNorm_beta = tf.Variable(
+ tf.zeros(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='zeros_205_',
+ ) # (256)
+ )
+ self._11ParallelBlock_01SequentialBlock_02BatchNorm_runningMean = tf.Variable(
+ tf.zeros(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='zeros_206_',
+ ) # (256)
+ , trainable = False
+ )
+ self._11ParallelBlock_01SequentialBlock_02BatchNorm_runningVar = tf.Variable(
+ tf.ones(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='ones_207_',
+ ) # (256)
+ , trainable = False
+ )
+ self._11ParallelBlock_01SequentialBlock_04Conv2d_weight = tf.Variable(
+ tf.transpose(
+ tf.random.normal(
+ shape=[256, 256, 3, 3],
+ mean=0.0,
+ stddev=0.029462783,
+ dtype=tf.dtypes.float32,
+ name='normal_208_',
+ ), # (256, 256, 3, 3)
+ perm=[2, 3, 1, 0],
+ name='transpose_209_',
+ ) # (3, 3, 256, 256)
+ )
+ self._11ParallelBlock_01SequentialBlock_05BatchNorm_gamma = tf.Variable(
+ tf.ones(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='ones_210_',
+ ) # (256)
+ )
+ self._11ParallelBlock_01SequentialBlock_05BatchNorm_beta = tf.Variable(
+ tf.zeros(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='zeros_211_',
+ ) # (256)
+ )
+ self._11ParallelBlock_01SequentialBlock_05BatchNorm_runningMean = tf.Variable(
+ tf.zeros(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='zeros_212_',
+ ) # (256)
+ , trainable = False
+ )
+ self._11ParallelBlock_01SequentialBlock_05BatchNorm_runningVar = tf.Variable(
+ tf.ones(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='ones_213_',
+ ) # (256)
+ , trainable = False
+ )
+ self._11ParallelBlock_01SequentialBlock_07Conv2d_weight = tf.Variable(
+ tf.transpose(
+ tf.random.normal(
+ shape=[1024, 256, 1, 1],
+ mean=0.0,
+ stddev=0.088388346,
+ dtype=tf.dtypes.float32,
+ name='normal_214_',
+ ), # (1024, 256, 1, 1)
+ perm=[2, 3, 1, 0],
+ name='transpose_215_',
+ ) # (1, 1, 256, 1024)
+ )
+ self._11ParallelBlock_01SequentialBlock_07Conv2d_bias = tf.Variable(
+ tf.zeros(
+ shape=[1024],
+ dtype=tf.dtypes.float32,
+ name='zeros_216_',
+ ) # (1024)
+ )
+ self._11ParallelBlock_01SequentialBlock_08BatchNorm_gamma = tf.Variable(
+ tf.ones(
+ shape=[1024],
+ dtype=tf.dtypes.float32,
+ name='ones_217_',
+ ) # (1024)
+ )
+ self._11ParallelBlock_01SequentialBlock_08BatchNorm_beta = tf.Variable(
+ tf.zeros(
+ shape=[1024],
+ dtype=tf.dtypes.float32,
+ name='zeros_218_',
+ ) # (1024)
+ )
+ self._11ParallelBlock_01SequentialBlock_08BatchNorm_runningMean = tf.Variable(
+ tf.zeros(
+ shape=[1024],
+ dtype=tf.dtypes.float32,
+ name='zeros_219_',
+ ) # (1024)
+ , trainable = False
+ )
+ self._11ParallelBlock_01SequentialBlock_08BatchNorm_runningVar = tf.Variable(
+ tf.ones(
+ shape=[1024],
+ dtype=tf.dtypes.float32,
+ name='ones_220_',
+ ) # (1024)
+ , trainable = False
+ )
+ self._12ParallelBlock_01SequentialBlock_01Conv2d_weight = tf.Variable(
+ tf.transpose(
+ tf.random.normal(
+ shape=[256, 1024, 1, 1],
+ mean=0.0,
+ stddev=0.044194173,
+ dtype=tf.dtypes.float32,
+ name='normal_221_',
+ ), # (256, 1024, 1, 1)
+ perm=[2, 3, 1, 0],
+ name='transpose_222_',
+ ) # (1, 1, 1024, 256)
+ )
+ self._12ParallelBlock_01SequentialBlock_01Conv2d_bias = tf.Variable(
+ tf.zeros(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='zeros_223_',
+ ) # (256)
+ )
+ self._12ParallelBlock_01SequentialBlock_02BatchNorm_gamma = tf.Variable(
+ tf.ones(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='ones_224_',
+ ) # (256)
+ )
+ self._12ParallelBlock_01SequentialBlock_02BatchNorm_beta = tf.Variable(
+ tf.zeros(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='zeros_225_',
+ ) # (256)
+ )
+ self._12ParallelBlock_01SequentialBlock_02BatchNorm_runningMean = tf.Variable(
+ tf.zeros(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='zeros_226_',
+ ) # (256)
+ , trainable = False
+ )
+ self._12ParallelBlock_01SequentialBlock_02BatchNorm_runningVar = tf.Variable(
+ tf.ones(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='ones_227_',
+ ) # (256)
+ , trainable = False
+ )
+ self._12ParallelBlock_01SequentialBlock_04Conv2d_weight = tf.Variable(
+ tf.transpose(
+ tf.random.normal(
+ shape=[256, 256, 3, 3],
+ mean=0.0,
+ stddev=0.029462783,
+ dtype=tf.dtypes.float32,
+ name='normal_228_',
+ ), # (256, 256, 3, 3)
+ perm=[2, 3, 1, 0],
+ name='transpose_229_',
+ ) # (3, 3, 256, 256)
+ )
+ self._12ParallelBlock_01SequentialBlock_05BatchNorm_gamma = tf.Variable(
+ tf.ones(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='ones_230_',
+ ) # (256)
+ )
+ self._12ParallelBlock_01SequentialBlock_05BatchNorm_beta = tf.Variable(
+ tf.zeros(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='zeros_231_',
+ ) # (256)
+ )
+ self._12ParallelBlock_01SequentialBlock_05BatchNorm_runningMean = tf.Variable(
+ tf.zeros(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='zeros_232_',
+ ) # (256)
+ , trainable = False
+ )
+ self._12ParallelBlock_01SequentialBlock_05BatchNorm_runningVar = tf.Variable(
+ tf.ones(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='ones_233_',
+ ) # (256)
+ , trainable = False
+ )
+ self._12ParallelBlock_01SequentialBlock_07Conv2d_weight = tf.Variable(
+ tf.transpose(
+ tf.random.normal(
+ shape=[1024, 256, 1, 1],
+ mean=0.0,
+ stddev=0.088388346,
+ dtype=tf.dtypes.float32,
+ name='normal_234_',
+ ), # (1024, 256, 1, 1)
+ perm=[2, 3, 1, 0],
+ name='transpose_235_',
+ ) # (1, 1, 256, 1024)
+ )
+ self._12ParallelBlock_01SequentialBlock_07Conv2d_bias = tf.Variable(
+ tf.zeros(
+ shape=[1024],
+ dtype=tf.dtypes.float32,
+ name='zeros_236_',
+ ) # (1024)
+ )
+ self._12ParallelBlock_01SequentialBlock_08BatchNorm_gamma = tf.Variable(
+ tf.ones(
+ shape=[1024],
+ dtype=tf.dtypes.float32,
+ name='ones_237_',
+ ) # (1024)
+ )
+ self._12ParallelBlock_01SequentialBlock_08BatchNorm_beta = tf.Variable(
+ tf.zeros(
+ shape=[1024],
+ dtype=tf.dtypes.float32,
+ name='zeros_238_',
+ ) # (1024)
+ )
+ self._12ParallelBlock_01SequentialBlock_08BatchNorm_runningMean = tf.Variable(
+ tf.zeros(
+ shape=[1024],
+ dtype=tf.dtypes.float32,
+ name='zeros_239_',
+ ) # (1024)
+ , trainable = False
+ )
+ self._12ParallelBlock_01SequentialBlock_08BatchNorm_runningVar = tf.Variable(
+ tf.ones(
+ shape=[1024],
+ dtype=tf.dtypes.float32,
+ name='ones_240_',
+ ) # (1024)
+ , trainable = False
+ )
+ self._13ParallelBlock_01SequentialBlock_01Conv2d_weight = tf.Variable(
+ tf.transpose(
+ tf.random.normal(
+ shape=[256, 1024, 1, 1],
+ mean=0.0,
+ stddev=0.044194173,
+ dtype=tf.dtypes.float32,
+ name='normal_241_',
+ ), # (256, 1024, 1, 1)
+ perm=[2, 3, 1, 0],
+ name='transpose_242_',
+ ) # (1, 1, 1024, 256)
+ )
+ self._13ParallelBlock_01SequentialBlock_01Conv2d_bias = tf.Variable(
+ tf.zeros(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='zeros_243_',
+ ) # (256)
+ )
+ self._13ParallelBlock_01SequentialBlock_02BatchNorm_gamma = tf.Variable(
+ tf.ones(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='ones_244_',
+ ) # (256)
+ )
+ self._13ParallelBlock_01SequentialBlock_02BatchNorm_beta = tf.Variable(
+ tf.zeros(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='zeros_245_',
+ ) # (256)
+ )
+ self._13ParallelBlock_01SequentialBlock_02BatchNorm_runningMean = tf.Variable(
+ tf.zeros(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='zeros_246_',
+ ) # (256)
+ , trainable = False
+ )
+ self._13ParallelBlock_01SequentialBlock_02BatchNorm_runningVar = tf.Variable(
+ tf.ones(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='ones_247_',
+ ) # (256)
+ , trainable = False
+ )
+ self._13ParallelBlock_01SequentialBlock_04Conv2d_weight = tf.Variable(
+ tf.transpose(
+ tf.random.normal(
+ shape=[256, 256, 3, 3],
+ mean=0.0,
+ stddev=0.029462783,
+ dtype=tf.dtypes.float32,
+ name='normal_248_',
+ ), # (256, 256, 3, 3)
+ perm=[2, 3, 1, 0],
+ name='transpose_249_',
+ ) # (3, 3, 256, 256)
+ )
+ self._13ParallelBlock_01SequentialBlock_05BatchNorm_gamma = tf.Variable(
+ tf.ones(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='ones_250_',
+ ) # (256)
+ )
+ self._13ParallelBlock_01SequentialBlock_05BatchNorm_beta = tf.Variable(
+ tf.zeros(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='zeros_251_',
+ ) # (256)
+ )
+ self._13ParallelBlock_01SequentialBlock_05BatchNorm_runningMean = tf.Variable(
+ tf.zeros(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='zeros_252_',
+ ) # (256)
+ , trainable = False
+ )
+ self._13ParallelBlock_01SequentialBlock_05BatchNorm_runningVar = tf.Variable(
+ tf.ones(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='ones_253_',
+ ) # (256)
+ , trainable = False
+ )
+ self._13ParallelBlock_01SequentialBlock_07Conv2d_weight = tf.Variable(
+ tf.transpose(
+ tf.random.normal(
+ shape=[1024, 256, 1, 1],
+ mean=0.0,
+ stddev=0.088388346,
+ dtype=tf.dtypes.float32,
+ name='normal_254_',
+ ), # (1024, 256, 1, 1)
+ perm=[2, 3, 1, 0],
+ name='transpose_255_',
+ ) # (1, 1, 256, 1024)
+ )
+ self._13ParallelBlock_01SequentialBlock_07Conv2d_bias = tf.Variable(
+ tf.zeros(
+ shape=[1024],
+ dtype=tf.dtypes.float32,
+ name='zeros_256_',
+ ) # (1024)
+ )
+ self._13ParallelBlock_01SequentialBlock_08BatchNorm_gamma = tf.Variable(
+ tf.ones(
+ shape=[1024],
+ dtype=tf.dtypes.float32,
+ name='ones_257_',
+ ) # (1024)
+ )
+ self._13ParallelBlock_01SequentialBlock_08BatchNorm_beta = tf.Variable(
+ tf.zeros(
+ shape=[1024],
+ dtype=tf.dtypes.float32,
+ name='zeros_258_',
+ ) # (1024)
+ )
+ self._13ParallelBlock_01SequentialBlock_08BatchNorm_runningMean = tf.Variable(
+ tf.zeros(
+ shape=[1024],
+ dtype=tf.dtypes.float32,
+ name='zeros_259_',
+ ) # (1024)
+ , trainable = False
+ )
+ self._13ParallelBlock_01SequentialBlock_08BatchNorm_runningVar = tf.Variable(
+ tf.ones(
+ shape=[1024],
+ dtype=tf.dtypes.float32,
+ name='ones_260_',
+ ) # (1024)
+ , trainable = False
+ )
+ self._14ParallelBlock_01SequentialBlock_01Conv2d_weight = tf.Variable(
+ tf.transpose(
+ tf.random.normal(
+ shape=[256, 1024, 1, 1],
+ mean=0.0,
+ stddev=0.044194173,
+ dtype=tf.dtypes.float32,
+ name='normal_261_',
+ ), # (256, 1024, 1, 1)
+ perm=[2, 3, 1, 0],
+ name='transpose_262_',
+ ) # (1, 1, 1024, 256)
+ )
+ self._14ParallelBlock_01SequentialBlock_01Conv2d_bias = tf.Variable(
+ tf.zeros(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='zeros_263_',
+ ) # (256)
+ )
+ self._14ParallelBlock_01SequentialBlock_02BatchNorm_gamma = tf.Variable(
+ tf.ones(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='ones_264_',
+ ) # (256)
+ )
+ self._14ParallelBlock_01SequentialBlock_02BatchNorm_beta = tf.Variable(
+ tf.zeros(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='zeros_265_',
+ ) # (256)
+ )
+ self._14ParallelBlock_01SequentialBlock_02BatchNorm_runningMean = tf.Variable(
+ tf.zeros(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='zeros_266_',
+ ) # (256)
+ , trainable = False
+ )
+ self._14ParallelBlock_01SequentialBlock_02BatchNorm_runningVar = tf.Variable(
+ tf.ones(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='ones_267_',
+ ) # (256)
+ , trainable = False
+ )
+ self._14ParallelBlock_01SequentialBlock_04Conv2d_weight = tf.Variable(
+ tf.transpose(
+ tf.random.normal(
+ shape=[256, 256, 3, 3],
+ mean=0.0,
+ stddev=0.029462783,
+ dtype=tf.dtypes.float32,
+ name='normal_268_',
+ ), # (256, 256, 3, 3)
+ perm=[2, 3, 1, 0],
+ name='transpose_269_',
+ ) # (3, 3, 256, 256)
+ )
+ self._14ParallelBlock_01SequentialBlock_05BatchNorm_gamma = tf.Variable(
+ tf.ones(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='ones_270_',
+ ) # (256)
+ )
+ self._14ParallelBlock_01SequentialBlock_05BatchNorm_beta = tf.Variable(
+ tf.zeros(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='zeros_271_',
+ ) # (256)
+ )
+ self._14ParallelBlock_01SequentialBlock_05BatchNorm_runningMean = tf.Variable(
+ tf.zeros(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='zeros_272_',
+ ) # (256)
+ , trainable = False
+ )
+ self._14ParallelBlock_01SequentialBlock_05BatchNorm_runningVar = tf.Variable(
+ tf.ones(
+ shape=[256],
+ dtype=tf.dtypes.float32,
+ name='ones_273_',
+ ) # (256)
+ , trainable = False
+ )
+ self._14ParallelBlock_01SequentialBlock_07Conv2d_weight = tf.Variable(
+ tf.transpose(
+ tf.random.normal(
+ shape=[1024, 256, 1, 1],
+ mean=0.0,
+ stddev=0.088388346,
+ dtype=tf.dtypes.float32,
+ name='normal_274_',
+ ), # (1024, 256, 1, 1)
+ perm=[2, 3, 1, 0],
+ name='transpose_275_',
+ ) # (1, 1, 256, 1024)
+ )
+ self._14ParallelBlock_01SequentialBlock_07Conv2d_bias = tf.Variable(
+ tf.zeros(
+ shape=[1024],
+ dtype=tf.dtypes.float32,
+ name='zeros_276_',
+ ) # (1024)
+ )
+ self._14ParallelBlock_01SequentialBlock_08BatchNorm_gamma = tf.Variable(
+ tf.ones(
+ shape=[1024],
+ dtype=tf.dtypes.float32,
+ name='ones_277_',
+ ) # (1024)
+ )
+ self._14ParallelBlock_01SequentialBlock_08BatchNorm_beta = tf.Variable(
+ tf.zeros(
+ shape=[1024],
+ dtype=tf.dtypes.float32,
+ name='zeros_278_',
+ ) # (1024)
+ )
+ self._14ParallelBlock_01SequentialBlock_08BatchNorm_runningMean = tf.Variable(
+ tf.zeros(
+ shape=[1024],
+ dtype=tf.dtypes.float32,
+ name='zeros_279_',
+ ) # (1024)
+ , trainable = False
+ )
+ self._14ParallelBlock_01SequentialBlock_08BatchNorm_runningVar = tf.Variable(
+ tf.ones(
+ shape=[1024],
+ dtype=tf.dtypes.float32,
+ name='ones_280_',
+ ) # (1024)
+ , trainable = False
+ )
+ self._15ParallelBlock_01SequentialBlock_01Conv2d_weight = tf.Variable(
+ tf.transpose(
+ tf.random.normal(
+ shape=[512, 1024, 1, 1],
+ mean=0.0,
+ stddev=0.044194173,
+ dtype=tf.dtypes.float32,
+ name='normal_281_',
+ ), # (512, 1024, 1, 1)
+ perm=[2, 3, 1, 0],
+ name='transpose_282_',
+ ) # (1, 1, 1024, 512)
+ )
+ self._15ParallelBlock_01SequentialBlock_01Conv2d_bias = tf.Variable(
+ tf.zeros(
+ shape=[512],
+ dtype=tf.dtypes.float32,
+ name='zeros_283_',
+ ) # (512)
+ )
+ self._15ParallelBlock_01SequentialBlock_02BatchNorm_gamma = tf.Variable(
+ tf.ones(
+ shape=[512],
+ dtype=tf.dtypes.float32,
+ name='ones_284_',
+ ) # (512)
+ )
+ self._15ParallelBlock_01SequentialBlock_02BatchNorm_beta = tf.Variable(
+ tf.zeros(
+ shape=[512],
+ dtype=tf.dtypes.float32,
+ name='zeros_285_',
+ ) # (512)
+ )
+ self._15ParallelBlock_01SequentialBlock_02BatchNorm_runningMean = tf.Variable(
+ tf.zeros(
+ shape=[512],
+ dtype=tf.dtypes.float32,
+ name='zeros_286_',
+ ) # (512)
+ , trainable = False
+ )
+ self._15ParallelBlock_01SequentialBlock_02BatchNorm_runningVar = tf.Variable(
+ tf.ones(
+ shape=[512],
+ dtype=tf.dtypes.float32,
+ name='ones_287_',
+ ) # (512)
+ , trainable = False
+ )
+ self._15ParallelBlock_01SequentialBlock_04Conv2d_weight = tf.Variable(
+ tf.transpose(
+ tf.random.normal(
+ shape=[512, 512, 3, 3],
+ mean=0.0,
+ stddev=0.020833334,
+ dtype=tf.dtypes.float32,
+ name='normal_288_',
+ ), # (512, 512, 3, 3)
+ perm=[2, 3, 1, 0],
+ name='transpose_289_',
+ ) # (3, 3, 512, 512)
+ )
+ self._15ParallelBlock_01SequentialBlock_05BatchNorm_gamma = tf.Variable(
+ tf.ones(
+ shape=[512],
+ dtype=tf.dtypes.float32,
+ name='ones_290_',
+ ) # (512)
+ )
+ self._15ParallelBlock_01SequentialBlock_05BatchNorm_beta = tf.Variable(
+ tf.zeros(
+ shape=[512],
+ dtype=tf.dtypes.float32,
+ name='zeros_291_',
+ ) # (512)
+ )
+ self._15ParallelBlock_01SequentialBlock_05BatchNorm_runningMean = tf.Variable(
+ tf.zeros(
+ shape=[512],
+ dtype=tf.dtypes.float32,
+ name='zeros_292_',
+ ) # (512)
+ , trainable = False
+ )
+ self._15ParallelBlock_01SequentialBlock_05BatchNorm_runningVar = tf.Variable(
+ tf.ones(
+ shape=[512],
+ dtype=tf.dtypes.float32,
+ name='ones_293_',
+ ) # (512)
+ , trainable = False
+ )
+ self._15ParallelBlock_01SequentialBlock_07Conv2d_weight = tf.Variable(
+ tf.transpose(
+ tf.random.normal(
+ shape=[2048, 512, 1, 1],
+ mean=0.0,
+ stddev=0.0625,
+ dtype=tf.dtypes.float32,
+ name='normal_294_',
+ ), # (2048, 512, 1, 1)
+ perm=[2, 3, 1, 0],
+ name='transpose_295_',
+ ) # (1, 1, 512, 2048)
+ )
+ self._15ParallelBlock_01SequentialBlock_07Conv2d_bias = tf.Variable(
+ tf.zeros(
+ shape=[2048],
+ dtype=tf.dtypes.float32,
+ name='zeros_296_',
+ ) # (2048)
+ )
+ self._15ParallelBlock_01SequentialBlock_08BatchNorm_gamma = tf.Variable(
+ tf.ones(
+ shape=[2048],
+ dtype=tf.dtypes.float32,
+ name='ones_297_',
+ ) # (2048)
+ )
+ self._15ParallelBlock_01SequentialBlock_08BatchNorm_beta = tf.Variable(
+ tf.zeros(
+ shape=[2048],
+ dtype=tf.dtypes.float32,
+ name='zeros_298_',
+ ) # (2048)
+ )
+ self._15ParallelBlock_01SequentialBlock_08BatchNorm_runningMean = tf.Variable(
+ tf.zeros(
+ shape=[2048],
+ dtype=tf.dtypes.float32,
+ name='zeros_299_',
+ ) # (2048)
+ , trainable = False
+ )
+ self._15ParallelBlock_01SequentialBlock_08BatchNorm_runningVar = tf.Variable(
+ tf.ones(
+ shape=[2048],
+ dtype=tf.dtypes.float32,
+ name='ones_300_',
+ ) # (2048)
+ , trainable = False
+ )
+ self._15ParallelBlock_02SequentialBlock_01Conv2d_weight = tf.Variable(
+ tf.transpose(
+ tf.random.normal(
+ shape=[2048, 1024, 1, 1],
+ mean=0.0,
+ stddev=0.044194173,
+ dtype=tf.dtypes.float32,
+ name='normal_301_',
+ ), # (2048, 1024, 1, 1)
+ perm=[2, 3, 1, 0],
+ name='transpose_302_',
+ ) # (1, 1, 1024, 2048)
+ )
+ self._15ParallelBlock_02SequentialBlock_02BatchNorm_gamma = tf.Variable(
+ tf.ones(
+ shape=[2048],
+ dtype=tf.dtypes.float32,
+ name='ones_303_',
+ ) # (2048)
+ )
+ self._15ParallelBlock_02SequentialBlock_02BatchNorm_beta = tf.Variable(
+ tf.zeros(
+ shape=[2048],
+ dtype=tf.dtypes.float32,
+ name='zeros_304_',
+ ) # (2048)
+ )
+ self._15ParallelBlock_02SequentialBlock_02BatchNorm_runningMean = tf.Variable(
+ tf.zeros(
+ shape=[2048],
+ dtype=tf.dtypes.float32,
+ name='zeros_305_',
+ ) # (2048)
+ , trainable = False
+ )
+ self._15ParallelBlock_02SequentialBlock_02BatchNorm_runningVar = tf.Variable(
+ tf.ones(
+ shape=[2048],
+ dtype=tf.dtypes.float32,
+ name='ones_306_',
+ ) # (2048)
+ , trainable = False
+ )
+ self._16ParallelBlock_01SequentialBlock_01Conv2d_weight = tf.Variable(
+ tf.transpose(
+ tf.random.normal(
+ shape=[512, 2048, 1, 1],
+ mean=0.0,
+ stddev=0.03125,
+ dtype=tf.dtypes.float32,
+ name='normal_307_',
+ ), # (512, 2048, 1, 1)
+ perm=[2, 3, 1, 0],
+ name='transpose_308_',
+ ) # (1, 1, 2048, 512)
+ )
+ self._16ParallelBlock_01SequentialBlock_01Conv2d_bias = tf.Variable(
+ tf.zeros(
+ shape=[512],
+ dtype=tf.dtypes.float32,
+ name='zeros_309_',
+ ) # (512)
+ )
+ self._16ParallelBlock_01SequentialBlock_02BatchNorm_gamma = tf.Variable(
+ tf.ones(
+ shape=[512],
+ dtype=tf.dtypes.float32,
+ name='ones_310_',
+ ) # (512)
+ )
+ self._16ParallelBlock_01SequentialBlock_02BatchNorm_beta = tf.Variable(
+ tf.zeros(
+ shape=[512],
+ dtype=tf.dtypes.float32,
+ name='zeros_311_',
+ ) # (512)
+ )
+ self._16ParallelBlock_01SequentialBlock_02BatchNorm_runningMean = tf.Variable(
+ tf.zeros(
+ shape=[512],
+ dtype=tf.dtypes.float32,
+ name='zeros_312_',
+ ) # (512)
+ , trainable = False
+ )
+ self._16ParallelBlock_01SequentialBlock_02BatchNorm_runningVar = tf.Variable(
+ tf.ones(
+ shape=[512],
+ dtype=tf.dtypes.float32,
+ name='ones_313_',
+ ) # (512)
+ , trainable = False
+ )
+ self._16ParallelBlock_01SequentialBlock_04Conv2d_weight = tf.Variable(
+ tf.transpose(
+ tf.random.normal(
+ shape=[512, 512, 3, 3],
+ mean=0.0,
+ stddev=0.020833334,
+ dtype=tf.dtypes.float32,
+ name='normal_314_',
+ ), # (512, 512, 3, 3)
+ perm=[2, 3, 1, 0],
+ name='transpose_315_',
+ ) # (3, 3, 512, 512)
+ )
+ self._16ParallelBlock_01SequentialBlock_05BatchNorm_gamma = tf.Variable(
+ tf.ones(
+ shape=[512],
+ dtype=tf.dtypes.float32,
+ name='ones_316_',
+ ) # (512)
+ )
+ self._16ParallelBlock_01SequentialBlock_05BatchNorm_beta = tf.Variable(
+ tf.zeros(
+ shape=[512],
+ dtype=tf.dtypes.float32,
+ name='zeros_317_',
+ ) # (512)
+ )
+ self._16ParallelBlock_01SequentialBlock_05BatchNorm_runningMean = tf.Variable(
+ tf.zeros(
+ shape=[512],
+ dtype=tf.dtypes.float32,
+ name='zeros_318_',
+ ) # (512)
+ , trainable = False
+ )
+ self._16ParallelBlock_01SequentialBlock_05BatchNorm_runningVar = tf.Variable(
+ tf.ones(
+ shape=[512],
+ dtype=tf.dtypes.float32,
+ name='ones_319_',
+ ) # (512)
+ , trainable = False
+ )
+ self._16ParallelBlock_01SequentialBlock_07Conv2d_weight = tf.Variable(
+ tf.transpose(
+ tf.random.normal(
+ shape=[2048, 512, 1, 1],
+ mean=0.0,
+ stddev=0.0625,
+ dtype=tf.dtypes.float32,
+ name='normal_320_',
+ ), # (2048, 512, 1, 1)
+ perm=[2, 3, 1, 0],
+ name='transpose_321_',
+ ) # (1, 1, 512, 2048)
+ )
+ self._16ParallelBlock_01SequentialBlock_07Conv2d_bias = tf.Variable(
+ tf.zeros(
+ shape=[2048],
+ dtype=tf.dtypes.float32,
+ name='zeros_322_',
+ ) # (2048)
+ )
+ self._16ParallelBlock_01SequentialBlock_08BatchNorm_gamma = tf.Variable(
+ tf.ones(
+ shape=[2048],
+ dtype=tf.dtypes.float32,
+ name='ones_323_',
+ ) # (2048)
+ )
+ self._16ParallelBlock_01SequentialBlock_08BatchNorm_beta = tf.Variable(
+ tf.zeros(
+ shape=[2048],
+ dtype=tf.dtypes.float32,
+ name='zeros_324_',
+ ) # (2048)
+ )
+ self._16ParallelBlock_01SequentialBlock_08BatchNorm_runningMean = tf.Variable(
+ tf.zeros(
+ shape=[2048],
+ dtype=tf.dtypes.float32,
+ name='zeros_325_',
+ ) # (2048)
+ , trainable = False
+ )
+ self._16ParallelBlock_01SequentialBlock_08BatchNorm_runningVar = tf.Variable(
+ tf.ones(
+ shape=[2048],
+ dtype=tf.dtypes.float32,
+ name='ones_326_',
+ ) # (2048)
+ , trainable = False
+ )
+ self._17ParallelBlock_01SequentialBlock_01Conv2d_weight = tf.Variable(
+ tf.transpose(
+ tf.random.normal(
+ shape=[512, 2048, 1, 1],
+ mean=0.0,
+ stddev=0.03125,
+ dtype=tf.dtypes.float32,
+ name='normal_327_',
+ ), # (512, 2048, 1, 1)
+ perm=[2, 3, 1, 0],
+ name='transpose_328_',
+ ) # (1, 1, 2048, 512)
+ )
+ self._17ParallelBlock_01SequentialBlock_01Conv2d_bias = tf.Variable(
+ tf.zeros(
+ shape=[512],
+ dtype=tf.dtypes.float32,
+ name='zeros_329_',
+ ) # (512)
+ )
+ self._17ParallelBlock_01SequentialBlock_02BatchNorm_gamma = tf.Variable(
+ tf.ones(
+ shape=[512],
+ dtype=tf.dtypes.float32,
+ name='ones_330_',
+ ) # (512)
+ )
+ self._17ParallelBlock_01SequentialBlock_02BatchNorm_beta = tf.Variable(
+ tf.zeros(
+ shape=[512],
+ dtype=tf.dtypes.float32,
+ name='zeros_331_',
+ ) # (512)
+ )
+ self._17ParallelBlock_01SequentialBlock_02BatchNorm_runningMean = tf.Variable(
+ tf.zeros(
+ shape=[512],
+ dtype=tf.dtypes.float32,
+ name='zeros_332_',
+ ) # (512)
+ , trainable = False
+ )
+ self._17ParallelBlock_01SequentialBlock_02BatchNorm_runningVar = tf.Variable(
+ tf.ones(
+ shape=[512],
+ dtype=tf.dtypes.float32,
+ name='ones_333_',
+ ) # (512)
+ , trainable = False
+ )
+ self._17ParallelBlock_01SequentialBlock_04Conv2d_weight = tf.Variable(
+ tf.transpose(
+ tf.random.normal(
+ shape=[512, 512, 3, 3],
+ mean=0.0,
+ stddev=0.020833334,
+ dtype=tf.dtypes.float32,
+ name='normal_334_',
+ ), # (512, 512, 3, 3)
+ perm=[2, 3, 1, 0],
+ name='transpose_335_',
+ ) # (3, 3, 512, 512)
+ )
+ self._17ParallelBlock_01SequentialBlock_05BatchNorm_gamma = tf.Variable(
+ tf.ones(
+ shape=[512],
+ dtype=tf.dtypes.float32,
+ name='ones_336_',
+ ) # (512)
+ )
+ self._17ParallelBlock_01SequentialBlock_05BatchNorm_beta = tf.Variable(
+ tf.zeros(
+ shape=[512],
+ dtype=tf.dtypes.float32,
+ name='zeros_337_',
+ ) # (512)
+ )
+ self._17ParallelBlock_01SequentialBlock_05BatchNorm_runningMean = tf.Variable(
+ tf.zeros(
+ shape=[512],
+ dtype=tf.dtypes.float32,
+ name='zeros_338_',
+ ) # (512)
+ , trainable = False
+ )
+ self._17ParallelBlock_01SequentialBlock_05BatchNorm_runningVar = tf.Variable(
+ tf.ones(
+ shape=[512],
+ dtype=tf.dtypes.float32,
+ name='ones_339_',
+ ) # (512)
+ , trainable = False
+ )
+ self._17ParallelBlock_01SequentialBlock_07Conv2d_weight = tf.Variable(
+ tf.transpose(
+ tf.random.normal(
+ shape=[2048, 512, 1, 1],
+ mean=0.0,
+ stddev=0.0625,
+ dtype=tf.dtypes.float32,
+ name='normal_340_',
+ ), # (2048, 512, 1, 1)
+ perm=[2, 3, 1, 0],
+ name='transpose_341_',
+ ) # (1, 1, 512, 2048)
+ )
+ self._17ParallelBlock_01SequentialBlock_07Conv2d_bias = tf.Variable(
+ tf.zeros(
+ shape=[2048],
+ dtype=tf.dtypes.float32,
+ name='zeros_342_',
+ ) # (2048)
+ )
+ self._17ParallelBlock_01SequentialBlock_08BatchNorm_gamma = tf.Variable(
+ tf.ones(
+ shape=[2048],
+ dtype=tf.dtypes.float32,
+ name='ones_343_',
+ ) # (2048)
+ )
+ self._17ParallelBlock_01SequentialBlock_08BatchNorm_beta = tf.Variable(
+ tf.zeros(
+ shape=[2048],
+ dtype=tf.dtypes.float32,
+ name='zeros_344_',
+ ) # (2048)
+ )
+ self._17ParallelBlock_01SequentialBlock_08BatchNorm_runningMean = tf.Variable(
+ tf.zeros(
+ shape=[2048],
+ dtype=tf.dtypes.float32,
+ name='zeros_345_',
+ ) # (2048)
+ , trainable = False
+ )
+ self._17ParallelBlock_01SequentialBlock_08BatchNorm_runningVar = tf.Variable(
+ tf.ones(
+ shape=[2048],
+ dtype=tf.dtypes.float32,
+ name='ones_346_',
+ ) # (2048)
+ , trainable = False
+ )
+ self._20Linear_weight = tf.Variable(
+ tf.random.normal(
+ shape=[10, 2048],
+ mean=0.0,
+ stddev=0.03125,
+ dtype=tf.dtypes.float32,
+ name='normal_347_',
+ ) # (10, 2048)
+ )
+ self._20Linear_bias = tf.Variable(
+ tf.zeros(
+ shape=[10],
+ dtype=tf.dtypes.float32,
+ name='zeros_348_',
+ ) # (10)
+ )
+
+## 2
+ def call(self, x):
+ val1 = tf.nn.convolution(
+ x, # (111, 3, 32, 32)
+ filters=self._01Conv2d_weight, # (3, 3, 3, 64)
+ strides=[1, 1],
+ padding='SAME',
+ dilations=[1, 1],
+ data_format='NCHW',
+ name='convolution_349_',
+ ) # (111, 64, 32, 32)
+ (batchnorm1, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm(
+ tf.nn.bias_add(
+ tf.nn.convolution(
+ val1, # (111, 64, 32, 32)
+ filters=self._02ParallelBlock_01SequentialBlock_01Conv2d_weight, # (1, 1, 64, 64)
+ strides=[1, 1],
+ padding='VALID',
+ dilations=[1, 1],
+ data_format='NCHW',
+ name='convolution_350_',
+ ), # (111, 64, 32, 32)
+ bias=self._02ParallelBlock_01SequentialBlock_01Conv2d_bias, # (64)
+ data_format='NCHW',
+ name='bias_add_351_',
+ ), # (111, 64, 32, 32)
+ scale=self._02ParallelBlock_01SequentialBlock_02BatchNorm_gamma, # (64)
+ offset=self._02ParallelBlock_01SequentialBlock_02BatchNorm_beta, # (64)
+ mean=self._02ParallelBlock_01SequentialBlock_02BatchNorm_runningMean, # (64)
+ variance=self._02ParallelBlock_01SequentialBlock_02BatchNorm_runningVar, # (64)
+ epsilon=1.0E-5,
+ is_training=True,
+ exponential_avg_factor=0.9,
+ data_format='NCHW',
+ name='fused_batch_norm_352_',
+ ) # (111, 64, 32, 32)
+ self._02ParallelBlock_01SequentialBlock_02BatchNorm_runningMean.assign(running_mean)
+ self._02ParallelBlock_01SequentialBlock_02BatchNorm_runningVar.assign(running_var)
+ (batchnorm2, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm(
+ tf.nn.convolution(
+ tf.nn.relu(
+ batchnorm1, # (111, 64, 32, 32)
+ name='relu_353_',
+ ), # (111, 64, 32, 32)
+ filters=self._02ParallelBlock_01SequentialBlock_04Conv2d_weight, # (3, 3, 64, 64)
+ strides=[1, 1],
+ padding='SAME',
+ dilations=[1, 1],
+ data_format='NCHW',
+ name='convolution_354_',
+ ), # (111, 64, 32, 32)
+ scale=self._02ParallelBlock_01SequentialBlock_05BatchNorm_gamma, # (64)
+ offset=self._02ParallelBlock_01SequentialBlock_05BatchNorm_beta, # (64)
+ mean=self._02ParallelBlock_01SequentialBlock_05BatchNorm_runningMean, # (64)
+ variance=self._02ParallelBlock_01SequentialBlock_05BatchNorm_runningVar, # (64)
+ epsilon=2.0E-5,
+ is_training=True,
+ exponential_avg_factor=0.9,
+ data_format='NCHW',
+ name='fused_batch_norm_355_',
+ ) # (111, 64, 32, 32)
+ self._02ParallelBlock_01SequentialBlock_05BatchNorm_runningMean.assign(running_mean)
+ self._02ParallelBlock_01SequentialBlock_05BatchNorm_runningVar.assign(running_var)
+ (batchnorm3, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm(
+ tf.nn.bias_add(
+ tf.nn.convolution(
+ tf.nn.relu(
+ batchnorm2, # (111, 64, 32, 32)
+ name='relu_356_',
+ ), # (111, 64, 32, 32)
+ filters=self._02ParallelBlock_01SequentialBlock_07Conv2d_weight, # (1, 1, 64, 256)
+ strides=[1, 1],
+ padding='VALID',
+ dilations=[1, 1],
+ data_format='NCHW',
+ name='convolution_357_',
+ ), # (111, 256, 32, 32)
+ bias=self._02ParallelBlock_01SequentialBlock_07Conv2d_bias, # (256)
+ data_format='NCHW',
+ name='bias_add_358_',
+ ), # (111, 256, 32, 32)
+ scale=self._02ParallelBlock_01SequentialBlock_08BatchNorm_gamma, # (256)
+ offset=self._02ParallelBlock_01SequentialBlock_08BatchNorm_beta, # (256)
+ mean=self._02ParallelBlock_01SequentialBlock_08BatchNorm_runningMean, # (256)
+ variance=self._02ParallelBlock_01SequentialBlock_08BatchNorm_runningVar, # (256)
+ epsilon=1.0E-5,
+ is_training=True,
+ exponential_avg_factor=0.9,
+ data_format='NCHW',
+ name='fused_batch_norm_359_',
+ ) # (111, 256, 32, 32)
+ self._02ParallelBlock_01SequentialBlock_08BatchNorm_runningMean.assign(running_mean)
+ self._02ParallelBlock_01SequentialBlock_08BatchNorm_runningVar.assign(running_var)
+ (batchnorm4, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm(
+ tf.nn.convolution(
+ val1, # (111, 64, 32, 32)
+ filters=self._02ParallelBlock_02SequentialBlock_01Conv2d_weight, # (1, 1, 64, 256)
+ strides=[1, 1],
+ padding='VALID',
+ dilations=[1, 1],
+ data_format='NCHW',
+ name='convolution_360_',
+ ), # (111, 256, 32, 32)
+ scale=self._02ParallelBlock_02SequentialBlock_02BatchNorm_gamma, # (256)
+ offset=self._02ParallelBlock_02SequentialBlock_02BatchNorm_beta, # (256)
+ mean=self._02ParallelBlock_02SequentialBlock_02BatchNorm_runningMean, # (256)
+ variance=self._02ParallelBlock_02SequentialBlock_02BatchNorm_runningVar, # (256)
+ epsilon=1.0E-5,
+ is_training=True,
+ exponential_avg_factor=0.9,
+ data_format='NCHW',
+ name='fused_batch_norm_361_',
+ ) # (111, 256, 32, 32)
+ self._02ParallelBlock_02SequentialBlock_02BatchNorm_runningMean.assign(running_mean)
+ self._02ParallelBlock_02SequentialBlock_02BatchNorm_runningVar.assign(running_var)
+ val2 = tf.nn.relu(
+ tf.add(
+ batchnorm3, # (111, 256, 32, 32)
+ batchnorm4, # (111, 256, 32, 32)
+ name='add_362_',
+ ), # (111, 256, 32, 32)
+ name='relu_363_',
+ ) # (111, 256, 32, 32)
+ (batchnorm5, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm(
+ tf.nn.bias_add(
+ tf.nn.convolution(
+ val2, # (111, 256, 32, 32)
+ filters=self._03ParallelBlock_01SequentialBlock_01Conv2d_weight, # (1, 1, 256, 64)
+ strides=[1, 1],
+ padding='VALID',
+ dilations=[1, 1],
+ data_format='NCHW',
+ name='convolution_364_',
+ ), # (111, 64, 32, 32)
+ bias=self._03ParallelBlock_01SequentialBlock_01Conv2d_bias, # (64)
+ data_format='NCHW',
+ name='bias_add_365_',
+ ), # (111, 64, 32, 32)
+ scale=self._03ParallelBlock_01SequentialBlock_02BatchNorm_gamma, # (64)
+ offset=self._03ParallelBlock_01SequentialBlock_02BatchNorm_beta, # (64)
+ mean=self._03ParallelBlock_01SequentialBlock_02BatchNorm_runningMean, # (64)
+ variance=self._03ParallelBlock_01SequentialBlock_02BatchNorm_runningVar, # (64)
+ epsilon=1.0E-5,
+ is_training=True,
+ exponential_avg_factor=0.9,
+ data_format='NCHW',
+ name='fused_batch_norm_366_',
+ ) # (111, 64, 32, 32)
+ self._03ParallelBlock_01SequentialBlock_02BatchNorm_runningMean.assign(running_mean)
+ self._03ParallelBlock_01SequentialBlock_02BatchNorm_runningVar.assign(running_var)
+ (batchnorm6, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm(
+ tf.nn.convolution(
+ tf.nn.relu(
+ batchnorm5, # (111, 64, 32, 32)
+ name='relu_367_',
+ ), # (111, 64, 32, 32)
+ filters=self._03ParallelBlock_01SequentialBlock_04Conv2d_weight, # (3, 3, 64, 64)
+ strides=[1, 1],
+ padding='SAME',
+ dilations=[1, 1],
+ data_format='NCHW',
+ name='convolution_368_',
+ ), # (111, 64, 32, 32)
+ scale=self._03ParallelBlock_01SequentialBlock_05BatchNorm_gamma, # (64)
+ offset=self._03ParallelBlock_01SequentialBlock_05BatchNorm_beta, # (64)
+ mean=self._03ParallelBlock_01SequentialBlock_05BatchNorm_runningMean, # (64)
+ variance=self._03ParallelBlock_01SequentialBlock_05BatchNorm_runningVar, # (64)
+ epsilon=2.0E-5,
+ is_training=True,
+ exponential_avg_factor=0.9,
+ data_format='NCHW',
+ name='fused_batch_norm_369_',
+ ) # (111, 64, 32, 32)
+ self._03ParallelBlock_01SequentialBlock_05BatchNorm_runningMean.assign(running_mean)
+ self._03ParallelBlock_01SequentialBlock_05BatchNorm_runningVar.assign(running_var)
+ (batchnorm7, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm(
+ tf.nn.bias_add(
+ tf.nn.convolution(
+ tf.nn.relu(
+ batchnorm6, # (111, 64, 32, 32)
+ name='relu_370_',
+ ), # (111, 64, 32, 32)
+ filters=self._03ParallelBlock_01SequentialBlock_07Conv2d_weight, # (1, 1, 64, 256)
+ strides=[1, 1],
+ padding='VALID',
+ dilations=[1, 1],
+ data_format='NCHW',
+ name='convolution_371_',
+ ), # (111, 256, 32, 32)
+ bias=self._03ParallelBlock_01SequentialBlock_07Conv2d_bias, # (256)
+ data_format='NCHW',
+ name='bias_add_372_',
+ ), # (111, 256, 32, 32)
+ scale=self._03ParallelBlock_01SequentialBlock_08BatchNorm_gamma, # (256)
+ offset=self._03ParallelBlock_01SequentialBlock_08BatchNorm_beta, # (256)
+ mean=self._03ParallelBlock_01SequentialBlock_08BatchNorm_runningMean, # (256)
+ variance=self._03ParallelBlock_01SequentialBlock_08BatchNorm_runningVar, # (256)
+ epsilon=1.0E-5,
+ is_training=True,
+ exponential_avg_factor=0.9,
+ data_format='NCHW',
+ name='fused_batch_norm_373_',
+ ) # (111, 256, 32, 32)
+ self._03ParallelBlock_01SequentialBlock_08BatchNorm_runningMean.assign(running_mean)
+ self._03ParallelBlock_01SequentialBlock_08BatchNorm_runningVar.assign(running_var)
+ val3 = tf.nn.relu(
+ tf.add(
+ batchnorm7, # (111, 256, 32, 32)
+ val2, # (111, 256, 32, 32)
+ name='add_374_',
+ ), # (111, 256, 32, 32)
+ name='relu_375_',
+ ) # (111, 256, 32, 32)
+ (batchnorm8, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm(
+ tf.nn.bias_add(
+ tf.nn.convolution(
+ val3, # (111, 256, 32, 32)
+ filters=self._04ParallelBlock_01SequentialBlock_01Conv2d_weight, # (1, 1, 256, 64)
+ strides=[1, 1],
+ padding='VALID',
+ dilations=[1, 1],
+ data_format='NCHW',
+ name='convolution_376_',
+ ), # (111, 64, 32, 32)
+ bias=self._04ParallelBlock_01SequentialBlock_01Conv2d_bias, # (64)
+ data_format='NCHW',
+ name='bias_add_377_',
+ ), # (111, 64, 32, 32)
+ scale=self._04ParallelBlock_01SequentialBlock_02BatchNorm_gamma, # (64)
+ offset=self._04ParallelBlock_01SequentialBlock_02BatchNorm_beta, # (64)
+ mean=self._04ParallelBlock_01SequentialBlock_02BatchNorm_runningMean, # (64)
+ variance=self._04ParallelBlock_01SequentialBlock_02BatchNorm_runningVar, # (64)
+ epsilon=1.0E-5,
+ is_training=True,
+ exponential_avg_factor=0.9,
+ data_format='NCHW',
+ name='fused_batch_norm_378_',
+ ) # (111, 64, 32, 32)
+ self._04ParallelBlock_01SequentialBlock_02BatchNorm_runningMean.assign(running_mean)
+ self._04ParallelBlock_01SequentialBlock_02BatchNorm_runningVar.assign(running_var)
+ (batchnorm9, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm(
+ tf.nn.convolution(
+ tf.nn.relu(
+ batchnorm8, # (111, 64, 32, 32)
+ name='relu_379_',
+ ), # (111, 64, 32, 32)
+ filters=self._04ParallelBlock_01SequentialBlock_04Conv2d_weight, # (3, 3, 64, 64)
+ strides=[1, 1],
+ padding='SAME',
+ dilations=[1, 1],
+ data_format='NCHW',
+ name='convolution_380_',
+ ), # (111, 64, 32, 32)
+ scale=self._04ParallelBlock_01SequentialBlock_05BatchNorm_gamma, # (64)
+ offset=self._04ParallelBlock_01SequentialBlock_05BatchNorm_beta, # (64)
+ mean=self._04ParallelBlock_01SequentialBlock_05BatchNorm_runningMean, # (64)
+ variance=self._04ParallelBlock_01SequentialBlock_05BatchNorm_runningVar, # (64)
+ epsilon=2.0E-5,
+ is_training=True,
+ exponential_avg_factor=0.9,
+ data_format='NCHW',
+ name='fused_batch_norm_381_',
+ ) # (111, 64, 32, 32)
+ self._04ParallelBlock_01SequentialBlock_05BatchNorm_runningMean.assign(running_mean)
+ self._04ParallelBlock_01SequentialBlock_05BatchNorm_runningVar.assign(running_var)
+ (batchnorm10, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm(
+ tf.nn.bias_add(
+ tf.nn.convolution(
+ tf.nn.relu(
+ batchnorm9, # (111, 64, 32, 32)
+ name='relu_382_',
+ ), # (111, 64, 32, 32)
+ filters=self._04ParallelBlock_01SequentialBlock_07Conv2d_weight, # (1, 1, 64, 256)
+ strides=[1, 1],
+ padding='VALID',
+ dilations=[1, 1],
+ data_format='NCHW',
+ name='convolution_383_',
+ ), # (111, 256, 32, 32)
+ bias=self._04ParallelBlock_01SequentialBlock_07Conv2d_bias, # (256)
+ data_format='NCHW',
+ name='bias_add_384_',
+ ), # (111, 256, 32, 32)
+ scale=self._04ParallelBlock_01SequentialBlock_08BatchNorm_gamma, # (256)
+ offset=self._04ParallelBlock_01SequentialBlock_08BatchNorm_beta, # (256)
+ mean=self._04ParallelBlock_01SequentialBlock_08BatchNorm_runningMean, # (256)
+ variance=self._04ParallelBlock_01SequentialBlock_08BatchNorm_runningVar, # (256)
+ epsilon=1.0E-5,
+ is_training=True,
+ exponential_avg_factor=0.9,
+ data_format='NCHW',
+ name='fused_batch_norm_385_',
+ ) # (111, 256, 32, 32)
+ self._04ParallelBlock_01SequentialBlock_08BatchNorm_runningMean.assign(running_mean)
+ self._04ParallelBlock_01SequentialBlock_08BatchNorm_runningVar.assign(running_var)
+ val4 = tf.nn.relu(
+ tf.add(
+ batchnorm10, # (111, 256, 32, 32)
+ val3, # (111, 256, 32, 32)
+ name='add_386_',
+ ), # (111, 256, 32, 32)
+ name='relu_387_',
+ ) # (111, 256, 32, 32)
+ (batchnorm11, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm(
+ tf.nn.bias_add(
+ tf.nn.convolution(
+ val4, # (111, 256, 32, 32)
+ filters=self._05ParallelBlock_01SequentialBlock_01Conv2d_weight, # (1, 1, 256, 128)
+ strides=[2, 2],
+ padding='VALID',
+ dilations=[1, 1],
+ data_format='NCHW',
+ name='convolution_388_',
+ ), # (111, 128, 16, 16)
+ bias=self._05ParallelBlock_01SequentialBlock_01Conv2d_bias, # (128)
+ data_format='NCHW',
+ name='bias_add_389_',
+ ), # (111, 128, 16, 16)
+ scale=self._05ParallelBlock_01SequentialBlock_02BatchNorm_gamma, # (128)
+ offset=self._05ParallelBlock_01SequentialBlock_02BatchNorm_beta, # (128)
+ mean=self._05ParallelBlock_01SequentialBlock_02BatchNorm_runningMean, # (128)
+ variance=self._05ParallelBlock_01SequentialBlock_02BatchNorm_runningVar, # (128)
+ epsilon=1.0E-5,
+ is_training=True,
+ exponential_avg_factor=0.9,
+ data_format='NCHW',
+ name='fused_batch_norm_390_',
+ ) # (111, 128, 16, 16)
+ self._05ParallelBlock_01SequentialBlock_02BatchNorm_runningMean.assign(running_mean)
+ self._05ParallelBlock_01SequentialBlock_02BatchNorm_runningVar.assign(running_var)
+ (batchnorm12, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm(
+ tf.nn.convolution(
+ tf.nn.relu(
+ batchnorm11, # (111, 128, 16, 16)
+ name='relu_391_',
+ ), # (111, 128, 16, 16)
+ filters=self._05ParallelBlock_01SequentialBlock_04Conv2d_weight, # (3, 3, 128, 128)
+ strides=[1, 1],
+ padding='SAME',
+ dilations=[1, 1],
+ data_format='NCHW',
+ name='convolution_392_',
+ ), # (111, 128, 16, 16)
+ scale=self._05ParallelBlock_01SequentialBlock_05BatchNorm_gamma, # (128)
+ offset=self._05ParallelBlock_01SequentialBlock_05BatchNorm_beta, # (128)
+ mean=self._05ParallelBlock_01SequentialBlock_05BatchNorm_runningMean, # (128)
+ variance=self._05ParallelBlock_01SequentialBlock_05BatchNorm_runningVar, # (128)
+ epsilon=2.0E-5,
+ is_training=True,
+ exponential_avg_factor=0.9,
+ data_format='NCHW',
+ name='fused_batch_norm_393_',
+ ) # (111, 128, 16, 16)
+ self._05ParallelBlock_01SequentialBlock_05BatchNorm_runningMean.assign(running_mean)
+ self._05ParallelBlock_01SequentialBlock_05BatchNorm_runningVar.assign(running_var)
+ (batchnorm13, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm(
+ tf.nn.bias_add(
+ tf.nn.convolution(
+ tf.nn.relu(
+ batchnorm12, # (111, 128, 16, 16)
+ name='relu_394_',
+ ), # (111, 128, 16, 16)
+ filters=self._05ParallelBlock_01SequentialBlock_07Conv2d_weight, # (1, 1, 128, 512)
+ strides=[1, 1],
+ padding='VALID',
+ dilations=[1, 1],
+ data_format='NCHW',
+ name='convolution_395_',
+ ), # (111, 512, 16, 16)
+ bias=self._05ParallelBlock_01SequentialBlock_07Conv2d_bias, # (512)
+ data_format='NCHW',
+ name='bias_add_396_',
+ ), # (111, 512, 16, 16)
+ scale=self._05ParallelBlock_01SequentialBlock_08BatchNorm_gamma, # (512)
+ offset=self._05ParallelBlock_01SequentialBlock_08BatchNorm_beta, # (512)
+ mean=self._05ParallelBlock_01SequentialBlock_08BatchNorm_runningMean, # (512)
+ variance=self._05ParallelBlock_01SequentialBlock_08BatchNorm_runningVar, # (512)
+ epsilon=1.0E-5,
+ is_training=True,
+ exponential_avg_factor=0.9,
+ data_format='NCHW',
+ name='fused_batch_norm_397_',
+ ) # (111, 512, 16, 16)
+ self._05ParallelBlock_01SequentialBlock_08BatchNorm_runningMean.assign(running_mean)
+ self._05ParallelBlock_01SequentialBlock_08BatchNorm_runningVar.assign(running_var)
+ (batchnorm14, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm(
+ tf.nn.convolution(
+ val4, # (111, 256, 32, 32)
+ filters=self._05ParallelBlock_02SequentialBlock_01Conv2d_weight, # (1, 1, 256, 512)
+ strides=[2, 2],
+ padding='VALID',
+ dilations=[1, 1],
+ data_format='NCHW',
+ name='convolution_398_',
+ ), # (111, 512, 16, 16)
+ scale=self._05ParallelBlock_02SequentialBlock_02BatchNorm_gamma, # (512)
+ offset=self._05ParallelBlock_02SequentialBlock_02BatchNorm_beta, # (512)
+ mean=self._05ParallelBlock_02SequentialBlock_02BatchNorm_runningMean, # (512)
+ variance=self._05ParallelBlock_02SequentialBlock_02BatchNorm_runningVar, # (512)
+ epsilon=1.0E-5,
+ is_training=True,
+ exponential_avg_factor=0.9,
+ data_format='NCHW',
+ name='fused_batch_norm_399_',
+ ) # (111, 512, 16, 16)
+ self._05ParallelBlock_02SequentialBlock_02BatchNorm_runningMean.assign(running_mean)
+ self._05ParallelBlock_02SequentialBlock_02BatchNorm_runningVar.assign(running_var)
+ val5 = tf.nn.relu(
+ tf.add(
+ batchnorm13, # (111, 512, 16, 16)
+ batchnorm14, # (111, 512, 16, 16)
+ name='add_400_',
+ ), # (111, 512, 16, 16)
+ name='relu_401_',
+ ) # (111, 512, 16, 16)
+ (batchnorm15, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm(
+ tf.nn.bias_add(
+ tf.nn.convolution(
+ val5, # (111, 512, 16, 16)
+ filters=self._06ParallelBlock_01SequentialBlock_01Conv2d_weight, # (1, 1, 512, 128)
+ strides=[1, 1],
+ padding='VALID',
+ dilations=[1, 1],
+ data_format='NCHW',
+ name='convolution_402_',
+ ), # (111, 128, 16, 16)
+ bias=self._06ParallelBlock_01SequentialBlock_01Conv2d_bias, # (128)
+ data_format='NCHW',
+ name='bias_add_403_',
+ ), # (111, 128, 16, 16)
+ scale=self._06ParallelBlock_01SequentialBlock_02BatchNorm_gamma, # (128)
+ offset=self._06ParallelBlock_01SequentialBlock_02BatchNorm_beta, # (128)
+ mean=self._06ParallelBlock_01SequentialBlock_02BatchNorm_runningMean, # (128)
+ variance=self._06ParallelBlock_01SequentialBlock_02BatchNorm_runningVar, # (128)
+ epsilon=1.0E-5,
+ is_training=True,
+ exponential_avg_factor=0.9,
+ data_format='NCHW',
+ name='fused_batch_norm_404_',
+ ) # (111, 128, 16, 16)
+ self._06ParallelBlock_01SequentialBlock_02BatchNorm_runningMean.assign(running_mean)
+ self._06ParallelBlock_01SequentialBlock_02BatchNorm_runningVar.assign(running_var)
+ (batchnorm16, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm(
+ tf.nn.convolution(
+ tf.nn.relu(
+ batchnorm15, # (111, 128, 16, 16)
+ name='relu_405_',
+ ), # (111, 128, 16, 16)
+ filters=self._06ParallelBlock_01SequentialBlock_04Conv2d_weight, # (3, 3, 128, 128)
+ strides=[1, 1],
+ padding='SAME',
+ dilations=[1, 1],
+ data_format='NCHW',
+ name='convolution_406_',
+ ), # (111, 128, 16, 16)
+ scale=self._06ParallelBlock_01SequentialBlock_05BatchNorm_gamma, # (128)
+ offset=self._06ParallelBlock_01SequentialBlock_05BatchNorm_beta, # (128)
+ mean=self._06ParallelBlock_01SequentialBlock_05BatchNorm_runningMean, # (128)
+ variance=self._06ParallelBlock_01SequentialBlock_05BatchNorm_runningVar, # (128)
+ epsilon=2.0E-5,
+ is_training=True,
+ exponential_avg_factor=0.9,
+ data_format='NCHW',
+ name='fused_batch_norm_407_',
+ ) # (111, 128, 16, 16)
+ self._06ParallelBlock_01SequentialBlock_05BatchNorm_runningMean.assign(running_mean)
+ self._06ParallelBlock_01SequentialBlock_05BatchNorm_runningVar.assign(running_var)
+ (batchnorm17, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm(
+ tf.nn.bias_add(
+ tf.nn.convolution(
+ tf.nn.relu(
+ batchnorm16, # (111, 128, 16, 16)
+ name='relu_408_',
+ ), # (111, 128, 16, 16)
+ filters=self._06ParallelBlock_01SequentialBlock_07Conv2d_weight, # (1, 1, 128, 512)
+ strides=[1, 1],
+ padding='VALID',
+ dilations=[1, 1],
+ data_format='NCHW',
+ name='convolution_409_',
+ ), # (111, 512, 16, 16)
+ bias=self._06ParallelBlock_01SequentialBlock_07Conv2d_bias, # (512)
+ data_format='NCHW',
+ name='bias_add_410_',
+ ), # (111, 512, 16, 16)
+ scale=self._06ParallelBlock_01SequentialBlock_08BatchNorm_gamma, # (512)
+ offset=self._06ParallelBlock_01SequentialBlock_08BatchNorm_beta, # (512)
+ mean=self._06ParallelBlock_01SequentialBlock_08BatchNorm_runningMean, # (512)
+ variance=self._06ParallelBlock_01SequentialBlock_08BatchNorm_runningVar, # (512)
+ epsilon=1.0E-5,
+ is_training=True,
+ exponential_avg_factor=0.9,
+ data_format='NCHW',
+ name='fused_batch_norm_411_',
+ ) # (111, 512, 16, 16)
+ self._06ParallelBlock_01SequentialBlock_08BatchNorm_runningMean.assign(running_mean)
+ self._06ParallelBlock_01SequentialBlock_08BatchNorm_runningVar.assign(running_var)
+ val6 = tf.nn.relu(
+ tf.add(
+ batchnorm17, # (111, 512, 16, 16)
+ val5, # (111, 512, 16, 16)
+ name='add_412_',
+ ), # (111, 512, 16, 16)
+ name='relu_413_',
+ ) # (111, 512, 16, 16)
+ (batchnorm18, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm(
+ tf.nn.bias_add(
+ tf.nn.convolution(
+ val6, # (111, 512, 16, 16)
+ filters=self._07ParallelBlock_01SequentialBlock_01Conv2d_weight, # (1, 1, 512, 128)
+ strides=[1, 1],
+ padding='VALID',
+ dilations=[1, 1],
+ data_format='NCHW',
+ name='convolution_414_',
+ ), # (111, 128, 16, 16)
+ bias=self._07ParallelBlock_01SequentialBlock_01Conv2d_bias, # (128)
+ data_format='NCHW',
+ name='bias_add_415_',
+ ), # (111, 128, 16, 16)
+ scale=self._07ParallelBlock_01SequentialBlock_02BatchNorm_gamma, # (128)
+ offset=self._07ParallelBlock_01SequentialBlock_02BatchNorm_beta, # (128)
+ mean=self._07ParallelBlock_01SequentialBlock_02BatchNorm_runningMean, # (128)
+ variance=self._07ParallelBlock_01SequentialBlock_02BatchNorm_runningVar, # (128)
+ epsilon=1.0E-5,
+ is_training=True,
+ exponential_avg_factor=0.9,
+ data_format='NCHW',
+ name='fused_batch_norm_416_',
+ ) # (111, 128, 16, 16)
+ self._07ParallelBlock_01SequentialBlock_02BatchNorm_runningMean.assign(running_mean)
+ self._07ParallelBlock_01SequentialBlock_02BatchNorm_runningVar.assign(running_var)
+ (batchnorm19, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm(
+ tf.nn.convolution(
+ tf.nn.relu(
+ batchnorm18, # (111, 128, 16, 16)
+ name='relu_417_',
+ ), # (111, 128, 16, 16)
+ filters=self._07ParallelBlock_01SequentialBlock_04Conv2d_weight, # (3, 3, 128, 128)
+ strides=[1, 1],
+ padding='SAME',
+ dilations=[1, 1],
+ data_format='NCHW',
+ name='convolution_418_',
+ ), # (111, 128, 16, 16)
+ scale=self._07ParallelBlock_01SequentialBlock_05BatchNorm_gamma, # (128)
+ offset=self._07ParallelBlock_01SequentialBlock_05BatchNorm_beta, # (128)
+ mean=self._07ParallelBlock_01SequentialBlock_05BatchNorm_runningMean, # (128)
+ variance=self._07ParallelBlock_01SequentialBlock_05BatchNorm_runningVar, # (128)
+ epsilon=2.0E-5,
+ is_training=True,
+ exponential_avg_factor=0.9,
+ data_format='NCHW',
+ name='fused_batch_norm_419_',
+ ) # (111, 128, 16, 16)
+ self._07ParallelBlock_01SequentialBlock_05BatchNorm_runningMean.assign(running_mean)
+ self._07ParallelBlock_01SequentialBlock_05BatchNorm_runningVar.assign(running_var)
+ (batchnorm20, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm(
+ tf.nn.bias_add(
+ tf.nn.convolution(
+ tf.nn.relu(
+ batchnorm19, # (111, 128, 16, 16)
+ name='relu_420_',
+ ), # (111, 128, 16, 16)
+ filters=self._07ParallelBlock_01SequentialBlock_07Conv2d_weight, # (1, 1, 128, 512)
+ strides=[1, 1],
+ padding='VALID',
+ dilations=[1, 1],
+ data_format='NCHW',
+ name='convolution_421_',
+ ), # (111, 512, 16, 16)
+ bias=self._07ParallelBlock_01SequentialBlock_07Conv2d_bias, # (512)
+ data_format='NCHW',
+ name='bias_add_422_',
+ ), # (111, 512, 16, 16)
+ scale=self._07ParallelBlock_01SequentialBlock_08BatchNorm_gamma, # (512)
+ offset=self._07ParallelBlock_01SequentialBlock_08BatchNorm_beta, # (512)
+ mean=self._07ParallelBlock_01SequentialBlock_08BatchNorm_runningMean, # (512)
+ variance=self._07ParallelBlock_01SequentialBlock_08BatchNorm_runningVar, # (512)
+ epsilon=1.0E-5,
+ is_training=True,
+ exponential_avg_factor=0.9,
+ data_format='NCHW',
+ name='fused_batch_norm_423_',
+ ) # (111, 512, 16, 16)
+ self._07ParallelBlock_01SequentialBlock_08BatchNorm_runningMean.assign(running_mean)
+ self._07ParallelBlock_01SequentialBlock_08BatchNorm_runningVar.assign(running_var)
+ val7 = tf.nn.relu(
+ tf.add(
+ batchnorm20, # (111, 512, 16, 16)
+ val6, # (111, 512, 16, 16)
+ name='add_424_',
+ ), # (111, 512, 16, 16)
+ name='relu_425_',
+ ) # (111, 512, 16, 16)
+ (batchnorm21, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm(
+ tf.nn.bias_add(
+ tf.nn.convolution(
+ val7, # (111, 512, 16, 16)
+ filters=self._08ParallelBlock_01SequentialBlock_01Conv2d_weight, # (1, 1, 512, 128)
+ strides=[1, 1],
+ padding='VALID',
+ dilations=[1, 1],
+ data_format='NCHW',
+ name='convolution_426_',
+ ), # (111, 128, 16, 16)
+ bias=self._08ParallelBlock_01SequentialBlock_01Conv2d_bias, # (128)
+ data_format='NCHW',
+ name='bias_add_427_',
+ ), # (111, 128, 16, 16)
+ scale=self._08ParallelBlock_01SequentialBlock_02BatchNorm_gamma, # (128)
+ offset=self._08ParallelBlock_01SequentialBlock_02BatchNorm_beta, # (128)
+ mean=self._08ParallelBlock_01SequentialBlock_02BatchNorm_runningMean, # (128)
+ variance=self._08ParallelBlock_01SequentialBlock_02BatchNorm_runningVar, # (128)
+ epsilon=1.0E-5,
+ is_training=True,
+ exponential_avg_factor=0.9,
+ data_format='NCHW',
+ name='fused_batch_norm_428_',
+ ) # (111, 128, 16, 16)
+ self._08ParallelBlock_01SequentialBlock_02BatchNorm_runningMean.assign(running_mean)
+ self._08ParallelBlock_01SequentialBlock_02BatchNorm_runningVar.assign(running_var)
+ (batchnorm22, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm(
+ tf.nn.convolution(
+ tf.nn.relu(
+ batchnorm21, # (111, 128, 16, 16)
+ name='relu_429_',
+ ), # (111, 128, 16, 16)
+ filters=self._08ParallelBlock_01SequentialBlock_04Conv2d_weight, # (3, 3, 128, 128)
+ strides=[1, 1],
+ padding='SAME',
+ dilations=[1, 1],
+ data_format='NCHW',
+ name='convolution_430_',
+ ), # (111, 128, 16, 16)
+ scale=self._08ParallelBlock_01SequentialBlock_05BatchNorm_gamma, # (128)
+ offset=self._08ParallelBlock_01SequentialBlock_05BatchNorm_beta, # (128)
+ mean=self._08ParallelBlock_01SequentialBlock_05BatchNorm_runningMean, # (128)
+ variance=self._08ParallelBlock_01SequentialBlock_05BatchNorm_runningVar, # (128)
+ epsilon=2.0E-5,
+ is_training=True,
+ exponential_avg_factor=0.9,
+ data_format='NCHW',
+ name='fused_batch_norm_431_',
+ ) # (111, 128, 16, 16)
+ self._08ParallelBlock_01SequentialBlock_05BatchNorm_runningMean.assign(running_mean)
+ self._08ParallelBlock_01SequentialBlock_05BatchNorm_runningVar.assign(running_var)
+ (batchnorm23, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm(
+ tf.nn.bias_add(
+ tf.nn.convolution(
+ tf.nn.relu(
+ batchnorm22, # (111, 128, 16, 16)
+ name='relu_432_',
+ ), # (111, 128, 16, 16)
+ filters=self._08ParallelBlock_01SequentialBlock_07Conv2d_weight, # (1, 1, 128, 512)
+ strides=[1, 1],
+ padding='VALID',
+ dilations=[1, 1],
+ data_format='NCHW',
+ name='convolution_433_',
+ ), # (111, 512, 16, 16)
+ bias=self._08ParallelBlock_01SequentialBlock_07Conv2d_bias, # (512)
+ data_format='NCHW',
+ name='bias_add_434_',
+ ), # (111, 512, 16, 16)
+ scale=self._08ParallelBlock_01SequentialBlock_08BatchNorm_gamma, # (512)
+ offset=self._08ParallelBlock_01SequentialBlock_08BatchNorm_beta, # (512)
+ mean=self._08ParallelBlock_01SequentialBlock_08BatchNorm_runningMean, # (512)
+ variance=self._08ParallelBlock_01SequentialBlock_08BatchNorm_runningVar, # (512)
+ epsilon=1.0E-5,
+ is_training=True,
+ exponential_avg_factor=0.9,
+ data_format='NCHW',
+ name='fused_batch_norm_435_',
+ ) # (111, 512, 16, 16)
+ self._08ParallelBlock_01SequentialBlock_08BatchNorm_runningMean.assign(running_mean)
+ self._08ParallelBlock_01SequentialBlock_08BatchNorm_runningVar.assign(running_var)
+ val8 = tf.nn.relu(
+ tf.add(
+ batchnorm23, # (111, 512, 16, 16)
+ val7, # (111, 512, 16, 16)
+ name='add_436_',
+ ), # (111, 512, 16, 16)
+ name='relu_437_',
+ ) # (111, 512, 16, 16)
+ (batchnorm24, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm(
+ tf.nn.bias_add(
+ tf.nn.convolution(
+ val8, # (111, 512, 16, 16)
+ filters=self._09ParallelBlock_01SequentialBlock_01Conv2d_weight, # (1, 1, 512, 256)
+ strides=[2, 2],
+ padding='VALID',
+ dilations=[1, 1],
+ data_format='NCHW',
+ name='convolution_438_',
+ ), # (111, 256, 8, 8)
+ bias=self._09ParallelBlock_01SequentialBlock_01Conv2d_bias, # (256)
+ data_format='NCHW',
+ name='bias_add_439_',
+ ), # (111, 256, 8, 8)
+ scale=self._09ParallelBlock_01SequentialBlock_02BatchNorm_gamma, # (256)
+ offset=self._09ParallelBlock_01SequentialBlock_02BatchNorm_beta, # (256)
+ mean=self._09ParallelBlock_01SequentialBlock_02BatchNorm_runningMean, # (256)
+ variance=self._09ParallelBlock_01SequentialBlock_02BatchNorm_runningVar, # (256)
+ epsilon=1.0E-5,
+ is_training=True,
+ exponential_avg_factor=0.9,
+ data_format='NCHW',
+ name='fused_batch_norm_440_',
+ ) # (111, 256, 8, 8)
+ self._09ParallelBlock_01SequentialBlock_02BatchNorm_runningMean.assign(running_mean)
+ self._09ParallelBlock_01SequentialBlock_02BatchNorm_runningVar.assign(running_var)
+ (batchnorm25, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm(
+ tf.nn.convolution(
+ tf.nn.relu(
+ batchnorm24, # (111, 256, 8, 8)
+ name='relu_441_',
+ ), # (111, 256, 8, 8)
+ filters=self._09ParallelBlock_01SequentialBlock_04Conv2d_weight, # (3, 3, 256, 256)
+ strides=[1, 1],
+ padding='SAME',
+ dilations=[1, 1],
+ data_format='NCHW',
+ name='convolution_442_',
+ ), # (111, 256, 8, 8)
+ scale=self._09ParallelBlock_01SequentialBlock_05BatchNorm_gamma, # (256)
+ offset=self._09ParallelBlock_01SequentialBlock_05BatchNorm_beta, # (256)
+ mean=self._09ParallelBlock_01SequentialBlock_05BatchNorm_runningMean, # (256)
+ variance=self._09ParallelBlock_01SequentialBlock_05BatchNorm_runningVar, # (256)
+ epsilon=2.0E-5,
+ is_training=True,
+ exponential_avg_factor=0.9,
+ data_format='NCHW',
+ name='fused_batch_norm_443_',
+ ) # (111, 256, 8, 8)
+ self._09ParallelBlock_01SequentialBlock_05BatchNorm_runningMean.assign(running_mean)
+ self._09ParallelBlock_01SequentialBlock_05BatchNorm_runningVar.assign(running_var)
+ (batchnorm26, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm(
+ tf.nn.bias_add(
+ tf.nn.convolution(
+ tf.nn.relu(
+ batchnorm25, # (111, 256, 8, 8)
+ name='relu_444_',
+ ), # (111, 256, 8, 8)
+ filters=self._09ParallelBlock_01SequentialBlock_07Conv2d_weight, # (1, 1, 256, 1024)
+ strides=[1, 1],
+ padding='VALID',
+ dilations=[1, 1],
+ data_format='NCHW',
+ name='convolution_445_',
+ ), # (111, 1024, 8, 8)
+ bias=self._09ParallelBlock_01SequentialBlock_07Conv2d_bias, # (1024)
+ data_format='NCHW',
+ name='bias_add_446_',
+ ), # (111, 1024, 8, 8)
+ scale=self._09ParallelBlock_01SequentialBlock_08BatchNorm_gamma, # (1024)
+ offset=self._09ParallelBlock_01SequentialBlock_08BatchNorm_beta, # (1024)
+ mean=self._09ParallelBlock_01SequentialBlock_08BatchNorm_runningMean, # (1024)
+ variance=self._09ParallelBlock_01SequentialBlock_08BatchNorm_runningVar, # (1024)
+ epsilon=1.0E-5,
+ is_training=True,
+ exponential_avg_factor=0.9,
+ data_format='NCHW',
+ name='fused_batch_norm_447_',
+ ) # (111, 1024, 8, 8)
+ self._09ParallelBlock_01SequentialBlock_08BatchNorm_runningMean.assign(running_mean)
+ self._09ParallelBlock_01SequentialBlock_08BatchNorm_runningVar.assign(running_var)
+ (batchnorm27, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm(
+ tf.nn.convolution(
+ val8, # (111, 512, 16, 16)
+ filters=self._09ParallelBlock_02SequentialBlock_01Conv2d_weight, # (1, 1, 512, 1024)
+ strides=[2, 2],
+ padding='VALID',
+ dilations=[1, 1],
+ data_format='NCHW',
+ name='convolution_448_',
+ ), # (111, 1024, 8, 8)
+ scale=self._09ParallelBlock_02SequentialBlock_02BatchNorm_gamma, # (1024)
+ offset=self._09ParallelBlock_02SequentialBlock_02BatchNorm_beta, # (1024)
+ mean=self._09ParallelBlock_02SequentialBlock_02BatchNorm_runningMean, # (1024)
+ variance=self._09ParallelBlock_02SequentialBlock_02BatchNorm_runningVar, # (1024)
+ epsilon=1.0E-5,
+ is_training=True,
+ exponential_avg_factor=0.9,
+ data_format='NCHW',
+ name='fused_batch_norm_449_',
+ ) # (111, 1024, 8, 8)
+ self._09ParallelBlock_02SequentialBlock_02BatchNorm_runningMean.assign(running_mean)
+ self._09ParallelBlock_02SequentialBlock_02BatchNorm_runningVar.assign(running_var)
+ val9 = tf.nn.relu(
+ tf.add(
+ batchnorm26, # (111, 1024, 8, 8)
+ batchnorm27, # (111, 1024, 8, 8)
+ name='add_450_',
+ ), # (111, 1024, 8, 8)
+ name='relu_451_',
+ ) # (111, 1024, 8, 8)
+ (batchnorm28, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm(
+ tf.nn.bias_add(
+ tf.nn.convolution(
+ val9, # (111, 1024, 8, 8)
+ filters=self._10ParallelBlock_01SequentialBlock_01Conv2d_weight, # (1, 1, 1024, 256)
+ strides=[1, 1],
+ padding='VALID',
+ dilations=[1, 1],
+ data_format='NCHW',
+ name='convolution_452_',
+ ), # (111, 256, 8, 8)
+ bias=self._10ParallelBlock_01SequentialBlock_01Conv2d_bias, # (256)
+ data_format='NCHW',
+ name='bias_add_453_',
+ ), # (111, 256, 8, 8)
+ scale=self._10ParallelBlock_01SequentialBlock_02BatchNorm_gamma, # (256)
+ offset=self._10ParallelBlock_01SequentialBlock_02BatchNorm_beta, # (256)
+ mean=self._10ParallelBlock_01SequentialBlock_02BatchNorm_runningMean, # (256)
+ variance=self._10ParallelBlock_01SequentialBlock_02BatchNorm_runningVar, # (256)
+ epsilon=1.0E-5,
+ is_training=True,
+ exponential_avg_factor=0.9,
+ data_format='NCHW',
+ name='fused_batch_norm_454_',
+ ) # (111, 256, 8, 8)
+ self._10ParallelBlock_01SequentialBlock_02BatchNorm_runningMean.assign(running_mean)
+ self._10ParallelBlock_01SequentialBlock_02BatchNorm_runningVar.assign(running_var)
+ (batchnorm29, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm(
+ tf.nn.convolution(
+ tf.nn.relu(
+ batchnorm28, # (111, 256, 8, 8)
+ name='relu_455_',
+ ), # (111, 256, 8, 8)
+ filters=self._10ParallelBlock_01SequentialBlock_04Conv2d_weight, # (3, 3, 256, 256)
+ strides=[1, 1],
+ padding='SAME',
+ dilations=[1, 1],
+ data_format='NCHW',
+ name='convolution_456_',
+ ), # (111, 256, 8, 8)
+ scale=self._10ParallelBlock_01SequentialBlock_05BatchNorm_gamma, # (256)
+ offset=self._10ParallelBlock_01SequentialBlock_05BatchNorm_beta, # (256)
+ mean=self._10ParallelBlock_01SequentialBlock_05BatchNorm_runningMean, # (256)
+ variance=self._10ParallelBlock_01SequentialBlock_05BatchNorm_runningVar, # (256)
+ epsilon=2.0E-5,
+ is_training=True,
+ exponential_avg_factor=0.9,
+ data_format='NCHW',
+ name='fused_batch_norm_457_',
+ ) # (111, 256, 8, 8)
+ self._10ParallelBlock_01SequentialBlock_05BatchNorm_runningMean.assign(running_mean)
+ self._10ParallelBlock_01SequentialBlock_05BatchNorm_runningVar.assign(running_var)
+ (batchnorm30, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm(
+ tf.nn.bias_add(
+ tf.nn.convolution(
+ tf.nn.relu(
+ batchnorm29, # (111, 256, 8, 8)
+ name='relu_458_',
+ ), # (111, 256, 8, 8)
+ filters=self._10ParallelBlock_01SequentialBlock_07Conv2d_weight, # (1, 1, 256, 1024)
+ strides=[1, 1],
+ padding='VALID',
+ dilations=[1, 1],
+ data_format='NCHW',
+ name='convolution_459_',
+ ), # (111, 1024, 8, 8)
+ bias=self._10ParallelBlock_01SequentialBlock_07Conv2d_bias, # (1024)
+ data_format='NCHW',
+ name='bias_add_460_',
+ ), # (111, 1024, 8, 8)
+ scale=self._10ParallelBlock_01SequentialBlock_08BatchNorm_gamma, # (1024)
+ offset=self._10ParallelBlock_01SequentialBlock_08BatchNorm_beta, # (1024)
+ mean=self._10ParallelBlock_01SequentialBlock_08BatchNorm_runningMean, # (1024)
+ variance=self._10ParallelBlock_01SequentialBlock_08BatchNorm_runningVar, # (1024)
+ epsilon=1.0E-5,
+ is_training=True,
+ exponential_avg_factor=0.9,
+ data_format='NCHW',
+ name='fused_batch_norm_461_',
+ ) # (111, 1024, 8, 8)
+ self._10ParallelBlock_01SequentialBlock_08BatchNorm_runningMean.assign(running_mean)
+ self._10ParallelBlock_01SequentialBlock_08BatchNorm_runningVar.assign(running_var)
+ val10 = tf.nn.relu(
+ tf.add(
+ batchnorm30, # (111, 1024, 8, 8)
+ val9, # (111, 1024, 8, 8)
+ name='add_462_',
+ ), # (111, 1024, 8, 8)
+ name='relu_463_',
+ ) # (111, 1024, 8, 8)
+ (batchnorm31, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm(
+ tf.nn.bias_add(
+ tf.nn.convolution(
+ val10, # (111, 1024, 8, 8)
+ filters=self._11ParallelBlock_01SequentialBlock_01Conv2d_weight, # (1, 1, 1024, 256)
+ strides=[1, 1],
+ padding='VALID',
+ dilations=[1, 1],
+ data_format='NCHW',
+ name='convolution_464_',
+ ), # (111, 256, 8, 8)
+ bias=self._11ParallelBlock_01SequentialBlock_01Conv2d_bias, # (256)
+ data_format='NCHW',
+ name='bias_add_465_',
+ ), # (111, 256, 8, 8)
+ scale=self._11ParallelBlock_01SequentialBlock_02BatchNorm_gamma, # (256)
+ offset=self._11ParallelBlock_01SequentialBlock_02BatchNorm_beta, # (256)
+ mean=self._11ParallelBlock_01SequentialBlock_02BatchNorm_runningMean, # (256)
+ variance=self._11ParallelBlock_01SequentialBlock_02BatchNorm_runningVar, # (256)
+ epsilon=1.0E-5,
+ is_training=True,
+ exponential_avg_factor=0.9,
+ data_format='NCHW',
+ name='fused_batch_norm_466_',
+ ) # (111, 256, 8, 8)
+ self._11ParallelBlock_01SequentialBlock_02BatchNorm_runningMean.assign(running_mean)
+ self._11ParallelBlock_01SequentialBlock_02BatchNorm_runningVar.assign(running_var)
+ (batchnorm32, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm(
+ tf.nn.convolution(
+ tf.nn.relu(
+ batchnorm31, # (111, 256, 8, 8)
+ name='relu_467_',
+ ), # (111, 256, 8, 8)
+ filters=self._11ParallelBlock_01SequentialBlock_04Conv2d_weight, # (3, 3, 256, 256)
+ strides=[1, 1],
+ padding='SAME',
+ dilations=[1, 1],
+ data_format='NCHW',
+ name='convolution_468_',
+ ), # (111, 256, 8, 8)
+ scale=self._11ParallelBlock_01SequentialBlock_05BatchNorm_gamma, # (256)
+ offset=self._11ParallelBlock_01SequentialBlock_05BatchNorm_beta, # (256)
+ mean=self._11ParallelBlock_01SequentialBlock_05BatchNorm_runningMean, # (256)
+ variance=self._11ParallelBlock_01SequentialBlock_05BatchNorm_runningVar, # (256)
+ epsilon=2.0E-5,
+ is_training=True,
+ exponential_avg_factor=0.9,
+ data_format='NCHW',
+ name='fused_batch_norm_469_',
+ ) # (111, 256, 8, 8)
+ self._11ParallelBlock_01SequentialBlock_05BatchNorm_runningMean.assign(running_mean)
+ self._11ParallelBlock_01SequentialBlock_05BatchNorm_runningVar.assign(running_var)
+ (batchnorm33, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm(
+ tf.nn.bias_add(
+ tf.nn.convolution(
+ tf.nn.relu(
+ batchnorm32, # (111, 256, 8, 8)
+ name='relu_470_',
+ ), # (111, 256, 8, 8)
+ filters=self._11ParallelBlock_01SequentialBlock_07Conv2d_weight, # (1, 1, 256, 1024)
+ strides=[1, 1],
+ padding='VALID',
+ dilations=[1, 1],
+ data_format='NCHW',
+ name='convolution_471_',
+ ), # (111, 1024, 8, 8)
+ bias=self._11ParallelBlock_01SequentialBlock_07Conv2d_bias, # (1024)
+ data_format='NCHW',
+ name='bias_add_472_',
+ ), # (111, 1024, 8, 8)
+ scale=self._11ParallelBlock_01SequentialBlock_08BatchNorm_gamma, # (1024)
+ offset=self._11ParallelBlock_01SequentialBlock_08BatchNorm_beta, # (1024)
+ mean=self._11ParallelBlock_01SequentialBlock_08BatchNorm_runningMean, # (1024)
+ variance=self._11ParallelBlock_01SequentialBlock_08BatchNorm_runningVar, # (1024)
+ epsilon=1.0E-5,
+ is_training=True,
+ exponential_avg_factor=0.9,
+ data_format='NCHW',
+ name='fused_batch_norm_473_',
+ ) # (111, 1024, 8, 8)
+ self._11ParallelBlock_01SequentialBlock_08BatchNorm_runningMean.assign(running_mean)
+ self._11ParallelBlock_01SequentialBlock_08BatchNorm_runningVar.assign(running_var)
+ val11 = tf.nn.relu(
+ tf.add(
+ batchnorm33, # (111, 1024, 8, 8)
+ val10, # (111, 1024, 8, 8)
+ name='add_474_',
+ ), # (111, 1024, 8, 8)
+ name='relu_475_',
+ ) # (111, 1024, 8, 8)
+ (batchnorm34, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm(
+ tf.nn.bias_add(
+ tf.nn.convolution(
+ val11, # (111, 1024, 8, 8)
+ filters=self._12ParallelBlock_01SequentialBlock_01Conv2d_weight, # (1, 1, 1024, 256)
+ strides=[1, 1],
+ padding='VALID',
+ dilations=[1, 1],
+ data_format='NCHW',
+ name='convolution_476_',
+ ), # (111, 256, 8, 8)
+ bias=self._12ParallelBlock_01SequentialBlock_01Conv2d_bias, # (256)
+ data_format='NCHW',
+ name='bias_add_477_',
+ ), # (111, 256, 8, 8)
+ scale=self._12ParallelBlock_01SequentialBlock_02BatchNorm_gamma, # (256)
+ offset=self._12ParallelBlock_01SequentialBlock_02BatchNorm_beta, # (256)
+ mean=self._12ParallelBlock_01SequentialBlock_02BatchNorm_runningMean, # (256)
+ variance=self._12ParallelBlock_01SequentialBlock_02BatchNorm_runningVar, # (256)
+ epsilon=1.0E-5,
+ is_training=True,
+ exponential_avg_factor=0.9,
+ data_format='NCHW',
+ name='fused_batch_norm_478_',
+ ) # (111, 256, 8, 8)
+ self._12ParallelBlock_01SequentialBlock_02BatchNorm_runningMean.assign(running_mean)
+ self._12ParallelBlock_01SequentialBlock_02BatchNorm_runningVar.assign(running_var)
+ (batchnorm35, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm(
+ tf.nn.convolution(
+ tf.nn.relu(
+ batchnorm34, # (111, 256, 8, 8)
+ name='relu_479_',
+ ), # (111, 256, 8, 8)
+ filters=self._12ParallelBlock_01SequentialBlock_04Conv2d_weight, # (3, 3, 256, 256)
+ strides=[1, 1],
+ padding='SAME',
+ dilations=[1, 1],
+ data_format='NCHW',
+ name='convolution_480_',
+ ), # (111, 256, 8, 8)
+ scale=self._12ParallelBlock_01SequentialBlock_05BatchNorm_gamma, # (256)
+ offset=self._12ParallelBlock_01SequentialBlock_05BatchNorm_beta, # (256)
+ mean=self._12ParallelBlock_01SequentialBlock_05BatchNorm_runningMean, # (256)
+ variance=self._12ParallelBlock_01SequentialBlock_05BatchNorm_runningVar, # (256)
+ epsilon=2.0E-5,
+ is_training=True,
+ exponential_avg_factor=0.9,
+ data_format='NCHW',
+ name='fused_batch_norm_481_',
+ ) # (111, 256, 8, 8)
+ self._12ParallelBlock_01SequentialBlock_05BatchNorm_runningMean.assign(running_mean)
+ self._12ParallelBlock_01SequentialBlock_05BatchNorm_runningVar.assign(running_var)
+ (batchnorm36, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm(
+ tf.nn.bias_add(
+ tf.nn.convolution(
+ tf.nn.relu(
+ batchnorm35, # (111, 256, 8, 8)
+ name='relu_482_',
+ ), # (111, 256, 8, 8)
+ filters=self._12ParallelBlock_01SequentialBlock_07Conv2d_weight, # (1, 1, 256, 1024)
+ strides=[1, 1],
+ padding='VALID',
+ dilations=[1, 1],
+ data_format='NCHW',
+ name='convolution_483_',
+ ), # (111, 1024, 8, 8)
+ bias=self._12ParallelBlock_01SequentialBlock_07Conv2d_bias, # (1024)
+ data_format='NCHW',
+ name='bias_add_484_',
+ ), # (111, 1024, 8, 8)
+ scale=self._12ParallelBlock_01SequentialBlock_08BatchNorm_gamma, # (1024)
+ offset=self._12ParallelBlock_01SequentialBlock_08BatchNorm_beta, # (1024)
+ mean=self._12ParallelBlock_01SequentialBlock_08BatchNorm_runningMean, # (1024)
+ variance=self._12ParallelBlock_01SequentialBlock_08BatchNorm_runningVar, # (1024)
+ epsilon=1.0E-5,
+ is_training=True,
+ exponential_avg_factor=0.9,
+ data_format='NCHW',
+ name='fused_batch_norm_485_',
+ ) # (111, 1024, 8, 8)
+ self._12ParallelBlock_01SequentialBlock_08BatchNorm_runningMean.assign(running_mean)
+ self._12ParallelBlock_01SequentialBlock_08BatchNorm_runningVar.assign(running_var)
+ val12 = tf.nn.relu(
+ tf.add(
+ batchnorm36, # (111, 1024, 8, 8)
+ val11, # (111, 1024, 8, 8)
+ name='add_486_',
+ ), # (111, 1024, 8, 8)
+ name='relu_487_',
+ ) # (111, 1024, 8, 8)
+ (batchnorm37, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm(
+ tf.nn.bias_add(
+ tf.nn.convolution(
+ val12, # (111, 1024, 8, 8)
+ filters=self._13ParallelBlock_01SequentialBlock_01Conv2d_weight, # (1, 1, 1024, 256)
+ strides=[1, 1],
+ padding='VALID',
+ dilations=[1, 1],
+ data_format='NCHW',
+ name='convolution_488_',
+ ), # (111, 256, 8, 8)
+ bias=self._13ParallelBlock_01SequentialBlock_01Conv2d_bias, # (256)
+ data_format='NCHW',
+ name='bias_add_489_',
+ ), # (111, 256, 8, 8)
+ scale=self._13ParallelBlock_01SequentialBlock_02BatchNorm_gamma, # (256)
+ offset=self._13ParallelBlock_01SequentialBlock_02BatchNorm_beta, # (256)
+ mean=self._13ParallelBlock_01SequentialBlock_02BatchNorm_runningMean, # (256)
+ variance=self._13ParallelBlock_01SequentialBlock_02BatchNorm_runningVar, # (256)
+ epsilon=1.0E-5,
+ is_training=True,
+ exponential_avg_factor=0.9,
+ data_format='NCHW',
+ name='fused_batch_norm_490_',
+ ) # (111, 256, 8, 8)
+ self._13ParallelBlock_01SequentialBlock_02BatchNorm_runningMean.assign(running_mean)
+ self._13ParallelBlock_01SequentialBlock_02BatchNorm_runningVar.assign(running_var)
+ (batchnorm38, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm(
+ tf.nn.convolution(
+ tf.nn.relu(
+ batchnorm37, # (111, 256, 8, 8)
+ name='relu_491_',
+ ), # (111, 256, 8, 8)
+ filters=self._13ParallelBlock_01SequentialBlock_04Conv2d_weight, # (3, 3, 256, 256)
+ strides=[1, 1],
+ padding='SAME',
+ dilations=[1, 1],
+ data_format='NCHW',
+ name='convolution_492_',
+ ), # (111, 256, 8, 8)
+ scale=self._13ParallelBlock_01SequentialBlock_05BatchNorm_gamma, # (256)
+ offset=self._13ParallelBlock_01SequentialBlock_05BatchNorm_beta, # (256)
+ mean=self._13ParallelBlock_01SequentialBlock_05BatchNorm_runningMean, # (256)
+ variance=self._13ParallelBlock_01SequentialBlock_05BatchNorm_runningVar, # (256)
+ epsilon=2.0E-5,
+ is_training=True,
+ exponential_avg_factor=0.9,
+ data_format='NCHW',
+ name='fused_batch_norm_493_',
+ ) # (111, 256, 8, 8)
+ self._13ParallelBlock_01SequentialBlock_05BatchNorm_runningMean.assign(running_mean)
+ self._13ParallelBlock_01SequentialBlock_05BatchNorm_runningVar.assign(running_var)
+ (batchnorm39, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm(
+ tf.nn.bias_add(
+ tf.nn.convolution(
+ tf.nn.relu(
+ batchnorm38, # (111, 256, 8, 8)
+ name='relu_494_',
+ ), # (111, 256, 8, 8)
+ filters=self._13ParallelBlock_01SequentialBlock_07Conv2d_weight, # (1, 1, 256, 1024)
+ strides=[1, 1],
+ padding='VALID',
+ dilations=[1, 1],
+ data_format='NCHW',
+ name='convolution_495_',
+ ), # (111, 1024, 8, 8)
+ bias=self._13ParallelBlock_01SequentialBlock_07Conv2d_bias, # (1024)
+ data_format='NCHW',
+ name='bias_add_496_',
+ ), # (111, 1024, 8, 8)
+ scale=self._13ParallelBlock_01SequentialBlock_08BatchNorm_gamma, # (1024)
+ offset=self._13ParallelBlock_01SequentialBlock_08BatchNorm_beta, # (1024)
+ mean=self._13ParallelBlock_01SequentialBlock_08BatchNorm_runningMean, # (1024)
+ variance=self._13ParallelBlock_01SequentialBlock_08BatchNorm_runningVar, # (1024)
+ epsilon=1.0E-5,
+ is_training=True,
+ exponential_avg_factor=0.9,
+ data_format='NCHW',
+ name='fused_batch_norm_497_',
+ ) # (111, 1024, 8, 8)
+ self._13ParallelBlock_01SequentialBlock_08BatchNorm_runningMean.assign(running_mean)
+ self._13ParallelBlock_01SequentialBlock_08BatchNorm_runningVar.assign(running_var)
+ val13 = tf.nn.relu(
+ tf.add(
+ batchnorm39, # (111, 1024, 8, 8)
+ val12, # (111, 1024, 8, 8)
+ name='add_498_',
+ ), # (111, 1024, 8, 8)
+ name='relu_499_',
+ ) # (111, 1024, 8, 8)
+ (batchnorm40, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm(
+ tf.nn.bias_add(
+ tf.nn.convolution(
+ val13, # (111, 1024, 8, 8)
+ filters=self._14ParallelBlock_01SequentialBlock_01Conv2d_weight, # (1, 1, 1024, 256)
+ strides=[1, 1],
+ padding='VALID',
+ dilations=[1, 1],
+ data_format='NCHW',
+ name='convolution_500_',
+ ), # (111, 256, 8, 8)
+ bias=self._14ParallelBlock_01SequentialBlock_01Conv2d_bias, # (256)
+ data_format='NCHW',
+ name='bias_add_501_',
+ ), # (111, 256, 8, 8)
+ scale=self._14ParallelBlock_01SequentialBlock_02BatchNorm_gamma, # (256)
+ offset=self._14ParallelBlock_01SequentialBlock_02BatchNorm_beta, # (256)
+ mean=self._14ParallelBlock_01SequentialBlock_02BatchNorm_runningMean, # (256)
+ variance=self._14ParallelBlock_01SequentialBlock_02BatchNorm_runningVar, # (256)
+ epsilon=1.0E-5,
+ is_training=True,
+ exponential_avg_factor=0.9,
+ data_format='NCHW',
+ name='fused_batch_norm_502_',
+ ) # (111, 256, 8, 8)
+ self._14ParallelBlock_01SequentialBlock_02BatchNorm_runningMean.assign(running_mean)
+ self._14ParallelBlock_01SequentialBlock_02BatchNorm_runningVar.assign(running_var)
+ (batchnorm41, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm(
+ tf.nn.convolution(
+ tf.nn.relu(
+ batchnorm40, # (111, 256, 8, 8)
+ name='relu_503_',
+ ), # (111, 256, 8, 8)
+ filters=self._14ParallelBlock_01SequentialBlock_04Conv2d_weight, # (3, 3, 256, 256)
+ strides=[1, 1],
+ padding='SAME',
+ dilations=[1, 1],
+ data_format='NCHW',
+ name='convolution_504_',
+ ), # (111, 256, 8, 8)
+ scale=self._14ParallelBlock_01SequentialBlock_05BatchNorm_gamma, # (256)
+ offset=self._14ParallelBlock_01SequentialBlock_05BatchNorm_beta, # (256)
+ mean=self._14ParallelBlock_01SequentialBlock_05BatchNorm_runningMean, # (256)
+ variance=self._14ParallelBlock_01SequentialBlock_05BatchNorm_runningVar, # (256)
+ epsilon=2.0E-5,
+ is_training=True,
+ exponential_avg_factor=0.9,
+ data_format='NCHW',
+ name='fused_batch_norm_505_',
+ ) # (111, 256, 8, 8)
+ self._14ParallelBlock_01SequentialBlock_05BatchNorm_runningMean.assign(running_mean)
+ self._14ParallelBlock_01SequentialBlock_05BatchNorm_runningVar.assign(running_var)
+ (batchnorm42, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm(
+ tf.nn.bias_add(
+ tf.nn.convolution(
+ tf.nn.relu(
+ batchnorm41, # (111, 256, 8, 8)
+ name='relu_506_',
+ ), # (111, 256, 8, 8)
+ filters=self._14ParallelBlock_01SequentialBlock_07Conv2d_weight, # (1, 1, 256, 1024)
+ strides=[1, 1],
+ padding='VALID',
+ dilations=[1, 1],
+ data_format='NCHW',
+ name='convolution_507_',
+ ), # (111, 1024, 8, 8)
+ bias=self._14ParallelBlock_01SequentialBlock_07Conv2d_bias, # (1024)
+ data_format='NCHW',
+ name='bias_add_508_',
+ ), # (111, 1024, 8, 8)
+ scale=self._14ParallelBlock_01SequentialBlock_08BatchNorm_gamma, # (1024)
+ offset=self._14ParallelBlock_01SequentialBlock_08BatchNorm_beta, # (1024)
+ mean=self._14ParallelBlock_01SequentialBlock_08BatchNorm_runningMean, # (1024)
+ variance=self._14ParallelBlock_01SequentialBlock_08BatchNorm_runningVar, # (1024)
+ epsilon=1.0E-5,
+ is_training=True,
+ exponential_avg_factor=0.9,
+ data_format='NCHW',
+ name='fused_batch_norm_509_',
+ ) # (111, 1024, 8, 8)
+ self._14ParallelBlock_01SequentialBlock_08BatchNorm_runningMean.assign(running_mean)
+ self._14ParallelBlock_01SequentialBlock_08BatchNorm_runningVar.assign(running_var)
+ val14 = tf.nn.relu(
+ tf.add(
+ batchnorm42, # (111, 1024, 8, 8)
+ val13, # (111, 1024, 8, 8)
+ name='add_510_',
+ ), # (111, 1024, 8, 8)
+ name='relu_511_',
+ ) # (111, 1024, 8, 8)
+ (batchnorm43, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm(
+ tf.nn.bias_add(
+ tf.nn.convolution(
+ val14, # (111, 1024, 8, 8)
+ filters=self._15ParallelBlock_01SequentialBlock_01Conv2d_weight, # (1, 1, 1024, 512)
+ strides=[2, 2],
+ padding='VALID',
+ dilations=[1, 1],
+ data_format='NCHW',
+ name='convolution_512_',
+ ), # (111, 512, 4, 4)
+ bias=self._15ParallelBlock_01SequentialBlock_01Conv2d_bias, # (512)
+ data_format='NCHW',
+ name='bias_add_513_',
+ ), # (111, 512, 4, 4)
+ scale=self._15ParallelBlock_01SequentialBlock_02BatchNorm_gamma, # (512)
+ offset=self._15ParallelBlock_01SequentialBlock_02BatchNorm_beta, # (512)
+ mean=self._15ParallelBlock_01SequentialBlock_02BatchNorm_runningMean, # (512)
+ variance=self._15ParallelBlock_01SequentialBlock_02BatchNorm_runningVar, # (512)
+ epsilon=1.0E-5,
+ is_training=True,
+ exponential_avg_factor=0.9,
+ data_format='NCHW',
+ name='fused_batch_norm_514_',
+ ) # (111, 512, 4, 4)
+ self._15ParallelBlock_01SequentialBlock_02BatchNorm_runningMean.assign(running_mean)
+ self._15ParallelBlock_01SequentialBlock_02BatchNorm_runningVar.assign(running_var)
+ (batchnorm44, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm(
+ tf.nn.convolution(
+ tf.nn.relu(
+ batchnorm43, # (111, 512, 4, 4)
+ name='relu_515_',
+ ), # (111, 512, 4, 4)
+ filters=self._15ParallelBlock_01SequentialBlock_04Conv2d_weight, # (3, 3, 512, 512)
+ strides=[1, 1],
+ padding='SAME',
+ dilations=[1, 1],
+ data_format='NCHW',
+ name='convolution_516_',
+ ), # (111, 512, 4, 4)
+ scale=self._15ParallelBlock_01SequentialBlock_05BatchNorm_gamma, # (512)
+ offset=self._15ParallelBlock_01SequentialBlock_05BatchNorm_beta, # (512)
+ mean=self._15ParallelBlock_01SequentialBlock_05BatchNorm_runningMean, # (512)
+ variance=self._15ParallelBlock_01SequentialBlock_05BatchNorm_runningVar, # (512)
+ epsilon=2.0E-5,
+ is_training=True,
+ exponential_avg_factor=0.9,
+ data_format='NCHW',
+ name='fused_batch_norm_517_',
+ ) # (111, 512, 4, 4)
+ self._15ParallelBlock_01SequentialBlock_05BatchNorm_runningMean.assign(running_mean)
+ self._15ParallelBlock_01SequentialBlock_05BatchNorm_runningVar.assign(running_var)
+ (batchnorm45, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm(
+ tf.nn.bias_add(
+ tf.nn.convolution(
+ tf.nn.relu(
+ batchnorm44, # (111, 512, 4, 4)
+ name='relu_518_',
+ ), # (111, 512, 4, 4)
+ filters=self._15ParallelBlock_01SequentialBlock_07Conv2d_weight, # (1, 1, 512, 2048)
+ strides=[1, 1],
+ padding='VALID',
+ dilations=[1, 1],
+ data_format='NCHW',
+ name='convolution_519_',
+ ), # (111, 2048, 4, 4)
+ bias=self._15ParallelBlock_01SequentialBlock_07Conv2d_bias, # (2048)
+ data_format='NCHW',
+ name='bias_add_520_',
+ ), # (111, 2048, 4, 4)
+ scale=self._15ParallelBlock_01SequentialBlock_08BatchNorm_gamma, # (2048)
+ offset=self._15ParallelBlock_01SequentialBlock_08BatchNorm_beta, # (2048)
+ mean=self._15ParallelBlock_01SequentialBlock_08BatchNorm_runningMean, # (2048)
+ variance=self._15ParallelBlock_01SequentialBlock_08BatchNorm_runningVar, # (2048)
+ epsilon=1.0E-5,
+ is_training=True,
+ exponential_avg_factor=0.9,
+ data_format='NCHW',
+ name='fused_batch_norm_521_',
+ ) # (111, 2048, 4, 4)
+ self._15ParallelBlock_01SequentialBlock_08BatchNorm_runningMean.assign(running_mean)
+ self._15ParallelBlock_01SequentialBlock_08BatchNorm_runningVar.assign(running_var)
+ (batchnorm46, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm(
+ tf.nn.convolution(
+ val14, # (111, 1024, 8, 8)
+ filters=self._15ParallelBlock_02SequentialBlock_01Conv2d_weight, # (1, 1, 1024, 2048)
+ strides=[2, 2],
+ padding='VALID',
+ dilations=[1, 1],
+ data_format='NCHW',
+ name='convolution_522_',
+ ), # (111, 2048, 4, 4)
+ scale=self._15ParallelBlock_02SequentialBlock_02BatchNorm_gamma, # (2048)
+ offset=self._15ParallelBlock_02SequentialBlock_02BatchNorm_beta, # (2048)
+ mean=self._15ParallelBlock_02SequentialBlock_02BatchNorm_runningMean, # (2048)
+ variance=self._15ParallelBlock_02SequentialBlock_02BatchNorm_runningVar, # (2048)
+ epsilon=1.0E-5,
+ is_training=True,
+ exponential_avg_factor=0.9,
+ data_format='NCHW',
+ name='fused_batch_norm_523_',
+ ) # (111, 2048, 4, 4)
+ self._15ParallelBlock_02SequentialBlock_02BatchNorm_runningMean.assign(running_mean)
+ self._15ParallelBlock_02SequentialBlock_02BatchNorm_runningVar.assign(running_var)
+ val15 = tf.nn.relu(
+ tf.add(
+ batchnorm45, # (111, 2048, 4, 4)
+ batchnorm46, # (111, 2048, 4, 4)
+ name='add_524_',
+ ), # (111, 2048, 4, 4)
+ name='relu_525_',
+ ) # (111, 2048, 4, 4)
+ (batchnorm47, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm(
+ tf.nn.bias_add(
+ tf.nn.convolution(
+ val15, # (111, 2048, 4, 4)
+ filters=self._16ParallelBlock_01SequentialBlock_01Conv2d_weight, # (1, 1, 2048, 512)
+ strides=[1, 1],
+ padding='VALID',
+ dilations=[1, 1],
+ data_format='NCHW',
+ name='convolution_526_',
+ ), # (111, 512, 4, 4)
+ bias=self._16ParallelBlock_01SequentialBlock_01Conv2d_bias, # (512)
+ data_format='NCHW',
+ name='bias_add_527_',
+ ), # (111, 512, 4, 4)
+ scale=self._16ParallelBlock_01SequentialBlock_02BatchNorm_gamma, # (512)
+ offset=self._16ParallelBlock_01SequentialBlock_02BatchNorm_beta, # (512)
+ mean=self._16ParallelBlock_01SequentialBlock_02BatchNorm_runningMean, # (512)
+ variance=self._16ParallelBlock_01SequentialBlock_02BatchNorm_runningVar, # (512)
+ epsilon=1.0E-5,
+ is_training=True,
+ exponential_avg_factor=0.9,
+ data_format='NCHW',
+ name='fused_batch_norm_528_',
+ ) # (111, 512, 4, 4)
+ self._16ParallelBlock_01SequentialBlock_02BatchNorm_runningMean.assign(running_mean)
+ self._16ParallelBlock_01SequentialBlock_02BatchNorm_runningVar.assign(running_var)
+ (batchnorm48, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm(
+ tf.nn.convolution(
+ tf.nn.relu(
+ batchnorm47, # (111, 512, 4, 4)
+ name='relu_529_',
+ ), # (111, 512, 4, 4)
+ filters=self._16ParallelBlock_01SequentialBlock_04Conv2d_weight, # (3, 3, 512, 512)
+ strides=[1, 1],
+ padding='SAME',
+ dilations=[1, 1],
+ data_format='NCHW',
+ name='convolution_530_',
+ ), # (111, 512, 4, 4)
+ scale=self._16ParallelBlock_01SequentialBlock_05BatchNorm_gamma, # (512)
+ offset=self._16ParallelBlock_01SequentialBlock_05BatchNorm_beta, # (512)
+ mean=self._16ParallelBlock_01SequentialBlock_05BatchNorm_runningMean, # (512)
+ variance=self._16ParallelBlock_01SequentialBlock_05BatchNorm_runningVar, # (512)
+ epsilon=2.0E-5,
+ is_training=True,
+ exponential_avg_factor=0.9,
+ data_format='NCHW',
+ name='fused_batch_norm_531_',
+ ) # (111, 512, 4, 4)
+ self._16ParallelBlock_01SequentialBlock_05BatchNorm_runningMean.assign(running_mean)
+ self._16ParallelBlock_01SequentialBlock_05BatchNorm_runningVar.assign(running_var)
+ (batchnorm49, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm(
+ tf.nn.bias_add(
+ tf.nn.convolution(
+ tf.nn.relu(
+ batchnorm48, # (111, 512, 4, 4)
+ name='relu_532_',
+ ), # (111, 512, 4, 4)
+ filters=self._16ParallelBlock_01SequentialBlock_07Conv2d_weight, # (1, 1, 512, 2048)
+ strides=[1, 1],
+ padding='VALID',
+ dilations=[1, 1],
+ data_format='NCHW',
+ name='convolution_533_',
+ ), # (111, 2048, 4, 4)
+ bias=self._16ParallelBlock_01SequentialBlock_07Conv2d_bias, # (2048)
+ data_format='NCHW',
+ name='bias_add_534_',
+ ), # (111, 2048, 4, 4)
+ scale=self._16ParallelBlock_01SequentialBlock_08BatchNorm_gamma, # (2048)
+ offset=self._16ParallelBlock_01SequentialBlock_08BatchNorm_beta, # (2048)
+ mean=self._16ParallelBlock_01SequentialBlock_08BatchNorm_runningMean, # (2048)
+ variance=self._16ParallelBlock_01SequentialBlock_08BatchNorm_runningVar, # (2048)
+ epsilon=1.0E-5,
+ is_training=True,
+ exponential_avg_factor=0.9,
+ data_format='NCHW',
+ name='fused_batch_norm_535_',
+ ) # (111, 2048, 4, 4)
+ self._16ParallelBlock_01SequentialBlock_08BatchNorm_runningMean.assign(running_mean)
+ self._16ParallelBlock_01SequentialBlock_08BatchNorm_runningVar.assign(running_var)
+ val16 = tf.nn.relu(
+ tf.add(
+ batchnorm49, # (111, 2048, 4, 4)
+ val15, # (111, 2048, 4, 4)
+ name='add_536_',
+ ), # (111, 2048, 4, 4)
+ name='relu_537_',
+ ) # (111, 2048, 4, 4)
+ (batchnorm50, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm(
+ tf.nn.bias_add(
+ tf.nn.convolution(
+ val16, # (111, 2048, 4, 4)
+ filters=self._17ParallelBlock_01SequentialBlock_01Conv2d_weight, # (1, 1, 2048, 512)
+ strides=[1, 1],
+ padding='VALID',
+ dilations=[1, 1],
+ data_format='NCHW',
+ name='convolution_538_',
+ ), # (111, 512, 4, 4)
+ bias=self._17ParallelBlock_01SequentialBlock_01Conv2d_bias, # (512)
+ data_format='NCHW',
+ name='bias_add_539_',
+ ), # (111, 512, 4, 4)
+ scale=self._17ParallelBlock_01SequentialBlock_02BatchNorm_gamma, # (512)
+ offset=self._17ParallelBlock_01SequentialBlock_02BatchNorm_beta, # (512)
+ mean=self._17ParallelBlock_01SequentialBlock_02BatchNorm_runningMean, # (512)
+ variance=self._17ParallelBlock_01SequentialBlock_02BatchNorm_runningVar, # (512)
+ epsilon=1.0E-5,
+ is_training=True,
+ exponential_avg_factor=0.9,
+ data_format='NCHW',
+ name='fused_batch_norm_540_',
+ ) # (111, 512, 4, 4)
+ self._17ParallelBlock_01SequentialBlock_02BatchNorm_runningMean.assign(running_mean)
+ self._17ParallelBlock_01SequentialBlock_02BatchNorm_runningVar.assign(running_var)
+ (batchnorm51, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm(
+ tf.nn.convolution(
+ tf.nn.relu(
+ batchnorm50, # (111, 512, 4, 4)
+ name='relu_541_',
+ ), # (111, 512, 4, 4)
+ filters=self._17ParallelBlock_01SequentialBlock_04Conv2d_weight, # (3, 3, 512, 512)
+ strides=[1, 1],
+ padding='SAME',
+ dilations=[1, 1],
+ data_format='NCHW',
+ name='convolution_542_',
+ ), # (111, 512, 4, 4)
+ scale=self._17ParallelBlock_01SequentialBlock_05BatchNorm_gamma, # (512)
+ offset=self._17ParallelBlock_01SequentialBlock_05BatchNorm_beta, # (512)
+ mean=self._17ParallelBlock_01SequentialBlock_05BatchNorm_runningMean, # (512)
+ variance=self._17ParallelBlock_01SequentialBlock_05BatchNorm_runningVar, # (512)
+ epsilon=2.0E-5,
+ is_training=True,
+ exponential_avg_factor=0.9,
+ data_format='NCHW',
+ name='fused_batch_norm_543_',
+ ) # (111, 512, 4, 4)
+ self._17ParallelBlock_01SequentialBlock_05BatchNorm_runningMean.assign(running_mean)
+ self._17ParallelBlock_01SequentialBlock_05BatchNorm_runningVar.assign(running_var)
+ (batchnorm52, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm(
+ tf.nn.bias_add(
+ tf.nn.convolution(
+ tf.nn.relu(
+ batchnorm51, # (111, 512, 4, 4)
+ name='relu_544_',
+ ), # (111, 512, 4, 4)
+ filters=self._17ParallelBlock_01SequentialBlock_07Conv2d_weight, # (1, 1, 512, 2048)
+ strides=[1, 1],
+ padding='VALID',
+ dilations=[1, 1],
+ data_format='NCHW',
+ name='convolution_545_',
+ ), # (111, 2048, 4, 4)
+ bias=self._17ParallelBlock_01SequentialBlock_07Conv2d_bias, # (2048)
+ data_format='NCHW',
+ name='bias_add_546_',
+ ), # (111, 2048, 4, 4)
+ scale=self._17ParallelBlock_01SequentialBlock_08BatchNorm_gamma, # (2048)
+ offset=self._17ParallelBlock_01SequentialBlock_08BatchNorm_beta, # (2048)
+ mean=self._17ParallelBlock_01SequentialBlock_08BatchNorm_runningMean, # (2048)
+ variance=self._17ParallelBlock_01SequentialBlock_08BatchNorm_runningVar, # (2048)
+ epsilon=1.0E-5,
+ is_training=True,
+ exponential_avg_factor=0.9,
+ data_format='NCHW',
+ name='fused_batch_norm_547_',
+ ) # (111, 2048, 4, 4)
+ self._17ParallelBlock_01SequentialBlock_08BatchNorm_runningMean.assign(running_mean)
+ self._17ParallelBlock_01SequentialBlock_08BatchNorm_runningVar.assign(running_var)
+ result = tf.reshape(
+ tf.nn.bias_add(
+ tf.matmul(
+ tf.reshape(
+ tf.reshape(
+ tf.reduce_mean(
+ tf.nn.relu(
+ tf.add(
+ batchnorm52, # (111, 2048, 4, 4)
+ val16, # (111, 2048, 4, 4)
+ name='add_548_',
+ ), # (111, 2048, 4, 4)
+ name='relu_549_',
+ ), # (111, 2048, 4, 4)
+ axis=[2, 3],
+ name='reduce_mean_550_',
+ ), # (111, 2048, 1, 1)
+ shape=[-1, 2048],
+ name='reshape_551_',
+ ), # (111, 2048)
+ shape=[-1, 2048],
+ name='reshape_552_',
+ ), # (111, 2048)
+ b=self._20Linear_weight, # (10, 2048)
+ transpose_b=True,
+ name='matmul_553_',
+ ), # (111, 10)
+ bias=self._20Linear_bias, # (10)
+ data_format=None,
+ name='bias_add_554_',
+ ), # (111, 10)
+ shape=[-1, 10],
+ name='reshape_555_',
+ ) # (111, 10)
+ return result
+
+## 2
+def loss(label, prediction):
+ result = tf.reduce_mean(
+ tf.negative(
+ tf.gather(
+ tf.nn.log_softmax(
+ prediction, # (111, 10)
+ axis=-1,
+ name='log_softmax_556_',
+ ), # (111, 10)
+ indices=label, # (111)
+ batch_dims=1,
+ name='gather_557_',
+ ), # (111, 1)
+ name='negative_558_',
+ ), # (111, 1)
+ name='reduce_mean_559_',
+ ) # ()
+ return result
+
+# number of epochs was 2
+# number of prediction functions is 1
+# number of loss functions is 1
+
diff --git a/examples/src/test/resources/yolov8_synset.txt b/examples/src/test/resources/yolov8_synset.txt
new file mode 100644
index 00000000000..ffba2064933
--- /dev/null
+++ b/examples/src/test/resources/yolov8_synset.txt
@@ -0,0 +1,84 @@
+# Classes for coco dataset on which yelov8 is trained
+# source config https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/datasets/coco.yaml.
+# COCO dataset website: https://cocodataset.org/#home
+# Ultralytics Coco doc page: https://docs.ultralytics.com/datasets/detect/coco/
+person
+bicycle
+car
+motorbike
+aeroplane
+bus
+train
+truck
+boat
+traffic light
+fire hydrant
+stop sign
+parking meter
+bench
+bird
+cat
+dog
+horse
+sheep
+cow
+elephant
+bear
+zebra
+giraffe
+backpack
+umbrella
+handbag
+tie
+suitcase
+frisbee
+skis
+snowboard
+sports ball
+kite
+baseball bat
+baseball glove
+skateboard
+surfboard
+tennis racket
+bottle
+wine glass
+cup
+fork
+knife
+spoon
+bowl
+banana
+apple
+sandwich
+orange
+broccoli
+carrot
+hot dog
+pizza
+donut
+cake
+chair
+sofa
+pottedplant
+bed
+diningtable
+toilet
+tvmonitor
+laptop
+mouse
+remote
+keyboard
+cell phone
+microwave
+oven
+toaster
+sink
+refrigerator
+book
+clock
+vase
+scissors
+teddy bear
+hair drier
+toothbrush
\ No newline at end of file
diff --git a/examples/src/test/resources/yolov8_test.jpg b/examples/src/test/resources/yolov8_test.jpg
new file mode 100644
index 00000000000..01e43374348
Binary files /dev/null and b/examples/src/test/resources/yolov8_test.jpg differ
diff --git a/examples/src/test/resources/yolov8n.onnx b/examples/src/test/resources/yolov8n.onnx
new file mode 100644
index 00000000000..430f7f2beb0
Binary files /dev/null and b/examples/src/test/resources/yolov8n.onnx differ
diff --git a/extensions/audio/README.md b/extensions/audio/README.md
index 7e2c89692bc..95ed8c53a84 100644
--- a/extensions/audio/README.md
+++ b/extensions/audio/README.md
@@ -23,6 +23,6 @@ You can pull the module from the central Maven repository by including the follo
ai.djl.audio
audio
- 0.23.0
+ 0.26.0
```
diff --git a/extensions/aws-ai/README.md b/extensions/aws-ai/README.md
index 829df0bb0ca..16d412904c5 100644
--- a/extensions/aws-ai/README.md
+++ b/extensions/aws-ai/README.md
@@ -58,6 +58,6 @@ You can pull the module from the central Maven repository by including the follo
ai.djl.aws
aws-ai
- 0.23.0
+ 0.26.0
```
diff --git a/extensions/fasttext/README.md b/extensions/fasttext/README.md
index 6f5a25064ea..f0c60d39bf1 100644
--- a/extensions/fasttext/README.md
+++ b/extensions/fasttext/README.md
@@ -34,7 +34,7 @@ You can pull the fastText engine from the central Maven repository by including
ai.djl.fasttext
fasttext-engine
- 0.23.0
+ 0.26.0
```
diff --git a/extensions/fasttext/src/test/java/ai/djl/fasttext/CookingStackExchangeTest.java b/extensions/fasttext/src/test/java/ai/djl/fasttext/CookingStackExchangeTest.java
index 5b421ff431f..4395ddf1a6c 100644
--- a/extensions/fasttext/src/test/java/ai/djl/fasttext/CookingStackExchangeTest.java
+++ b/extensions/fasttext/src/test/java/ai/djl/fasttext/CookingStackExchangeTest.java
@@ -41,6 +41,7 @@
import java.io.IOException;
import java.io.InputStream;
+import java.net.URI;
import java.net.URL;
import java.nio.file.Files;
import java.nio.file.Path;
@@ -129,7 +130,9 @@ public void testWord2Vec() throws IOException, MalformedModelException, ModelNot
public void testBlazingText() throws IOException, ModelException {
TestRequirements.nightly();
- URL url = new URL("https://resources.djl.ai/test-models/blazingtext_classification.bin");
+ URL url =
+ URI.create("https://resources.djl.ai/test-models/blazingtext_classification.bin")
+ .toURL();
Path path = Paths.get("build/tmp/model");
Path modelFile = path.resolve("text_classification.bin");
if (!Files.exists(modelFile)) {
diff --git a/extensions/hadoop/README.md b/extensions/hadoop/README.md
index b3c4ebcc762..8a376e22d85 100644
--- a/extensions/hadoop/README.md
+++ b/extensions/hadoop/README.md
@@ -52,6 +52,6 @@ You can pull the module from the central Maven repository by including the follo
ai.djl.hadoop
hadoop
- 0.23.0
+ 0.26.0
```
diff --git a/extensions/opencv/README.md b/extensions/opencv/README.md
index d6c58f518dc..c23e0c58532 100644
--- a/extensions/opencv/README.md
+++ b/extensions/opencv/README.md
@@ -23,6 +23,6 @@ You can pull the module from the central Maven repository by including the follo
ai.djl.opencv
opencv
- 0.23.0
+ 0.26.0
```
diff --git a/extensions/sentencepiece/README.md b/extensions/sentencepiece/README.md
index 4308308111f..de28d5334df 100644
--- a/extensions/sentencepiece/README.md
+++ b/extensions/sentencepiece/README.md
@@ -23,6 +23,6 @@ You can pull the module from the central Maven repository by including the follo
ai.djl.sentencepiece
sentencepiece
- 0.23.0
+ 0.26.0
```
diff --git a/extensions/spark/README.md b/extensions/spark/README.md
index 02ebcc07a1d..da3171ca008 100644
--- a/extensions/spark/README.md
+++ b/extensions/spark/README.md
@@ -34,7 +34,7 @@ You can pull the module from the central Maven repository by including the follo
ai.djl.spark
spark_2.12
- 0.23.0
+ 0.26.0
```
diff --git a/extensions/tablesaw/README.md b/extensions/tablesaw/README.md
index 010c6395eb9..b4287d9733d 100644
--- a/extensions/tablesaw/README.md
+++ b/extensions/tablesaw/README.md
@@ -25,6 +25,6 @@ You can pull the module from the central Maven repository by including the follo
ai.djl.tablesaw
tablesaw
- 0.23.0
+ 0.26.0
```
diff --git a/extensions/timeseries/README.md b/extensions/timeseries/README.md
index 9706c9334a4..3ef6887825c 100644
--- a/extensions/timeseries/README.md
+++ b/extensions/timeseries/README.md
@@ -245,6 +245,6 @@ You can pull the module from the central Maven repository by including the follo
ai.djl.timeseries
timeseries
- 0.23.0
+ 0.26.0
```
diff --git a/extensions/timeseries/docs/forecast_with_M5_data.md b/extensions/timeseries/docs/forecast_with_M5_data.md
index a4f1a24a1d9..7b8e1c78210 100644
--- a/extensions/timeseries/docs/forecast_with_M5_data.md
+++ b/extensions/timeseries/docs/forecast_with_M5_data.md
@@ -1,5 +1,7 @@
# Forecast the future in a timeseries data with Deep Java Library (DJL)
+
## -- Demonstration on M5forecasting and airpassenger datasests
+
Junyuan Zhang, Kexin Feng
Time series data are commonly seen in the world. They can contain valued information that helps forecast for the future, monitor the status of a procedure and feedforward a control. Generic applications includes the following: sales forecasting, stock market analysis, yield projections, process and quality control, and many many more. See [link1](https://www.itl.nist.gov/div898/handbook/pmc/section4/pmc41.htm) and [link2](https://www.influxdata.com/time-series-forecasting-methods/#:~:text=Time%20series%20forecasting%20means%20to,on%20what%20has%20already%20happened) for further examples of timeseries data.
@@ -54,7 +56,7 @@ repositories {
}
dependencies {
implementation "org.apache.logging.log4j:log4j-slf4j-impl:2.17.1"
- implementation platform("ai.djl:bom:0.23.0")
+ implementation platform("ai.djl:bom:0.26.0")
implementation "ai.djl:api"
implementation "ai.djl.timeseries"
runtimeOnly "ai.djl.mxnet:mxnet-engine"
diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/evaluator/Rmsse.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/evaluator/Rmsse.java
index 5b642285c3e..9edb45ff5f0 100644
--- a/extensions/timeseries/src/main/java/ai/djl/timeseries/evaluator/Rmsse.java
+++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/evaluator/Rmsse.java
@@ -94,15 +94,23 @@ public void addAccumulator(String key) {
/** {@inheritDoc} */
@Override
public void updateAccumulator(String key, NDList labels, NDList predictions) {
+ updateAccumulators(new String[] {key}, labels, predictions);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public void updateAccumulators(String[] keys, NDList labels, NDList predictions) {
Pair update = evaluateHelper(labels, predictions);
- totalInstances.compute(key, (k, v) -> v + update.getKey());
- totalLoss.compute(
- key,
- (k, v) -> {
- try (NDArray array = update.getValue().sum()) {
- return v + array.getFloat();
- }
- });
+ for (String key : keys) {
+ totalInstances.compute(key, (k, v) -> v + update.getKey());
+ totalLoss.compute(
+ key,
+ (k, v) -> {
+ try (NDArray array = update.getValue().sum()) {
+ return v + array.getFloat();
+ }
+ });
+ }
}
/** {@inheritDoc} */
diff --git a/extensions/tokenizers/README.md b/extensions/tokenizers/README.md
index 1b85625572c..2cdf4f19137 100644
--- a/extensions/tokenizers/README.md
+++ b/extensions/tokenizers/README.md
@@ -23,7 +23,7 @@ You can pull the module from the central Maven repository by including the follo
ai.djl.huggingface
tokenizers
- 0.23.0
+ 0.26.0
```
diff --git a/extensions/tokenizers/build.cmd b/extensions/tokenizers/build.cmd
index 3a481d33bab..d83f2c1ed74 100644
--- a/extensions/tokenizers/build.cmd
+++ b/extensions/tokenizers/build.cmd
@@ -3,7 +3,7 @@
@rem choco install rust -y
@rem choco install jdk8 -y
-set VERSION=python-v"%1"
+set VERSION=v"%1"
if exist "tokenizers" (
echo Found "tokenizers"
diff --git a/extensions/tokenizers/build.sh b/extensions/tokenizers/build.sh
index 4ba45a09965..229e8124914 100755
--- a/extensions/tokenizers/build.sh
+++ b/extensions/tokenizers/build.sh
@@ -10,7 +10,7 @@ elif [[ -n $(command -v sysctl) ]]; then
fi
PLATFORM=$(uname | tr '[:upper:]' '[:lower:]')
-VERSION=python-v$1
+VERSION=v$1
ARCH=$2
pushd $WORK_DIR
diff --git a/extensions/tokenizers/rust/Cargo.toml b/extensions/tokenizers/rust/Cargo.toml
index f6b846f636c..3418c8f5129 100644
--- a/extensions/tokenizers/rust/Cargo.toml
+++ b/extensions/tokenizers/rust/Cargo.toml
@@ -6,7 +6,7 @@ edition = "2018"
[dependencies]
jni = "0.19.0"
-tokenizers = { path = "../tokenizers/tokenizers", version = "*" }
+tokenizers = { path = "../tokenizers/tokenizers", version = "*", features = ["http"] }
[target.'cfg(target_os = "linux")'.dependencies]
openssl = { version = "0.10", features = ["vendored"] }
diff --git a/extensions/tokenizers/rust/src/lib.rs b/extensions/tokenizers/rust/src/lib.rs
index d1c0c455c19..590099c2ecf 100644
--- a/extensions/tokenizers/rust/src/lib.rs
+++ b/extensions/tokenizers/rust/src/lib.rs
@@ -490,7 +490,7 @@ pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_
}
}
let decoding: String = tokenizer
- .decode(decode_ids, skip_special_tokens == JNI_TRUE)
+ .decode(&*decode_ids, skip_special_tokens == JNI_TRUE)
.unwrap();
let ret = env
.new_string(decoding)
@@ -527,8 +527,12 @@ pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_
}
batch_decode_input.push(decode_ids);
}
+ let mut references: Vec<&[u32]> = Vec::new();
+ for reference in batch_decode_input.iter() {
+ references.push(reference);
+ }
let decoding: Vec = tokenizer
- .decode_batch(batch_decode_input, skip_special_tokens == JNI_TRUE)
+ .decode_batch(&references, skip_special_tokens == JNI_TRUE)
.unwrap();
let ret: jobjectArray = env
.new_object_array(batch_len, "java/lang/String", JObject::null())
diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/Encoding.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/Encoding.java
index e58d6ada5ee..887f01646dc 100644
--- a/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/Encoding.java
+++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/Encoding.java
@@ -27,6 +27,7 @@ public class Encoding {
private long[] specialTokenMask;
private CharSpan[] charTokenSpans;
private Encoding[] overflowing;
+ private boolean exceedMaxLength;
protected Encoding(
long[] ids,
@@ -36,6 +37,7 @@ protected Encoding(
long[] attentionMask,
long[] specialTokenMask,
CharSpan[] charTokenSpans,
+ boolean exceedMaxLength,
Encoding[] overflowing) {
this.ids = ids;
this.typeIds = typeIds;
@@ -44,6 +46,7 @@ protected Encoding(
this.attentionMask = attentionMask;
this.specialTokenMask = specialTokenMask;
this.charTokenSpans = charTokenSpans;
+ this.exceedMaxLength = exceedMaxLength;
this.overflowing = overflowing;
}
@@ -127,6 +130,15 @@ public CharSpan[] getCharTokenSpans() {
return charTokenSpans;
}
+ /**
+ * Returns if tokens exceed max length.
+ *
+ * @return {@code true} if tokens exceed max length
+ */
+ public boolean exceedMaxLength() {
+ return exceedMaxLength;
+ }
+
/**
* Returns an array of overflowing encodings.
*
diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java
index f75342b7cb8..ba4d61b79b1 100644
--- a/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java
+++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java
@@ -44,6 +44,7 @@ public final class HuggingFaceTokenizer extends NativeResource implements
private static final Logger logger = LoggerFactory.getLogger(HuggingFaceTokenizer.class);
private boolean addSpecialTokens;
+ private boolean withOverflowingTokens;
private TruncationStrategy truncation;
private PaddingStrategy padding;
private int maxLength;
@@ -64,6 +65,8 @@ private HuggingFaceTokenizer(long handle, Map options) {
if (options != null) {
val = options.getOrDefault("addSpecialTokens", "true");
addSpecialTokens = Boolean.parseBoolean(val);
+ val = options.getOrDefault("withOverflowingTokens", "false");
+ withOverflowingTokens = Boolean.parseBoolean(val);
modelMaxLength = ArgumentsUtil.intValue(options, "modelMaxLength", 512);
if (options.containsKey("truncation")) {
truncation = TruncationStrategy.fromValue(options.get("truncation"));
@@ -203,11 +206,12 @@ public void close() {
* @param text the input sentence
* @param addSpecialTokens whether to encode the sequence with special tokens relative to their
* model
+ * @param withOverflowingTokens whether to return overflowing tokens
* @return the {@code Encoding} of the input sentence
*/
- public Encoding encode(String text, boolean addSpecialTokens) {
+ public Encoding encode(String text, boolean addSpecialTokens, boolean withOverflowingTokens) {
long encoding = TokenizersLibrary.LIB.encode(getHandle(), text, addSpecialTokens);
- return toEncoding(encoding);
+ return toEncoding(encoding, withOverflowingTokens);
}
/**
@@ -217,7 +221,7 @@ public Encoding encode(String text, boolean addSpecialTokens) {
* @return the {@code Encoding} of the input sentence
*/
public Encoding encode(String text) {
- return encode(text, addSpecialTokens);
+ return encode(text, addSpecialTokens, withOverflowingTokens);
}
/**
@@ -227,12 +231,14 @@ public Encoding encode(String text) {
* @param textPair the second input sentence
* @param addSpecialTokens whether to encode the sequence with special tokens relative to their
* model
+ * @param withOverflowingTokens whether to return overflowing tokens
* @return the {@code Encoding} of the input sentence
*/
- public Encoding encode(String text, String textPair, boolean addSpecialTokens) {
+ public Encoding encode(
+ String text, String textPair, boolean addSpecialTokens, boolean withOverflowingTokens) {
long encoding =
TokenizersLibrary.LIB.encodeDual(getHandle(), text, textPair, addSpecialTokens);
- return toEncoding(encoding);
+ return toEncoding(encoding, withOverflowingTokens);
}
/**
@@ -243,7 +249,7 @@ public Encoding encode(String text, String textPair, boolean addSpecialTokens) {
* @return the {@code Encoding} of the input sentence
*/
public Encoding encode(String text, String textPair) {
- return encode(text, textPair, addSpecialTokens);
+ return encode(text, textPair, addSpecialTokens, withOverflowingTokens);
}
/**
@@ -252,11 +258,13 @@ public Encoding encode(String text, String textPair) {
* @param inputs the input sentences
* @param addSpecialTokens whether to encode the sequence with special tokens relative to their
* model
+ * @param withOverflowingTokens whether to return overflowing tokens
* @return the {@code Encoding} of the input sentences
*/
- public Encoding encode(List inputs, boolean addSpecialTokens) {
+ public Encoding encode(
+ List inputs, boolean addSpecialTokens, boolean withOverflowingTokens) {
String[] array = inputs.toArray(Utils.EMPTY_ARRAY);
- return encode(array, addSpecialTokens);
+ return encode(array, addSpecialTokens, withOverflowingTokens);
}
/**
@@ -266,7 +274,7 @@ public Encoding encode(List