Skip to content

Commit

Permalink
Merge pull request #875 from austinzh/u/austinzh/stringindexer_multip…
Browse files Browse the repository at this point in the history
…le_colums

Add MultiInOutTransformer
  • Loading branch information
austinzh authored Nov 12, 2024
2 parents 43993e1 + fb0c52e commit 11c4f54
Show file tree
Hide file tree
Showing 23 changed files with 317 additions and 307 deletions.
44 changes: 44 additions & 0 deletions .devcontainer/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
FROM --platform=linux/amd64 ubuntu:22.04

ENV TZ=America/New_York
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone

# Install necessary packages and tools
RUN DEBIAN_FRONTEND=noninteractive \
apt-get update && apt-get install -y \
curl \
openjdk-11-jdk \
unzip \
zip \
git \
software-properties-common \
python3-pip && \
add-apt-repository ppa:deadsnakes/ppa && \
apt-get update && apt-get install -y \
python3.7 \
python3.8 \
python3.10 \
tox \
&& rm -rf /var/lib/apt/lists/*

RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.10 1

ARG USERNAME=mleap
ARG USER_UID=1000
ARG USER_GID=$USER_UID

# Create the user
RUN groupadd --gid $USER_GID $USERNAME \
&& useradd --uid $USER_UID --gid $USER_GID -m $USERNAME --shell /usr/bin/bash

USER $USERNAME

# Install SDKMAN and Scala
SHELL ["/bin/bash", "-c"]

RUN curl -s "https://get.sdkman.io" | bash \
&& source "$HOME/.sdkman/bin/sdkman-init.sh" \
&& sdk install scala 2.12.18 \
&& sdk install sbt

# Set up your environment as needed
11 changes: 11 additions & 0 deletions .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
{
"features": {
"ghcr.io/devcontainers/features/sshd:1": {
"version": "latest"
}
},
"build": {
"dockerfile": "Dockerfile",
"context": "."
}
}
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ test_python37:
test_python38:
source scripts/scala_classpath_for_python.sh && make -C python py38_test

.PHONY: test_python
test_python: test_python37 test_python38
@echo "All python tests run successfully"

.PHONY: test
test: test_executor test_benchmark test_xgboost_runtime test_xgboost_spark test_root_sbt_project test_python
@echo "All tests run successfully"
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ case class BundleFile(fs: FileSystem,
}

override def finalize(): Unit = {
super.finalize()
close()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ object FileUtil {
def rmRF(toRemove: Path): Array[(String, Boolean)] = {
def removeElement(path: Path): (String, Boolean) = {
val result = Try {
Files.deleteIfExists(toRemove)
Files.deleteIfExists(path)
} match {
case Failure(_) => false
case Success(value) => value
Expand Down
Original file line number Diff line number Diff line change
@@ -1,40 +1,60 @@
package ml.combust.mleap.core.feature

import ml.combust.mleap.core.Model
import ml.combust.mleap.core.types.{ScalarType, StructType}
import ml.combust.mleap.core.types.{ScalarType, StructField, StructType}

/** Class for string indexer model.
*
* String indexer converts a string into an integer representation.
*
* @param labels list of labels that can be indexed
* @param labelsArray Array of ordered list of labels, corresponding to indices to be assigned for each input
* @param handleInvalid how to handle invalid values (unseen or NULL labels): 'error' (throw an error),
* 'skip' (skips invalid data)
* or 'keep' (put invalid data in a special bucket at index labels.size
*/
case class StringIndexerModel(labels: Seq[String],
case class StringIndexerModel(labelsArray: Seq[Seq[String]],
handleInvalid: HandleInvalid = HandleInvalid.Error) extends Model {
val stringToIndex: Map[String, Int] = labels.zipWithIndex.toMap

private val stringToIndex: Array[Map[String, Int]] = labelsArray.map(_.zipWithIndex.toMap).toArray
private val keepInvalid = handleInvalid == HandleInvalid.Keep
private val invalidValue = labels.length
private val invalidValues = labelsArray.map(_.length)

@deprecated("Use labelsArray instead")
def labels: Seq[String] = labelsArray.head

/** Convert all strings into its integer representation.
*
* @param values labels to index
* @return indexes of labels
*/
def apply(values: Seq[Any]): Seq[Double] = values.zipWithIndex.map {
case (v: Any, i: Int) => encoder(v, i).toDouble
case (null, i: Int) => encoder(null, i).toDouble
}

def contains(values: Seq[Any]): Boolean = {
values.zipWithIndex.forall {
case (key, i) => stringToIndex(i).contains(key.toString)
}
}
/** Convert a string into its integer representation.
*
* @param value label to index
* @return index of label
*/
def apply(value: Any): Int = if (value == null) {
*
* @param value label to index
* @return index of label
*/
private def encoder(value: Any, colIdx: Int): Int = if (value == null) {
if (keepInvalid) {
invalidValue
invalidValues(colIdx)
} else {
throw new NullPointerException("StringIndexer encountered NULL value. " +
s"To handle NULLS, set handleInvalid to ${HandleInvalid.Keep.asParamString}")
}
} else {
val label = value.toString
stringToIndex.get(label) match {
stringToIndex(colIdx).get(label) match {
case Some(v) => v
case None => if (keepInvalid) {
invalidValue
invalidValues(colIdx)
} else {
throw new NoSuchElementException(s"Unseen label: $label. To handle unseen labels, " +
s"set handleInvalid to ${HandleInvalid.Keep.asParamString}")
Expand All @@ -43,12 +63,25 @@ case class StringIndexerModel(labels: Seq[String],
}

/** Create a [[ml.combust.mleap.core.feature.ReverseStringIndexerModel]] from this model.
*
* ReverseStringIndexer only support one input
* @return reverse string indexer of this string indexer
*/
def toReverse: ReverseStringIndexerModel = ReverseStringIndexerModel(labels)
def toReverse: ReverseStringIndexerModel = {
require(labelsArray.length == 1)
ReverseStringIndexerModel(labelsArray.head)
}

override def inputSchema: StructType = StructType("input" -> ScalarType.String).get
override def inputSchema: StructType = {
val f = labelsArray.zipWithIndex.map {
case (_, i) => StructField(s"input$i", ScalarType.String)
}
StructType(f).get
}

override def outputSchema: StructType = StructType("output" -> ScalarType.Double.nonNullable).get
override def outputSchema: StructType = {
val f = labelsArray.zipWithIndex.map {
case (_, i) => StructField(s"output$i", ScalarType.Double.nonNullable)
}
StructType(f).get
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,48 +10,48 @@ import org.scalatest.prop.TableDrivenPropertyChecks
class StringIndexerModelSpec extends org.scalatest.funspec.AnyFunSpec with TableDrivenPropertyChecks {
describe("#apply") {
it("returns the index of the string") {
val indexer = StringIndexerModel(Array("hello", "there", "dude"))
val indexer = StringIndexerModel(Seq(Seq("hello", "there", "dude")))

assert(indexer("hello") == 0.0)
assert(indexer("there") == 1.0)
assert(indexer("dude") == 2.0)
assert(indexer(Seq("hello")).head == 0.0)
assert(indexer(Seq("there")).head == 1.0)
assert(indexer(Seq("dude")).head == 2.0)
}

it("throws NullPointerException when encounters NULL/None and handleInvalid is not keep") {
val indexer = StringIndexerModel(Array("hello"))
val indexer = StringIndexerModel(Seq(Seq("hello")))
assertThrows[NullPointerException](indexer(null))
}

it("throws NoSuchElementException when encounters unseen label and handleInvalid is not keep") {
val indexer = StringIndexerModel(Array("hello"))
val unseenLabels = Table("unknown1", "unknown2")
val indexer = StringIndexerModel(Seq(Seq("hello")))
val unseenLabels = Table("label", "unknown1", "unknown2")

forAll(unseenLabels) { (label: Any) =>
intercept[NoSuchElementException] {
indexer(label)
indexer(Seq(label))
}
}
}

it("returns default index for HandleInvalid.keep mode") {
val indexer = StringIndexerModel(Array("hello", "there", "dude"), handleInvalid = HandleInvalid.Keep)
val indexer = StringIndexerModel(Seq(Seq("hello", "there", "dude")), handleInvalid = HandleInvalid.Keep)
val invalidLabels = Table("unknown", "other unknown", null, None)

forAll(invalidLabels) { (label: Any) =>
assert(indexer(label) == 3.0)
assert(indexer(Seq(label)).head == 3.0)
}
}
}

describe("input/output schema") {
val indexer = StringIndexerModel(Array("hello", "there", "dude"))
val indexer = StringIndexerModel(Seq(Seq("hello", "there", "dude")))

it("has the right input schema") {
assert(indexer.inputSchema.fields == Seq(StructField("input", ScalarType.String)))
assert(indexer.inputSchema.fields == Seq(StructField("input0", ScalarType.String)))
}

it("has the right output schema") {
assert(indexer.outputSchema.fields == Seq(StructField("output", ScalarType.Double.nonNullable)))
assert(indexer.outputSchema.fields == Seq(StructField("output0", ScalarType.Double.nonNullable)))
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package ml.combust.mleap.bundle.ops

import ml.bundle.Socket
import ml.combust.bundle.BundleContext
import ml.combust.bundle.dsl.{Node, NodeShape}
import ml.combust.mleap.core.types
import ml.combust.mleap.runtime.MleapContext
import ml.combust.mleap.runtime.frame.Transformer
import ml.combust.mleap.runtime.types.BundleTypeConverters._

import scala.reflect.ClassTag

abstract class MultiInOutMleapOp[N <: Transformer, M <: AnyRef](implicit ct: ClassTag[N]) extends MleapOp[N, M] {
override def load(node: Node, model: M)(implicit context: BundleContext[MleapContext]): N = {
val ns = node.shape.getInput(NodeShape.standardInputPort) match { // Old version need to translate serialized port names to new expectation (input -> input0)
case Some(_) translateLegacyShape(node.shape)

// New version
case None node.shape
}
klazz.getConstructor(classOf[String], classOf[types.NodeShape], Model.klazz).newInstance(node.name, ns.asBundle: types.NodeShape, model)
}

private def translateLegacyShape(ns: NodeShape): NodeShape = {
val i = ns.getInput(NodeShape.standardInputPort).get
val o = ns.getOutput(NodeShape.standardOutputPort).get
NodeShape(inputs = Seq(Socket(i.port + "0", i.name)), outputs = Seq(Socket(o.port + "0", o.name)))
}
}
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
package ml.combust.mleap.bundle.ops.feature

import ml.bundle.Socket
import ml.combust.bundle.BundleContext
import ml.combust.bundle.dsl._
import ml.combust.bundle.op.OpModel
import ml.combust.mleap.bundle.ops.MleapOp
import ml.combust.mleap.bundle.ops.MultiInOutMleapOp
import ml.combust.mleap.core.feature.{HandleInvalid, OneHotEncoderModel}
import ml.combust.mleap.core.types
import ml.combust.mleap.runtime.MleapContext
import ml.combust.mleap.runtime.transformer.feature.OneHotEncoder
import ml.combust.mleap.runtime.types.BundleTypeConverters._

/**
* Created by hollinwilkins on 10/24/16.
*/
class OneHotEncoderOp extends MleapOp[OneHotEncoder, OneHotEncoderModel] {
class OneHotEncoderOp extends MultiInOutMleapOp[OneHotEncoder, OneHotEncoderModel] {
override val Model: OpModel[MleapContext, OneHotEncoderModel] = new OpModel[MleapContext, OneHotEncoderModel] {
override val klazz: Class[OneHotEncoderModel] = classOf[OneHotEncoderModel]

Expand Down Expand Up @@ -50,24 +47,6 @@ class OneHotEncoderOp extends MleapOp[OneHotEncoder, OneHotEncoderModel] {
}
}

override def load(node: Node, model: OneHotEncoderModel)
(implicit context: BundleContext[MleapContext]): OneHotEncoder = {
val ns = node.shape.getInput(NodeShape.standardInputPort) match {
// Old version of 1HE -- need to translate serialized port names to new expectation (input -> input0)
case Some(_) translateLegacyShape(node.shape)

// New version of 1HE
case None node.shape
}
klazz.getConstructor(classOf[String], classOf[types.NodeShape], Model.klazz).
newInstance(node.name, ns.asBundle: types.NodeShape, model)
}

override def model(node: OneHotEncoder): OneHotEncoderModel = node.model

private def translateLegacyShape(ns: NodeShape): NodeShape = {
val i = ns.getInput(NodeShape.standardInputPort).get
val o = ns.getOutput(NodeShape.standardOutputPort).get
NodeShape(inputs = Seq(Socket(i.port + "0", i.name)), outputs = Seq(Socket(o.port + "0", o.name)))
}
}
Original file line number Diff line number Diff line change
@@ -1,45 +1,46 @@
package ml.combust.mleap.bundle.ops.feature

import ml.combust.bundle.BundleContext
import ml.combust.mleap.core.feature.{HandleInvalid, StringIndexerModel}
import ml.combust.mleap.runtime.transformer.feature.StringIndexer
import ml.combust.bundle.op.OpModel
import ml.combust.bundle.dsl._
import ml.combust.mleap.bundle.ops.MleapOp
import ml.combust.bundle.op.OpModel
import ml.combust.mleap.bundle.ops.MultiInOutMleapOp
import ml.combust.mleap.core.feature.{HandleInvalid, StringIndexerModel}
import ml.combust.mleap.runtime.MleapContext
import ml.combust.mleap.runtime.transformer.feature.StringIndexer

/**
* Created by hollinwilkins on 8/22/16.
*/
class StringIndexerOp extends MleapOp[StringIndexer, StringIndexerModel] {
* Created by hollinwilkins on 8/22/16.
*/
class StringIndexerOp extends MultiInOutMleapOp[StringIndexer, StringIndexerModel] {
override val Model: OpModel[MleapContext, StringIndexerModel] = new OpModel[MleapContext, StringIndexerModel] {
override val klazz: Class[StringIndexerModel] = classOf[StringIndexerModel]

override def opName: String = Bundle.BuiltinOps.feature.string_indexer

override def store(model: Model, obj: StringIndexerModel)
(implicit context: BundleContext[MleapContext]): Model = {
model.
withValue("labels_length", Value.int(1)).
withValue("labels_array_0", Value.stringList(obj.labels)).
withValue("handle_invalid", Value.string(obj.handleInvalid.asParamString))

val m = model.
withValue("labels_length", Value.int(1)).
withValue("handle_invalid", Value.string(obj.handleInvalid.asParamString))
obj.labelsArray.zipWithIndex.foldLeft(m) {
case (m, (labels, i)) => m.withValue(s"labels_array_$i", Value.stringList(labels))
}
}

override def load(model: Model)
(implicit context: BundleContext[MleapContext]): StringIndexerModel = {
val handleInvalid = model.getValue("handle_invalid").map(_.getString).map(HandleInvalid.fromString(_)).getOrElse(HandleInvalid.default)
val label_length = model.getValue("labels_length").map(_.getInt).getOrElse(-1)
val labels: Seq[String] = label_length match {
val labelsArray: Seq[Seq[String]] = label_length match {
case -1 =>
// backawards compatibility with spark v2
model.value("labels").getStringList
case 1 => model.value("labels_array_0").getStringList
case _ => throw new UnsupportedOperationException("Multi-input StringIndexer not supported yet.")
Seq(model.value("labels").getStringList)
case _ => (0 until label_length).map(i => model.value(s"labels_array_$i").getStringList)
}
StringIndexerModel(labels = labels, handleInvalid = handleInvalid)
StringIndexerModel(labelsArray = labelsArray, handleInvalid = handleInvalid)
}
}

override def model(node: StringIndexer): StringIndexerModel = node.model

}
Loading

0 comments on commit 11c4f54

Please sign in to comment.