From 65d672ce1c93e558611c85eeb9fd5c136ecf726f Mon Sep 17 00:00:00 2001 From: "Yunhui Zhang (GitHub)" Date: Mon, 21 Oct 2024 21:34:18 -0400 Subject: [PATCH 01/15] Add MultiInOutTransformer to support Spark transfomer has InputCols and outputCols. --- .../ml/combust/bundle/util/FileUtil.scala | 2 +- .../core/feature/StringIndexerModel.scala | 56 +++++++++---- .../core/feature/StringIndexerModelSpec.scala | 26 +++--- .../bundle/ops/feature/StringIndexerOp.scala | 15 ++-- .../feature/MultiInOutTransformer.scala | 28 +++++++ .../transformer/feature/StringIndexer.scala | 38 +++++---- .../mleap/runtime/MleapSupportSpec.scala | 2 +- .../mleap/runtime/javadsl/JavaDSLSpec.java | 2 +- .../feature/StringIndexerSpec.scala | 12 ++- .../ml/bundle/MultiInOutFormatSparkOp.scala | 84 +++++++++---------- .../ml/bundle/ops/feature/BinarizerOp.scala | 27 ++---- .../bundle/ops/feature/OneHotEncoderOp.scala | 64 ++------------ .../ops/feature/StopWordsRemoverOp.scala | 13 ++- .../bundle/ops/feature/StringIndexerOp.scala | 14 ++-- .../feature/StringIndexerParitySpec.scala | 8 +- 15 files changed, 194 insertions(+), 197 deletions(-) create mode 100644 mleap-runtime/src/main/scala/ml/combust/mleap/runtime/transformer/feature/MultiInOutTransformer.scala diff --git a/bundle-ml/src/main/scala/ml/combust/bundle/util/FileUtil.scala b/bundle-ml/src/main/scala/ml/combust/bundle/util/FileUtil.scala index 86732131c..1f15bb83c 100644 --- a/bundle-ml/src/main/scala/ml/combust/bundle/util/FileUtil.scala +++ b/bundle-ml/src/main/scala/ml/combust/bundle/util/FileUtil.scala @@ -31,7 +31,7 @@ object FileUtil { def rmRF(toRemove: Path): Array[(String, Boolean)] = { def removeElement(path: Path): (String, Boolean) = { val result = Try { - Files.deleteIfExists(toRemove) + Files.deleteIfExists(path) } match { case Failure(_) => false case Success(value) => value diff --git a/mleap-core/src/main/scala/ml/combust/mleap/core/feature/StringIndexerModel.scala b/mleap-core/src/main/scala/ml/combust/mleap/core/feature/StringIndexerModel.scala index 00c8f9998..5cbeaed25 100644 --- a/mleap-core/src/main/scala/ml/combust/mleap/core/feature/StringIndexerModel.scala +++ b/mleap-core/src/main/scala/ml/combust/mleap/core/feature/StringIndexerModel.scala @@ -1,7 +1,7 @@ package ml.combust.mleap.core.feature import ml.combust.mleap.core.Model -import ml.combust.mleap.core.types.{ScalarType, StructType} +import ml.combust.mleap.core.types.{ScalarType, StructField, StructType} /** Class for string indexer model. * @@ -12,29 +12,45 @@ import ml.combust.mleap.core.types.{ScalarType, StructType} * 'skip' (skips invalid data) * or 'keep' (put invalid data in a special bucket at index labels.size */ -case class StringIndexerModel(labels: Seq[String], +case class StringIndexerModel(labels: Seq[Seq[String]], handleInvalid: HandleInvalid = HandleInvalid.Error) extends Model { - val stringToIndex: Map[String, Int] = labels.zipWithIndex.toMap + private val stringToIndex: Array[Map[String, Int]] = labels.map(_.zipWithIndex.toMap).toArray private val keepInvalid = handleInvalid == HandleInvalid.Keep - private val invalidValue = labels.length + private val invalidValue = labels.map(_.length) + + /** Convert all strings into its integer representation. + * + * @param values labels to index + * @return indexes of labels + */ + def apply(values: Seq[Any]): Seq[Double] = values.zipWithIndex.map { + case (v: Any, i: Int) => encoder(v, i).toDouble + case (null, i: Int) => encoder(null, i).toDouble + } + + def contains(values: Seq[Any]): Boolean = { + values.zipWithIndex.forall { + case (key, i) => stringToIndex(i).contains(key.toString) + } + } /** Convert a string into its integer representation. - * - * @param value label to index - * @return index of label - */ - def apply(value: Any): Int = if (value == null) { + * + * @param value label to index + * @return index of label + */ + private def encoder(value: Any, colIdx: Int): Int = if (value == null) { if (keepInvalid) { - invalidValue + invalidValue(colIdx) } else { throw new NullPointerException("StringIndexer encountered NULL value. " + s"To handle NULLS, set handleInvalid to ${HandleInvalid.Keep.asParamString}") } } else { val label = value.toString - stringToIndex.get(label) match { + stringToIndex(colIdx).get(label) match { case Some(v) => v case None => if (keepInvalid) { - invalidValue + invalidValue(colIdx) } else { throw new NoSuchElementException(s"Unseen label: $label. To handle unseen labels, " + s"set handleInvalid to ${HandleInvalid.Keep.asParamString}") @@ -46,9 +62,19 @@ case class StringIndexerModel(labels: Seq[String], * * @return reverse string indexer of this string indexer */ - def toReverse: ReverseStringIndexerModel = ReverseStringIndexerModel(labels) +// def toReverse: ReverseStringIndexerModel = ReverseStringIndexerModel(labels) - override def inputSchema: StructType = StructType("input" -> ScalarType.String).get + override def inputSchema: StructType = { + val f = labels.zipWithIndex.map { + case (_, i) => StructField(s"input$i", ScalarType.String) + } + StructType(f).get + } - override def outputSchema: StructType = StructType("output" -> ScalarType.Double.nonNullable).get + override def outputSchema: StructType = { + val f = labels.zipWithIndex.map { + case (_, i) => StructField(s"output$i", ScalarType.Double.nonNullable) + } + StructType(f).get + } } diff --git a/mleap-core/src/test/scala/ml/combust/mleap/core/feature/StringIndexerModelSpec.scala b/mleap-core/src/test/scala/ml/combust/mleap/core/feature/StringIndexerModelSpec.scala index e598497ae..790db0d4f 100644 --- a/mleap-core/src/test/scala/ml/combust/mleap/core/feature/StringIndexerModelSpec.scala +++ b/mleap-core/src/test/scala/ml/combust/mleap/core/feature/StringIndexerModelSpec.scala @@ -10,48 +10,48 @@ import org.scalatest.prop.TableDrivenPropertyChecks class StringIndexerModelSpec extends org.scalatest.funspec.AnyFunSpec with TableDrivenPropertyChecks { describe("#apply") { it("returns the index of the string") { - val indexer = StringIndexerModel(Array("hello", "there", "dude")) + val indexer = StringIndexerModel(Seq(Array("hello", "there", "dude"))) - assert(indexer("hello") == 0.0) - assert(indexer("there") == 1.0) - assert(indexer("dude") == 2.0) + assert(indexer(Seq("hello")).head == 0.0) + assert(indexer(Seq("there")).head == 1.0) + assert(indexer(Seq("dude")).head == 2.0) } it("throws NullPointerException when encounters NULL/None and handleInvalid is not keep") { - val indexer = StringIndexerModel(Array("hello")) + val indexer = StringIndexerModel(Seq(Array("hello"))) assertThrows[NullPointerException](indexer(null)) } it("throws NoSuchElementException when encounters unseen label and handleInvalid is not keep") { - val indexer = StringIndexerModel(Array("hello")) - val unseenLabels = Table("unknown1", "unknown2") + val indexer = StringIndexerModel(Seq(Array("hello"))) + val unseenLabels = Table("label", "unknown1", "unknown2") forAll(unseenLabels) { (label: Any) => intercept[NoSuchElementException] { - indexer(label) + indexer(Seq(label)) } } } it("returns default index for HandleInvalid.keep mode") { - val indexer = StringIndexerModel(Array("hello", "there", "dude"), handleInvalid = HandleInvalid.Keep) + val indexer = StringIndexerModel(Seq(Array("hello", "there", "dude")), handleInvalid = HandleInvalid.Keep) val invalidLabels = Table("unknown", "other unknown", null, None) forAll(invalidLabels) { (label: Any) => - assert(indexer(label) == 3.0) + assert(indexer(Seq(label)).head == 3.0) } } } describe("input/output schema") { - val indexer = StringIndexerModel(Array("hello", "there", "dude")) + val indexer = StringIndexerModel(Seq(Array("hello", "there", "dude"))) it("has the right input schema") { - assert(indexer.inputSchema.fields == Seq(StructField("input", ScalarType.String))) + assert(indexer.inputSchema.fields == Seq(StructField("input0", ScalarType.String))) } it("has the right output schema") { - assert(indexer.outputSchema.fields == Seq(StructField("output", ScalarType.Double.nonNullable))) + assert(indexer.outputSchema.fields == Seq(StructField("output0", ScalarType.Double.nonNullable))) } } } diff --git a/mleap-runtime/src/main/scala/ml/combust/mleap/bundle/ops/feature/StringIndexerOp.scala b/mleap-runtime/src/main/scala/ml/combust/mleap/bundle/ops/feature/StringIndexerOp.scala index 3fd3075ce..4c21f726b 100644 --- a/mleap-runtime/src/main/scala/ml/combust/mleap/bundle/ops/feature/StringIndexerOp.scala +++ b/mleap-runtime/src/main/scala/ml/combust/mleap/bundle/ops/feature/StringIndexerOp.scala @@ -19,27 +19,28 @@ class StringIndexerOp extends MleapOp[StringIndexer, StringIndexerModel] { override def store(model: Model, obj: StringIndexerModel) (implicit context: BundleContext[MleapContext]): Model = { - model. + val m = model. withValue("labels_length", Value.int(1)). - withValue("labels_array_0", Value.stringList(obj.labels)). withValue("handle_invalid", Value.string(obj.handleInvalid.asParamString)) - + obj.labels.zipWithIndex.foldLeft(m){ + case (m, (label, i)) => m.withValue(s"labels_array_$i", Value.stringList(label)) + } } override def load(model: Model) (implicit context: BundleContext[MleapContext]): StringIndexerModel = { val handleInvalid = model.getValue("handle_invalid").map(_.getString).map(HandleInvalid.fromString(_)).getOrElse(HandleInvalid.default) val label_length = model.getValue("labels_length").map(_.getInt).getOrElse(-1) - val labels: Seq[String] = label_length match { + val labels: Seq[Seq[String]] = label_length match { case -1 => // backawards compatibility with spark v2 - model.value("labels").getStringList - case 1 => model.value("labels_array_0").getStringList - case _ => throw new UnsupportedOperationException("Multi-input StringIndexer not supported yet.") + Seq(model.value("labels").getStringList) + case _ => (0 until label_length).map(i=>model.value(s"labels_array_$i").getStringList) } StringIndexerModel(labels = labels, handleInvalid = handleInvalid) } } override def model(node: StringIndexer): StringIndexerModel = node.model + } diff --git a/mleap-runtime/src/main/scala/ml/combust/mleap/runtime/transformer/feature/MultiInOutTransformer.scala b/mleap-runtime/src/main/scala/ml/combust/mleap/runtime/transformer/feature/MultiInOutTransformer.scala new file mode 100644 index 000000000..e3ecd024a --- /dev/null +++ b/mleap-runtime/src/main/scala/ml/combust/mleap/runtime/transformer/feature/MultiInOutTransformer.scala @@ -0,0 +1,28 @@ +package ml.combust.mleap.runtime.transformer.feature + +import ml.combust.mleap.core.types._ +import ml.combust.mleap.runtime.frame.Transformer + +trait MultiInOutTransformer extends Transformer { + override def inputSchema: StructType = { + if (shape.getInput("input").isDefined) { + val fields = model.inputSchema.getField("input0").map { + case StructField(_, dataType) => StructField(shape.input("input").name, dataType) + }.toSeq + StructType(fields).get + } else { + super.inputSchema + } + } + + override def outputSchema: StructType = { + if (shape.getOutput("output").isDefined) { + val fields = model.outputSchema.getField("output0").map { + case StructField(_, dataType) => StructField(shape.output("output").name, dataType) + }.toSeq + StructType(fields).get + } else { + super.outputSchema + } + } +} diff --git a/mleap-runtime/src/main/scala/ml/combust/mleap/runtime/transformer/feature/StringIndexer.scala b/mleap-runtime/src/main/scala/ml/combust/mleap/runtime/transformer/feature/StringIndexer.scala index a83bdbe5e..660af4495 100644 --- a/mleap-runtime/src/main/scala/ml/combust/mleap/runtime/transformer/feature/StringIndexer.scala +++ b/mleap-runtime/src/main/scala/ml/combust/mleap/runtime/transformer/feature/StringIndexer.scala @@ -2,29 +2,39 @@ package ml.combust.mleap.runtime.transformer.feature import ml.combust.mleap.core.feature.{HandleInvalid, StringIndexerModel} import ml.combust.mleap.core.types._ -import ml.combust.mleap.runtime.function.{FieldSelector, UserDefinedFunction} -import ml.combust.mleap.runtime.frame.{FrameBuilder, Transformer} +import ml.combust.mleap.runtime.function.{StructSelector, UserDefinedFunction} +import ml.combust.mleap.runtime.frame.{FrameBuilder, Row, Transformer} import scala.util.Try /** - * Created by hwilkins on 10/22/15. - */ + * Created by hwilkins on 10/22/15. + */ case class StringIndexer(override val uid: String = Transformer.uniqueName("string_indexer"), override val shape: NodeShape, - override val model: StringIndexerModel) extends Transformer { - val input: String = inputSchema.fields.head.name - val inputSelector: FieldSelector = input - val output: String = outputSchema.fields.head.name - val exec: UserDefinedFunction = (value: String) => model(value).toDouble + override val model: StringIndexerModel) extends Transformer with MultiInOutTransformer { + private val outputs: Seq[String] = outputSchema.fields.map(_.name) + private val inputs: Seq[String] = inputSchema.fields.map(_.name) + private val inputSelector: StructSelector = StructSelector(inputs) + private val filterSchema = StructType(Seq(StructField("output", ScalarType.Boolean.nonNullable))).get + private val exec: UserDefinedFunction = UserDefinedFunction((keys: Row) => { + val res = model(keys.toSeq) + Row(res:_*) + }, SchemaSpec(outputSchema), Seq(SchemaSpec(inputSchema))) + private val contains: UserDefinedFunction = UserDefinedFunction((keys: Row) => { + model.contains(keys.toSeq) + }, SchemaSpec(filterSchema), Seq(SchemaSpec(inputSchema))) override def transform[FB <: FrameBuilder[FB]](builder: FB): Try[FB] = { + def withColumns(builder: FB): Try[FB] = { + builder.withColumns(outputs, inputSelector)(exec) + } + if(model.handleInvalid == HandleInvalid.Skip) { - builder.filter(input) { - (key: String) => model.stringToIndex.contains(key) - }.flatMap(_.withColumn(output, inputSelector)(exec)) + builder.filter(inputSelector)(contains) + .flatMap(withColumns) } else { - builder.withColumn(output, inputSelector)(exec) + withColumns(builder) } } -} +} \ No newline at end of file diff --git a/mleap-runtime/src/test/scala/ml/combust/mleap/runtime/MleapSupportSpec.scala b/mleap-runtime/src/test/scala/ml/combust/mleap/runtime/MleapSupportSpec.scala index dd1caed2a..c6f2651e7 100644 --- a/mleap-runtime/src/test/scala/ml/combust/mleap/runtime/MleapSupportSpec.scala +++ b/mleap-runtime/src/test/scala/ml/combust/mleap/runtime/MleapSupportSpec.scala @@ -16,7 +16,7 @@ class MleapSupportSpec extends org.scalatest.funspec.AnyFunSpec { private val stringIndexer = StringIndexer(shape = NodeShape(). withStandardInput("feature"). withStandardOutput("feature_index"), - model = StringIndexerModel(Seq("label1", "label2"))) + model = StringIndexerModel(Seq(Seq("label1", "label2")))) describe("URIBundleFileOps") { it("can save/load a bundle using a URI") { diff --git a/mleap-runtime/src/test/scala/ml/combust/mleap/runtime/javadsl/JavaDSLSpec.java b/mleap-runtime/src/test/scala/ml/combust/mleap/runtime/javadsl/JavaDSLSpec.java index 2651eac37..e947e6079 100644 --- a/mleap-runtime/src/test/scala/ml/combust/mleap/runtime/javadsl/JavaDSLSpec.java +++ b/mleap-runtime/src/test/scala/ml/combust/mleap/runtime/javadsl/JavaDSLSpec.java @@ -52,7 +52,7 @@ private static Map createMap() { new NodeShape(new ListMap<>(), new ListMap<>()). withStandardInput("string"). withStandardOutput("string_index"), - new StringIndexerModel(JavaConversions.asScalaBuffer(Collections.singletonList("hello")).toSeq(), + new StringIndexerModel(JavaConversions.asScalaBuffer(Collections.singletonList(JavaConversions.asScalaBuffer(Collections.singletonList("hello")).toSeq())).toSeq(), HandleInvalid$.MODULE$.fromString("error", true))); DefaultLeapFrame buildFrame() { diff --git a/mleap-runtime/src/test/scala/ml/combust/mleap/runtime/transformer/feature/StringIndexerSpec.scala b/mleap-runtime/src/test/scala/ml/combust/mleap/runtime/transformer/feature/StringIndexerSpec.scala index 19cda92ec..8add9cbd9 100644 --- a/mleap-runtime/src/test/scala/ml/combust/mleap/runtime/transformer/feature/StringIndexerSpec.scala +++ b/mleap-runtime/src/test/scala/ml/combust/mleap/runtime/transformer/feature/StringIndexerSpec.scala @@ -14,9 +14,11 @@ class StringIndexerSpec extends org.scalatest.funspec.AnyFunSpec { val stringIndexer = StringIndexer( shape = NodeShape.feature( + inputPort="input0", + outputPort="output0", inputCol = "test_string", outputCol = "test_index"), - model = StringIndexerModel(Seq("index1", "index2", "index3"))) + model = StringIndexerModel(Seq(Seq("index1", "index2", "index3")))) describe("#transform") { it("converts input string into an index") { @@ -29,12 +31,16 @@ class StringIndexerSpec extends org.scalatest.funspec.AnyFunSpec { } describe("with invalid input column") { - val stringIndexer2 = stringIndexer.copy(shape = NodeShape().withStandardInput("bad_input"). - withStandardOutput("output")) + val stringIndexer2 = stringIndexer.copy(shape = NodeShape.feature( + inputPort="input0", + outputPort="output0", + inputCol = "bad_input", + outputCol = "bad_output")) it("returns a Failure") { assert(stringIndexer2.transform(frame).isFailure) } + } describe("with invalid string") { diff --git a/mleap-spark-base/src/main/scala/org/apache/spark/ml/bundle/MultiInOutFormatSparkOp.scala b/mleap-spark-base/src/main/scala/org/apache/spark/ml/bundle/MultiInOutFormatSparkOp.scala index f43950762..02868c190 100644 --- a/mleap-spark-base/src/main/scala/org/apache/spark/ml/bundle/MultiInOutFormatSparkOp.scala +++ b/mleap-spark-base/src/main/scala/org/apache/spark/ml/bundle/MultiInOutFormatSparkOp.scala @@ -1,61 +1,53 @@ package org.apache.spark.ml.bundle -import ml.combust.bundle.dsl.{Model, Value} +import ml.combust.bundle.BundleContext +import ml.combust.bundle.dsl._ import org.apache.spark.ml.Transformer import org.apache.spark.ml.param.ParamValidators -import org.apache.spark.ml.param.shared.{HasInputCol, HasInputCols, HasOutputCol, HasOutputCols} +import org.apache.spark.ml.param.shared._ +import ml.combust.bundle.op.OpModel -trait MultiInOutFormatSparkOp[ - N <: Transformer with HasInputCol with HasInputCols with HasOutputCol with HasOutputCols -]{ +import scala.reflect.ClassTag - protected def saveMultiInOutFormat(model: Model, obj: N): Model = { +abstract class MultiInOutOpModel[N <: Transformer with HasInputCol with HasInputCols with HasOutputCol with HasOutputCols] extends OpModel[SparkBundleContext, N] { + private def validateParams(obj: N): Unit = { ParamValidators.checkSingleVsMultiColumnParams(obj, Seq(obj.inputCol), Seq(obj.inputCols)) ParamValidators.checkSingleVsMultiColumnParams(obj, Seq(obj.outputCol), Seq(obj.outputCols)) - val result = if(obj.isSet(obj.inputCols)) { - model.withValue("input_cols", Value.stringList(obj.getInputCols)) - } else { - model.withValue("input_col", Value.string(obj.getInputCol)) - } - if (obj.isSet(obj.outputCols)) { - result.withValue("output_cols", Value.stringList(obj.getOutputCols)) - } else { - result.withValue("output_col", Value.string(obj.getOutputCol)) - } } + override def store(model: Model, obj: N)(implicit context: BundleContext[SparkBundleContext]): Model = { + validateParams(obj) + model + } +} + +abstract class MultiInOutFormatSparkOp[ + N <: Transformer with HasInputCol with HasInputCols with HasOutputCol with HasOutputCols +](implicit ct: ClassTag[N]) extends SimpleSparkOp[N] { + import NodeShape._ - protected def loadMultiInOutFormat(model: Model, obj: N): N = { - val inputCol = model.getValue("input_col").map(_.getString) - val inputCols = model.getValue("input_cols").map(_.getStringList) - val outputCol = model.getValue("output_col").map(_.getString) - val outputCols = model.getValue("output_cols").map(_.getStringList) - val result: N = (inputCol, inputCols) match { - case (None, None) => obj - case (Some(col), None) => obj.set(obj.inputCol, col) - case (None, Some(cols)) => obj.set(obj.inputCols, cols.toArray) - case (_, _) => throw new UnsupportedOperationException("Cannot use both inputCol and inputCols") - } - (outputCol, outputCols) match { - case (None, None) => obj - case (Some(col), None) => result.set(result.outputCol, col) - case (None, Some(cols)) => result.set(result.outputCols, cols.toArray) - case (_, _) => throw new UnsupportedOperationException("Cannot use both outputCol and outputCols") - } + override def load(node: Node, model: N)(implicit context: BundleContext[SparkBundleContext]): N = { + val n = sparkLoad(node.name, node.shape, model) + SparkShapeLoader(node.shape, n, sparkInputs(n, node.shape), sparkOutputs(n, node.shape)).loadShape() + n } - def sparkInputs(obj: N): Seq[ParamSpec] = { - if (obj.isSet(obj.inputCols)) { - Seq(ParamSpec("input", obj.inputCols)) - } else{ - Seq(ParamSpec("input", obj.inputCol)) - } + def sparkInputs(obj: N, shape: NodeShape): Seq[ParamSpec] = sparkInputs(shape.getInput(standardInputPort).isDefined, obj) + + def sparkOutputs(obj: N, shape: NodeShape): Seq[ParamSpec] = sparkOutputs(shape.getOutput(standardOutputPort).isDefined, obj) + + def sparkInputs(obj: N): Seq[ParamSpec] = sparkInputs(obj.isSet(obj.inputCol), obj) + + def sparkOutputs(obj: N): Seq[ParamSpec] = sparkOutputs(obj.isSet(obj.outputCol), obj) + + def sparkInputs(hasInputCol: Boolean, obj: N): Seq[ParamSpec] = if (hasInputCol) { + Seq(ParamSpec(standardInputPort, obj.inputCol)) + } else { + Seq(ParamSpec(standardInputPort, obj.inputCols)) } - def sparkOutputs(obj: N): Seq[ParamSpec] = { - if (obj.isSet(obj.outputCols)) { - Seq(ParamSpec("output", obj.outputCols)) - } else{ - Seq(ParamSpec("output", obj.outputCol)) - } + def sparkOutputs(hasOutputCol: Boolean, obj: N): Seq[ParamSpec] = if (hasOutputCol) { + Seq(ParamSpec(standardOutputPort, obj.outputCol)) + } else { + Seq(ParamSpec(standardOutputPort, obj.outputCols)) } -} +} \ No newline at end of file diff --git a/mleap-spark/src/main/scala/org/apache/spark/ml/bundle/ops/feature/BinarizerOp.scala b/mleap-spark/src/main/scala/org/apache/spark/ml/bundle/ops/feature/BinarizerOp.scala index 80ba16cf5..af0b1d674 100644 --- a/mleap-spark/src/main/scala/org/apache/spark/ml/bundle/ops/feature/BinarizerOp.scala +++ b/mleap-spark/src/main/scala/org/apache/spark/ml/bundle/ops/feature/BinarizerOp.scala @@ -9,13 +9,11 @@ import org.apache.spark.ml.feature.Binarizer import org.apache.spark.ml.bundle.ops.OpsUtils import org.apache.spark.sql.mleap.TypeConverters._ import ml.combust.mleap.runtime.types.BundleTypeConverters._ -import org.apache.spark.ml.bundle.ops.OpsUtils -import org.apache.spark.ml.param.ParamValidators /** * Created by fshabbir on 12/1/16. */ -class BinarizerOp extends SimpleSparkOp[Binarizer] with MultiInOutFormatSparkOp[Binarizer] { +class BinarizerOp extends MultiInOutFormatSparkOp[Binarizer] { override val Model: OpModel[SparkBundleContext, Binarizer] = new OpModel[SparkBundleContext, Binarizer] { override val klazz: Class[Binarizer] = classOf[Binarizer] @@ -24,21 +22,15 @@ class BinarizerOp extends SimpleSparkOp[Binarizer] with MultiInOutFormatSparkOp[ override def store(model: Model, obj: Binarizer) (implicit context: BundleContext[SparkBundleContext]): Model = { assert(context.context.dataset.isDefined, BundleHelper.sampleDataframeMessage(klazz)) - val dataset = context.context.dataset.get - var result = { - ParamValidators.checkSingleVsMultiColumnParams(obj, Seq(obj.inputCol), Seq(obj.inputCols)) - if(obj.isSet(obj.inputCols)) { - val inputShapes = obj.getInputCols.map(i => sparkToMleapDataShape(dataset.schema(i), dataset): DataShape) - model.withValue("input_shapes_list", Value.dataShapeList(inputShapes)) - } else { - model.withValue("input_shapes", Value.dataShape(sparkToMleapDataShape(dataset.schema(obj.getInputCol), dataset))) - } + if(obj.isSet(obj.inputCols)) { + val inputShapes = obj.getInputCols.map(i => sparkToMleapDataShape(dataset.schema(i), dataset): DataShape) + model.withValue("input_shapes_list", Value.dataShapeList(inputShapes)) + .withValue("thresholds", Value.doubleList(obj.getThresholds)) + } else { + model.withValue("input_shapes", Value.dataShape(sparkToMleapDataShape(dataset.schema(obj.getInputCol), dataset))) + .withValue("threshold", Value.double(obj.getThreshold)) } - - if (obj.isSet(obj.threshold)) result = result.withValue("threshold", Value.double(obj.getThreshold)) - if (obj.isSet(obj.thresholds)) result = result.withValue("thresholds", Value.doubleList(obj.getThresholds)) - saveMultiInOutFormat(result, obj) } override def load(model: Model) @@ -46,13 +38,12 @@ class BinarizerOp extends SimpleSparkOp[Binarizer] with MultiInOutFormatSparkOp[ val threshold: Option[Double] = model.getValue("threshold").map(_.getDouble) val thresholds: Option[Seq[Double]] = model.getValue("thresholds").map(_.getDoubleList) val binarizer = new Binarizer() - val obj = (threshold, thresholds) match { + (threshold, thresholds) match { case (None, None) => throw new IllegalArgumentException("Neither threshold nor thresholds were found") case (Some(v), None) => binarizer.setThreshold(v) case (None, Some(v)) => binarizer.setThresholds(v.toArray) case (_, _) => throw new IllegalArgumentException("Both thresholds and threshold were found") } - loadMultiInOutFormat(model, obj) } } diff --git a/mleap-spark/src/main/scala/org/apache/spark/ml/bundle/ops/feature/OneHotEncoderOp.scala b/mleap-spark/src/main/scala/org/apache/spark/ml/bundle/ops/feature/OneHotEncoderOp.scala index 13c7b854d..ef1dc356f 100644 --- a/mleap-spark/src/main/scala/org/apache/spark/ml/bundle/ops/feature/OneHotEncoderOp.scala +++ b/mleap-spark/src/main/scala/org/apache/spark/ml/bundle/ops/feature/OneHotEncoderOp.scala @@ -5,6 +5,7 @@ import ml.combust.bundle.dsl._ import ml.combust.bundle.op.OpModel import org.apache.spark.ml.attribute.{Attribute, BinaryAttribute, NominalAttribute, NumericAttribute} import org.apache.spark.ml.bundle._ +import org.apache.spark.ml.bundle.ops.OpsUtils import org.apache.spark.ml.feature.OneHotEncoderModel import org.apache.spark.sql.types.StructField @@ -33,7 +34,7 @@ object OneHotEncoderOp { } } -class OneHotEncoderOp extends SimpleSparkOp[OneHotEncoderModel] { +class OneHotEncoderOp extends MultiInOutFormatSparkOp[OneHotEncoderModel] { override val Model: OpModel[SparkBundleContext, OneHotEncoderModel] = new OpModel[SparkBundleContext, OneHotEncoderModel] { override val klazz: Class[OneHotEncoderModel] = classOf[OneHotEncoderModel] @@ -47,22 +48,9 @@ class OneHotEncoderOp extends SimpleSparkOp[OneHotEncoderModel] { val inputCols = if (obj.isSet(obj.inputCol)) Array(obj.getInputCol) else obj.getInputCols val df = context.context.dataset.get val categorySizes = inputCols.map { f ⇒ OneHotEncoderOp.sizeForField(df.schema(f)) } - var m = model.withValue("category_sizes", Value.intList(categorySizes)) + model.withValue("category_sizes", Value.intList(categorySizes)) .withValue("drop_last", Value.boolean(obj.getDropLast)) .withValue("handle_invalid", Value.string(obj.getHandleInvalid)) - if (obj.isSet(obj.inputCol)) { - m = m.withValue("inputCol", Value.string(obj.getInputCol)) - } - if (obj.isSet(obj.inputCols)) { - m = m.withValue("inputCols", Value.stringList(obj.getInputCols)) - } - if (obj.isSet(obj.outputCol)) { - m = m.withValue("outputCol", Value.string(obj.getOutputCol)) - } - if (obj.isSet(obj.outputCols)) { - m = m.withValue("outputCols", Value.stringList(obj.getOutputCols)) - } - m } override def load(model: Model) @@ -70,55 +58,13 @@ class OneHotEncoderOp extends SimpleSparkOp[OneHotEncoderModel] { val m = new OneHotEncoderModel(uid = "", categorySizes = model.value("category_sizes").getIntList.toArray) .setDropLast(model.value("drop_last").getBoolean) .setHandleInvalid(model.value("handle_invalid").getString) - if (model.getValue("inputCol").isDefined) { - m.setInputCol(model.value("inputCol").getString) - } - if (model.getValue("inputCols").isDefined) { - m.setInputCols(model.value("inputCols").getStringList.toArray) - } - if (model.getValue("outputCol").isDefined) { - m.setOutputCol(model.value("outputCol").getString) - } - if (model.getValue("outputCols").isDefined) { - m.setOutputCols(model.value("outputCols").getStringList.toArray) - } m } } override def sparkLoad(uid: String, shape: NodeShape, model: OneHotEncoderModel): OneHotEncoderModel = { - val m = new OneHotEncoderModel(uid = uid, categorySizes = model.categorySizes) - .setDropLast(model.getDropLast) - .setHandleInvalid(model.getHandleInvalid) - if (model.isSet(model.inputCol)) { - m.setInputCol(model.getInputCol) - } - if (model.isSet(model.inputCols)) { - m.setInputCols(model.getInputCols) - } - if (model.isSet(model.outputCol)) { - m.setOutputCol(model.getOutputCol) - } - if (model.isSet(model.outputCols)) { - m.setOutputCols(model.getOutputCols) - } + val m = new OneHotEncoderModel(uid = uid, categorySizes = model.categorySizes) + OpsUtils.copySparkStageParams(model, m) m } - - override def sparkInputs(obj: OneHotEncoderModel): Seq[ParamSpec] = { - obj.isSet(obj.inputCol) match { - case true => Seq(ParamSpec("input", obj.inputCol)) - case false => Seq(ParamSpec("input", obj.inputCols)) - } - - } - - override def sparkOutputs(obj: OneHotEncoderModel): Seq[ParamSpec] = { - obj.isSet(obj.outputCol) match { - case true => Seq(ParamSpec("output", obj.outputCol)) - case false => Seq(ParamSpec("output", obj.outputCols)) - } - - } - } diff --git a/mleap-spark/src/main/scala/org/apache/spark/ml/bundle/ops/feature/StopWordsRemoverOp.scala b/mleap-spark/src/main/scala/org/apache/spark/ml/bundle/ops/feature/StopWordsRemoverOp.scala index 71f0496aa..fdc58fdef 100644 --- a/mleap-spark/src/main/scala/org/apache/spark/ml/bundle/ops/feature/StopWordsRemoverOp.scala +++ b/mleap-spark/src/main/scala/org/apache/spark/ml/bundle/ops/feature/StopWordsRemoverOp.scala @@ -10,7 +10,7 @@ import org.apache.spark.ml.feature.StopWordsRemover /** * Created by mikhail on 10/16/16. */ -class StopWordsRemoverOp extends SimpleSparkOp[StopWordsRemover] with MultiInOutFormatSparkOp[StopWordsRemover] { +class StopWordsRemoverOp extends MultiInOutFormatSparkOp[StopWordsRemover] { override val Model: OpModel[SparkBundleContext, StopWordsRemover] = new OpModel[SparkBundleContext, StopWordsRemover] { override val klazz: Class[StopWordsRemover] = classOf[StopWordsRemover] @@ -18,16 +18,15 @@ class StopWordsRemoverOp extends SimpleSparkOp[StopWordsRemover] with MultiInOut override def store(model: Model, obj: StopWordsRemover) (implicit context: BundleContext[SparkBundleContext]): Model = { - val result = model.withValue("stop_words", Value.stringList(obj.getStopWords)). - withValue("case_sensitive", Value.boolean(obj.getCaseSensitive)) - saveMultiInOutFormat(result, obj) + model.withValue("stop_words", Value.stringList(obj.getStopWords)) + .withValue("case_sensitive", Value.boolean(obj.getCaseSensitive)) } override def load(model: Model) (implicit context: BundleContext[SparkBundleContext]): StopWordsRemover = { - val obj = new StopWordsRemover(uid = "").setStopWords(model.value("stop_words").getStringList.toArray). - setCaseSensitive(model.value("case_sensitive").getBoolean) - loadMultiInOutFormat(model, obj) + new StopWordsRemover(uid = "") + .setStopWords(model.value("stop_words").getStringList.toArray) + .setCaseSensitive(model.value("case_sensitive").getBoolean) } } diff --git a/mleap-spark/src/main/scala/org/apache/spark/ml/bundle/ops/feature/StringIndexerOp.scala b/mleap-spark/src/main/scala/org/apache/spark/ml/bundle/ops/feature/StringIndexerOp.scala index 7e8749d6d..ef9e57873 100644 --- a/mleap-spark/src/main/scala/org/apache/spark/ml/bundle/ops/feature/StringIndexerOp.scala +++ b/mleap-spark/src/main/scala/org/apache/spark/ml/bundle/ops/feature/StringIndexerOp.scala @@ -10,21 +10,22 @@ import org.apache.spark.ml.feature.StringIndexerModel /** * Created by hollinwilkins on 8/21/16. */ -class StringIndexerOp extends SimpleSparkOp[StringIndexerModel] with MultiInOutFormatSparkOp[StringIndexerModel] { - override val Model: OpModel[SparkBundleContext, StringIndexerModel] = new OpModel[SparkBundleContext, StringIndexerModel] { +class StringIndexerOp extends MultiInOutFormatSparkOp[StringIndexerModel] { + override val Model: OpModel[SparkBundleContext, StringIndexerModel] = new MultiInOutOpModel[StringIndexerModel] { override val klazz: Class[StringIndexerModel] = classOf[StringIndexerModel] override def opName: String = Bundle.BuiltinOps.feature.string_indexer override def store(model: Model, obj: StringIndexerModel) (implicit context: BundleContext[SparkBundleContext]): Model = { - var result = model. + + var result = super.store(model, obj). withValue("labels_length", Value.int(obj.labelsArray.length)). withValue("handle_invalid", Value.string(obj.getHandleInvalid)) obj.labelsArray.indices.foreach( i => result = result.withValue(s"labels_array_$i", Value.stringList(obj.labelsArray(i))) ) - saveMultiInOutFormat(result, obj) + result } override def load(model: Model) @@ -36,15 +37,14 @@ class StringIndexerOp extends SimpleSparkOp[StringIndexerModel] with MultiInOutF } else { val collectedLabels = new Array[Array[String]](label_length) - for ( i <- 0 to label_length - 1) { + for ( i <- 0 until label_length) { collectedLabels(i) = model.value(s"labels_array_$i").getStringList.toArray } collectedLabels } - val obj = new StringIndexerModel(labelsArray = labelsArray). + new StringIndexerModel(labelsArray = labelsArray). setHandleInvalid(model.value("handle_invalid").getString) - loadMultiInOutFormat(model, obj) } } diff --git a/mleap-spark/src/test/scala/org/apache/spark/ml/parity/feature/StringIndexerParitySpec.scala b/mleap-spark/src/test/scala/org/apache/spark/ml/parity/feature/StringIndexerParitySpec.scala index 64fa7db53..889c99deb 100644 --- a/mleap-spark/src/test/scala/org/apache/spark/ml/parity/feature/StringIndexerParitySpec.scala +++ b/mleap-spark/src/test/scala/org/apache/spark/ml/parity/feature/StringIndexerParitySpec.scala @@ -4,7 +4,6 @@ import org.apache.spark.ml.parity.SparkParityBase import org.apache.spark.ml.Transformer import org.apache.spark.ml.bundle.SparkBundleContext import org.apache.spark.ml.feature.StringIndexer -import org.apache.spark.ml.param.Param import org.apache.spark.sql.DataFrame /** @@ -18,14 +17,13 @@ class StringIndexerParitySpec extends SparkParityBase { setInputCol("state"). setOutputCol("state_index").setHandleInvalid("keep"). fit(dataset) - - override val unserializedParams = Set("stringOrderType") + val ignoreParams = Set("inputCol", "outputCol", "inputCols", "outputCols") + override val unserializedParams = Set("stringOrderType").union(ignoreParams) it("serializes/deserializes the Spark model properly with multiple in/out columns"){ bundleCache = None // outputCol has a default value of "__output, so we ignore it in this test // since the uid will be different - val additionalIgnoreParams = Set("outputCol") val multiColTransformer = new StringIndexer(). setInputCols(Array("state", "loan_title")). @@ -35,7 +33,7 @@ class StringIndexerParitySpec extends SparkParityBase { val sparkTransformed = multiColTransformer.transform(baseDataset) implicit val sbc = SparkBundleContext().withDataset(sparkTransformed) val deserializedTransformer = deserializedSparkTransformer(multiColTransformer) - checkEquality(multiColTransformer, deserializedTransformer, additionalIgnoreParams) + checkEquality(multiColTransformer, deserializedTransformer, ignoreParams) equalityTest(sparkTransformed, deserializedTransformer.transform(baseDataset)) bundleCache = None From 77440133e0c7810a8e3caf10d5dfabd95ae6fa37 Mon Sep 17 00:00:00 2001 From: "Yunhui Zhang (GitHub)" Date: Tue, 22 Oct 2024 16:05:38 -0400 Subject: [PATCH 02/15] use labelsArray instead of lables for StringIndexer --- .../core/feature/StringIndexerModel.scala | 30 +++++++++++++------ .../core/feature/StringIndexerModelSpec.scala | 10 +++---- .../bundle/ops/feature/StringIndexerOp.scala | 12 ++++---- .../mleap/runtime/MleapSupportSpec.scala | 2 +- .../mleap/runtime/javadsl/JavaDSLSpec.java | 4 +-- .../feature/StringIndexerSpec.scala | 2 +- 6 files changed, 36 insertions(+), 24 deletions(-) diff --git a/mleap-core/src/main/scala/ml/combust/mleap/core/feature/StringIndexerModel.scala b/mleap-core/src/main/scala/ml/combust/mleap/core/feature/StringIndexerModel.scala index 5cbeaed25..b9283f168 100644 --- a/mleap-core/src/main/scala/ml/combust/mleap/core/feature/StringIndexerModel.scala +++ b/mleap-core/src/main/scala/ml/combust/mleap/core/feature/StringIndexerModel.scala @@ -7,16 +7,22 @@ import ml.combust.mleap.core.types.{ScalarType, StructField, StructType} * * String indexer converts a string into an integer representation. * - * @param labels list of labels that can be indexed + * @param labelsArray Array of ordered list of labels, corresponding to indices to be assigned for each input * @param handleInvalid how to handle invalid values (unseen or NULL labels): 'error' (throw an error), * 'skip' (skips invalid data) * or 'keep' (put invalid data in a special bucket at index labels.size */ -case class StringIndexerModel(labels: Seq[Seq[String]], - handleInvalid: HandleInvalid = HandleInvalid.Error) extends Model { - private val stringToIndex: Array[Map[String, Int]] = labels.map(_.zipWithIndex.toMap).toArray +case class StringIndexerModel(labelsArray: Array[Array[String]], + handleInvalid: HandleInvalid) extends Model { + + private val stringToIndex: Array[Map[String, Int]] = labelsArray.map(_.zipWithIndex.toMap) private val keepInvalid = handleInvalid == HandleInvalid.Keep - private val invalidValue = labels.map(_.length) + private val invalidValue = labelsArray.map(_.length) + + + + @deprecated("Use labelsArray instead") + def labels: Seq[String] = labelsArray(0).toSeq /** Convert all strings into its integer representation. * @@ -59,22 +65,28 @@ case class StringIndexerModel(labels: Seq[Seq[String]], } /** Create a [[ml.combust.mleap.core.feature.ReverseStringIndexerModel]] from this model. - * + * ReverseStringIndexer only support one input * @return reverse string indexer of this string indexer */ -// def toReverse: ReverseStringIndexerModel = ReverseStringIndexerModel(labels) + def toReverse: ReverseStringIndexerModel = ReverseStringIndexerModel(labelsArray(0)) override def inputSchema: StructType = { - val f = labels.zipWithIndex.map { + val f = labelsArray.zipWithIndex.map { case (_, i) => StructField(s"input$i", ScalarType.String) } StructType(f).get } override def outputSchema: StructType = { - val f = labels.zipWithIndex.map { + val f = labelsArray.zipWithIndex.map { case (_, i) => StructField(s"output$i", ScalarType.Double.nonNullable) } StructType(f).get } } + +object StringIndexerModel { + def apply(labels: Seq[String], handleInvalid: HandleInvalid): StringIndexerModel = StringIndexerModel(Array(labels.toArray), handleInvalid) + def apply(labels: Seq[String]): StringIndexerModel = StringIndexerModel(Array(labels.toArray), HandleInvalid.Error) + def apply(labelsArray: Array[Array[String]]): StringIndexerModel = StringIndexerModel(labelsArray, HandleInvalid.Error) +} diff --git a/mleap-core/src/test/scala/ml/combust/mleap/core/feature/StringIndexerModelSpec.scala b/mleap-core/src/test/scala/ml/combust/mleap/core/feature/StringIndexerModelSpec.scala index 790db0d4f..6e43bf6d7 100644 --- a/mleap-core/src/test/scala/ml/combust/mleap/core/feature/StringIndexerModelSpec.scala +++ b/mleap-core/src/test/scala/ml/combust/mleap/core/feature/StringIndexerModelSpec.scala @@ -10,7 +10,7 @@ import org.scalatest.prop.TableDrivenPropertyChecks class StringIndexerModelSpec extends org.scalatest.funspec.AnyFunSpec with TableDrivenPropertyChecks { describe("#apply") { it("returns the index of the string") { - val indexer = StringIndexerModel(Seq(Array("hello", "there", "dude"))) + val indexer = StringIndexerModel(Array("hello", "there", "dude")) assert(indexer(Seq("hello")).head == 0.0) assert(indexer(Seq("there")).head == 1.0) @@ -18,12 +18,12 @@ class StringIndexerModelSpec extends org.scalatest.funspec.AnyFunSpec with Table } it("throws NullPointerException when encounters NULL/None and handleInvalid is not keep") { - val indexer = StringIndexerModel(Seq(Array("hello"))) + val indexer = StringIndexerModel(Array("hello")) assertThrows[NullPointerException](indexer(null)) } it("throws NoSuchElementException when encounters unseen label and handleInvalid is not keep") { - val indexer = StringIndexerModel(Seq(Array("hello"))) + val indexer = StringIndexerModel(Array("hello")) val unseenLabels = Table("label", "unknown1", "unknown2") forAll(unseenLabels) { (label: Any) => @@ -34,7 +34,7 @@ class StringIndexerModelSpec extends org.scalatest.funspec.AnyFunSpec with Table } it("returns default index for HandleInvalid.keep mode") { - val indexer = StringIndexerModel(Seq(Array("hello", "there", "dude")), handleInvalid = HandleInvalid.Keep) + val indexer = StringIndexerModel(Array("hello", "there", "dude"), handleInvalid = HandleInvalid.Keep) val invalidLabels = Table("unknown", "other unknown", null, None) forAll(invalidLabels) { (label: Any) => @@ -44,7 +44,7 @@ class StringIndexerModelSpec extends org.scalatest.funspec.AnyFunSpec with Table } describe("input/output schema") { - val indexer = StringIndexerModel(Seq(Array("hello", "there", "dude"))) + val indexer = StringIndexerModel(Array("hello", "there", "dude")) it("has the right input schema") { assert(indexer.inputSchema.fields == Seq(StructField("input0", ScalarType.String))) diff --git a/mleap-runtime/src/main/scala/ml/combust/mleap/bundle/ops/feature/StringIndexerOp.scala b/mleap-runtime/src/main/scala/ml/combust/mleap/bundle/ops/feature/StringIndexerOp.scala index 4c21f726b..6c3416082 100644 --- a/mleap-runtime/src/main/scala/ml/combust/mleap/bundle/ops/feature/StringIndexerOp.scala +++ b/mleap-runtime/src/main/scala/ml/combust/mleap/bundle/ops/feature/StringIndexerOp.scala @@ -22,8 +22,8 @@ class StringIndexerOp extends MleapOp[StringIndexer, StringIndexerModel] { val m = model. withValue("labels_length", Value.int(1)). withValue("handle_invalid", Value.string(obj.handleInvalid.asParamString)) - obj.labels.zipWithIndex.foldLeft(m){ - case (m, (label, i)) => m.withValue(s"labels_array_$i", Value.stringList(label)) + obj.labelsArray.zipWithIndex.foldLeft(m){ + case (m, (labels, i)) => m.withValue(s"labels_array_$i", Value.stringList(labels)) } } @@ -31,13 +31,13 @@ class StringIndexerOp extends MleapOp[StringIndexer, StringIndexerModel] { (implicit context: BundleContext[MleapContext]): StringIndexerModel = { val handleInvalid = model.getValue("handle_invalid").map(_.getString).map(HandleInvalid.fromString(_)).getOrElse(HandleInvalid.default) val label_length = model.getValue("labels_length").map(_.getInt).getOrElse(-1) - val labels: Seq[Seq[String]] = label_length match { + val labelsArray: Array[Array[String]] = label_length match { case -1 => // backawards compatibility with spark v2 - Seq(model.value("labels").getStringList) - case _ => (0 until label_length).map(i=>model.value(s"labels_array_$i").getStringList) + Array(model.value("labels").getStringList.toArray) + case _ => (0 until label_length).map(i=>model.value(s"labels_array_$i").getStringList.toArray).toArray } - StringIndexerModel(labels = labels, handleInvalid = handleInvalid) + StringIndexerModel(labelsArray = labelsArray, handleInvalid = handleInvalid) } } diff --git a/mleap-runtime/src/test/scala/ml/combust/mleap/runtime/MleapSupportSpec.scala b/mleap-runtime/src/test/scala/ml/combust/mleap/runtime/MleapSupportSpec.scala index c6f2651e7..dd1caed2a 100644 --- a/mleap-runtime/src/test/scala/ml/combust/mleap/runtime/MleapSupportSpec.scala +++ b/mleap-runtime/src/test/scala/ml/combust/mleap/runtime/MleapSupportSpec.scala @@ -16,7 +16,7 @@ class MleapSupportSpec extends org.scalatest.funspec.AnyFunSpec { private val stringIndexer = StringIndexer(shape = NodeShape(). withStandardInput("feature"). withStandardOutput("feature_index"), - model = StringIndexerModel(Seq(Seq("label1", "label2")))) + model = StringIndexerModel(Seq("label1", "label2"))) describe("URIBundleFileOps") { it("can save/load a bundle using a URI") { diff --git a/mleap-runtime/src/test/scala/ml/combust/mleap/runtime/javadsl/JavaDSLSpec.java b/mleap-runtime/src/test/scala/ml/combust/mleap/runtime/javadsl/JavaDSLSpec.java index e947e6079..55bfd16f0 100644 --- a/mleap-runtime/src/test/scala/ml/combust/mleap/runtime/javadsl/JavaDSLSpec.java +++ b/mleap-runtime/src/test/scala/ml/combust/mleap/runtime/javadsl/JavaDSLSpec.java @@ -1,7 +1,7 @@ package ml.combust.mleap.runtime.javadsl; import ml.combust.mleap.core.feature.HandleInvalid$; -import ml.combust.mleap.core.feature.StringIndexerModel; +import ml.combust.mleap.core.feature.StringIndexerModel$; import ml.combust.mleap.core.types.*; import ml.combust.mleap.runtime.MleapContext; import ml.combust.mleap.runtime.frame.DefaultLeapFrame; @@ -52,7 +52,7 @@ private static Map createMap() { new NodeShape(new ListMap<>(), new ListMap<>()). withStandardInput("string"). withStandardOutput("string_index"), - new StringIndexerModel(JavaConversions.asScalaBuffer(Collections.singletonList(JavaConversions.asScalaBuffer(Collections.singletonList("hello")).toSeq())).toSeq(), + StringIndexerModel$.MODULE$.apply(JavaConversions.asScalaBuffer(Collections.singletonList("hello")).toSeq(), HandleInvalid$.MODULE$.fromString("error", true))); DefaultLeapFrame buildFrame() { diff --git a/mleap-runtime/src/test/scala/ml/combust/mleap/runtime/transformer/feature/StringIndexerSpec.scala b/mleap-runtime/src/test/scala/ml/combust/mleap/runtime/transformer/feature/StringIndexerSpec.scala index 8add9cbd9..eb254f714 100644 --- a/mleap-runtime/src/test/scala/ml/combust/mleap/runtime/transformer/feature/StringIndexerSpec.scala +++ b/mleap-runtime/src/test/scala/ml/combust/mleap/runtime/transformer/feature/StringIndexerSpec.scala @@ -18,7 +18,7 @@ class StringIndexerSpec extends org.scalatest.funspec.AnyFunSpec { outputPort="output0", inputCol = "test_string", outputCol = "test_index"), - model = StringIndexerModel(Seq(Seq("index1", "index2", "index3")))) + model = StringIndexerModel(Seq("index1", "index2", "index3"))) describe("#transform") { it("converts input string into an index") { From 33d061dc60f79484406424f5149b09d0a3c6eb24 Mon Sep 17 00:00:00 2001 From: "Yunhui Zhang (GitHub)" Date: Tue, 22 Oct 2024 16:16:22 -0400 Subject: [PATCH 03/15] Clean up unwanted changes --- .../transformer/feature/MultiInOutTransformer.scala | 5 +++++ .../transformer/feature/StringIndexerSpec.scala | 10 ++-------- .../ml/parity/feature/StringIndexerParitySpec.scala | 7 +++---- 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/mleap-runtime/src/main/scala/ml/combust/mleap/runtime/transformer/feature/MultiInOutTransformer.scala b/mleap-runtime/src/main/scala/ml/combust/mleap/runtime/transformer/feature/MultiInOutTransformer.scala index e3ecd024a..ab0b4fa27 100644 --- a/mleap-runtime/src/main/scala/ml/combust/mleap/runtime/transformer/feature/MultiInOutTransformer.scala +++ b/mleap-runtime/src/main/scala/ml/combust/mleap/runtime/transformer/feature/MultiInOutTransformer.scala @@ -3,6 +3,11 @@ package ml.combust.mleap.runtime.transformer.feature import ml.combust.mleap.core.types._ import ml.combust.mleap.runtime.frame.Transformer +/* +This Transformer trait only used when the Spark Transformer has inputCol and inputCols +, outputCol and outputCols. Because the saved shape will have port "input" instead of +"input0", "output" instead "output0". + */ trait MultiInOutTransformer extends Transformer { override def inputSchema: StructType = { if (shape.getInput("input").isDefined) { diff --git a/mleap-runtime/src/test/scala/ml/combust/mleap/runtime/transformer/feature/StringIndexerSpec.scala b/mleap-runtime/src/test/scala/ml/combust/mleap/runtime/transformer/feature/StringIndexerSpec.scala index eb254f714..19cda92ec 100644 --- a/mleap-runtime/src/test/scala/ml/combust/mleap/runtime/transformer/feature/StringIndexerSpec.scala +++ b/mleap-runtime/src/test/scala/ml/combust/mleap/runtime/transformer/feature/StringIndexerSpec.scala @@ -14,8 +14,6 @@ class StringIndexerSpec extends org.scalatest.funspec.AnyFunSpec { val stringIndexer = StringIndexer( shape = NodeShape.feature( - inputPort="input0", - outputPort="output0", inputCol = "test_string", outputCol = "test_index"), model = StringIndexerModel(Seq("index1", "index2", "index3"))) @@ -31,16 +29,12 @@ class StringIndexerSpec extends org.scalatest.funspec.AnyFunSpec { } describe("with invalid input column") { - val stringIndexer2 = stringIndexer.copy(shape = NodeShape.feature( - inputPort="input0", - outputPort="output0", - inputCol = "bad_input", - outputCol = "bad_output")) + val stringIndexer2 = stringIndexer.copy(shape = NodeShape().withStandardInput("bad_input"). + withStandardOutput("output")) it("returns a Failure") { assert(stringIndexer2.transform(frame).isFailure) } - } describe("with invalid string") { diff --git a/mleap-spark/src/test/scala/org/apache/spark/ml/parity/feature/StringIndexerParitySpec.scala b/mleap-spark/src/test/scala/org/apache/spark/ml/parity/feature/StringIndexerParitySpec.scala index 889c99deb..252f347f2 100644 --- a/mleap-spark/src/test/scala/org/apache/spark/ml/parity/feature/StringIndexerParitySpec.scala +++ b/mleap-spark/src/test/scala/org/apache/spark/ml/parity/feature/StringIndexerParitySpec.scala @@ -17,14 +17,13 @@ class StringIndexerParitySpec extends SparkParityBase { setInputCol("state"). setOutputCol("state_index").setHandleInvalid("keep"). fit(dataset) - val ignoreParams = Set("inputCol", "outputCol", "inputCols", "outputCols") - override val unserializedParams = Set("stringOrderType").union(ignoreParams) + override val unserializedParams = Set("stringOrderType") it("serializes/deserializes the Spark model properly with multiple in/out columns"){ bundleCache = None // outputCol has a default value of "__output, so we ignore it in this test // since the uid will be different - + val additionalIgnoreParams = Set("outputCol") val multiColTransformer = new StringIndexer(). setInputCols(Array("state", "loan_title")). setOutputCols(Array("state_index", "loan_tile_index")). @@ -33,7 +32,7 @@ class StringIndexerParitySpec extends SparkParityBase { val sparkTransformed = multiColTransformer.transform(baseDataset) implicit val sbc = SparkBundleContext().withDataset(sparkTransformed) val deserializedTransformer = deserializedSparkTransformer(multiColTransformer) - checkEquality(multiColTransformer, deserializedTransformer, ignoreParams) + checkEquality(multiColTransformer, deserializedTransformer, additionalIgnoreParams) equalityTest(sparkTransformed, deserializedTransformer.transform(baseDataset)) bundleCache = None From cdd51c4bb193714bce02cedd3bd8a5c3244b87b7 Mon Sep 17 00:00:00 2001 From: austinzh Date: Tue, 22 Oct 2024 18:35:04 -0400 Subject: [PATCH 04/15] Create devcontainer.json --- .devcontainer/devcontainer.json | 10 ++++++++++ 1 file changed, 10 insertions(+) create mode 100644 .devcontainer/devcontainer.json diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 000000000..f7cc92bcc --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,10 @@ +{ + "image": "mcr.microsoft.com/devcontainers/universal:2", + "features": { + "ghcr.io/devcontainers-extra/features/scala-sdkman:2": { + "version": "2.12.18", + "jdkVersion": "11", + "jdkDistro": "open" + } + } +} From 792f990e6f30384f3deba0a586897396bc5ea36f Mon Sep 17 00:00:00 2001 From: austinzh Date: Tue, 22 Oct 2024 18:49:41 -0400 Subject: [PATCH 05/15] Update devcontainer.json --- .devcontainer/devcontainer.json | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index f7cc92bcc..88106e5c9 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -3,8 +3,7 @@ "features": { "ghcr.io/devcontainers-extra/features/scala-sdkman:2": { "version": "2.12.18", - "jdkVersion": "11", - "jdkDistro": "open" + "jdkVersion": "11" } } } From 055c3ebc10f51816ad5a2c4b960ed86f5e0cb5b5 Mon Sep 17 00:00:00 2001 From: "Yunhui Zhang (GitHub)" Date: Wed, 23 Oct 2024 16:03:54 -0400 Subject: [PATCH 06/15] Refactor to use OneHotEncoder's logic OneHotEncoder already implement a logic to due with multiple input columns and output columns in mleap It's better to reuse what is ready done --- .../core/feature/StringIndexerModel.scala | 25 ++++++-------- .../core/feature/StringIndexerModelSpec.scala | 10 +++--- .../mleap/bundle/ops/MultiInOutMleapOp.scala | 29 ++++++++++++++++ .../bundle/ops/feature/OneHotEncoderOp.scala | 25 ++------------ .../bundle/ops/feature/StringIndexerOp.scala | 16 ++++----- .../feature/MultiInOutTransformer.scala | 33 ------------------- .../transformer/feature/StringIndexer.scala | 2 +- .../mleap/runtime/MleapSupportSpec.scala | 8 ++--- .../mleap/runtime/javadsl/JavaDSLSpec.java | 18 +++++----- .../feature/StringIndexerSpec.scala | 12 +++---- ...tSparkOp.scala => MultiInOutSparkOp.scala} | 25 +++++++------- .../spark/ml/bundle/SimpleSparkOp.scala | 23 ++----------- .../ml/bundle/ops/feature/BinarizerOp.scala | 2 +- .../bundle/ops/feature/OneHotEncoderOp.scala | 2 +- .../ops/feature/StopWordsRemoverOp.scala | 4 +-- .../bundle/ops/feature/StringIndexerOp.scala | 6 ++-- .../feature/OneHotEncoderParitySpec.scala | 27 +++++++-------- .../feature/StringIndexerParitySpec.scala | 28 ++++++---------- 18 files changed, 116 insertions(+), 179 deletions(-) create mode 100644 mleap-runtime/src/main/scala/ml/combust/mleap/bundle/ops/MultiInOutMleapOp.scala delete mode 100644 mleap-runtime/src/main/scala/ml/combust/mleap/runtime/transformer/feature/MultiInOutTransformer.scala rename mleap-spark-base/src/main/scala/org/apache/spark/ml/bundle/{MultiInOutFormatSparkOp.scala => MultiInOutSparkOp.scala} (82%) diff --git a/mleap-core/src/main/scala/ml/combust/mleap/core/feature/StringIndexerModel.scala b/mleap-core/src/main/scala/ml/combust/mleap/core/feature/StringIndexerModel.scala index b9283f168..f9e72a692 100644 --- a/mleap-core/src/main/scala/ml/combust/mleap/core/feature/StringIndexerModel.scala +++ b/mleap-core/src/main/scala/ml/combust/mleap/core/feature/StringIndexerModel.scala @@ -12,17 +12,15 @@ import ml.combust.mleap.core.types.{ScalarType, StructField, StructType} * 'skip' (skips invalid data) * or 'keep' (put invalid data in a special bucket at index labels.size */ -case class StringIndexerModel(labelsArray: Array[Array[String]], - handleInvalid: HandleInvalid) extends Model { +case class StringIndexerModel(labelsArray: Seq[Seq[String]], + handleInvalid: HandleInvalid = HandleInvalid.Error) extends Model { - private val stringToIndex: Array[Map[String, Int]] = labelsArray.map(_.zipWithIndex.toMap) + private val stringToIndex: Array[Map[String, Int]] = labelsArray.map(_.zipWithIndex.toMap).toArray private val keepInvalid = handleInvalid == HandleInvalid.Keep private val invalidValue = labelsArray.map(_.length) - - @deprecated("Use labelsArray instead") - def labels: Seq[String] = labelsArray(0).toSeq + def labels: Seq[String] = labelsArray.head /** Convert all strings into its integer representation. * @@ -68,7 +66,10 @@ case class StringIndexerModel(labelsArray: Array[Array[String]], * ReverseStringIndexer only support one input * @return reverse string indexer of this string indexer */ - def toReverse: ReverseStringIndexerModel = ReverseStringIndexerModel(labelsArray(0)) + def toReverse: ReverseStringIndexerModel = { + require(labelsArray.length == 1) + ReverseStringIndexerModel(labelsArray.head) + } override def inputSchema: StructType = { val f = labelsArray.zipWithIndex.map { @@ -79,14 +80,8 @@ case class StringIndexerModel(labelsArray: Array[Array[String]], override def outputSchema: StructType = { val f = labelsArray.zipWithIndex.map { - case (_, i) => StructField(s"output$i", ScalarType.Double.nonNullable) + case (_, i) => StructField(s"output$i", ScalarType.Double.nonNullable) } StructType(f).get } -} - -object StringIndexerModel { - def apply(labels: Seq[String], handleInvalid: HandleInvalid): StringIndexerModel = StringIndexerModel(Array(labels.toArray), handleInvalid) - def apply(labels: Seq[String]): StringIndexerModel = StringIndexerModel(Array(labels.toArray), HandleInvalid.Error) - def apply(labelsArray: Array[Array[String]]): StringIndexerModel = StringIndexerModel(labelsArray, HandleInvalid.Error) -} +} \ No newline at end of file diff --git a/mleap-core/src/test/scala/ml/combust/mleap/core/feature/StringIndexerModelSpec.scala b/mleap-core/src/test/scala/ml/combust/mleap/core/feature/StringIndexerModelSpec.scala index 6e43bf6d7..890a93008 100644 --- a/mleap-core/src/test/scala/ml/combust/mleap/core/feature/StringIndexerModelSpec.scala +++ b/mleap-core/src/test/scala/ml/combust/mleap/core/feature/StringIndexerModelSpec.scala @@ -10,7 +10,7 @@ import org.scalatest.prop.TableDrivenPropertyChecks class StringIndexerModelSpec extends org.scalatest.funspec.AnyFunSpec with TableDrivenPropertyChecks { describe("#apply") { it("returns the index of the string") { - val indexer = StringIndexerModel(Array("hello", "there", "dude")) + val indexer = StringIndexerModel(Seq(Seq("hello", "there", "dude"))) assert(indexer(Seq("hello")).head == 0.0) assert(indexer(Seq("there")).head == 1.0) @@ -18,12 +18,12 @@ class StringIndexerModelSpec extends org.scalatest.funspec.AnyFunSpec with Table } it("throws NullPointerException when encounters NULL/None and handleInvalid is not keep") { - val indexer = StringIndexerModel(Array("hello")) + val indexer = StringIndexerModel(Seq(Seq("hello"))) assertThrows[NullPointerException](indexer(null)) } it("throws NoSuchElementException when encounters unseen label and handleInvalid is not keep") { - val indexer = StringIndexerModel(Array("hello")) + val indexer = StringIndexerModel(Seq(Seq("hello"))) val unseenLabels = Table("label", "unknown1", "unknown2") forAll(unseenLabels) { (label: Any) => @@ -34,7 +34,7 @@ class StringIndexerModelSpec extends org.scalatest.funspec.AnyFunSpec with Table } it("returns default index for HandleInvalid.keep mode") { - val indexer = StringIndexerModel(Array("hello", "there", "dude"), handleInvalid = HandleInvalid.Keep) + val indexer = StringIndexerModel(Seq(Seq("hello", "there", "dude")), handleInvalid = HandleInvalid.Keep) val invalidLabels = Table("unknown", "other unknown", null, None) forAll(invalidLabels) { (label: Any) => @@ -44,7 +44,7 @@ class StringIndexerModelSpec extends org.scalatest.funspec.AnyFunSpec with Table } describe("input/output schema") { - val indexer = StringIndexerModel(Array("hello", "there", "dude")) + val indexer = StringIndexerModel(Seq(Seq("hello", "there", "dude"))) it("has the right input schema") { assert(indexer.inputSchema.fields == Seq(StructField("input0", ScalarType.String))) diff --git a/mleap-runtime/src/main/scala/ml/combust/mleap/bundle/ops/MultiInOutMleapOp.scala b/mleap-runtime/src/main/scala/ml/combust/mleap/bundle/ops/MultiInOutMleapOp.scala new file mode 100644 index 000000000..8c662e500 --- /dev/null +++ b/mleap-runtime/src/main/scala/ml/combust/mleap/bundle/ops/MultiInOutMleapOp.scala @@ -0,0 +1,29 @@ +package ml.combust.mleap.bundle.ops + +import ml.bundle.Socket +import ml.combust.bundle.BundleContext +import ml.combust.bundle.dsl.{Node, NodeShape} +import ml.combust.mleap.core.types +import ml.combust.mleap.runtime.MleapContext +import ml.combust.mleap.runtime.frame.Transformer +import ml.combust.mleap.runtime.types.BundleTypeConverters._ + +import scala.reflect.ClassTag + +abstract class MultiInOutMleapOp[N <: Transformer, M <: AnyRef](implicit ct: ClassTag[N]) extends MleapOp[N, M] { + override def load(node: Node, model: M)(implicit context: BundleContext[MleapContext]): N = { + val ns = node.shape.getInput(NodeShape.standardInputPort) match { // Old version need to translate serialized port names to new expectation (input -> input0) + case Some(_) ⇒ translateLegacyShape(node.shape) + + // New version + case None ⇒ node.shape + } + klazz.getConstructor(classOf[String], classOf[types.NodeShape], Model.klazz).newInstance(node.name, ns.asBundle: types.NodeShape, model) + } + + private def translateLegacyShape(ns: NodeShape): NodeShape = { + val i = ns.getInput(NodeShape.standardInputPort).get + val o = ns.getOutput(NodeShape.standardOutputPort).get + NodeShape(inputs = Seq(Socket(i.port + "0", i.name)), outputs = Seq(Socket(o.port + "0", o.name))) + } +} \ No newline at end of file diff --git a/mleap-runtime/src/main/scala/ml/combust/mleap/bundle/ops/feature/OneHotEncoderOp.scala b/mleap-runtime/src/main/scala/ml/combust/mleap/bundle/ops/feature/OneHotEncoderOp.scala index 60dab4187..4940aa526 100644 --- a/mleap-runtime/src/main/scala/ml/combust/mleap/bundle/ops/feature/OneHotEncoderOp.scala +++ b/mleap-runtime/src/main/scala/ml/combust/mleap/bundle/ops/feature/OneHotEncoderOp.scala @@ -1,20 +1,17 @@ package ml.combust.mleap.bundle.ops.feature -import ml.bundle.Socket import ml.combust.bundle.BundleContext import ml.combust.bundle.dsl._ import ml.combust.bundle.op.OpModel -import ml.combust.mleap.bundle.ops.MleapOp +import ml.combust.mleap.bundle.ops.MultiInOutMleapOp import ml.combust.mleap.core.feature.{HandleInvalid, OneHotEncoderModel} -import ml.combust.mleap.core.types import ml.combust.mleap.runtime.MleapContext import ml.combust.mleap.runtime.transformer.feature.OneHotEncoder -import ml.combust.mleap.runtime.types.BundleTypeConverters._ /** * Created by hollinwilkins on 10/24/16. */ -class OneHotEncoderOp extends MleapOp[OneHotEncoder, OneHotEncoderModel] { +class OneHotEncoderOp extends MultiInOutMleapOp[OneHotEncoder, OneHotEncoderModel] { override val Model: OpModel[MleapContext, OneHotEncoderModel] = new OpModel[MleapContext, OneHotEncoderModel] { override val klazz: Class[OneHotEncoderModel] = classOf[OneHotEncoderModel] @@ -50,24 +47,6 @@ class OneHotEncoderOp extends MleapOp[OneHotEncoder, OneHotEncoderModel] { } } - override def load(node: Node, model: OneHotEncoderModel) - (implicit context: BundleContext[MleapContext]): OneHotEncoder = { - val ns = node.shape.getInput(NodeShape.standardInputPort) match { - // Old version of 1HE -- need to translate serialized port names to new expectation (input -> input0) - case Some(_) ⇒ translateLegacyShape(node.shape) - - // New version of 1HE - case None ⇒ node.shape - } - klazz.getConstructor(classOf[String], classOf[types.NodeShape], Model.klazz). - newInstance(node.name, ns.asBundle: types.NodeShape, model) - } - override def model(node: OneHotEncoder): OneHotEncoderModel = node.model - private def translateLegacyShape(ns: NodeShape): NodeShape = { - val i = ns.getInput(NodeShape.standardInputPort).get - val o = ns.getOutput(NodeShape.standardOutputPort).get - NodeShape(inputs = Seq(Socket(i.port + "0", i.name)), outputs = Seq(Socket(o.port + "0", o.name))) - } } diff --git a/mleap-runtime/src/main/scala/ml/combust/mleap/bundle/ops/feature/StringIndexerOp.scala b/mleap-runtime/src/main/scala/ml/combust/mleap/bundle/ops/feature/StringIndexerOp.scala index 6c3416082..fdd596b49 100644 --- a/mleap-runtime/src/main/scala/ml/combust/mleap/bundle/ops/feature/StringIndexerOp.scala +++ b/mleap-runtime/src/main/scala/ml/combust/mleap/bundle/ops/feature/StringIndexerOp.scala @@ -1,17 +1,17 @@ package ml.combust.mleap.bundle.ops.feature import ml.combust.bundle.BundleContext -import ml.combust.mleap.core.feature.{HandleInvalid, StringIndexerModel} -import ml.combust.mleap.runtime.transformer.feature.StringIndexer -import ml.combust.bundle.op.OpModel import ml.combust.bundle.dsl._ -import ml.combust.mleap.bundle.ops.MleapOp +import ml.combust.bundle.op.OpModel +import ml.combust.mleap.bundle.ops.MultiInOutMleapOp +import ml.combust.mleap.core.feature.{HandleInvalid, StringIndexerModel} import ml.combust.mleap.runtime.MleapContext +import ml.combust.mleap.runtime.transformer.feature.StringIndexer /** * Created by hollinwilkins on 8/22/16. */ -class StringIndexerOp extends MleapOp[StringIndexer, StringIndexerModel] { +class StringIndexerOp extends MultiInOutMleapOp[StringIndexer, StringIndexerModel] { override val Model: OpModel[MleapContext, StringIndexerModel] = new OpModel[MleapContext, StringIndexerModel] { override val klazz: Class[StringIndexerModel] = classOf[StringIndexerModel] @@ -31,11 +31,11 @@ class StringIndexerOp extends MleapOp[StringIndexer, StringIndexerModel] { (implicit context: BundleContext[MleapContext]): StringIndexerModel = { val handleInvalid = model.getValue("handle_invalid").map(_.getString).map(HandleInvalid.fromString(_)).getOrElse(HandleInvalid.default) val label_length = model.getValue("labels_length").map(_.getInt).getOrElse(-1) - val labelsArray: Array[Array[String]] = label_length match { + val labelsArray: Seq[Seq[String]] = label_length match { case -1 => // backawards compatibility with spark v2 - Array(model.value("labels").getStringList.toArray) - case _ => (0 until label_length).map(i=>model.value(s"labels_array_$i").getStringList.toArray).toArray + Seq(model.value("labels").getStringList) + case _ => (0 until label_length).map(i=>model.value(s"labels_array_$i").getStringList) } StringIndexerModel(labelsArray = labelsArray, handleInvalid = handleInvalid) } diff --git a/mleap-runtime/src/main/scala/ml/combust/mleap/runtime/transformer/feature/MultiInOutTransformer.scala b/mleap-runtime/src/main/scala/ml/combust/mleap/runtime/transformer/feature/MultiInOutTransformer.scala deleted file mode 100644 index ab0b4fa27..000000000 --- a/mleap-runtime/src/main/scala/ml/combust/mleap/runtime/transformer/feature/MultiInOutTransformer.scala +++ /dev/null @@ -1,33 +0,0 @@ -package ml.combust.mleap.runtime.transformer.feature - -import ml.combust.mleap.core.types._ -import ml.combust.mleap.runtime.frame.Transformer - -/* -This Transformer trait only used when the Spark Transformer has inputCol and inputCols -, outputCol and outputCols. Because the saved shape will have port "input" instead of -"input0", "output" instead "output0". - */ -trait MultiInOutTransformer extends Transformer { - override def inputSchema: StructType = { - if (shape.getInput("input").isDefined) { - val fields = model.inputSchema.getField("input0").map { - case StructField(_, dataType) => StructField(shape.input("input").name, dataType) - }.toSeq - StructType(fields).get - } else { - super.inputSchema - } - } - - override def outputSchema: StructType = { - if (shape.getOutput("output").isDefined) { - val fields = model.outputSchema.getField("output0").map { - case StructField(_, dataType) => StructField(shape.output("output").name, dataType) - }.toSeq - StructType(fields).get - } else { - super.outputSchema - } - } -} diff --git a/mleap-runtime/src/main/scala/ml/combust/mleap/runtime/transformer/feature/StringIndexer.scala b/mleap-runtime/src/main/scala/ml/combust/mleap/runtime/transformer/feature/StringIndexer.scala index 660af4495..e71fe5624 100644 --- a/mleap-runtime/src/main/scala/ml/combust/mleap/runtime/transformer/feature/StringIndexer.scala +++ b/mleap-runtime/src/main/scala/ml/combust/mleap/runtime/transformer/feature/StringIndexer.scala @@ -12,7 +12,7 @@ import scala.util.Try */ case class StringIndexer(override val uid: String = Transformer.uniqueName("string_indexer"), override val shape: NodeShape, - override val model: StringIndexerModel) extends Transformer with MultiInOutTransformer { + override val model: StringIndexerModel) extends Transformer { private val outputs: Seq[String] = outputSchema.fields.map(_.name) private val inputs: Seq[String] = inputSchema.fields.map(_.name) private val inputSelector: StructSelector = StructSelector(inputs) diff --git a/mleap-runtime/src/test/scala/ml/combust/mleap/runtime/MleapSupportSpec.scala b/mleap-runtime/src/test/scala/ml/combust/mleap/runtime/MleapSupportSpec.scala index dd1caed2a..5ffc2d853 100644 --- a/mleap-runtime/src/test/scala/ml/combust/mleap/runtime/MleapSupportSpec.scala +++ b/mleap-runtime/src/test/scala/ml/combust/mleap/runtime/MleapSupportSpec.scala @@ -10,13 +10,13 @@ import MleapSupport._ import org.scalatest.funspec.AnyFunSpec -class MleapSupportSpec extends org.scalatest.funspec.AnyFunSpec { +class MleapSupportSpec extends AnyFunSpec { private val testDir = Files.createTempDirectory("MleapSupportSpec") private val stringIndexer = StringIndexer(shape = NodeShape(). - withStandardInput("feature"). - withStandardOutput("feature_index"), - model = StringIndexerModel(Seq("label1", "label2"))) + withInput("input0", "feature") + withOutput("output0", "feature_index"), + model = StringIndexerModel(Seq(Seq("label1", "label2")))) describe("URIBundleFileOps") { it("can save/load a bundle using a URI") { diff --git a/mleap-runtime/src/test/scala/ml/combust/mleap/runtime/javadsl/JavaDSLSpec.java b/mleap-runtime/src/test/scala/ml/combust/mleap/runtime/javadsl/JavaDSLSpec.java index 55bfd16f0..9cacd9a1e 100644 --- a/mleap-runtime/src/test/scala/ml/combust/mleap/runtime/javadsl/JavaDSLSpec.java +++ b/mleap-runtime/src/test/scala/ml/combust/mleap/runtime/javadsl/JavaDSLSpec.java @@ -1,7 +1,7 @@ package ml.combust.mleap.runtime.javadsl; -import ml.combust.mleap.core.feature.HandleInvalid$; -import ml.combust.mleap.core.feature.StringIndexerModel$; +import ml.combust.mleap.core.feature.HandleInvalid; +import ml.combust.mleap.core.feature.StringIndexerModel; import ml.combust.mleap.core.types.*; import ml.combust.mleap.runtime.MleapContext; import ml.combust.mleap.runtime.frame.DefaultLeapFrame; @@ -16,7 +16,7 @@ import ml.combust.mleap.tensor.ByteString; import org.junit.jupiter.api.Test; -import scala.collection.JavaConversions; +import scala.collection.JavaConverters; import scala.collection.immutable.ListMap; import java.io.File; @@ -50,10 +50,10 @@ private static Map createMap() { StringIndexer stringIndexer = new StringIndexer( "string_indexer", new NodeShape(new ListMap<>(), new ListMap<>()). - withStandardInput("string"). - withStandardOutput("string_index"), - StringIndexerModel$.MODULE$.apply(JavaConversions.asScalaBuffer(Collections.singletonList("hello")).toSeq(), - HandleInvalid$.MODULE$.fromString("error", true))); + withInput("input0", "string"). + withOutput("output0","string_index"), + new StringIndexerModel(JavaConverters.asScalaBuffer(Collections.singletonList(JavaConverters.asScalaBuffer(Collections.singletonList("hello")).toSeq())), + HandleInvalid.Error$.MODULE$)); DefaultLeapFrame buildFrame() { List fields = Arrays.asList(frameBuilder.createField("bool", frameBuilder.createBoolean()), @@ -101,7 +101,7 @@ public void buildALeapFrameWithAllDataTypes() { assertEquals(row.getDouble(7), 44.5, 0.0000000000001); assertEquals(row.getByteString(8), new ByteString("hello_there".getBytes())); assertEquals(row.getList(9), Arrays.asList(23, 44, 55)); - assertEquals(JavaConversions.mapAsJavaMap(row.getMap(10)), mapCol ); + assertEquals(JavaConverters.mapAsJavaMap(row.getMap(10)), mapCol ); List tensorValues = tensorSupport.toArray(row.getTensor(11)); assertEquals(tensorValues, Arrays.asList(23d, 3d, 4d)); } @@ -117,7 +117,7 @@ public void transformASingleRowUsingARowTransformer() { @Test public void createTensorFieldWithDimension() { StructField tensorField = frameBuilder.createField("tensor", frameBuilder.createTensor(frameBuilder.createBasicByte(), Arrays.asList(1, 2), true)); - assertEquals(((TensorType)tensorField.dataType()).dimensions().get(), JavaConversions.asScalaBuffer(Arrays.asList(1, 2)).toSeq()); + assertEquals(((TensorType)tensorField.dataType()).dimensions().get(), JavaConverters.asScalaBuffer(Arrays.asList(1, 2)).toSeq()); } @Test diff --git a/mleap-runtime/src/test/scala/ml/combust/mleap/runtime/transformer/feature/StringIndexerSpec.scala b/mleap-runtime/src/test/scala/ml/combust/mleap/runtime/transformer/feature/StringIndexerSpec.scala index 19cda92ec..d5177cbe7 100644 --- a/mleap-runtime/src/test/scala/ml/combust/mleap/runtime/transformer/feature/StringIndexerSpec.scala +++ b/mleap-runtime/src/test/scala/ml/combust/mleap/runtime/transformer/feature/StringIndexerSpec.scala @@ -13,10 +13,10 @@ class StringIndexerSpec extends org.scalatest.funspec.AnyFunSpec { val frame = DefaultLeapFrame(schema, dataset) val stringIndexer = StringIndexer( - shape = NodeShape.feature( - inputCol = "test_string", - outputCol = "test_index"), - model = StringIndexerModel(Seq("index1", "index2", "index3"))) + shape = NodeShape() + .withInput("input0", "test_string") + .withOutput("output0", "test_index"), + model = StringIndexerModel(Seq(Seq("index1", "index2", "index3")))) describe("#transform") { it("converts input string into an index") { @@ -29,8 +29,8 @@ class StringIndexerSpec extends org.scalatest.funspec.AnyFunSpec { } describe("with invalid input column") { - val stringIndexer2 = stringIndexer.copy(shape = NodeShape().withStandardInput("bad_input"). - withStandardOutput("output")) + val stringIndexer2 = stringIndexer.copy(shape = NodeShape().withInput("input0","bad_input"). + withOutput("output0","output")) it("returns a Failure") { assert(stringIndexer2.transform(frame).isFailure) diff --git a/mleap-spark-base/src/main/scala/org/apache/spark/ml/bundle/MultiInOutFormatSparkOp.scala b/mleap-spark-base/src/main/scala/org/apache/spark/ml/bundle/MultiInOutSparkOp.scala similarity index 82% rename from mleap-spark-base/src/main/scala/org/apache/spark/ml/bundle/MultiInOutFormatSparkOp.scala rename to mleap-spark-base/src/main/scala/org/apache/spark/ml/bundle/MultiInOutSparkOp.scala index 02868c190..2e6926391 100644 --- a/mleap-spark-base/src/main/scala/org/apache/spark/ml/bundle/MultiInOutFormatSparkOp.scala +++ b/mleap-spark-base/src/main/scala/org/apache/spark/ml/bundle/MultiInOutSparkOp.scala @@ -5,22 +5,9 @@ import ml.combust.bundle.dsl._ import org.apache.spark.ml.Transformer import org.apache.spark.ml.param.ParamValidators import org.apache.spark.ml.param.shared._ -import ml.combust.bundle.op.OpModel - import scala.reflect.ClassTag -abstract class MultiInOutOpModel[N <: Transformer with HasInputCol with HasInputCols with HasOutputCol with HasOutputCols] extends OpModel[SparkBundleContext, N] { - private def validateParams(obj: N): Unit = { - ParamValidators.checkSingleVsMultiColumnParams(obj, Seq(obj.inputCol), Seq(obj.inputCols)) - ParamValidators.checkSingleVsMultiColumnParams(obj, Seq(obj.outputCol), Seq(obj.outputCols)) - } - override def store(model: Model, obj: N)(implicit context: BundleContext[SparkBundleContext]): Model = { - validateParams(obj) - model - } -} - -abstract class MultiInOutFormatSparkOp[ +abstract class MultiInOutSparkOp[ N <: Transformer with HasInputCol with HasInputCols with HasOutputCol with HasOutputCols ](implicit ct: ClassTag[N]) extends SimpleSparkOp[N] { import NodeShape._ @@ -31,6 +18,11 @@ abstract class MultiInOutFormatSparkOp[ n } + override def shape(node: N)(implicit context: BundleContext[SparkBundleContext]): NodeShape = { + validateParams(node) + super.shape(node) + } + def sparkInputs(obj: N, shape: NodeShape): Seq[ParamSpec] = sparkInputs(shape.getInput(standardInputPort).isDefined, obj) def sparkOutputs(obj: N, shape: NodeShape): Seq[ParamSpec] = sparkOutputs(shape.getOutput(standardOutputPort).isDefined, obj) @@ -50,4 +42,9 @@ abstract class MultiInOutFormatSparkOp[ } else { Seq(ParamSpec(standardOutputPort, obj.outputCols)) } + + private def validateParams(obj: N): Unit = { + ParamValidators.checkSingleVsMultiColumnParams(obj, Seq(obj.inputCol), Seq(obj.inputCols)) + ParamValidators.checkSingleVsMultiColumnParams(obj, Seq(obj.outputCol), Seq(obj.outputCols)) + } } \ No newline at end of file diff --git a/mleap-spark-base/src/main/scala/org/apache/spark/ml/bundle/SimpleSparkOp.scala b/mleap-spark-base/src/main/scala/org/apache/spark/ml/bundle/SimpleSparkOp.scala index 19fe507b4..341a81a34 100644 --- a/mleap-spark-base/src/main/scala/org/apache/spark/ml/bundle/SimpleSparkOp.scala +++ b/mleap-spark-base/src/main/scala/org/apache/spark/ml/bundle/SimpleSparkOp.scala @@ -30,27 +30,8 @@ abstract class SimpleSparkOp[N <: Transformer](implicit ct: ClassTag[N]) extends override def shape(node: N) (implicit context: BundleContext[SparkBundleContext]): NodeShape = { - val dataset = context.context.dataset.getOrElse { - throw new IllegalArgumentException( - """ - |Must provide a transformed data frame to MLeap for serializing a pipeline. - |The transformed data frame is used to extract data types and other metadata - |required for execution. - | - |Example usage: - |``` - |val sparkTransformer: org.apache.spark.ml.Transformer - |val transformedDataset = sparkTransformer.transform(trainingDataset) - | - |implicit val sbc = SparkBundleContext().withDataset(transformedDataset) - | - |Using(BundleFile(file)) { bf => - | sparkTransformer.writeBundle.format(SerializationFormat.Json).save(bf).get - |} - |``` - """.stripMargin) - } - SparkShapeSaver(dataset, + assert(context.context.dataset.isDefined, BundleHelper.sampleDataframeMessage(klazz)) + SparkShapeSaver(context.context.dataset.get, node, sparkInputs(node), sparkOutputs(node)).asNodeShape diff --git a/mleap-spark/src/main/scala/org/apache/spark/ml/bundle/ops/feature/BinarizerOp.scala b/mleap-spark/src/main/scala/org/apache/spark/ml/bundle/ops/feature/BinarizerOp.scala index af0b1d674..f8733c6a6 100644 --- a/mleap-spark/src/main/scala/org/apache/spark/ml/bundle/ops/feature/BinarizerOp.scala +++ b/mleap-spark/src/main/scala/org/apache/spark/ml/bundle/ops/feature/BinarizerOp.scala @@ -13,7 +13,7 @@ import ml.combust.mleap.runtime.types.BundleTypeConverters._ /** * Created by fshabbir on 12/1/16. */ -class BinarizerOp extends MultiInOutFormatSparkOp[Binarizer] { +class BinarizerOp extends MultiInOutSparkOp[Binarizer] { override val Model: OpModel[SparkBundleContext, Binarizer] = new OpModel[SparkBundleContext, Binarizer] { override val klazz: Class[Binarizer] = classOf[Binarizer] diff --git a/mleap-spark/src/main/scala/org/apache/spark/ml/bundle/ops/feature/OneHotEncoderOp.scala b/mleap-spark/src/main/scala/org/apache/spark/ml/bundle/ops/feature/OneHotEncoderOp.scala index ef1dc356f..f563426df 100644 --- a/mleap-spark/src/main/scala/org/apache/spark/ml/bundle/ops/feature/OneHotEncoderOp.scala +++ b/mleap-spark/src/main/scala/org/apache/spark/ml/bundle/ops/feature/OneHotEncoderOp.scala @@ -34,7 +34,7 @@ object OneHotEncoderOp { } } -class OneHotEncoderOp extends MultiInOutFormatSparkOp[OneHotEncoderModel] { +class OneHotEncoderOp extends MultiInOutSparkOp[OneHotEncoderModel] { override val Model: OpModel[SparkBundleContext, OneHotEncoderModel] = new OpModel[SparkBundleContext, OneHotEncoderModel] { override val klazz: Class[OneHotEncoderModel] = classOf[OneHotEncoderModel] diff --git a/mleap-spark/src/main/scala/org/apache/spark/ml/bundle/ops/feature/StopWordsRemoverOp.scala b/mleap-spark/src/main/scala/org/apache/spark/ml/bundle/ops/feature/StopWordsRemoverOp.scala index fdc58fdef..380dcce32 100644 --- a/mleap-spark/src/main/scala/org/apache/spark/ml/bundle/ops/feature/StopWordsRemoverOp.scala +++ b/mleap-spark/src/main/scala/org/apache/spark/ml/bundle/ops/feature/StopWordsRemoverOp.scala @@ -4,13 +4,13 @@ import ml.combust.bundle.BundleContext import ml.combust.bundle.dsl._ import ml.combust.bundle.op.OpModel import org.apache.spark.ml.bundle.ops.OpsUtils -import org.apache.spark.ml.bundle.{MultiInOutFormatSparkOp, SimpleSparkOp, SparkBundleContext} +import org.apache.spark.ml.bundle.{MultiInOutSparkOp, SimpleSparkOp, SparkBundleContext} import org.apache.spark.ml.feature.StopWordsRemover /** * Created by mikhail on 10/16/16. */ -class StopWordsRemoverOp extends MultiInOutFormatSparkOp[StopWordsRemover] { +class StopWordsRemoverOp extends MultiInOutSparkOp[StopWordsRemover] { override val Model: OpModel[SparkBundleContext, StopWordsRemover] = new OpModel[SparkBundleContext, StopWordsRemover] { override val klazz: Class[StopWordsRemover] = classOf[StopWordsRemover] diff --git a/mleap-spark/src/main/scala/org/apache/spark/ml/bundle/ops/feature/StringIndexerOp.scala b/mleap-spark/src/main/scala/org/apache/spark/ml/bundle/ops/feature/StringIndexerOp.scala index ef9e57873..463b8c795 100644 --- a/mleap-spark/src/main/scala/org/apache/spark/ml/bundle/ops/feature/StringIndexerOp.scala +++ b/mleap-spark/src/main/scala/org/apache/spark/ml/bundle/ops/feature/StringIndexerOp.scala @@ -10,8 +10,8 @@ import org.apache.spark.ml.feature.StringIndexerModel /** * Created by hollinwilkins on 8/21/16. */ -class StringIndexerOp extends MultiInOutFormatSparkOp[StringIndexerModel] { - override val Model: OpModel[SparkBundleContext, StringIndexerModel] = new MultiInOutOpModel[StringIndexerModel] { +class StringIndexerOp extends MultiInOutSparkOp[StringIndexerModel] { + override val Model: OpModel[SparkBundleContext, StringIndexerModel] = new OpModel[SparkBundleContext, StringIndexerModel] { override val klazz: Class[StringIndexerModel] = classOf[StringIndexerModel] override def opName: String = Bundle.BuiltinOps.feature.string_indexer @@ -19,7 +19,7 @@ class StringIndexerOp extends MultiInOutFormatSparkOp[StringIndexerModel] { override def store(model: Model, obj: StringIndexerModel) (implicit context: BundleContext[SparkBundleContext]): Model = { - var result = super.store(model, obj). + var result = model. withValue("labels_length", Value.int(obj.labelsArray.length)). withValue("handle_invalid", Value.string(obj.getHandleInvalid)) obj.labelsArray.indices.foreach( diff --git a/mleap-spark/src/test/scala/org/apache/spark/ml/parity/feature/OneHotEncoderParitySpec.scala b/mleap-spark/src/test/scala/org/apache/spark/ml/parity/feature/OneHotEncoderParitySpec.scala index bdb419fde..f66576b33 100644 --- a/mleap-spark/src/test/scala/org/apache/spark/ml/parity/feature/OneHotEncoderParitySpec.scala +++ b/mleap-spark/src/test/scala/org/apache/spark/ml/parity/feature/OneHotEncoderParitySpec.scala @@ -9,6 +9,17 @@ import org.apache.spark.sql.DataFrame /** * Created by hollinwilkins on 10/30/16. */ + +class MIOOneHotEncoderParitySpec extends SparkParityBase { + override val dataset: DataFrame = baseDataset.select("state") + override val sparkTransformer: Transformer = new Pipeline() + .setStages(Array( + new StringIndexer().setInputCol("state").setOutputCol("state_index"), + new OneHotEncoder().setInputCol("state_index").setOutputCol("state_oh") + )).fit(dataset) + override val unserializedParams: Set[String] = Set("stringOrderType") +} + class OneHotEncoderParitySpec extends SparkParityBase { override val dataset: DataFrame = baseDataset.select("state") override val sparkTransformer: Transformer = @@ -24,21 +35,7 @@ class OneHotEncoderParitySpec extends SparkParityBase { override val unserializedParams = Set("stringOrderType") - it("serializes/deserializes the Spark model properly with one in/out column"){ - bundleCache = None - val additionalIgnoreParams = Set("outputCol") - val pipeline = new Pipeline() - .setStages(Array( - new StringIndexer().setInputCol("state").setOutputCol("state_index"), - new OneHotEncoder().setInputCol("state_index").setOutputCol("state_oh") - )).fit(dataset) - val sparkTransformed = pipeline.transform(dataset) - implicit val sbc = SparkBundleContext().withDataset(sparkTransformed) - val deserializedTransformer = deserializedSparkTransformer(pipeline) - checkEquality(pipeline, deserializedTransformer, additionalIgnoreParams) - equalityTest(sparkTransformed, deserializedTransformer.transform(dataset)) - bundleCache = None - } + it("fails to instantiate if the Spark model sets inputCol and inputCols"){ intercept[IllegalArgumentException] { diff --git a/mleap-spark/src/test/scala/org/apache/spark/ml/parity/feature/StringIndexerParitySpec.scala b/mleap-spark/src/test/scala/org/apache/spark/ml/parity/feature/StringIndexerParitySpec.scala index 252f347f2..4e530b940 100644 --- a/mleap-spark/src/test/scala/org/apache/spark/ml/parity/feature/StringIndexerParitySpec.scala +++ b/mleap-spark/src/test/scala/org/apache/spark/ml/parity/feature/StringIndexerParitySpec.scala @@ -2,7 +2,6 @@ package org.apache.spark.ml.parity.feature import org.apache.spark.ml.parity.SparkParityBase import org.apache.spark.ml.Transformer -import org.apache.spark.ml.bundle.SparkBundleContext import org.apache.spark.ml.feature.StringIndexer import org.apache.spark.sql.DataFrame @@ -11,30 +10,23 @@ import org.apache.spark.sql.DataFrame */ class StringIndexerParitySpec extends SparkParityBase { override val dataset: DataFrame = baseDataset.select("state") + override val unserializedParams = Set("stringOrderType") // setting to handle invalid to true override val sparkTransformer: Transformer = new StringIndexer(). setInputCol("state"). setOutputCol("state_index").setHandleInvalid("keep"). fit(dataset) - override val unserializedParams = Set("stringOrderType") - it("serializes/deserializes the Spark model properly with multiple in/out columns"){ - bundleCache = None - // outputCol has a default value of "__output, so we ignore it in this test - // since the uid will be different - val additionalIgnoreParams = Set("outputCol") - val multiColTransformer = new StringIndexer(). + +} + +class MIOStringIndexerParitySpec extends SparkParityBase { + override val dataset: DataFrame = baseDataset.select("state", "loan_title") + override val unserializedParams = Set("stringOrderType") + override val sparkTransformer: Transformer = new StringIndexer(). setInputCols(Array("state", "loan_title")). setOutputCols(Array("state_index", "loan_tile_index")). setHandleInvalid("keep"). - fit(baseDataset) - val sparkTransformed = multiColTransformer.transform(baseDataset) - implicit val sbc = SparkBundleContext().withDataset(sparkTransformed) - val deserializedTransformer = deserializedSparkTransformer(multiColTransformer) - checkEquality(multiColTransformer, deserializedTransformer, additionalIgnoreParams) - equalityTest(sparkTransformed, deserializedTransformer.transform(baseDataset)) - - bundleCache = None - } -} + fit(dataset) +} \ No newline at end of file From 6d1f5ae1d926641cd41609e83ba3aa14c14e250f Mon Sep 17 00:00:00 2001 From: "Yunhui Zhang (GitHub)" Date: Wed, 23 Oct 2024 16:09:14 -0400 Subject: [PATCH 07/15] Reformat --- .../core/feature/StringIndexerModel.scala | 6 ++-- .../spark/ml/bundle/MultiInOutSparkOp.scala | 34 +++++++++---------- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/mleap-core/src/main/scala/ml/combust/mleap/core/feature/StringIndexerModel.scala b/mleap-core/src/main/scala/ml/combust/mleap/core/feature/StringIndexerModel.scala index f9e72a692..0ea020d10 100644 --- a/mleap-core/src/main/scala/ml/combust/mleap/core/feature/StringIndexerModel.scala +++ b/mleap-core/src/main/scala/ml/combust/mleap/core/feature/StringIndexerModel.scala @@ -17,7 +17,7 @@ case class StringIndexerModel(labelsArray: Seq[Seq[String]], private val stringToIndex: Array[Map[String, Int]] = labelsArray.map(_.zipWithIndex.toMap).toArray private val keepInvalid = handleInvalid == HandleInvalid.Keep - private val invalidValue = labelsArray.map(_.length) + private val invalidValues = labelsArray.map(_.length) @deprecated("Use labelsArray instead") def labels: Seq[String] = labelsArray.head @@ -44,7 +44,7 @@ case class StringIndexerModel(labelsArray: Seq[Seq[String]], */ private def encoder(value: Any, colIdx: Int): Int = if (value == null) { if (keepInvalid) { - invalidValue(colIdx) + invalidValues(colIdx) } else { throw new NullPointerException("StringIndexer encountered NULL value. " + s"To handle NULLS, set handleInvalid to ${HandleInvalid.Keep.asParamString}") @@ -54,7 +54,7 @@ case class StringIndexerModel(labelsArray: Seq[Seq[String]], stringToIndex(colIdx).get(label) match { case Some(v) => v case None => if (keepInvalid) { - invalidValue(colIdx) + invalidValues(colIdx) } else { throw new NoSuchElementException(s"Unseen label: $label. To handle unseen labels, " + s"set handleInvalid to ${HandleInvalid.Keep.asParamString}") diff --git a/mleap-spark-base/src/main/scala/org/apache/spark/ml/bundle/MultiInOutSparkOp.scala b/mleap-spark-base/src/main/scala/org/apache/spark/ml/bundle/MultiInOutSparkOp.scala index 2e6926391..8f49e9a49 100644 --- a/mleap-spark-base/src/main/scala/org/apache/spark/ml/bundle/MultiInOutSparkOp.scala +++ b/mleap-spark-base/src/main/scala/org/apache/spark/ml/bundle/MultiInOutSparkOp.scala @@ -5,11 +5,11 @@ import ml.combust.bundle.dsl._ import org.apache.spark.ml.Transformer import org.apache.spark.ml.param.ParamValidators import org.apache.spark.ml.param.shared._ + import scala.reflect.ClassTag -abstract class MultiInOutSparkOp[ - N <: Transformer with HasInputCol with HasInputCols with HasOutputCol with HasOutputCols -](implicit ct: ClassTag[N]) extends SimpleSparkOp[N] { +abstract class MultiInOutSparkOp[N <: Transformer with HasInputCol with HasInputCols with HasOutputCol with HasOutputCols](implicit ct: ClassTag[N]) extends SimpleSparkOp[N] { + import NodeShape._ override def load(node: Node, model: N)(implicit context: BundleContext[SparkBundleContext]): N = { @@ -18,33 +18,33 @@ abstract class MultiInOutSparkOp[ n } + def sparkInputs(obj: N, shape: NodeShape): Seq[ParamSpec] = sparkInputs(shape.getInput(standardInputPort).isDefined, obj) + + def sparkInputs(hasInputCol: Boolean, obj: N): Seq[ParamSpec] = if (hasInputCol) { + Seq(ParamSpec(standardInputPort, obj.inputCol)) + } else { + Seq(ParamSpec(standardInputPort, obj.inputCols)) + } + + def sparkOutputs(obj: N, shape: NodeShape): Seq[ParamSpec] = sparkOutputs(shape.getOutput(standardOutputPort).isDefined, obj) + override def shape(node: N)(implicit context: BundleContext[SparkBundleContext]): NodeShape = { validateParams(node) super.shape(node) } - def sparkInputs(obj: N, shape: NodeShape): Seq[ParamSpec] = sparkInputs(shape.getInput(standardInputPort).isDefined, obj) - - def sparkOutputs(obj: N, shape: NodeShape): Seq[ParamSpec] = sparkOutputs(shape.getOutput(standardOutputPort).isDefined, obj) + private def validateParams(obj: N): Unit = { + ParamValidators.checkSingleVsMultiColumnParams(obj, Seq(obj.inputCol), Seq(obj.inputCols)) + ParamValidators.checkSingleVsMultiColumnParams(obj, Seq(obj.outputCol), Seq(obj.outputCols)) + } def sparkInputs(obj: N): Seq[ParamSpec] = sparkInputs(obj.isSet(obj.inputCol), obj) def sparkOutputs(obj: N): Seq[ParamSpec] = sparkOutputs(obj.isSet(obj.outputCol), obj) - def sparkInputs(hasInputCol: Boolean, obj: N): Seq[ParamSpec] = if (hasInputCol) { - Seq(ParamSpec(standardInputPort, obj.inputCol)) - } else { - Seq(ParamSpec(standardInputPort, obj.inputCols)) - } - def sparkOutputs(hasOutputCol: Boolean, obj: N): Seq[ParamSpec] = if (hasOutputCol) { Seq(ParamSpec(standardOutputPort, obj.outputCol)) } else { Seq(ParamSpec(standardOutputPort, obj.outputCols)) } - - private def validateParams(obj: N): Unit = { - ParamValidators.checkSingleVsMultiColumnParams(obj, Seq(obj.inputCol), Seq(obj.inputCols)) - ParamValidators.checkSingleVsMultiColumnParams(obj, Seq(obj.outputCol), Seq(obj.outputCols)) - } } \ No newline at end of file From 40f0c568bf53ebe9f86189f909dcb28a357a3510 Mon Sep 17 00:00:00 2001 From: "Yunhui Zhang (GitHub)" Date: Wed, 23 Oct 2024 16:25:44 -0400 Subject: [PATCH 08/15] format code using ide formatter --- .../mleap/bundle/ops/MultiInOutMleapOp.scala | 2 +- .../bundle/ops/feature/StringIndexerOp.scala | 18 +++++++++--------- .../feature/StringIndexerParitySpec.scala | 4 +--- 3 files changed, 11 insertions(+), 13 deletions(-) diff --git a/mleap-runtime/src/main/scala/ml/combust/mleap/bundle/ops/MultiInOutMleapOp.scala b/mleap-runtime/src/main/scala/ml/combust/mleap/bundle/ops/MultiInOutMleapOp.scala index 8c662e500..938a5bba8 100644 --- a/mleap-runtime/src/main/scala/ml/combust/mleap/bundle/ops/MultiInOutMleapOp.scala +++ b/mleap-runtime/src/main/scala/ml/combust/mleap/bundle/ops/MultiInOutMleapOp.scala @@ -26,4 +26,4 @@ abstract class MultiInOutMleapOp[N <: Transformer, M <: AnyRef](implicit ct: Cla val o = ns.getOutput(NodeShape.standardOutputPort).get NodeShape(inputs = Seq(Socket(i.port + "0", i.name)), outputs = Seq(Socket(o.port + "0", o.name))) } -} \ No newline at end of file +} diff --git a/mleap-runtime/src/main/scala/ml/combust/mleap/bundle/ops/feature/StringIndexerOp.scala b/mleap-runtime/src/main/scala/ml/combust/mleap/bundle/ops/feature/StringIndexerOp.scala index fdd596b49..a694aeb67 100644 --- a/mleap-runtime/src/main/scala/ml/combust/mleap/bundle/ops/feature/StringIndexerOp.scala +++ b/mleap-runtime/src/main/scala/ml/combust/mleap/bundle/ops/feature/StringIndexerOp.scala @@ -9,8 +9,8 @@ import ml.combust.mleap.runtime.MleapContext import ml.combust.mleap.runtime.transformer.feature.StringIndexer /** - * Created by hollinwilkins on 8/22/16. - */ + * Created by hollinwilkins on 8/22/16. + */ class StringIndexerOp extends MultiInOutMleapOp[StringIndexer, StringIndexerModel] { override val Model: OpModel[MleapContext, StringIndexerModel] = new OpModel[MleapContext, StringIndexerModel] { override val klazz: Class[StringIndexerModel] = classOf[StringIndexerModel] @@ -19,12 +19,12 @@ class StringIndexerOp extends MultiInOutMleapOp[StringIndexer, StringIndexerMode override def store(model: Model, obj: StringIndexerModel) (implicit context: BundleContext[MleapContext]): Model = { - val m = model. - withValue("labels_length", Value.int(1)). - withValue("handle_invalid", Value.string(obj.handleInvalid.asParamString)) - obj.labelsArray.zipWithIndex.foldLeft(m){ - case (m, (labels, i)) => m.withValue(s"labels_array_$i", Value.stringList(labels)) - } + val m = model. + withValue("labels_length", Value.int(1)). + withValue("handle_invalid", Value.string(obj.handleInvalid.asParamString)) + obj.labelsArray.zipWithIndex.foldLeft(m) { + case (m, (labels, i)) => m.withValue(s"labels_array_$i", Value.stringList(labels)) + } } override def load(model: Model) @@ -35,7 +35,7 @@ class StringIndexerOp extends MultiInOutMleapOp[StringIndexer, StringIndexerMode case -1 => // backawards compatibility with spark v2 Seq(model.value("labels").getStringList) - case _ => (0 until label_length).map(i=>model.value(s"labels_array_$i").getStringList) + case _ => (0 until label_length).map(i => model.value(s"labels_array_$i").getStringList) } StringIndexerModel(labelsArray = labelsArray, handleInvalid = handleInvalid) } diff --git a/mleap-spark/src/test/scala/org/apache/spark/ml/parity/feature/StringIndexerParitySpec.scala b/mleap-spark/src/test/scala/org/apache/spark/ml/parity/feature/StringIndexerParitySpec.scala index 4e530b940..62f2aee25 100644 --- a/mleap-spark/src/test/scala/org/apache/spark/ml/parity/feature/StringIndexerParitySpec.scala +++ b/mleap-spark/src/test/scala/org/apache/spark/ml/parity/feature/StringIndexerParitySpec.scala @@ -17,8 +17,6 @@ class StringIndexerParitySpec extends SparkParityBase { setInputCol("state"). setOutputCol("state_index").setHandleInvalid("keep"). fit(dataset) - - } class MIOStringIndexerParitySpec extends SparkParityBase { @@ -29,4 +27,4 @@ class MIOStringIndexerParitySpec extends SparkParityBase { setOutputCols(Array("state_index", "loan_tile_index")). setHandleInvalid("keep"). fit(dataset) -} \ No newline at end of file +} From 95c385ea56e406c90e0a9c0934f5b1b77373c774 Mon Sep 17 00:00:00 2001 From: austinzh Date: Fri, 25 Oct 2024 14:08:44 -0400 Subject: [PATCH 09/15] Add new line at the end of file --- .../ml/combust/mleap/core/feature/StringIndexerModel.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mleap-core/src/main/scala/ml/combust/mleap/core/feature/StringIndexerModel.scala b/mleap-core/src/main/scala/ml/combust/mleap/core/feature/StringIndexerModel.scala index 0ea020d10..7ac0e932b 100644 --- a/mleap-core/src/main/scala/ml/combust/mleap/core/feature/StringIndexerModel.scala +++ b/mleap-core/src/main/scala/ml/combust/mleap/core/feature/StringIndexerModel.scala @@ -84,4 +84,4 @@ case class StringIndexerModel(labelsArray: Seq[Seq[String]], } StructType(f).get } -} \ No newline at end of file +} From 06113199b0c39ae09b04f493d244d3f5e11b7f48 Mon Sep 17 00:00:00 2001 From: austinzh Date: Fri, 25 Oct 2024 14:09:18 -0400 Subject: [PATCH 10/15] Add new line EOF --- .../mleap/runtime/transformer/feature/StringIndexer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mleap-runtime/src/main/scala/ml/combust/mleap/runtime/transformer/feature/StringIndexer.scala b/mleap-runtime/src/main/scala/ml/combust/mleap/runtime/transformer/feature/StringIndexer.scala index e71fe5624..dd1faf553 100644 --- a/mleap-runtime/src/main/scala/ml/combust/mleap/runtime/transformer/feature/StringIndexer.scala +++ b/mleap-runtime/src/main/scala/ml/combust/mleap/runtime/transformer/feature/StringIndexer.scala @@ -37,4 +37,4 @@ case class StringIndexer(override val uid: String = Transformer.uniqueName("stri withColumns(builder) } } -} \ No newline at end of file +} From f3f29198b4ecb390b7cd95224ed5fceb7b3f6e62 Mon Sep 17 00:00:00 2001 From: austinzh Date: Fri, 25 Oct 2024 14:10:00 -0400 Subject: [PATCH 11/15] Add new line EOF --- .../scala/org/apache/spark/ml/bundle/MultiInOutSparkOp.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mleap-spark-base/src/main/scala/org/apache/spark/ml/bundle/MultiInOutSparkOp.scala b/mleap-spark-base/src/main/scala/org/apache/spark/ml/bundle/MultiInOutSparkOp.scala index 8f49e9a49..c43458b50 100644 --- a/mleap-spark-base/src/main/scala/org/apache/spark/ml/bundle/MultiInOutSparkOp.scala +++ b/mleap-spark-base/src/main/scala/org/apache/spark/ml/bundle/MultiInOutSparkOp.scala @@ -47,4 +47,4 @@ abstract class MultiInOutSparkOp[N <: Transformer with HasInputCol with HasInput } else { Seq(ParamSpec(standardOutputPort, obj.outputCols)) } -} \ No newline at end of file +} From 62b4199ac26aad54c4cd1cd9f14ee85171837f39 Mon Sep 17 00:00:00 2001 From: "Yunhui Zhang (GitHub)" Date: Wed, 6 Nov 2024 15:12:28 -0500 Subject: [PATCH 12/15] Fix outputCol not set issue If outputCol not set, we should check outputCols first If outputCols not set, we get the default value of outputCol --- .devcontainer/devcontainer.json | 9 +++------ Makefile | 4 ++++ .../scala/ml/combust/bundle/BundleFile.scala | 1 - .../spark/ml/bundle/MultiInOutSparkOp.scala | 20 +++++++++---------- .../feature/StringIndexerParitySpec.scala | 10 ++++++++++ 5 files changed, 27 insertions(+), 17 deletions(-) diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 88106e5c9..44af9fdb8 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -1,9 +1,6 @@ { - "image": "mcr.microsoft.com/devcontainers/universal:2", - "features": { - "ghcr.io/devcontainers-extra/features/scala-sdkman:2": { - "version": "2.12.18", - "jdkVersion": "11" - } + "build": { + "dockerfile": "Dockerfile", + "context": "." } } diff --git a/Makefile b/Makefile index 45cbabecc..995dcdf42 100644 --- a/Makefile +++ b/Makefile @@ -31,6 +31,10 @@ test_python37: test_python38: source scripts/scala_classpath_for_python.sh && make -C python py38_test +.PHONY: test_python +test_python: test_python37 test_python38 + @echo "All python tests run successfully" + .PHONY: test test: test_executor test_benchmark test_xgboost_runtime test_xgboost_spark test_root_sbt_project test_python @echo "All tests run successfully" diff --git a/bundle-ml/src/main/scala/ml/combust/bundle/BundleFile.scala b/bundle-ml/src/main/scala/ml/combust/bundle/BundleFile.scala index 5fcd00bf4..09e1e7ed9 100644 --- a/bundle-ml/src/main/scala/ml/combust/bundle/BundleFile.scala +++ b/bundle-ml/src/main/scala/ml/combust/bundle/BundleFile.scala @@ -132,7 +132,6 @@ case class BundleFile(fs: FileSystem, } override def finalize(): Unit = { - super.finalize() close() } diff --git a/mleap-spark-base/src/main/scala/org/apache/spark/ml/bundle/MultiInOutSparkOp.scala b/mleap-spark-base/src/main/scala/org/apache/spark/ml/bundle/MultiInOutSparkOp.scala index c43458b50..55186697f 100644 --- a/mleap-spark-base/src/main/scala/org/apache/spark/ml/bundle/MultiInOutSparkOp.scala +++ b/mleap-spark-base/src/main/scala/org/apache/spark/ml/bundle/MultiInOutSparkOp.scala @@ -18,15 +18,15 @@ abstract class MultiInOutSparkOp[N <: Transformer with HasInputCol with HasInput n } - def sparkInputs(obj: N, shape: NodeShape): Seq[ParamSpec] = sparkInputs(shape.getInput(standardInputPort).isDefined, obj) + def sparkInputs(obj: N, shape: NodeShape): Seq[ParamSpec] = sparkInputs(shape.getInput(standardInputPort+"0").isDefined, obj) - def sparkInputs(hasInputCol: Boolean, obj: N): Seq[ParamSpec] = if (hasInputCol) { - Seq(ParamSpec(standardInputPort, obj.inputCol)) - } else { + def sparkInputs(hasInputCols: Boolean, obj: N): Seq[ParamSpec] = if (hasInputCols) { Seq(ParamSpec(standardInputPort, obj.inputCols)) + } else { + Seq(ParamSpec(standardInputPort, obj.inputCol)) } - def sparkOutputs(obj: N, shape: NodeShape): Seq[ParamSpec] = sparkOutputs(shape.getOutput(standardOutputPort).isDefined, obj) + def sparkOutputs(obj: N, shape: NodeShape): Seq[ParamSpec] = sparkOutputs(shape.getOutput(standardOutputPort+"0").isDefined, obj) override def shape(node: N)(implicit context: BundleContext[SparkBundleContext]): NodeShape = { validateParams(node) @@ -38,13 +38,13 @@ abstract class MultiInOutSparkOp[N <: Transformer with HasInputCol with HasInput ParamValidators.checkSingleVsMultiColumnParams(obj, Seq(obj.outputCol), Seq(obj.outputCols)) } - def sparkInputs(obj: N): Seq[ParamSpec] = sparkInputs(obj.isSet(obj.inputCol), obj) + def sparkInputs(obj: N): Seq[ParamSpec] = sparkInputs(obj.isSet(obj.inputCols), obj) - def sparkOutputs(obj: N): Seq[ParamSpec] = sparkOutputs(obj.isSet(obj.outputCol), obj) + def sparkOutputs(obj: N): Seq[ParamSpec] = sparkOutputs(obj.isSet(obj.outputCols), obj) - def sparkOutputs(hasOutputCol: Boolean, obj: N): Seq[ParamSpec] = if (hasOutputCol) { - Seq(ParamSpec(standardOutputPort, obj.outputCol)) - } else { + def sparkOutputs(hasOutputCols: Boolean, obj: N): Seq[ParamSpec] = if (hasOutputCols) { Seq(ParamSpec(standardOutputPort, obj.outputCols)) + } else { + Seq(ParamSpec(standardOutputPort, obj.outputCol)) } } diff --git a/mleap-spark/src/test/scala/org/apache/spark/ml/parity/feature/StringIndexerParitySpec.scala b/mleap-spark/src/test/scala/org/apache/spark/ml/parity/feature/StringIndexerParitySpec.scala index 62f2aee25..89c28a034 100644 --- a/mleap-spark/src/test/scala/org/apache/spark/ml/parity/feature/StringIndexerParitySpec.scala +++ b/mleap-spark/src/test/scala/org/apache/spark/ml/parity/feature/StringIndexerParitySpec.scala @@ -18,6 +18,16 @@ class StringIndexerParitySpec extends SparkParityBase { setOutputCol("state_index").setHandleInvalid("keep"). fit(dataset) } +class StringIndexerNoOutputColParitySpec extends SparkParityBase { + override val dataset: DataFrame = baseDataset.select("state") + override val unserializedParams = Set("stringOrderType") + + // setting to handle invalid to true + override val sparkTransformer: Transformer = new StringIndexer(). + setInputCol("state"). + setHandleInvalid("keep"). + fit(dataset) +} class MIOStringIndexerParitySpec extends SparkParityBase { override val dataset: DataFrame = baseDataset.select("state", "loan_title") From 1b72a349948999f737a0b3dd7ec0f50edf9c01ef Mon Sep 17 00:00:00 2001 From: "Yunhui Zhang (GitHub)" Date: Wed, 6 Nov 2024 15:46:44 -0500 Subject: [PATCH 13/15] Add Dockerfile --- .devcontainer/Dockerfile | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 .devcontainer/Dockerfile diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile new file mode 100644 index 000000000..20e09bc16 --- /dev/null +++ b/.devcontainer/Dockerfile @@ -0,0 +1,34 @@ +FROM --platform=linux/amd64 ubuntu:22.04 + +ENV TZ=America/New_York +RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone + +# Install necessary packages and tools +RUN DEBIAN_FRONTEND=noninteractive \ + apt-get update && apt-get install -y \ + curl \ + openjdk-11-jdk \ + unzip \ + zip \ + git \ + software-properties-common \ + python3-pip && \ + add-apt-repository ppa:deadsnakes/ppa && \ + apt-get update && apt-get install -y \ + python3.7 \ + python3.8 \ + python3.10 \ + tox \ + && rm -rf /var/lib/apt/lists/* + +RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.10 1 + +# Install SDKMAN and Scala +SHELL ["/bin/bash", "-c"] + +RUN curl -s "https://get.sdkman.io" | bash \ + && source "$HOME/.sdkman/bin/sdkman-init.sh" \ + && sdk install scala 2.12.18 \ + && sdk install sbt + +# Set up your environment as needed From c90f52ccb569665aa2b5d198094ace28a5bc4a22 Mon Sep 17 00:00:00 2001 From: austinzh Date: Wed, 6 Nov 2024 19:18:18 -0500 Subject: [PATCH 14/15] Add sshd into devcontainer --- .devcontainer/devcontainer.json | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 44af9fdb8..ec3872e1c 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -1,4 +1,9 @@ { + "features": { + "ghcr.io/devcontainers/features/sshd:1": { + "version": "latest" + } + }, "build": { "dockerfile": "Dockerfile", "context": "." From fb0c52e5ce4a31d04cadca43f7edaf1d28a6eee7 Mon Sep 17 00:00:00 2001 From: austinzh Date: Wed, 6 Nov 2024 19:47:24 -0500 Subject: [PATCH 15/15] Change user to mleap --- .devcontainer/Dockerfile | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index 20e09bc16..a3f8afb5c 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -23,6 +23,16 @@ RUN DEBIAN_FRONTEND=noninteractive \ RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.10 1 +ARG USERNAME=mleap +ARG USER_UID=1000 +ARG USER_GID=$USER_UID + +# Create the user +RUN groupadd --gid $USER_GID $USERNAME \ + && useradd --uid $USER_UID --gid $USER_GID -m $USERNAME --shell /usr/bin/bash + +USER $USERNAME + # Install SDKMAN and Scala SHELL ["/bin/bash", "-c"]