From 612acf06562985ba5bd825058b177c87baeac823 Mon Sep 17 00:00:00 2001 From: Tibor Mezei Date: Wed, 27 Jan 2021 18:40:33 -0800 Subject: [PATCH] Put the output of the model into existing NDArrays when provided --- .../main/java/ai/djl/inference/Predictor.java | 34 ++++++++++++++++--- .../java/ai/djl/modality/nlp/Encoder.java | 1 + .../ai/djl/modality/nlp/EncoderDecoder.java | 1 + .../nlp/embedding/TrainableTextEmbedding.java | 1 + api/src/main/java/ai/djl/nn/Block.java | 29 +++++++++++++++- api/src/main/java/ai/djl/nn/LambdaBlock.java | 1 + .../main/java/ai/djl/nn/ParallelBlock.java | 1 + .../main/java/ai/djl/nn/SequentialBlock.java | 1 + .../ai/djl/nn/convolutional/Convolution.java | 1 + .../djl/nn/convolutional/Deconvolution.java | 1 + .../ai/djl/nn/core/ConstantEmbedding.java | 1 + .../main/java/ai/djl/nn/core/Embedding.java | 1 + api/src/main/java/ai/djl/nn/core/Linear.java | 1 + api/src/main/java/ai/djl/nn/core/Prelu.java | 1 + .../main/java/ai/djl/nn/norm/BatchNorm.java | 1 + api/src/main/java/ai/djl/nn/norm/Dropout.java | 1 + .../ai/djl/nn/recurrent/RecurrentBlock.java | 1 + .../ScaledDotProductAttentionBlock.java | 1 + .../ssd/SingleShotDetection.java | 1 + .../basicmodelzoo/nlp/SimpleTextDecoder.java | 1 + .../ai/djl/pytorch/engine/PtSymbolBlock.java | 5 +-- .../java/ai/djl/pytorch/jni/IValueUtils.java | 23 ++++++++++--- .../ai/djl/pytorch/jni/PyTorchLibrary.java | 2 ++ 23 files changed, 99 insertions(+), 12 deletions(-) diff --git a/api/src/main/java/ai/djl/inference/Predictor.java b/api/src/main/java/ai/djl/inference/Predictor.java index 2e11f50739b..1b9565529aa 100644 --- a/api/src/main/java/ai/djl/inference/Predictor.java +++ b/api/src/main/java/ai/djl/inference/Predictor.java @@ -125,12 +125,24 @@ public Predictor(Model model, Translator translator, boolean copy) { */ @SuppressWarnings("PMD.AvoidRethrowingException") public O predict(I input) throws TranslateException { - return batchPredict(Collections.singletonList(input)).get(0); + return batchPredict(Collections.singletonList(input), null).get(0); } - private NDList predict(NDList ndList) { + /** + * Predicts an item for inference. + * + * @param input the input + * @return the output object defined by the user + * @throws TranslateException if an error occurs during prediction + */ + @SuppressWarnings("PMD.AvoidRethrowingException") + public O predict(I input, NDList output) throws TranslateException { + return batchPredict(Collections.singletonList(input), output).get(0); + } + + private NDList predict(NDList ndList, NDList output) { logger.trace("Predictor input data: {}", ndList); - return block.forward(parameterStore, ndList, false); + return block.forward(parameterStore, ndList, output, false); } /** @@ -142,6 +154,18 @@ private NDList predict(NDList ndList) { */ @SuppressWarnings({"PMD.AvoidRethrowingException", "PMD.IdenticalCatchBranches"}) public List batchPredict(List inputs) throws TranslateException { + return batchPredict(inputs, null); + } + + /** + * Predicts a batch for inference. + * + * @param inputs a list of inputs + * @return a list of output objects defined by the user + * @throws TranslateException if an error occurs during prediction + */ + @SuppressWarnings("PMD.AvoidRethrowingException") + public List batchPredict(List inputs, NDList output) throws TranslateException { long begin = System.nanoTime(); try (PredictorContext context = new PredictorContext()) { if (!prepared) { @@ -157,7 +181,7 @@ public List batchPredict(List inputs) throws TranslateException { NDList ndList = translator.processInput(context, input); preprocessEnd(ndList); - NDList result = predict(ndList); + NDList result = predict(ndList, output); predictEnd(result); ret.add(translator.processOutput(context, result)); @@ -170,7 +194,7 @@ public List batchPredict(List inputs) throws TranslateException { NDList inputBatch = processInputs(context, inputs); preprocessEnd(inputBatch); - NDList result = predict(inputBatch); + NDList result = predict(inputBatch, output); predictEnd(result); List ret = processOutputs(context, result); diff --git a/api/src/main/java/ai/djl/modality/nlp/Encoder.java b/api/src/main/java/ai/djl/modality/nlp/Encoder.java index 645e3aabd5b..35d3bde968a 100644 --- a/api/src/main/java/ai/djl/modality/nlp/Encoder.java +++ b/api/src/main/java/ai/djl/modality/nlp/Encoder.java @@ -57,6 +57,7 @@ public Encoder(byte version, Block block) { public NDList forward( ParameterStore parameterStore, NDList inputs, + NDList output, boolean training, PairList params) { return block.forward(parameterStore, inputs, training, params); diff --git a/api/src/main/java/ai/djl/modality/nlp/EncoderDecoder.java b/api/src/main/java/ai/djl/modality/nlp/EncoderDecoder.java index 93845caeec0..e47dfdf8c36 100644 --- a/api/src/main/java/ai/djl/modality/nlp/EncoderDecoder.java +++ b/api/src/main/java/ai/djl/modality/nlp/EncoderDecoder.java @@ -72,6 +72,7 @@ public NDList forward(ParameterStore parameterStore, NDList inputs, boolean trai public NDList forward( ParameterStore parameterStore, NDList inputs, + NDList output, boolean training, PairList params) { if (training) { diff --git a/api/src/main/java/ai/djl/modality/nlp/embedding/TrainableTextEmbedding.java b/api/src/main/java/ai/djl/modality/nlp/embedding/TrainableTextEmbedding.java index 6e1386f87e3..dc8cc12b10d 100644 --- a/api/src/main/java/ai/djl/modality/nlp/embedding/TrainableTextEmbedding.java +++ b/api/src/main/java/ai/djl/modality/nlp/embedding/TrainableTextEmbedding.java @@ -73,6 +73,7 @@ public List unembedText(NDArray textEmbedding) { public NDList forward( ParameterStore parameterStore, NDList inputs, + NDList output, boolean training, PairList params) { return trainableWordEmbedding.forward(parameterStore, inputs, training, params); diff --git a/api/src/main/java/ai/djl/nn/Block.java b/api/src/main/java/ai/djl/nn/Block.java index 778c8a98dc8..9affc5a4dee 100644 --- a/api/src/main/java/ai/djl/nn/Block.java +++ b/api/src/main/java/ai/djl/nn/Block.java @@ -114,7 +114,33 @@ public interface Block { * @return the output of the forward pass */ default NDList forward(ParameterStore parameterStore, NDList inputs, boolean training) { - return forward(parameterStore, inputs, training, null); + return forward(parameterStore, inputs, null, training, null); + } + + /** + * Applies the operating function of the block once. This method should be called only on blocks + * that are initialized. + * + * @param parameterStore the parameter store + * @param inputs the input NDList + * @param training true for a training forward pass + * @return the output of the forward pass + */ + default NDList forward(ParameterStore parameterStore, NDList inputs, NDList output, boolean training) { + return forward(parameterStore, inputs, output, training, null); + } + + /** + * Applies the operating function of the block once. This method should be called only on blocks + * that are initialized. + * + * @param parameterStore the parameter store + * @param inputs the input NDList + * @param training true for a training forward pass + * @return the output of the forward pass + */ + default NDList forward(ParameterStore parameterStore, NDList inputs, boolean training, PairList params) { + return forward(parameterStore, inputs, null, training, params); } /** @@ -130,6 +156,7 @@ default NDList forward(ParameterStore parameterStore, NDList inputs, boolean tra NDList forward( ParameterStore parameterStore, NDList inputs, + NDList output, boolean training, PairList params); diff --git a/api/src/main/java/ai/djl/nn/LambdaBlock.java b/api/src/main/java/ai/djl/nn/LambdaBlock.java index bd58fc9a428..f5c3bdfbc98 100644 --- a/api/src/main/java/ai/djl/nn/LambdaBlock.java +++ b/api/src/main/java/ai/djl/nn/LambdaBlock.java @@ -63,6 +63,7 @@ public static LambdaBlock singleton(Function lambda) { public NDList forward( ParameterStore parameterStore, NDList inputs, + NDList output, boolean training, PairList params) { return lambda.apply(inputs); diff --git a/api/src/main/java/ai/djl/nn/ParallelBlock.java b/api/src/main/java/ai/djl/nn/ParallelBlock.java index 5061dc8d828..b7fc8d44bbf 100644 --- a/api/src/main/java/ai/djl/nn/ParallelBlock.java +++ b/api/src/main/java/ai/djl/nn/ParallelBlock.java @@ -116,6 +116,7 @@ public final ParallelBlock add(Function f) { public NDList forward( ParameterStore parameterStore, NDList inputs, + NDList output, boolean training, PairList params) { return function.apply( diff --git a/api/src/main/java/ai/djl/nn/SequentialBlock.java b/api/src/main/java/ai/djl/nn/SequentialBlock.java index e18317516d8..7102ede786d 100644 --- a/api/src/main/java/ai/djl/nn/SequentialBlock.java +++ b/api/src/main/java/ai/djl/nn/SequentialBlock.java @@ -126,6 +126,7 @@ public void replaceLastBlock(Block block) { public NDList forward( ParameterStore parameterStore, NDList inputs, + NDList output, boolean training, PairList params) { NDList current = inputs; diff --git a/api/src/main/java/ai/djl/nn/convolutional/Convolution.java b/api/src/main/java/ai/djl/nn/convolutional/Convolution.java index 101d5dd9446..30bd0e2eb86 100644 --- a/api/src/main/java/ai/djl/nn/convolutional/Convolution.java +++ b/api/src/main/java/ai/djl/nn/convolutional/Convolution.java @@ -135,6 +135,7 @@ public Convolution(ConvolutionBuilder builder) { public NDList forward( ParameterStore parameterStore, NDList inputs, + NDList output, boolean training, PairList params) { NDArray input = inputs.singletonOrThrow(); diff --git a/api/src/main/java/ai/djl/nn/convolutional/Deconvolution.java b/api/src/main/java/ai/djl/nn/convolutional/Deconvolution.java index 7ac983834c3..e1f5b2e0f03 100644 --- a/api/src/main/java/ai/djl/nn/convolutional/Deconvolution.java +++ b/api/src/main/java/ai/djl/nn/convolutional/Deconvolution.java @@ -114,6 +114,7 @@ public Deconvolution(DeconvolutionBuilder builder) { public NDList forward( ParameterStore parameterStore, NDList inputs, + NDList output, boolean training, PairList params) { NDArray input = inputs.singletonOrThrow(); diff --git a/api/src/main/java/ai/djl/nn/core/ConstantEmbedding.java b/api/src/main/java/ai/djl/nn/core/ConstantEmbedding.java index a0839479dfa..6e943ad644b 100644 --- a/api/src/main/java/ai/djl/nn/core/ConstantEmbedding.java +++ b/api/src/main/java/ai/djl/nn/core/ConstantEmbedding.java @@ -45,6 +45,7 @@ public ConstantEmbedding(NDArray embedding) { public NDList forward( ParameterStore parameterStore, NDList inputs, + NDList output, boolean training, PairList params) { NDManager manager = inputs.get(0).getManager(); diff --git a/api/src/main/java/ai/djl/nn/core/Embedding.java b/api/src/main/java/ai/djl/nn/core/Embedding.java index 077c338903a..d2076c1e4f0 100644 --- a/api/src/main/java/ai/djl/nn/core/Embedding.java +++ b/api/src/main/java/ai/djl/nn/core/Embedding.java @@ -126,6 +126,7 @@ public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) { public NDList forward( ParameterStore parameterStore, NDList inputs, + NDList output, boolean training, PairList params) { NDList opInputs = opInputs(parameterStore, inputs, training); diff --git a/api/src/main/java/ai/djl/nn/core/Linear.java b/api/src/main/java/ai/djl/nn/core/Linear.java index 2e844890a3b..ed67ab2337f 100644 --- a/api/src/main/java/ai/djl/nn/core/Linear.java +++ b/api/src/main/java/ai/djl/nn/core/Linear.java @@ -75,6 +75,7 @@ public class Linear extends AbstractBlock { public NDList forward( ParameterStore parameterStore, NDList inputs, + NDList output, boolean training, PairList params) { NDArray input = inputs.singletonOrThrow(); diff --git a/api/src/main/java/ai/djl/nn/core/Prelu.java b/api/src/main/java/ai/djl/nn/core/Prelu.java index 7de18e7aa41..dde87cabc39 100644 --- a/api/src/main/java/ai/djl/nn/core/Prelu.java +++ b/api/src/main/java/ai/djl/nn/core/Prelu.java @@ -52,6 +52,7 @@ public Prelu() { public NDList forward( ParameterStore parameterStore, NDList inputs, + NDList output, boolean training, PairList params) { NDArray input = inputs.singletonOrThrow(); diff --git a/api/src/main/java/ai/djl/nn/norm/BatchNorm.java b/api/src/main/java/ai/djl/nn/norm/BatchNorm.java index 1bb582ecb28..917e9f7e09c 100644 --- a/api/src/main/java/ai/djl/nn/norm/BatchNorm.java +++ b/api/src/main/java/ai/djl/nn/norm/BatchNorm.java @@ -110,6 +110,7 @@ public class BatchNorm extends AbstractBlock { public NDList forward( ParameterStore parameterStore, NDList inputs, + NDList output, boolean training, PairList params) { NDArray input = inputs.singletonOrThrow(); diff --git a/api/src/main/java/ai/djl/nn/norm/Dropout.java b/api/src/main/java/ai/djl/nn/norm/Dropout.java index a1e51c9ef14..c04cad000a1 100644 --- a/api/src/main/java/ai/djl/nn/norm/Dropout.java +++ b/api/src/main/java/ai/djl/nn/norm/Dropout.java @@ -66,6 +66,7 @@ public class Dropout extends AbstractBlock { public NDList forward( ParameterStore parameterStore, NDList inputs, + NDList output, boolean training, PairList params) { return dropout(inputs.singletonOrThrow(), rate, training); diff --git a/api/src/main/java/ai/djl/nn/recurrent/RecurrentBlock.java b/api/src/main/java/ai/djl/nn/recurrent/RecurrentBlock.java index f3ba165b1d8..ea896f9081d 100644 --- a/api/src/main/java/ai/djl/nn/recurrent/RecurrentBlock.java +++ b/api/src/main/java/ai/djl/nn/recurrent/RecurrentBlock.java @@ -124,6 +124,7 @@ public final void setStateOutputs(boolean stateOutputs) { public NDList forward( ParameterStore parameterStore, NDList inputs, + NDList o, boolean training, PairList params) { inputs = opInputs(parameterStore, inputs, training); diff --git a/api/src/main/java/ai/djl/nn/transformer/ScaledDotProductAttentionBlock.java b/api/src/main/java/ai/djl/nn/transformer/ScaledDotProductAttentionBlock.java index 13be24d55de..04c7f018cd4 100644 --- a/api/src/main/java/ai/djl/nn/transformer/ScaledDotProductAttentionBlock.java +++ b/api/src/main/java/ai/djl/nn/transformer/ScaledDotProductAttentionBlock.java @@ -196,6 +196,7 @@ private NDArray createAttentionHeadsFromEmbeddings( public NDList forward( ParameterStore parameterStore, NDList inputs, + NDList output, boolean training, PairList params) { // E=embedding size diff --git a/model-zoo/src/main/java/ai/djl/basicmodelzoo/cv/object_detection/ssd/SingleShotDetection.java b/model-zoo/src/main/java/ai/djl/basicmodelzoo/cv/object_detection/ssd/SingleShotDetection.java index a254f3320d0..911629021da 100644 --- a/model-zoo/src/main/java/ai/djl/basicmodelzoo/cv/object_detection/ssd/SingleShotDetection.java +++ b/model-zoo/src/main/java/ai/djl/basicmodelzoo/cv/object_detection/ssd/SingleShotDetection.java @@ -69,6 +69,7 @@ private SingleShotDetection(Builder builder) { public NDList forward( ParameterStore parameterStore, NDList inputs, + NDList output, boolean training, PairList params) { NDList networkOutput = inputs; diff --git a/model-zoo/src/main/java/ai/djl/basicmodelzoo/nlp/SimpleTextDecoder.java b/model-zoo/src/main/java/ai/djl/basicmodelzoo/nlp/SimpleTextDecoder.java index 6f1820d67d2..eb2a4df6f61 100644 --- a/model-zoo/src/main/java/ai/djl/basicmodelzoo/nlp/SimpleTextDecoder.java +++ b/model-zoo/src/main/java/ai/djl/basicmodelzoo/nlp/SimpleTextDecoder.java @@ -86,6 +86,7 @@ public void initState(NDList encoderStates) { public NDList forward( ParameterStore parameterStore, NDList inputs, + NDList o, boolean training, PairList params) { if (training) { diff --git a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java index 503d0ad9e77..7f951af70f5 100644 --- a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java +++ b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java @@ -87,6 +87,7 @@ public void removeLastBlock() { public NDList forward( ParameterStore parameterStore, NDList inputs, + NDList output, boolean training, PairList params) { // TODO refactor the forward to not take ParameterStore @@ -106,7 +107,7 @@ public NDList forward( for (NDArray array : inputs) { inputDescriptions.add(array.getName(), array.getShape()); } - NDList outputs = IValueUtils.forward(this, inputs, training); + NDList outputs = IValueUtils.forward(this, inputs, output, training); for (NDArray array : outputs) { outputDescriptions.add(array.getName(), array.getShape()); } @@ -115,7 +116,7 @@ public NDList forward( } } } - return IValueUtils.forward(this, inputs, training); + return IValueUtils.forward(this, inputs, output, training); } /** {@inheritDoc} */ diff --git a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/IValueUtils.java b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/IValueUtils.java index e1898713b2e..805c87487f7 100644 --- a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/IValueUtils.java +++ b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/IValueUtils.java @@ -136,6 +136,16 @@ public static PtNDArray toNDArray(long iValueHandle, PtNDManager manager) { return new PtNDArray(manager, ndHandle); } + /** + * Extract IValue with a {@link PtNDArray} value. + * + * @param arrayHandle array handle + * @param iValueHandle IValue {@link Pointer} + */ + public static void toNDArrayCopy(Pointer arrayHandle, Pointer iValueHandle) { + PyTorchLibrary.LIB.iValueToTensorCopy(arrayHandle, iValueHandle); + } + /** * Extract IValue to {@link NDList}. * @@ -190,15 +200,20 @@ public static Map toIValueMap(long iValueHandle) { return map; } - private static NDList forwardHelper(long iValueHandle, PtNDManager manager) { + private static NDList forwardHelper(long iValueHandle, NDList output, PtNDManager manager) { NDList list = new NDList(); if (isNDArray(iValueHandle)) { + if (output != null) { + toNDArrayCopy(((PtNDArray)output.get(0)).getHandle(), iValueHandle); + PyTorchLibrary.LIB.torchDeleteIValue(iValueHandle); + return output; + } list.add(toNDArray(iValueHandle, manager)); } else if (isNDList(iValueHandle)) { list.addAll(toNDList(iValueHandle, manager)); } else if (isList(iValueHandle) || isTuple(iValueHandle)) { for (long handle : toIValueArray(iValueHandle)) { - list.addAll(forwardHelper(handle, manager)); + list.addAll(forwardHelper(handle, output, manager)); } } else if (isMap(iValueHandle)) { // Only allows type of map @@ -231,14 +246,14 @@ private static NDList forwardHelper(long iValueHandle, PtNDManager manager) { * @param isTrain is running on training mode * @return result {@link NDList} */ - public static NDList forward(PtSymbolBlock block, NDList inputs, boolean isTrain) { + public static NDList forward(PtSymbolBlock block, NDList inputs, NDList output, boolean isTrain) { long[] arrayHandles = inputs.stream().mapToLong(input -> ((PtNDArray) input).getHandle()).toArray(); String[] names = inputs.stream().map(NDArray::getName).toArray(String[]::new); long[] iValueInputs = getInputs(arrayHandles, names); long result = PyTorchLibrary.LIB.moduleForward(block.getHandle(), iValueInputs, isTrain); PtNDManager manager = (PtNDManager) inputs.get(0).getManager(); - return forwardHelper(result, manager); + return forwardHelper(result, output, manager); } private static boolean isNameList(String name) { diff --git a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java index 0cbf6ab5696..5f63c934d37 100644 --- a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java +++ b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java @@ -434,6 +434,8 @@ native long moduleLoad( native long iValueToTensor(long iValueHandle); + native void iValueToTensorCopy(long iValueHandle, long tensorHandle); + native long[] iValueToTensorList(long iValueHandle); native long[] iValueToList(long iValueHandle);