From 5c0e4065e2ce7cf85e0f95b2b1e062d8fc975b6b Mon Sep 17 00:00:00 2001 From: Mark Hamilton Date: Thu, 28 Jun 2018 11:46:54 +0000 Subject: [PATCH] Enable CNTK memory re-use --- src/cntk-model/src/main/scala/CNTKModel.scala | 88 ++++++---- .../src/main/scala/Conversions.scala | 159 ++++++++++++++++++ .../src/main/scala/SerializableFunction.scala | 80 --------- .../src/test/scala/CNTKBindingSuite.scala | 27 ++- .../src/test/scala/CNTKModelSuite.scala | 4 +- 5 files changed, 230 insertions(+), 128 deletions(-) create mode 100644 src/cntk-model/src/main/scala/Conversions.scala diff --git a/src/cntk-model/src/main/scala/CNTKModel.scala b/src/cntk-model/src/main/scala/CNTKModel.scala index 5129cf95be..987576d3a7 100644 --- a/src/cntk-model/src/main/scala/CNTKModel.scala +++ b/src/cntk-model/src/main/scala/CNTKModel.scala @@ -4,10 +4,12 @@ package com.microsoft.ml.spark import com.microsoft.CNTK.CNTKExtensions._ -import com.microsoft.CNTK.CNTKExtensions import com.microsoft.CNTK.CNTKUtils._ -import com.microsoft.CNTK.{DataType => CNTKDataType, SerializableFunction => CNTKFunction, _} +import com.microsoft.CNTK.{CNTKExtensions, DataType => CNTKDataType, SerializableFunction => CNTKFunction, _} +import com.microsoft.ml.spark.ConversionUtils.GVV import com.microsoft.ml.spark.schema.DatasetExtensions +import com.microsoft.ml.spark.schema.DatasetExtensions.findUnusedColumnName +import org.apache.spark.SparkContext import org.apache.spark.broadcast._ import org.apache.spark.ml.Model import org.apache.spark.ml.linalg.DenseVector @@ -20,17 +22,15 @@ import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import spray.json.DefaultJsonProtocol._ -import DatasetExtensions.findUnusedColumnName -import org.apache.spark.SparkContext import scala.collection.JavaConversions._ private object CNTKModelUtils extends java.io.Serializable { def applyCNTKFunction(model: CNTKFunction, - feedDict: Map[Variable, GenericVectorVector], + feedDict: Map[Variable, GVV], outputVars: List[Variable], - device: DeviceDescriptor): List[GenericVectorVector] = { + device: DeviceDescriptor): List[GVV] = { val valueMap = feedDict.map { case (v, gvv) => gvv match { @@ -47,18 +47,25 @@ private object CNTKModelUtils extends java.io.Serializable { outputVars.foreach(ov => outputDataMap.add(ov, null)) model.evaluate(inputDataMap, outputDataMap, device) - outputVars.map { ov: Variable => + val out = outputVars.map { ov: Variable => ov.getDataType match { case CNTKDataType.Float => - val fvv = new FloatVectorVector() - outputDataMap.getitem(ov).copyVariableValueToFloat(ov, fvv) + val fvv = new FloatVectorVector() //TODO try re-using + val value = outputDataMap.getitem(ov) + value.copyVariableValueToFloat(ov, fvv) + value.delete() Left(fvv) case CNTKDataType.Double => - val dvv = new DoubleVectorVector() - outputDataMap.getitem(ov).copyVariableValueToDouble(ov, dvv) + val dvv = new DoubleVectorVector() //TODO try re-using + val value = outputDataMap.getitem(ov) + value.copyVariableValueToDouble(ov, dvv) + value.delete() Right(dvv) } } + + valueMap.values.foreach(_.delete()) + out } def applyModel(inputMap: Map[String, Int], @@ -76,34 +83,45 @@ private object CNTKModelUtils extends java.io.Serializable { val m = CNTKExtensions.fromSerializable(broadcastedModel.value).clone(ParameterCloningMethod.Share) val inputMapVar = inputMap.map { case (k, v) => v -> m.getInputVar(k) } + val outputMapVar = outputMap.map { case (k, v) => m.getOutputVar(v) -> k } - val preprocessFunction: Row => Map[Variable, GenericVectorVector] = { - val inputExtractors = inputMapVar.map { - case (colnum, variable) => - variable -> { - variable.getDataType match { - case CNTKDataType.Float => - r: Row => SSFToGVV(r.getAs[Seq[Seq[Float]]](colnum)) - case CNTKDataType.Double => - r: Row => SSDToGVV(r.getAs[Seq[Seq[Double]]](colnum)) - } - } + val inputExtractors = inputMapVar.map { + case (colnum, variable) => variable -> { + variable.getDataType match { + case CNTKDataType.Float => + r: Row => Left(r.getAs[Seq[Seq[Float]]](colnum)) + case CNTKDataType.Double => + r: Row => Right(r.getAs[Seq[Seq[Double]]](colnum)) + } } + } + val inputGVVs = inputMapVar.map { + case (colnum, variable) => variable -> { + variable.getDataType match { + case CNTKDataType.Float => + Left(new FloatVectorVector()) + case CNTKDataType.Double => + Right(new DoubleVectorVector()) + } + } + } - { row: Row => inputExtractors.mapValues(f => f(row)) } + // WARNING: DO NOT simplify this to mapValues, + // for some reason it calls the inner function more than it should + val preprocessFunction: (Row) => Map[Variable, GVV] = { + { row: Row => inputExtractors.map { case (k,f) => + k -> ConversionUtils.toGVV(f(row), inputGVVs(k)) }} } val outputVars = outputMapVar.keys.toList - val floatConverter = if (convertToDenseVector) { - { fvv: FloatVectorVector => toSeqDV(fvv) } - } else { - { fvv: FloatVectorVector => toSeqSeq(fvv) } - } - val doubleConverter = if (convertToDenseVector) { - { dvv: DoubleVectorVector => toSeqDV(dvv) } + val converter = if (convertToDenseVector) { + { gvv: GVV => ConversionUtils.toDV(gvv) } } else { - { dvv: DoubleVectorVector => toSeqSeq(dvv) } + { gvv: GVV => ConversionUtils.toSSG(gvv) match { + case Left(ssf) => ssf + case Right(ssd) => ssd + }} } val outputVarVector = new VariableVector() @@ -113,13 +131,9 @@ private object CNTKModelUtils extends java.io.Serializable { inputRows.map { row => val feedDict = preprocessFunction(row) val outputGVVs = applyCNTKFunction(of, feedDict, outputVars, device) - val resultRow = Row(outputGVVs.map { - case Left(vv) => floatConverter(vv) - case Right(vv) => doubleConverter(vv) - }:_*) + val resultRow = Row(outputGVVs.map(converter): _*) val outputRow = Row.merge(row, resultRow) - feedDict.values.foreach(deleteGVV) - outputGVVs.foreach(deleteGVV) + outputGVVs.foreach(ConversionUtils.deleteGVV) outputRow } } diff --git a/src/cntk-model/src/main/scala/Conversions.scala b/src/cntk-model/src/main/scala/Conversions.scala new file mode 100644 index 0000000000..84619263a8 --- /dev/null +++ b/src/cntk-model/src/main/scala/Conversions.scala @@ -0,0 +1,159 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.ml.spark + +import com.microsoft.CNTK.{DoubleVector, DoubleVectorVector, FloatVector, FloatVectorVector} +import org.apache.spark.ml.linalg.DenseVector + +import scala.collection.mutable +import scala.language.implicitConversions + +object ConversionUtils { + + type GVV = Either[FloatVectorVector, DoubleVectorVector] + + type SSG = Either[Seq[Seq[Float]], Seq[Seq[Double]]] + + def toSSG(gvv: GVV): SSG = { + gvv match { + case Left(vv) => + Left((0 until vv.size.toInt).map { i => + val v = vv.get(i) + (0 until v.size.toInt).map { j => + v.get(j) + } + }) + case Right(vv) => + Right((0 until vv.size.toInt).map { i => + val v = vv.get(i) + (0 until v.size.toInt).map { j => + v.get(j) + } + }) + } + } + + def deleteGVV(gvv: GVV): Unit = { + gvv match { + case Left(fvv) => fvv.clear(); fvv.delete() + case Right(dvv) => dvv.clear(); dvv.delete() + } + } + + def toDV(gvv: GVV): Seq[DenseVector] = { + gvv match { + case Left(vv) => + (0 until vv.size.toInt).map { i => + val v = vv.get(i) + new DenseVector((0 until v.size.toInt).map { j => + v.get(j).toDouble + }.toArray) + } + case Right(vv) => + (0 until vv.size.toInt).map { i => + val v = vv.get(i) + new DenseVector((0 until v.size.toInt).map { j => + v.get(j) + }.toArray) + } + } + + } + + def toFV(v: Seq[Float], fv: FloatVector): FloatVector = { + val vs = v.size + val fvs = fv.size() + if (fvs==vs) { + () + v.zipWithIndex.foreach(p => fv.set(p._2, p._1)) + } else if (fvs>vs) { + fv.clear() + fv.reserve(vs.toLong) + v.foreach(fv.add) + } else { + fv.reserve(vs.toLong) + (0 until fvs.toInt).foreach(i => fv.set(i, v(i))) + (fvs.toInt until vs).foreach(i => fv.add(v(i))) + } + fv + } + + def toDV(v: Seq[Double], fv: DoubleVector): DoubleVector = { + val vs = v.size + val fvs = fv.size() + if (fvs==vs) { + () + v.zipWithIndex.foreach(p => fv.set(p._2, p._1)) + } else if (fvs>vs) { + fv.clear() + fv.reserve(vs.toLong) + v.foreach(fv.add) + } else { + fv.reserve(vs.toLong) + (0 until fvs.toInt).foreach(i => fv.set(i, v(i))) + (fvs.toInt until vs).foreach(i => fv.add(v(i))) + } + fv + } + + def toFV(v: Seq[Float]): FloatVector = { + val fv = new FloatVector(v.length.toLong) + v.zipWithIndex.foreach(p=>fv.set(p._2,p._1)) + fv + } + + def toDV(v: Seq[Double]): DoubleVector = { + val fv = new DoubleVector(v.length.toLong) + v.zipWithIndex.foreach(p=>fv.set(p._2,p._1)) + fv + } + + def toFVV(vv: Seq[Seq[Float]], fvv: FloatVectorVector): FloatVectorVector = { + val vvs = vv.size + val fvvs = fvv.size() + if (fvvs==vvs) { + () + vv.zipWithIndex.foreach(p=>toFV(p._1,fvv.get(p._2))) + } else if (fvvs>vvs) { + fvv.clear() + fvv.reserve(vvs.toLong) + vv.foreach { v => fvv.add(toFV(v))} + } else { + fvv.reserve(vvs.toLong) + (0 until fvvs.toInt).foreach(i => fvv.set(i, toFV(vv(i),fvv.get(i)))) + (fvvs.toInt until vvs).foreach(i => fvv.add(toFV(vv(i)))) + } + fvv + } + + def toDVV(vv: Seq[Seq[Double]], fvv: DoubleVectorVector): DoubleVectorVector = { + val vvs = vv.size + val fvvs = fvv.size() + if (fvvs==vvs) { + () + vv.zipWithIndex.foreach(p=>toDV(p._1,fvv.get(p._2))) + } else if (fvvs>vvs) { + fvv.clear() + fvv.reserve(vvs.toLong) + vv.foreach { v => fvv.add(toDV(v))} + } else { + fvv.reserve(vvs.toLong) + (0 until fvvs.toInt).foreach(i => fvv.set(i, toDV(vv(i),fvv.get(i)))) + (fvvs.toInt until vvs).foreach(i => fvv.add(toDV(vv(i)))) + } + fvv + } + + def toGVV(garr: SSG, existingGVV:GVV): GVV = { + (garr, existingGVV) match { + case (Left(arr), Left(fvv)) => + Left(toFVV(arr,fvv)) + case (Right(arr), Right(fvv)) => + Right(toDVV(arr,fvv)) + case _ => + throw new IllegalArgumentException("Need to have matching arrays and VectorVectors") + } + } + +} diff --git a/src/cntk-model/src/main/scala/SerializableFunction.scala b/src/cntk-model/src/main/scala/SerializableFunction.scala index dfdcf7a267..b07ab37868 100644 --- a/src/cntk-model/src/main/scala/SerializableFunction.scala +++ b/src/cntk-model/src/main/scala/SerializableFunction.scala @@ -10,7 +10,6 @@ import com.microsoft.CNTK.CNTKUtils._ import com.microsoft.ml.spark.StreamUtilities.using import org.apache.commons.io.FileUtils.{forceDelete, getTempDirectoryPath, writeByteArrayToFile} import org.apache.commons.io.IOUtils -import org.apache.spark.ml.linalg.DenseVector import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType, StructField, DataType => SDataType} import scala.language.implicitConversions @@ -23,85 +22,6 @@ object CNTKExtensions { object CNTKUtils { - type GenericVectorVector = Either[FloatVectorVector, DoubleVectorVector] - - def deleteGVV(fvv: GenericVectorVector): Unit = { - fvv match { - case Left(vv) => - val size = vv.size().toInt - (0 until size).foreach{i => - vv.get(i).delete() - } - vv.delete() - case Right(vv) => - val size = vv.size().toInt - (0 until size).foreach{i => - vv.get(i).delete() - } - vv.delete() - } - } - - def toSeqSeq(vv: FloatVectorVector): Seq[Seq[Float]] = { - (0 until vv.size.toInt).map { i => - val v = vv.get(i) - (0 until v.size.toInt).map { j => - v.get(j) - } - } - } - - def toSeqDV(vv: FloatVectorVector): Seq[DenseVector] = { - (0 until vv.size.toInt).map { i => - val v = vv.get(i) - new DenseVector((0 until v.size.toInt).map { j => - v.get(j).toDouble - }.toArray) - } - } - - def toSeqSeq(vv: DoubleVectorVector): Seq[Seq[Double]] = { - (0 until vv.size.toInt).map { i => - val v = vv.get(i) - (0 until v.size.toInt).map { j => - v.get(j) - } - } - } - - def toSeqDV(vv: DoubleVectorVector): Seq[DenseVector] = { - (0 until vv.size.toInt).map { i => - val v = vv.get(i) - new DenseVector((0 until v.size.toInt).map { j => - v.get(j) - }.toArray) - } - } - - def SSFToGVV(arr: Seq[Seq[Float]]): GenericVectorVector = { - val inputFVV = new FloatVectorVector(arr.length.toLong) - arr.zipWithIndex.foreach { case (vect, i) => - val fv = new FloatVector(vect.length.toLong) - vect.zipWithIndex.foreach { case (x, j) => - fv.set(j, x) - } - inputFVV.set(i, fv) - } - Left(inputFVV) - } - - def SSDToGVV(arr: Seq[Seq[Double]]): GenericVectorVector = { - val inputDVV = new DoubleVectorVector(arr.length.toLong) - arr.zipWithIndex.foreach { case (vect, i) => - val dv = new DoubleVector(vect.length.toLong) - vect.zipWithIndex.foreach { case (x, j) => - dv.set(j, x) - } - inputDVV.set(i, dv) - } - Right(inputDVV) - } - def loadModelFromBytes(bytes: Array[Byte], device: DeviceDescriptor = DeviceDescriptor.useDefaultDevice): Function = { diff --git a/src/cntk-model/src/test/scala/CNTKBindingSuite.scala b/src/cntk-model/src/test/scala/CNTKBindingSuite.scala index 075d5a4d03..924f0325f7 100644 --- a/src/cntk-model/src/test/scala/CNTKBindingSuite.scala +++ b/src/cntk-model/src/test/scala/CNTKBindingSuite.scala @@ -28,23 +28,33 @@ class CNTKBindingSuite extends LinuxOnly with CNTKTestUtils { } } - def randomSeqSeq(outerSize: Int, innerSize: Int = 32 * 32 * 3, seed: Int=123): Seq[Seq[Float]] = { + def randomSeqSeq(outerSize: Int, dim: Int, seed: Int): Seq[Seq[Float]] = { val r = scala.util.Random r.setSeed(seed.toLong) (1 to outerSize).map(i => { - (1 to innerSize).map(j => { + (1 to dim).map(j => { r.nextFloat() }) }) } + def randomFVV(batchSize: Int, dim: Int, seed: Int): FloatVectorVector = { + toFVV(randomSeqSeq(batchSize, dim, seed)) + } + def evaluateRandomMinibatch(model: CNTKFunction, batchSize: Int, outputNum: Int = 3, seed: Int = 123): Seq[Seq[Float]] = { + evaluateFVV( + model, randomFVV(batchSize,32*32*3, seed), new FloatVectorVector(), outputNum) + } + + def evaluateFVV(model: CNTKFunction, + inFvv: FloatVectorVector, + outFvv: FloatVectorVector, + outputNum: Int = 3): Seq[Seq[Float]] = { val inputVar = model.getArguments()(0) val inputShape = inputVar.getShape - val fakeImages = randomSeqSeq(batchSize, seed=seed) - val inputFVV = toFVV(fakeImages) - val inputVal = Value.createDenseFloat(inputShape, inputFVV, DeviceDescriptor.getCPUDevice) + val inputVal = Value.createDenseFloat(inputShape, inFvv, DeviceDescriptor.getCPUDevice) val inputDataMap = new UnorderedMapVariableValuePtr() inputDataMap.add(inputVar, inputVal) @@ -52,11 +62,10 @@ class CNTKBindingSuite extends LinuxOnly with CNTKTestUtils { val outputVar = model.getOutputs.get(outputNum) outputDataMap.add(outputVar, null) - println(s"evaluating shape ${inputVal.getShape().getDimensions.toList}") + println(s"evaluating shape ${inputVal.getShape.getDimensions.toList}") model.evaluate(inputDataMap, outputDataMap, DeviceDescriptor.getCPUDevice) - val outputFVV = new FloatVectorVector() - outputDataMap.getitem(outputVar).copyVariableValueToFloat(outputVar, outputFVV) - toSeqSeq(outputFVV) + outputDataMap.getitem(outputVar).copyVariableValueToFloat(outputVar, outFvv) + toSeqSeq(outFvv) } test(" A serializable CNTKModel should be serializable") { diff --git a/src/cntk-model/src/test/scala/CNTKModelSuite.scala b/src/cntk-model/src/test/scala/CNTKModelSuite.scala index 2aeaae38ff..75a088aa37 100644 --- a/src/cntk-model/src/test/scala/CNTKModelSuite.scala +++ b/src/cntk-model/src/test/scala/CNTKModelSuite.scala @@ -36,9 +36,9 @@ class CNTKModelSuite extends LinuxOnly with CNTKTestUtils with TransformerFuzzin test("A CNTK model should be able to support setting the input and output node") { val model = testModel().setInputNodeIndex(0) - val data = makeFakeData(session, 3, featureVectorLength) + val data = makeFakeData(session, 30, featureVectorLength) val result = model.transform(data) - assert(result.select(outputCol).count() == 3) + assert(result.select(outputCol).count() == 30) } test("A CNTK model should support finding a node by name") {