From e73ebf3f4b8ab99a680f98f307445f1b6ba9909e Mon Sep 17 00:00:00 2001 From: Villu Ruusmann Date: Fri, 21 Dec 2018 18:28:50 +0200 Subject: [PATCH 01/10] Updated version information --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index f05702a6..1381c2c7 100644 --- a/README.md +++ b/README.md @@ -66,13 +66,13 @@ Java library and command-line application for converting Apache Spark ML pipelin JPMML-SparkML library JAR file (together with accompanying Java source and Javadocs JAR files) is released via [Maven Central Repository](https://repo1.maven.org/maven2/org/jpmml/). -The current version is **1.1.21** (2 October, 2018). +The current version is **1.1.22** (21 December, 2018). ```xml org.jpmml jpmml-sparkml - 1.1.21 + 1.1.22 ``` @@ -82,7 +82,7 @@ Compatibility matrix: |-----------------------|----------------------|--------------| | 1.0.0 through 1.0.9 | 1.5.X and 1.6.X | 4.2 | | 1.1.0 | 2.0.X | 4.2 | -| 1.1.1 through 1.1.21 | 2.0.X | 4.3 | +| 1.1.1 through 1.1.22 | 2.0.X | 4.3 | JPMML-SparkML depends on the latest and greatest version of the [JPMML-Model](https://github.com/jpmml/jpmml-model) library, which is in conflict with the legacy version that is part of the Apache Spark distribution. From 4950b968ede65678a4074c0aeb7fa968ea8bcaea Mon Sep 17 00:00:00 2001 From: Villu Ruusmann Date: Sat, 5 Jan 2019 00:45:51 +0200 Subject: [PATCH 02/10] Updated JPMML-Evaluator dependency --- pom.xml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pom.xml b/pom.xml index fdb0a817..e0fd39fe 100644 --- a/pom.xml +++ b/pom.xml @@ -90,13 +90,13 @@ org.jpmml pmml-evaluator - 1.4.4 + 1.4.5 test org.jpmml pmml-evaluator-test - 1.4.4 + 1.4.5 test From 88b9967e6844717465df7f331208e6a04c4df7d7 Mon Sep 17 00:00:00 2001 From: Villu Ruusmann Date: Sat, 19 Jan 2019 09:49:04 +0200 Subject: [PATCH 03/10] Updated JPMML-Converter dependency --- pom.xml | 6 +-- .../org/jpmml/sparkml/ModelConverter.java | 42 ++++++------------- .../jpmml/sparkml/model/TreeModelUtil.java | 7 ++-- .../sparkml/visitors/TreeModelCompactor.java | 2 +- 4 files changed, 20 insertions(+), 37 deletions(-) diff --git a/pom.xml b/pom.xml index e0fd39fe..e013b008 100644 --- a/pom.xml +++ b/pom.xml @@ -77,7 +77,7 @@ org.jpmml jpmml-converter - 1.3.4 + 1.3.5 @@ -90,13 +90,13 @@ org.jpmml pmml-evaluator - 1.4.5 + 1.4.6 test org.jpmml pmml-evaluator-test - 1.4.5 + 1.4.6 test diff --git a/src/main/java/org/jpmml/sparkml/ModelConverter.java b/src/main/java/org/jpmml/sparkml/ModelConverter.java index 3e402a91..6442d9bf 100644 --- a/src/main/java/org/jpmml/sparkml/ModelConverter.java +++ b/src/main/java/org/jpmml/sparkml/ModelConverter.java @@ -34,9 +34,6 @@ import org.dmg.pmml.Output; import org.dmg.pmml.OutputField; import org.dmg.pmml.mining.MiningModel; -import org.dmg.pmml.mining.Segment; -import org.dmg.pmml.mining.Segmentation; -import org.dmg.pmml.mining.Segmentation.MultipleModelMethod; import org.jpmml.converter.BooleanFeature; import org.jpmml.converter.CategoricalFeature; import org.jpmml.converter.CategoricalLabel; @@ -46,6 +43,7 @@ import org.jpmml.converter.Label; import org.jpmml.converter.ModelUtil; import org.jpmml.converter.Schema; +import org.jpmml.converter.mining.MiningModelUtil; abstract public class ModelConverter & HasFeaturesCol & HasPredictionCol> extends TransformerConverter { @@ -175,39 +173,23 @@ public org.dmg.pmml.Model registerModel(SparkMLEncoder encoder){ List sparkOutputFields = registerOutputFields(label, encoder); if(sparkOutputFields != null && sparkOutputFields.size() > 0){ - org.dmg.pmml.Model lastModel = getLastModel(model); + Output output; - Output output = ModelUtil.ensureOutput(lastModel); + if(model instanceof MiningModel){ + MiningModel miningModel = (MiningModel)model; - List outputFields = output.getOutputFields(); - - outputFields.addAll(0, sparkOutputFields); - } - - return model; - } - - protected org.dmg.pmml.Model getLastModel(org.dmg.pmml.Model model){ - - if(model instanceof MiningModel){ - MiningModel miningModel = (MiningModel)model; + org.dmg.pmml.Model finalModel = MiningModelUtil.getFinalModel(miningModel); - Segmentation segmentation = miningModel.getSegmentation(); + output = ModelUtil.ensureOutput(finalModel); + } else - MultipleModelMethod multipleModelMethod = segmentation.getMultipleModelMethod(); - switch(multipleModelMethod){ - case MODEL_CHAIN: - List segments = segmentation.getSegments(); + { + output = ModelUtil.ensureOutput(model); + } - if(segments.size() > 0){ - Segment lastSegment = segments.get(segments.size() - 1); + List outputFields = output.getOutputFields(); - return lastSegment.getModel(); - } - break; - default: - break; - } + outputFields.addAll(0, sparkOutputFields); } return model; diff --git a/src/main/java/org/jpmml/sparkml/model/TreeModelUtil.java b/src/main/java/org/jpmml/sparkml/model/TreeModelUtil.java index 4fc76352..efac0c70 100644 --- a/src/main/java/org/jpmml/sparkml/model/TreeModelUtil.java +++ b/src/main/java/org/jpmml/sparkml/model/TreeModelUtil.java @@ -41,6 +41,7 @@ import org.dmg.pmml.SimplePredicate; import org.dmg.pmml.True; import org.dmg.pmml.Visitor; +import org.dmg.pmml.tree.ComplexNode; import org.dmg.pmml.tree.Node; import org.dmg.pmml.tree.TreeModel; import org.jpmml.converter.BinaryFeature; @@ -161,7 +162,7 @@ public void encode(Node node, LeafNode leafNode){ static private & DecisionTreeModel> TreeModel encodeTreeModel(M model, PredicateManager predicateManager, MiningFunction miningFunction, ScoreEncoder scoreEncoder, Schema schema){ - Node root = new Node() + Node root = new ComplexNode() .setPredicate(new True()); encodeNode(root, model.rootNode(), predicateManager, new CategoryManager(), scoreEncoder, schema); @@ -283,10 +284,10 @@ private void encodeNode(Node node, org.apache.spark.ml.tree.Node sparkNode, Pred throw new IllegalArgumentException(); } - Node leftChild = new Node() + Node leftChild = new ComplexNode() .setPredicate(leftPredicate); - Node rightChild = new Node() + Node rightChild = new ComplexNode() .setPredicate(rightPredicate); encodeNode(leftChild, internalNode.leftChild(), predicateManager, leftCategoryManager, scoreEncoder, schema); diff --git a/src/main/java/org/jpmml/sparkml/visitors/TreeModelCompactor.java b/src/main/java/org/jpmml/sparkml/visitors/TreeModelCompactor.java index a3f95ccd..8c61a1ec 100644 --- a/src/main/java/org/jpmml/sparkml/visitors/TreeModelCompactor.java +++ b/src/main/java/org/jpmml/sparkml/visitors/TreeModelCompactor.java @@ -43,7 +43,7 @@ public class TreeModelCompactor extends AbstractTreeModelTransformer { @Override public void enterNode(Node node){ String id = node.getId(); - String score = node.getScore(); + Object score = node.getScore(); if(id != null){ throw new IllegalArgumentException(); From 1b278fcc1262e249488e1f0e900b9cba6bfd384f Mon Sep 17 00:00:00 2001 From: Villu Ruusmann Date: Sat, 19 Jan 2019 10:45:24 +0200 Subject: [PATCH 04/10] Optimized the encoding of Node elements --- .../jpmml/sparkml/model/TreeModelUtil.java | 58 ++++++++++--------- 1 file changed, 32 insertions(+), 26 deletions(-) diff --git a/src/main/java/org/jpmml/sparkml/model/TreeModelUtil.java b/src/main/java/org/jpmml/sparkml/model/TreeModelUtil.java index efac0c70..a57d880a 100644 --- a/src/main/java/org/jpmml/sparkml/model/TreeModelUtil.java +++ b/src/main/java/org/jpmml/sparkml/model/TreeModelUtil.java @@ -29,8 +29,6 @@ import org.apache.spark.ml.tree.CategoricalSplit; import org.apache.spark.ml.tree.ContinuousSplit; import org.apache.spark.ml.tree.DecisionTreeModel; -import org.apache.spark.ml.tree.InternalNode; -import org.apache.spark.ml.tree.LeafNode; import org.apache.spark.ml.tree.Split; import org.apache.spark.ml.tree.TreeEnsembleModel; import org.apache.spark.mllib.tree.impurity.ImpurityCalculator; @@ -41,7 +39,9 @@ import org.dmg.pmml.SimplePredicate; import org.dmg.pmml.True; import org.dmg.pmml.Visitor; +import org.dmg.pmml.tree.BranchNode; import org.dmg.pmml.tree.ComplexNode; +import org.dmg.pmml.tree.LeafNode; import org.dmg.pmml.tree.Node; import org.dmg.pmml.tree.TreeModel; import org.jpmml.converter.BinaryFeature; @@ -108,10 +108,10 @@ private & DecisionTreeModel> TreeModel encodeDecisionTree(Mo ScoreEncoder scoreEncoder = new ScoreEncoder(){ @Override - public void encode(Node node, LeafNode leafNode){ - String score = ValueUtil.formatValue(leafNode.prediction()); + public Node encode(Node node, org.apache.spark.ml.tree.LeafNode leafNode){ + node.setScore(leafNode.prediction()); - node.setScore(score); + return node; } }; @@ -125,7 +125,11 @@ public void encode(Node node, LeafNode leafNode){ @Override - public void encode(Node node, LeafNode leafNode){ + public Node encode(Node node, org.apache.spark.ml.tree.LeafNode leafNode){ + // XXX + node = new ComplexNode() + .setPredicate(node.getPredicate()); + int index = ValueUtil.asInt(leafNode.prediction()); node.setScore(this.categoricalLabel.getValue(index)); @@ -134,12 +138,16 @@ public void encode(Node node, LeafNode leafNode){ node.setRecordCount((double)impurityCalculator.count()); + List scoreDistributions = node.getScoreDistributions(); + double[] stats = impurityCalculator.stats(); for(int i = 0; i < stats.length; i++){ ScoreDistribution scoreDistribution = new ScoreDistribution(this.categoricalLabel.getValue(i), stats[i]); - node.addScoreDistributions(scoreDistribution); + scoreDistributions.add(scoreDistribution); } + + return node; } }; @@ -162,10 +170,7 @@ public void encode(Node node, LeafNode leafNode){ static private & DecisionTreeModel> TreeModel encodeTreeModel(M model, PredicateManager predicateManager, MiningFunction miningFunction, ScoreEncoder scoreEncoder, Schema schema){ - Node root = new ComplexNode() - .setPredicate(new True()); - - encodeNode(root, model.rootNode(), predicateManager, new CategoryManager(), scoreEncoder, schema); + Node root = encodeNode(new True(), model.rootNode(), predicateManager, new CategoryManager(), scoreEncoder, schema); TreeModel treeModel = new TreeModel(miningFunction, ModelUtil.createMiningSchema(schema.getLabel()), root) .setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT); @@ -174,16 +179,19 @@ private & DecisionTreeModel> TreeModel encodeTreeModel(M mod } static - private void encodeNode(Node node, org.apache.spark.ml.tree.Node sparkNode, PredicateManager predicateManager, CategoryManager categoryManager, ScoreEncoder scoreEncoder, Schema schema){ + private Node encodeNode(Predicate predicate, org.apache.spark.ml.tree.Node sparkNode, PredicateManager predicateManager, CategoryManager categoryManager, ScoreEncoder scoreEncoder, Schema schema){ - if(sparkNode instanceof LeafNode){ - LeafNode leafNode = (LeafNode)sparkNode; + if(sparkNode instanceof org.apache.spark.ml.tree.LeafNode){ + org.apache.spark.ml.tree.LeafNode leafNode = (org.apache.spark.ml.tree.LeafNode)sparkNode; - scoreEncoder.encode(node, leafNode); + Node result = new LeafNode() + .setPredicate(predicate); + + return scoreEncoder.encode(result, leafNode); } else - if(sparkNode instanceof InternalNode){ - InternalNode internalNode = (InternalNode)sparkNode; + if(sparkNode instanceof org.apache.spark.ml.tree.InternalNode){ + org.apache.spark.ml.tree.InternalNode internalNode = (org.apache.spark.ml.tree.InternalNode)sparkNode; CategoryManager leftCategoryManager = categoryManager; CategoryManager rightCategoryManager = categoryManager; @@ -284,16 +292,14 @@ private void encodeNode(Node node, org.apache.spark.ml.tree.Node sparkNode, Pred throw new IllegalArgumentException(); } - Node leftChild = new ComplexNode() - .setPredicate(leftPredicate); - - Node rightChild = new ComplexNode() - .setPredicate(rightPredicate); + Node leftChild = encodeNode(leftPredicate, internalNode.leftChild(), predicateManager, leftCategoryManager, scoreEncoder, schema); + Node rightChild = encodeNode(rightPredicate, internalNode.rightChild(), predicateManager, rightCategoryManager, scoreEncoder, schema); - encodeNode(leftChild, internalNode.leftChild(), predicateManager, leftCategoryManager, scoreEncoder, schema); - encodeNode(rightChild, internalNode.rightChild(), predicateManager, rightCategoryManager, scoreEncoder, schema); + Node result = new BranchNode() + .setPredicate(predicate) + .addNodes(leftChild, rightChild); - node.addNodes(leftChild, rightChild); + return result; } else { @@ -335,7 +341,7 @@ private List selectValues(List values, double[] categories, java interface ScoreEncoder { - void encode(Node node, LeafNode leafNode); + Node encode(Node node, org.apache.spark.ml.tree.LeafNode leafNode); } private static final double[] TRUE = {1.0d}; From 1af41c046d3bfaeadff83db6540e4d28ec8d0cc1 Mon Sep 17 00:00:00 2001 From: Villu Ruusmann Date: Tue, 19 Feb 2019 22:14:48 +0200 Subject: [PATCH 05/10] Updated JPMML-Evaluator dependency --- pom.xml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pom.xml b/pom.xml index e013b008..593b013f 100644 --- a/pom.xml +++ b/pom.xml @@ -90,13 +90,13 @@ org.jpmml pmml-evaluator - 1.4.6 + 1.4.7 test org.jpmml pmml-evaluator-test - 1.4.6 + 1.4.7 test From 4d3ca46cb64ed66944e4becdfd3f231ddf7b967f Mon Sep 17 00:00:00 2001 From: Villu Ruusmann Date: Wed, 20 Feb 2019 10:58:38 +0200 Subject: [PATCH 06/10] Added utility method VectorUtil#checkSize(int, Vector[]) --- .../java/org/jpmml/sparkml/MatrixUtil.java | 43 +++++++++++++++++++ .../java/org/jpmml/sparkml/VectorUtil.java | 11 +++++ .../sparkml/feature/IDFModelConverter.java | 8 ++-- .../feature/MaxAbsScalerModelConverter.java | 8 ++-- .../feature/MinMaxScalerModelConverter.java | 14 +++--- .../sparkml/feature/PCAModelConverter.java | 8 ++-- .../feature/StandardScalerModelConverter.java | 23 ++++++---- .../model/NaiveBayesModelConverter.java | 8 +++- 8 files changed, 91 insertions(+), 32 deletions(-) create mode 100644 src/main/java/org/jpmml/sparkml/MatrixUtil.java diff --git a/src/main/java/org/jpmml/sparkml/MatrixUtil.java b/src/main/java/org/jpmml/sparkml/MatrixUtil.java new file mode 100644 index 00000000..a70ab994 --- /dev/null +++ b/src/main/java/org/jpmml/sparkml/MatrixUtil.java @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2019 Villu Ruusmann + * + * This file is part of JPMML-SparkML + * + * JPMML-SparkML is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * JPMML-SparkML is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with JPMML-SparkML. If not, see . + */ +package org.jpmml.sparkml; + +import org.apache.spark.ml.linalg.Matrix; + +public class MatrixUtil { + + private MatrixUtil(){ + } + + static + public void checkColumns(int columns, Matrix matrix){ + + if(matrix.numCols() != columns){ + throw new IllegalArgumentException("Expected " + columns + " column(s), got " + matrix.numCols() + " column(s)"); + } + } + + static + public void checkRows(int rows, Matrix matrix){ + + if(matrix.numRows() != rows){ + throw new IllegalArgumentException("Expected " + rows + " row(s), got " + matrix.numRows() + " row(s)"); + } + } +} \ No newline at end of file diff --git a/src/main/java/org/jpmml/sparkml/VectorUtil.java b/src/main/java/org/jpmml/sparkml/VectorUtil.java index 56d6e622..42bf23f9 100644 --- a/src/main/java/org/jpmml/sparkml/VectorUtil.java +++ b/src/main/java/org/jpmml/sparkml/VectorUtil.java @@ -29,6 +29,17 @@ public class VectorUtil { private VectorUtil(){ } + static + public void checkSize(int size, Vector... vectors){ + + for(Vector vector : vectors){ + + if(vector.size() != size){ + throw new IllegalArgumentException("Expected " + size + " element(s), got " + vector.size() + " element(s)"); + } + } + } + static public List toList(Vector vector){ DenseVector denseVector = vector.toDense(); diff --git a/src/main/java/org/jpmml/sparkml/feature/IDFModelConverter.java b/src/main/java/org/jpmml/sparkml/feature/IDFModelConverter.java index 00f5a032..39ac9dd8 100644 --- a/src/main/java/org/jpmml/sparkml/feature/IDFModelConverter.java +++ b/src/main/java/org/jpmml/sparkml/feature/IDFModelConverter.java @@ -29,6 +29,7 @@ import org.jpmml.sparkml.FeatureConverter; import org.jpmml.sparkml.SparkMLEncoder; import org.jpmml.sparkml.TermFeature; +import org.jpmml.sparkml.VectorUtil; import org.jpmml.sparkml.WeightedTermFeature; public class IDFModelConverter extends FeatureConverter { @@ -41,12 +42,11 @@ public IDFModelConverter(IDFModel transformer){ public List encodeFeatures(SparkMLEncoder encoder){ IDFModel transformer = getTransformer(); + Vector idf = transformer.idf(); + List features = encoder.getFeatures(transformer.getInputCol()); - Vector idf = transformer.idf(); - if(idf.size() != features.size()){ - throw new IllegalArgumentException(); - } + VectorUtil.checkSize(features.size(), idf); List result = new ArrayList<>(); diff --git a/src/main/java/org/jpmml/sparkml/feature/MaxAbsScalerModelConverter.java b/src/main/java/org/jpmml/sparkml/feature/MaxAbsScalerModelConverter.java index e177d86b..1ae12739 100644 --- a/src/main/java/org/jpmml/sparkml/feature/MaxAbsScalerModelConverter.java +++ b/src/main/java/org/jpmml/sparkml/feature/MaxAbsScalerModelConverter.java @@ -33,6 +33,7 @@ import org.jpmml.converter.ValueUtil; import org.jpmml.sparkml.FeatureConverter; import org.jpmml.sparkml.SparkMLEncoder; +import org.jpmml.sparkml.VectorUtil; public class MaxAbsScalerModelConverter extends FeatureConverter { @@ -44,12 +45,11 @@ public MaxAbsScalerModelConverter(MaxAbsScalerModel transformer){ public List encodeFeatures(SparkMLEncoder encoder){ MaxAbsScalerModel transformer = getTransformer(); + Vector maxAbs = transformer.maxAbs(); + List features = encoder.getFeatures(transformer.getInputCol()); - Vector maxAbs = transformer.maxAbs(); - if(maxAbs.size() != features.size()){ - throw new IllegalArgumentException(); - } + VectorUtil.checkSize(features.size(), maxAbs); List result = new ArrayList<>(); diff --git a/src/main/java/org/jpmml/sparkml/feature/MinMaxScalerModelConverter.java b/src/main/java/org/jpmml/sparkml/feature/MinMaxScalerModelConverter.java index 6a78d0be..996e501b 100644 --- a/src/main/java/org/jpmml/sparkml/feature/MinMaxScalerModelConverter.java +++ b/src/main/java/org/jpmml/sparkml/feature/MinMaxScalerModelConverter.java @@ -33,6 +33,7 @@ import org.jpmml.converter.ValueUtil; import org.jpmml.sparkml.FeatureConverter; import org.jpmml.sparkml.SparkMLEncoder; +import org.jpmml.sparkml.VectorUtil; public class MinMaxScalerModelConverter extends FeatureConverter { @@ -47,17 +48,12 @@ public List encodeFeatures(SparkMLEncoder encoder){ double rescaleFactor = (transformer.getMax() - transformer.getMin()); double rescaleConstant = transformer.getMin(); - List features = encoder.getFeatures(transformer.getInputCol()); - Vector originalMax = transformer.originalMax(); - if(originalMax.size() != features.size()){ - throw new IllegalArgumentException(); - } - Vector originalMin = transformer.originalMin(); - if(originalMin.size() != features.size()){ - throw new IllegalArgumentException(); - } + + List features = encoder.getFeatures(transformer.getInputCol()); + + VectorUtil.checkSize(features.size(), originalMax, originalMin); List result = new ArrayList<>(); diff --git a/src/main/java/org/jpmml/sparkml/feature/PCAModelConverter.java b/src/main/java/org/jpmml/sparkml/feature/PCAModelConverter.java index edafe8c0..5599f584 100644 --- a/src/main/java/org/jpmml/sparkml/feature/PCAModelConverter.java +++ b/src/main/java/org/jpmml/sparkml/feature/PCAModelConverter.java @@ -33,6 +33,7 @@ import org.jpmml.converter.PMMLUtil; import org.jpmml.converter.ValueUtil; import org.jpmml.sparkml.FeatureConverter; +import org.jpmml.sparkml.MatrixUtil; import org.jpmml.sparkml.SparkMLEncoder; public class PCAModelConverter extends FeatureConverter { @@ -45,12 +46,11 @@ public PCAModelConverter(PCAModel transformer){ public List encodeFeatures(SparkMLEncoder encoder){ PCAModel transformer = getTransformer(); + DenseMatrix pc = transformer.pc(); + List features = encoder.getFeatures(transformer.getInputCol()); - DenseMatrix pc = transformer.pc(); - if(pc.numRows() != features.size()){ - throw new IllegalArgumentException(); - } + MatrixUtil.checkRows(features.size(), pc); List result = new ArrayList<>(); diff --git a/src/main/java/org/jpmml/sparkml/feature/StandardScalerModelConverter.java b/src/main/java/org/jpmml/sparkml/feature/StandardScalerModelConverter.java index fd455c5e..3bf15489 100644 --- a/src/main/java/org/jpmml/sparkml/feature/StandardScalerModelConverter.java +++ b/src/main/java/org/jpmml/sparkml/feature/StandardScalerModelConverter.java @@ -37,6 +37,7 @@ import org.jpmml.converter.ValueUtil; import org.jpmml.sparkml.FeatureConverter; import org.jpmml.sparkml.SparkMLEncoder; +import org.jpmml.sparkml.VectorUtil; public class StandardScalerModelConverter extends FeatureConverter { @@ -48,16 +49,20 @@ public StandardScalerModelConverter(StandardScalerModel transformer){ public List encodeFeatures(SparkMLEncoder encoder){ StandardScalerModel transformer = getTransformer(); + Vector mean = transformer.mean(); + Vector std = transformer.std(); + + boolean withMean = transformer.getWithMean(); + boolean withStd = transformer.getWithStd(); + List features = encoder.getFeatures(transformer.getInputCol()); - Vector mean = transformer.mean(); - if(transformer.getWithMean() && mean.size() != features.size()){ - throw new IllegalArgumentException(); - } + if(withMean){ + VectorUtil.checkSize(features.size(), mean); + } // End if - Vector std = transformer.std(); - if(transformer.getWithStd() && std.size() != features.size()){ - throw new IllegalArgumentException(); + if(withStd){ + VectorUtil.checkSize(features.size(), std); } List result = new ArrayList<>(); @@ -69,7 +74,7 @@ public List encodeFeatures(SparkMLEncoder encoder){ Expression expression = null; - if(transformer.getWithMean()){ + if(withMean){ double meanValue = mean.apply(i); if(!ValueUtil.isZero(meanValue)){ @@ -79,7 +84,7 @@ public List encodeFeatures(SparkMLEncoder encoder){ } } // End if - if(transformer.getWithStd()){ + if(withStd){ double stdValue = std.apply(i); if(!ValueUtil.isOne(stdValue)){ diff --git a/src/main/java/org/jpmml/sparkml/model/NaiveBayesModelConverter.java b/src/main/java/org/jpmml/sparkml/model/NaiveBayesModelConverter.java index d317b73a..1a2361f0 100644 --- a/src/main/java/org/jpmml/sparkml/model/NaiveBayesModelConverter.java +++ b/src/main/java/org/jpmml/sparkml/model/NaiveBayesModelConverter.java @@ -34,6 +34,7 @@ import org.jpmml.converter.Schema; import org.jpmml.converter.regression.RegressionModelUtil; import org.jpmml.sparkml.ClassificationModelConverter; +import org.jpmml.sparkml.MatrixUtil; import org.jpmml.sparkml.VectorUtil; public class NaiveBayesModelConverter extends ClassificationModelConverter implements HasRegressionOptions { @@ -71,10 +72,13 @@ public RegressionModel encodeModel(Schema schema){ Vector pi = model.pi(); Matrix theta = model.theta(); - List intercepts = VectorUtil.toList(pi); - CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel(); + VectorUtil.checkSize(categoricalLabel.size(), pi); + MatrixUtil.checkRows(categoricalLabel.size(), theta); + + List intercepts = VectorUtil.toList(pi); + scala.collection.Iterator thetaRows = theta.rowIter(); RegressionModel regressionModel = new RegressionModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(categoricalLabel), null) From 4c2a4ae80b89d1821a80203c973038f4e20c75cd Mon Sep 17 00:00:00 2001 From: Villu Ruusmann Date: Wed, 20 Feb 2019 12:21:10 +0200 Subject: [PATCH 07/10] Refined exception messages --- .../org/jpmml/sparkml/ModelConverter.java | 11 ++--- .../java/org/jpmml/sparkml/SchemaUtil.java | 46 +++++++++++++++++++ .../org/jpmml/sparkml/SparkMLEncoder.java | 4 +- .../CountVectorizerModelConverter.java | 2 +- .../feature/SQLTransformerConverter.java | 4 +- .../feature/StopWordsRemoverConverter.java | 2 +- .../feature/StringIndexerModelConverter.java | 2 +- .../feature/VectorIndexerModelConverter.java | 8 ++-- ...ralizedLinearRegressionModelConverter.java | 9 ++-- ...erceptronClassificationModelConverter.java | 11 ++--- .../model/NaiveBayesModelConverter.java | 4 +- .../jpmml/sparkml/model/TreeModelUtil.java | 2 +- 12 files changed, 73 insertions(+), 32 deletions(-) create mode 100644 src/main/java/org/jpmml/sparkml/SchemaUtil.java diff --git a/src/main/java/org/jpmml/sparkml/ModelConverter.java b/src/main/java/org/jpmml/sparkml/ModelConverter.java index 6442d9bf..3f1e35c3 100644 --- a/src/main/java/org/jpmml/sparkml/ModelConverter.java +++ b/src/main/java/org/jpmml/sparkml/ModelConverter.java @@ -134,12 +134,11 @@ public Schema encodeSchema(SparkMLEncoder encoder){ if(model instanceof ClassificationModel){ ClassificationModel classificationModel = (ClassificationModel)model; + int numClasses = classificationModel.numClasses(); + CategoricalLabel categoricalLabel = (CategoricalLabel)label; - int numClasses = classificationModel.numClasses(); - if(numClasses != categoricalLabel.size()){ - throw new IllegalArgumentException("Expected " + numClasses + " target categories, got " + categoricalLabel.size() + " target categories"); - } + SchemaUtil.checkSize(numClasses, categoricalLabel); } String featuresCol = model.getFeaturesCol(); @@ -150,8 +149,8 @@ public Schema encodeSchema(SparkMLEncoder encoder){ PredictionModel predictionModel = (PredictionModel)model; int numFeatures = predictionModel.numFeatures(); - if(numFeatures != -1 && features.size() != numFeatures){ - throw new IllegalArgumentException("Expected " + numFeatures + " features, got " + features.size() + " features"); + if(numFeatures != -1){ + SchemaUtil.checkSize(numFeatures, features); } } diff --git a/src/main/java/org/jpmml/sparkml/SchemaUtil.java b/src/main/java/org/jpmml/sparkml/SchemaUtil.java new file mode 100644 index 00000000..b89f53a3 --- /dev/null +++ b/src/main/java/org/jpmml/sparkml/SchemaUtil.java @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2019 Villu Ruusmann + * + * This file is part of JPMML-SparkML + * + * JPMML-SparkML is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * JPMML-SparkML is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with JPMML-SparkML. If not, see . + */ +package org.jpmml.sparkml; + +import java.util.List; + +import org.jpmml.converter.CategoricalLabel; +import org.jpmml.converter.Feature; + +public class SchemaUtil { + + private SchemaUtil(){ + } + + static + public void checkSize(int size, CategoricalLabel categoricalLabel){ + + if(categoricalLabel.size() != size){ + throw new IllegalArgumentException("Expected " + size + " target categories, got " + categoricalLabel.size() + " target categories"); + } + } + + static + public void checkSize(int size, List features){ + + if(features.size() != size){ + throw new IllegalArgumentException("Expected " + size + " feature(s), got " + features.size() + " feature(s)"); + } + } +} \ No newline at end of file diff --git a/src/main/java/org/jpmml/sparkml/SparkMLEncoder.java b/src/main/java/org/jpmml/sparkml/SparkMLEncoder.java index 1d823727..99bef8fe 100644 --- a/src/main/java/org/jpmml/sparkml/SparkMLEncoder.java +++ b/src/main/java/org/jpmml/sparkml/SparkMLEncoder.java @@ -141,7 +141,7 @@ public void putFeatures(String column, List features){ if(existingFeatures != null && existingFeatures.size() > 0){ if(features.size() != existingFeatures.size()){ - throw new IllegalArgumentException("Expected " + existingFeatures.size() + " features, got " + features.size() + " features"); + throw new IllegalArgumentException("Expected " + existingFeatures.size() + " feature(s), got " + features.size() + " feature(s)"); } for(int i = 0; i < existingFeatures.size(); i++){ @@ -149,7 +149,7 @@ public void putFeatures(String column, List features){ Feature feature = features.get(i); if(!(feature.getName()).equals(existingFeature.getName())){ - throw new IllegalArgumentException(); + throw new IllegalArgumentException("Expected '" + existingFeature.getName() + "' feature, got '" + feature.getName() + "' feature"); } } } diff --git a/src/main/java/org/jpmml/sparkml/feature/CountVectorizerModelConverter.java b/src/main/java/org/jpmml/sparkml/feature/CountVectorizerModelConverter.java index f2d1f719..002a1382 100644 --- a/src/main/java/org/jpmml/sparkml/feature/CountVectorizerModelConverter.java +++ b/src/main/java/org/jpmml/sparkml/feature/CountVectorizerModelConverter.java @@ -114,7 +114,7 @@ public List encodeFeatures(SparkMLEncoder encoder){ String term = vocabulary[i]; if(TermUtil.hasPunctuation(term)){ - throw new IllegalArgumentException(term); + throw new IllegalArgumentException("Punctuated vocabulary terms (" + term + ") are not supported"); } result.add(new TermFeature(encoder, defineFunction, documentFeature, term)); diff --git a/src/main/java/org/jpmml/sparkml/feature/SQLTransformerConverter.java b/src/main/java/org/jpmml/sparkml/feature/SQLTransformerConverter.java index 93d33272..2c90abec 100644 --- a/src/main/java/org/jpmml/sparkml/feature/SQLTransformerConverter.java +++ b/src/main/java/org/jpmml/sparkml/feature/SQLTransformerConverter.java @@ -104,7 +104,7 @@ public List encodeFeatures(SparkMLEncoder encoder){ opType = OpType.CATEGORICAL; break; default: - throw new IllegalArgumentException(); + throw new IllegalArgumentException("Data type " + dataType + " is not supported"); } org.dmg.pmml.Expression pmmlExpression = ExpressionTranslator.translate(expression); @@ -125,7 +125,7 @@ public List encodeFeatures(SparkMLEncoder encoder){ feature = new BooleanFeature(encoder, derivedField); break; default: - throw new IllegalArgumentException(); + throw new IllegalArgumentException("Data type " + dataType + " is not supported"); } encoder.putOnlyFeature(name, feature); diff --git a/src/main/java/org/jpmml/sparkml/feature/StopWordsRemoverConverter.java b/src/main/java/org/jpmml/sparkml/feature/StopWordsRemoverConverter.java index 067e0ec6..fdae00dc 100644 --- a/src/main/java/org/jpmml/sparkml/feature/StopWordsRemoverConverter.java +++ b/src/main/java/org/jpmml/sparkml/feature/StopWordsRemoverConverter.java @@ -46,7 +46,7 @@ public List encodeFeatures(SparkMLEncoder encoder){ for(String stopWord : stopWords){ if(TermUtil.hasPunctuation(stopWord)){ - throw new IllegalArgumentException(stopWord); + throw new IllegalArgumentException("Punctuated stop words (" + stopWord + ") are not supported"); } stopWordSet.add(stopWord); diff --git a/src/main/java/org/jpmml/sparkml/feature/StringIndexerModelConverter.java b/src/main/java/org/jpmml/sparkml/feature/StringIndexerModelConverter.java index 01d232b2..72483a69 100644 --- a/src/main/java/org/jpmml/sparkml/feature/StringIndexerModelConverter.java +++ b/src/main/java/org/jpmml/sparkml/feature/StringIndexerModelConverter.java @@ -62,7 +62,7 @@ public List encodeFeatures(SparkMLEncoder encoder){ invalidValueTreatmentMethod = InvalidValueTreatmentMethod.RETURN_INVALID; break; default: - throw new IllegalArgumentException(handleInvalid); + throw new IllegalArgumentException("Invalid value handling strategy " + handleInvalid + " is not supported"); } InvalidValueDecorator invalidValueDecorator = new InvalidValueDecorator() diff --git a/src/main/java/org/jpmml/sparkml/feature/VectorIndexerModelConverter.java b/src/main/java/org/jpmml/sparkml/feature/VectorIndexerModelConverter.java index 4eeb6bca..e4e2d88f 100644 --- a/src/main/java/org/jpmml/sparkml/feature/VectorIndexerModelConverter.java +++ b/src/main/java/org/jpmml/sparkml/feature/VectorIndexerModelConverter.java @@ -35,6 +35,7 @@ import org.jpmml.converter.PMMLUtil; import org.jpmml.converter.ValueUtil; import org.jpmml.sparkml.FeatureConverter; +import org.jpmml.sparkml.SchemaUtil; import org.jpmml.sparkml.SparkMLEncoder; public class VectorIndexerModelConverter extends FeatureConverter { @@ -47,12 +48,11 @@ public VectorIndexerModelConverter(VectorIndexerModel transformer){ public List encodeFeatures(SparkMLEncoder encoder){ VectorIndexerModel transformer = getTransformer(); + int numFeatures = transformer.numFeatures(); + List features = encoder.getFeatures(transformer.getInputCol()); - int numFeatures = transformer.numFeatures(); - if(numFeatures != features.size()){ - throw new IllegalArgumentException("Expected " + numFeatures + " features, got " + features.size() + " features"); - } + SchemaUtil.checkSize(numFeatures, features); Map> categoryMaps = transformer.javaCategoryMaps(); diff --git a/src/main/java/org/jpmml/sparkml/model/GeneralizedLinearRegressionModelConverter.java b/src/main/java/org/jpmml/sparkml/model/GeneralizedLinearRegressionModelConverter.java index 5712fc9d..c78306e8 100644 --- a/src/main/java/org/jpmml/sparkml/model/GeneralizedLinearRegressionModelConverter.java +++ b/src/main/java/org/jpmml/sparkml/model/GeneralizedLinearRegressionModelConverter.java @@ -33,6 +33,7 @@ import org.jpmml.converter.Schema; import org.jpmml.converter.general_regression.GeneralRegressionModelUtil; import org.jpmml.sparkml.RegressionModelConverter; +import org.jpmml.sparkml.SchemaUtil; import org.jpmml.sparkml.SparkMLEncoder; import org.jpmml.sparkml.VectorUtil; @@ -85,9 +86,7 @@ public GeneralRegressionModel encodeModel(Schema schema){ case CLASSIFICATION: CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel(); - if(categoricalLabel.size() != 2){ - throw new IllegalArgumentException(); - } + SchemaUtil.checkSize(2, categoricalLabel); targetCategory = categoricalLabel.getValue(1); break; @@ -123,7 +122,7 @@ private GeneralRegressionModel.Distribution parseFamily(String family){ case "poisson": return GeneralRegressionModel.Distribution.POISSON; default: - throw new IllegalArgumentException(family); + throw new IllegalArgumentException("Distribution family " + family + " is not supported"); } } @@ -146,7 +145,7 @@ private GeneralRegressionModel.LinkFunction parseLinkFunction(String link){ case "sqrt": return GeneralRegressionModel.LinkFunction.POWER; default: - throw new IllegalArgumentException(link); + throw new IllegalArgumentException("Link function " + link + " is not supported"); } } diff --git a/src/main/java/org/jpmml/sparkml/model/MultilayerPerceptronClassificationModelConverter.java b/src/main/java/org/jpmml/sparkml/model/MultilayerPerceptronClassificationModelConverter.java index 67f223e5..cd9beb3b 100644 --- a/src/main/java/org/jpmml/sparkml/model/MultilayerPerceptronClassificationModelConverter.java +++ b/src/main/java/org/jpmml/sparkml/model/MultilayerPerceptronClassificationModelConverter.java @@ -39,6 +39,7 @@ import org.jpmml.converter.Schema; import org.jpmml.converter.neural_network.NeuralNetworkUtil; import org.jpmml.sparkml.ClassificationModelConverter; +import org.jpmml.sparkml.SchemaUtil; import org.jpmml.sparkml.SparkMLEncoder; public class MultilayerPerceptronClassificationModelConverter extends ClassificationModelConverter { @@ -71,14 +72,10 @@ public NeuralNetwork encodeModel(Schema schema){ Vector weights = model.weights(); CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel(); - if(categoricalLabel.size() != layers[layers.length - 1]){ - throw new IllegalArgumentException(); - } - List features = schema.getFeatures(); - if(features.size() != layers[0]){ - throw new IllegalArgumentException(); - } + + SchemaUtil.checkSize(layers[layers.length - 1], categoricalLabel); + SchemaUtil.checkSize(layers[0], features); NeuralInputs neuralInputs = NeuralNetworkUtil.createNeuralInputs(features, DataType.DOUBLE); diff --git a/src/main/java/org/jpmml/sparkml/model/NaiveBayesModelConverter.java b/src/main/java/org/jpmml/sparkml/model/NaiveBayesModelConverter.java index 1a2361f0..968bd652 100644 --- a/src/main/java/org/jpmml/sparkml/model/NaiveBayesModelConverter.java +++ b/src/main/java/org/jpmml/sparkml/model/NaiveBayesModelConverter.java @@ -52,7 +52,7 @@ public RegressionModel encodeModel(Schema schema){ case "multinomial": break; default: - throw new IllegalArgumentException(modelType); + throw new IllegalArgumentException("Model type " + modelType + " is not supported"); } try { @@ -62,7 +62,7 @@ public RegressionModel encodeModel(Schema schema){ double threshold = thresholds[i]; if(threshold != 0d){ - throw new IllegalArgumentException(); + throw new IllegalArgumentException("Non-zero thresholds are not supported"); } } } catch(NoSuchElementException nsee){ diff --git a/src/main/java/org/jpmml/sparkml/model/TreeModelUtil.java b/src/main/java/org/jpmml/sparkml/model/TreeModelUtil.java index a57d880a..6fc8bf01 100644 --- a/src/main/java/org/jpmml/sparkml/model/TreeModelUtil.java +++ b/src/main/java/org/jpmml/sparkml/model/TreeModelUtil.java @@ -212,7 +212,7 @@ private Node encodeNode(Predicate predicate, org.apache.spark.ml.tree.Node spark BooleanFeature booleanFeature = (BooleanFeature)feature; if(threshold != 0d){ - throw new IllegalArgumentException(); + throw new IllegalArgumentException("Invalid split threshold value " + threshold + " for a boolean feature"); } leftPredicate = predicateManager.createSimplePredicate(booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(0)); From 81ffd6f0c98f78805a7cada37b25b2ec4473f306 Mon Sep 17 00:00:00 2001 From: Villu Ruusmann Date: Wed, 20 Feb 2019 13:36:52 +0200 Subject: [PATCH 08/10] Checking the sanity of supervised learning model schemas. Fixes #47 --- .../org/jpmml/sparkml/ModelConverter.java | 2 ++ .../java/org/jpmml/sparkml/SchemaUtil.java | 19 +++++++++++++++++++ .../org/jpmml/sparkml/SparkMLEncoder.java | 17 +---------------- 3 files changed, 22 insertions(+), 16 deletions(-) diff --git a/src/main/java/org/jpmml/sparkml/ModelConverter.java b/src/main/java/org/jpmml/sparkml/ModelConverter.java index 3f1e35c3..c23f9767 100644 --- a/src/main/java/org/jpmml/sparkml/ModelConverter.java +++ b/src/main/java/org/jpmml/sparkml/ModelConverter.java @@ -156,6 +156,8 @@ public Schema encodeSchema(SparkMLEncoder encoder){ Schema result = new Schema(label, features); + SchemaUtil.checkSchema(result); + return result; } diff --git a/src/main/java/org/jpmml/sparkml/SchemaUtil.java b/src/main/java/org/jpmml/sparkml/SchemaUtil.java index b89f53a3..8b00871e 100644 --- a/src/main/java/org/jpmml/sparkml/SchemaUtil.java +++ b/src/main/java/org/jpmml/sparkml/SchemaUtil.java @@ -19,15 +19,34 @@ package org.jpmml.sparkml; import java.util.List; +import java.util.Objects; import org.jpmml.converter.CategoricalLabel; import org.jpmml.converter.Feature; +import org.jpmml.converter.Label; +import org.jpmml.converter.Schema; public class SchemaUtil { private SchemaUtil(){ } + static + public void checkSchema(Schema schema){ + Label label = schema.getLabel(); + List features = schema.getFeatures(); + + if(label != null){ + + for(Feature feature : features){ + + if(Objects.equals(label.getName(), feature.getName())){ + throw new IllegalArgumentException("Label column '" + label.getName() + "' is contained in the list of feature columns"); + } + } + } + } + static public void checkSize(int size, CategoricalLabel categoricalLabel){ diff --git a/src/main/java/org/jpmml/sparkml/SparkMLEncoder.java b/src/main/java/org/jpmml/sparkml/SparkMLEncoder.java index 99bef8fe..fb8b9cbe 100644 --- a/src/main/java/org/jpmml/sparkml/SparkMLEncoder.java +++ b/src/main/java/org/jpmml/sparkml/SparkMLEncoder.java @@ -59,21 +59,6 @@ public boolean hasFeatures(String column){ return this.columnFeatures.containsKey(column); } - public List getSchemaFeatures(){ - StructType schema = getSchema(); - - List result = new ArrayList<>(); - - StructField[] fields = schema.fields(); - for(StructField field : fields){ - Feature feature = getOnlyFeature(field.name()); - - result.add(feature); - } - - return result; - } - public Feature getOnlyFeature(String column){ List features = getFeatures(column); @@ -149,7 +134,7 @@ public void putFeatures(String column, List features){ Feature feature = features.get(i); if(!(feature.getName()).equals(existingFeature.getName())){ - throw new IllegalArgumentException("Expected '" + existingFeature.getName() + "' feature, got '" + feature.getName() + "' feature"); + throw new IllegalArgumentException("Expected feature column '" + existingFeature.getName() + "', got feature column '" + feature.getName() + "'"); } } } From 885d5f84852e53153b178ad2dfeb2b48e84be97c Mon Sep 17 00:00:00 2001 From: Villu Ruusmann Date: Wed, 20 Feb 2019 13:44:48 +0200 Subject: [PATCH 09/10] [maven-release-plugin] prepare release 1.1.23 --- pom.xml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pom.xml b/pom.xml index 593b013f..a7008b28 100644 --- a/pom.xml +++ b/pom.xml @@ -10,7 +10,7 @@ org.jpmml jpmml-sparkml - 1.1-SNAPSHOT + 1.1.23 JPMML-SparkML Java library and command-line application for converting Spark ML pipelines to PMML @@ -35,7 +35,7 @@ scm:git:git@github.com:jpmml/jpmml-sparkml.git scm:git:git@github.com:jpmml/jpmml-sparkml.git git://github.com/jpmml/jpmml-sparkml.git - HEAD + 1.1.23 GitHub From d4c5685c731a8a88a85423e33ef924cc42ea7422 Mon Sep 17 00:00:00 2001 From: Villu Ruusmann Date: Wed, 20 Feb 2019 13:44:48 +0200 Subject: [PATCH 10/10] [maven-release-plugin] prepare for next development iteration --- pom.xml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pom.xml b/pom.xml index a7008b28..593b013f 100644 --- a/pom.xml +++ b/pom.xml @@ -10,7 +10,7 @@ org.jpmml jpmml-sparkml - 1.1.23 + 1.1-SNAPSHOT JPMML-SparkML Java library and command-line application for converting Spark ML pipelines to PMML @@ -35,7 +35,7 @@ scm:git:git@github.com:jpmml/jpmml-sparkml.git scm:git:git@github.com:jpmml/jpmml-sparkml.git git://github.com/jpmml/jpmml-sparkml.git - 1.1.23 + HEAD GitHub