Skip to content

Commit

Permalink
Enable CNTK memory re-use
Browse files Browse the repository at this point in the history
  • Loading branch information
mhamilton723 committed Jun 28, 2018
1 parent 37e8f04 commit 5c0e406
Show file tree
Hide file tree
Showing 5 changed files with 230 additions and 128 deletions.
88 changes: 51 additions & 37 deletions src/cntk-model/src/main/scala/CNTKModel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
package com.microsoft.ml.spark

import com.microsoft.CNTK.CNTKExtensions._
import com.microsoft.CNTK.CNTKExtensions
import com.microsoft.CNTK.CNTKUtils._
import com.microsoft.CNTK.{DataType => CNTKDataType, SerializableFunction => CNTKFunction, _}
import com.microsoft.CNTK.{CNTKExtensions, DataType => CNTKDataType, SerializableFunction => CNTKFunction, _}
import com.microsoft.ml.spark.ConversionUtils.GVV
import com.microsoft.ml.spark.schema.DatasetExtensions
import com.microsoft.ml.spark.schema.DatasetExtensions.findUnusedColumnName
import org.apache.spark.SparkContext
import org.apache.spark.broadcast._
import org.apache.spark.ml.Model
import org.apache.spark.ml.linalg.DenseVector
Expand All @@ -20,17 +22,15 @@ import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import spray.json.DefaultJsonProtocol._
import DatasetExtensions.findUnusedColumnName
import org.apache.spark.SparkContext

import scala.collection.JavaConversions._

private object CNTKModelUtils extends java.io.Serializable {

def applyCNTKFunction(model: CNTKFunction,
feedDict: Map[Variable, GenericVectorVector],
feedDict: Map[Variable, GVV],
outputVars: List[Variable],
device: DeviceDescriptor): List[GenericVectorVector] = {
device: DeviceDescriptor): List[GVV] = {

val valueMap = feedDict.map { case (v, gvv) =>
gvv match {
Expand All @@ -47,18 +47,25 @@ private object CNTKModelUtils extends java.io.Serializable {
outputVars.foreach(ov => outputDataMap.add(ov, null))
model.evaluate(inputDataMap, outputDataMap, device)

outputVars.map { ov: Variable =>
val out = outputVars.map { ov: Variable =>
ov.getDataType match {
case CNTKDataType.Float =>
val fvv = new FloatVectorVector()
outputDataMap.getitem(ov).copyVariableValueToFloat(ov, fvv)
val fvv = new FloatVectorVector() //TODO try re-using
val value = outputDataMap.getitem(ov)
value.copyVariableValueToFloat(ov, fvv)
value.delete()
Left(fvv)
case CNTKDataType.Double =>
val dvv = new DoubleVectorVector()
outputDataMap.getitem(ov).copyVariableValueToDouble(ov, dvv)
val dvv = new DoubleVectorVector() //TODO try re-using
val value = outputDataMap.getitem(ov)
value.copyVariableValueToDouble(ov, dvv)
value.delete()
Right(dvv)
}
}

valueMap.values.foreach(_.delete())
out
}

def applyModel(inputMap: Map[String, Int],
Expand All @@ -76,34 +83,45 @@ private object CNTKModelUtils extends java.io.Serializable {
val m = CNTKExtensions.fromSerializable(broadcastedModel.value).clone(ParameterCloningMethod.Share)

val inputMapVar = inputMap.map { case (k, v) => v -> m.getInputVar(k) }

val outputMapVar = outputMap.map { case (k, v) => m.getOutputVar(v) -> k }

val preprocessFunction: Row => Map[Variable, GenericVectorVector] = {
val inputExtractors = inputMapVar.map {
case (colnum, variable) =>
variable -> {
variable.getDataType match {
case CNTKDataType.Float =>
r: Row => SSFToGVV(r.getAs[Seq[Seq[Float]]](colnum))
case CNTKDataType.Double =>
r: Row => SSDToGVV(r.getAs[Seq[Seq[Double]]](colnum))
}
}
val inputExtractors = inputMapVar.map {
case (colnum, variable) => variable -> {
variable.getDataType match {
case CNTKDataType.Float =>
r: Row => Left(r.getAs[Seq[Seq[Float]]](colnum))
case CNTKDataType.Double =>
r: Row => Right(r.getAs[Seq[Seq[Double]]](colnum))
}
}
}
val inputGVVs = inputMapVar.map {
case (colnum, variable) => variable -> {
variable.getDataType match {
case CNTKDataType.Float =>
Left(new FloatVectorVector())
case CNTKDataType.Double =>
Right(new DoubleVectorVector())
}
}
}

{ row: Row => inputExtractors.mapValues(f => f(row)) }
// WARNING: DO NOT simplify this to mapValues,
// for some reason it calls the inner function more than it should
val preprocessFunction: (Row) => Map[Variable, GVV] = {
{ row: Row => inputExtractors.map { case (k,f) =>
k -> ConversionUtils.toGVV(f(row), inputGVVs(k)) }}
}

val outputVars = outputMapVar.keys.toList
val floatConverter = if (convertToDenseVector) {
{ fvv: FloatVectorVector => toSeqDV(fvv) }
} else {
{ fvv: FloatVectorVector => toSeqSeq(fvv) }
}
val doubleConverter = if (convertToDenseVector) {
{ dvv: DoubleVectorVector => toSeqDV(dvv) }
val converter = if (convertToDenseVector) {
{ gvv: GVV => ConversionUtils.toDV(gvv) }
} else {
{ dvv: DoubleVectorVector => toSeqSeq(dvv) }
{ gvv: GVV => ConversionUtils.toSSG(gvv) match {
case Left(ssf) => ssf
case Right(ssd) => ssd
}}
}

val outputVarVector = new VariableVector()
Expand All @@ -113,13 +131,9 @@ private object CNTKModelUtils extends java.io.Serializable {
inputRows.map { row =>
val feedDict = preprocessFunction(row)
val outputGVVs = applyCNTKFunction(of, feedDict, outputVars, device)
val resultRow = Row(outputGVVs.map {
case Left(vv) => floatConverter(vv)
case Right(vv) => doubleConverter(vv)
}:_*)
val resultRow = Row(outputGVVs.map(converter): _*)
val outputRow = Row.merge(row, resultRow)
feedDict.values.foreach(deleteGVV)
outputGVVs.foreach(deleteGVV)
outputGVVs.foreach(ConversionUtils.deleteGVV)
outputRow
}
}
Expand Down
159 changes: 159 additions & 0 deletions src/cntk-model/src/main/scala/Conversions.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark

import com.microsoft.CNTK.{DoubleVector, DoubleVectorVector, FloatVector, FloatVectorVector}
import org.apache.spark.ml.linalg.DenseVector

import scala.collection.mutable
import scala.language.implicitConversions

object ConversionUtils {

type GVV = Either[FloatVectorVector, DoubleVectorVector]

type SSG = Either[Seq[Seq[Float]], Seq[Seq[Double]]]

def toSSG(gvv: GVV): SSG = {
gvv match {
case Left(vv) =>
Left((0 until vv.size.toInt).map { i =>
val v = vv.get(i)
(0 until v.size.toInt).map { j =>
v.get(j)
}
})
case Right(vv) =>
Right((0 until vv.size.toInt).map { i =>
val v = vv.get(i)
(0 until v.size.toInt).map { j =>
v.get(j)
}
})
}
}

def deleteGVV(gvv: GVV): Unit = {
gvv match {
case Left(fvv) => fvv.clear(); fvv.delete()
case Right(dvv) => dvv.clear(); dvv.delete()
}
}

def toDV(gvv: GVV): Seq[DenseVector] = {
gvv match {
case Left(vv) =>
(0 until vv.size.toInt).map { i =>
val v = vv.get(i)
new DenseVector((0 until v.size.toInt).map { j =>
v.get(j).toDouble
}.toArray)
}
case Right(vv) =>
(0 until vv.size.toInt).map { i =>
val v = vv.get(i)
new DenseVector((0 until v.size.toInt).map { j =>
v.get(j)
}.toArray)
}
}

}

def toFV(v: Seq[Float], fv: FloatVector): FloatVector = {
val vs = v.size
val fvs = fv.size()
if (fvs==vs) {
()
v.zipWithIndex.foreach(p => fv.set(p._2, p._1))
} else if (fvs>vs) {
fv.clear()
fv.reserve(vs.toLong)
v.foreach(fv.add)
} else {
fv.reserve(vs.toLong)
(0 until fvs.toInt).foreach(i => fv.set(i, v(i)))
(fvs.toInt until vs).foreach(i => fv.add(v(i)))
}
fv
}

def toDV(v: Seq[Double], fv: DoubleVector): DoubleVector = {
val vs = v.size
val fvs = fv.size()
if (fvs==vs) {
()
v.zipWithIndex.foreach(p => fv.set(p._2, p._1))
} else if (fvs>vs) {
fv.clear()
fv.reserve(vs.toLong)
v.foreach(fv.add)
} else {
fv.reserve(vs.toLong)
(0 until fvs.toInt).foreach(i => fv.set(i, v(i)))
(fvs.toInt until vs).foreach(i => fv.add(v(i)))
}
fv
}

def toFV(v: Seq[Float]): FloatVector = {
val fv = new FloatVector(v.length.toLong)
v.zipWithIndex.foreach(p=>fv.set(p._2,p._1))
fv
}

def toDV(v: Seq[Double]): DoubleVector = {
val fv = new DoubleVector(v.length.toLong)
v.zipWithIndex.foreach(p=>fv.set(p._2,p._1))
fv
}

def toFVV(vv: Seq[Seq[Float]], fvv: FloatVectorVector): FloatVectorVector = {
val vvs = vv.size
val fvvs = fvv.size()
if (fvvs==vvs) {
()
vv.zipWithIndex.foreach(p=>toFV(p._1,fvv.get(p._2)))
} else if (fvvs>vvs) {
fvv.clear()
fvv.reserve(vvs.toLong)
vv.foreach { v => fvv.add(toFV(v))}
} else {
fvv.reserve(vvs.toLong)
(0 until fvvs.toInt).foreach(i => fvv.set(i, toFV(vv(i),fvv.get(i))))
(fvvs.toInt until vvs).foreach(i => fvv.add(toFV(vv(i))))
}
fvv
}

def toDVV(vv: Seq[Seq[Double]], fvv: DoubleVectorVector): DoubleVectorVector = {
val vvs = vv.size
val fvvs = fvv.size()
if (fvvs==vvs) {
()
vv.zipWithIndex.foreach(p=>toDV(p._1,fvv.get(p._2)))
} else if (fvvs>vvs) {
fvv.clear()
fvv.reserve(vvs.toLong)
vv.foreach { v => fvv.add(toDV(v))}
} else {
fvv.reserve(vvs.toLong)
(0 until fvvs.toInt).foreach(i => fvv.set(i, toDV(vv(i),fvv.get(i))))
(fvvs.toInt until vvs).foreach(i => fvv.add(toDV(vv(i))))
}
fvv
}

def toGVV(garr: SSG, existingGVV:GVV): GVV = {
(garr, existingGVV) match {
case (Left(arr), Left(fvv)) =>
Left(toFVV(arr,fvv))
case (Right(arr), Right(fvv)) =>
Right(toDVV(arr,fvv))
case _ =>
throw new IllegalArgumentException("Need to have matching arrays and VectorVectors")
}
}

}
Loading

0 comments on commit 5c0e406

Please sign in to comment.