diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile new file mode 100644 index 000000000..a3f8afb5c --- /dev/null +++ b/.devcontainer/Dockerfile @@ -0,0 +1,44 @@ +FROM --platform=linux/amd64 ubuntu:22.04 + +ENV TZ=America/New_York +RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone + +# Install necessary packages and tools +RUN DEBIAN_FRONTEND=noninteractive \ + apt-get update && apt-get install -y \ + curl \ + openjdk-11-jdk \ + unzip \ + zip \ + git \ + software-properties-common \ + python3-pip && \ + add-apt-repository ppa:deadsnakes/ppa && \ + apt-get update && apt-get install -y \ + python3.7 \ + python3.8 \ + python3.10 \ + tox \ + && rm -rf /var/lib/apt/lists/* + +RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.10 1 + +ARG USERNAME=mleap +ARG USER_UID=1000 +ARG USER_GID=$USER_UID + +# Create the user +RUN groupadd --gid $USER_GID $USERNAME \ + && useradd --uid $USER_UID --gid $USER_GID -m $USERNAME --shell /usr/bin/bash + +USER $USERNAME + +# Install SDKMAN and Scala +SHELL ["/bin/bash", "-c"] + +RUN curl -s "https://get.sdkman.io" | bash \ + && source "$HOME/.sdkman/bin/sdkman-init.sh" \ + && sdk install scala 2.12.18 \ + && sdk install sbt + +# Set up your environment as needed diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 000000000..ec3872e1c --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,11 @@ +{ + "features": { + "ghcr.io/devcontainers/features/sshd:1": { + "version": "latest" + } + }, + "build": { + "dockerfile": "Dockerfile", + "context": "." + } +} diff --git a/Makefile b/Makefile index 45cbabecc..995dcdf42 100644 --- a/Makefile +++ b/Makefile @@ -31,6 +31,10 @@ test_python37: test_python38: source scripts/scala_classpath_for_python.sh && make -C python py38_test +.PHONY: test_python +test_python: test_python37 test_python38 + @echo "All python tests run successfully" + .PHONY: test test: test_executor test_benchmark test_xgboost_runtime test_xgboost_spark test_root_sbt_project test_python @echo "All tests run successfully" diff --git a/bundle-ml/src/main/scala/ml/combust/bundle/BundleFile.scala b/bundle-ml/src/main/scala/ml/combust/bundle/BundleFile.scala index 5fcd00bf4..09e1e7ed9 100644 --- a/bundle-ml/src/main/scala/ml/combust/bundle/BundleFile.scala +++ b/bundle-ml/src/main/scala/ml/combust/bundle/BundleFile.scala @@ -132,7 +132,6 @@ case class BundleFile(fs: FileSystem, } override def finalize(): Unit = { - super.finalize() close() } diff --git a/bundle-ml/src/main/scala/ml/combust/bundle/util/FileUtil.scala b/bundle-ml/src/main/scala/ml/combust/bundle/util/FileUtil.scala index 86732131c..1f15bb83c 100644 --- a/bundle-ml/src/main/scala/ml/combust/bundle/util/FileUtil.scala +++ b/bundle-ml/src/main/scala/ml/combust/bundle/util/FileUtil.scala @@ -31,7 +31,7 @@ object FileUtil { def rmRF(toRemove: Path): Array[(String, Boolean)] = { def removeElement(path: Path): (String, Boolean) = { val result = Try { - Files.deleteIfExists(toRemove) + Files.deleteIfExists(path) } match { case Failure(_) => false case Success(value) => value diff --git a/mleap-core/src/main/scala/ml/combust/mleap/core/feature/StringIndexerModel.scala b/mleap-core/src/main/scala/ml/combust/mleap/core/feature/StringIndexerModel.scala index 00c8f9998..7ac0e932b 100644 --- a/mleap-core/src/main/scala/ml/combust/mleap/core/feature/StringIndexerModel.scala +++ b/mleap-core/src/main/scala/ml/combust/mleap/core/feature/StringIndexerModel.scala @@ -1,40 +1,60 @@ package ml.combust.mleap.core.feature import ml.combust.mleap.core.Model -import ml.combust.mleap.core.types.{ScalarType, StructType} +import ml.combust.mleap.core.types.{ScalarType, StructField, StructType} /** Class for string indexer model. * * String indexer converts a string into an integer representation. * - * @param labels list of labels that can be indexed + * @param labelsArray Array of ordered list of labels, corresponding to indices to be assigned for each input * @param handleInvalid how to handle invalid values (unseen or NULL labels): 'error' (throw an error), * 'skip' (skips invalid data) * or 'keep' (put invalid data in a special bucket at index labels.size */ -case class StringIndexerModel(labels: Seq[String], +case class StringIndexerModel(labelsArray: Seq[Seq[String]], handleInvalid: HandleInvalid = HandleInvalid.Error) extends Model { - val stringToIndex: Map[String, Int] = labels.zipWithIndex.toMap + + private val stringToIndex: Array[Map[String, Int]] = labelsArray.map(_.zipWithIndex.toMap).toArray private val keepInvalid = handleInvalid == HandleInvalid.Keep - private val invalidValue = labels.length + private val invalidValues = labelsArray.map(_.length) + + @deprecated("Use labelsArray instead") + def labels: Seq[String] = labelsArray.head + + /** Convert all strings into its integer representation. + * + * @param values labels to index + * @return indexes of labels + */ + def apply(values: Seq[Any]): Seq[Double] = values.zipWithIndex.map { + case (v: Any, i: Int) => encoder(v, i).toDouble + case (null, i: Int) => encoder(null, i).toDouble + } + + def contains(values: Seq[Any]): Boolean = { + values.zipWithIndex.forall { + case (key, i) => stringToIndex(i).contains(key.toString) + } + } /** Convert a string into its integer representation. - * - * @param value label to index - * @return index of label - */ - def apply(value: Any): Int = if (value == null) { + * + * @param value label to index + * @return index of label + */ + private def encoder(value: Any, colIdx: Int): Int = if (value == null) { if (keepInvalid) { - invalidValue + invalidValues(colIdx) } else { throw new NullPointerException("StringIndexer encountered NULL value. " + s"To handle NULLS, set handleInvalid to ${HandleInvalid.Keep.asParamString}") } } else { val label = value.toString - stringToIndex.get(label) match { + stringToIndex(colIdx).get(label) match { case Some(v) => v case None => if (keepInvalid) { - invalidValue + invalidValues(colIdx) } else { throw new NoSuchElementException(s"Unseen label: $label. To handle unseen labels, " + s"set handleInvalid to ${HandleInvalid.Keep.asParamString}") @@ -43,12 +63,25 @@ case class StringIndexerModel(labels: Seq[String], } /** Create a [[ml.combust.mleap.core.feature.ReverseStringIndexerModel]] from this model. - * + * ReverseStringIndexer only support one input * @return reverse string indexer of this string indexer */ - def toReverse: ReverseStringIndexerModel = ReverseStringIndexerModel(labels) + def toReverse: ReverseStringIndexerModel = { + require(labelsArray.length == 1) + ReverseStringIndexerModel(labelsArray.head) + } - override def inputSchema: StructType = StructType("input" -> ScalarType.String).get + override def inputSchema: StructType = { + val f = labelsArray.zipWithIndex.map { + case (_, i) => StructField(s"input$i", ScalarType.String) + } + StructType(f).get + } - override def outputSchema: StructType = StructType("output" -> ScalarType.Double.nonNullable).get + override def outputSchema: StructType = { + val f = labelsArray.zipWithIndex.map { + case (_, i) => StructField(s"output$i", ScalarType.Double.nonNullable) + } + StructType(f).get + } } diff --git a/mleap-core/src/test/scala/ml/combust/mleap/core/feature/StringIndexerModelSpec.scala b/mleap-core/src/test/scala/ml/combust/mleap/core/feature/StringIndexerModelSpec.scala index e598497ae..890a93008 100644 --- a/mleap-core/src/test/scala/ml/combust/mleap/core/feature/StringIndexerModelSpec.scala +++ b/mleap-core/src/test/scala/ml/combust/mleap/core/feature/StringIndexerModelSpec.scala @@ -10,48 +10,48 @@ import org.scalatest.prop.TableDrivenPropertyChecks class StringIndexerModelSpec extends org.scalatest.funspec.AnyFunSpec with TableDrivenPropertyChecks { describe("#apply") { it("returns the index of the string") { - val indexer = StringIndexerModel(Array("hello", "there", "dude")) + val indexer = StringIndexerModel(Seq(Seq("hello", "there", "dude"))) - assert(indexer("hello") == 0.0) - assert(indexer("there") == 1.0) - assert(indexer("dude") == 2.0) + assert(indexer(Seq("hello")).head == 0.0) + assert(indexer(Seq("there")).head == 1.0) + assert(indexer(Seq("dude")).head == 2.0) } it("throws NullPointerException when encounters NULL/None and handleInvalid is not keep") { - val indexer = StringIndexerModel(Array("hello")) + val indexer = StringIndexerModel(Seq(Seq("hello"))) assertThrows[NullPointerException](indexer(null)) } it("throws NoSuchElementException when encounters unseen label and handleInvalid is not keep") { - val indexer = StringIndexerModel(Array("hello")) - val unseenLabels = Table("unknown1", "unknown2") + val indexer = StringIndexerModel(Seq(Seq("hello"))) + val unseenLabels = Table("label", "unknown1", "unknown2") forAll(unseenLabels) { (label: Any) => intercept[NoSuchElementException] { - indexer(label) + indexer(Seq(label)) } } } it("returns default index for HandleInvalid.keep mode") { - val indexer = StringIndexerModel(Array("hello", "there", "dude"), handleInvalid = HandleInvalid.Keep) + val indexer = StringIndexerModel(Seq(Seq("hello", "there", "dude")), handleInvalid = HandleInvalid.Keep) val invalidLabels = Table("unknown", "other unknown", null, None) forAll(invalidLabels) { (label: Any) => - assert(indexer(label) == 3.0) + assert(indexer(Seq(label)).head == 3.0) } } } describe("input/output schema") { - val indexer = StringIndexerModel(Array("hello", "there", "dude")) + val indexer = StringIndexerModel(Seq(Seq("hello", "there", "dude"))) it("has the right input schema") { - assert(indexer.inputSchema.fields == Seq(StructField("input", ScalarType.String))) + assert(indexer.inputSchema.fields == Seq(StructField("input0", ScalarType.String))) } it("has the right output schema") { - assert(indexer.outputSchema.fields == Seq(StructField("output", ScalarType.Double.nonNullable))) + assert(indexer.outputSchema.fields == Seq(StructField("output0", ScalarType.Double.nonNullable))) } } } diff --git a/mleap-runtime/src/main/scala/ml/combust/mleap/bundle/ops/MultiInOutMleapOp.scala b/mleap-runtime/src/main/scala/ml/combust/mleap/bundle/ops/MultiInOutMleapOp.scala new file mode 100644 index 000000000..938a5bba8 --- /dev/null +++ b/mleap-runtime/src/main/scala/ml/combust/mleap/bundle/ops/MultiInOutMleapOp.scala @@ -0,0 +1,29 @@ +package ml.combust.mleap.bundle.ops + +import ml.bundle.Socket +import ml.combust.bundle.BundleContext +import ml.combust.bundle.dsl.{Node, NodeShape} +import ml.combust.mleap.core.types +import ml.combust.mleap.runtime.MleapContext +import ml.combust.mleap.runtime.frame.Transformer +import ml.combust.mleap.runtime.types.BundleTypeConverters._ + +import scala.reflect.ClassTag + +abstract class MultiInOutMleapOp[N <: Transformer, M <: AnyRef](implicit ct: ClassTag[N]) extends MleapOp[N, M] { + override def load(node: Node, model: M)(implicit context: BundleContext[MleapContext]): N = { + val ns = node.shape.getInput(NodeShape.standardInputPort) match { // Old version need to translate serialized port names to new expectation (input -> input0) + case Some(_) ⇒ translateLegacyShape(node.shape) + + // New version + case None ⇒ node.shape + } + klazz.getConstructor(classOf[String], classOf[types.NodeShape], Model.klazz).newInstance(node.name, ns.asBundle: types.NodeShape, model) + } + + private def translateLegacyShape(ns: NodeShape): NodeShape = { + val i = ns.getInput(NodeShape.standardInputPort).get + val o = ns.getOutput(NodeShape.standardOutputPort).get + NodeShape(inputs = Seq(Socket(i.port + "0", i.name)), outputs = Seq(Socket(o.port + "0", o.name))) + } +} diff --git a/mleap-runtime/src/main/scala/ml/combust/mleap/bundle/ops/feature/OneHotEncoderOp.scala b/mleap-runtime/src/main/scala/ml/combust/mleap/bundle/ops/feature/OneHotEncoderOp.scala index 60dab4187..4940aa526 100644 --- a/mleap-runtime/src/main/scala/ml/combust/mleap/bundle/ops/feature/OneHotEncoderOp.scala +++ b/mleap-runtime/src/main/scala/ml/combust/mleap/bundle/ops/feature/OneHotEncoderOp.scala @@ -1,20 +1,17 @@ package ml.combust.mleap.bundle.ops.feature -import ml.bundle.Socket import ml.combust.bundle.BundleContext import ml.combust.bundle.dsl._ import ml.combust.bundle.op.OpModel -import ml.combust.mleap.bundle.ops.MleapOp +import ml.combust.mleap.bundle.ops.MultiInOutMleapOp import ml.combust.mleap.core.feature.{HandleInvalid, OneHotEncoderModel} -import ml.combust.mleap.core.types import ml.combust.mleap.runtime.MleapContext import ml.combust.mleap.runtime.transformer.feature.OneHotEncoder -import ml.combust.mleap.runtime.types.BundleTypeConverters._ /** * Created by hollinwilkins on 10/24/16. */ -class OneHotEncoderOp extends MleapOp[OneHotEncoder, OneHotEncoderModel] { +class OneHotEncoderOp extends MultiInOutMleapOp[OneHotEncoder, OneHotEncoderModel] { override val Model: OpModel[MleapContext, OneHotEncoderModel] = new OpModel[MleapContext, OneHotEncoderModel] { override val klazz: Class[OneHotEncoderModel] = classOf[OneHotEncoderModel] @@ -50,24 +47,6 @@ class OneHotEncoderOp extends MleapOp[OneHotEncoder, OneHotEncoderModel] { } } - override def load(node: Node, model: OneHotEncoderModel) - (implicit context: BundleContext[MleapContext]): OneHotEncoder = { - val ns = node.shape.getInput(NodeShape.standardInputPort) match { - // Old version of 1HE -- need to translate serialized port names to new expectation (input -> input0) - case Some(_) ⇒ translateLegacyShape(node.shape) - - // New version of 1HE - case None ⇒ node.shape - } - klazz.getConstructor(classOf[String], classOf[types.NodeShape], Model.klazz). - newInstance(node.name, ns.asBundle: types.NodeShape, model) - } - override def model(node: OneHotEncoder): OneHotEncoderModel = node.model - private def translateLegacyShape(ns: NodeShape): NodeShape = { - val i = ns.getInput(NodeShape.standardInputPort).get - val o = ns.getOutput(NodeShape.standardOutputPort).get - NodeShape(inputs = Seq(Socket(i.port + "0", i.name)), outputs = Seq(Socket(o.port + "0", o.name))) - } } diff --git a/mleap-runtime/src/main/scala/ml/combust/mleap/bundle/ops/feature/StringIndexerOp.scala b/mleap-runtime/src/main/scala/ml/combust/mleap/bundle/ops/feature/StringIndexerOp.scala index 3fd3075ce..a694aeb67 100644 --- a/mleap-runtime/src/main/scala/ml/combust/mleap/bundle/ops/feature/StringIndexerOp.scala +++ b/mleap-runtime/src/main/scala/ml/combust/mleap/bundle/ops/feature/StringIndexerOp.scala @@ -1,17 +1,17 @@ package ml.combust.mleap.bundle.ops.feature import ml.combust.bundle.BundleContext -import ml.combust.mleap.core.feature.{HandleInvalid, StringIndexerModel} -import ml.combust.mleap.runtime.transformer.feature.StringIndexer -import ml.combust.bundle.op.OpModel import ml.combust.bundle.dsl._ -import ml.combust.mleap.bundle.ops.MleapOp +import ml.combust.bundle.op.OpModel +import ml.combust.mleap.bundle.ops.MultiInOutMleapOp +import ml.combust.mleap.core.feature.{HandleInvalid, StringIndexerModel} import ml.combust.mleap.runtime.MleapContext +import ml.combust.mleap.runtime.transformer.feature.StringIndexer /** - * Created by hollinwilkins on 8/22/16. - */ -class StringIndexerOp extends MleapOp[StringIndexer, StringIndexerModel] { + * Created by hollinwilkins on 8/22/16. + */ +class StringIndexerOp extends MultiInOutMleapOp[StringIndexer, StringIndexerModel] { override val Model: OpModel[MleapContext, StringIndexerModel] = new OpModel[MleapContext, StringIndexerModel] { override val klazz: Class[StringIndexerModel] = classOf[StringIndexerModel] @@ -19,27 +19,28 @@ class StringIndexerOp extends MleapOp[StringIndexer, StringIndexerModel] { override def store(model: Model, obj: StringIndexerModel) (implicit context: BundleContext[MleapContext]): Model = { - model. - withValue("labels_length", Value.int(1)). - withValue("labels_array_0", Value.stringList(obj.labels)). - withValue("handle_invalid", Value.string(obj.handleInvalid.asParamString)) - + val m = model. + withValue("labels_length", Value.int(1)). + withValue("handle_invalid", Value.string(obj.handleInvalid.asParamString)) + obj.labelsArray.zipWithIndex.foldLeft(m) { + case (m, (labels, i)) => m.withValue(s"labels_array_$i", Value.stringList(labels)) + } } override def load(model: Model) (implicit context: BundleContext[MleapContext]): StringIndexerModel = { val handleInvalid = model.getValue("handle_invalid").map(_.getString).map(HandleInvalid.fromString(_)).getOrElse(HandleInvalid.default) val label_length = model.getValue("labels_length").map(_.getInt).getOrElse(-1) - val labels: Seq[String] = label_length match { + val labelsArray: Seq[Seq[String]] = label_length match { case -1 => // backawards compatibility with spark v2 - model.value("labels").getStringList - case 1 => model.value("labels_array_0").getStringList - case _ => throw new UnsupportedOperationException("Multi-input StringIndexer not supported yet.") + Seq(model.value("labels").getStringList) + case _ => (0 until label_length).map(i => model.value(s"labels_array_$i").getStringList) } - StringIndexerModel(labels = labels, handleInvalid = handleInvalid) + StringIndexerModel(labelsArray = labelsArray, handleInvalid = handleInvalid) } } override def model(node: StringIndexer): StringIndexerModel = node.model + } diff --git a/mleap-runtime/src/main/scala/ml/combust/mleap/runtime/transformer/feature/StringIndexer.scala b/mleap-runtime/src/main/scala/ml/combust/mleap/runtime/transformer/feature/StringIndexer.scala index a83bdbe5e..dd1faf553 100644 --- a/mleap-runtime/src/main/scala/ml/combust/mleap/runtime/transformer/feature/StringIndexer.scala +++ b/mleap-runtime/src/main/scala/ml/combust/mleap/runtime/transformer/feature/StringIndexer.scala @@ -2,29 +2,39 @@ package ml.combust.mleap.runtime.transformer.feature import ml.combust.mleap.core.feature.{HandleInvalid, StringIndexerModel} import ml.combust.mleap.core.types._ -import ml.combust.mleap.runtime.function.{FieldSelector, UserDefinedFunction} -import ml.combust.mleap.runtime.frame.{FrameBuilder, Transformer} +import ml.combust.mleap.runtime.function.{StructSelector, UserDefinedFunction} +import ml.combust.mleap.runtime.frame.{FrameBuilder, Row, Transformer} import scala.util.Try /** - * Created by hwilkins on 10/22/15. - */ + * Created by hwilkins on 10/22/15. + */ case class StringIndexer(override val uid: String = Transformer.uniqueName("string_indexer"), override val shape: NodeShape, override val model: StringIndexerModel) extends Transformer { - val input: String = inputSchema.fields.head.name - val inputSelector: FieldSelector = input - val output: String = outputSchema.fields.head.name - val exec: UserDefinedFunction = (value: String) => model(value).toDouble + private val outputs: Seq[String] = outputSchema.fields.map(_.name) + private val inputs: Seq[String] = inputSchema.fields.map(_.name) + private val inputSelector: StructSelector = StructSelector(inputs) + private val filterSchema = StructType(Seq(StructField("output", ScalarType.Boolean.nonNullable))).get + private val exec: UserDefinedFunction = UserDefinedFunction((keys: Row) => { + val res = model(keys.toSeq) + Row(res:_*) + }, SchemaSpec(outputSchema), Seq(SchemaSpec(inputSchema))) + private val contains: UserDefinedFunction = UserDefinedFunction((keys: Row) => { + model.contains(keys.toSeq) + }, SchemaSpec(filterSchema), Seq(SchemaSpec(inputSchema))) override def transform[FB <: FrameBuilder[FB]](builder: FB): Try[FB] = { + def withColumns(builder: FB): Try[FB] = { + builder.withColumns(outputs, inputSelector)(exec) + } + if(model.handleInvalid == HandleInvalid.Skip) { - builder.filter(input) { - (key: String) => model.stringToIndex.contains(key) - }.flatMap(_.withColumn(output, inputSelector)(exec)) + builder.filter(inputSelector)(contains) + .flatMap(withColumns) } else { - builder.withColumn(output, inputSelector)(exec) + withColumns(builder) } } } diff --git a/mleap-runtime/src/test/scala/ml/combust/mleap/runtime/MleapSupportSpec.scala b/mleap-runtime/src/test/scala/ml/combust/mleap/runtime/MleapSupportSpec.scala index dd1caed2a..5ffc2d853 100644 --- a/mleap-runtime/src/test/scala/ml/combust/mleap/runtime/MleapSupportSpec.scala +++ b/mleap-runtime/src/test/scala/ml/combust/mleap/runtime/MleapSupportSpec.scala @@ -10,13 +10,13 @@ import MleapSupport._ import org.scalatest.funspec.AnyFunSpec -class MleapSupportSpec extends org.scalatest.funspec.AnyFunSpec { +class MleapSupportSpec extends AnyFunSpec { private val testDir = Files.createTempDirectory("MleapSupportSpec") private val stringIndexer = StringIndexer(shape = NodeShape(). - withStandardInput("feature"). - withStandardOutput("feature_index"), - model = StringIndexerModel(Seq("label1", "label2"))) + withInput("input0", "feature") + withOutput("output0", "feature_index"), + model = StringIndexerModel(Seq(Seq("label1", "label2")))) describe("URIBundleFileOps") { it("can save/load a bundle using a URI") { diff --git a/mleap-runtime/src/test/scala/ml/combust/mleap/runtime/javadsl/JavaDSLSpec.java b/mleap-runtime/src/test/scala/ml/combust/mleap/runtime/javadsl/JavaDSLSpec.java index 2651eac37..9cacd9a1e 100644 --- a/mleap-runtime/src/test/scala/ml/combust/mleap/runtime/javadsl/JavaDSLSpec.java +++ b/mleap-runtime/src/test/scala/ml/combust/mleap/runtime/javadsl/JavaDSLSpec.java @@ -1,6 +1,6 @@ package ml.combust.mleap.runtime.javadsl; -import ml.combust.mleap.core.feature.HandleInvalid$; +import ml.combust.mleap.core.feature.HandleInvalid; import ml.combust.mleap.core.feature.StringIndexerModel; import ml.combust.mleap.core.types.*; import ml.combust.mleap.runtime.MleapContext; @@ -16,7 +16,7 @@ import ml.combust.mleap.tensor.ByteString; import org.junit.jupiter.api.Test; -import scala.collection.JavaConversions; +import scala.collection.JavaConverters; import scala.collection.immutable.ListMap; import java.io.File; @@ -50,10 +50,10 @@ private static Map createMap() { StringIndexer stringIndexer = new StringIndexer( "string_indexer", new NodeShape(new ListMap<>(), new ListMap<>()). - withStandardInput("string"). - withStandardOutput("string_index"), - new StringIndexerModel(JavaConversions.asScalaBuffer(Collections.singletonList("hello")).toSeq(), - HandleInvalid$.MODULE$.fromString("error", true))); + withInput("input0", "string"). + withOutput("output0","string_index"), + new StringIndexerModel(JavaConverters.asScalaBuffer(Collections.singletonList(JavaConverters.asScalaBuffer(Collections.singletonList("hello")).toSeq())), + HandleInvalid.Error$.MODULE$)); DefaultLeapFrame buildFrame() { List fields = Arrays.asList(frameBuilder.createField("bool", frameBuilder.createBoolean()), @@ -101,7 +101,7 @@ public void buildALeapFrameWithAllDataTypes() { assertEquals(row.getDouble(7), 44.5, 0.0000000000001); assertEquals(row.getByteString(8), new ByteString("hello_there".getBytes())); assertEquals(row.getList(9), Arrays.asList(23, 44, 55)); - assertEquals(JavaConversions.mapAsJavaMap(row.getMap(10)), mapCol ); + assertEquals(JavaConverters.mapAsJavaMap(row.getMap(10)), mapCol ); List tensorValues = tensorSupport.toArray(row.getTensor(11)); assertEquals(tensorValues, Arrays.asList(23d, 3d, 4d)); } @@ -117,7 +117,7 @@ public void transformASingleRowUsingARowTransformer() { @Test public void createTensorFieldWithDimension() { StructField tensorField = frameBuilder.createField("tensor", frameBuilder.createTensor(frameBuilder.createBasicByte(), Arrays.asList(1, 2), true)); - assertEquals(((TensorType)tensorField.dataType()).dimensions().get(), JavaConversions.asScalaBuffer(Arrays.asList(1, 2)).toSeq()); + assertEquals(((TensorType)tensorField.dataType()).dimensions().get(), JavaConverters.asScalaBuffer(Arrays.asList(1, 2)).toSeq()); } @Test diff --git a/mleap-runtime/src/test/scala/ml/combust/mleap/runtime/transformer/feature/StringIndexerSpec.scala b/mleap-runtime/src/test/scala/ml/combust/mleap/runtime/transformer/feature/StringIndexerSpec.scala index 19cda92ec..d5177cbe7 100644 --- a/mleap-runtime/src/test/scala/ml/combust/mleap/runtime/transformer/feature/StringIndexerSpec.scala +++ b/mleap-runtime/src/test/scala/ml/combust/mleap/runtime/transformer/feature/StringIndexerSpec.scala @@ -13,10 +13,10 @@ class StringIndexerSpec extends org.scalatest.funspec.AnyFunSpec { val frame = DefaultLeapFrame(schema, dataset) val stringIndexer = StringIndexer( - shape = NodeShape.feature( - inputCol = "test_string", - outputCol = "test_index"), - model = StringIndexerModel(Seq("index1", "index2", "index3"))) + shape = NodeShape() + .withInput("input0", "test_string") + .withOutput("output0", "test_index"), + model = StringIndexerModel(Seq(Seq("index1", "index2", "index3")))) describe("#transform") { it("converts input string into an index") { @@ -29,8 +29,8 @@ class StringIndexerSpec extends org.scalatest.funspec.AnyFunSpec { } describe("with invalid input column") { - val stringIndexer2 = stringIndexer.copy(shape = NodeShape().withStandardInput("bad_input"). - withStandardOutput("output")) + val stringIndexer2 = stringIndexer.copy(shape = NodeShape().withInput("input0","bad_input"). + withOutput("output0","output")) it("returns a Failure") { assert(stringIndexer2.transform(frame).isFailure) diff --git a/mleap-spark-base/src/main/scala/org/apache/spark/ml/bundle/MultiInOutFormatSparkOp.scala b/mleap-spark-base/src/main/scala/org/apache/spark/ml/bundle/MultiInOutFormatSparkOp.scala deleted file mode 100644 index f43950762..000000000 --- a/mleap-spark-base/src/main/scala/org/apache/spark/ml/bundle/MultiInOutFormatSparkOp.scala +++ /dev/null @@ -1,61 +0,0 @@ -package org.apache.spark.ml.bundle - -import ml.combust.bundle.dsl.{Model, Value} -import org.apache.spark.ml.Transformer -import org.apache.spark.ml.param.ParamValidators -import org.apache.spark.ml.param.shared.{HasInputCol, HasInputCols, HasOutputCol, HasOutputCols} - -trait MultiInOutFormatSparkOp[ - N <: Transformer with HasInputCol with HasInputCols with HasOutputCol with HasOutputCols -]{ - - protected def saveMultiInOutFormat(model: Model, obj: N): Model = { - ParamValidators.checkSingleVsMultiColumnParams(obj, Seq(obj.inputCol), Seq(obj.inputCols)) - ParamValidators.checkSingleVsMultiColumnParams(obj, Seq(obj.outputCol), Seq(obj.outputCols)) - val result = if(obj.isSet(obj.inputCols)) { - model.withValue("input_cols", Value.stringList(obj.getInputCols)) - } else { - model.withValue("input_col", Value.string(obj.getInputCol)) - } - if (obj.isSet(obj.outputCols)) { - result.withValue("output_cols", Value.stringList(obj.getOutputCols)) - } else { - result.withValue("output_col", Value.string(obj.getOutputCol)) - } - } - - protected def loadMultiInOutFormat(model: Model, obj: N): N = { - val inputCol = model.getValue("input_col").map(_.getString) - val inputCols = model.getValue("input_cols").map(_.getStringList) - val outputCol = model.getValue("output_col").map(_.getString) - val outputCols = model.getValue("output_cols").map(_.getStringList) - val result: N = (inputCol, inputCols) match { - case (None, None) => obj - case (Some(col), None) => obj.set(obj.inputCol, col) - case (None, Some(cols)) => obj.set(obj.inputCols, cols.toArray) - case (_, _) => throw new UnsupportedOperationException("Cannot use both inputCol and inputCols") - } - (outputCol, outputCols) match { - case (None, None) => obj - case (Some(col), None) => result.set(result.outputCol, col) - case (None, Some(cols)) => result.set(result.outputCols, cols.toArray) - case (_, _) => throw new UnsupportedOperationException("Cannot use both outputCol and outputCols") - } - } - - def sparkInputs(obj: N): Seq[ParamSpec] = { - if (obj.isSet(obj.inputCols)) { - Seq(ParamSpec("input", obj.inputCols)) - } else{ - Seq(ParamSpec("input", obj.inputCol)) - } - } - - def sparkOutputs(obj: N): Seq[ParamSpec] = { - if (obj.isSet(obj.outputCols)) { - Seq(ParamSpec("output", obj.outputCols)) - } else{ - Seq(ParamSpec("output", obj.outputCol)) - } - } -} diff --git a/mleap-spark-base/src/main/scala/org/apache/spark/ml/bundle/MultiInOutSparkOp.scala b/mleap-spark-base/src/main/scala/org/apache/spark/ml/bundle/MultiInOutSparkOp.scala new file mode 100644 index 000000000..55186697f --- /dev/null +++ b/mleap-spark-base/src/main/scala/org/apache/spark/ml/bundle/MultiInOutSparkOp.scala @@ -0,0 +1,50 @@ +package org.apache.spark.ml.bundle + +import ml.combust.bundle.BundleContext +import ml.combust.bundle.dsl._ +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.param.ParamValidators +import org.apache.spark.ml.param.shared._ + +import scala.reflect.ClassTag + +abstract class MultiInOutSparkOp[N <: Transformer with HasInputCol with HasInputCols with HasOutputCol with HasOutputCols](implicit ct: ClassTag[N]) extends SimpleSparkOp[N] { + + import NodeShape._ + + override def load(node: Node, model: N)(implicit context: BundleContext[SparkBundleContext]): N = { + val n = sparkLoad(node.name, node.shape, model) + SparkShapeLoader(node.shape, n, sparkInputs(n, node.shape), sparkOutputs(n, node.shape)).loadShape() + n + } + + def sparkInputs(obj: N, shape: NodeShape): Seq[ParamSpec] = sparkInputs(shape.getInput(standardInputPort+"0").isDefined, obj) + + def sparkInputs(hasInputCols: Boolean, obj: N): Seq[ParamSpec] = if (hasInputCols) { + Seq(ParamSpec(standardInputPort, obj.inputCols)) + } else { + Seq(ParamSpec(standardInputPort, obj.inputCol)) + } + + def sparkOutputs(obj: N, shape: NodeShape): Seq[ParamSpec] = sparkOutputs(shape.getOutput(standardOutputPort+"0").isDefined, obj) + + override def shape(node: N)(implicit context: BundleContext[SparkBundleContext]): NodeShape = { + validateParams(node) + super.shape(node) + } + + private def validateParams(obj: N): Unit = { + ParamValidators.checkSingleVsMultiColumnParams(obj, Seq(obj.inputCol), Seq(obj.inputCols)) + ParamValidators.checkSingleVsMultiColumnParams(obj, Seq(obj.outputCol), Seq(obj.outputCols)) + } + + def sparkInputs(obj: N): Seq[ParamSpec] = sparkInputs(obj.isSet(obj.inputCols), obj) + + def sparkOutputs(obj: N): Seq[ParamSpec] = sparkOutputs(obj.isSet(obj.outputCols), obj) + + def sparkOutputs(hasOutputCols: Boolean, obj: N): Seq[ParamSpec] = if (hasOutputCols) { + Seq(ParamSpec(standardOutputPort, obj.outputCols)) + } else { + Seq(ParamSpec(standardOutputPort, obj.outputCol)) + } +} diff --git a/mleap-spark-base/src/main/scala/org/apache/spark/ml/bundle/SimpleSparkOp.scala b/mleap-spark-base/src/main/scala/org/apache/spark/ml/bundle/SimpleSparkOp.scala index 19fe507b4..341a81a34 100644 --- a/mleap-spark-base/src/main/scala/org/apache/spark/ml/bundle/SimpleSparkOp.scala +++ b/mleap-spark-base/src/main/scala/org/apache/spark/ml/bundle/SimpleSparkOp.scala @@ -30,27 +30,8 @@ abstract class SimpleSparkOp[N <: Transformer](implicit ct: ClassTag[N]) extends override def shape(node: N) (implicit context: BundleContext[SparkBundleContext]): NodeShape = { - val dataset = context.context.dataset.getOrElse { - throw new IllegalArgumentException( - """ - |Must provide a transformed data frame to MLeap for serializing a pipeline. - |The transformed data frame is used to extract data types and other metadata - |required for execution. - | - |Example usage: - |``` - |val sparkTransformer: org.apache.spark.ml.Transformer - |val transformedDataset = sparkTransformer.transform(trainingDataset) - | - |implicit val sbc = SparkBundleContext().withDataset(transformedDataset) - | - |Using(BundleFile(file)) { bf => - | sparkTransformer.writeBundle.format(SerializationFormat.Json).save(bf).get - |} - |``` - """.stripMargin) - } - SparkShapeSaver(dataset, + assert(context.context.dataset.isDefined, BundleHelper.sampleDataframeMessage(klazz)) + SparkShapeSaver(context.context.dataset.get, node, sparkInputs(node), sparkOutputs(node)).asNodeShape diff --git a/mleap-spark/src/main/scala/org/apache/spark/ml/bundle/ops/feature/BinarizerOp.scala b/mleap-spark/src/main/scala/org/apache/spark/ml/bundle/ops/feature/BinarizerOp.scala index 80ba16cf5..f8733c6a6 100644 --- a/mleap-spark/src/main/scala/org/apache/spark/ml/bundle/ops/feature/BinarizerOp.scala +++ b/mleap-spark/src/main/scala/org/apache/spark/ml/bundle/ops/feature/BinarizerOp.scala @@ -9,13 +9,11 @@ import org.apache.spark.ml.feature.Binarizer import org.apache.spark.ml.bundle.ops.OpsUtils import org.apache.spark.sql.mleap.TypeConverters._ import ml.combust.mleap.runtime.types.BundleTypeConverters._ -import org.apache.spark.ml.bundle.ops.OpsUtils -import org.apache.spark.ml.param.ParamValidators /** * Created by fshabbir on 12/1/16. */ -class BinarizerOp extends SimpleSparkOp[Binarizer] with MultiInOutFormatSparkOp[Binarizer] { +class BinarizerOp extends MultiInOutSparkOp[Binarizer] { override val Model: OpModel[SparkBundleContext, Binarizer] = new OpModel[SparkBundleContext, Binarizer] { override val klazz: Class[Binarizer] = classOf[Binarizer] @@ -24,21 +22,15 @@ class BinarizerOp extends SimpleSparkOp[Binarizer] with MultiInOutFormatSparkOp[ override def store(model: Model, obj: Binarizer) (implicit context: BundleContext[SparkBundleContext]): Model = { assert(context.context.dataset.isDefined, BundleHelper.sampleDataframeMessage(klazz)) - val dataset = context.context.dataset.get - var result = { - ParamValidators.checkSingleVsMultiColumnParams(obj, Seq(obj.inputCol), Seq(obj.inputCols)) - if(obj.isSet(obj.inputCols)) { - val inputShapes = obj.getInputCols.map(i => sparkToMleapDataShape(dataset.schema(i), dataset): DataShape) - model.withValue("input_shapes_list", Value.dataShapeList(inputShapes)) - } else { - model.withValue("input_shapes", Value.dataShape(sparkToMleapDataShape(dataset.schema(obj.getInputCol), dataset))) - } + if(obj.isSet(obj.inputCols)) { + val inputShapes = obj.getInputCols.map(i => sparkToMleapDataShape(dataset.schema(i), dataset): DataShape) + model.withValue("input_shapes_list", Value.dataShapeList(inputShapes)) + .withValue("thresholds", Value.doubleList(obj.getThresholds)) + } else { + model.withValue("input_shapes", Value.dataShape(sparkToMleapDataShape(dataset.schema(obj.getInputCol), dataset))) + .withValue("threshold", Value.double(obj.getThreshold)) } - - if (obj.isSet(obj.threshold)) result = result.withValue("threshold", Value.double(obj.getThreshold)) - if (obj.isSet(obj.thresholds)) result = result.withValue("thresholds", Value.doubleList(obj.getThresholds)) - saveMultiInOutFormat(result, obj) } override def load(model: Model) @@ -46,13 +38,12 @@ class BinarizerOp extends SimpleSparkOp[Binarizer] with MultiInOutFormatSparkOp[ val threshold: Option[Double] = model.getValue("threshold").map(_.getDouble) val thresholds: Option[Seq[Double]] = model.getValue("thresholds").map(_.getDoubleList) val binarizer = new Binarizer() - val obj = (threshold, thresholds) match { + (threshold, thresholds) match { case (None, None) => throw new IllegalArgumentException("Neither threshold nor thresholds were found") case (Some(v), None) => binarizer.setThreshold(v) case (None, Some(v)) => binarizer.setThresholds(v.toArray) case (_, _) => throw new IllegalArgumentException("Both thresholds and threshold were found") } - loadMultiInOutFormat(model, obj) } } diff --git a/mleap-spark/src/main/scala/org/apache/spark/ml/bundle/ops/feature/OneHotEncoderOp.scala b/mleap-spark/src/main/scala/org/apache/spark/ml/bundle/ops/feature/OneHotEncoderOp.scala index 13c7b854d..f563426df 100644 --- a/mleap-spark/src/main/scala/org/apache/spark/ml/bundle/ops/feature/OneHotEncoderOp.scala +++ b/mleap-spark/src/main/scala/org/apache/spark/ml/bundle/ops/feature/OneHotEncoderOp.scala @@ -5,6 +5,7 @@ import ml.combust.bundle.dsl._ import ml.combust.bundle.op.OpModel import org.apache.spark.ml.attribute.{Attribute, BinaryAttribute, NominalAttribute, NumericAttribute} import org.apache.spark.ml.bundle._ +import org.apache.spark.ml.bundle.ops.OpsUtils import org.apache.spark.ml.feature.OneHotEncoderModel import org.apache.spark.sql.types.StructField @@ -33,7 +34,7 @@ object OneHotEncoderOp { } } -class OneHotEncoderOp extends SimpleSparkOp[OneHotEncoderModel] { +class OneHotEncoderOp extends MultiInOutSparkOp[OneHotEncoderModel] { override val Model: OpModel[SparkBundleContext, OneHotEncoderModel] = new OpModel[SparkBundleContext, OneHotEncoderModel] { override val klazz: Class[OneHotEncoderModel] = classOf[OneHotEncoderModel] @@ -47,22 +48,9 @@ class OneHotEncoderOp extends SimpleSparkOp[OneHotEncoderModel] { val inputCols = if (obj.isSet(obj.inputCol)) Array(obj.getInputCol) else obj.getInputCols val df = context.context.dataset.get val categorySizes = inputCols.map { f ⇒ OneHotEncoderOp.sizeForField(df.schema(f)) } - var m = model.withValue("category_sizes", Value.intList(categorySizes)) + model.withValue("category_sizes", Value.intList(categorySizes)) .withValue("drop_last", Value.boolean(obj.getDropLast)) .withValue("handle_invalid", Value.string(obj.getHandleInvalid)) - if (obj.isSet(obj.inputCol)) { - m = m.withValue("inputCol", Value.string(obj.getInputCol)) - } - if (obj.isSet(obj.inputCols)) { - m = m.withValue("inputCols", Value.stringList(obj.getInputCols)) - } - if (obj.isSet(obj.outputCol)) { - m = m.withValue("outputCol", Value.string(obj.getOutputCol)) - } - if (obj.isSet(obj.outputCols)) { - m = m.withValue("outputCols", Value.stringList(obj.getOutputCols)) - } - m } override def load(model: Model) @@ -70,55 +58,13 @@ class OneHotEncoderOp extends SimpleSparkOp[OneHotEncoderModel] { val m = new OneHotEncoderModel(uid = "", categorySizes = model.value("category_sizes").getIntList.toArray) .setDropLast(model.value("drop_last").getBoolean) .setHandleInvalid(model.value("handle_invalid").getString) - if (model.getValue("inputCol").isDefined) { - m.setInputCol(model.value("inputCol").getString) - } - if (model.getValue("inputCols").isDefined) { - m.setInputCols(model.value("inputCols").getStringList.toArray) - } - if (model.getValue("outputCol").isDefined) { - m.setOutputCol(model.value("outputCol").getString) - } - if (model.getValue("outputCols").isDefined) { - m.setOutputCols(model.value("outputCols").getStringList.toArray) - } m } } override def sparkLoad(uid: String, shape: NodeShape, model: OneHotEncoderModel): OneHotEncoderModel = { - val m = new OneHotEncoderModel(uid = uid, categorySizes = model.categorySizes) - .setDropLast(model.getDropLast) - .setHandleInvalid(model.getHandleInvalid) - if (model.isSet(model.inputCol)) { - m.setInputCol(model.getInputCol) - } - if (model.isSet(model.inputCols)) { - m.setInputCols(model.getInputCols) - } - if (model.isSet(model.outputCol)) { - m.setOutputCol(model.getOutputCol) - } - if (model.isSet(model.outputCols)) { - m.setOutputCols(model.getOutputCols) - } + val m = new OneHotEncoderModel(uid = uid, categorySizes = model.categorySizes) + OpsUtils.copySparkStageParams(model, m) m } - - override def sparkInputs(obj: OneHotEncoderModel): Seq[ParamSpec] = { - obj.isSet(obj.inputCol) match { - case true => Seq(ParamSpec("input", obj.inputCol)) - case false => Seq(ParamSpec("input", obj.inputCols)) - } - - } - - override def sparkOutputs(obj: OneHotEncoderModel): Seq[ParamSpec] = { - obj.isSet(obj.outputCol) match { - case true => Seq(ParamSpec("output", obj.outputCol)) - case false => Seq(ParamSpec("output", obj.outputCols)) - } - - } - } diff --git a/mleap-spark/src/main/scala/org/apache/spark/ml/bundle/ops/feature/StopWordsRemoverOp.scala b/mleap-spark/src/main/scala/org/apache/spark/ml/bundle/ops/feature/StopWordsRemoverOp.scala index 71f0496aa..380dcce32 100644 --- a/mleap-spark/src/main/scala/org/apache/spark/ml/bundle/ops/feature/StopWordsRemoverOp.scala +++ b/mleap-spark/src/main/scala/org/apache/spark/ml/bundle/ops/feature/StopWordsRemoverOp.scala @@ -4,13 +4,13 @@ import ml.combust.bundle.BundleContext import ml.combust.bundle.dsl._ import ml.combust.bundle.op.OpModel import org.apache.spark.ml.bundle.ops.OpsUtils -import org.apache.spark.ml.bundle.{MultiInOutFormatSparkOp, SimpleSparkOp, SparkBundleContext} +import org.apache.spark.ml.bundle.{MultiInOutSparkOp, SimpleSparkOp, SparkBundleContext} import org.apache.spark.ml.feature.StopWordsRemover /** * Created by mikhail on 10/16/16. */ -class StopWordsRemoverOp extends SimpleSparkOp[StopWordsRemover] with MultiInOutFormatSparkOp[StopWordsRemover] { +class StopWordsRemoverOp extends MultiInOutSparkOp[StopWordsRemover] { override val Model: OpModel[SparkBundleContext, StopWordsRemover] = new OpModel[SparkBundleContext, StopWordsRemover] { override val klazz: Class[StopWordsRemover] = classOf[StopWordsRemover] @@ -18,16 +18,15 @@ class StopWordsRemoverOp extends SimpleSparkOp[StopWordsRemover] with MultiInOut override def store(model: Model, obj: StopWordsRemover) (implicit context: BundleContext[SparkBundleContext]): Model = { - val result = model.withValue("stop_words", Value.stringList(obj.getStopWords)). - withValue("case_sensitive", Value.boolean(obj.getCaseSensitive)) - saveMultiInOutFormat(result, obj) + model.withValue("stop_words", Value.stringList(obj.getStopWords)) + .withValue("case_sensitive", Value.boolean(obj.getCaseSensitive)) } override def load(model: Model) (implicit context: BundleContext[SparkBundleContext]): StopWordsRemover = { - val obj = new StopWordsRemover(uid = "").setStopWords(model.value("stop_words").getStringList.toArray). - setCaseSensitive(model.value("case_sensitive").getBoolean) - loadMultiInOutFormat(model, obj) + new StopWordsRemover(uid = "") + .setStopWords(model.value("stop_words").getStringList.toArray) + .setCaseSensitive(model.value("case_sensitive").getBoolean) } } diff --git a/mleap-spark/src/main/scala/org/apache/spark/ml/bundle/ops/feature/StringIndexerOp.scala b/mleap-spark/src/main/scala/org/apache/spark/ml/bundle/ops/feature/StringIndexerOp.scala index 7e8749d6d..463b8c795 100644 --- a/mleap-spark/src/main/scala/org/apache/spark/ml/bundle/ops/feature/StringIndexerOp.scala +++ b/mleap-spark/src/main/scala/org/apache/spark/ml/bundle/ops/feature/StringIndexerOp.scala @@ -10,7 +10,7 @@ import org.apache.spark.ml.feature.StringIndexerModel /** * Created by hollinwilkins on 8/21/16. */ -class StringIndexerOp extends SimpleSparkOp[StringIndexerModel] with MultiInOutFormatSparkOp[StringIndexerModel] { +class StringIndexerOp extends MultiInOutSparkOp[StringIndexerModel] { override val Model: OpModel[SparkBundleContext, StringIndexerModel] = new OpModel[SparkBundleContext, StringIndexerModel] { override val klazz: Class[StringIndexerModel] = classOf[StringIndexerModel] @@ -18,13 +18,14 @@ class StringIndexerOp extends SimpleSparkOp[StringIndexerModel] with MultiInOutF override def store(model: Model, obj: StringIndexerModel) (implicit context: BundleContext[SparkBundleContext]): Model = { + var result = model. withValue("labels_length", Value.int(obj.labelsArray.length)). withValue("handle_invalid", Value.string(obj.getHandleInvalid)) obj.labelsArray.indices.foreach( i => result = result.withValue(s"labels_array_$i", Value.stringList(obj.labelsArray(i))) ) - saveMultiInOutFormat(result, obj) + result } override def load(model: Model) @@ -36,15 +37,14 @@ class StringIndexerOp extends SimpleSparkOp[StringIndexerModel] with MultiInOutF } else { val collectedLabels = new Array[Array[String]](label_length) - for ( i <- 0 to label_length - 1) { + for ( i <- 0 until label_length) { collectedLabels(i) = model.value(s"labels_array_$i").getStringList.toArray } collectedLabels } - val obj = new StringIndexerModel(labelsArray = labelsArray). + new StringIndexerModel(labelsArray = labelsArray). setHandleInvalid(model.value("handle_invalid").getString) - loadMultiInOutFormat(model, obj) } } diff --git a/mleap-spark/src/test/scala/org/apache/spark/ml/parity/feature/OneHotEncoderParitySpec.scala b/mleap-spark/src/test/scala/org/apache/spark/ml/parity/feature/OneHotEncoderParitySpec.scala index bdb419fde..f66576b33 100644 --- a/mleap-spark/src/test/scala/org/apache/spark/ml/parity/feature/OneHotEncoderParitySpec.scala +++ b/mleap-spark/src/test/scala/org/apache/spark/ml/parity/feature/OneHotEncoderParitySpec.scala @@ -9,6 +9,17 @@ import org.apache.spark.sql.DataFrame /** * Created by hollinwilkins on 10/30/16. */ + +class MIOOneHotEncoderParitySpec extends SparkParityBase { + override val dataset: DataFrame = baseDataset.select("state") + override val sparkTransformer: Transformer = new Pipeline() + .setStages(Array( + new StringIndexer().setInputCol("state").setOutputCol("state_index"), + new OneHotEncoder().setInputCol("state_index").setOutputCol("state_oh") + )).fit(dataset) + override val unserializedParams: Set[String] = Set("stringOrderType") +} + class OneHotEncoderParitySpec extends SparkParityBase { override val dataset: DataFrame = baseDataset.select("state") override val sparkTransformer: Transformer = @@ -24,21 +35,7 @@ class OneHotEncoderParitySpec extends SparkParityBase { override val unserializedParams = Set("stringOrderType") - it("serializes/deserializes the Spark model properly with one in/out column"){ - bundleCache = None - val additionalIgnoreParams = Set("outputCol") - val pipeline = new Pipeline() - .setStages(Array( - new StringIndexer().setInputCol("state").setOutputCol("state_index"), - new OneHotEncoder().setInputCol("state_index").setOutputCol("state_oh") - )).fit(dataset) - val sparkTransformed = pipeline.transform(dataset) - implicit val sbc = SparkBundleContext().withDataset(sparkTransformed) - val deserializedTransformer = deserializedSparkTransformer(pipeline) - checkEquality(pipeline, deserializedTransformer, additionalIgnoreParams) - equalityTest(sparkTransformed, deserializedTransformer.transform(dataset)) - bundleCache = None - } + it("fails to instantiate if the Spark model sets inputCol and inputCols"){ intercept[IllegalArgumentException] { diff --git a/mleap-spark/src/test/scala/org/apache/spark/ml/parity/feature/StringIndexerParitySpec.scala b/mleap-spark/src/test/scala/org/apache/spark/ml/parity/feature/StringIndexerParitySpec.scala index 64fa7db53..89c28a034 100644 --- a/mleap-spark/src/test/scala/org/apache/spark/ml/parity/feature/StringIndexerParitySpec.scala +++ b/mleap-spark/src/test/scala/org/apache/spark/ml/parity/feature/StringIndexerParitySpec.scala @@ -2,9 +2,7 @@ package org.apache.spark.ml.parity.feature import org.apache.spark.ml.parity.SparkParityBase import org.apache.spark.ml.Transformer -import org.apache.spark.ml.bundle.SparkBundleContext import org.apache.spark.ml.feature.StringIndexer -import org.apache.spark.ml.param.Param import org.apache.spark.sql.DataFrame /** @@ -12,32 +10,31 @@ import org.apache.spark.sql.DataFrame */ class StringIndexerParitySpec extends SparkParityBase { override val dataset: DataFrame = baseDataset.select("state") + override val unserializedParams = Set("stringOrderType") // setting to handle invalid to true override val sparkTransformer: Transformer = new StringIndexer(). setInputCol("state"). setOutputCol("state_index").setHandleInvalid("keep"). fit(dataset) - +} +class StringIndexerNoOutputColParitySpec extends SparkParityBase { + override val dataset: DataFrame = baseDataset.select("state") override val unserializedParams = Set("stringOrderType") - it("serializes/deserializes the Spark model properly with multiple in/out columns"){ - bundleCache = None - // outputCol has a default value of "__output, so we ignore it in this test - // since the uid will be different - val additionalIgnoreParams = Set("outputCol") + // setting to handle invalid to true + override val sparkTransformer: Transformer = new StringIndexer(). + setInputCol("state"). + setHandleInvalid("keep"). + fit(dataset) +} - val multiColTransformer = new StringIndexer(). +class MIOStringIndexerParitySpec extends SparkParityBase { + override val dataset: DataFrame = baseDataset.select("state", "loan_title") + override val unserializedParams = Set("stringOrderType") + override val sparkTransformer: Transformer = new StringIndexer(). setInputCols(Array("state", "loan_title")). setOutputCols(Array("state_index", "loan_tile_index")). setHandleInvalid("keep"). - fit(baseDataset) - val sparkTransformed = multiColTransformer.transform(baseDataset) - implicit val sbc = SparkBundleContext().withDataset(sparkTransformed) - val deserializedTransformer = deserializedSparkTransformer(multiColTransformer) - checkEquality(multiColTransformer, deserializedTransformer, additionalIgnoreParams) - equalityTest(sparkTransformed, deserializedTransformer.transform(baseDataset)) - - bundleCache = None - } + fit(dataset) }