diff --git a/README.md b/README.md
index 72962698..0d895dd2 100644
--- a/README.md
+++ b/README.md
@@ -9,21 +9,21 @@ Java library and command-line application for converting Apache Spark ML pipelin
* Feature extractors, transformers and selectors:
* [`feature.Binarizer`](https://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/feature/Binarizer.html)
* [`feature.Bucketizer`](https://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/feature/Bucketizer.html)
- * [`feature.ChiSqSelectorModel`](http://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/feature/ChiSqSelectorModel.html) (the result of fitting a `feature.ChiSqSelector`)
+ * [`feature.ChiSqSelectorModel`](https://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/feature/ChiSqSelectorModel.html) (the result of fitting a `feature.ChiSqSelector`)
* [`feature.ColumnPruner`](https://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/feature/ColumnPruner.html)
* [`feature.CountVectorizerModel`](https://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/feature/CountVectorizerModel.html) (the result of fitting a `feature.CountVectorizer`)
* [`feature.IDFModel`](https://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/feature/IDFModel.html) (the result of fitting a `feature.IDF`)
* [`feature.ImputerModel`](https://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/feature/ImputerModel.html) (the result of fitting a `feature.Imputer`)
* [`feature.IndexToString`](https://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/feature/IndexToString.html)
- * [`feature.Interaction`](http://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/feature/Interaction.html)
- * [`feature.MaxAbsScalerModel`](http://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/feature/MaxAbsScalerModel.html) (the result of fitting a `feature.MaxAbsScaler`)
- * [`feature.MinMaxScalerModel`](http://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/feature/MinMaxScalerModel.html) (the result of fitting a `feature.MinMaxScaler`)
+ * [`feature.Interaction`](https://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/feature/Interaction.html)
+ * [`feature.MaxAbsScalerModel`](https://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/feature/MaxAbsScalerModel.html) (the result of fitting a `feature.MaxAbsScaler`)
+ * [`feature.MinMaxScalerModel`](https://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/feature/MinMaxScalerModel.html) (the result of fitting a `feature.MinMaxScaler`)
* [`feature.NGram`](https://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/feature/NGram.html)
* [`feature.OneHotEncoder`](https://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/feature/OneHotEncoder.html)
* [`feature.OneHotEncoderModel`](https://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/feature/OneHotEncoderModel.html) (the result of fitting a `feature.OneHotEncoderEstimator`)
* [`feature.PCAModel`](https://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/feature/PCAModel.html) (the result of fitting a `feature.PCA`)
- * [`feature.QuantileDiscretizer`](http://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/feature/QuantileDiscretizer.html)
- * [`feature.RegexTokenizer`](http://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/feature/RegexTokenizer.html)
+ * [`feature.QuantileDiscretizer`](https://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/feature/QuantileDiscretizer.html)
+ * [`feature.RegexTokenizer`](https://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/feature/RegexTokenizer.html)
* [`feature.RFormulaModel`](https://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/feature/RFormulaModel.html) (the result of fitting a `feature.RFormula`)
* [`feature.SQLTransformer`](https://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/feature/SQLTransformer.html)
* [`feature.StandardScalerModel`](https://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/feature/StandardScalerModel.html) (the result of fitting a `feature.StandardScaler`)
@@ -33,15 +33,15 @@ Java library and command-line application for converting Apache Spark ML pipelin
* [`feature.VectorAssembler`](https://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/feature/VectorAssembler.html)
* [`feature.VectorAttributeRewriter`](https://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/feature/VectorAttributeRewriter.html)
* [`feature.VectorIndexerModel`](https://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/feature/VectorIndexerModel.html) (the result of fitting a `feature.VectorIndexer`)
- * [`feature.VectorSlicer`](http://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/feature/VectorSlicer.html)
+ * [`feature.VectorSlicer`](https://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/feature/VectorSlicer.html)
* Prediction models:
* [`classification.DecisionTreeClassificationModel`](https://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/classification/DecisionTreeClassificationModel.html)
- * [`classification.GBTClassificationModel`](http://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/classification/GBTClassificationModel.html)
+ * [`classification.GBTClassificationModel`](https://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/classification/GBTClassificationModel.html)
* [`classification.LogisticRegressionModel`](https://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/classification/LogisticRegressionModel.html)
- * [`classification.MultilayerPerceptronClassificationModel`](http://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/classification/MultilayerPerceptronClassificationModel.html)
- * [`classification.NaiveBayesModel`](http://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/classification/NaiveBayesModel.html)
+ * [`classification.MultilayerPerceptronClassificationModel`](https://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/classification/MultilayerPerceptronClassificationModel.html)
+ * [`classification.NaiveBayesModel`](https://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/classification/NaiveBayesModel.html)
* [`classification.RandomForestClassificationModel`](https://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/classification/RandomForestClassificationModel.html)
- * [`clustering.KMeansModel`](http://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/clustering/KMeansModel.html)
+ * [`clustering.KMeansModel`](https://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/clustering/KMeansModel.html)
* [`regression.DecisionTreeRegressionModel`](https://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/regression/DecisionTreeRegressionModel.html)
* [`regression.GBTRegressionModel`](https://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/regression/GBTRegressionModel.html)
* [`regression.GeneralizedLinearRegressionModel`](https://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/regression/GeneralizedLinearRegressionModel.html)
@@ -66,7 +66,7 @@ Java library and command-line application for converting Apache Spark ML pipelin
## Library ##
-JPMML-SparkML library JAR file (together with accompanying Java source and Javadocs JAR files) is released via [Maven Central Repository](http://repo1.maven.org/maven2/org/jpmml/).
+JPMML-SparkML library JAR file (together with accompanying Java source and Javadocs JAR files) is released via [Maven Central Repository](https://repo1.maven.org/maven2/org/jpmml/).
The current version is **1.4.6** (2 October, 2018).
@@ -155,7 +155,7 @@ The downside of shading is that such relocated classes are incompatible with oth
## Example application ##
-Enter the project root directory and build using [Apache Maven](http://maven.apache.org/):
+Enter the project root directory and build using [Apache Maven](https://maven.apache.org/):
```
mvn clean install
```
@@ -198,7 +198,7 @@ JAXBUtil.marshalPMML(pmml, new StreamResult(System.out));
Please refer to the following resources for more ideas and code examples:
-* [Converting Apache Spark ML pipeline models to PMML](http://openscoring.io/blog/2018/07/09/converting_sparkml_pipeline_pmml/)
+* [Converting Apache Spark ML pipeline models to PMML](https://openscoring.io/blog/2018/07/09/converting_sparkml_pipeline_pmml/)
## Example application ##
@@ -218,7 +218,7 @@ spark-submit --master local --class org.jpmml.sparkml.Main target/jpmml-sparkml-
# License #
-JPMML-SparkML is dual-licensed under the [GNU Affero General Public License (AGPL) version 3.0](http://www.gnu.org/licenses/agpl-3.0.html), and a commercial license.
+JPMML-SparkML is dual-licensed under the [GNU Affero General Public License (AGPL) version 3.0](https://www.gnu.org/licenses/agpl-3.0.html), and a commercial license.
# Additional information #
diff --git a/pom.xml b/pom.xml
index 02e61f59..5236acad 100644
--- a/pom.xml
+++ b/pom.xml
@@ -71,7 +71,7 @@
org.jpmml
jpmml-converter
- 1.3.3
+ 1.3.4
@@ -91,13 +91,13 @@
org.jpmml
pmml-evaluator
- 1.4.3
+ 1.4.4
test
org.jpmml
pmml-evaluator-test
- 1.4.3
+ 1.4.4
test
diff --git a/src/main/java/org/jpmml/sparkml/DatasetUtil.java b/src/main/java/org/jpmml/sparkml/DatasetUtil.java
new file mode 100644
index 00000000..18bc9e17
--- /dev/null
+++ b/src/main/java/org/jpmml/sparkml/DatasetUtil.java
@@ -0,0 +1,88 @@
+/*
+ * Copyright (c) 2018 Villu Ruusmann
+ *
+ * This file is part of JPMML-SparkML
+ *
+ * JPMML-SparkML is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * JPMML-SparkML is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with JPMML-SparkML. If not, see .
+ */
+package org.jpmml.sparkml;
+
+import java.util.Collections;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.catalog.Catalog;
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
+import org.apache.spark.sql.execution.QueryExecution;
+import org.apache.spark.sql.types.BooleanType;
+import org.apache.spark.sql.types.DoubleType;
+import org.apache.spark.sql.types.IntegralType;
+import org.apache.spark.sql.types.StringType;
+import org.apache.spark.sql.types.StructType;
+import org.dmg.pmml.DataType;
+
+public class DatasetUtil {
+
+ private DatasetUtil(){
+ }
+
+ static
+ public LogicalPlan createAnalyzedLogicalPlan(SparkSession sparkSession, StructType schema, String statement){
+ String tableName = "sql2pmml_" + DatasetUtil.ID.getAndIncrement();
+
+ statement = statement.replace("__THIS__", tableName);
+
+ Dataset dataset = sparkSession.createDataFrame(Collections.emptyList(), schema);
+
+ dataset.createOrReplaceTempView(tableName);
+
+ try {
+ QueryExecution queryExecution = sparkSession.sql(statement).queryExecution();
+
+ return queryExecution.analyzed();
+ } finally {
+ Catalog catalog = sparkSession.catalog();
+
+ catalog.dropTempView(tableName);
+ }
+ }
+
+ static
+ public DataType translateDataType(org.apache.spark.sql.types.DataType sparkDataType){
+
+ if(sparkDataType instanceof StringType){
+ return DataType.STRING;
+ } else
+
+ if(sparkDataType instanceof IntegralType){
+ return DataType.INTEGER;
+ } else
+
+ if(sparkDataType instanceof DoubleType){
+ return DataType.DOUBLE;
+ } else
+
+ if(sparkDataType instanceof BooleanType){
+ return DataType.BOOLEAN;
+ } else
+
+ {
+ throw new IllegalArgumentException("Expected string, integral, double or boolean type, got " + sparkDataType.typeName() + " type");
+ }
+ }
+
+ private static final AtomicInteger ID = new AtomicInteger(1);
+}
\ No newline at end of file
diff --git a/src/main/java/org/jpmml/sparkml/ExpressionMapping.java b/src/main/java/org/jpmml/sparkml/ExpressionMapping.java
deleted file mode 100644
index ab3f7afa..00000000
--- a/src/main/java/org/jpmml/sparkml/ExpressionMapping.java
+++ /dev/null
@@ -1,62 +0,0 @@
-/*
- * Copyright (c) 2018 Villu Ruusmann
- *
- * This file is part of JPMML-SparkML
- *
- * JPMML-SparkML is free software: you can redistribute it and/or modify
- * it under the terms of the GNU Affero General Public License as published by
- * the Free Software Foundation, either version 3 of the License, or
- * (at your option) any later version.
- *
- * JPMML-SparkML is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- * GNU Affero General Public License for more details.
- *
- * You should have received a copy of the GNU Affero General Public License
- * along with JPMML-SparkML. If not, see .
- */
-package org.jpmml.sparkml;
-
-import org.apache.spark.sql.catalyst.expressions.Expression;
-import org.dmg.pmml.DataType;
-
-public class ExpressionMapping {
-
- private Expression from = null;
-
- private org.dmg.pmml.Expression to = null;
-
- private DataType dataType = null;
-
-
- public ExpressionMapping(Expression from, org.dmg.pmml.Expression to, DataType dataType){
- setFrom(from);
- setTo(to);
- setDataType(dataType);
- }
-
- public Expression getFrom(){
- return this.from;
- }
-
- private void setFrom(Expression from){
- this.from = from;
- }
-
- public org.dmg.pmml.Expression getTo(){
- return this.to;
- }
-
- private void setTo(org.dmg.pmml.Expression to){
- this.to = to;
- }
-
- public DataType getDataType(){
- return this.dataType;
- }
-
- private void setDataType(DataType dataType){
- this.dataType = dataType;
- }
-}
\ No newline at end of file
diff --git a/src/main/java/org/jpmml/sparkml/ExpressionTranslator.java b/src/main/java/org/jpmml/sparkml/ExpressionTranslator.java
index 056bdf36..443c3d92 100644
--- a/src/main/java/org/jpmml/sparkml/ExpressionTranslator.java
+++ b/src/main/java/org/jpmml/sparkml/ExpressionTranslator.java
@@ -18,17 +18,18 @@
*/
package org.jpmml.sparkml;
+import java.util.Iterator;
import java.util.List;
-import org.apache.spark.sql.catalyst.FunctionIdentifier;
-import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias;
-import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute;
-import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction;
import org.apache.spark.sql.catalyst.expressions.Add;
+import org.apache.spark.sql.catalyst.expressions.Alias;
import org.apache.spark.sql.catalyst.expressions.And;
+import org.apache.spark.sql.catalyst.expressions.AttributeReference;
import org.apache.spark.sql.catalyst.expressions.BinaryArithmetic;
import org.apache.spark.sql.catalyst.expressions.BinaryComparison;
import org.apache.spark.sql.catalyst.expressions.BinaryOperator;
+import org.apache.spark.sql.catalyst.expressions.CaseWhen;
+import org.apache.spark.sql.catalyst.expressions.Cast;
import org.apache.spark.sql.catalyst.expressions.Divide;
import org.apache.spark.sql.catalyst.expressions.EqualTo;
import org.apache.spark.sql.catalyst.expressions.Expression;
@@ -46,62 +47,57 @@
import org.apache.spark.sql.catalyst.expressions.Or;
import org.apache.spark.sql.catalyst.expressions.Subtract;
import org.apache.spark.sql.catalyst.expressions.UnaryExpression;
-import org.apache.spark.sql.types.BooleanType;
-import org.apache.spark.sql.types.DoubleType;
-import org.apache.spark.sql.types.IntegralType;
-import org.apache.spark.sql.types.StringType;
+import org.apache.spark.sql.catalyst.expressions.UnaryMinus;
import org.dmg.pmml.Apply;
+import org.dmg.pmml.Constant;
import org.dmg.pmml.DataType;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.FieldRef;
+import org.dmg.pmml.HasDataType;
import org.jpmml.converter.PMMLUtil;
+import org.jpmml.converter.visitors.ExpressionCompactor;
+import scala.Option;
+import scala.Tuple2;
import scala.collection.JavaConversions;
public class ExpressionTranslator {
static
- public ExpressionMapping translate(Expression expression, DataTypeResolver dataTypeResolver){
-
- if(expression instanceof UnresolvedAlias){
- UnresolvedAlias unresolvedAlias = (UnresolvedAlias)expression;
-
- Expression child = unresolvedAlias.child();
+ public org.dmg.pmml.Expression translate(Expression expression){
+ return translate(expression, true);
+ }
- return translate(child, dataTypeResolver);
- } else
+ static
+ public org.dmg.pmml.Expression translate(Expression expression, boolean compact){
+ org.dmg.pmml.Expression pmmlExpression = translateInternal(expression);
- if(expression instanceof UnresolvedAttribute){
- UnresolvedAttribute unresolvedAttribute = (UnresolvedAttribute)expression;
+ if(compact){
+ ExpressionCompactor expressionCompactor = new ExpressionCompactor();
- String name = unresolvedAttribute.name();
+ expressionCompactor.applyTo(pmmlExpression);
+ }
- return new ExpressionMapping(unresolvedAttribute, new FieldRef(FieldName.create(name)), dataTypeResolver.getDataType(name));
- } else
+ return pmmlExpression;
+ }
- if(expression instanceof UnresolvedFunction){
- UnresolvedFunction unresolvedFunction = (UnresolvedFunction)expression;
+ static
+ private org.dmg.pmml.Expression translateInternal(Expression expression){
- FunctionIdentifier name = unresolvedFunction.name();
- List children = JavaConversions.seqAsJavaList(unresolvedFunction.children());
+ if(expression instanceof Alias){
+ Alias alias = (Alias)expression;
- String identifier = name.identifier();
+ Expression child = alias.child();
- if("IF".equalsIgnoreCase(identifier) && children.size() == 3){
- return translate(new If(children.get(0), children.get(1), children.get(2)), dataTypeResolver);
- } else
+ return translateInternal(child);
+ } // End if
- if("ISNOTNULL".equalsIgnoreCase(identifier) && children.size() == 1){
- return translate(new IsNotNull(children.get(0)), dataTypeResolver);
- } else
+ if(expression instanceof AttributeReference){
+ AttributeReference attributeReference = (AttributeReference)expression;
- if("ISNULL".equalsIgnoreCase(identifier) && children.size() == 1){
- return translate(new IsNull(children.get(0)), dataTypeResolver);
- } else
+ String name = attributeReference.name();
- {
- throw new IllegalArgumentException(String.valueOf(unresolvedFunction));
- }
- } // End if
+ return new FieldRef(FieldName.create(name));
+ } else
if(expression instanceof BinaryOperator){
BinaryOperator binaryOperator = (BinaryOperator)expression;
@@ -111,27 +107,31 @@ public ExpressionMapping translate(Expression expression, DataTypeResolver dataT
Expression left = binaryOperator.left();
Expression right = binaryOperator.right();
- DataType dataType;
-
if(expression instanceof And || expression instanceof Or){
- symbol = symbol.toLowerCase();
- dataType = DataType.BOOLEAN;
+ switch(symbol){
+ case "&&":
+ symbol = "and";
+ break;
+ case "||":
+ symbol = "or";
+ break;
+ default:
+ throw new IllegalArgumentException(String.valueOf(binaryOperator));
+ }
} else
if(expression instanceof Add || expression instanceof Divide || expression instanceof Multiply || expression instanceof Subtract){
BinaryArithmetic binaryArithmetic = (BinaryArithmetic)binaryOperator;
- if((left.dataType()).acceptsType(right.dataType())){
- dataType = translateDataType(left.dataType());
- } else
-
- if((right.dataType()).acceptsType(left.dataType())){
- dataType = translateDataType(right.dataType());
- } else
-
- {
- throw new IllegalArgumentException(String.valueOf(binaryArithmetic));
+ switch(symbol){
+ case "+":
+ case "/":
+ case "*":
+ case "-":
+ break;
+ default:
+ throw new IllegalArgumentException(String.valueOf(binaryArithmetic));
}
} else
@@ -157,15 +157,77 @@ public ExpressionMapping translate(Expression expression, DataTypeResolver dataT
default:
throw new IllegalArgumentException(String.valueOf(binaryComparison));
}
-
- dataType = DataType.BOOLEAN;
} else
{
throw new IllegalArgumentException(String.valueOf(binaryOperator));
}
- return new ExpressionMapping(binaryOperator, PMMLUtil.createApply(symbol, translateChild(left, dataTypeResolver), translateChild(right, dataTypeResolver)), dataType);
+ return PMMLUtil.createApply(symbol, translateInternal(left), translateInternal(right));
+ } else
+
+ if(expression instanceof CaseWhen){
+ CaseWhen caseWhen = (CaseWhen)expression;
+
+ List> branches = JavaConversions.seqAsJavaList(caseWhen.branches());
+
+ Option elseValue = caseWhen.elseValue();
+
+ Apply apply = null;
+
+ Iterator> branchIt = branches.iterator();
+
+ Apply prevBranchApply = null;
+
+ do {
+ Tuple2 branch = branchIt.next();
+
+ Expression predicate = branch._1();
+ Expression value = branch._2();
+
+ Apply branchApply = PMMLUtil.createApply("if")
+ .addExpressions(translateInternal(predicate), translateInternal(value));
+
+ if(apply == null){
+ apply = branchApply;
+ } // End if
+
+ if(prevBranchApply != null){
+ prevBranchApply.addExpressions(branchApply);
+ }
+
+ prevBranchApply = branchApply;
+ } while(branchIt.hasNext());
+
+ if(elseValue.isDefined()){
+ Expression value = elseValue.get();
+
+ prevBranchApply.addExpressions(translateInternal(value));
+ }
+
+ return apply;
+ } else
+
+ if(expression instanceof Cast){
+ Cast cast = (Cast)expression;
+
+ Expression child = cast.child();
+
+ DataType dataType = DatasetUtil.translateDataType(cast.dataType());
+
+ org.dmg.pmml.Expression pmmlExpression = translateInternal(child);
+
+ if(pmmlExpression instanceof HasDataType){
+ HasDataType> hasDataType = (HasDataType>)pmmlExpression;
+
+ hasDataType.setDataType(dataType);
+
+ return pmmlExpression;
+ } else
+
+ {
+ throw new IllegalArgumentException(String.valueOf(cast));
+ }
} else
if(expression instanceof If){
@@ -176,16 +238,8 @@ public ExpressionMapping translate(Expression expression, DataTypeResolver dataT
Expression trueValue = _if.trueValue();
Expression falseValue = _if.falseValue();
- if(!(trueValue.dataType()).sameType(falseValue.dataType())){
- throw new IllegalArgumentException(String.valueOf(_if));
- }
-
- DataType dataType = translateDataType(trueValue.dataType());
-
- Apply apply = PMMLUtil.createApply("if", translateChild(predicate, dataTypeResolver))
- .addExpressions(translateChild(trueValue, dataTypeResolver), translateChild(falseValue, dataTypeResolver));
-
- return new ExpressionMapping(_if, apply, dataType);
+ return PMMLUtil.createApply("if", translateInternal(predicate))
+ .addExpressions(translateInternal(trueValue), translateInternal(falseValue));
} else
if(expression instanceof In){
@@ -195,13 +249,13 @@ public ExpressionMapping translate(Expression expression, DataTypeResolver dataT
List elements = JavaConversions.seqAsJavaList(in.list());
- Apply apply = PMMLUtil.createApply("isIn", translateChild(value, dataTypeResolver));
+ Apply apply = PMMLUtil.createApply("isIn", translateInternal(value));
for(Expression element : elements){
- apply.addExpressions(translateChild(element, dataTypeResolver));
+ apply.addExpressions(translateInternal(element));
}
- return new ExpressionMapping(in, apply, DataType.BOOLEAN);
+ return apply;
} else
if(expression instanceof Literal){
@@ -209,9 +263,9 @@ public ExpressionMapping translate(Expression expression, DataTypeResolver dataT
Object value = literal.value();
- DataType dataType = translateDataType(literal.dataType());
+ DataType dataType = DatasetUtil.translateDataType(literal.dataType());
- return new ExpressionMapping(literal, PMMLUtil.createConstant(value, dataType), dataType);
+ return PMMLUtil.createConstant(value, dataType);
} else
if(expression instanceof Not){
@@ -219,7 +273,7 @@ public ExpressionMapping translate(Expression expression, DataTypeResolver dataT
Expression child = not.child();
- return new ExpressionMapping(not, PMMLUtil.createApply("not", translateChild(child, dataTypeResolver)), DataType.BOOLEAN);
+ return PMMLUtil.createApply("not", translateInternal(child));
} else
if(expression instanceof UnaryExpression){
@@ -228,57 +282,38 @@ public ExpressionMapping translate(Expression expression, DataTypeResolver dataT
Expression child = unaryExpression.child();
if(expression instanceof IsNotNull){
- return new ExpressionMapping(unaryExpression, PMMLUtil.createApply("isNotMissing", translateChild(child, dataTypeResolver)), DataType.BOOLEAN);
+ return PMMLUtil.createApply("isNotMissing", translateInternal(child));
} else
if(expression instanceof IsNull){
- return new ExpressionMapping(unaryExpression, PMMLUtil.createApply("isMissing", translateChild(child, dataTypeResolver)), DataType.BOOLEAN);
+ return PMMLUtil.createApply("isMissing", translateInternal(child));
} else
- {
- throw new IllegalArgumentException(String.valueOf(unaryExpression));
- }
- } else
+ if(expression instanceof UnaryMinus){
+ UnaryMinus unaryMinus = (UnaryMinus)unaryExpression;
- {
- throw new IllegalArgumentException(String.valueOf(expression));
- }
- }
-
- static
- private org.dmg.pmml.Expression translateChild(Expression expression, DataTypeResolver dataTypeResolver){
- ExpressionMapping expressionMapping = translate(expression, dataTypeResolver);
-
- return expressionMapping.getTo();
- }
+ org.dmg.pmml.Expression pmmlExpression = translateInternal(child);
- static
- private DataType translateDataType(org.apache.spark.sql.types.DataType sparkDataType){
+ if(pmmlExpression instanceof Constant){
+ Constant constant = (Constant)pmmlExpression;
- if(sparkDataType instanceof StringType){
- return DataType.STRING;
- } else
+ constant.setValue("-" + constant.getValue());
- if(sparkDataType instanceof IntegralType){
- return DataType.INTEGER;
- } else
+ return constant;
+ } else
- if(sparkDataType instanceof DoubleType){
- return DataType.DOUBLE;
- } else
+ {
+ return PMMLUtil.createApply("*", PMMLUtil.createConstant(-1), pmmlExpression);
+ }
+ } else
- if(sparkDataType instanceof BooleanType){
- return DataType.BOOLEAN;
+ {
+ throw new IllegalArgumentException(String.valueOf(unaryExpression));
+ }
} else
{
- throw new IllegalArgumentException("Expected string, integral, double or boolean type, got " + sparkDataType.typeName() + " type");
+ throw new IllegalArgumentException(String.valueOf(expression));
}
}
-
- static
- public interface DataTypeResolver {
-
- DataType getDataType(String name);
- }
}
\ No newline at end of file
diff --git a/src/main/java/org/jpmml/sparkml/ModelConverter.java b/src/main/java/org/jpmml/sparkml/ModelConverter.java
index c38ebbb1..3e402a91 100644
--- a/src/main/java/org/jpmml/sparkml/ModelConverter.java
+++ b/src/main/java/org/jpmml/sparkml/ModelConverter.java
@@ -85,7 +85,7 @@ public Schema encodeSchema(SparkMLEncoder encoder){
if(feature instanceof CategoricalFeature){
CategoricalFeature categoricalFeature = (CategoricalFeature)feature;
- DataField dataField = encoder.getDataField(categoricalFeature.getName());
+ DataField dataField = (DataField)categoricalFeature.getField();
label = new CategoricalLabel(dataField);
} else
diff --git a/src/main/java/org/jpmml/sparkml/PMMLBuilder.java b/src/main/java/org/jpmml/sparkml/PMMLBuilder.java
index 8277ab7a..f3c55345 100644
--- a/src/main/java/org/jpmml/sparkml/PMMLBuilder.java
+++ b/src/main/java/org/jpmml/sparkml/PMMLBuilder.java
@@ -59,10 +59,8 @@
import org.dmg.pmml.PMML;
import org.dmg.pmml.ResultFeature;
import org.dmg.pmml.VerificationField;
-import org.dmg.pmml.mining.MiningModel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.ModelUtil;
-import org.jpmml.converter.Schema;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.model.MetroJAXBUtil;
@@ -149,39 +147,14 @@ public PMML build(){
List postProcessorNames = new ArrayList<>(derivedFields.keySet());
postProcessorNames.removeAll(preProcessorNames);
- org.dmg.pmml.Model rootModel;
+ org.dmg.pmml.Model model;
if(models.size() == 1){
- rootModel = Iterables.getOnlyElement(models);
+ model = Iterables.getOnlyElement(models);
} else
if(models.size() > 1){
- List targetMiningFields = new ArrayList<>();
-
- for(org.dmg.pmml.Model model : models){
- MiningSchema miningSchema = model.getMiningSchema();
-
- List miningFields = miningSchema.getMiningFields();
- for(MiningField miningField : miningFields){
- MiningField.UsageType usageType = miningField.getUsageType();
-
- switch(usageType){
- case PREDICTED:
- case TARGET:
- targetMiningFields.add(miningField);
- break;
- default:
- break;
- }
- }
- }
-
- MiningSchema miningSchema = new MiningSchema(targetMiningFields);
-
- MiningModel miningModel = MiningModelUtil.createModelChain(models, new Schema(null, Collections.emptyList()))
- .setMiningSchema(miningSchema);
-
- rootModel = miningModel;
+ model = MiningModelUtil.createModelChain(models);
} else
{
@@ -193,7 +166,7 @@ public PMML build(){
encoder.removeDerivedField(postProcessorName);
- Output output = ModelUtil.ensureOutput(rootModel);
+ Output output = ModelUtil.ensureOutput(model);
OutputField outputField = new OutputField(derivedField.getName(), derivedField.getDataType())
.setOpType(derivedField.getOpType())
@@ -203,7 +176,7 @@ public PMML build(){
output.addOutputFields(outputField);
}
- PMML pmml = encoder.encodePMML(rootModel);
+ PMML pmml = encoder.encodePMML(model);
if((predictionColumns.size() > 0 || probabilityColumns.size() > 0) && (verification != null)){
Dataset dataset = verification.getDataset();
@@ -213,7 +186,7 @@ public PMML build(){
List inputColumns = new ArrayList<>();
- MiningSchema miningSchema = rootModel.getMiningSchema();
+ MiningSchema miningSchema = model.getMiningSchema();
List miningFields = miningSchema.getMiningFields();
for(MiningField miningField : miningFields){
@@ -262,7 +235,7 @@ public PMML build(){
}
}
- rootModel.setModelVerification(ModelUtil.createModelVerification(data));
+ model.setModelVerification(ModelUtil.createModelVerification(data));
}
return pmml;
diff --git a/src/main/java/org/jpmml/sparkml/TermFeature.java b/src/main/java/org/jpmml/sparkml/TermFeature.java
index 4e88974d..a4894077 100644
--- a/src/main/java/org/jpmml/sparkml/TermFeature.java
+++ b/src/main/java/org/jpmml/sparkml/TermFeature.java
@@ -26,7 +26,6 @@
import org.dmg.pmml.Constant;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DefineFunction;
-import org.dmg.pmml.DerivedField;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.OpType;
@@ -57,16 +56,7 @@ public TermFeature(PMMLEncoder encoder, DefineFunction defineFunction, Feature f
@Override
public ContinuousFeature toContinuousFeature(){
- PMMLEncoder encoder = ensureEncoder();
-
- DerivedField derivedField = encoder.getDerivedField(getName());
- if(derivedField == null){
- Apply apply = createApply();
-
- derivedField = encoder.createDerivedField(getName(), OpType.CONTINUOUS, getDataType(), apply);
- }
-
- return new ContinuousFeature(encoder, derivedField);
+ return toContinuousFeature(getName(), getDataType(), () -> createApply());
}
public WeightedTermFeature toWeightedTermFeature(Number weight){
diff --git a/src/main/java/org/jpmml/sparkml/feature/InteractionConverter.java b/src/main/java/org/jpmml/sparkml/feature/InteractionConverter.java
index 4ff47fb8..f1231155 100644
--- a/src/main/java/org/jpmml/sparkml/feature/InteractionConverter.java
+++ b/src/main/java/org/jpmml/sparkml/feature/InteractionConverter.java
@@ -25,6 +25,7 @@
import org.apache.spark.ml.feature.Interaction;
import org.dmg.pmml.DataType;
import org.dmg.pmml.FieldName;
+import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.InteractionFeature;
import org.jpmml.sparkml.FeatureConverter;
@@ -40,7 +41,7 @@ public InteractionConverter(Interaction transformer){
public List encodeFeatures(SparkMLEncoder encoder){
Interaction transformer = getTransformer();
- String name = "";
+ StringBuilder sb = new StringBuilder();
List result = new ArrayList<>();
@@ -50,14 +51,29 @@ public List encodeFeatures(SparkMLEncoder encoder){
List features = encoder.getFeatures(inputCol);
+ if(features.size() == 1){
+ Feature feature = features.get(0);
+
+ if(feature instanceof CategoricalFeature){
+ CategoricalFeature categoricalFeature = (CategoricalFeature)feature;
+
+ FieldName name = categoricalFeature.getName();
+
+ // XXX
+ inputCol = name.getValue();
+
+ features = OneHotEncoderConverter.encodeFeature(categoricalFeature.getEncoder(), categoricalFeature, categoricalFeature.getValues());
+ }
+ } // End if
+
if(i == 0){
- name = inputCol;
+ sb.append(inputCol);
result = features;
} else
{
- name += (":" + inputCol);
+ sb.append(':').append(inputCol);
List interactionFeatures = new ArrayList<>();
@@ -66,7 +82,7 @@ public List encodeFeatures(SparkMLEncoder encoder){
for(Feature left : result){
for(Feature right : features){
- interactionFeatures.add(new InteractionFeature(encoder, FieldName.create(name + "[" + index + "]"), DataType.DOUBLE, Arrays.asList(left, right)));
+ interactionFeatures.add(new InteractionFeature(encoder, FieldName.create(sb.toString() + "[" + index + "]"), DataType.DOUBLE, Arrays.asList(left, right)));
index++;
}
diff --git a/src/main/java/org/jpmml/sparkml/feature/OneHotEncoderConverter.java b/src/main/java/org/jpmml/sparkml/feature/OneHotEncoderConverter.java
index 413ed754..5e39d8dd 100644
--- a/src/main/java/org/jpmml/sparkml/feature/OneHotEncoderConverter.java
+++ b/src/main/java/org/jpmml/sparkml/feature/OneHotEncoderConverter.java
@@ -25,6 +25,7 @@
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.Feature;
+import org.jpmml.converter.PMMLEncoder;
import org.jpmml.sparkml.FeatureConverter;
import org.jpmml.sparkml.SparkMLEncoder;
import scala.Option;
@@ -39,6 +40,8 @@ public OneHotEncoderConverter(OneHotEncoder transformer){
public List encodeFeatures(SparkMLEncoder encoder){
OneHotEncoder transformer = getTransformer();
+ CategoricalFeature categoricalFeature = (CategoricalFeature)encoder.getOnlyFeature(transformer.getInputCol());
+
boolean dropLast = true;
Option