diff --git a/engine/src/nn/mxnetapi.cpp b/engine/src/nn/mxnetapi.cpp index 1d5d7bbe..77df7b6c 100644 --- a/engine/src/nn/mxnetapi.cpp +++ b/engine/src/nn/mxnetapi.cpp @@ -157,7 +157,7 @@ void MXNetAPI::init_nn_design() set_shape(nnDesign.valueOutputShape, executor->outputs[nnDesign.valueOutputIdx].GetShape()); nnDesign.hasAuxiliaryOutputs = executor->outputs.size() > 2; if (nnDesign.hasAuxiliaryOutputs) { - set_shape(nnDesign.valueOutputShape, executor->outputs[nnDesign.auxiliaryOutputIdx].GetShape()); + set_shape(nnDesign.auxiliaryOutputShape, executor->outputs[nnDesign.auxiliaryOutputIdx].GetShape()); } float* inputPlanes = new float[batchSize*StateConstants::NB_VALUES_TOTAL()];