-
Notifications
You must be signed in to change notification settings - Fork 314
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #875 from austinzh/u/austinzh/stringindexer_multip…
…le_colums Add MultiInOutTransformer
- Loading branch information
Showing
23 changed files
with
317 additions
and
307 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
FROM --platform=linux/amd64 ubuntu:22.04 | ||
|
||
ENV TZ=America/New_York | ||
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone | ||
|
||
# Install necessary packages and tools | ||
RUN DEBIAN_FRONTEND=noninteractive \ | ||
apt-get update && apt-get install -y \ | ||
curl \ | ||
openjdk-11-jdk \ | ||
unzip \ | ||
zip \ | ||
git \ | ||
software-properties-common \ | ||
python3-pip && \ | ||
add-apt-repository ppa:deadsnakes/ppa && \ | ||
apt-get update && apt-get install -y \ | ||
python3.7 \ | ||
python3.8 \ | ||
python3.10 \ | ||
tox \ | ||
&& rm -rf /var/lib/apt/lists/* | ||
|
||
RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.10 1 | ||
|
||
ARG USERNAME=mleap | ||
ARG USER_UID=1000 | ||
ARG USER_GID=$USER_UID | ||
|
||
# Create the user | ||
RUN groupadd --gid $USER_GID $USERNAME \ | ||
&& useradd --uid $USER_UID --gid $USER_GID -m $USERNAME --shell /usr/bin/bash | ||
|
||
USER $USERNAME | ||
|
||
# Install SDKMAN and Scala | ||
SHELL ["/bin/bash", "-c"] | ||
|
||
RUN curl -s "https://get.sdkman.io" | bash \ | ||
&& source "$HOME/.sdkman/bin/sdkman-init.sh" \ | ||
&& sdk install scala 2.12.18 \ | ||
&& sdk install sbt | ||
|
||
# Set up your environment as needed |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
{ | ||
"features": { | ||
"ghcr.io/devcontainers/features/sshd:1": { | ||
"version": "latest" | ||
} | ||
}, | ||
"build": { | ||
"dockerfile": "Dockerfile", | ||
"context": "." | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
29 changes: 29 additions & 0 deletions
29
mleap-runtime/src/main/scala/ml/combust/mleap/bundle/ops/MultiInOutMleapOp.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
package ml.combust.mleap.bundle.ops | ||
|
||
import ml.bundle.Socket | ||
import ml.combust.bundle.BundleContext | ||
import ml.combust.bundle.dsl.{Node, NodeShape} | ||
import ml.combust.mleap.core.types | ||
import ml.combust.mleap.runtime.MleapContext | ||
import ml.combust.mleap.runtime.frame.Transformer | ||
import ml.combust.mleap.runtime.types.BundleTypeConverters._ | ||
|
||
import scala.reflect.ClassTag | ||
|
||
abstract class MultiInOutMleapOp[N <: Transformer, M <: AnyRef](implicit ct: ClassTag[N]) extends MleapOp[N, M] { | ||
override def load(node: Node, model: M)(implicit context: BundleContext[MleapContext]): N = { | ||
val ns = node.shape.getInput(NodeShape.standardInputPort) match { // Old version need to translate serialized port names to new expectation (input -> input0) | ||
case Some(_) ⇒ translateLegacyShape(node.shape) | ||
|
||
// New version | ||
case None ⇒ node.shape | ||
} | ||
klazz.getConstructor(classOf[String], classOf[types.NodeShape], Model.klazz).newInstance(node.name, ns.asBundle: types.NodeShape, model) | ||
} | ||
|
||
private def translateLegacyShape(ns: NodeShape): NodeShape = { | ||
val i = ns.getInput(NodeShape.standardInputPort).get | ||
val o = ns.getOutput(NodeShape.standardOutputPort).get | ||
NodeShape(inputs = Seq(Socket(i.port + "0", i.name)), outputs = Seq(Socket(o.port + "0", o.name))) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
35 changes: 18 additions & 17 deletions
35
mleap-runtime/src/main/scala/ml/combust/mleap/bundle/ops/feature/StringIndexerOp.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,45 +1,46 @@ | ||
package ml.combust.mleap.bundle.ops.feature | ||
|
||
import ml.combust.bundle.BundleContext | ||
import ml.combust.mleap.core.feature.{HandleInvalid, StringIndexerModel} | ||
import ml.combust.mleap.runtime.transformer.feature.StringIndexer | ||
import ml.combust.bundle.op.OpModel | ||
import ml.combust.bundle.dsl._ | ||
import ml.combust.mleap.bundle.ops.MleapOp | ||
import ml.combust.bundle.op.OpModel | ||
import ml.combust.mleap.bundle.ops.MultiInOutMleapOp | ||
import ml.combust.mleap.core.feature.{HandleInvalid, StringIndexerModel} | ||
import ml.combust.mleap.runtime.MleapContext | ||
import ml.combust.mleap.runtime.transformer.feature.StringIndexer | ||
|
||
/** | ||
* Created by hollinwilkins on 8/22/16. | ||
*/ | ||
class StringIndexerOp extends MleapOp[StringIndexer, StringIndexerModel] { | ||
* Created by hollinwilkins on 8/22/16. | ||
*/ | ||
class StringIndexerOp extends MultiInOutMleapOp[StringIndexer, StringIndexerModel] { | ||
override val Model: OpModel[MleapContext, StringIndexerModel] = new OpModel[MleapContext, StringIndexerModel] { | ||
override val klazz: Class[StringIndexerModel] = classOf[StringIndexerModel] | ||
|
||
override def opName: String = Bundle.BuiltinOps.feature.string_indexer | ||
|
||
override def store(model: Model, obj: StringIndexerModel) | ||
(implicit context: BundleContext[MleapContext]): Model = { | ||
model. | ||
withValue("labels_length", Value.int(1)). | ||
withValue("labels_array_0", Value.stringList(obj.labels)). | ||
withValue("handle_invalid", Value.string(obj.handleInvalid.asParamString)) | ||
|
||
val m = model. | ||
withValue("labels_length", Value.int(1)). | ||
withValue("handle_invalid", Value.string(obj.handleInvalid.asParamString)) | ||
obj.labelsArray.zipWithIndex.foldLeft(m) { | ||
case (m, (labels, i)) => m.withValue(s"labels_array_$i", Value.stringList(labels)) | ||
} | ||
} | ||
|
||
override def load(model: Model) | ||
(implicit context: BundleContext[MleapContext]): StringIndexerModel = { | ||
val handleInvalid = model.getValue("handle_invalid").map(_.getString).map(HandleInvalid.fromString(_)).getOrElse(HandleInvalid.default) | ||
val label_length = model.getValue("labels_length").map(_.getInt).getOrElse(-1) | ||
val labels: Seq[String] = label_length match { | ||
val labelsArray: Seq[Seq[String]] = label_length match { | ||
case -1 => | ||
// backawards compatibility with spark v2 | ||
model.value("labels").getStringList | ||
case 1 => model.value("labels_array_0").getStringList | ||
case _ => throw new UnsupportedOperationException("Multi-input StringIndexer not supported yet.") | ||
Seq(model.value("labels").getStringList) | ||
case _ => (0 until label_length).map(i => model.value(s"labels_array_$i").getStringList) | ||
} | ||
StringIndexerModel(labels = labels, handleInvalid = handleInvalid) | ||
StringIndexerModel(labelsArray = labelsArray, handleInvalid = handleInvalid) | ||
} | ||
} | ||
|
||
override def model(node: StringIndexer): StringIndexerModel = node.model | ||
|
||
} |
Oops, something went wrong.